xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILDataScalarization.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
9 #include "DXILDataScalarization.h"
10 #include "DirectX.h"
11 #include "llvm/ADT/PostOrderIterator.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/IR/DerivedTypes.h"
14 #include "llvm/IR/GlobalVariable.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/InstVisitor.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/Operator.h"
20 #include "llvm/IR/PassManager.h"
21 #include "llvm/IR/ReplaceConstant.h"
22 #include "llvm/IR/Type.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Transforms/Utils/Cloning.h"
25 #include "llvm/Transforms/Utils/Local.h"
26 
27 #define DEBUG_TYPE "dxil-data-scalarization"
28 static const int MaxVecSize = 4;
29 
30 using namespace llvm;
31 
32 // Recursively creates an array-like version of a given vector type.
33 static Type *equivalentArrayTypeFromVector(Type *T) {
34   if (auto *VecTy = dyn_cast<VectorType>(T))
35     return ArrayType::get(VecTy->getElementType(),
36                           dyn_cast<FixedVectorType>(VecTy)->getNumElements());
37   if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
38     Type *NewElementType =
39         equivalentArrayTypeFromVector(ArrayTy->getElementType());
40     return ArrayType::get(NewElementType, ArrayTy->getNumElements());
41   }
42   // If it's not a vector or array, return the original type.
43   return T;
44 }
45 
46 class DXILDataScalarizationLegacy : public ModulePass {
47 
48 public:
49   bool runOnModule(Module &M) override;
50   DXILDataScalarizationLegacy() : ModulePass(ID) {}
51 
52   static char ID; // Pass identification.
53 };
54 
55 static bool findAndReplaceVectors(Module &M);
56 
57 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
58 public:
59   DataScalarizerVisitor() : GlobalMap() {}
60   bool visit(Function &F);
61   // InstVisitor methods.  They return true if the instruction was scalarized,
62   // false if nothing changed.
63   bool visitAllocaInst(AllocaInst &AI);
64   bool visitInstruction(Instruction &I) { return false; }
65   bool visitSelectInst(SelectInst &SI) { return false; }
66   bool visitICmpInst(ICmpInst &ICI) { return false; }
67   bool visitFCmpInst(FCmpInst &FCI) { return false; }
68   bool visitUnaryOperator(UnaryOperator &UO) { return false; }
69   bool visitBinaryOperator(BinaryOperator &BO) { return false; }
70   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
71   bool visitCastInst(CastInst &CI) { return false; }
72   bool visitBitCastInst(BitCastInst &BCI) { return false; }
73   bool visitInsertElementInst(InsertElementInst &IEI);
74   bool visitExtractElementInst(ExtractElementInst &EEI);
75   bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
76   bool visitPHINode(PHINode &PHI) { return false; }
77   bool visitLoadInst(LoadInst &LI);
78   bool visitStoreInst(StoreInst &SI);
79   bool visitCallInst(CallInst &ICI) { return false; }
80   bool visitFreezeInst(FreezeInst &FI) { return false; }
81   friend bool findAndReplaceVectors(llvm::Module &M);
82 
83 private:
84   typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;
85   typedef SmallDenseMap<Value *, AllocaAndGEPs>
86       VectorToArrayMap; // A map from a vector-typed Value to its corresponding
87                         // AllocaInst and GEPs to each element of an array
88   VectorToArrayMap VectorAllocaMap;
89   AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
90                                       const Twine &Name);
91   bool replaceDynamicInsertElementInst(InsertElementInst &IEI);
92   bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);
93 
94   GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
95   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
96 };
97 
98 bool DataScalarizerVisitor::visit(Function &F) {
99   bool MadeChange = false;
100   ReversePostOrderTraversal<Function *> RPOT(&F);
101   for (BasicBlock *BB : make_early_inc_range(RPOT)) {
102     for (Instruction &I : make_early_inc_range(*BB))
103       MadeChange |= InstVisitor::visit(I);
104   }
105   VectorAllocaMap.clear();
106   return MadeChange;
107 }
108 
109 GlobalVariable *
110 DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
111   if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
112     auto It = GlobalMap.find(OldGlobal);
113     if (It != GlobalMap.end()) {
114       return It->second; // Found, return the new global
115     }
116   }
117   return nullptr; // Not found
118 }
119 
120 // Helper function to check if a type is a vector or an array of vectors
121 static bool isVectorOrArrayOfVectors(Type *T) {
122   if (isa<VectorType>(T))
123     return true;
124   if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
125     return isa<VectorType>(ArrType->getElementType()) ||
126            isVectorOrArrayOfVectors(ArrType->getElementType());
127   return false;
128 }
129 
130 bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
131   Type *AllocatedType = AI.getAllocatedType();
132   if (!isVectorOrArrayOfVectors(AllocatedType))
133     return false;
134 
135   IRBuilder<> Builder(&AI);
136   Type *NewType = equivalentArrayTypeFromVector(AllocatedType);
137   AllocaInst *ArrAlloca =
138       Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
139   ArrAlloca->setAlignment(AI.getAlign());
140   AI.replaceAllUsesWith(ArrAlloca);
141   AI.eraseFromParent();
142   return true;
143 }
144 
145 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
146   Value *PtrOperand = LI.getPointerOperand();
147   ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
148   if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
149     GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
150     OldGEP->insertBefore(LI.getIterator());
151     IRBuilder<> Builder(&LI);
152     LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
153     NewLoad->setAlignment(LI.getAlign());
154     LI.replaceAllUsesWith(NewLoad);
155     LI.eraseFromParent();
156     visitGetElementPtrInst(*OldGEP);
157     return true;
158   }
159   if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
160     LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
161   return false;
162 }
163 
164 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
165 
166   Value *PtrOperand = SI.getPointerOperand();
167   ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
168   if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
169     GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
170     OldGEP->insertBefore(SI.getIterator());
171     IRBuilder<> Builder(&SI);
172     StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
173     NewStore->setAlignment(SI.getAlign());
174     SI.replaceAllUsesWith(NewStore);
175     SI.eraseFromParent();
176     visitGetElementPtrInst(*OldGEP);
177     return true;
178   }
179   if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
180     SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);
181 
182   return false;
183 }
184 
185 DataScalarizerVisitor::AllocaAndGEPs
186 DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
187                                              const Twine &Name = "") {
188   // If there is already an alloca for this vector, return it
189   if (VectorAllocaMap.contains(Vec))
190     return VectorAllocaMap[Vec];
191 
192   auto InsertPoint = Builder.GetInsertPoint();
193 
194   // Allocate the array to hold the vector elements
195   Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());
196   Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
197   AllocaInst *ArrAlloca =
198       Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");
199   const uint64_t ArrNumElems = ArrTy->getArrayNumElements();
200 
201   // Create loads and stores to populate the array immediately after the
202   // original vector's defining instruction if available, else immediately after
203   // the alloca
204   if (auto *Instr = dyn_cast<Instruction>(Vec))
205     Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());
206   SmallVector<Value *, 4> GEPs(ArrNumElems);
207   for (unsigned I = 0; I < ArrNumElems; ++I) {
208     Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
209     GEPs[I] = Builder.CreateInBoundsGEP(
210         ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},
211         Name + ".index");
212     Builder.CreateStore(EE, GEPs[I]);
213   }
214 
215   VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});
216   Builder.SetInsertPoint(InsertPoint);
217   return {ArrAlloca, GEPs};
218 }
219 
220 /// Returns a pair of Value* with the first being a GEP into ArrAlloca using
221 /// indices {0, Index}, and the second Value* being a Load of the GEP
222 static std::pair<Value *, Value *>
223 dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index,
224                      const Twine &Name = "") {
225   Type *ArrTy = ArrAlloca->getAllocatedType();
226   Value *GEP = Builder.CreateInBoundsGEP(
227       ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index");
228   Value *Load =
229       Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load");
230   return std::make_pair(GEP, Load);
231 }
232 
233 bool DataScalarizerVisitor::replaceDynamicInsertElementInst(
234     InsertElementInst &IEI) {
235   IRBuilder<> Builder(&IEI);
236 
237   Value *Vec = IEI.getOperand(0);
238   Value *Val = IEI.getOperand(1);
239   Value *Index = IEI.getOperand(2);
240 
241   AllocaAndGEPs ArrAllocaAndGEPs =
242       createArrayFromVector(Builder, Vec, IEI.getName());
243   AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
244   Type *ArrTy = ArrAlloca->getAllocatedType();
245   SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;
246 
247   auto GEPAndLoad =
248       dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName());
249   Value *GEP = GEPAndLoad.first;
250   Value *Load = GEPAndLoad.second;
251 
252   Builder.CreateStore(Val, GEP);
253   Value *NewIEI = PoisonValue::get(Vec->getType());
254   for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
255     Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],
256                                      IEI.getName() + ".load");
257     NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),
258                                          IEI.getName() + ".insert");
259   }
260 
261   // Store back the original value so the Alloca can be reused for subsequent
262   // insertelement instructions on the same vector
263   Builder.CreateStore(Load, GEP);
264 
265   IEI.replaceAllUsesWith(NewIEI);
266   IEI.eraseFromParent();
267   return true;
268 }
269 
270 bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
271   // If the index is a constant then we don't need to scalarize it
272   Value *Index = IEI.getOperand(2);
273   if (isa<ConstantInt>(Index))
274     return false;
275   return replaceDynamicInsertElementInst(IEI);
276 }
277 
278 bool DataScalarizerVisitor::replaceDynamicExtractElementInst(
279     ExtractElementInst &EEI) {
280   IRBuilder<> Builder(&EEI);
281 
282   AllocaAndGEPs ArrAllocaAndGEPs =
283       createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());
284   AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
285 
286   auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca,
287                                          EEI.getIndexOperand(), EEI.getName());
288   Value *Load = GEPAndLoad.second;
289 
290   EEI.replaceAllUsesWith(Load);
291   EEI.eraseFromParent();
292   return true;
293 }
294 
295 bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
296   // If the index is a constant then we don't need to scalarize it
297   Value *Index = EEI.getIndexOperand();
298   if (isa<ConstantInt>(Index))
299     return false;
300   return replaceDynamicExtractElementInst(EEI);
301 }
302 
303 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
304   Value *PtrOperand = GEPI.getPointerOperand();
305   Type *OrigGEPType = GEPI.getSourceElementType();
306   Type *NewGEPType = OrigGEPType;
307   bool NeedsTransform = false;
308 
309   if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
310     NewGEPType = NewGlobal->getValueType();
311     PtrOperand = NewGlobal;
312     NeedsTransform = true;
313   } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
314     Type *AllocatedType = Alloca->getAllocatedType();
315     // Only transform if the allocated type is an array
316     if (AllocatedType != OrigGEPType && isa<ArrayType>(AllocatedType)) {
317       NewGEPType = AllocatedType;
318       NeedsTransform = true;
319     }
320   }
321 
322   // Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will
323   // convert these scalar geps into flattened array geps
324   if (!isa<ArrayType>(OrigGEPType))
325     NewGEPType = OrigGEPType;
326 
327   // Note: We bail if this isn't a gep touched via alloca or global
328   // transformations
329   if (!NeedsTransform)
330     return false;
331 
332   IRBuilder<> Builder(&GEPI);
333   SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
334 
335   Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
336                                     GEPI.getName(), GEPI.getNoWrapFlags());
337   GEPI.replaceAllUsesWith(NewGEP);
338   GEPI.eraseFromParent();
339   return true;
340 }
341 
342 static Constant *transformInitializer(Constant *Init, Type *OrigType,
343                                       Type *NewType, LLVMContext &Ctx) {
344   // Handle ConstantAggregateZero (zero-initialized constants)
345   if (isa<ConstantAggregateZero>(Init)) {
346     return ConstantAggregateZero::get(NewType);
347   }
348 
349   // Handle UndefValue (undefined constants)
350   if (isa<UndefValue>(Init)) {
351     return UndefValue::get(NewType);
352   }
353 
354   // Handle vector to array transformation
355   if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
356     // Convert vector initializer to array initializer
357     SmallVector<Constant *, MaxVecSize> ArrayElements;
358     if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
359       for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
360         ArrayElements.push_back(ConstVecInit->getOperand(I));
361     } else if (ConstantDataVector *ConstDataVecInit =
362                    llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
363       for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
364         ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
365     } else {
366       assert(false && "Expected a ConstantVector or ConstantDataVector for "
367                       "vector initializer!");
368     }
369 
370     return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
371   }
372 
373   // Handle array of vectors transformation
374   if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
375     auto *ArrayInit = dyn_cast<ConstantArray>(Init);
376     assert(ArrayInit && "Expected a ConstantArray for array initializer!");
377 
378     SmallVector<Constant *, MaxVecSize> NewArrayElements;
379     for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
380       // Recursively transform array elements
381       Constant *NewElemInit = transformInitializer(
382           ArrayInit->getOperand(I), ArrayTy->getElementType(),
383           cast<ArrayType>(NewType)->getElementType(), Ctx);
384       NewArrayElements.push_back(NewElemInit);
385     }
386 
387     return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
388   }
389 
390   // If not a vector or array, return the original initializer
391   return Init;
392 }
393 
394 static bool findAndReplaceVectors(Module &M) {
395   bool MadeChange = false;
396   LLVMContext &Ctx = M.getContext();
397   IRBuilder<> Builder(Ctx);
398   DataScalarizerVisitor Impl;
399   for (GlobalVariable &G : M.globals()) {
400     Type *OrigType = G.getValueType();
401 
402     Type *NewType = equivalentArrayTypeFromVector(OrigType);
403     if (OrigType != NewType) {
404       // Create a new global variable with the updated type
405       // Note: Initializer is set via transformInitializer
406       GlobalVariable *NewGlobal = new GlobalVariable(
407           M, NewType, G.isConstant(), G.getLinkage(),
408           /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
409           G.getThreadLocalMode(), G.getAddressSpace(),
410           G.isExternallyInitialized());
411 
412       // Copy relevant attributes
413       NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
414       if (G.getAlignment() > 0) {
415         NewGlobal->setAlignment(G.getAlign());
416       }
417 
418       if (G.hasInitializer()) {
419         Constant *Init = G.getInitializer();
420         Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
421         NewGlobal->setInitializer(NewInit);
422       }
423 
424       // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
425       // type equality. Instead we will use the visitor pattern.
426       Impl.GlobalMap[&G] = NewGlobal;
427     }
428   }
429 
430   for (auto &F : make_early_inc_range(M.functions())) {
431     if (F.isDeclaration())
432       continue;
433     MadeChange |= Impl.visit(F);
434   }
435 
436   // Remove the old globals after the iteration
437   for (auto &[Old, New] : Impl.GlobalMap) {
438     Old->eraseFromParent();
439     MadeChange = true;
440   }
441   return MadeChange;
442 }
443 
444 PreservedAnalyses DXILDataScalarization::run(Module &M,
445                                              ModuleAnalysisManager &) {
446   bool MadeChanges = findAndReplaceVectors(M);
447   if (!MadeChanges)
448     return PreservedAnalyses::all();
449   PreservedAnalyses PA;
450   return PA;
451 }
452 
453 bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
454   return findAndReplaceVectors(M);
455 }
456 
457 char DXILDataScalarizationLegacy::ID = 0;
458 
459 INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
460                       "DXIL Data Scalarization", false, false)
461 INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
462                     "DXIL Data Scalarization", false, false)
463 
464 ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
465   return new DXILDataScalarizationLegacy();
466 }
467