xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers operations on buffer fat pointers (addrspace 7) to
10 // operations on buffer resources (addrspace 8) and is needed for correct
11 // codegen.
12 //
13 // # Background
14 //
15 // Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists
16 // of a 128-bit buffer descriptor and a 32-bit offset into that descriptor.
17 // The buffer resource part needs to be it needs to be a "raw" buffer resource
18 // (it must have a stride of 0 and bounds checks must be in raw buffer mode
19 // or disabled).
20 //
21 // When these requirements are met, a buffer resource can be treated as a
22 // typical (though quite wide) pointer that follows typical LLVM pointer
23 // semantics. This allows the frontend to reason about such buffers (which are
24 // often encountered in the context of SPIR-V kernels).
25 //
26 // However, because of their non-power-of-2 size, these fat pointers cannot be
27 // present during translation to MIR (though this restriction may be lifted
28 // during the transition to GlobalISel). Therefore, this pass is needed in order
29 // to correctly implement these fat pointers.
30 //
31 // The resource intrinsics take the resource part (the address space 8 pointer)
32 // and the offset part (the 32-bit integer) as separate arguments. In addition,
33 // many users of these buffers manipulate the offset while leaving the resource
34 // part alone. For these reasons, we want to typically separate the resource
35 // and offset parts into separate variables, but combine them together when
36 // encountering cases where this is required, such as by inserting these values
37 // into aggretates or moving them to memory.
38 //
39 // Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8)
40 // %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8),
41 // i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes
42 // `{vector<Nxp8>, vector<Nxi32 >}` and its component parts.
43 //
44 // # Implementation
45 //
46 // This pass proceeds in three main phases:
47 //
48 // ## Rewriting loads and stores of p7 and memcpy()-like handling
49 //
50 // The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
51 // including aggregates containing such pointers, to ones that use `i160`. This
52 // is handled by `StoreFatPtrsAsIntsAndExpandMemcpyVisitor` , which visits
53 // loads, stores, and allocas and, if the loaded or stored type contains `ptr
54 // addrspace(7)`, rewrites that type to one where the p7s are replaced by i160s,
55 // copying other parts of aggregates as needed. In the case of a store, each
56 // pointer is `ptrtoint`d to i160 before storing, and load integers are
57 // `inttoptr`d back. This same transformation is applied to vectors of pointers.
58 //
59 // Such a transformation allows the later phases of the pass to not need
60 // to handle buffer fat pointers moving to and from memory, where we load
61 // have to handle the incompatibility between a `{Nxp8, Nxi32}` representation
62 // and `Nxi60` directly. Instead, that transposing action (where the vectors
63 // of resources and vectors of offsets are concatentated before being stored to
64 // memory) are handled through implementing `inttoptr` and `ptrtoint` only.
65 //
66 // Atomics operations on `ptr addrspace(7)` values are not suppported, as the
67 // hardware does not include a 160-bit atomic.
68 //
69 // In order to save on O(N) work and to ensure that the contents type
70 // legalizer correctly splits up wide loads, also unconditionally lower
71 // memcpy-like intrinsics into loops here.
72 //
73 // ## Buffer contents type legalization
74 //
75 // The underlying buffer intrinsics only support types up to 128 bits long,
76 // and don't support complex types. If buffer operations were
77 // standard pointer operations that could be represented as MIR-level loads,
78 // this would be handled by the various legalization schemes in instruction
79 // selection. However, because we have to do the conversion from `load` and
80 // `store` to intrinsics at LLVM IR level, we must perform that legalization
81 // ourselves.
82 //
83 // This involves a combination of
84 // - Converting arrays to vectors where possible
85 // - Otherwise, splitting loads and stores of aggregates into loads/stores of
86 //   each component.
87 // - Zero-extending things to fill a whole number of bytes
88 // - Casting values of types that don't neatly correspond to supported machine
89 // value
90 //   (for example, an i96 or i256) into ones that would work (
91 //    like <3 x i32> and <8 x i32>, respectively)
92 // - Splitting values that are too long (such as aforementioned <8 x i32>) into
93 //   multiple operations.
94 //
95 // ## Type remapping
96 //
97 // We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers
98 // to the corresponding struct type, which has a resource part and an offset
99 // part.
100 //
101 // This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer`
102 // to, usually by way of `setType`ing values. Constants are handled here
103 // because there isn't a good way to fix them up later.
104 //
105 // This has the downside of leaving the IR in an invalid state (for example,
106 // the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist),
107 // but all such invalid states will be resolved by the third phase.
108 //
109 // Functions that don't take buffer fat pointers are modified in place. Those
110 // that do take such pointers have their basic blocks moved to a new function
111 // with arguments that are {ptr addrspace(8), i32} arguments and return values.
112 // This phase also records intrinsics so that they can be remangled or deleted
113 // later.
114 //
115 // ## Splitting pointer structs
116 //
117 // The meat of this pass consists of defining semantics for operations that
118 // produce or consume [vectors of] buffer fat pointers in terms of their
119 // resource and offset parts. This is accomplished throgh the `SplitPtrStructs`
120 // visitor.
121 //
122 // In the first pass through each function that is being lowered, the splitter
123 // inserts new instructions to implement the split-structures behavior, which is
124 // needed for correctness and performance. It records a list of "split users",
125 // instructions that are being replaced by operations on the resource and offset
126 // parts.
127 //
128 // Split users do not necessarily need to produce parts themselves (
129 // a `load float, ptr addrspace(7)` does not, for example), but, if they do not
130 // generate fat buffer pointers, they must RAUW in their replacement
131 // instructions during the initial visit.
132 //
133 // When these new instructions are created, they use the split parts recorded
134 // for their initial arguments in order to generate their replacements, creating
135 // a parallel set of instructions that does not refer to the original fat
136 // pointer values but instead to their resource and offset components.
137 //
138 // Instructions, such as `extractvalue`, that produce buffer fat pointers from
139 // sources that do not have split parts, have such parts generated using
140 // `extractvalue`. This is also the initial handling of PHI nodes, which
141 // are then cleaned up.
142 //
143 // ### Conditionals
144 //
145 // PHI nodes are initially given resource parts via `extractvalue`. However,
146 // this is not an efficient rewrite of such nodes, as, in most cases, the
147 // resource part in a conditional or loop remains constant throughout the loop
148 // and only the offset varies. Failing to optimize away these constant resources
149 // would cause additional registers to be sent around loops and might lead to
150 // waterfall loops being generated for buffer operations due to the
151 // "non-uniform" resource argument.
152 //
153 // Therefore, after all instructions have been visited, the pointer splitter
154 // post-processes all encountered conditionals. Given a PHI node or select,
155 // getPossibleRsrcRoots() collects all values that the resource parts of that
156 // conditional's input could come from as well as collecting all conditional
157 // instructions encountered during the search. If, after filtering out the
158 // initial node itself, the set of encountered conditionals is a subset of the
159 // potential roots and there is a single potential resource that isn't in the
160 // conditional set, that value is the only possible value the resource argument
161 // could have throughout the control flow.
162 //
163 // If that condition is met, then a PHI node can have its resource part changed
164 // to the singleton value and then be replaced by a PHI on the offsets.
165 // Otherwise, each PHI node is split into two, one for the resource part and one
166 // for the offset part, which replace the temporary `extractvalue` instructions
167 // that were added during the first pass.
168 //
169 // Similar logic applies to `select`, where
170 // `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y`
171 // can be split into `%z.rsrc = %x.rsrc` and
172 // `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off`
173 // if both `%x` and `%y` have the same resource part, but two `select`
174 // operations will be needed if they do not.
175 //
176 // ### Final processing
177 //
178 // After conditionals have been cleaned up, the IR for each function is
179 // rewritten to remove all the old instructions that have been split up.
180 //
181 // Any instruction that used to produce a buffer fat pointer (and therefore now
182 // produces a resource-and-offset struct after type remapping) is
183 // replaced as follows:
184 // 1. All debug value annotations are cloned to reflect that the resource part
185 //    and offset parts are computed separately and constitute different
186 //    fragments of the underlying source language variable.
187 // 2. All uses that were themselves split are replaced by a `poison` of the
188 //    struct type, as they will themselves be erased soon. This rule, combined
189 //    with debug handling, should leave the use lists of split instructions
190 //    empty in almost all cases.
191 // 3. If a user of the original struct-valued result remains, the structure
192 //    needed for the new types to work is constructed out of the newly-defined
193 //    parts, and the original instruction is replaced by this structure
194 //    before being erased. Instructions requiring this construction include
195 //    `ret` and `insertvalue`.
196 //
197 // # Consequences
198 //
199 // This pass does not alter the CFG.
200 //
201 // Alias analysis information will become coarser, as the LLVM alias analyzer
202 // cannot handle the buffer intrinsics. Specifically, while we can determine
203 // that the following two loads do not alias:
204 // ```
205 //   %y = getelementptr i32, ptr addrspace(7) %x, i32 1
206 //   %a = load i32, ptr addrspace(7) %x
207 //   %b = load i32, ptr addrspace(7) %y
208 // ```
209 // we cannot (except through some code that runs during scheduling) determine
210 // that the rewritten loads below do not alias.
211 // ```
212 //   %y.off = add i32 %x.off, 1
213 //   %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32
214 //     %x.off, ...)
215 //   %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8)
216 //     %x.rsrc, i32 %y.off, ...)
217 // ```
218 // However, existing alias information is preserved.
219 //===----------------------------------------------------------------------===//
220 
221 #include "AMDGPU.h"
222 #include "AMDGPUTargetMachine.h"
223 #include "GCNSubtarget.h"
224 #include "SIDefines.h"
225 #include "llvm/ADT/SetOperations.h"
226 #include "llvm/ADT/SmallVector.h"
227 #include "llvm/Analysis/InstSimplifyFolder.h"
228 #include "llvm/Analysis/Utils/Local.h"
229 #include "llvm/CodeGen/TargetPassConfig.h"
230 #include "llvm/IR/AttributeMask.h"
231 #include "llvm/IR/Constants.h"
232 #include "llvm/IR/DebugInfo.h"
233 #include "llvm/IR/DerivedTypes.h"
234 #include "llvm/IR/IRBuilder.h"
235 #include "llvm/IR/InstIterator.h"
236 #include "llvm/IR/InstVisitor.h"
237 #include "llvm/IR/Instructions.h"
238 #include "llvm/IR/IntrinsicInst.h"
239 #include "llvm/IR/Intrinsics.h"
240 #include "llvm/IR/IntrinsicsAMDGPU.h"
241 #include "llvm/IR/Metadata.h"
242 #include "llvm/IR/Operator.h"
243 #include "llvm/IR/PatternMatch.h"
244 #include "llvm/IR/ReplaceConstant.h"
245 #include "llvm/IR/ValueHandle.h"
246 #include "llvm/Pass.h"
247 #include "llvm/Support/AMDGPUAddrSpace.h"
248 #include "llvm/Support/Alignment.h"
249 #include "llvm/Support/AtomicOrdering.h"
250 #include "llvm/Support/Debug.h"
251 #include "llvm/Support/ErrorHandling.h"
252 #include "llvm/Transforms/Utils/Cloning.h"
253 #include "llvm/Transforms/Utils/Local.h"
254 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
255 #include "llvm/Transforms/Utils/ValueMapper.h"
256 
257 #define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
258 
259 using namespace llvm;
260 
261 static constexpr unsigned BufferOffsetWidth = 32;
262 
263 namespace {
264 /// Recursively replace instances of ptr addrspace(7) and vector<Nxptr
265 /// addrspace(7)> with some other type as defined by the relevant subclass.
266 class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper {
267   DenseMap<Type *, Type *> Map;
268 
269   Type *remapTypeImpl(Type *Ty);
270 
271 protected:
272   virtual Type *remapScalar(PointerType *PT) = 0;
273   virtual Type *remapVector(VectorType *VT) = 0;
274 
275   const DataLayout &DL;
276 
277 public:
BufferFatPtrTypeLoweringBase(const DataLayout & DL)278   BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {}
279   Type *remapType(Type *SrcTy) override;
clear()280   void clear() { Map.clear(); }
281 };
282 
283 /// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to
284 /// vector<Nxi60> in order to correctly handling loading/storing these values
285 /// from memory.
286 class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase {
287   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
288 
289 protected:
remapScalar(PointerType * PT)290   Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); }
remapVector(VectorType * VT)291   Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); }
292 };
293 
294 /// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset
295 /// parts of the pointer) so that we can easily rewrite operations on these
296 /// values that aren't loading them from or storing them to memory.
297 class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase {
298   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
299 
300 protected:
301   Type *remapScalar(PointerType *PT) override;
302   Type *remapVector(VectorType *VT) override;
303 };
304 } // namespace
305 
306 // This code is adapted from the type remapper in lib/Linker/IRMover.cpp
remapTypeImpl(Type * Ty)307 Type *BufferFatPtrTypeLoweringBase::remapTypeImpl(Type *Ty) {
308   Type **Entry = &Map[Ty];
309   if (*Entry)
310     return *Entry;
311   if (auto *PT = dyn_cast<PointerType>(Ty)) {
312     if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
313       return *Entry = remapScalar(PT);
314     }
315   }
316   if (auto *VT = dyn_cast<VectorType>(Ty)) {
317     auto *PT = dyn_cast<PointerType>(VT->getElementType());
318     if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
319       return *Entry = remapVector(VT);
320     }
321     return *Entry = Ty;
322   }
323   // Whether the type is one that is structurally uniqued - that is, if it is
324   // not a named struct (the only kind of type where multiple structurally
325   // identical types that have a distinct `Type*`)
326   StructType *TyAsStruct = dyn_cast<StructType>(Ty);
327   bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral();
328   // Base case for ints, floats, opaque pointers, and so on, which don't
329   // require recursion.
330   if (Ty->getNumContainedTypes() == 0 && IsUniqued)
331     return *Entry = Ty;
332   bool Changed = false;
333   SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr);
334   for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) {
335     Type *OldElem = Ty->getContainedType(I);
336     Type *NewElem = remapTypeImpl(OldElem);
337     ElementTypes[I] = NewElem;
338     Changed |= (OldElem != NewElem);
339   }
340   // Recursive calls to remapTypeImpl() may have invalidated pointer.
341   Entry = &Map[Ty];
342   if (!Changed) {
343     return *Entry = Ty;
344   }
345   if (auto *ArrTy = dyn_cast<ArrayType>(Ty))
346     return *Entry = ArrayType::get(ElementTypes[0], ArrTy->getNumElements());
347   if (auto *FnTy = dyn_cast<FunctionType>(Ty))
348     return *Entry = FunctionType::get(ElementTypes[0],
349                                       ArrayRef(ElementTypes).slice(1),
350                                       FnTy->isVarArg());
351   if (auto *STy = dyn_cast<StructType>(Ty)) {
352     // Genuine opaque types don't have a remapping.
353     if (STy->isOpaque())
354       return *Entry = Ty;
355     bool IsPacked = STy->isPacked();
356     if (IsUniqued)
357       return *Entry = StructType::get(Ty->getContext(), ElementTypes, IsPacked);
358     SmallString<16> Name(STy->getName());
359     STy->setName("");
360     return *Entry = StructType::create(Ty->getContext(), ElementTypes, Name,
361                                        IsPacked);
362   }
363   llvm_unreachable("Unknown type of type that contains elements");
364 }
365 
remapType(Type * SrcTy)366 Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) {
367   return remapTypeImpl(SrcTy);
368 }
369 
remapScalar(PointerType * PT)370 Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) {
371   LLVMContext &Ctx = PT->getContext();
372   return StructType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE),
373                          IntegerType::get(Ctx, BufferOffsetWidth));
374 }
375 
remapVector(VectorType * VT)376 Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) {
377   ElementCount EC = VT->getElementCount();
378   LLVMContext &Ctx = VT->getContext();
379   Type *RsrcVec =
380       VectorType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), EC);
381   Type *OffVec = VectorType::get(IntegerType::get(Ctx, BufferOffsetWidth), EC);
382   return StructType::get(RsrcVec, OffVec);
383 }
384 
isBufferFatPtrOrVector(Type * Ty)385 static bool isBufferFatPtrOrVector(Type *Ty) {
386   if (auto *PT = dyn_cast<PointerType>(Ty->getScalarType()))
387     return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER;
388   return false;
389 }
390 
391 // True if the type is {ptr addrspace(8), i32} or a struct containing vectors of
392 // those types. Used to quickly skip instructions we don't need to process.
isSplitFatPtr(Type * Ty)393 static bool isSplitFatPtr(Type *Ty) {
394   auto *ST = dyn_cast<StructType>(Ty);
395   if (!ST)
396     return false;
397   if (!ST->isLiteral() || ST->getNumElements() != 2)
398     return false;
399   auto *MaybeRsrc =
400       dyn_cast<PointerType>(ST->getElementType(0)->getScalarType());
401   auto *MaybeOff =
402       dyn_cast<IntegerType>(ST->getElementType(1)->getScalarType());
403   return MaybeRsrc && MaybeOff &&
404          MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE &&
405          MaybeOff->getBitWidth() == BufferOffsetWidth;
406 }
407 
408 // True if the result type or any argument types are buffer fat pointers.
isBufferFatPtrConst(Constant * C)409 static bool isBufferFatPtrConst(Constant *C) {
410   Type *T = C->getType();
411   return isBufferFatPtrOrVector(T) || any_of(C->operands(), [](const Use &U) {
412            return isBufferFatPtrOrVector(U.get()->getType());
413          });
414 }
415 
416 namespace {
417 /// Convert [vectors of] buffer fat pointers to integers when they are read from
418 /// or stored to memory. This ensures that these pointers will have the same
419 /// memory layout as before they are lowered, even though they will no longer
420 /// have their previous layout in registers/in the program (they'll be broken
421 /// down into resource and offset parts). This has the downside of imposing
422 /// marshalling costs when reading or storing these values, but since placing
423 /// such pointers into memory is an uncommon operation at best, we feel that
424 /// this cost is acceptable for better performance in the common case.
425 class StoreFatPtrsAsIntsAndExpandMemcpyVisitor
426     : public InstVisitor<StoreFatPtrsAsIntsAndExpandMemcpyVisitor, bool> {
427   BufferFatPtrToIntTypeMap *TypeMap;
428 
429   ValueToValueMapTy ConvertedForStore;
430 
431   IRBuilder<InstSimplifyFolder> IRB;
432 
433   const TargetMachine *TM;
434 
435   // Convert all the buffer fat pointers within the input value to inttegers
436   // so that it can be stored in memory.
437   Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
438   // Convert all the i160s that need to be buffer fat pointers (as specified)
439   // by the To type) into those pointers to preserve the semantics of the rest
440   // of the program.
441   Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);
442 
443 public:
StoreFatPtrsAsIntsAndExpandMemcpyVisitor(BufferFatPtrToIntTypeMap * TypeMap,const DataLayout & DL,LLVMContext & Ctx,const TargetMachine * TM)444   StoreFatPtrsAsIntsAndExpandMemcpyVisitor(BufferFatPtrToIntTypeMap *TypeMap,
445                                            const DataLayout &DL,
446                                            LLVMContext &Ctx,
447                                            const TargetMachine *TM)
448       : TypeMap(TypeMap), IRB(Ctx, InstSimplifyFolder(DL)), TM(TM) {}
449   bool processFunction(Function &F);
450 
visitInstruction(Instruction & I)451   bool visitInstruction(Instruction &I) { return false; }
452   bool visitAllocaInst(AllocaInst &I);
453   bool visitLoadInst(LoadInst &LI);
454   bool visitStoreInst(StoreInst &SI);
455   bool visitGetElementPtrInst(GetElementPtrInst &I);
456 
457   bool visitMemCpyInst(MemCpyInst &MCI);
458   bool visitMemMoveInst(MemMoveInst &MMI);
459   bool visitMemSetInst(MemSetInst &MSI);
460   bool visitMemSetPatternInst(MemSetPatternInst &MSPI);
461 };
462 } // namespace
463 
fatPtrsToInts(Value * V,Type * From,Type * To,const Twine & Name)464 Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::fatPtrsToInts(
465     Value *V, Type *From, Type *To, const Twine &Name) {
466   if (From == To)
467     return V;
468   ValueToValueMapTy::iterator Find = ConvertedForStore.find(V);
469   if (Find != ConvertedForStore.end())
470     return Find->second;
471   if (isBufferFatPtrOrVector(From)) {
472     Value *Cast = IRB.CreatePtrToInt(V, To, Name + ".int");
473     ConvertedForStore[V] = Cast;
474     return Cast;
475   }
476   if (From->getNumContainedTypes() == 0)
477     return V;
478   // Structs, arrays, and other compound types.
479   Value *Ret = PoisonValue::get(To);
480   if (auto *AT = dyn_cast<ArrayType>(From)) {
481     Type *FromPart = AT->getArrayElementType();
482     Type *ToPart = cast<ArrayType>(To)->getElementType();
483     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
484       Value *Field = IRB.CreateExtractValue(V, I);
485       Value *NewField =
486           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(I));
487       Ret = IRB.CreateInsertValue(Ret, NewField, I);
488     }
489   } else {
490     for (auto [Idx, FromPart, ToPart] :
491          enumerate(From->subtypes(), To->subtypes())) {
492       Value *Field = IRB.CreateExtractValue(V, Idx);
493       Value *NewField =
494           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(Idx));
495       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
496     }
497   }
498   ConvertedForStore[V] = Ret;
499   return Ret;
500 }
501 
intsToFatPtrs(Value * V,Type * From,Type * To,const Twine & Name)502 Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::intsToFatPtrs(
503     Value *V, Type *From, Type *To, const Twine &Name) {
504   if (From == To)
505     return V;
506   if (isBufferFatPtrOrVector(To)) {
507     Value *Cast = IRB.CreateIntToPtr(V, To, Name + ".ptr");
508     return Cast;
509   }
510   if (From->getNumContainedTypes() == 0)
511     return V;
512   // Structs, arrays, and other compound types.
513   Value *Ret = PoisonValue::get(To);
514   if (auto *AT = dyn_cast<ArrayType>(From)) {
515     Type *FromPart = AT->getArrayElementType();
516     Type *ToPart = cast<ArrayType>(To)->getElementType();
517     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
518       Value *Field = IRB.CreateExtractValue(V, I);
519       Value *NewField =
520           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(I));
521       Ret = IRB.CreateInsertValue(Ret, NewField, I);
522     }
523   } else {
524     for (auto [Idx, FromPart, ToPart] :
525          enumerate(From->subtypes(), To->subtypes())) {
526       Value *Field = IRB.CreateExtractValue(V, Idx);
527       Value *NewField =
528           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(Idx));
529       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
530     }
531   }
532   return Ret;
533 }
534 
processFunction(Function & F)535 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::processFunction(Function &F) {
536   bool Changed = false;
537   // Process memcpy-like instructions after the main iteration because they can
538   // invalidate iterators.
539   SmallVector<WeakTrackingVH> CanBecomeLoops;
540   for (Instruction &I : make_early_inc_range(instructions(F))) {
541     if (isa<MemTransferInst, MemSetInst, MemSetPatternInst>(I))
542       CanBecomeLoops.push_back(&I);
543     else
544       Changed |= visit(I);
545   }
546   for (WeakTrackingVH VH : make_early_inc_range(CanBecomeLoops)) {
547     Changed |= visit(cast<Instruction>(VH));
548   }
549   ConvertedForStore.clear();
550   return Changed;
551 }
552 
visitAllocaInst(AllocaInst & I)553 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitAllocaInst(AllocaInst &I) {
554   Type *Ty = I.getAllocatedType();
555   Type *NewTy = TypeMap->remapType(Ty);
556   if (Ty == NewTy)
557     return false;
558   I.setAllocatedType(NewTy);
559   return true;
560 }
561 
visitGetElementPtrInst(GetElementPtrInst & I)562 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitGetElementPtrInst(
563     GetElementPtrInst &I) {
564   Type *Ty = I.getSourceElementType();
565   Type *NewTy = TypeMap->remapType(Ty);
566   if (Ty == NewTy)
567     return false;
568   // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so
569   // make sure GEPs don't have different semantics with the new type.
570   I.setSourceElementType(NewTy);
571   I.setResultElementType(TypeMap->remapType(I.getResultElementType()));
572   return true;
573 }
574 
visitLoadInst(LoadInst & LI)575 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitLoadInst(LoadInst &LI) {
576   Type *Ty = LI.getType();
577   Type *IntTy = TypeMap->remapType(Ty);
578   if (Ty == IntTy)
579     return false;
580 
581   IRB.SetInsertPoint(&LI);
582   auto *NLI = cast<LoadInst>(LI.clone());
583   NLI->mutateType(IntTy);
584   NLI = IRB.Insert(NLI);
585   NLI->takeName(&LI);
586 
587   Value *CastBack = intsToFatPtrs(NLI, IntTy, Ty, NLI->getName());
588   LI.replaceAllUsesWith(CastBack);
589   LI.eraseFromParent();
590   return true;
591 }
592 
visitStoreInst(StoreInst & SI)593 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitStoreInst(StoreInst &SI) {
594   Value *V = SI.getValueOperand();
595   Type *Ty = V->getType();
596   Type *IntTy = TypeMap->remapType(Ty);
597   if (Ty == IntTy)
598     return false;
599 
600   IRB.SetInsertPoint(&SI);
601   Value *IntV = fatPtrsToInts(V, Ty, IntTy, V->getName());
602   for (auto *Dbg : at::getAssignmentMarkers(&SI))
603     Dbg->setValue(IntV);
604 
605   SI.setOperand(0, IntV);
606   return true;
607 }
608 
visitMemCpyInst(MemCpyInst & MCI)609 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemCpyInst(
610     MemCpyInst &MCI) {
611   // TODO: Allow memcpy.p7.p3 as a synonym for the direct-to-LDS copy, which'll
612   // need loop expansion here.
613   if (MCI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
614       MCI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
615     return false;
616   llvm::expandMemCpyAsLoop(&MCI,
617                            TM->getTargetTransformInfo(*MCI.getFunction()));
618   MCI.eraseFromParent();
619   return true;
620 }
621 
visitMemMoveInst(MemMoveInst & MMI)622 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemMoveInst(
623     MemMoveInst &MMI) {
624   if (MMI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
625       MMI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
626     return false;
627   reportFatalUsageError(
628       "memmove() on buffer descriptors is not implemented because pointer "
629       "comparison on buffer descriptors isn't implemented\n");
630 }
631 
visitMemSetInst(MemSetInst & MSI)632 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetInst(
633     MemSetInst &MSI) {
634   if (MSI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
635     return false;
636   llvm::expandMemSetAsLoop(&MSI);
637   MSI.eraseFromParent();
638   return true;
639 }
640 
visitMemSetPatternInst(MemSetPatternInst & MSPI)641 bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetPatternInst(
642     MemSetPatternInst &MSPI) {
643   if (MSPI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
644     return false;
645   llvm::expandMemSetPatternAsLoop(&MSPI);
646   MSPI.eraseFromParent();
647   return true;
648 }
649 
650 namespace {
651 /// Convert loads/stores of types that the buffer intrinsics can't handle into
652 /// one ore more such loads/stores that consist of legal types.
653 ///
654 /// Do this by
655 /// 1. Recursing into structs (and arrays that don't share a memory layout with
656 /// vectors) since the intrinsics can't handle complex types.
657 /// 2. Converting arrays of non-aggregate, byte-sized types into their
658 /// corresponding vectors
659 /// 3. Bitcasting unsupported types, namely overly-long scalars and byte
660 /// vectors, into vectors of supported types.
661 /// 4. Splitting up excessively long reads/writes into multiple operations.
662 ///
663 /// Note that this doesn't handle complex data strucures, but, in the future,
664 /// the aggregate load splitter from SROA could be refactored to allow for that
665 /// case.
666 class LegalizeBufferContentTypesVisitor
667     : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
668   friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
669 
670   IRBuilder<InstSimplifyFolder> IRB;
671 
672   const DataLayout &DL;
673 
674   /// If T is [N x U], where U is a scalar type, return the vector type
675   /// <N x U>, otherwise, return T.
676   Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
677   Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
678   Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
679 
680   /// Break up the loads of a struct into the loads of its components
681 
682   /// Convert a vector or scalar type that can't be operated on by buffer
683   /// intrinsics to one that would be legal through bitcasts and/or truncation.
684   /// Uses the wider of i32, i16, or i8 where possible.
685   Type *legalNonAggregateFor(Type *T);
686   Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
687   Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
688 
689   struct VecSlice {
690     uint64_t Index = 0;
691     uint64_t Length = 0;
692     VecSlice() = delete;
693     // Needed for some Clangs
VecSlice__anone5d2dd830411::LegalizeBufferContentTypesVisitor::VecSlice694     VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {}
695   };
696   /// Return the [index, length] pairs into which `T` needs to be cut to form
697   /// legal buffer load or store operations. Clears `Slices`. Creates an empty
698   /// `Slices` for non-vector inputs and creates one slice if no slicing will be
699   /// needed.
700   void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
701 
702   Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
703   Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
704 
705   /// In most cases, return `LegalType`. However, when given an input that would
706   /// normally be a legal type for the buffer intrinsics to return but that
707   /// isn't hooked up through SelectionDAG, return a type of the same width that
708   /// can be used with the relevant intrinsics. Specifically, handle the cases:
709   /// - <1 x T> => T for all T
710   /// - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
711   /// - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
712   /// i32>
713   Type *intrinsicTypeFor(Type *LegalType);
714 
715   bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
716                      SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
717                      Value *&Result, const Twine &Name);
718   /// Return value is (Changed, ModifiedInPlace)
719   std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
720                                        SmallVectorImpl<uint32_t> &AggIdxs,
721                                        uint64_t AggByteOffset,
722                                        const Twine &Name);
723 
visitInstruction(Instruction & I)724   bool visitInstruction(Instruction &I) { return false; }
725   bool visitLoadInst(LoadInst &LI);
726   bool visitStoreInst(StoreInst &SI);
727 
728 public:
LegalizeBufferContentTypesVisitor(const DataLayout & DL,LLVMContext & Ctx)729   LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
730       : IRB(Ctx, InstSimplifyFolder(DL)), DL(DL) {}
731   bool processFunction(Function &F);
732 };
733 } // namespace
734 
scalarArrayTypeAsVector(Type * T)735 Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
736   ArrayType *AT = dyn_cast<ArrayType>(T);
737   if (!AT)
738     return T;
739   Type *ET = AT->getElementType();
740   if (!ET->isSingleValueType() || isa<VectorType>(ET))
741     reportFatalUsageError("loading non-scalar arrays from buffer fat pointers "
742                           "should have recursed");
743   if (!DL.typeSizeEqualsStoreSize(AT))
744     reportFatalUsageError(
745         "loading padded arrays from buffer fat pinters should have recursed");
746   return FixedVectorType::get(ET, AT->getNumElements());
747 }
748 
arrayToVector(Value * V,Type * TargetType,const Twine & Name)749 Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
750                                                         Type *TargetType,
751                                                         const Twine &Name) {
752   Value *VectorRes = PoisonValue::get(TargetType);
753   auto *VT = cast<FixedVectorType>(TargetType);
754   unsigned EC = VT->getNumElements();
755   for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
756     Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I));
757     VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I,
758                                         Name + ".as.vec." + Twine(I));
759   }
760   return VectorRes;
761 }
762 
vectorToArray(Value * V,Type * OrigType,const Twine & Name)763 Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
764                                                         Type *OrigType,
765                                                         const Twine &Name) {
766   Value *ArrayRes = PoisonValue::get(OrigType);
767   ArrayType *AT = cast<ArrayType>(OrigType);
768   unsigned EC = AT->getNumElements();
769   for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
770     Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I));
771     ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I,
772                                      Name + ".as.array." + Twine(I));
773   }
774   return ArrayRes;
775 }
776 
legalNonAggregateFor(Type * T)777 Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
778   TypeSize Size = DL.getTypeStoreSizeInBits(T);
779   // Implicitly zero-extend to the next byte if needed
780   if (!DL.typeSizeEqualsStoreSize(T))
781     T = IRB.getIntNTy(Size.getFixedValue());
782   Type *ElemTy = T->getScalarType();
783   if (isa<PointerType, ScalableVectorType>(ElemTy)) {
784     // Pointers are always big enough, and we'll let scalable vectors through to
785     // fail in codegen.
786     return T;
787   }
788   unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
789   if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
790     // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
791     // legal buffer operations.
792     return T;
793   }
794   Type *BestVectorElemType = nullptr;
795   if (Size.isKnownMultipleOf(32))
796     BestVectorElemType = IRB.getInt32Ty();
797   else if (Size.isKnownMultipleOf(16))
798     BestVectorElemType = IRB.getInt16Ty();
799   else
800     BestVectorElemType = IRB.getInt8Ty();
801   unsigned NumCastElems =
802       Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
803   if (NumCastElems == 1)
804     return BestVectorElemType;
805   return FixedVectorType::get(BestVectorElemType, NumCastElems);
806 }
807 
makeLegalNonAggregate(Value * V,Type * TargetType,const Twine & Name)808 Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
809     Value *V, Type *TargetType, const Twine &Name) {
810   Type *SourceType = V->getType();
811   TypeSize SourceSize = DL.getTypeSizeInBits(SourceType);
812   TypeSize TargetSize = DL.getTypeSizeInBits(TargetType);
813   if (SourceSize != TargetSize) {
814     Type *ShortScalarTy = IRB.getIntNTy(SourceSize.getFixedValue());
815     Type *ByteScalarTy = IRB.getIntNTy(TargetSize.getFixedValue());
816     Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar");
817     Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext");
818     V = Zext;
819     SourceType = ByteScalarTy;
820   }
821   return IRB.CreateBitCast(V, TargetType, Name + ".legal");
822 }
823 
makeIllegalNonAggregate(Value * V,Type * OrigType,const Twine & Name)824 Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
825     Value *V, Type *OrigType, const Twine &Name) {
826   Type *LegalType = V->getType();
827   TypeSize LegalSize = DL.getTypeSizeInBits(LegalType);
828   TypeSize OrigSize = DL.getTypeSizeInBits(OrigType);
829   if (LegalSize != OrigSize) {
830     Type *ShortScalarTy = IRB.getIntNTy(OrigSize.getFixedValue());
831     Type *ByteScalarTy = IRB.getIntNTy(LegalSize.getFixedValue());
832     Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast");
833     Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc");
834     return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig");
835   }
836   return IRB.CreateBitCast(V, OrigType, Name + ".real.ty");
837 }
838 
intrinsicTypeFor(Type * LegalType)839 Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
840   auto *VT = dyn_cast<FixedVectorType>(LegalType);
841   if (!VT)
842     return LegalType;
843   Type *ET = VT->getElementType();
844   // Explicitly return the element type of 1-element vectors because the
845   // underlying intrinsics don't like <1 x T> even though it's a synonym for T.
846   if (VT->getNumElements() == 1)
847     return ET;
848   if (DL.getTypeSizeInBits(LegalType) == 96 && DL.getTypeSizeInBits(ET) < 32)
849     return FixedVectorType::get(IRB.getInt32Ty(), 3);
850   if (ET->isIntegerTy(8)) {
851     switch (VT->getNumElements()) {
852     default:
853       return LegalType; // Let it crash later
854     case 1:
855       return IRB.getInt8Ty();
856     case 2:
857       return IRB.getInt16Ty();
858     case 4:
859       return IRB.getInt32Ty();
860     case 8:
861       return FixedVectorType::get(IRB.getInt32Ty(), 2);
862     case 16:
863       return FixedVectorType::get(IRB.getInt32Ty(), 4);
864     }
865   }
866   return LegalType;
867 }
868 
getVecSlices(Type * T,SmallVectorImpl<VecSlice> & Slices)869 void LegalizeBufferContentTypesVisitor::getVecSlices(
870     Type *T, SmallVectorImpl<VecSlice> &Slices) {
871   Slices.clear();
872   auto *VT = dyn_cast<FixedVectorType>(T);
873   if (!VT)
874     return;
875 
876   uint64_t ElemBitWidth =
877       DL.getTypeSizeInBits(VT->getElementType()).getFixedValue();
878 
879   uint64_t ElemsPer4Words = 128 / ElemBitWidth;
880   uint64_t ElemsPer2Words = ElemsPer4Words / 2;
881   uint64_t ElemsPerWord = ElemsPer2Words / 2;
882   uint64_t ElemsPerShort = ElemsPerWord / 2;
883   uint64_t ElemsPerByte = ElemsPerShort / 2;
884   // If the elements evenly pack into 32-bit words, we can use 3-word stores,
885   // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
886   // example, <3 x i64>, since that's not slicing.
887   uint64_t ElemsPer3Words = ElemsPerWord * 3;
888 
889   uint64_t TotalElems = VT->getNumElements();
890   uint64_t Index = 0;
891   auto TrySlice = [&](unsigned MaybeLen) {
892     if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
893       VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen};
894       Slices.push_back(Slice);
895       Index += MaybeLen;
896       return true;
897     }
898     return false;
899   };
900   while (Index < TotalElems) {
901     TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
902         TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
903         TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
904   }
905 }
906 
extractSlice(Value * Vec,VecSlice S,const Twine & Name)907 Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
908                                                        const Twine &Name) {
909   auto *VecVT = dyn_cast<FixedVectorType>(Vec->getType());
910   if (!VecVT)
911     return Vec;
912   if (S.Length == VecVT->getNumElements() && S.Index == 0)
913     return Vec;
914   if (S.Length == 1)
915     return IRB.CreateExtractElement(Vec, S.Index,
916                                     Name + ".slice." + Twine(S.Index));
917   SmallVector<int> Mask = llvm::to_vector(
918       llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false));
919   return IRB.CreateShuffleVector(Vec, Mask, Name + ".slice." + Twine(S.Index));
920 }
921 
insertSlice(Value * Whole,Value * Part,VecSlice S,const Twine & Name)922 Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
923                                                       VecSlice S,
924                                                       const Twine &Name) {
925   auto *WholeVT = dyn_cast<FixedVectorType>(Whole->getType());
926   if (!WholeVT)
927     return Part;
928   if (S.Length == WholeVT->getNumElements() && S.Index == 0)
929     return Part;
930   if (S.Length == 1) {
931     return IRB.CreateInsertElement(Whole, Part, S.Index,
932                                    Name + ".slice." + Twine(S.Index));
933   }
934   int NumElems = cast<FixedVectorType>(Whole->getType())->getNumElements();
935 
936   // Extend the slice with poisons to make the main shufflevector happy.
937   SmallVector<int> ExtPartMask(NumElems, -1);
938   for (auto [I, E] : llvm::enumerate(
939            MutableArrayRef<int>(ExtPartMask).take_front(S.Length))) {
940     E = I;
941   }
942   Value *ExtPart = IRB.CreateShuffleVector(Part, ExtPartMask,
943                                            Name + ".ext." + Twine(S.Index));
944 
945   SmallVector<int> Mask =
946       llvm::to_vector(llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false));
947   for (auto [I, E] :
948        llvm::enumerate(MutableArrayRef<int>(Mask).slice(S.Index, S.Length)))
949     E = I + NumElems;
950   return IRB.CreateShuffleVector(Whole, ExtPart, Mask,
951                                  Name + ".parts." + Twine(S.Index));
952 }
953 
visitLoadImpl(LoadInst & OrigLI,Type * PartType,SmallVectorImpl<uint32_t> & AggIdxs,uint64_t AggByteOff,Value * & Result,const Twine & Name)954 bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
955     LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
956     uint64_t AggByteOff, Value *&Result, const Twine &Name) {
957   if (auto *ST = dyn_cast<StructType>(PartType)) {
958     const StructLayout *Layout = DL.getStructLayout(ST);
959     bool Changed = false;
960     for (auto [I, ElemTy, Offset] :
961          llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
962       AggIdxs.push_back(I);
963       Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
964                                AggByteOff + Offset.getFixedValue(), Result,
965                                Name + "." + Twine(I));
966       AggIdxs.pop_back();
967     }
968     return Changed;
969   }
970   if (auto *AT = dyn_cast<ArrayType>(PartType)) {
971     Type *ElemTy = AT->getElementType();
972     if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) ||
973         ElemTy->isVectorTy()) {
974       TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy);
975       bool Changed = false;
976       for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
977                                                /*Inclusive=*/false)) {
978         AggIdxs.push_back(I);
979         Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
980                                  AggByteOff + I * ElemStoreSize.getFixedValue(),
981                                  Result, Name + Twine(I));
982         AggIdxs.pop_back();
983       }
984       return Changed;
985     }
986   }
987 
988   // Typical case
989 
990   Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
991   Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
992 
993   SmallVector<VecSlice> Slices;
994   getVecSlices(LegalType, Slices);
995   bool HasSlices = Slices.size() > 1;
996   bool IsAggPart = !AggIdxs.empty();
997   Value *LoadsRes;
998   if (!HasSlices && !IsAggPart) {
999     Type *LoadableType = intrinsicTypeFor(LegalType);
1000     if (LoadableType == PartType)
1001       return false;
1002 
1003     IRB.SetInsertPoint(&OrigLI);
1004     auto *NLI = cast<LoadInst>(OrigLI.clone());
1005     NLI->mutateType(LoadableType);
1006     NLI = IRB.Insert(NLI);
1007     NLI->setName(Name + ".loadable");
1008 
1009     LoadsRes = IRB.CreateBitCast(NLI, LegalType, Name + ".from.loadable");
1010   } else {
1011     IRB.SetInsertPoint(&OrigLI);
1012     LoadsRes = PoisonValue::get(LegalType);
1013     Value *OrigPtr = OrigLI.getPointerOperand();
1014     // If we're needing to spill something into more than one load, its legal
1015     // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>).
1016     // But if we're already a scalar (which can happen if we're splitting up a
1017     // struct), the element type will be the legal type itself.
1018     Type *ElemType = LegalType->getScalarType();
1019     unsigned ElemBytes = DL.getTypeStoreSize(ElemType);
1020     AAMDNodes AANodes = OrigLI.getAAMetadata();
1021     if (IsAggPart && Slices.empty())
1022       Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1});
1023     for (VecSlice S : Slices) {
1024       Type *SliceType =
1025           S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
1026       int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1027       // You can't reasonably expect loads to wrap around the edge of memory.
1028       Value *NewPtr = IRB.CreateGEP(
1029           IRB.getInt8Ty(), OrigLI.getPointerOperand(), IRB.getInt32(ByteOffset),
1030           OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
1031           GEPNoWrapFlags::noUnsignedWrap());
1032       Type *LoadableType = intrinsicTypeFor(SliceType);
1033       LoadInst *NewLI = IRB.CreateAlignedLoad(
1034           LoadableType, NewPtr, commonAlignment(OrigLI.getAlign(), ByteOffset),
1035           Name + ".off." + Twine(ByteOffset));
1036       copyMetadataForLoad(*NewLI, OrigLI);
1037       NewLI->setAAMetadata(
1038           AANodes.adjustForAccess(ByteOffset, LoadableType, DL));
1039       NewLI->setAtomic(OrigLI.getOrdering(), OrigLI.getSyncScopeID());
1040       NewLI->setVolatile(OrigLI.isVolatile());
1041       Value *Loaded = IRB.CreateBitCast(NewLI, SliceType,
1042                                         NewLI->getName() + ".from.loadable");
1043       LoadsRes = insertSlice(LoadsRes, Loaded, S, Name);
1044     }
1045   }
1046   if (LegalType != ArrayAsVecType)
1047     LoadsRes = makeIllegalNonAggregate(LoadsRes, ArrayAsVecType, Name);
1048   if (ArrayAsVecType != PartType)
1049     LoadsRes = vectorToArray(LoadsRes, PartType, Name);
1050 
1051   if (IsAggPart)
1052     Result = IRB.CreateInsertValue(Result, LoadsRes, AggIdxs, Name);
1053   else
1054     Result = LoadsRes;
1055   return true;
1056 }
1057 
visitLoadInst(LoadInst & LI)1058 bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) {
1059   if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1060     return false;
1061 
1062   SmallVector<uint32_t> AggIdxs;
1063   Type *OrigType = LI.getType();
1064   Value *Result = PoisonValue::get(OrigType);
1065   bool Changed = visitLoadImpl(LI, OrigType, AggIdxs, 0, Result, LI.getName());
1066   if (!Changed)
1067     return false;
1068   Result->takeName(&LI);
1069   LI.replaceAllUsesWith(Result);
1070   LI.eraseFromParent();
1071   return Changed;
1072 }
1073 
visitStoreImpl(StoreInst & OrigSI,Type * PartType,SmallVectorImpl<uint32_t> & AggIdxs,uint64_t AggByteOff,const Twine & Name)1074 std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
1075     StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
1076     uint64_t AggByteOff, const Twine &Name) {
1077   if (auto *ST = dyn_cast<StructType>(PartType)) {
1078     const StructLayout *Layout = DL.getStructLayout(ST);
1079     bool Changed = false;
1080     for (auto [I, ElemTy, Offset] :
1081          llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
1082       AggIdxs.push_back(I);
1083       Changed |= std::get<0>(visitStoreImpl(OrigSI, ElemTy, AggIdxs,
1084                                             AggByteOff + Offset.getFixedValue(),
1085                                             Name + "." + Twine(I)));
1086       AggIdxs.pop_back();
1087     }
1088     return std::make_pair(Changed, /*ModifiedInPlace=*/false);
1089   }
1090   if (auto *AT = dyn_cast<ArrayType>(PartType)) {
1091     Type *ElemTy = AT->getElementType();
1092     if (!ElemTy->isSingleValueType() || !DL.typeSizeEqualsStoreSize(ElemTy) ||
1093         ElemTy->isVectorTy()) {
1094       TypeSize ElemStoreSize = DL.getTypeStoreSize(ElemTy);
1095       bool Changed = false;
1096       for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
1097                                                /*Inclusive=*/false)) {
1098         AggIdxs.push_back(I);
1099         Changed |= std::get<0>(visitStoreImpl(
1100             OrigSI, ElemTy, AggIdxs,
1101             AggByteOff + I * ElemStoreSize.getFixedValue(), Name + Twine(I)));
1102         AggIdxs.pop_back();
1103       }
1104       return std::make_pair(Changed, /*ModifiedInPlace=*/false);
1105     }
1106   }
1107 
1108   Value *OrigData = OrigSI.getValueOperand();
1109   Value *NewData = OrigData;
1110 
1111   bool IsAggPart = !AggIdxs.empty();
1112   if (IsAggPart)
1113     NewData = IRB.CreateExtractValue(NewData, AggIdxs, Name);
1114 
1115   Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
1116   if (ArrayAsVecType != PartType) {
1117     NewData = arrayToVector(NewData, ArrayAsVecType, Name);
1118   }
1119 
1120   Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
1121   if (LegalType != ArrayAsVecType) {
1122     NewData = makeLegalNonAggregate(NewData, LegalType, Name);
1123   }
1124 
1125   SmallVector<VecSlice> Slices;
1126   getVecSlices(LegalType, Slices);
1127   bool NeedToSplit = Slices.size() > 1 || IsAggPart;
1128   if (!NeedToSplit) {
1129     Type *StorableType = intrinsicTypeFor(LegalType);
1130     if (StorableType == PartType)
1131       return std::make_pair(/*Changed=*/false, /*ModifiedInPlace=*/false);
1132     NewData = IRB.CreateBitCast(NewData, StorableType, Name + ".storable");
1133     OrigSI.setOperand(0, NewData);
1134     return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/true);
1135   }
1136 
1137   Value *OrigPtr = OrigSI.getPointerOperand();
1138   Type *ElemType = LegalType->getScalarType();
1139   if (IsAggPart && Slices.empty())
1140     Slices.push_back(VecSlice{/*Index=*/0, /*Length=*/1});
1141   unsigned ElemBytes = DL.getTypeStoreSize(ElemType);
1142   AAMDNodes AANodes = OrigSI.getAAMetadata();
1143   for (VecSlice S : Slices) {
1144     Type *SliceType =
1145         S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
1146     int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
1147     Value *NewPtr =
1148         IRB.CreateGEP(IRB.getInt8Ty(), OrigPtr, IRB.getInt32(ByteOffset),
1149                       OrigPtr->getName() + ".part." + Twine(S.Index),
1150                       GEPNoWrapFlags::noUnsignedWrap());
1151     Value *DataSlice = extractSlice(NewData, S, Name);
1152     Type *StorableType = intrinsicTypeFor(SliceType);
1153     DataSlice = IRB.CreateBitCast(DataSlice, StorableType,
1154                                   DataSlice->getName() + ".storable");
1155     auto *NewSI = cast<StoreInst>(OrigSI.clone());
1156     NewSI->setAlignment(commonAlignment(OrigSI.getAlign(), ByteOffset));
1157     IRB.Insert(NewSI);
1158     NewSI->setOperand(0, DataSlice);
1159     NewSI->setOperand(1, NewPtr);
1160     NewSI->setAAMetadata(AANodes.adjustForAccess(ByteOffset, StorableType, DL));
1161   }
1162   return std::make_pair(/*Changed=*/true, /*ModifiedInPlace=*/false);
1163 }
1164 
visitStoreInst(StoreInst & SI)1165 bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {
1166   if (SI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
1167     return false;
1168   IRB.SetInsertPoint(&SI);
1169   SmallVector<uint32_t> AggIdxs;
1170   Value *OrigData = SI.getValueOperand();
1171   auto [Changed, ModifiedInPlace] =
1172       visitStoreImpl(SI, OrigData->getType(), AggIdxs, 0, OrigData->getName());
1173   if (Changed && !ModifiedInPlace)
1174     SI.eraseFromParent();
1175   return Changed;
1176 }
1177 
processFunction(Function & F)1178 bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
1179   bool Changed = false;
1180   // Note, memory transfer intrinsics won't
1181   for (Instruction &I : make_early_inc_range(instructions(F))) {
1182     Changed |= visit(I);
1183   }
1184   return Changed;
1185 }
1186 
1187 /// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered
1188 /// buffer fat pointer constant.
1189 static std::pair<Constant *, Constant *>
splitLoweredFatBufferConst(Constant * C)1190 splitLoweredFatBufferConst(Constant *C) {
1191   assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
1192   return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u));
1193 }
1194 
1195 namespace {
1196 /// Handle the remapping of ptr addrspace(7) constants.
1197 class FatPtrConstMaterializer final : public ValueMaterializer {
1198   BufferFatPtrToStructTypeMap *TypeMap;
1199   // An internal mapper that is used to recurse into the arguments of constants.
1200   // While the documentation for `ValueMapper` specifies not to use it
1201   // recursively, examination of the logic in mapValue() shows that it can
1202   // safely be used recursively when handling constants, like it does in its own
1203   // logic.
1204   ValueMapper InternalMapper;
1205 
1206   Constant *materializeBufferFatPtrConst(Constant *C);
1207 
1208 public:
1209   // UnderlyingMap is the value map this materializer will be filling.
FatPtrConstMaterializer(BufferFatPtrToStructTypeMap * TypeMap,ValueToValueMapTy & UnderlyingMap)1210   FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
1211                           ValueToValueMapTy &UnderlyingMap)
1212       : TypeMap(TypeMap),
1213         InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
1214   ~FatPtrConstMaterializer() = default;
1215 
1216   Value *materialize(Value *V) override;
1217 };
1218 } // namespace
1219 
materializeBufferFatPtrConst(Constant * C)1220 Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
1221   Type *SrcTy = C->getType();
1222   auto *NewTy = dyn_cast<StructType>(TypeMap->remapType(SrcTy));
1223   if (C->isNullValue())
1224     return ConstantAggregateZero::getNullValue(NewTy);
1225   if (isa<PoisonValue>(C)) {
1226     return ConstantStruct::get(NewTy,
1227                                {PoisonValue::get(NewTy->getElementType(0)),
1228                                 PoisonValue::get(NewTy->getElementType(1))});
1229   }
1230   if (isa<UndefValue>(C)) {
1231     return ConstantStruct::get(NewTy,
1232                                {UndefValue::get(NewTy->getElementType(0)),
1233                                 UndefValue::get(NewTy->getElementType(1))});
1234   }
1235 
1236   if (auto *VC = dyn_cast<ConstantVector>(C)) {
1237     if (Constant *S = VC->getSplatValue()) {
1238       Constant *NewS = InternalMapper.mapConstant(*S);
1239       if (!NewS)
1240         return nullptr;
1241       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewS);
1242       auto EC = VC->getType()->getElementCount();
1243       return ConstantStruct::get(NewTy, {ConstantVector::getSplat(EC, Rsrc),
1244                                          ConstantVector::getSplat(EC, Off)});
1245     }
1246     SmallVector<Constant *> Rsrcs;
1247     SmallVector<Constant *> Offs;
1248     for (Value *Op : VC->operand_values()) {
1249       auto *NewOp = dyn_cast_or_null<Constant>(InternalMapper.mapValue(*Op));
1250       if (!NewOp)
1251         return nullptr;
1252       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewOp);
1253       Rsrcs.push_back(Rsrc);
1254       Offs.push_back(Off);
1255     }
1256     Constant *RsrcVec = ConstantVector::get(Rsrcs);
1257     Constant *OffVec = ConstantVector::get(Offs);
1258     return ConstantStruct::get(NewTy, {RsrcVec, OffVec});
1259   }
1260 
1261   if (isa<GlobalValue>(C))
1262     reportFatalUsageError("global values containing ptr addrspace(7) (buffer "
1263                           "fat pointer) values are not supported");
1264 
1265   if (isa<ConstantExpr>(C))
1266     reportFatalUsageError(
1267         "constant exprs containing ptr addrspace(7) (buffer "
1268         "fat pointer) values should have been expanded earlier");
1269 
1270   return nullptr;
1271 }
1272 
materialize(Value * V)1273 Value *FatPtrConstMaterializer::materialize(Value *V) {
1274   Constant *C = dyn_cast<Constant>(V);
1275   if (!C)
1276     return nullptr;
1277   // Structs and other types that happen to contain fat pointers get remapped
1278   // by the mapValue() logic.
1279   if (!isBufferFatPtrConst(C))
1280     return nullptr;
1281   return materializeBufferFatPtrConst(C);
1282 }
1283 
1284 using PtrParts = std::pair<Value *, Value *>;
1285 namespace {
1286 // The visitor returns the resource and offset parts for an instruction if they
1287 // can be computed, or (nullptr, nullptr) for cases that don't have a meaningful
1288 // value mapping.
1289 class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> {
1290   ValueToValueMapTy RsrcParts;
1291   ValueToValueMapTy OffParts;
1292 
1293   // Track instructions that have been rewritten into a user of the component
1294   // parts of their ptr addrspace(7) input. Instructions that produced
1295   // ptr addrspace(7) parts should **not** be RAUW'd before being added to this
1296   // set, as that replacement will be handled in a post-visit step. However,
1297   // instructions that yield values that aren't fat pointers (ex. ptrtoint)
1298   // should RAUW themselves with new instructions that use the split parts
1299   // of their arguments during processing.
1300   DenseSet<Instruction *> SplitUsers;
1301 
1302   // Nodes that need a second look once we've computed the parts for all other
1303   // instructions to see if, for example, we really need to phi on the resource
1304   // part.
1305   SmallVector<Instruction *> Conditionals;
1306   // Temporary instructions produced while lowering conditionals that should be
1307   // killed.
1308   SmallVector<Instruction *> ConditionalTemps;
1309 
1310   // Subtarget info, needed for determining what cache control bits to set.
1311   const TargetMachine *TM;
1312   const GCNSubtarget *ST = nullptr;
1313 
1314   IRBuilder<InstSimplifyFolder> IRB;
1315 
1316   // Copy metadata between instructions if applicable.
1317   void copyMetadata(Value *Dest, Value *Src);
1318 
1319   // Get the resource and offset parts of the value V, inserting appropriate
1320   // extractvalue calls if needed.
1321   PtrParts getPtrParts(Value *V);
1322 
1323   // Given an instruction that could produce multiple resource parts (a PHI or
1324   // select), collect the set of possible instructions that could have provided
1325   // its resource parts  that it could have (the `Roots`) and the set of
1326   // conditional instructions visited during the search (`Seen`). If, after
1327   // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset
1328   // of `Roots` and `Roots - Seen` contains one element, the resource part of
1329   // that element can replace the resource part of all other elements in `Seen`.
1330   void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots,
1331                             SmallPtrSetImpl<Value *> &Seen);
1332   void processConditionals();
1333 
1334   // If an instruction hav been split into resource and offset parts,
1335   // delete that instruction. If any of its uses have not themselves been split
1336   // into parts (for example, an insertvalue), construct the structure
1337   // that the type rewrites declared should be produced by the dying instruction
1338   // and use that.
1339   // Also, kill the temporary extractvalue operations produced by the two-stage
1340   // lowering of PHIs and conditionals.
1341   void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs);
1342 
1343   void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx);
1344   void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1345   void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
1346   Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty,
1347                           Align Alignment, AtomicOrdering Order,
1348                           bool IsVolatile, SyncScope::ID SSID);
1349 
1350 public:
SplitPtrStructs(const DataLayout & DL,LLVMContext & Ctx,const TargetMachine * TM)1351   SplitPtrStructs(const DataLayout &DL, LLVMContext &Ctx,
1352                   const TargetMachine *TM)
1353       : TM(TM), IRB(Ctx, InstSimplifyFolder(DL)) {}
1354 
1355   void processFunction(Function &F);
1356 
1357   PtrParts visitInstruction(Instruction &I);
1358   PtrParts visitLoadInst(LoadInst &LI);
1359   PtrParts visitStoreInst(StoreInst &SI);
1360   PtrParts visitAtomicRMWInst(AtomicRMWInst &AI);
1361   PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI);
1362   PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP);
1363 
1364   PtrParts visitPtrToIntInst(PtrToIntInst &PI);
1365   PtrParts visitIntToPtrInst(IntToPtrInst &IP);
1366   PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I);
1367   PtrParts visitICmpInst(ICmpInst &Cmp);
1368   PtrParts visitFreezeInst(FreezeInst &I);
1369 
1370   PtrParts visitExtractElementInst(ExtractElementInst &I);
1371   PtrParts visitInsertElementInst(InsertElementInst &I);
1372   PtrParts visitShuffleVectorInst(ShuffleVectorInst &I);
1373 
1374   PtrParts visitPHINode(PHINode &PHI);
1375   PtrParts visitSelectInst(SelectInst &SI);
1376 
1377   PtrParts visitIntrinsicInst(IntrinsicInst &II);
1378 };
1379 } // namespace
1380 
copyMetadata(Value * Dest,Value * Src)1381 void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) {
1382   auto *DestI = dyn_cast<Instruction>(Dest);
1383   auto *SrcI = dyn_cast<Instruction>(Src);
1384 
1385   if (!DestI || !SrcI)
1386     return;
1387 
1388   DestI->copyMetadata(*SrcI);
1389 }
1390 
getPtrParts(Value * V)1391 PtrParts SplitPtrStructs::getPtrParts(Value *V) {
1392   assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts "
1393                                         "of something that wasn't rewritten");
1394   auto *RsrcEntry = &RsrcParts[V];
1395   auto *OffEntry = &OffParts[V];
1396   if (*RsrcEntry && *OffEntry)
1397     return {*RsrcEntry, *OffEntry};
1398 
1399   if (auto *C = dyn_cast<Constant>(V)) {
1400     auto [Rsrc, Off] = splitLoweredFatBufferConst(C);
1401     return {*RsrcEntry = Rsrc, *OffEntry = Off};
1402   }
1403 
1404   IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1405   if (auto *I = dyn_cast<Instruction>(V)) {
1406     LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n");
1407     auto [Rsrc, Off] = visit(*I);
1408     if (Rsrc && Off)
1409       return {*RsrcEntry = Rsrc, *OffEntry = Off};
1410     // We'll be creating the new values after the relevant instruction.
1411     // This instruction generates a value and so isn't a terminator.
1412     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1413     IRB.SetCurrentDebugLocation(I->getDebugLoc());
1414   } else if (auto *A = dyn_cast<Argument>(V)) {
1415     IRB.SetInsertPointPastAllocas(A->getParent());
1416     IRB.SetCurrentDebugLocation(DebugLoc());
1417   }
1418   Value *Rsrc = IRB.CreateExtractValue(V, 0, V->getName() + ".rsrc");
1419   Value *Off = IRB.CreateExtractValue(V, 1, V->getName() + ".off");
1420   return {*RsrcEntry = Rsrc, *OffEntry = Off};
1421 }
1422 
1423 /// Returns the instruction that defines the resource part of the value V.
1424 /// Note that this is not getUnderlyingObject(), since that looks through
1425 /// operations like ptrmask which might modify the resource part.
1426 ///
1427 /// We can limit ourselves to just looking through GEPs followed by looking
1428 /// through addrspacecasts because only those two operations preserve the
1429 /// resource part, and because operations on an `addrspace(8)` (which is the
1430 /// legal input to this addrspacecast) would produce a different resource part.
rsrcPartRoot(Value * V)1431 static Value *rsrcPartRoot(Value *V) {
1432   while (auto *GEP = dyn_cast<GEPOperator>(V))
1433     V = GEP->getPointerOperand();
1434   while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V))
1435     V = ASC->getPointerOperand();
1436   return V;
1437 }
1438 
getPossibleRsrcRoots(Instruction * I,SmallPtrSetImpl<Value * > & Roots,SmallPtrSetImpl<Value * > & Seen)1439 void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I,
1440                                            SmallPtrSetImpl<Value *> &Roots,
1441                                            SmallPtrSetImpl<Value *> &Seen) {
1442   if (auto *PHI = dyn_cast<PHINode>(I)) {
1443     if (!Seen.insert(I).second)
1444       return;
1445     for (Value *In : PHI->incoming_values()) {
1446       In = rsrcPartRoot(In);
1447       Roots.insert(In);
1448       if (isa<PHINode, SelectInst>(In))
1449         getPossibleRsrcRoots(cast<Instruction>(In), Roots, Seen);
1450     }
1451   } else if (auto *SI = dyn_cast<SelectInst>(I)) {
1452     if (!Seen.insert(SI).second)
1453       return;
1454     Value *TrueVal = rsrcPartRoot(SI->getTrueValue());
1455     Value *FalseVal = rsrcPartRoot(SI->getFalseValue());
1456     Roots.insert(TrueVal);
1457     Roots.insert(FalseVal);
1458     if (isa<PHINode, SelectInst>(TrueVal))
1459       getPossibleRsrcRoots(cast<Instruction>(TrueVal), Roots, Seen);
1460     if (isa<PHINode, SelectInst>(FalseVal))
1461       getPossibleRsrcRoots(cast<Instruction>(FalseVal), Roots, Seen);
1462   } else {
1463     llvm_unreachable("getPossibleRsrcParts() only works on phi and select");
1464   }
1465 }
1466 
processConditionals()1467 void SplitPtrStructs::processConditionals() {
1468   SmallDenseMap<Value *, Value *> FoundRsrcs;
1469   SmallPtrSet<Value *, 4> Roots;
1470   SmallPtrSet<Value *, 4> Seen;
1471   for (Instruction *I : Conditionals) {
1472     // These have to exist by now because we've visited these nodes.
1473     Value *Rsrc = RsrcParts[I];
1474     Value *Off = OffParts[I];
1475     assert(Rsrc && Off && "must have visited conditionals by now");
1476 
1477     std::optional<Value *> MaybeRsrc;
1478     auto MaybeFoundRsrc = FoundRsrcs.find(I);
1479     if (MaybeFoundRsrc != FoundRsrcs.end()) {
1480       MaybeRsrc = MaybeFoundRsrc->second;
1481     } else {
1482       IRBuilder<InstSimplifyFolder>::InsertPointGuard Guard(IRB);
1483       Roots.clear();
1484       Seen.clear();
1485       getPossibleRsrcRoots(I, Roots, Seen);
1486       LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n");
1487 #ifndef NDEBUG
1488       for (Value *V : Roots)
1489         LLVM_DEBUG(dbgs() << "Root: " << *V << "\n");
1490       for (Value *V : Seen)
1491         LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n");
1492 #endif
1493       // If we are our own possible root, then we shouldn't block our
1494       // replacement with a valid incoming value.
1495       Roots.erase(I);
1496       // We don't want to block the optimization for conditionals that don't
1497       // refer to themselves but did see themselves during the traversal.
1498       Seen.erase(I);
1499 
1500       if (set_is_subset(Seen, Roots)) {
1501         auto Diff = set_difference(Roots, Seen);
1502         if (Diff.size() == 1) {
1503           Value *RootVal = *Diff.begin();
1504           // Handle the case where previous loops already looked through
1505           // an addrspacecast.
1506           if (isSplitFatPtr(RootVal->getType()))
1507             MaybeRsrc = std::get<0>(getPtrParts(RootVal));
1508           else
1509             MaybeRsrc = RootVal;
1510         }
1511       }
1512     }
1513 
1514     if (auto *PHI = dyn_cast<PHINode>(I)) {
1515       Value *NewRsrc;
1516       StructType *PHITy = cast<StructType>(PHI->getType());
1517       IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef());
1518       IRB.SetCurrentDebugLocation(PHI->getDebugLoc());
1519       if (MaybeRsrc) {
1520         NewRsrc = *MaybeRsrc;
1521       } else {
1522         Type *RsrcTy = PHITy->getElementType(0);
1523         auto *RsrcPHI = IRB.CreatePHI(RsrcTy, PHI->getNumIncomingValues());
1524         RsrcPHI->takeName(Rsrc);
1525         for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
1526           Value *VRsrc = std::get<0>(getPtrParts(V));
1527           RsrcPHI->addIncoming(VRsrc, BB);
1528         }
1529         copyMetadata(RsrcPHI, PHI);
1530         NewRsrc = RsrcPHI;
1531       }
1532 
1533       Type *OffTy = PHITy->getElementType(1);
1534       auto *NewOff = IRB.CreatePHI(OffTy, PHI->getNumIncomingValues());
1535       NewOff->takeName(Off);
1536       for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
1537         assert(OffParts.count(V) && "An offset part had to be created by now");
1538         Value *VOff = std::get<1>(getPtrParts(V));
1539         NewOff->addIncoming(VOff, BB);
1540       }
1541       copyMetadata(NewOff, PHI);
1542 
1543       // Note: We don't eraseFromParent() the temporaries because we don't want
1544       // to put the corrections maps in an inconstent state. That'll be handed
1545       // during the rest of the killing. Also, `ValueToValueMapTy` guarantees
1546       // that references in that map will be updated as well.
1547       // Note that if the temporary instruction got `InstSimplify`'d away, it
1548       // might be something like a block argument.
1549       if (auto *RsrcInst = dyn_cast<Instruction>(Rsrc)) {
1550         ConditionalTemps.push_back(RsrcInst);
1551         RsrcInst->replaceAllUsesWith(NewRsrc);
1552       }
1553       if (auto *OffInst = dyn_cast<Instruction>(Off)) {
1554         ConditionalTemps.push_back(OffInst);
1555         OffInst->replaceAllUsesWith(NewOff);
1556       }
1557 
1558       // Save on recomputing the cycle traversals in known-root cases.
1559       if (MaybeRsrc)
1560         for (Value *V : Seen)
1561           FoundRsrcs[V] = NewRsrc;
1562     } else if (isa<SelectInst>(I)) {
1563       if (MaybeRsrc) {
1564         if (auto *RsrcInst = dyn_cast<Instruction>(Rsrc)) {
1565           ConditionalTemps.push_back(RsrcInst);
1566           RsrcInst->replaceAllUsesWith(*MaybeRsrc);
1567         }
1568         for (Value *V : Seen)
1569           FoundRsrcs[V] = *MaybeRsrc;
1570       }
1571     } else {
1572       llvm_unreachable("Only PHIs and selects go in the conditionals list");
1573     }
1574   }
1575 }
1576 
killAndReplaceSplitInstructions(SmallVectorImpl<Instruction * > & Origs)1577 void SplitPtrStructs::killAndReplaceSplitInstructions(
1578     SmallVectorImpl<Instruction *> &Origs) {
1579   for (Instruction *I : ConditionalTemps)
1580     I->eraseFromParent();
1581 
1582   for (Instruction *I : Origs) {
1583     if (!SplitUsers.contains(I))
1584       continue;
1585 
1586     SmallVector<DbgValueInst *> Dbgs;
1587     findDbgValues(Dbgs, I);
1588     for (auto *Dbg : Dbgs) {
1589       IRB.SetInsertPoint(Dbg);
1590       auto &DL = I->getDataLayout();
1591       assert(isSplitFatPtr(I->getType()) &&
1592              "We should've RAUW'd away loads, stores, etc. at this point");
1593       auto *OffDbg = cast<DbgValueInst>(Dbg->clone());
1594       copyMetadata(OffDbg, Dbg);
1595       auto [Rsrc, Off] = getPtrParts(I);
1596 
1597       int64_t RsrcSz = DL.getTypeSizeInBits(Rsrc->getType());
1598       int64_t OffSz = DL.getTypeSizeInBits(Off->getType());
1599 
1600       std::optional<DIExpression *> RsrcExpr =
1601           DIExpression::createFragmentExpression(Dbg->getExpression(), 0,
1602                                                  RsrcSz);
1603       std::optional<DIExpression *> OffExpr =
1604           DIExpression::createFragmentExpression(Dbg->getExpression(), RsrcSz,
1605                                                  OffSz);
1606       if (OffExpr) {
1607         OffDbg->setExpression(*OffExpr);
1608         OffDbg->replaceVariableLocationOp(I, Off);
1609         IRB.Insert(OffDbg);
1610       } else {
1611         OffDbg->deleteValue();
1612       }
1613       if (RsrcExpr) {
1614         Dbg->setExpression(*RsrcExpr);
1615         Dbg->replaceVariableLocationOp(I, Rsrc);
1616       } else {
1617         Dbg->replaceVariableLocationOp(I, PoisonValue::get(I->getType()));
1618       }
1619     }
1620 
1621     Value *Poison = PoisonValue::get(I->getType());
1622     I->replaceUsesWithIf(Poison, [&](const Use &U) -> bool {
1623       if (const auto *UI = dyn_cast<Instruction>(U.getUser()))
1624         return SplitUsers.contains(UI);
1625       return false;
1626     });
1627 
1628     if (I->use_empty()) {
1629       I->eraseFromParent();
1630       continue;
1631     }
1632     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1633     IRB.SetCurrentDebugLocation(I->getDebugLoc());
1634     auto [Rsrc, Off] = getPtrParts(I);
1635     Value *Struct = PoisonValue::get(I->getType());
1636     Struct = IRB.CreateInsertValue(Struct, Rsrc, 0);
1637     Struct = IRB.CreateInsertValue(Struct, Off, 1);
1638     copyMetadata(Struct, I);
1639     Struct->takeName(I);
1640     I->replaceAllUsesWith(Struct);
1641     I->eraseFromParent();
1642   }
1643 }
1644 
setAlign(CallInst * Intr,Align A,unsigned RsrcArgIdx)1645 void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) {
1646   LLVMContext &Ctx = Intr->getContext();
1647   Intr->addParamAttr(RsrcArgIdx, Attribute::getWithAlignment(Ctx, A));
1648 }
1649 
insertPreMemOpFence(AtomicOrdering Order,SyncScope::ID SSID)1650 void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order,
1651                                           SyncScope::ID SSID) {
1652   switch (Order) {
1653   case AtomicOrdering::Release:
1654   case AtomicOrdering::AcquireRelease:
1655   case AtomicOrdering::SequentiallyConsistent:
1656     IRB.CreateFence(AtomicOrdering::Release, SSID);
1657     break;
1658   default:
1659     break;
1660   }
1661 }
1662 
insertPostMemOpFence(AtomicOrdering Order,SyncScope::ID SSID)1663 void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order,
1664                                            SyncScope::ID SSID) {
1665   switch (Order) {
1666   case AtomicOrdering::Acquire:
1667   case AtomicOrdering::AcquireRelease:
1668   case AtomicOrdering::SequentiallyConsistent:
1669     IRB.CreateFence(AtomicOrdering::Acquire, SSID);
1670     break;
1671   default:
1672     break;
1673   }
1674 }
1675 
handleMemoryInst(Instruction * I,Value * Arg,Value * Ptr,Type * Ty,Align Alignment,AtomicOrdering Order,bool IsVolatile,SyncScope::ID SSID)1676 Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr,
1677                                          Type *Ty, Align Alignment,
1678                                          AtomicOrdering Order, bool IsVolatile,
1679                                          SyncScope::ID SSID) {
1680   IRB.SetInsertPoint(I);
1681 
1682   auto [Rsrc, Off] = getPtrParts(Ptr);
1683   SmallVector<Value *, 5> Args;
1684   if (Arg)
1685     Args.push_back(Arg);
1686   Args.push_back(Rsrc);
1687   Args.push_back(Off);
1688   insertPreMemOpFence(Order, SSID);
1689   // soffset is always 0 for these cases, where we always want any offset to be
1690   // part of bounds checking and we don't know which parts of the GEPs is
1691   // uniform.
1692   Args.push_back(IRB.getInt32(0));
1693 
1694   uint32_t Aux = 0;
1695   if (IsVolatile)
1696     Aux |= AMDGPU::CPol::VOLATILE;
1697   Args.push_back(IRB.getInt32(Aux));
1698 
1699   Intrinsic::ID IID = Intrinsic::not_intrinsic;
1700   if (isa<LoadInst>(I))
1701     IID = Order == AtomicOrdering::NotAtomic
1702               ? Intrinsic::amdgcn_raw_ptr_buffer_load
1703               : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load;
1704   else if (isa<StoreInst>(I))
1705     IID = Intrinsic::amdgcn_raw_ptr_buffer_store;
1706   else if (auto *RMW = dyn_cast<AtomicRMWInst>(I)) {
1707     switch (RMW->getOperation()) {
1708     case AtomicRMWInst::Xchg:
1709       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap;
1710       break;
1711     case AtomicRMWInst::Add:
1712       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add;
1713       break;
1714     case AtomicRMWInst::Sub:
1715       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub;
1716       break;
1717     case AtomicRMWInst::And:
1718       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and;
1719       break;
1720     case AtomicRMWInst::Or:
1721       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or;
1722       break;
1723     case AtomicRMWInst::Xor:
1724       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor;
1725       break;
1726     case AtomicRMWInst::Max:
1727       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax;
1728       break;
1729     case AtomicRMWInst::Min:
1730       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin;
1731       break;
1732     case AtomicRMWInst::UMax:
1733       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax;
1734       break;
1735     case AtomicRMWInst::UMin:
1736       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin;
1737       break;
1738     case AtomicRMWInst::FAdd:
1739       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd;
1740       break;
1741     case AtomicRMWInst::FMax:
1742       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax;
1743       break;
1744     case AtomicRMWInst::FMin:
1745       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin;
1746       break;
1747     case AtomicRMWInst::FSub: {
1748       reportFatalUsageError(
1749           "atomic floating point subtraction not supported for "
1750           "buffer resources and should've been expanded away");
1751       break;
1752     }
1753     case AtomicRMWInst::FMaximum: {
1754       reportFatalUsageError(
1755           "atomic floating point fmaximum not supported for "
1756           "buffer resources and should've been expanded away");
1757       break;
1758     }
1759     case AtomicRMWInst::FMinimum: {
1760       reportFatalUsageError(
1761           "atomic floating point fminimum not supported for "
1762           "buffer resources and should've been expanded away");
1763       break;
1764     }
1765     case AtomicRMWInst::Nand:
1766       reportFatalUsageError(
1767           "atomic nand not supported for buffer resources and "
1768           "should've been expanded away");
1769       break;
1770     case AtomicRMWInst::UIncWrap:
1771     case AtomicRMWInst::UDecWrap:
1772       reportFatalUsageError("wrapping increment/decrement not supported for "
1773                             "buffer resources and should've ben expanded away");
1774       break;
1775     case AtomicRMWInst::BAD_BINOP:
1776       llvm_unreachable("Not sure how we got a bad binop");
1777     case AtomicRMWInst::USubCond:
1778     case AtomicRMWInst::USubSat:
1779       break;
1780     }
1781   }
1782 
1783   auto *Call = IRB.CreateIntrinsic(IID, Ty, Args);
1784   copyMetadata(Call, I);
1785   setAlign(Call, Alignment, Arg ? 1 : 0);
1786   Call->takeName(I);
1787 
1788   insertPostMemOpFence(Order, SSID);
1789   // The "no moving p7 directly" rewrites ensure that this load or store won't
1790   // itself need to be split into parts.
1791   SplitUsers.insert(I);
1792   I->replaceAllUsesWith(Call);
1793   return Call;
1794 }
1795 
visitInstruction(Instruction & I)1796 PtrParts SplitPtrStructs::visitInstruction(Instruction &I) {
1797   return {nullptr, nullptr};
1798 }
1799 
visitLoadInst(LoadInst & LI)1800 PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) {
1801   if (!isSplitFatPtr(LI.getPointerOperandType()))
1802     return {nullptr, nullptr};
1803   handleMemoryInst(&LI, nullptr, LI.getPointerOperand(), LI.getType(),
1804                    LI.getAlign(), LI.getOrdering(), LI.isVolatile(),
1805                    LI.getSyncScopeID());
1806   return {nullptr, nullptr};
1807 }
1808 
visitStoreInst(StoreInst & SI)1809 PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) {
1810   if (!isSplitFatPtr(SI.getPointerOperandType()))
1811     return {nullptr, nullptr};
1812   Value *Arg = SI.getValueOperand();
1813   handleMemoryInst(&SI, Arg, SI.getPointerOperand(), Arg->getType(),
1814                    SI.getAlign(), SI.getOrdering(), SI.isVolatile(),
1815                    SI.getSyncScopeID());
1816   return {nullptr, nullptr};
1817 }
1818 
visitAtomicRMWInst(AtomicRMWInst & AI)1819 PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) {
1820   if (!isSplitFatPtr(AI.getPointerOperand()->getType()))
1821     return {nullptr, nullptr};
1822   Value *Arg = AI.getValOperand();
1823   handleMemoryInst(&AI, Arg, AI.getPointerOperand(), Arg->getType(),
1824                    AI.getAlign(), AI.getOrdering(), AI.isVolatile(),
1825                    AI.getSyncScopeID());
1826   return {nullptr, nullptr};
1827 }
1828 
1829 // Unlike load, store, and RMW, cmpxchg needs special handling to account
1830 // for the boolean argument.
visitAtomicCmpXchgInst(AtomicCmpXchgInst & AI)1831 PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) {
1832   Value *Ptr = AI.getPointerOperand();
1833   if (!isSplitFatPtr(Ptr->getType()))
1834     return {nullptr, nullptr};
1835   IRB.SetInsertPoint(&AI);
1836 
1837   Type *Ty = AI.getNewValOperand()->getType();
1838   AtomicOrdering Order = AI.getMergedOrdering();
1839   SyncScope::ID SSID = AI.getSyncScopeID();
1840   bool IsNonTemporal = AI.getMetadata(LLVMContext::MD_nontemporal);
1841 
1842   auto [Rsrc, Off] = getPtrParts(Ptr);
1843   insertPreMemOpFence(Order, SSID);
1844 
1845   uint32_t Aux = 0;
1846   if (IsNonTemporal)
1847     Aux |= AMDGPU::CPol::SLC;
1848   if (AI.isVolatile())
1849     Aux |= AMDGPU::CPol::VOLATILE;
1850   auto *Call =
1851       IRB.CreateIntrinsic(Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Ty,
1852                           {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc,
1853                            Off, IRB.getInt32(0), IRB.getInt32(Aux)});
1854   copyMetadata(Call, &AI);
1855   setAlign(Call, AI.getAlign(), 2);
1856   Call->takeName(&AI);
1857   insertPostMemOpFence(Order, SSID);
1858 
1859   Value *Res = PoisonValue::get(AI.getType());
1860   Res = IRB.CreateInsertValue(Res, Call, 0);
1861   if (!AI.isWeak()) {
1862     Value *Succeeded = IRB.CreateICmpEQ(Call, AI.getCompareOperand());
1863     Res = IRB.CreateInsertValue(Res, Succeeded, 1);
1864   }
1865   SplitUsers.insert(&AI);
1866   AI.replaceAllUsesWith(Res);
1867   return {nullptr, nullptr};
1868 }
1869 
visitGetElementPtrInst(GetElementPtrInst & GEP)1870 PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) {
1871   using namespace llvm::PatternMatch;
1872   Value *Ptr = GEP.getPointerOperand();
1873   if (!isSplitFatPtr(Ptr->getType()))
1874     return {nullptr, nullptr};
1875   IRB.SetInsertPoint(&GEP);
1876 
1877   auto [Rsrc, Off] = getPtrParts(Ptr);
1878   const DataLayout &DL = GEP.getDataLayout();
1879   bool IsNUW = GEP.hasNoUnsignedWrap();
1880   bool IsNUSW = GEP.hasNoUnsignedSignedWrap();
1881 
1882   StructType *ResTy = cast<StructType>(GEP.getType());
1883   Type *ResRsrcTy = ResTy->getElementType(0);
1884   VectorType *ResRsrcVecTy = dyn_cast<VectorType>(ResRsrcTy);
1885   bool BroadcastsPtr = ResRsrcVecTy && !isa<VectorType>(Off->getType());
1886 
1887   // In order to call emitGEPOffset() and thus not have to reimplement it,
1888   // we need the GEP result to have ptr addrspace(7) type.
1889   Type *FatPtrTy =
1890       ResRsrcTy->getWithNewType(IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER));
1891   GEP.mutateType(FatPtrTy);
1892   Value *OffAccum = emitGEPOffset(&IRB, DL, &GEP);
1893   GEP.mutateType(ResTy);
1894 
1895   if (BroadcastsPtr) {
1896     Rsrc = IRB.CreateVectorSplat(ResRsrcVecTy->getElementCount(), Rsrc,
1897                                  Rsrc->getName());
1898     Off = IRB.CreateVectorSplat(ResRsrcVecTy->getElementCount(), Off,
1899                                 Off->getName());
1900   }
1901   if (match(OffAccum, m_Zero())) { // Constant-zero offset
1902     SplitUsers.insert(&GEP);
1903     return {Rsrc, Off};
1904   }
1905 
1906   bool HasNonNegativeOff = false;
1907   if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) {
1908     HasNonNegativeOff = !CI->isNegative();
1909   }
1910   Value *NewOff;
1911   if (match(Off, m_Zero())) {
1912     NewOff = OffAccum;
1913   } else {
1914     NewOff = IRB.CreateAdd(Off, OffAccum, "",
1915                            /*hasNUW=*/IsNUW || (IsNUSW && HasNonNegativeOff),
1916                            /*hasNSW=*/false);
1917   }
1918   copyMetadata(NewOff, &GEP);
1919   NewOff->takeName(&GEP);
1920   SplitUsers.insert(&GEP);
1921   return {Rsrc, NewOff};
1922 }
1923 
visitPtrToIntInst(PtrToIntInst & PI)1924 PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) {
1925   Value *Ptr = PI.getPointerOperand();
1926   if (!isSplitFatPtr(Ptr->getType()))
1927     return {nullptr, nullptr};
1928   IRB.SetInsertPoint(&PI);
1929 
1930   Type *ResTy = PI.getType();
1931   unsigned Width = ResTy->getScalarSizeInBits();
1932 
1933   auto [Rsrc, Off] = getPtrParts(Ptr);
1934   const DataLayout &DL = PI.getDataLayout();
1935   unsigned FatPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
1936 
1937   Value *Res;
1938   if (Width <= BufferOffsetWidth) {
1939     Res = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1940                             PI.getName() + ".off");
1941   } else {
1942     Value *RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc");
1943     Value *Shl = IRB.CreateShl(
1944         RsrcInt,
1945         ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)),
1946         "", Width >= FatPtrWidth, Width > FatPtrWidth);
1947     Value *OffCast = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1948                                        PI.getName() + ".off");
1949     Res = IRB.CreateOr(Shl, OffCast);
1950   }
1951 
1952   copyMetadata(Res, &PI);
1953   Res->takeName(&PI);
1954   SplitUsers.insert(&PI);
1955   PI.replaceAllUsesWith(Res);
1956   return {nullptr, nullptr};
1957 }
1958 
visitIntToPtrInst(IntToPtrInst & IP)1959 PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) {
1960   if (!isSplitFatPtr(IP.getType()))
1961     return {nullptr, nullptr};
1962   IRB.SetInsertPoint(&IP);
1963   const DataLayout &DL = IP.getDataLayout();
1964   unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE);
1965   Value *Int = IP.getOperand(0);
1966   Type *IntTy = Int->getType();
1967   Type *RsrcIntTy = IntTy->getWithNewBitWidth(RsrcPtrWidth);
1968   unsigned Width = IntTy->getScalarSizeInBits();
1969 
1970   auto *RetTy = cast<StructType>(IP.getType());
1971   Type *RsrcTy = RetTy->getElementType(0);
1972   Type *OffTy = RetTy->getElementType(1);
1973   Value *RsrcPart = IRB.CreateLShr(
1974       Int,
1975       ConstantExpr::getIntegerValue(IntTy, APInt(Width, BufferOffsetWidth)));
1976   Value *RsrcInt = IRB.CreateIntCast(RsrcPart, RsrcIntTy, /*isSigned=*/false);
1977   Value *Rsrc = IRB.CreateIntToPtr(RsrcInt, RsrcTy, IP.getName() + ".rsrc");
1978   Value *Off =
1979       IRB.CreateIntCast(Int, OffTy, /*IsSigned=*/false, IP.getName() + ".off");
1980 
1981   copyMetadata(Rsrc, &IP);
1982   SplitUsers.insert(&IP);
1983   return {Rsrc, Off};
1984 }
1985 
visitAddrSpaceCastInst(AddrSpaceCastInst & I)1986 PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
1987   // TODO(krzysz00): handle casts from ptr addrspace(7) to global pointers
1988   // by computing the effective address.
1989   if (!isSplitFatPtr(I.getType()))
1990     return {nullptr, nullptr};
1991   IRB.SetInsertPoint(&I);
1992   Value *In = I.getPointerOperand();
1993   // No-op casts preserve parts
1994   if (In->getType() == I.getType()) {
1995     auto [Rsrc, Off] = getPtrParts(In);
1996     SplitUsers.insert(&I);
1997     return {Rsrc, Off};
1998   }
1999 
2000   auto *ResTy = cast<StructType>(I.getType());
2001   Type *RsrcTy = ResTy->getElementType(0);
2002   Type *OffTy = ResTy->getElementType(1);
2003   Value *ZeroOff = Constant::getNullValue(OffTy);
2004 
2005   // Special case for null pointers, undef, and poison, which can be created by
2006   // address space propagation.
2007   auto *InConst = dyn_cast<Constant>(In);
2008   if (InConst && InConst->isNullValue()) {
2009     Value *NullRsrc = Constant::getNullValue(RsrcTy);
2010     SplitUsers.insert(&I);
2011     return {NullRsrc, ZeroOff};
2012   }
2013   if (isa<PoisonValue>(In)) {
2014     Value *PoisonRsrc = PoisonValue::get(RsrcTy);
2015     Value *PoisonOff = PoisonValue::get(OffTy);
2016     SplitUsers.insert(&I);
2017     return {PoisonRsrc, PoisonOff};
2018   }
2019   if (isa<UndefValue>(In)) {
2020     Value *UndefRsrc = UndefValue::get(RsrcTy);
2021     Value *UndefOff = UndefValue::get(OffTy);
2022     SplitUsers.insert(&I);
2023     return {UndefRsrc, UndefOff};
2024   }
2025 
2026   if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE)
2027     reportFatalUsageError(
2028         "only buffer resources (addrspace 8) and null/poison pointers can be "
2029         "cast to buffer fat pointers (addrspace 7)");
2030   SplitUsers.insert(&I);
2031   return {In, ZeroOff};
2032 }
2033 
visitICmpInst(ICmpInst & Cmp)2034 PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) {
2035   Value *Lhs = Cmp.getOperand(0);
2036   if (!isSplitFatPtr(Lhs->getType()))
2037     return {nullptr, nullptr};
2038   Value *Rhs = Cmp.getOperand(1);
2039   IRB.SetInsertPoint(&Cmp);
2040   ICmpInst::Predicate Pred = Cmp.getPredicate();
2041 
2042   assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
2043          "Pointer comparison is only equal or unequal");
2044   auto [LhsRsrc, LhsOff] = getPtrParts(Lhs);
2045   auto [RhsRsrc, RhsOff] = getPtrParts(Rhs);
2046   Value *RsrcCmp =
2047       IRB.CreateICmp(Pred, LhsRsrc, RhsRsrc, Cmp.getName() + ".rsrc");
2048   copyMetadata(RsrcCmp, &Cmp);
2049   Value *OffCmp = IRB.CreateICmp(Pred, LhsOff, RhsOff, Cmp.getName() + ".off");
2050   copyMetadata(OffCmp, &Cmp);
2051 
2052   Value *Res = nullptr;
2053   if (Pred == ICmpInst::ICMP_EQ)
2054     Res = IRB.CreateAnd(RsrcCmp, OffCmp);
2055   else if (Pred == ICmpInst::ICMP_NE)
2056     Res = IRB.CreateOr(RsrcCmp, OffCmp);
2057   copyMetadata(Res, &Cmp);
2058   Res->takeName(&Cmp);
2059   SplitUsers.insert(&Cmp);
2060   Cmp.replaceAllUsesWith(Res);
2061   return {nullptr, nullptr};
2062 }
2063 
visitFreezeInst(FreezeInst & I)2064 PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) {
2065   if (!isSplitFatPtr(I.getType()))
2066     return {nullptr, nullptr};
2067   IRB.SetInsertPoint(&I);
2068   auto [Rsrc, Off] = getPtrParts(I.getOperand(0));
2069 
2070   Value *RsrcRes = IRB.CreateFreeze(Rsrc, I.getName() + ".rsrc");
2071   copyMetadata(RsrcRes, &I);
2072   Value *OffRes = IRB.CreateFreeze(Off, I.getName() + ".off");
2073   copyMetadata(OffRes, &I);
2074   SplitUsers.insert(&I);
2075   return {RsrcRes, OffRes};
2076 }
2077 
visitExtractElementInst(ExtractElementInst & I)2078 PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) {
2079   if (!isSplitFatPtr(I.getType()))
2080     return {nullptr, nullptr};
2081   IRB.SetInsertPoint(&I);
2082   Value *Vec = I.getVectorOperand();
2083   Value *Idx = I.getIndexOperand();
2084   auto [Rsrc, Off] = getPtrParts(Vec);
2085 
2086   Value *RsrcRes = IRB.CreateExtractElement(Rsrc, Idx, I.getName() + ".rsrc");
2087   copyMetadata(RsrcRes, &I);
2088   Value *OffRes = IRB.CreateExtractElement(Off, Idx, I.getName() + ".off");
2089   copyMetadata(OffRes, &I);
2090   SplitUsers.insert(&I);
2091   return {RsrcRes, OffRes};
2092 }
2093 
visitInsertElementInst(InsertElementInst & I)2094 PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) {
2095   // The mutated instructions temporarily don't return vectors, and so
2096   // we need the generic getType() here to avoid crashes.
2097   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
2098     return {nullptr, nullptr};
2099   IRB.SetInsertPoint(&I);
2100   Value *Vec = I.getOperand(0);
2101   Value *Elem = I.getOperand(1);
2102   Value *Idx = I.getOperand(2);
2103   auto [VecRsrc, VecOff] = getPtrParts(Vec);
2104   auto [ElemRsrc, ElemOff] = getPtrParts(Elem);
2105 
2106   Value *RsrcRes =
2107       IRB.CreateInsertElement(VecRsrc, ElemRsrc, Idx, I.getName() + ".rsrc");
2108   copyMetadata(RsrcRes, &I);
2109   Value *OffRes =
2110       IRB.CreateInsertElement(VecOff, ElemOff, Idx, I.getName() + ".off");
2111   copyMetadata(OffRes, &I);
2112   SplitUsers.insert(&I);
2113   return {RsrcRes, OffRes};
2114 }
2115 
visitShuffleVectorInst(ShuffleVectorInst & I)2116 PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) {
2117   // Cast is needed for the same reason as insertelement's.
2118   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
2119     return {nullptr, nullptr};
2120   IRB.SetInsertPoint(&I);
2121 
2122   Value *V1 = I.getOperand(0);
2123   Value *V2 = I.getOperand(1);
2124   ArrayRef<int> Mask = I.getShuffleMask();
2125   auto [V1Rsrc, V1Off] = getPtrParts(V1);
2126   auto [V2Rsrc, V2Off] = getPtrParts(V2);
2127 
2128   Value *RsrcRes =
2129       IRB.CreateShuffleVector(V1Rsrc, V2Rsrc, Mask, I.getName() + ".rsrc");
2130   copyMetadata(RsrcRes, &I);
2131   Value *OffRes =
2132       IRB.CreateShuffleVector(V1Off, V2Off, Mask, I.getName() + ".off");
2133   copyMetadata(OffRes, &I);
2134   SplitUsers.insert(&I);
2135   return {RsrcRes, OffRes};
2136 }
2137 
visitPHINode(PHINode & PHI)2138 PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) {
2139   if (!isSplitFatPtr(PHI.getType()))
2140     return {nullptr, nullptr};
2141   IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef());
2142   // Phi nodes will be handled in post-processing after we've visited every
2143   // instruction. However, instead of just returning {nullptr, nullptr},
2144   // we explicitly create the temporary extractvalue operations that are our
2145   // temporary results so that they end up at the beginning of the block with
2146   // the PHIs.
2147   Value *TmpRsrc = IRB.CreateExtractValue(&PHI, 0, PHI.getName() + ".rsrc");
2148   Value *TmpOff = IRB.CreateExtractValue(&PHI, 1, PHI.getName() + ".off");
2149   Conditionals.push_back(&PHI);
2150   SplitUsers.insert(&PHI);
2151   return {TmpRsrc, TmpOff};
2152 }
2153 
visitSelectInst(SelectInst & SI)2154 PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) {
2155   if (!isSplitFatPtr(SI.getType()))
2156     return {nullptr, nullptr};
2157   IRB.SetInsertPoint(&SI);
2158 
2159   Value *Cond = SI.getCondition();
2160   Value *True = SI.getTrueValue();
2161   Value *False = SI.getFalseValue();
2162   auto [TrueRsrc, TrueOff] = getPtrParts(True);
2163   auto [FalseRsrc, FalseOff] = getPtrParts(False);
2164 
2165   Value *RsrcRes =
2166       IRB.CreateSelect(Cond, TrueRsrc, FalseRsrc, SI.getName() + ".rsrc", &SI);
2167   copyMetadata(RsrcRes, &SI);
2168   Conditionals.push_back(&SI);
2169   Value *OffRes =
2170       IRB.CreateSelect(Cond, TrueOff, FalseOff, SI.getName() + ".off", &SI);
2171   copyMetadata(OffRes, &SI);
2172   SplitUsers.insert(&SI);
2173   return {RsrcRes, OffRes};
2174 }
2175 
2176 /// Returns true if this intrinsic needs to be removed when it is
2177 /// applied to `ptr addrspace(7)` values. Calls to these intrinsics are
2178 /// rewritten into calls to versions of that intrinsic on the resource
2179 /// descriptor.
isRemovablePointerIntrinsic(Intrinsic::ID IID)2180 static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
2181   switch (IID) {
2182   default:
2183     return false;
2184   case Intrinsic::amdgcn_make_buffer_rsrc:
2185   case Intrinsic::ptrmask:
2186   case Intrinsic::invariant_start:
2187   case Intrinsic::invariant_end:
2188   case Intrinsic::launder_invariant_group:
2189   case Intrinsic::strip_invariant_group:
2190   case Intrinsic::memcpy:
2191   case Intrinsic::memcpy_inline:
2192   case Intrinsic::memmove:
2193   case Intrinsic::memset:
2194   case Intrinsic::memset_inline:
2195   case Intrinsic::experimental_memset_pattern:
2196   case Intrinsic::amdgcn_load_to_lds:
2197     return true;
2198   }
2199 }
2200 
visitIntrinsicInst(IntrinsicInst & I)2201 PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) {
2202   Intrinsic::ID IID = I.getIntrinsicID();
2203   switch (IID) {
2204   default:
2205     break;
2206   case Intrinsic::amdgcn_make_buffer_rsrc: {
2207     if (!isSplitFatPtr(I.getType()))
2208       return {nullptr, nullptr};
2209     Value *Base = I.getArgOperand(0);
2210     Value *Stride = I.getArgOperand(1);
2211     Value *NumRecords = I.getArgOperand(2);
2212     Value *Flags = I.getArgOperand(3);
2213     auto *SplitType = cast<StructType>(I.getType());
2214     Type *RsrcType = SplitType->getElementType(0);
2215     Type *OffType = SplitType->getElementType(1);
2216     IRB.SetInsertPoint(&I);
2217     Value *Rsrc = IRB.CreateIntrinsic(IID, {RsrcType, Base->getType()},
2218                                       {Base, Stride, NumRecords, Flags});
2219     copyMetadata(Rsrc, &I);
2220     Rsrc->takeName(&I);
2221     Value *Zero = Constant::getNullValue(OffType);
2222     SplitUsers.insert(&I);
2223     return {Rsrc, Zero};
2224   }
2225   case Intrinsic::ptrmask: {
2226     Value *Ptr = I.getArgOperand(0);
2227     if (!isSplitFatPtr(Ptr->getType()))
2228       return {nullptr, nullptr};
2229     Value *Mask = I.getArgOperand(1);
2230     IRB.SetInsertPoint(&I);
2231     auto [Rsrc, Off] = getPtrParts(Ptr);
2232     if (Mask->getType() != Off->getType())
2233       reportFatalUsageError("offset width is not equal to index width of fat "
2234                             "pointer (data layout not set up correctly?)");
2235     Value *OffRes = IRB.CreateAnd(Off, Mask, I.getName() + ".off");
2236     copyMetadata(OffRes, &I);
2237     SplitUsers.insert(&I);
2238     return {Rsrc, OffRes};
2239   }
2240   // Pointer annotation intrinsics that, given their object-wide nature
2241   // operate on the resource part.
2242   case Intrinsic::invariant_start: {
2243     Value *Ptr = I.getArgOperand(1);
2244     if (!isSplitFatPtr(Ptr->getType()))
2245       return {nullptr, nullptr};
2246     IRB.SetInsertPoint(&I);
2247     auto [Rsrc, Off] = getPtrParts(Ptr);
2248     Type *NewTy = PointerType::get(I.getContext(), AMDGPUAS::BUFFER_RESOURCE);
2249     auto *NewRsrc = IRB.CreateIntrinsic(IID, {NewTy}, {I.getOperand(0), Rsrc});
2250     copyMetadata(NewRsrc, &I);
2251     NewRsrc->takeName(&I);
2252     SplitUsers.insert(&I);
2253     I.replaceAllUsesWith(NewRsrc);
2254     return {nullptr, nullptr};
2255   }
2256   case Intrinsic::invariant_end: {
2257     Value *RealPtr = I.getArgOperand(2);
2258     if (!isSplitFatPtr(RealPtr->getType()))
2259       return {nullptr, nullptr};
2260     IRB.SetInsertPoint(&I);
2261     Value *RealRsrc = getPtrParts(RealPtr).first;
2262     Value *InvPtr = I.getArgOperand(0);
2263     Value *Size = I.getArgOperand(1);
2264     Value *NewRsrc = IRB.CreateIntrinsic(IID, {RealRsrc->getType()},
2265                                          {InvPtr, Size, RealRsrc});
2266     copyMetadata(NewRsrc, &I);
2267     NewRsrc->takeName(&I);
2268     SplitUsers.insert(&I);
2269     I.replaceAllUsesWith(NewRsrc);
2270     return {nullptr, nullptr};
2271   }
2272   case Intrinsic::launder_invariant_group:
2273   case Intrinsic::strip_invariant_group: {
2274     Value *Ptr = I.getArgOperand(0);
2275     if (!isSplitFatPtr(Ptr->getType()))
2276       return {nullptr, nullptr};
2277     IRB.SetInsertPoint(&I);
2278     auto [Rsrc, Off] = getPtrParts(Ptr);
2279     Value *NewRsrc = IRB.CreateIntrinsic(IID, {Rsrc->getType()}, {Rsrc});
2280     copyMetadata(NewRsrc, &I);
2281     NewRsrc->takeName(&I);
2282     SplitUsers.insert(&I);
2283     return {NewRsrc, Off};
2284   }
2285   case Intrinsic::amdgcn_load_to_lds: {
2286     Value *Ptr = I.getArgOperand(0);
2287     if (!isSplitFatPtr(Ptr->getType()))
2288       return {nullptr, nullptr};
2289     IRB.SetInsertPoint(&I);
2290     auto [Rsrc, Off] = getPtrParts(Ptr);
2291     Value *LDSPtr = I.getArgOperand(1);
2292     Value *LoadSize = I.getArgOperand(2);
2293     Value *ImmOff = I.getArgOperand(3);
2294     Value *Aux = I.getArgOperand(4);
2295     Value *SOffset = IRB.getInt32(0);
2296     Instruction *NewLoad = IRB.CreateIntrinsic(
2297         Intrinsic::amdgcn_raw_ptr_buffer_load_lds, {},
2298         {Rsrc, LDSPtr, LoadSize, Off, SOffset, ImmOff, Aux});
2299     copyMetadata(NewLoad, &I);
2300     SplitUsers.insert(&I);
2301     I.replaceAllUsesWith(NewLoad);
2302     return {nullptr, nullptr};
2303   }
2304   }
2305   return {nullptr, nullptr};
2306 }
2307 
processFunction(Function & F)2308 void SplitPtrStructs::processFunction(Function &F) {
2309   ST = &TM->getSubtarget<GCNSubtarget>(F);
2310   SmallVector<Instruction *, 0> Originals(
2311       llvm::make_pointer_range(instructions(F)));
2312   LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName()
2313                     << "\n");
2314   for (Instruction *I : Originals) {
2315     auto [Rsrc, Off] = visit(I);
2316     assert(((Rsrc && Off) || (!Rsrc && !Off)) &&
2317            "Can't have a resource but no offset");
2318     if (Rsrc)
2319       RsrcParts[I] = Rsrc;
2320     if (Off)
2321       OffParts[I] = Off;
2322   }
2323   processConditionals();
2324   killAndReplaceSplitInstructions(Originals);
2325 
2326   // Clean up after ourselves to save on memory.
2327   RsrcParts.clear();
2328   OffParts.clear();
2329   SplitUsers.clear();
2330   Conditionals.clear();
2331   ConditionalTemps.clear();
2332 }
2333 
2334 namespace {
2335 class AMDGPULowerBufferFatPointers : public ModulePass {
2336 public:
2337   static char ID;
2338 
AMDGPULowerBufferFatPointers()2339   AMDGPULowerBufferFatPointers() : ModulePass(ID) {}
2340 
2341   bool run(Module &M, const TargetMachine &TM);
2342   bool runOnModule(Module &M) override;
2343 
2344   void getAnalysisUsage(AnalysisUsage &AU) const override;
2345 };
2346 } // namespace
2347 
2348 /// Returns true if there are values that have a buffer fat pointer in them,
2349 /// which means we'll need to perform rewrites on this function. As a side
2350 /// effect, this will populate the type remapping cache.
containsBufferFatPointers(const Function & F,BufferFatPtrToStructTypeMap * TypeMap)2351 static bool containsBufferFatPointers(const Function &F,
2352                                       BufferFatPtrToStructTypeMap *TypeMap) {
2353   bool HasFatPointers = false;
2354   for (const BasicBlock &BB : F)
2355     for (const Instruction &I : BB)
2356       HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType()));
2357   return HasFatPointers;
2358 }
2359 
hasFatPointerInterface(const Function & F,BufferFatPtrToStructTypeMap * TypeMap)2360 static bool hasFatPointerInterface(const Function &F,
2361                                    BufferFatPtrToStructTypeMap *TypeMap) {
2362   Type *Ty = F.getFunctionType();
2363   return Ty != TypeMap->remapType(Ty);
2364 }
2365 
2366 /// Move the body of `OldF` into a new function, returning it.
moveFunctionAdaptingType(Function * OldF,FunctionType * NewTy,ValueToValueMapTy & CloneMap)2367 static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy,
2368                                           ValueToValueMapTy &CloneMap) {
2369   bool IsIntrinsic = OldF->isIntrinsic();
2370   Function *NewF =
2371       Function::Create(NewTy, OldF->getLinkage(), OldF->getAddressSpace());
2372   NewF->copyAttributesFrom(OldF);
2373   NewF->copyMetadata(OldF, 0);
2374   NewF->takeName(OldF);
2375   NewF->updateAfterNameChange();
2376   NewF->setDLLStorageClass(OldF->getDLLStorageClass());
2377   OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
2378 
2379   while (!OldF->empty()) {
2380     BasicBlock *BB = &OldF->front();
2381     BB->removeFromParent();
2382     BB->insertInto(NewF);
2383     CloneMap[BB] = BB;
2384     for (Instruction &I : *BB) {
2385       CloneMap[&I] = &I;
2386     }
2387   }
2388 
2389   SmallVector<AttributeSet> ArgAttrs;
2390   AttributeList OldAttrs = OldF->getAttributes();
2391 
2392   for (auto [I, OldArg, NewArg] : enumerate(OldF->args(), NewF->args())) {
2393     CloneMap[&NewArg] = &OldArg;
2394     NewArg.takeName(&OldArg);
2395     Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType();
2396     // Temporarily mutate type of `NewArg` to allow RAUW to work.
2397     NewArg.mutateType(OldArgTy);
2398     OldArg.replaceAllUsesWith(&NewArg);
2399     NewArg.mutateType(NewArgTy);
2400 
2401     AttributeSet ArgAttr = OldAttrs.getParamAttrs(I);
2402     // Intrinsics get their attributes fixed later.
2403     if (OldArgTy != NewArgTy && !IsIntrinsic)
2404       ArgAttr = ArgAttr.removeAttributes(
2405           NewF->getContext(),
2406           AttributeFuncs::typeIncompatible(NewArgTy, ArgAttr));
2407     ArgAttrs.push_back(ArgAttr);
2408   }
2409   AttributeSet RetAttrs = OldAttrs.getRetAttrs();
2410   if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic)
2411     RetAttrs = RetAttrs.removeAttributes(
2412         NewF->getContext(),
2413         AttributeFuncs::typeIncompatible(NewF->getReturnType(), RetAttrs));
2414   NewF->setAttributes(AttributeList::get(
2415       NewF->getContext(), OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs));
2416   return NewF;
2417 }
2418 
makeCloneInPraceMap(Function * F,ValueToValueMapTy & CloneMap)2419 static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) {
2420   for (Argument &A : F->args())
2421     CloneMap[&A] = &A;
2422   for (BasicBlock &BB : *F) {
2423     CloneMap[&BB] = &BB;
2424     for (Instruction &I : BB)
2425       CloneMap[&I] = &I;
2426   }
2427 }
2428 
run(Module & M,const TargetMachine & TM)2429 bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
2430   bool Changed = false;
2431   const DataLayout &DL = M.getDataLayout();
2432   // Record the functions which need to be remapped.
2433   // The second element of the pair indicates whether the function has to have
2434   // its arguments or return types adjusted.
2435   SmallVector<std::pair<Function *, bool>> NeedsRemap;
2436 
2437   LLVMContext &Ctx = M.getContext();
2438 
2439   BufferFatPtrToStructTypeMap StructTM(DL);
2440   BufferFatPtrToIntTypeMap IntTM(DL);
2441   for (const GlobalVariable &GV : M.globals()) {
2442     if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
2443       // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2444       Ctx.emitError("global variables with a buffer fat pointer address "
2445                     "space (7) are not supported");
2446       continue;
2447     }
2448 
2449     Type *VT = GV.getValueType();
2450     if (VT != StructTM.remapType(VT)) {
2451       // FIXME: Use DiagnosticInfo unsupported but it requires a Function
2452       Ctx.emitError("global variables that contain buffer fat pointers "
2453                     "(address space 7 pointers) are unsupported. Use "
2454                     "buffer resource pointers (address space 8) instead");
2455       continue;
2456     }
2457   }
2458 
2459   {
2460     // Collect all constant exprs and aggregates referenced by any function.
2461     SmallVector<Constant *, 8> Worklist;
2462     for (Function &F : M.functions())
2463       for (Instruction &I : instructions(F))
2464         for (Value *Op : I.operands())
2465           if (isa<ConstantExpr, ConstantAggregate>(Op))
2466             Worklist.push_back(cast<Constant>(Op));
2467 
2468     // Recursively look for any referenced buffer pointer constants.
2469     SmallPtrSet<Constant *, 8> Visited;
2470     SetVector<Constant *> BufferFatPtrConsts;
2471     while (!Worklist.empty()) {
2472       Constant *C = Worklist.pop_back_val();
2473       if (!Visited.insert(C).second)
2474         continue;
2475       if (isBufferFatPtrOrVector(C->getType()))
2476         BufferFatPtrConsts.insert(C);
2477       for (Value *Op : C->operands())
2478         if (isa<ConstantExpr, ConstantAggregate>(Op))
2479           Worklist.push_back(cast<Constant>(Op));
2480     }
2481 
2482     // Expand all constant expressions using fat buffer pointers to
2483     // instructions.
2484     Changed |= convertUsersOfConstantsToInstructions(
2485         BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
2486         /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
2487   }
2488 
2489   StoreFatPtrsAsIntsAndExpandMemcpyVisitor MemOpsRewrite(&IntTM, DL,
2490                                                          M.getContext(), &TM);
2491   LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
2492                                                               M.getContext());
2493   for (Function &F : M.functions()) {
2494     bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
2495     bool BodyChanges = containsBufferFatPointers(F, &StructTM);
2496     Changed |= MemOpsRewrite.processFunction(F);
2497     if (InterfaceChange || BodyChanges) {
2498       NeedsRemap.push_back(std::make_pair(&F, InterfaceChange));
2499       Changed |= BufferContentsTypeRewrite.processFunction(F);
2500     }
2501   }
2502   if (NeedsRemap.empty())
2503     return Changed;
2504 
2505   SmallVector<Function *> NeedsPostProcess;
2506   SmallVector<Function *> Intrinsics;
2507   // Keep one big map so as to memoize constants across functions.
2508   ValueToValueMapTy CloneMap;
2509   FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
2510 
2511   ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
2512   for (auto [F, InterfaceChange] : NeedsRemap) {
2513     Function *NewF = F;
2514     if (InterfaceChange)
2515       NewF = moveFunctionAdaptingType(
2516           F, cast<FunctionType>(StructTM.remapType(F->getFunctionType())),
2517           CloneMap);
2518     else
2519       makeCloneInPraceMap(F, CloneMap);
2520     LowerInFuncs.remapFunction(*NewF);
2521     if (NewF->isIntrinsic())
2522       Intrinsics.push_back(NewF);
2523     else
2524       NeedsPostProcess.push_back(NewF);
2525     if (InterfaceChange) {
2526       F->replaceAllUsesWith(NewF);
2527       F->eraseFromParent();
2528     }
2529     Changed = true;
2530   }
2531   StructTM.clear();
2532   IntTM.clear();
2533   CloneMap.clear();
2534 
2535   SplitPtrStructs Splitter(DL, M.getContext(), &TM);
2536   for (Function *F : NeedsPostProcess)
2537     Splitter.processFunction(*F);
2538   for (Function *F : Intrinsics) {
2539     if (isRemovablePointerIntrinsic(F->getIntrinsicID())) {
2540       F->eraseFromParent();
2541     } else {
2542       std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F);
2543       if (NewF)
2544         F->replaceAllUsesWith(*NewF);
2545     }
2546   }
2547   return Changed;
2548 }
2549 
runOnModule(Module & M)2550 bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) {
2551   TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
2552   const TargetMachine &TM = TPC.getTM<TargetMachine>();
2553   return run(M, TM);
2554 }
2555 
2556 char AMDGPULowerBufferFatPointers::ID = 0;
2557 
2558 char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID;
2559 
getAnalysisUsage(AnalysisUsage & AU) const2560 void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const {
2561   AU.addRequired<TargetPassConfig>();
2562 }
2563 
2564 #define PASS_DESC "Lower buffer fat pointer operations to buffer resources"
INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers,DEBUG_TYPE,PASS_DESC,false,false)2565 INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC,
2566                       false, false)
2567 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
2568 INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false,
2569                     false)
2570 #undef PASS_DESC
2571 
2572 ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() {
2573   return new AMDGPULowerBufferFatPointers();
2574 }
2575 
2576 PreservedAnalyses
run(Module & M,ModuleAnalysisManager & MA)2577 AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) {
2578   return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none()
2579                                                    : PreservedAnalyses::all();
2580 }
2581