1 //===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 9 #include "DXILDataScalarization.h" 10 #include "DirectX.h" 11 #include "llvm/ADT/PostOrderIterator.h" 12 #include "llvm/ADT/STLExtras.h" 13 #include "llvm/IR/DerivedTypes.h" 14 #include "llvm/IR/GlobalVariable.h" 15 #include "llvm/IR/IRBuilder.h" 16 #include "llvm/IR/InstVisitor.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/IR/Module.h" 19 #include "llvm/IR/Operator.h" 20 #include "llvm/IR/PassManager.h" 21 #include "llvm/IR/ReplaceConstant.h" 22 #include "llvm/IR/Type.h" 23 #include "llvm/Support/Casting.h" 24 #include "llvm/Transforms/Utils/Cloning.h" 25 #include "llvm/Transforms/Utils/Local.h" 26 27 #define DEBUG_TYPE "dxil-data-scalarization" 28 static const int MaxVecSize = 4; 29 30 using namespace llvm; 31 32 // Recursively creates an array-like version of a given vector type. 33 static Type *equivalentArrayTypeFromVector(Type *T) { 34 if (auto *VecTy = dyn_cast<VectorType>(T)) 35 return ArrayType::get(VecTy->getElementType(), 36 dyn_cast<FixedVectorType>(VecTy)->getNumElements()); 37 if (auto *ArrayTy = dyn_cast<ArrayType>(T)) { 38 Type *NewElementType = 39 equivalentArrayTypeFromVector(ArrayTy->getElementType()); 40 return ArrayType::get(NewElementType, ArrayTy->getNumElements()); 41 } 42 // If it's not a vector or array, return the original type. 43 return T; 44 } 45 46 class DXILDataScalarizationLegacy : public ModulePass { 47 48 public: 49 bool runOnModule(Module &M) override; 50 DXILDataScalarizationLegacy() : ModulePass(ID) {} 51 52 static char ID; // Pass identification. 53 }; 54 55 static bool findAndReplaceVectors(Module &M); 56 57 class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> { 58 public: 59 DataScalarizerVisitor() : GlobalMap() {} 60 bool visit(Function &F); 61 // InstVisitor methods. They return true if the instruction was scalarized, 62 // false if nothing changed. 63 bool visitAllocaInst(AllocaInst &AI); 64 bool visitInstruction(Instruction &I) { return false; } 65 bool visitSelectInst(SelectInst &SI) { return false; } 66 bool visitICmpInst(ICmpInst &ICI) { return false; } 67 bool visitFCmpInst(FCmpInst &FCI) { return false; } 68 bool visitUnaryOperator(UnaryOperator &UO) { return false; } 69 bool visitBinaryOperator(BinaryOperator &BO) { return false; } 70 bool visitGetElementPtrInst(GetElementPtrInst &GEPI); 71 bool visitCastInst(CastInst &CI) { return false; } 72 bool visitBitCastInst(BitCastInst &BCI) { return false; } 73 bool visitInsertElementInst(InsertElementInst &IEI); 74 bool visitExtractElementInst(ExtractElementInst &EEI); 75 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } 76 bool visitPHINode(PHINode &PHI) { return false; } 77 bool visitLoadInst(LoadInst &LI); 78 bool visitStoreInst(StoreInst &SI); 79 bool visitCallInst(CallInst &ICI) { return false; } 80 bool visitFreezeInst(FreezeInst &FI) { return false; } 81 friend bool findAndReplaceVectors(llvm::Module &M); 82 83 private: 84 typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs; 85 typedef SmallDenseMap<Value *, AllocaAndGEPs> 86 VectorToArrayMap; // A map from a vector-typed Value to its corresponding 87 // AllocaInst and GEPs to each element of an array 88 VectorToArrayMap VectorAllocaMap; 89 AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec, 90 const Twine &Name); 91 bool replaceDynamicInsertElementInst(InsertElementInst &IEI); 92 bool replaceDynamicExtractElementInst(ExtractElementInst &EEI); 93 94 GlobalVariable *lookupReplacementGlobal(Value *CurrOperand); 95 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; 96 }; 97 98 bool DataScalarizerVisitor::visit(Function &F) { 99 bool MadeChange = false; 100 ReversePostOrderTraversal<Function *> RPOT(&F); 101 for (BasicBlock *BB : make_early_inc_range(RPOT)) { 102 for (Instruction &I : make_early_inc_range(*BB)) 103 MadeChange |= InstVisitor::visit(I); 104 } 105 VectorAllocaMap.clear(); 106 return MadeChange; 107 } 108 109 GlobalVariable * 110 DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { 111 if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) { 112 auto It = GlobalMap.find(OldGlobal); 113 if (It != GlobalMap.end()) { 114 return It->second; // Found, return the new global 115 } 116 } 117 return nullptr; // Not found 118 } 119 120 // Helper function to check if a type is a vector or an array of vectors 121 static bool isVectorOrArrayOfVectors(Type *T) { 122 if (isa<VectorType>(T)) 123 return true; 124 if (ArrayType *ArrType = dyn_cast<ArrayType>(T)) 125 return isa<VectorType>(ArrType->getElementType()) || 126 isVectorOrArrayOfVectors(ArrType->getElementType()); 127 return false; 128 } 129 130 bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) { 131 Type *AllocatedType = AI.getAllocatedType(); 132 if (!isVectorOrArrayOfVectors(AllocatedType)) 133 return false; 134 135 IRBuilder<> Builder(&AI); 136 Type *NewType = equivalentArrayTypeFromVector(AllocatedType); 137 AllocaInst *ArrAlloca = 138 Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize"); 139 ArrAlloca->setAlignment(AI.getAlign()); 140 AI.replaceAllUsesWith(ArrAlloca); 141 AI.eraseFromParent(); 142 return true; 143 } 144 145 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { 146 Value *PtrOperand = LI.getPointerOperand(); 147 ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand); 148 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 149 GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction()); 150 OldGEP->insertBefore(LI.getIterator()); 151 IRBuilder<> Builder(&LI); 152 LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); 153 NewLoad->setAlignment(LI.getAlign()); 154 LI.replaceAllUsesWith(NewLoad); 155 LI.eraseFromParent(); 156 visitGetElementPtrInst(*OldGEP); 157 return true; 158 } 159 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) 160 LI.setOperand(LI.getPointerOperandIndex(), NewGlobal); 161 return false; 162 } 163 164 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { 165 166 Value *PtrOperand = SI.getPointerOperand(); 167 ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand); 168 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 169 GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction()); 170 OldGEP->insertBefore(SI.getIterator()); 171 IRBuilder<> Builder(&SI); 172 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); 173 NewStore->setAlignment(SI.getAlign()); 174 SI.replaceAllUsesWith(NewStore); 175 SI.eraseFromParent(); 176 visitGetElementPtrInst(*OldGEP); 177 return true; 178 } 179 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) 180 SI.setOperand(SI.getPointerOperandIndex(), NewGlobal); 181 182 return false; 183 } 184 185 DataScalarizerVisitor::AllocaAndGEPs 186 DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec, 187 const Twine &Name = "") { 188 // If there is already an alloca for this vector, return it 189 if (VectorAllocaMap.contains(Vec)) 190 return VectorAllocaMap[Vec]; 191 192 auto InsertPoint = Builder.GetInsertPoint(); 193 194 // Allocate the array to hold the vector elements 195 Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent()); 196 Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType()); 197 AllocaInst *ArrAlloca = 198 Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca"); 199 const uint64_t ArrNumElems = ArrTy->getArrayNumElements(); 200 201 // Create loads and stores to populate the array immediately after the 202 // original vector's defining instruction if available, else immediately after 203 // the alloca 204 if (auto *Instr = dyn_cast<Instruction>(Vec)) 205 Builder.SetInsertPoint(Instr->getNextNonDebugInstruction()); 206 SmallVector<Value *, 4> GEPs(ArrNumElems); 207 for (unsigned I = 0; I < ArrNumElems; ++I) { 208 Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract"); 209 GEPs[I] = Builder.CreateInBoundsGEP( 210 ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)}, 211 Name + ".index"); 212 Builder.CreateStore(EE, GEPs[I]); 213 } 214 215 VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}}); 216 Builder.SetInsertPoint(InsertPoint); 217 return {ArrAlloca, GEPs}; 218 } 219 220 /// Returns a pair of Value* with the first being a GEP into ArrAlloca using 221 /// indices {0, Index}, and the second Value* being a Load of the GEP 222 static std::pair<Value *, Value *> 223 dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index, 224 const Twine &Name = "") { 225 Type *ArrTy = ArrAlloca->getAllocatedType(); 226 Value *GEP = Builder.CreateInBoundsGEP( 227 ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index"); 228 Value *Load = 229 Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load"); 230 return std::make_pair(GEP, Load); 231 } 232 233 bool DataScalarizerVisitor::replaceDynamicInsertElementInst( 234 InsertElementInst &IEI) { 235 IRBuilder<> Builder(&IEI); 236 237 Value *Vec = IEI.getOperand(0); 238 Value *Val = IEI.getOperand(1); 239 Value *Index = IEI.getOperand(2); 240 241 AllocaAndGEPs ArrAllocaAndGEPs = 242 createArrayFromVector(Builder, Vec, IEI.getName()); 243 AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first; 244 Type *ArrTy = ArrAlloca->getAllocatedType(); 245 SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second; 246 247 auto GEPAndLoad = 248 dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName()); 249 Value *GEP = GEPAndLoad.first; 250 Value *Load = GEPAndLoad.second; 251 252 Builder.CreateStore(Val, GEP); 253 Value *NewIEI = PoisonValue::get(Vec->getType()); 254 for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) { 255 Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I], 256 IEI.getName() + ".load"); 257 NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I), 258 IEI.getName() + ".insert"); 259 } 260 261 // Store back the original value so the Alloca can be reused for subsequent 262 // insertelement instructions on the same vector 263 Builder.CreateStore(Load, GEP); 264 265 IEI.replaceAllUsesWith(NewIEI); 266 IEI.eraseFromParent(); 267 return true; 268 } 269 270 bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { 271 // If the index is a constant then we don't need to scalarize it 272 Value *Index = IEI.getOperand(2); 273 if (isa<ConstantInt>(Index)) 274 return false; 275 return replaceDynamicInsertElementInst(IEI); 276 } 277 278 bool DataScalarizerVisitor::replaceDynamicExtractElementInst( 279 ExtractElementInst &EEI) { 280 IRBuilder<> Builder(&EEI); 281 282 AllocaAndGEPs ArrAllocaAndGEPs = 283 createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName()); 284 AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first; 285 286 auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca, 287 EEI.getIndexOperand(), EEI.getName()); 288 Value *Load = GEPAndLoad.second; 289 290 EEI.replaceAllUsesWith(Load); 291 EEI.eraseFromParent(); 292 return true; 293 } 294 295 bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { 296 // If the index is a constant then we don't need to scalarize it 297 Value *Index = EEI.getIndexOperand(); 298 if (isa<ConstantInt>(Index)) 299 return false; 300 return replaceDynamicExtractElementInst(EEI); 301 } 302 303 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { 304 Value *PtrOperand = GEPI.getPointerOperand(); 305 Type *OrigGEPType = GEPI.getSourceElementType(); 306 Type *NewGEPType = OrigGEPType; 307 bool NeedsTransform = false; 308 309 if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) { 310 NewGEPType = NewGlobal->getValueType(); 311 PtrOperand = NewGlobal; 312 NeedsTransform = true; 313 } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) { 314 Type *AllocatedType = Alloca->getAllocatedType(); 315 // Only transform if the allocated type is an array 316 if (AllocatedType != OrigGEPType && isa<ArrayType>(AllocatedType)) { 317 NewGEPType = AllocatedType; 318 NeedsTransform = true; 319 } 320 } 321 322 // Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will 323 // convert these scalar geps into flattened array geps 324 if (!isa<ArrayType>(OrigGEPType)) 325 NewGEPType = OrigGEPType; 326 327 // Note: We bail if this isn't a gep touched via alloca or global 328 // transformations 329 if (!NeedsTransform) 330 return false; 331 332 IRBuilder<> Builder(&GEPI); 333 SmallVector<Value *, MaxVecSize> Indices(GEPI.indices()); 334 335 Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices, 336 GEPI.getName(), GEPI.getNoWrapFlags()); 337 GEPI.replaceAllUsesWith(NewGEP); 338 GEPI.eraseFromParent(); 339 return true; 340 } 341 342 static Constant *transformInitializer(Constant *Init, Type *OrigType, 343 Type *NewType, LLVMContext &Ctx) { 344 // Handle ConstantAggregateZero (zero-initialized constants) 345 if (isa<ConstantAggregateZero>(Init)) { 346 return ConstantAggregateZero::get(NewType); 347 } 348 349 // Handle UndefValue (undefined constants) 350 if (isa<UndefValue>(Init)) { 351 return UndefValue::get(NewType); 352 } 353 354 // Handle vector to array transformation 355 if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) { 356 // Convert vector initializer to array initializer 357 SmallVector<Constant *, MaxVecSize> ArrayElements; 358 if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) { 359 for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I) 360 ArrayElements.push_back(ConstVecInit->getOperand(I)); 361 } else if (ConstantDataVector *ConstDataVecInit = 362 llvm::dyn_cast<llvm::ConstantDataVector>(Init)) { 363 for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I) 364 ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I)); 365 } else { 366 assert(false && "Expected a ConstantVector or ConstantDataVector for " 367 "vector initializer!"); 368 } 369 370 return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements); 371 } 372 373 // Handle array of vectors transformation 374 if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) { 375 auto *ArrayInit = dyn_cast<ConstantArray>(Init); 376 assert(ArrayInit && "Expected a ConstantArray for array initializer!"); 377 378 SmallVector<Constant *, MaxVecSize> NewArrayElements; 379 for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) { 380 // Recursively transform array elements 381 Constant *NewElemInit = transformInitializer( 382 ArrayInit->getOperand(I), ArrayTy->getElementType(), 383 cast<ArrayType>(NewType)->getElementType(), Ctx); 384 NewArrayElements.push_back(NewElemInit); 385 } 386 387 return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements); 388 } 389 390 // If not a vector or array, return the original initializer 391 return Init; 392 } 393 394 static bool findAndReplaceVectors(Module &M) { 395 bool MadeChange = false; 396 LLVMContext &Ctx = M.getContext(); 397 IRBuilder<> Builder(Ctx); 398 DataScalarizerVisitor Impl; 399 for (GlobalVariable &G : M.globals()) { 400 Type *OrigType = G.getValueType(); 401 402 Type *NewType = equivalentArrayTypeFromVector(OrigType); 403 if (OrigType != NewType) { 404 // Create a new global variable with the updated type 405 // Note: Initializer is set via transformInitializer 406 GlobalVariable *NewGlobal = new GlobalVariable( 407 M, NewType, G.isConstant(), G.getLinkage(), 408 /*Initializer=*/nullptr, G.getName() + ".scalarized", &G, 409 G.getThreadLocalMode(), G.getAddressSpace(), 410 G.isExternallyInitialized()); 411 412 // Copy relevant attributes 413 NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); 414 if (G.getAlignment() > 0) { 415 NewGlobal->setAlignment(G.getAlign()); 416 } 417 418 if (G.hasInitializer()) { 419 Constant *Init = G.getInitializer(); 420 Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx); 421 NewGlobal->setInitializer(NewInit); 422 } 423 424 // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes 425 // type equality. Instead we will use the visitor pattern. 426 Impl.GlobalMap[&G] = NewGlobal; 427 } 428 } 429 430 for (auto &F : make_early_inc_range(M.functions())) { 431 if (F.isDeclaration()) 432 continue; 433 MadeChange |= Impl.visit(F); 434 } 435 436 // Remove the old globals after the iteration 437 for (auto &[Old, New] : Impl.GlobalMap) { 438 Old->eraseFromParent(); 439 MadeChange = true; 440 } 441 return MadeChange; 442 } 443 444 PreservedAnalyses DXILDataScalarization::run(Module &M, 445 ModuleAnalysisManager &) { 446 bool MadeChanges = findAndReplaceVectors(M); 447 if (!MadeChanges) 448 return PreservedAnalyses::all(); 449 PreservedAnalyses PA; 450 return PA; 451 } 452 453 bool DXILDataScalarizationLegacy::runOnModule(Module &M) { 454 return findAndReplaceVectors(M); 455 } 456 457 char DXILDataScalarizationLegacy::ID = 0; 458 459 INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE, 460 "DXIL Data Scalarization", false, false) 461 INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE, 462 "DXIL Data Scalarization", false, false) 463 464 ModulePass *llvm::createDXILDataScalarizationLegacyPass() { 465 return new DXILDataScalarizationLegacy(); 466 } 467