1 //===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8
9 #include "DXILDataScalarization.h"
10 #include "DirectX.h"
11 #include "llvm/ADT/PostOrderIterator.h"
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/IR/DerivedTypes.h"
14 #include "llvm/IR/GlobalVariable.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/InstVisitor.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/Operator.h"
20 #include "llvm/IR/PassManager.h"
21 #include "llvm/IR/ReplaceConstant.h"
22 #include "llvm/IR/Type.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Transforms/Utils/Cloning.h"
25 #include "llvm/Transforms/Utils/Local.h"
26
27 #define DEBUG_TYPE "dxil-data-scalarization"
28 static const int MaxVecSize = 4;
29
30 using namespace llvm;
31
32 // Recursively creates an array-like version of a given vector type.
equivalentArrayTypeFromVector(Type * T)33 static Type *equivalentArrayTypeFromVector(Type *T) {
34 if (auto *VecTy = dyn_cast<VectorType>(T))
35 return ArrayType::get(VecTy->getElementType(),
36 dyn_cast<FixedVectorType>(VecTy)->getNumElements());
37 if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
38 Type *NewElementType =
39 equivalentArrayTypeFromVector(ArrayTy->getElementType());
40 return ArrayType::get(NewElementType, ArrayTy->getNumElements());
41 }
42 // If it's not a vector or array, return the original type.
43 return T;
44 }
45
46 class DXILDataScalarizationLegacy : public ModulePass {
47
48 public:
49 bool runOnModule(Module &M) override;
DXILDataScalarizationLegacy()50 DXILDataScalarizationLegacy() : ModulePass(ID) {}
51
52 static char ID; // Pass identification.
53 };
54
55 static bool findAndReplaceVectors(Module &M);
56
57 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
58 public:
DataScalarizerVisitor()59 DataScalarizerVisitor() : GlobalMap() {}
60 bool visit(Function &F);
61 // InstVisitor methods. They return true if the instruction was scalarized,
62 // false if nothing changed.
63 bool visitAllocaInst(AllocaInst &AI);
visitInstruction(Instruction & I)64 bool visitInstruction(Instruction &I) { return false; }
visitSelectInst(SelectInst & SI)65 bool visitSelectInst(SelectInst &SI) { return false; }
visitICmpInst(ICmpInst & ICI)66 bool visitICmpInst(ICmpInst &ICI) { return false; }
visitFCmpInst(FCmpInst & FCI)67 bool visitFCmpInst(FCmpInst &FCI) { return false; }
visitUnaryOperator(UnaryOperator & UO)68 bool visitUnaryOperator(UnaryOperator &UO) { return false; }
visitBinaryOperator(BinaryOperator & BO)69 bool visitBinaryOperator(BinaryOperator &BO) { return false; }
70 bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
visitCastInst(CastInst & CI)71 bool visitCastInst(CastInst &CI) { return false; }
visitBitCastInst(BitCastInst & BCI)72 bool visitBitCastInst(BitCastInst &BCI) { return false; }
73 bool visitInsertElementInst(InsertElementInst &IEI);
74 bool visitExtractElementInst(ExtractElementInst &EEI);
visitShuffleVectorInst(ShuffleVectorInst & SVI)75 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
visitPHINode(PHINode & PHI)76 bool visitPHINode(PHINode &PHI) { return false; }
77 bool visitLoadInst(LoadInst &LI);
78 bool visitStoreInst(StoreInst &SI);
visitCallInst(CallInst & ICI)79 bool visitCallInst(CallInst &ICI) { return false; }
visitFreezeInst(FreezeInst & FI)80 bool visitFreezeInst(FreezeInst &FI) { return false; }
81 friend bool findAndReplaceVectors(llvm::Module &M);
82
83 private:
84 typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;
85 typedef SmallDenseMap<Value *, AllocaAndGEPs>
86 VectorToArrayMap; // A map from a vector-typed Value to its corresponding
87 // AllocaInst and GEPs to each element of an array
88 VectorToArrayMap VectorAllocaMap;
89 AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
90 const Twine &Name);
91 bool replaceDynamicInsertElementInst(InsertElementInst &IEI);
92 bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);
93
94 GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
95 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
96 };
97
visit(Function & F)98 bool DataScalarizerVisitor::visit(Function &F) {
99 bool MadeChange = false;
100 ReversePostOrderTraversal<Function *> RPOT(&F);
101 for (BasicBlock *BB : make_early_inc_range(RPOT)) {
102 for (Instruction &I : make_early_inc_range(*BB))
103 MadeChange |= InstVisitor::visit(I);
104 }
105 VectorAllocaMap.clear();
106 return MadeChange;
107 }
108
109 GlobalVariable *
lookupReplacementGlobal(Value * CurrOperand)110 DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
111 if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
112 auto It = GlobalMap.find(OldGlobal);
113 if (It != GlobalMap.end()) {
114 return It->second; // Found, return the new global
115 }
116 }
117 return nullptr; // Not found
118 }
119
120 // Helper function to check if a type is a vector or an array of vectors
isVectorOrArrayOfVectors(Type * T)121 static bool isVectorOrArrayOfVectors(Type *T) {
122 if (isa<VectorType>(T))
123 return true;
124 if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
125 return isa<VectorType>(ArrType->getElementType()) ||
126 isVectorOrArrayOfVectors(ArrType->getElementType());
127 return false;
128 }
129
visitAllocaInst(AllocaInst & AI)130 bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
131 Type *AllocatedType = AI.getAllocatedType();
132 if (!isVectorOrArrayOfVectors(AllocatedType))
133 return false;
134
135 IRBuilder<> Builder(&AI);
136 Type *NewType = equivalentArrayTypeFromVector(AllocatedType);
137 AllocaInst *ArrAlloca =
138 Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
139 ArrAlloca->setAlignment(AI.getAlign());
140 AI.replaceAllUsesWith(ArrAlloca);
141 AI.eraseFromParent();
142 return true;
143 }
144
visitLoadInst(LoadInst & LI)145 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
146 Value *PtrOperand = LI.getPointerOperand();
147 ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
148 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
149 GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
150 OldGEP->insertBefore(LI.getIterator());
151 IRBuilder<> Builder(&LI);
152 LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
153 NewLoad->setAlignment(LI.getAlign());
154 LI.replaceAllUsesWith(NewLoad);
155 LI.eraseFromParent();
156 visitGetElementPtrInst(*OldGEP);
157 return true;
158 }
159 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
160 LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
161 return false;
162 }
163
visitStoreInst(StoreInst & SI)164 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
165
166 Value *PtrOperand = SI.getPointerOperand();
167 ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
168 if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
169 GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
170 OldGEP->insertBefore(SI.getIterator());
171 IRBuilder<> Builder(&SI);
172 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
173 NewStore->setAlignment(SI.getAlign());
174 SI.replaceAllUsesWith(NewStore);
175 SI.eraseFromParent();
176 visitGetElementPtrInst(*OldGEP);
177 return true;
178 }
179 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
180 SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);
181
182 return false;
183 }
184
185 DataScalarizerVisitor::AllocaAndGEPs
createArrayFromVector(IRBuilder<> & Builder,Value * Vec,const Twine & Name="")186 DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
187 const Twine &Name = "") {
188 // If there is already an alloca for this vector, return it
189 if (VectorAllocaMap.contains(Vec))
190 return VectorAllocaMap[Vec];
191
192 auto InsertPoint = Builder.GetInsertPoint();
193
194 // Allocate the array to hold the vector elements
195 Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());
196 Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
197 AllocaInst *ArrAlloca =
198 Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");
199 const uint64_t ArrNumElems = ArrTy->getArrayNumElements();
200
201 // Create loads and stores to populate the array immediately after the
202 // original vector's defining instruction if available, else immediately after
203 // the alloca
204 if (auto *Instr = dyn_cast<Instruction>(Vec))
205 Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());
206 SmallVector<Value *, 4> GEPs(ArrNumElems);
207 for (unsigned I = 0; I < ArrNumElems; ++I) {
208 Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
209 GEPs[I] = Builder.CreateInBoundsGEP(
210 ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},
211 Name + ".index");
212 Builder.CreateStore(EE, GEPs[I]);
213 }
214
215 VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});
216 Builder.SetInsertPoint(InsertPoint);
217 return {ArrAlloca, GEPs};
218 }
219
220 /// Returns a pair of Value* with the first being a GEP into ArrAlloca using
221 /// indices {0, Index}, and the second Value* being a Load of the GEP
222 static std::pair<Value *, Value *>
dynamicallyLoadArray(IRBuilder<> & Builder,AllocaInst * ArrAlloca,Value * Index,const Twine & Name="")223 dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index,
224 const Twine &Name = "") {
225 Type *ArrTy = ArrAlloca->getAllocatedType();
226 Value *GEP = Builder.CreateInBoundsGEP(
227 ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index");
228 Value *Load =
229 Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load");
230 return std::make_pair(GEP, Load);
231 }
232
replaceDynamicInsertElementInst(InsertElementInst & IEI)233 bool DataScalarizerVisitor::replaceDynamicInsertElementInst(
234 InsertElementInst &IEI) {
235 IRBuilder<> Builder(&IEI);
236
237 Value *Vec = IEI.getOperand(0);
238 Value *Val = IEI.getOperand(1);
239 Value *Index = IEI.getOperand(2);
240
241 AllocaAndGEPs ArrAllocaAndGEPs =
242 createArrayFromVector(Builder, Vec, IEI.getName());
243 AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
244 Type *ArrTy = ArrAlloca->getAllocatedType();
245 SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;
246
247 auto GEPAndLoad =
248 dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName());
249 Value *GEP = GEPAndLoad.first;
250 Value *Load = GEPAndLoad.second;
251
252 Builder.CreateStore(Val, GEP);
253 Value *NewIEI = PoisonValue::get(Vec->getType());
254 for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
255 Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],
256 IEI.getName() + ".load");
257 NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),
258 IEI.getName() + ".insert");
259 }
260
261 // Store back the original value so the Alloca can be reused for subsequent
262 // insertelement instructions on the same vector
263 Builder.CreateStore(Load, GEP);
264
265 IEI.replaceAllUsesWith(NewIEI);
266 IEI.eraseFromParent();
267 return true;
268 }
269
visitInsertElementInst(InsertElementInst & IEI)270 bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
271 // If the index is a constant then we don't need to scalarize it
272 Value *Index = IEI.getOperand(2);
273 if (isa<ConstantInt>(Index))
274 return false;
275 return replaceDynamicInsertElementInst(IEI);
276 }
277
replaceDynamicExtractElementInst(ExtractElementInst & EEI)278 bool DataScalarizerVisitor::replaceDynamicExtractElementInst(
279 ExtractElementInst &EEI) {
280 IRBuilder<> Builder(&EEI);
281
282 AllocaAndGEPs ArrAllocaAndGEPs =
283 createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());
284 AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
285
286 auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca,
287 EEI.getIndexOperand(), EEI.getName());
288 Value *Load = GEPAndLoad.second;
289
290 EEI.replaceAllUsesWith(Load);
291 EEI.eraseFromParent();
292 return true;
293 }
294
visitExtractElementInst(ExtractElementInst & EEI)295 bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
296 // If the index is a constant then we don't need to scalarize it
297 Value *Index = EEI.getIndexOperand();
298 if (isa<ConstantInt>(Index))
299 return false;
300 return replaceDynamicExtractElementInst(EEI);
301 }
302
visitGetElementPtrInst(GetElementPtrInst & GEPI)303 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
304 Value *PtrOperand = GEPI.getPointerOperand();
305 Type *OrigGEPType = GEPI.getSourceElementType();
306 Type *NewGEPType = OrigGEPType;
307 bool NeedsTransform = false;
308
309 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
310 NewGEPType = NewGlobal->getValueType();
311 PtrOperand = NewGlobal;
312 NeedsTransform = true;
313 } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
314 Type *AllocatedType = Alloca->getAllocatedType();
315 // Only transform if the allocated type is an array
316 if (AllocatedType != OrigGEPType && isa<ArrayType>(AllocatedType)) {
317 NewGEPType = AllocatedType;
318 NeedsTransform = true;
319 }
320 }
321
322 // Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will
323 // convert these scalar geps into flattened array geps
324 if (!isa<ArrayType>(OrigGEPType))
325 NewGEPType = OrigGEPType;
326
327 // Note: We bail if this isn't a gep touched via alloca or global
328 // transformations
329 if (!NeedsTransform)
330 return false;
331
332 IRBuilder<> Builder(&GEPI);
333 SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
334
335 Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
336 GEPI.getName(), GEPI.getNoWrapFlags());
337 GEPI.replaceAllUsesWith(NewGEP);
338 GEPI.eraseFromParent();
339 return true;
340 }
341
transformInitializer(Constant * Init,Type * OrigType,Type * NewType,LLVMContext & Ctx)342 static Constant *transformInitializer(Constant *Init, Type *OrigType,
343 Type *NewType, LLVMContext &Ctx) {
344 // Handle ConstantAggregateZero (zero-initialized constants)
345 if (isa<ConstantAggregateZero>(Init)) {
346 return ConstantAggregateZero::get(NewType);
347 }
348
349 // Handle UndefValue (undefined constants)
350 if (isa<UndefValue>(Init)) {
351 return UndefValue::get(NewType);
352 }
353
354 // Handle vector to array transformation
355 if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
356 // Convert vector initializer to array initializer
357 SmallVector<Constant *, MaxVecSize> ArrayElements;
358 if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
359 for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
360 ArrayElements.push_back(ConstVecInit->getOperand(I));
361 } else if (ConstantDataVector *ConstDataVecInit =
362 llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
363 for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
364 ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
365 } else {
366 assert(false && "Expected a ConstantVector or ConstantDataVector for "
367 "vector initializer!");
368 }
369
370 return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
371 }
372
373 // Handle array of vectors transformation
374 if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
375 auto *ArrayInit = dyn_cast<ConstantArray>(Init);
376 assert(ArrayInit && "Expected a ConstantArray for array initializer!");
377
378 SmallVector<Constant *, MaxVecSize> NewArrayElements;
379 for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
380 // Recursively transform array elements
381 Constant *NewElemInit = transformInitializer(
382 ArrayInit->getOperand(I), ArrayTy->getElementType(),
383 cast<ArrayType>(NewType)->getElementType(), Ctx);
384 NewArrayElements.push_back(NewElemInit);
385 }
386
387 return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
388 }
389
390 // If not a vector or array, return the original initializer
391 return Init;
392 }
393
findAndReplaceVectors(Module & M)394 static bool findAndReplaceVectors(Module &M) {
395 bool MadeChange = false;
396 LLVMContext &Ctx = M.getContext();
397 IRBuilder<> Builder(Ctx);
398 DataScalarizerVisitor Impl;
399 for (GlobalVariable &G : M.globals()) {
400 Type *OrigType = G.getValueType();
401
402 Type *NewType = equivalentArrayTypeFromVector(OrigType);
403 if (OrigType != NewType) {
404 // Create a new global variable with the updated type
405 // Note: Initializer is set via transformInitializer
406 GlobalVariable *NewGlobal = new GlobalVariable(
407 M, NewType, G.isConstant(), G.getLinkage(),
408 /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
409 G.getThreadLocalMode(), G.getAddressSpace(),
410 G.isExternallyInitialized());
411
412 // Copy relevant attributes
413 NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
414 if (G.getAlignment() > 0) {
415 NewGlobal->setAlignment(G.getAlign());
416 }
417
418 if (G.hasInitializer()) {
419 Constant *Init = G.getInitializer();
420 Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
421 NewGlobal->setInitializer(NewInit);
422 }
423
424 // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
425 // type equality. Instead we will use the visitor pattern.
426 Impl.GlobalMap[&G] = NewGlobal;
427 }
428 }
429
430 for (auto &F : make_early_inc_range(M.functions())) {
431 if (F.isDeclaration())
432 continue;
433 MadeChange |= Impl.visit(F);
434 }
435
436 // Remove the old globals after the iteration
437 for (auto &[Old, New] : Impl.GlobalMap) {
438 Old->eraseFromParent();
439 MadeChange = true;
440 }
441 return MadeChange;
442 }
443
run(Module & M,ModuleAnalysisManager &)444 PreservedAnalyses DXILDataScalarization::run(Module &M,
445 ModuleAnalysisManager &) {
446 bool MadeChanges = findAndReplaceVectors(M);
447 if (!MadeChanges)
448 return PreservedAnalyses::all();
449 PreservedAnalyses PA;
450 return PA;
451 }
452
runOnModule(Module & M)453 bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
454 return findAndReplaceVectors(M);
455 }
456
457 char DXILDataScalarizationLegacy::ID = 0;
458
459 INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
460 "DXIL Data Scalarization", false, false)
461 INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
462 "DXIL Data Scalarization", false, false)
463
createDXILDataScalarizationLegacyPass()464 ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
465 return new DXILDataScalarizationLegacy();
466 }
467