xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1*700637cbSDimitry Andric //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
2*700637cbSDimitry Andric //
3*700637cbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*700637cbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*700637cbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*700637cbSDimitry Andric //
7*700637cbSDimitry Andric //===---------------------------------------------------------------------===//
8*700637cbSDimitry Andric ///
9*700637cbSDimitry Andric /// \file This file contains a pass to flatten arrays for the DirectX Backend.
10*700637cbSDimitry Andric ///
11*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
12*700637cbSDimitry Andric 
13*700637cbSDimitry Andric #include "DXILFlattenArrays.h"
14*700637cbSDimitry Andric #include "DirectX.h"
15*700637cbSDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
16*700637cbSDimitry Andric #include "llvm/ADT/STLExtras.h"
17*700637cbSDimitry Andric #include "llvm/IR/BasicBlock.h"
18*700637cbSDimitry Andric #include "llvm/IR/DerivedTypes.h"
19*700637cbSDimitry Andric #include "llvm/IR/IRBuilder.h"
20*700637cbSDimitry Andric #include "llvm/IR/InstVisitor.h"
21*700637cbSDimitry Andric #include "llvm/IR/ReplaceConstant.h"
22*700637cbSDimitry Andric #include "llvm/Support/Casting.h"
23*700637cbSDimitry Andric #include "llvm/Support/MathExtras.h"
24*700637cbSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
25*700637cbSDimitry Andric #include <cassert>
26*700637cbSDimitry Andric #include <cstddef>
27*700637cbSDimitry Andric #include <cstdint>
28*700637cbSDimitry Andric #include <utility>
29*700637cbSDimitry Andric 
30*700637cbSDimitry Andric #define DEBUG_TYPE "dxil-flatten-arrays"
31*700637cbSDimitry Andric 
32*700637cbSDimitry Andric using namespace llvm;
33*700637cbSDimitry Andric namespace {
34*700637cbSDimitry Andric 
35*700637cbSDimitry Andric class DXILFlattenArraysLegacy : public ModulePass {
36*700637cbSDimitry Andric 
37*700637cbSDimitry Andric public:
38*700637cbSDimitry Andric   bool runOnModule(Module &M) override;
DXILFlattenArraysLegacy()39*700637cbSDimitry Andric   DXILFlattenArraysLegacy() : ModulePass(ID) {}
40*700637cbSDimitry Andric 
41*700637cbSDimitry Andric   static char ID; // Pass identification.
42*700637cbSDimitry Andric };
43*700637cbSDimitry Andric 
44*700637cbSDimitry Andric struct GEPInfo {
45*700637cbSDimitry Andric   ArrayType *RootFlattenedArrayType;
46*700637cbSDimitry Andric   Value *RootPointerOperand;
47*700637cbSDimitry Andric   SmallMapVector<Value *, APInt, 4> VariableOffsets;
48*700637cbSDimitry Andric   APInt ConstantOffset;
49*700637cbSDimitry Andric };
50*700637cbSDimitry Andric 
51*700637cbSDimitry Andric class DXILFlattenArraysVisitor
52*700637cbSDimitry Andric     : public InstVisitor<DXILFlattenArraysVisitor, bool> {
53*700637cbSDimitry Andric public:
DXILFlattenArraysVisitor(SmallDenseMap<GlobalVariable *,GlobalVariable * > & GlobalMap)54*700637cbSDimitry Andric   DXILFlattenArraysVisitor(
55*700637cbSDimitry Andric       SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
56*700637cbSDimitry Andric       : GlobalMap(GlobalMap) {}
57*700637cbSDimitry Andric   bool visit(Function &F);
58*700637cbSDimitry Andric   // InstVisitor methods.  They return true if the instruction was scalarized,
59*700637cbSDimitry Andric   // false if nothing changed.
60*700637cbSDimitry Andric   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
61*700637cbSDimitry Andric   bool visitAllocaInst(AllocaInst &AI);
visitInstruction(Instruction & I)62*700637cbSDimitry Andric   bool visitInstruction(Instruction &I) { return false; }
visitSelectInst(SelectInst & SI)63*700637cbSDimitry Andric   bool visitSelectInst(SelectInst &SI) { return false; }
visitICmpInst(ICmpInst & ICI)64*700637cbSDimitry Andric   bool visitICmpInst(ICmpInst &ICI) { return false; }
visitFCmpInst(FCmpInst & FCI)65*700637cbSDimitry Andric   bool visitFCmpInst(FCmpInst &FCI) { return false; }
visitUnaryOperator(UnaryOperator & UO)66*700637cbSDimitry Andric   bool visitUnaryOperator(UnaryOperator &UO) { return false; }
visitBinaryOperator(BinaryOperator & BO)67*700637cbSDimitry Andric   bool visitBinaryOperator(BinaryOperator &BO) { return false; }
visitCastInst(CastInst & CI)68*700637cbSDimitry Andric   bool visitCastInst(CastInst &CI) { return false; }
visitBitCastInst(BitCastInst & BCI)69*700637cbSDimitry Andric   bool visitBitCastInst(BitCastInst &BCI) { return false; }
visitInsertElementInst(InsertElementInst & IEI)70*700637cbSDimitry Andric   bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
visitExtractElementInst(ExtractElementInst & EEI)71*700637cbSDimitry Andric   bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
visitShuffleVectorInst(ShuffleVectorInst & SVI)72*700637cbSDimitry Andric   bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
visitPHINode(PHINode & PHI)73*700637cbSDimitry Andric   bool visitPHINode(PHINode &PHI) { return false; }
74*700637cbSDimitry Andric   bool visitLoadInst(LoadInst &LI);
75*700637cbSDimitry Andric   bool visitStoreInst(StoreInst &SI);
visitCallInst(CallInst & ICI)76*700637cbSDimitry Andric   bool visitCallInst(CallInst &ICI) { return false; }
visitFreezeInst(FreezeInst & FI)77*700637cbSDimitry Andric   bool visitFreezeInst(FreezeInst &FI) { return false; }
78*700637cbSDimitry Andric   static bool isMultiDimensionalArray(Type *T);
79*700637cbSDimitry Andric   static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
80*700637cbSDimitry Andric 
81*700637cbSDimitry Andric private:
82*700637cbSDimitry Andric   SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
83*700637cbSDimitry Andric   SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
84*700637cbSDimitry Andric   SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
85*700637cbSDimitry Andric   bool finish();
86*700637cbSDimitry Andric   ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
87*700637cbSDimitry Andric                                       ArrayRef<uint64_t> Dims,
88*700637cbSDimitry Andric                                       IRBuilder<> &Builder);
89*700637cbSDimitry Andric   Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
90*700637cbSDimitry Andric                                       ArrayRef<uint64_t> Dims,
91*700637cbSDimitry Andric                                       IRBuilder<> &Builder);
92*700637cbSDimitry Andric };
93*700637cbSDimitry Andric } // namespace
94*700637cbSDimitry Andric 
finish()95*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::finish() {
96*700637cbSDimitry Andric   GEPChainInfoMap.clear();
97*700637cbSDimitry Andric   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
98*700637cbSDimitry Andric   return true;
99*700637cbSDimitry Andric }
100*700637cbSDimitry Andric 
isMultiDimensionalArray(Type * T)101*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
102*700637cbSDimitry Andric   if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
103*700637cbSDimitry Andric     return isa<ArrayType>(ArrType->getElementType());
104*700637cbSDimitry Andric   return false;
105*700637cbSDimitry Andric }
106*700637cbSDimitry Andric 
107*700637cbSDimitry Andric std::pair<unsigned, Type *>
getElementCountAndType(Type * ArrayTy)108*700637cbSDimitry Andric DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
109*700637cbSDimitry Andric   unsigned TotalElements = 1;
110*700637cbSDimitry Andric   Type *CurrArrayTy = ArrayTy;
111*700637cbSDimitry Andric   while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
112*700637cbSDimitry Andric     TotalElements *= InnerArrayTy->getNumElements();
113*700637cbSDimitry Andric     CurrArrayTy = InnerArrayTy->getElementType();
114*700637cbSDimitry Andric   }
115*700637cbSDimitry Andric   return std::make_pair(TotalElements, CurrArrayTy);
116*700637cbSDimitry Andric }
117*700637cbSDimitry Andric 
genConstFlattenIndices(ArrayRef<Value * > Indices,ArrayRef<uint64_t> Dims,IRBuilder<> & Builder)118*700637cbSDimitry Andric ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
119*700637cbSDimitry Andric     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
120*700637cbSDimitry Andric   assert(Indices.size() == Dims.size() &&
121*700637cbSDimitry Andric          "Indicies and dimmensions should be the same");
122*700637cbSDimitry Andric   unsigned FlatIndex = 0;
123*700637cbSDimitry Andric   unsigned Multiplier = 1;
124*700637cbSDimitry Andric 
125*700637cbSDimitry Andric   for (int I = Indices.size() - 1; I >= 0; --I) {
126*700637cbSDimitry Andric     unsigned DimSize = Dims[I];
127*700637cbSDimitry Andric     ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
128*700637cbSDimitry Andric     assert(CIndex && "This function expects all indicies to be ConstantInt");
129*700637cbSDimitry Andric     FlatIndex += CIndex->getZExtValue() * Multiplier;
130*700637cbSDimitry Andric     Multiplier *= DimSize;
131*700637cbSDimitry Andric   }
132*700637cbSDimitry Andric   return Builder.getInt32(FlatIndex);
133*700637cbSDimitry Andric }
134*700637cbSDimitry Andric 
genInstructionFlattenIndices(ArrayRef<Value * > Indices,ArrayRef<uint64_t> Dims,IRBuilder<> & Builder)135*700637cbSDimitry Andric Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
136*700637cbSDimitry Andric     ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
137*700637cbSDimitry Andric   if (Indices.size() == 1)
138*700637cbSDimitry Andric     return Indices[0];
139*700637cbSDimitry Andric 
140*700637cbSDimitry Andric   Value *FlatIndex = Builder.getInt32(0);
141*700637cbSDimitry Andric   unsigned Multiplier = 1;
142*700637cbSDimitry Andric 
143*700637cbSDimitry Andric   for (int I = Indices.size() - 1; I >= 0; --I) {
144*700637cbSDimitry Andric     unsigned DimSize = Dims[I];
145*700637cbSDimitry Andric     Value *VMultiplier = Builder.getInt32(Multiplier);
146*700637cbSDimitry Andric     Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
147*700637cbSDimitry Andric     FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
148*700637cbSDimitry Andric     Multiplier *= DimSize;
149*700637cbSDimitry Andric   }
150*700637cbSDimitry Andric   return FlatIndex;
151*700637cbSDimitry Andric }
152*700637cbSDimitry Andric 
visitLoadInst(LoadInst & LI)153*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
154*700637cbSDimitry Andric   unsigned NumOperands = LI.getNumOperands();
155*700637cbSDimitry Andric   for (unsigned I = 0; I < NumOperands; ++I) {
156*700637cbSDimitry Andric     Value *CurrOpperand = LI.getOperand(I);
157*700637cbSDimitry Andric     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
158*700637cbSDimitry Andric     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
159*700637cbSDimitry Andric       GetElementPtrInst *OldGEP =
160*700637cbSDimitry Andric           cast<GetElementPtrInst>(CE->getAsInstruction());
161*700637cbSDimitry Andric       OldGEP->insertBefore(LI.getIterator());
162*700637cbSDimitry Andric 
163*700637cbSDimitry Andric       IRBuilder<> Builder(&LI);
164*700637cbSDimitry Andric       LoadInst *NewLoad =
165*700637cbSDimitry Andric           Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
166*700637cbSDimitry Andric       NewLoad->setAlignment(LI.getAlign());
167*700637cbSDimitry Andric       LI.replaceAllUsesWith(NewLoad);
168*700637cbSDimitry Andric       LI.eraseFromParent();
169*700637cbSDimitry Andric       visitGetElementPtrInst(*OldGEP);
170*700637cbSDimitry Andric       return true;
171*700637cbSDimitry Andric     }
172*700637cbSDimitry Andric   }
173*700637cbSDimitry Andric   return false;
174*700637cbSDimitry Andric }
175*700637cbSDimitry Andric 
visitStoreInst(StoreInst & SI)176*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
177*700637cbSDimitry Andric   unsigned NumOperands = SI.getNumOperands();
178*700637cbSDimitry Andric   for (unsigned I = 0; I < NumOperands; ++I) {
179*700637cbSDimitry Andric     Value *CurrOpperand = SI.getOperand(I);
180*700637cbSDimitry Andric     ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
181*700637cbSDimitry Andric     if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
182*700637cbSDimitry Andric       GetElementPtrInst *OldGEP =
183*700637cbSDimitry Andric           cast<GetElementPtrInst>(CE->getAsInstruction());
184*700637cbSDimitry Andric       OldGEP->insertBefore(SI.getIterator());
185*700637cbSDimitry Andric 
186*700637cbSDimitry Andric       IRBuilder<> Builder(&SI);
187*700637cbSDimitry Andric       StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
188*700637cbSDimitry Andric       NewStore->setAlignment(SI.getAlign());
189*700637cbSDimitry Andric       SI.replaceAllUsesWith(NewStore);
190*700637cbSDimitry Andric       SI.eraseFromParent();
191*700637cbSDimitry Andric       visitGetElementPtrInst(*OldGEP);
192*700637cbSDimitry Andric       return true;
193*700637cbSDimitry Andric     }
194*700637cbSDimitry Andric   }
195*700637cbSDimitry Andric   return false;
196*700637cbSDimitry Andric }
197*700637cbSDimitry Andric 
visitAllocaInst(AllocaInst & AI)198*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
199*700637cbSDimitry Andric   if (!isMultiDimensionalArray(AI.getAllocatedType()))
200*700637cbSDimitry Andric     return false;
201*700637cbSDimitry Andric 
202*700637cbSDimitry Andric   ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
203*700637cbSDimitry Andric   IRBuilder<> Builder(&AI);
204*700637cbSDimitry Andric   auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
205*700637cbSDimitry Andric 
206*700637cbSDimitry Andric   ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
207*700637cbSDimitry Andric   AllocaInst *FlatAlloca =
208*700637cbSDimitry Andric       Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".1dim");
209*700637cbSDimitry Andric   FlatAlloca->setAlignment(AI.getAlign());
210*700637cbSDimitry Andric   AI.replaceAllUsesWith(FlatAlloca);
211*700637cbSDimitry Andric   AI.eraseFromParent();
212*700637cbSDimitry Andric   return true;
213*700637cbSDimitry Andric }
214*700637cbSDimitry Andric 
visitGetElementPtrInst(GetElementPtrInst & GEP)215*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
216*700637cbSDimitry Andric   // Do not visit GEPs more than once
217*700637cbSDimitry Andric   if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
218*700637cbSDimitry Andric     return false;
219*700637cbSDimitry Andric 
220*700637cbSDimitry Andric   Value *PtrOperand = GEP.getPointerOperand();
221*700637cbSDimitry Andric   // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI
222*700637cbSDimitry Andric   // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets
223*700637cbSDimitry Andric   // pointer types, then the handling of this case can be implemented later.
224*700637cbSDimitry Andric   assert(!isa<PHINode>(PtrOperand) &&
225*700637cbSDimitry Andric          "Pointer operand of GEP should not be a PHI Node");
226*700637cbSDimitry Andric 
227*700637cbSDimitry Andric   // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
228*700637cbSDimitry Andric   // it can be visited
229*700637cbSDimitry Andric   if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
230*700637cbSDimitry Andric       PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
231*700637cbSDimitry Andric     GetElementPtrInst *OldGEPI =
232*700637cbSDimitry Andric         cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
233*700637cbSDimitry Andric     OldGEPI->insertBefore(GEP.getIterator());
234*700637cbSDimitry Andric 
235*700637cbSDimitry Andric     IRBuilder<> Builder(&GEP);
236*700637cbSDimitry Andric     SmallVector<Value *> Indices(GEP.indices());
237*700637cbSDimitry Andric     Value *NewGEP =
238*700637cbSDimitry Andric         Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
239*700637cbSDimitry Andric                           GEP.getName(), GEP.getNoWrapFlags());
240*700637cbSDimitry Andric     assert(isa<GetElementPtrInst>(NewGEP) &&
241*700637cbSDimitry Andric            "Expected newly-created GEP to be an instruction");
242*700637cbSDimitry Andric     GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
243*700637cbSDimitry Andric 
244*700637cbSDimitry Andric     GEP.replaceAllUsesWith(NewGEPI);
245*700637cbSDimitry Andric     GEP.eraseFromParent();
246*700637cbSDimitry Andric     visitGetElementPtrInst(*OldGEPI);
247*700637cbSDimitry Andric     visitGetElementPtrInst(*NewGEPI);
248*700637cbSDimitry Andric     return true;
249*700637cbSDimitry Andric   }
250*700637cbSDimitry Andric 
251*700637cbSDimitry Andric   // Construct GEPInfo for this GEP
252*700637cbSDimitry Andric   GEPInfo Info;
253*700637cbSDimitry Andric 
254*700637cbSDimitry Andric   // Obtain the variable and constant byte offsets computed by this GEP
255*700637cbSDimitry Andric   const DataLayout &DL = GEP.getDataLayout();
256*700637cbSDimitry Andric   unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
257*700637cbSDimitry Andric   Info.ConstantOffset = {BitWidth, 0};
258*700637cbSDimitry Andric   [[maybe_unused]] bool Success = GEP.collectOffset(
259*700637cbSDimitry Andric       DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset);
260*700637cbSDimitry Andric   assert(Success && "Failed to collect offsets for GEP");
261*700637cbSDimitry Andric 
262*700637cbSDimitry Andric   // If there is a parent GEP, inherit the root array type and pointer, and
263*700637cbSDimitry Andric   // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
264*700637cbSDimitry Andric   // chain and we need to deterine the root array type
265*700637cbSDimitry Andric   if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
266*700637cbSDimitry Andric     assert(GEPChainInfoMap.contains(PtrOpGEP) &&
267*700637cbSDimitry Andric            "Expected parent GEP to be visited before this GEP");
268*700637cbSDimitry Andric     GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
269*700637cbSDimitry Andric     Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
270*700637cbSDimitry Andric     Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
271*700637cbSDimitry Andric     for (auto &VariableOffset : PGEPInfo.VariableOffsets)
272*700637cbSDimitry Andric       Info.VariableOffsets.insert(VariableOffset);
273*700637cbSDimitry Andric     Info.ConstantOffset += PGEPInfo.ConstantOffset;
274*700637cbSDimitry Andric   } else {
275*700637cbSDimitry Andric     Info.RootPointerOperand = PtrOperand;
276*700637cbSDimitry Andric 
277*700637cbSDimitry Andric     // We should try to determine the type of the root from the pointer rather
278*700637cbSDimitry Andric     // than the GEP's source element type because this could be a scalar GEP
279*700637cbSDimitry Andric     // into an array-typed pointer from an Alloca or Global Variable.
280*700637cbSDimitry Andric     Type *RootTy = GEP.getSourceElementType();
281*700637cbSDimitry Andric     if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
282*700637cbSDimitry Andric       if (GlobalMap.contains(GlobalVar))
283*700637cbSDimitry Andric         GlobalVar = GlobalMap[GlobalVar];
284*700637cbSDimitry Andric       Info.RootPointerOperand = GlobalVar;
285*700637cbSDimitry Andric       RootTy = GlobalVar->getValueType();
286*700637cbSDimitry Andric     } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
287*700637cbSDimitry Andric       RootTy = Alloca->getAllocatedType();
288*700637cbSDimitry Andric     assert(!isMultiDimensionalArray(RootTy) &&
289*700637cbSDimitry Andric            "Expected root array type to be flattened");
290*700637cbSDimitry Andric 
291*700637cbSDimitry Andric     // If the root type is not an array, we don't need to do any flattening
292*700637cbSDimitry Andric     if (!isa<ArrayType>(RootTy))
293*700637cbSDimitry Andric       return false;
294*700637cbSDimitry Andric 
295*700637cbSDimitry Andric     Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
296*700637cbSDimitry Andric   }
297*700637cbSDimitry Andric 
298*700637cbSDimitry Andric   // GEPs without users or GEPs with non-GEP users should be replaced such that
299*700637cbSDimitry Andric   // the chain of GEPs they are a part of are collapsed to a single GEP into a
300*700637cbSDimitry Andric   // flattened array.
301*700637cbSDimitry Andric   bool ReplaceThisGEP = GEP.users().empty();
302*700637cbSDimitry Andric   for (Value *User : GEP.users())
303*700637cbSDimitry Andric     if (!isa<GetElementPtrInst>(User))
304*700637cbSDimitry Andric       ReplaceThisGEP = true;
305*700637cbSDimitry Andric 
306*700637cbSDimitry Andric   if (ReplaceThisGEP) {
307*700637cbSDimitry Andric     unsigned BytesPerElem =
308*700637cbSDimitry Andric         DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());
309*700637cbSDimitry Andric     assert(isPowerOf2_32(BytesPerElem) &&
310*700637cbSDimitry Andric            "Bytes per element should be a power of 2");
311*700637cbSDimitry Andric 
312*700637cbSDimitry Andric     // Compute the 32-bit index for this flattened GEP from the constant and
313*700637cbSDimitry Andric     // variable byte offsets in the GEPInfo
314*700637cbSDimitry Andric     IRBuilder<> Builder(&GEP);
315*700637cbSDimitry Andric     Value *ZeroIndex = Builder.getInt32(0);
316*700637cbSDimitry Andric     uint64_t ConstantOffset =
317*700637cbSDimitry Andric         Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
318*700637cbSDimitry Andric     assert(ConstantOffset < UINT32_MAX &&
319*700637cbSDimitry Andric            "Constant byte offset for flat GEP index must fit within 32 bits");
320*700637cbSDimitry Andric     Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
321*700637cbSDimitry Andric     for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
322*700637cbSDimitry Andric       assert(Multiplier.getActiveBits() <= 32 &&
323*700637cbSDimitry Andric              "The multiplier for a flat GEP index must fit within 32 bits");
324*700637cbSDimitry Andric       assert(VarIndex->getType()->isIntegerTy(32) &&
325*700637cbSDimitry Andric              "Expected i32-typed GEP indices");
326*700637cbSDimitry Andric       Value *VI;
327*700637cbSDimitry Andric       if (Multiplier.getZExtValue() % BytesPerElem != 0) {
328*700637cbSDimitry Andric         // This can happen, e.g., with i8 GEPs. To handle this we just divide
329*700637cbSDimitry Andric         // by BytesPerElem using an instruction after multiplying VarIndex by
330*700637cbSDimitry Andric         // Multiplier.
331*700637cbSDimitry Andric         VI = Builder.CreateMul(VarIndex,
332*700637cbSDimitry Andric                                Builder.getInt32(Multiplier.getZExtValue()));
333*700637cbSDimitry Andric         VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));
334*700637cbSDimitry Andric       } else
335*700637cbSDimitry Andric         VI = Builder.CreateMul(
336*700637cbSDimitry Andric             VarIndex,
337*700637cbSDimitry Andric             Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
338*700637cbSDimitry Andric       FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
339*700637cbSDimitry Andric     }
340*700637cbSDimitry Andric 
341*700637cbSDimitry Andric     // Construct a new GEP for the flattened array to replace the current GEP
342*700637cbSDimitry Andric     Value *NewGEP = Builder.CreateGEP(
343*700637cbSDimitry Andric         Info.RootFlattenedArrayType, Info.RootPointerOperand,
344*700637cbSDimitry Andric         {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
345*700637cbSDimitry Andric 
346*700637cbSDimitry Andric     // Replace the current GEP with the new GEP. Store GEPInfo into the map
347*700637cbSDimitry Andric     // for later use in case this GEP was not the end of the chain
348*700637cbSDimitry Andric     GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
349*700637cbSDimitry Andric     GEP.replaceAllUsesWith(NewGEP);
350*700637cbSDimitry Andric     GEP.eraseFromParent();
351*700637cbSDimitry Andric     return true;
352*700637cbSDimitry Andric   }
353*700637cbSDimitry Andric 
354*700637cbSDimitry Andric   // This GEP is potentially dead at the end of the pass since it may not have
355*700637cbSDimitry Andric   // any users anymore after GEP chains have been collapsed. We retain store
356*700637cbSDimitry Andric   // GEPInfo for GEPs down the chain to use to compute their indices.
357*700637cbSDimitry Andric   GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
358*700637cbSDimitry Andric   PotentiallyDeadInstrs.emplace_back(&GEP);
359*700637cbSDimitry Andric   return false;
360*700637cbSDimitry Andric }
361*700637cbSDimitry Andric 
visit(Function & F)362*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visit(Function &F) {
363*700637cbSDimitry Andric   bool MadeChange = false;
364*700637cbSDimitry Andric   ReversePostOrderTraversal<Function *> RPOT(&F);
365*700637cbSDimitry Andric   for (BasicBlock *BB : make_early_inc_range(RPOT)) {
366*700637cbSDimitry Andric     for (Instruction &I : make_early_inc_range(*BB))
367*700637cbSDimitry Andric       MadeChange |= InstVisitor::visit(I);
368*700637cbSDimitry Andric   }
369*700637cbSDimitry Andric   finish();
370*700637cbSDimitry Andric   return MadeChange;
371*700637cbSDimitry Andric }
372*700637cbSDimitry Andric 
collectElements(Constant * Init,SmallVectorImpl<Constant * > & Elements)373*700637cbSDimitry Andric static void collectElements(Constant *Init,
374*700637cbSDimitry Andric                             SmallVectorImpl<Constant *> &Elements) {
375*700637cbSDimitry Andric   // Base case: If Init is not an array, add it directly to the vector.
376*700637cbSDimitry Andric   auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
377*700637cbSDimitry Andric   if (!ArrayTy) {
378*700637cbSDimitry Andric     Elements.push_back(Init);
379*700637cbSDimitry Andric     return;
380*700637cbSDimitry Andric   }
381*700637cbSDimitry Andric   unsigned ArrSize = ArrayTy->getNumElements();
382*700637cbSDimitry Andric   if (isa<ConstantAggregateZero>(Init)) {
383*700637cbSDimitry Andric     for (unsigned I = 0; I < ArrSize; ++I)
384*700637cbSDimitry Andric       Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
385*700637cbSDimitry Andric     return;
386*700637cbSDimitry Andric   }
387*700637cbSDimitry Andric 
388*700637cbSDimitry Andric   // Recursive case: Process each element in the array.
389*700637cbSDimitry Andric   if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
390*700637cbSDimitry Andric     for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
391*700637cbSDimitry Andric       collectElements(ArrayConstant->getOperand(I), Elements);
392*700637cbSDimitry Andric     }
393*700637cbSDimitry Andric   } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
394*700637cbSDimitry Andric     for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
395*700637cbSDimitry Andric       collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
396*700637cbSDimitry Andric     }
397*700637cbSDimitry Andric   } else {
398*700637cbSDimitry Andric     llvm_unreachable(
399*700637cbSDimitry Andric         "Expected a ConstantArray or ConstantDataArray for array initializer!");
400*700637cbSDimitry Andric   }
401*700637cbSDimitry Andric }
402*700637cbSDimitry Andric 
transformInitializer(Constant * Init,Type * OrigType,ArrayType * FlattenedType,LLVMContext & Ctx)403*700637cbSDimitry Andric static Constant *transformInitializer(Constant *Init, Type *OrigType,
404*700637cbSDimitry Andric                                       ArrayType *FlattenedType,
405*700637cbSDimitry Andric                                       LLVMContext &Ctx) {
406*700637cbSDimitry Andric   // Handle ConstantAggregateZero (zero-initialized constants)
407*700637cbSDimitry Andric   if (isa<ConstantAggregateZero>(Init))
408*700637cbSDimitry Andric     return ConstantAggregateZero::get(FlattenedType);
409*700637cbSDimitry Andric 
410*700637cbSDimitry Andric   // Handle UndefValue (undefined constants)
411*700637cbSDimitry Andric   if (isa<UndefValue>(Init))
412*700637cbSDimitry Andric     return UndefValue::get(FlattenedType);
413*700637cbSDimitry Andric 
414*700637cbSDimitry Andric   if (!isa<ArrayType>(OrigType))
415*700637cbSDimitry Andric     return Init;
416*700637cbSDimitry Andric 
417*700637cbSDimitry Andric   SmallVector<Constant *> FlattenedElements;
418*700637cbSDimitry Andric   collectElements(Init, FlattenedElements);
419*700637cbSDimitry Andric   assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
420*700637cbSDimitry Andric          "The number of collected elements should match the FlattenedType");
421*700637cbSDimitry Andric   return ConstantArray::get(FlattenedType, FlattenedElements);
422*700637cbSDimitry Andric }
423*700637cbSDimitry Andric 
flattenGlobalArrays(Module & M,SmallDenseMap<GlobalVariable *,GlobalVariable * > & GlobalMap)424*700637cbSDimitry Andric static void flattenGlobalArrays(
425*700637cbSDimitry Andric     Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
426*700637cbSDimitry Andric   LLVMContext &Ctx = M.getContext();
427*700637cbSDimitry Andric   for (GlobalVariable &G : M.globals()) {
428*700637cbSDimitry Andric     Type *OrigType = G.getValueType();
429*700637cbSDimitry Andric     if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
430*700637cbSDimitry Andric       continue;
431*700637cbSDimitry Andric 
432*700637cbSDimitry Andric     ArrayType *ArrType = cast<ArrayType>(OrigType);
433*700637cbSDimitry Andric     auto [TotalElements, BaseType] =
434*700637cbSDimitry Andric         DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
435*700637cbSDimitry Andric     ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
436*700637cbSDimitry Andric 
437*700637cbSDimitry Andric     // Create a new global variable with the updated type
438*700637cbSDimitry Andric     // Note: Initializer is set via transformInitializer
439*700637cbSDimitry Andric     GlobalVariable *NewGlobal =
440*700637cbSDimitry Andric         new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
441*700637cbSDimitry Andric                            /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
442*700637cbSDimitry Andric                            G.getThreadLocalMode(), G.getAddressSpace(),
443*700637cbSDimitry Andric                            G.isExternallyInitialized());
444*700637cbSDimitry Andric 
445*700637cbSDimitry Andric     // Copy relevant attributes
446*700637cbSDimitry Andric     NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
447*700637cbSDimitry Andric     if (G.getAlignment() > 0) {
448*700637cbSDimitry Andric       NewGlobal->setAlignment(G.getAlign());
449*700637cbSDimitry Andric     }
450*700637cbSDimitry Andric 
451*700637cbSDimitry Andric     if (G.hasInitializer()) {
452*700637cbSDimitry Andric       Constant *Init = G.getInitializer();
453*700637cbSDimitry Andric       Constant *NewInit =
454*700637cbSDimitry Andric           transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
455*700637cbSDimitry Andric       NewGlobal->setInitializer(NewInit);
456*700637cbSDimitry Andric     }
457*700637cbSDimitry Andric     GlobalMap[&G] = NewGlobal;
458*700637cbSDimitry Andric   }
459*700637cbSDimitry Andric }
460*700637cbSDimitry Andric 
flattenArrays(Module & M)461*700637cbSDimitry Andric static bool flattenArrays(Module &M) {
462*700637cbSDimitry Andric   bool MadeChange = false;
463*700637cbSDimitry Andric   SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
464*700637cbSDimitry Andric   flattenGlobalArrays(M, GlobalMap);
465*700637cbSDimitry Andric   DXILFlattenArraysVisitor Impl(GlobalMap);
466*700637cbSDimitry Andric   for (auto &F : make_early_inc_range(M.functions())) {
467*700637cbSDimitry Andric     if (F.isDeclaration())
468*700637cbSDimitry Andric       continue;
469*700637cbSDimitry Andric     MadeChange |= Impl.visit(F);
470*700637cbSDimitry Andric   }
471*700637cbSDimitry Andric   for (auto &[Old, New] : GlobalMap) {
472*700637cbSDimitry Andric     Old->replaceAllUsesWith(New);
473*700637cbSDimitry Andric     Old->eraseFromParent();
474*700637cbSDimitry Andric     MadeChange = true;
475*700637cbSDimitry Andric   }
476*700637cbSDimitry Andric   return MadeChange;
477*700637cbSDimitry Andric }
478*700637cbSDimitry Andric 
run(Module & M,ModuleAnalysisManager &)479*700637cbSDimitry Andric PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
480*700637cbSDimitry Andric   bool MadeChanges = flattenArrays(M);
481*700637cbSDimitry Andric   if (!MadeChanges)
482*700637cbSDimitry Andric     return PreservedAnalyses::all();
483*700637cbSDimitry Andric   PreservedAnalyses PA;
484*700637cbSDimitry Andric   return PA;
485*700637cbSDimitry Andric }
486*700637cbSDimitry Andric 
runOnModule(Module & M)487*700637cbSDimitry Andric bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
488*700637cbSDimitry Andric   return flattenArrays(M);
489*700637cbSDimitry Andric }
490*700637cbSDimitry Andric 
491*700637cbSDimitry Andric char DXILFlattenArraysLegacy::ID = 0;
492*700637cbSDimitry Andric 
493*700637cbSDimitry Andric INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
494*700637cbSDimitry Andric                       "DXIL Array Flattener", false, false)
495*700637cbSDimitry Andric INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
496*700637cbSDimitry Andric                     false, false)
497*700637cbSDimitry Andric 
createDXILFlattenArraysLegacyPass()498*700637cbSDimitry Andric ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
499*700637cbSDimitry Andric   return new DXILFlattenArraysLegacy();
500*700637cbSDimitry Andric }
501