1*700637cbSDimitry Andric //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
2*700637cbSDimitry Andric //
3*700637cbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*700637cbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*700637cbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*700637cbSDimitry Andric //
7*700637cbSDimitry Andric //===---------------------------------------------------------------------===//
8*700637cbSDimitry Andric ///
9*700637cbSDimitry Andric /// \file This file contains a pass to flatten arrays for the DirectX Backend.
10*700637cbSDimitry Andric ///
11*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
12*700637cbSDimitry Andric
13*700637cbSDimitry Andric #include "DXILFlattenArrays.h"
14*700637cbSDimitry Andric #include "DirectX.h"
15*700637cbSDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
16*700637cbSDimitry Andric #include "llvm/ADT/STLExtras.h"
17*700637cbSDimitry Andric #include "llvm/IR/BasicBlock.h"
18*700637cbSDimitry Andric #include "llvm/IR/DerivedTypes.h"
19*700637cbSDimitry Andric #include "llvm/IR/IRBuilder.h"
20*700637cbSDimitry Andric #include "llvm/IR/InstVisitor.h"
21*700637cbSDimitry Andric #include "llvm/IR/ReplaceConstant.h"
22*700637cbSDimitry Andric #include "llvm/Support/Casting.h"
23*700637cbSDimitry Andric #include "llvm/Support/MathExtras.h"
24*700637cbSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
25*700637cbSDimitry Andric #include <cassert>
26*700637cbSDimitry Andric #include <cstddef>
27*700637cbSDimitry Andric #include <cstdint>
28*700637cbSDimitry Andric #include <utility>
29*700637cbSDimitry Andric
30*700637cbSDimitry Andric #define DEBUG_TYPE "dxil-flatten-arrays"
31*700637cbSDimitry Andric
32*700637cbSDimitry Andric using namespace llvm;
33*700637cbSDimitry Andric namespace {
34*700637cbSDimitry Andric
35*700637cbSDimitry Andric class DXILFlattenArraysLegacy : public ModulePass {
36*700637cbSDimitry Andric
37*700637cbSDimitry Andric public:
38*700637cbSDimitry Andric bool runOnModule(Module &M) override;
DXILFlattenArraysLegacy()39*700637cbSDimitry Andric DXILFlattenArraysLegacy() : ModulePass(ID) {}
40*700637cbSDimitry Andric
41*700637cbSDimitry Andric static char ID; // Pass identification.
42*700637cbSDimitry Andric };
43*700637cbSDimitry Andric
44*700637cbSDimitry Andric struct GEPInfo {
45*700637cbSDimitry Andric ArrayType *RootFlattenedArrayType;
46*700637cbSDimitry Andric Value *RootPointerOperand;
47*700637cbSDimitry Andric SmallMapVector<Value *, APInt, 4> VariableOffsets;
48*700637cbSDimitry Andric APInt ConstantOffset;
49*700637cbSDimitry Andric };
50*700637cbSDimitry Andric
51*700637cbSDimitry Andric class DXILFlattenArraysVisitor
52*700637cbSDimitry Andric : public InstVisitor<DXILFlattenArraysVisitor, bool> {
53*700637cbSDimitry Andric public:
DXILFlattenArraysVisitor(SmallDenseMap<GlobalVariable *,GlobalVariable * > & GlobalMap)54*700637cbSDimitry Andric DXILFlattenArraysVisitor(
55*700637cbSDimitry Andric SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
56*700637cbSDimitry Andric : GlobalMap(GlobalMap) {}
57*700637cbSDimitry Andric bool visit(Function &F);
58*700637cbSDimitry Andric // InstVisitor methods. They return true if the instruction was scalarized,
59*700637cbSDimitry Andric // false if nothing changed.
60*700637cbSDimitry Andric bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
61*700637cbSDimitry Andric bool visitAllocaInst(AllocaInst &AI);
visitInstruction(Instruction & I)62*700637cbSDimitry Andric bool visitInstruction(Instruction &I) { return false; }
visitSelectInst(SelectInst & SI)63*700637cbSDimitry Andric bool visitSelectInst(SelectInst &SI) { return false; }
visitICmpInst(ICmpInst & ICI)64*700637cbSDimitry Andric bool visitICmpInst(ICmpInst &ICI) { return false; }
visitFCmpInst(FCmpInst & FCI)65*700637cbSDimitry Andric bool visitFCmpInst(FCmpInst &FCI) { return false; }
visitUnaryOperator(UnaryOperator & UO)66*700637cbSDimitry Andric bool visitUnaryOperator(UnaryOperator &UO) { return false; }
visitBinaryOperator(BinaryOperator & BO)67*700637cbSDimitry Andric bool visitBinaryOperator(BinaryOperator &BO) { return false; }
visitCastInst(CastInst & CI)68*700637cbSDimitry Andric bool visitCastInst(CastInst &CI) { return false; }
visitBitCastInst(BitCastInst & BCI)69*700637cbSDimitry Andric bool visitBitCastInst(BitCastInst &BCI) { return false; }
visitInsertElementInst(InsertElementInst & IEI)70*700637cbSDimitry Andric bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
visitExtractElementInst(ExtractElementInst & EEI)71*700637cbSDimitry Andric bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
visitShuffleVectorInst(ShuffleVectorInst & SVI)72*700637cbSDimitry Andric bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
visitPHINode(PHINode & PHI)73*700637cbSDimitry Andric bool visitPHINode(PHINode &PHI) { return false; }
74*700637cbSDimitry Andric bool visitLoadInst(LoadInst &LI);
75*700637cbSDimitry Andric bool visitStoreInst(StoreInst &SI);
visitCallInst(CallInst & ICI)76*700637cbSDimitry Andric bool visitCallInst(CallInst &ICI) { return false; }
visitFreezeInst(FreezeInst & FI)77*700637cbSDimitry Andric bool visitFreezeInst(FreezeInst &FI) { return false; }
78*700637cbSDimitry Andric static bool isMultiDimensionalArray(Type *T);
79*700637cbSDimitry Andric static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
80*700637cbSDimitry Andric
81*700637cbSDimitry Andric private:
82*700637cbSDimitry Andric SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
83*700637cbSDimitry Andric SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
84*700637cbSDimitry Andric SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
85*700637cbSDimitry Andric bool finish();
86*700637cbSDimitry Andric ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
87*700637cbSDimitry Andric ArrayRef<uint64_t> Dims,
88*700637cbSDimitry Andric IRBuilder<> &Builder);
89*700637cbSDimitry Andric Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
90*700637cbSDimitry Andric ArrayRef<uint64_t> Dims,
91*700637cbSDimitry Andric IRBuilder<> &Builder);
92*700637cbSDimitry Andric };
93*700637cbSDimitry Andric } // namespace
94*700637cbSDimitry Andric
finish()95*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::finish() {
96*700637cbSDimitry Andric GEPChainInfoMap.clear();
97*700637cbSDimitry Andric RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
98*700637cbSDimitry Andric return true;
99*700637cbSDimitry Andric }
100*700637cbSDimitry Andric
isMultiDimensionalArray(Type * T)101*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
102*700637cbSDimitry Andric if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
103*700637cbSDimitry Andric return isa<ArrayType>(ArrType->getElementType());
104*700637cbSDimitry Andric return false;
105*700637cbSDimitry Andric }
106*700637cbSDimitry Andric
107*700637cbSDimitry Andric std::pair<unsigned, Type *>
getElementCountAndType(Type * ArrayTy)108*700637cbSDimitry Andric DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
109*700637cbSDimitry Andric unsigned TotalElements = 1;
110*700637cbSDimitry Andric Type *CurrArrayTy = ArrayTy;
111*700637cbSDimitry Andric while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
112*700637cbSDimitry Andric TotalElements *= InnerArrayTy->getNumElements();
113*700637cbSDimitry Andric CurrArrayTy = InnerArrayTy->getElementType();
114*700637cbSDimitry Andric }
115*700637cbSDimitry Andric return std::make_pair(TotalElements, CurrArrayTy);
116*700637cbSDimitry Andric }
117*700637cbSDimitry Andric
genConstFlattenIndices(ArrayRef<Value * > Indices,ArrayRef<uint64_t> Dims,IRBuilder<> & Builder)118*700637cbSDimitry Andric ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
119*700637cbSDimitry Andric ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
120*700637cbSDimitry Andric assert(Indices.size() == Dims.size() &&
121*700637cbSDimitry Andric "Indicies and dimmensions should be the same");
122*700637cbSDimitry Andric unsigned FlatIndex = 0;
123*700637cbSDimitry Andric unsigned Multiplier = 1;
124*700637cbSDimitry Andric
125*700637cbSDimitry Andric for (int I = Indices.size() - 1; I >= 0; --I) {
126*700637cbSDimitry Andric unsigned DimSize = Dims[I];
127*700637cbSDimitry Andric ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
128*700637cbSDimitry Andric assert(CIndex && "This function expects all indicies to be ConstantInt");
129*700637cbSDimitry Andric FlatIndex += CIndex->getZExtValue() * Multiplier;
130*700637cbSDimitry Andric Multiplier *= DimSize;
131*700637cbSDimitry Andric }
132*700637cbSDimitry Andric return Builder.getInt32(FlatIndex);
133*700637cbSDimitry Andric }
134*700637cbSDimitry Andric
genInstructionFlattenIndices(ArrayRef<Value * > Indices,ArrayRef<uint64_t> Dims,IRBuilder<> & Builder)135*700637cbSDimitry Andric Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
136*700637cbSDimitry Andric ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
137*700637cbSDimitry Andric if (Indices.size() == 1)
138*700637cbSDimitry Andric return Indices[0];
139*700637cbSDimitry Andric
140*700637cbSDimitry Andric Value *FlatIndex = Builder.getInt32(0);
141*700637cbSDimitry Andric unsigned Multiplier = 1;
142*700637cbSDimitry Andric
143*700637cbSDimitry Andric for (int I = Indices.size() - 1; I >= 0; --I) {
144*700637cbSDimitry Andric unsigned DimSize = Dims[I];
145*700637cbSDimitry Andric Value *VMultiplier = Builder.getInt32(Multiplier);
146*700637cbSDimitry Andric Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
147*700637cbSDimitry Andric FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
148*700637cbSDimitry Andric Multiplier *= DimSize;
149*700637cbSDimitry Andric }
150*700637cbSDimitry Andric return FlatIndex;
151*700637cbSDimitry Andric }
152*700637cbSDimitry Andric
visitLoadInst(LoadInst & LI)153*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
154*700637cbSDimitry Andric unsigned NumOperands = LI.getNumOperands();
155*700637cbSDimitry Andric for (unsigned I = 0; I < NumOperands; ++I) {
156*700637cbSDimitry Andric Value *CurrOpperand = LI.getOperand(I);
157*700637cbSDimitry Andric ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
158*700637cbSDimitry Andric if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
159*700637cbSDimitry Andric GetElementPtrInst *OldGEP =
160*700637cbSDimitry Andric cast<GetElementPtrInst>(CE->getAsInstruction());
161*700637cbSDimitry Andric OldGEP->insertBefore(LI.getIterator());
162*700637cbSDimitry Andric
163*700637cbSDimitry Andric IRBuilder<> Builder(&LI);
164*700637cbSDimitry Andric LoadInst *NewLoad =
165*700637cbSDimitry Andric Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
166*700637cbSDimitry Andric NewLoad->setAlignment(LI.getAlign());
167*700637cbSDimitry Andric LI.replaceAllUsesWith(NewLoad);
168*700637cbSDimitry Andric LI.eraseFromParent();
169*700637cbSDimitry Andric visitGetElementPtrInst(*OldGEP);
170*700637cbSDimitry Andric return true;
171*700637cbSDimitry Andric }
172*700637cbSDimitry Andric }
173*700637cbSDimitry Andric return false;
174*700637cbSDimitry Andric }
175*700637cbSDimitry Andric
visitStoreInst(StoreInst & SI)176*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
177*700637cbSDimitry Andric unsigned NumOperands = SI.getNumOperands();
178*700637cbSDimitry Andric for (unsigned I = 0; I < NumOperands; ++I) {
179*700637cbSDimitry Andric Value *CurrOpperand = SI.getOperand(I);
180*700637cbSDimitry Andric ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
181*700637cbSDimitry Andric if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
182*700637cbSDimitry Andric GetElementPtrInst *OldGEP =
183*700637cbSDimitry Andric cast<GetElementPtrInst>(CE->getAsInstruction());
184*700637cbSDimitry Andric OldGEP->insertBefore(SI.getIterator());
185*700637cbSDimitry Andric
186*700637cbSDimitry Andric IRBuilder<> Builder(&SI);
187*700637cbSDimitry Andric StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
188*700637cbSDimitry Andric NewStore->setAlignment(SI.getAlign());
189*700637cbSDimitry Andric SI.replaceAllUsesWith(NewStore);
190*700637cbSDimitry Andric SI.eraseFromParent();
191*700637cbSDimitry Andric visitGetElementPtrInst(*OldGEP);
192*700637cbSDimitry Andric return true;
193*700637cbSDimitry Andric }
194*700637cbSDimitry Andric }
195*700637cbSDimitry Andric return false;
196*700637cbSDimitry Andric }
197*700637cbSDimitry Andric
visitAllocaInst(AllocaInst & AI)198*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
199*700637cbSDimitry Andric if (!isMultiDimensionalArray(AI.getAllocatedType()))
200*700637cbSDimitry Andric return false;
201*700637cbSDimitry Andric
202*700637cbSDimitry Andric ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
203*700637cbSDimitry Andric IRBuilder<> Builder(&AI);
204*700637cbSDimitry Andric auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
205*700637cbSDimitry Andric
206*700637cbSDimitry Andric ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
207*700637cbSDimitry Andric AllocaInst *FlatAlloca =
208*700637cbSDimitry Andric Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".1dim");
209*700637cbSDimitry Andric FlatAlloca->setAlignment(AI.getAlign());
210*700637cbSDimitry Andric AI.replaceAllUsesWith(FlatAlloca);
211*700637cbSDimitry Andric AI.eraseFromParent();
212*700637cbSDimitry Andric return true;
213*700637cbSDimitry Andric }
214*700637cbSDimitry Andric
visitGetElementPtrInst(GetElementPtrInst & GEP)215*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
216*700637cbSDimitry Andric // Do not visit GEPs more than once
217*700637cbSDimitry Andric if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
218*700637cbSDimitry Andric return false;
219*700637cbSDimitry Andric
220*700637cbSDimitry Andric Value *PtrOperand = GEP.getPointerOperand();
221*700637cbSDimitry Andric // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI
222*700637cbSDimitry Andric // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets
223*700637cbSDimitry Andric // pointer types, then the handling of this case can be implemented later.
224*700637cbSDimitry Andric assert(!isa<PHINode>(PtrOperand) &&
225*700637cbSDimitry Andric "Pointer operand of GEP should not be a PHI Node");
226*700637cbSDimitry Andric
227*700637cbSDimitry Andric // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
228*700637cbSDimitry Andric // it can be visited
229*700637cbSDimitry Andric if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
230*700637cbSDimitry Andric PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
231*700637cbSDimitry Andric GetElementPtrInst *OldGEPI =
232*700637cbSDimitry Andric cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
233*700637cbSDimitry Andric OldGEPI->insertBefore(GEP.getIterator());
234*700637cbSDimitry Andric
235*700637cbSDimitry Andric IRBuilder<> Builder(&GEP);
236*700637cbSDimitry Andric SmallVector<Value *> Indices(GEP.indices());
237*700637cbSDimitry Andric Value *NewGEP =
238*700637cbSDimitry Andric Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
239*700637cbSDimitry Andric GEP.getName(), GEP.getNoWrapFlags());
240*700637cbSDimitry Andric assert(isa<GetElementPtrInst>(NewGEP) &&
241*700637cbSDimitry Andric "Expected newly-created GEP to be an instruction");
242*700637cbSDimitry Andric GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
243*700637cbSDimitry Andric
244*700637cbSDimitry Andric GEP.replaceAllUsesWith(NewGEPI);
245*700637cbSDimitry Andric GEP.eraseFromParent();
246*700637cbSDimitry Andric visitGetElementPtrInst(*OldGEPI);
247*700637cbSDimitry Andric visitGetElementPtrInst(*NewGEPI);
248*700637cbSDimitry Andric return true;
249*700637cbSDimitry Andric }
250*700637cbSDimitry Andric
251*700637cbSDimitry Andric // Construct GEPInfo for this GEP
252*700637cbSDimitry Andric GEPInfo Info;
253*700637cbSDimitry Andric
254*700637cbSDimitry Andric // Obtain the variable and constant byte offsets computed by this GEP
255*700637cbSDimitry Andric const DataLayout &DL = GEP.getDataLayout();
256*700637cbSDimitry Andric unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
257*700637cbSDimitry Andric Info.ConstantOffset = {BitWidth, 0};
258*700637cbSDimitry Andric [[maybe_unused]] bool Success = GEP.collectOffset(
259*700637cbSDimitry Andric DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset);
260*700637cbSDimitry Andric assert(Success && "Failed to collect offsets for GEP");
261*700637cbSDimitry Andric
262*700637cbSDimitry Andric // If there is a parent GEP, inherit the root array type and pointer, and
263*700637cbSDimitry Andric // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
264*700637cbSDimitry Andric // chain and we need to deterine the root array type
265*700637cbSDimitry Andric if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
266*700637cbSDimitry Andric assert(GEPChainInfoMap.contains(PtrOpGEP) &&
267*700637cbSDimitry Andric "Expected parent GEP to be visited before this GEP");
268*700637cbSDimitry Andric GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
269*700637cbSDimitry Andric Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
270*700637cbSDimitry Andric Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
271*700637cbSDimitry Andric for (auto &VariableOffset : PGEPInfo.VariableOffsets)
272*700637cbSDimitry Andric Info.VariableOffsets.insert(VariableOffset);
273*700637cbSDimitry Andric Info.ConstantOffset += PGEPInfo.ConstantOffset;
274*700637cbSDimitry Andric } else {
275*700637cbSDimitry Andric Info.RootPointerOperand = PtrOperand;
276*700637cbSDimitry Andric
277*700637cbSDimitry Andric // We should try to determine the type of the root from the pointer rather
278*700637cbSDimitry Andric // than the GEP's source element type because this could be a scalar GEP
279*700637cbSDimitry Andric // into an array-typed pointer from an Alloca or Global Variable.
280*700637cbSDimitry Andric Type *RootTy = GEP.getSourceElementType();
281*700637cbSDimitry Andric if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
282*700637cbSDimitry Andric if (GlobalMap.contains(GlobalVar))
283*700637cbSDimitry Andric GlobalVar = GlobalMap[GlobalVar];
284*700637cbSDimitry Andric Info.RootPointerOperand = GlobalVar;
285*700637cbSDimitry Andric RootTy = GlobalVar->getValueType();
286*700637cbSDimitry Andric } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
287*700637cbSDimitry Andric RootTy = Alloca->getAllocatedType();
288*700637cbSDimitry Andric assert(!isMultiDimensionalArray(RootTy) &&
289*700637cbSDimitry Andric "Expected root array type to be flattened");
290*700637cbSDimitry Andric
291*700637cbSDimitry Andric // If the root type is not an array, we don't need to do any flattening
292*700637cbSDimitry Andric if (!isa<ArrayType>(RootTy))
293*700637cbSDimitry Andric return false;
294*700637cbSDimitry Andric
295*700637cbSDimitry Andric Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
296*700637cbSDimitry Andric }
297*700637cbSDimitry Andric
298*700637cbSDimitry Andric // GEPs without users or GEPs with non-GEP users should be replaced such that
299*700637cbSDimitry Andric // the chain of GEPs they are a part of are collapsed to a single GEP into a
300*700637cbSDimitry Andric // flattened array.
301*700637cbSDimitry Andric bool ReplaceThisGEP = GEP.users().empty();
302*700637cbSDimitry Andric for (Value *User : GEP.users())
303*700637cbSDimitry Andric if (!isa<GetElementPtrInst>(User))
304*700637cbSDimitry Andric ReplaceThisGEP = true;
305*700637cbSDimitry Andric
306*700637cbSDimitry Andric if (ReplaceThisGEP) {
307*700637cbSDimitry Andric unsigned BytesPerElem =
308*700637cbSDimitry Andric DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());
309*700637cbSDimitry Andric assert(isPowerOf2_32(BytesPerElem) &&
310*700637cbSDimitry Andric "Bytes per element should be a power of 2");
311*700637cbSDimitry Andric
312*700637cbSDimitry Andric // Compute the 32-bit index for this flattened GEP from the constant and
313*700637cbSDimitry Andric // variable byte offsets in the GEPInfo
314*700637cbSDimitry Andric IRBuilder<> Builder(&GEP);
315*700637cbSDimitry Andric Value *ZeroIndex = Builder.getInt32(0);
316*700637cbSDimitry Andric uint64_t ConstantOffset =
317*700637cbSDimitry Andric Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
318*700637cbSDimitry Andric assert(ConstantOffset < UINT32_MAX &&
319*700637cbSDimitry Andric "Constant byte offset for flat GEP index must fit within 32 bits");
320*700637cbSDimitry Andric Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
321*700637cbSDimitry Andric for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
322*700637cbSDimitry Andric assert(Multiplier.getActiveBits() <= 32 &&
323*700637cbSDimitry Andric "The multiplier for a flat GEP index must fit within 32 bits");
324*700637cbSDimitry Andric assert(VarIndex->getType()->isIntegerTy(32) &&
325*700637cbSDimitry Andric "Expected i32-typed GEP indices");
326*700637cbSDimitry Andric Value *VI;
327*700637cbSDimitry Andric if (Multiplier.getZExtValue() % BytesPerElem != 0) {
328*700637cbSDimitry Andric // This can happen, e.g., with i8 GEPs. To handle this we just divide
329*700637cbSDimitry Andric // by BytesPerElem using an instruction after multiplying VarIndex by
330*700637cbSDimitry Andric // Multiplier.
331*700637cbSDimitry Andric VI = Builder.CreateMul(VarIndex,
332*700637cbSDimitry Andric Builder.getInt32(Multiplier.getZExtValue()));
333*700637cbSDimitry Andric VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));
334*700637cbSDimitry Andric } else
335*700637cbSDimitry Andric VI = Builder.CreateMul(
336*700637cbSDimitry Andric VarIndex,
337*700637cbSDimitry Andric Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
338*700637cbSDimitry Andric FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
339*700637cbSDimitry Andric }
340*700637cbSDimitry Andric
341*700637cbSDimitry Andric // Construct a new GEP for the flattened array to replace the current GEP
342*700637cbSDimitry Andric Value *NewGEP = Builder.CreateGEP(
343*700637cbSDimitry Andric Info.RootFlattenedArrayType, Info.RootPointerOperand,
344*700637cbSDimitry Andric {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
345*700637cbSDimitry Andric
346*700637cbSDimitry Andric // Replace the current GEP with the new GEP. Store GEPInfo into the map
347*700637cbSDimitry Andric // for later use in case this GEP was not the end of the chain
348*700637cbSDimitry Andric GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
349*700637cbSDimitry Andric GEP.replaceAllUsesWith(NewGEP);
350*700637cbSDimitry Andric GEP.eraseFromParent();
351*700637cbSDimitry Andric return true;
352*700637cbSDimitry Andric }
353*700637cbSDimitry Andric
354*700637cbSDimitry Andric // This GEP is potentially dead at the end of the pass since it may not have
355*700637cbSDimitry Andric // any users anymore after GEP chains have been collapsed. We retain store
356*700637cbSDimitry Andric // GEPInfo for GEPs down the chain to use to compute their indices.
357*700637cbSDimitry Andric GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
358*700637cbSDimitry Andric PotentiallyDeadInstrs.emplace_back(&GEP);
359*700637cbSDimitry Andric return false;
360*700637cbSDimitry Andric }
361*700637cbSDimitry Andric
visit(Function & F)362*700637cbSDimitry Andric bool DXILFlattenArraysVisitor::visit(Function &F) {
363*700637cbSDimitry Andric bool MadeChange = false;
364*700637cbSDimitry Andric ReversePostOrderTraversal<Function *> RPOT(&F);
365*700637cbSDimitry Andric for (BasicBlock *BB : make_early_inc_range(RPOT)) {
366*700637cbSDimitry Andric for (Instruction &I : make_early_inc_range(*BB))
367*700637cbSDimitry Andric MadeChange |= InstVisitor::visit(I);
368*700637cbSDimitry Andric }
369*700637cbSDimitry Andric finish();
370*700637cbSDimitry Andric return MadeChange;
371*700637cbSDimitry Andric }
372*700637cbSDimitry Andric
collectElements(Constant * Init,SmallVectorImpl<Constant * > & Elements)373*700637cbSDimitry Andric static void collectElements(Constant *Init,
374*700637cbSDimitry Andric SmallVectorImpl<Constant *> &Elements) {
375*700637cbSDimitry Andric // Base case: If Init is not an array, add it directly to the vector.
376*700637cbSDimitry Andric auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
377*700637cbSDimitry Andric if (!ArrayTy) {
378*700637cbSDimitry Andric Elements.push_back(Init);
379*700637cbSDimitry Andric return;
380*700637cbSDimitry Andric }
381*700637cbSDimitry Andric unsigned ArrSize = ArrayTy->getNumElements();
382*700637cbSDimitry Andric if (isa<ConstantAggregateZero>(Init)) {
383*700637cbSDimitry Andric for (unsigned I = 0; I < ArrSize; ++I)
384*700637cbSDimitry Andric Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
385*700637cbSDimitry Andric return;
386*700637cbSDimitry Andric }
387*700637cbSDimitry Andric
388*700637cbSDimitry Andric // Recursive case: Process each element in the array.
389*700637cbSDimitry Andric if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
390*700637cbSDimitry Andric for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
391*700637cbSDimitry Andric collectElements(ArrayConstant->getOperand(I), Elements);
392*700637cbSDimitry Andric }
393*700637cbSDimitry Andric } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
394*700637cbSDimitry Andric for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
395*700637cbSDimitry Andric collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
396*700637cbSDimitry Andric }
397*700637cbSDimitry Andric } else {
398*700637cbSDimitry Andric llvm_unreachable(
399*700637cbSDimitry Andric "Expected a ConstantArray or ConstantDataArray for array initializer!");
400*700637cbSDimitry Andric }
401*700637cbSDimitry Andric }
402*700637cbSDimitry Andric
transformInitializer(Constant * Init,Type * OrigType,ArrayType * FlattenedType,LLVMContext & Ctx)403*700637cbSDimitry Andric static Constant *transformInitializer(Constant *Init, Type *OrigType,
404*700637cbSDimitry Andric ArrayType *FlattenedType,
405*700637cbSDimitry Andric LLVMContext &Ctx) {
406*700637cbSDimitry Andric // Handle ConstantAggregateZero (zero-initialized constants)
407*700637cbSDimitry Andric if (isa<ConstantAggregateZero>(Init))
408*700637cbSDimitry Andric return ConstantAggregateZero::get(FlattenedType);
409*700637cbSDimitry Andric
410*700637cbSDimitry Andric // Handle UndefValue (undefined constants)
411*700637cbSDimitry Andric if (isa<UndefValue>(Init))
412*700637cbSDimitry Andric return UndefValue::get(FlattenedType);
413*700637cbSDimitry Andric
414*700637cbSDimitry Andric if (!isa<ArrayType>(OrigType))
415*700637cbSDimitry Andric return Init;
416*700637cbSDimitry Andric
417*700637cbSDimitry Andric SmallVector<Constant *> FlattenedElements;
418*700637cbSDimitry Andric collectElements(Init, FlattenedElements);
419*700637cbSDimitry Andric assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
420*700637cbSDimitry Andric "The number of collected elements should match the FlattenedType");
421*700637cbSDimitry Andric return ConstantArray::get(FlattenedType, FlattenedElements);
422*700637cbSDimitry Andric }
423*700637cbSDimitry Andric
flattenGlobalArrays(Module & M,SmallDenseMap<GlobalVariable *,GlobalVariable * > & GlobalMap)424*700637cbSDimitry Andric static void flattenGlobalArrays(
425*700637cbSDimitry Andric Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
426*700637cbSDimitry Andric LLVMContext &Ctx = M.getContext();
427*700637cbSDimitry Andric for (GlobalVariable &G : M.globals()) {
428*700637cbSDimitry Andric Type *OrigType = G.getValueType();
429*700637cbSDimitry Andric if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
430*700637cbSDimitry Andric continue;
431*700637cbSDimitry Andric
432*700637cbSDimitry Andric ArrayType *ArrType = cast<ArrayType>(OrigType);
433*700637cbSDimitry Andric auto [TotalElements, BaseType] =
434*700637cbSDimitry Andric DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
435*700637cbSDimitry Andric ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
436*700637cbSDimitry Andric
437*700637cbSDimitry Andric // Create a new global variable with the updated type
438*700637cbSDimitry Andric // Note: Initializer is set via transformInitializer
439*700637cbSDimitry Andric GlobalVariable *NewGlobal =
440*700637cbSDimitry Andric new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
441*700637cbSDimitry Andric /*Initializer=*/nullptr, G.getName() + ".1dim", &G,
442*700637cbSDimitry Andric G.getThreadLocalMode(), G.getAddressSpace(),
443*700637cbSDimitry Andric G.isExternallyInitialized());
444*700637cbSDimitry Andric
445*700637cbSDimitry Andric // Copy relevant attributes
446*700637cbSDimitry Andric NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
447*700637cbSDimitry Andric if (G.getAlignment() > 0) {
448*700637cbSDimitry Andric NewGlobal->setAlignment(G.getAlign());
449*700637cbSDimitry Andric }
450*700637cbSDimitry Andric
451*700637cbSDimitry Andric if (G.hasInitializer()) {
452*700637cbSDimitry Andric Constant *Init = G.getInitializer();
453*700637cbSDimitry Andric Constant *NewInit =
454*700637cbSDimitry Andric transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
455*700637cbSDimitry Andric NewGlobal->setInitializer(NewInit);
456*700637cbSDimitry Andric }
457*700637cbSDimitry Andric GlobalMap[&G] = NewGlobal;
458*700637cbSDimitry Andric }
459*700637cbSDimitry Andric }
460*700637cbSDimitry Andric
flattenArrays(Module & M)461*700637cbSDimitry Andric static bool flattenArrays(Module &M) {
462*700637cbSDimitry Andric bool MadeChange = false;
463*700637cbSDimitry Andric SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
464*700637cbSDimitry Andric flattenGlobalArrays(M, GlobalMap);
465*700637cbSDimitry Andric DXILFlattenArraysVisitor Impl(GlobalMap);
466*700637cbSDimitry Andric for (auto &F : make_early_inc_range(M.functions())) {
467*700637cbSDimitry Andric if (F.isDeclaration())
468*700637cbSDimitry Andric continue;
469*700637cbSDimitry Andric MadeChange |= Impl.visit(F);
470*700637cbSDimitry Andric }
471*700637cbSDimitry Andric for (auto &[Old, New] : GlobalMap) {
472*700637cbSDimitry Andric Old->replaceAllUsesWith(New);
473*700637cbSDimitry Andric Old->eraseFromParent();
474*700637cbSDimitry Andric MadeChange = true;
475*700637cbSDimitry Andric }
476*700637cbSDimitry Andric return MadeChange;
477*700637cbSDimitry Andric }
478*700637cbSDimitry Andric
run(Module & M,ModuleAnalysisManager &)479*700637cbSDimitry Andric PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
480*700637cbSDimitry Andric bool MadeChanges = flattenArrays(M);
481*700637cbSDimitry Andric if (!MadeChanges)
482*700637cbSDimitry Andric return PreservedAnalyses::all();
483*700637cbSDimitry Andric PreservedAnalyses PA;
484*700637cbSDimitry Andric return PA;
485*700637cbSDimitry Andric }
486*700637cbSDimitry Andric
runOnModule(Module & M)487*700637cbSDimitry Andric bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
488*700637cbSDimitry Andric return flattenArrays(M);
489*700637cbSDimitry Andric }
490*700637cbSDimitry Andric
491*700637cbSDimitry Andric char DXILFlattenArraysLegacy::ID = 0;
492*700637cbSDimitry Andric
493*700637cbSDimitry Andric INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
494*700637cbSDimitry Andric "DXIL Array Flattener", false, false)
495*700637cbSDimitry Andric INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
496*700637cbSDimitry Andric false, false)
497*700637cbSDimitry Andric
createDXILFlattenArraysLegacyPass()498*700637cbSDimitry Andric ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
499*700637cbSDimitry Andric return new DXILFlattenArraysLegacy();
500*700637cbSDimitry Andric }
501