1 //===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 9 #include "DXILLegalizePass.h" 10 #include "DirectX.h" 11 #include "llvm/ADT/APInt.h" 12 #include "llvm/IR/Constants.h" 13 #include "llvm/IR/Function.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/InstIterator.h" 16 #include "llvm/IR/Instruction.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/IR/Module.h" 19 #include "llvm/Pass.h" 20 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 21 #include <functional> 22 23 #define DEBUG_TYPE "dxil-legalize" 24 25 using namespace llvm; 26 27 static void legalizeFreeze(Instruction &I, 28 SmallVectorImpl<Instruction *> &ToRemove, 29 DenseMap<Value *, Value *>) { 30 auto *FI = dyn_cast<FreezeInst>(&I); 31 if (!FI) 32 return; 33 34 FI->replaceAllUsesWith(FI->getOperand(0)); 35 ToRemove.push_back(FI); 36 } 37 38 static void fixI8UseChain(Instruction &I, 39 SmallVectorImpl<Instruction *> &ToRemove, 40 DenseMap<Value *, Value *> &ReplacedValues) { 41 42 auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) { 43 Type *InstrType = IntegerType::get(I.getContext(), 32); 44 45 for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { 46 Value *Op = I.getOperand(OpIdx); 47 if (ReplacedValues.count(Op) && 48 ReplacedValues[Op]->getType()->isIntegerTy()) 49 InstrType = ReplacedValues[Op]->getType(); 50 } 51 52 for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { 53 Value *Op = I.getOperand(OpIdx); 54 if (ReplacedValues.count(Op)) 55 NewOperands.push_back(ReplacedValues[Op]); 56 else if (auto *Imm = dyn_cast<ConstantInt>(Op)) { 57 APInt Value = Imm->getValue(); 58 unsigned NewBitWidth = InstrType->getIntegerBitWidth(); 59 // Note: options here are sext or sextOrTrunc. 60 // Since i8 isn't supported, we assume new values 61 // will always have a higher bitness. 62 assert(NewBitWidth > Value.getBitWidth() && 63 "Replacement's BitWidth should be larger than Current."); 64 APInt NewValue = Value.sext(NewBitWidth); 65 NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); 66 } else { 67 assert(!Op->getType()->isIntegerTy(8)); 68 NewOperands.push_back(Op); 69 } 70 } 71 }; 72 IRBuilder<> Builder(&I); 73 if (auto *Trunc = dyn_cast<TruncInst>(&I)) { 74 if (Trunc->getDestTy()->isIntegerTy(8)) { 75 ReplacedValues[Trunc] = Trunc->getOperand(0); 76 ToRemove.push_back(Trunc); 77 return; 78 } 79 } 80 81 if (auto *Store = dyn_cast<StoreInst>(&I)) { 82 if (!Store->getValueOperand()->getType()->isIntegerTy(8)) 83 return; 84 SmallVector<Value *> NewOperands; 85 ProcessOperands(NewOperands); 86 Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]); 87 ReplacedValues[Store] = NewStore; 88 ToRemove.push_back(Store); 89 return; 90 } 91 92 if (auto *Load = dyn_cast<LoadInst>(&I); 93 Load && I.getType()->isIntegerTy(8)) { 94 SmallVector<Value *> NewOperands; 95 ProcessOperands(NewOperands); 96 Type *ElementType = NewOperands[0]->getType(); 97 if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0])) 98 ElementType = AI->getAllocatedType(); 99 if (auto *GEP = dyn_cast<GetElementPtrInst>(NewOperands[0])) { 100 ElementType = GEP->getSourceElementType(); 101 if (ElementType->isArrayTy()) 102 ElementType = ElementType->getArrayElementType(); 103 } 104 LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]); 105 ReplacedValues[Load] = NewLoad; 106 ToRemove.push_back(Load); 107 return; 108 } 109 110 if (auto *Load = dyn_cast<LoadInst>(&I); 111 Load && isa<ConstantExpr>(Load->getPointerOperand())) { 112 auto *CE = dyn_cast<ConstantExpr>(Load->getPointerOperand()); 113 if (!(CE->getOpcode() == Instruction::GetElementPtr)) 114 return; 115 auto *GEP = dyn_cast<GEPOperator>(CE); 116 if (!GEP->getSourceElementType()->isIntegerTy(8)) 117 return; 118 119 Type *ElementType = Load->getType(); 120 ConstantInt *Offset = dyn_cast<ConstantInt>(GEP->getOperand(1)); 121 uint32_t ByteOffset = Offset->getZExtValue(); 122 uint32_t ElemSize = Load->getDataLayout().getTypeAllocSize(ElementType); 123 uint32_t Index = ByteOffset / ElemSize; 124 125 Value *PtrOperand = GEP->getPointerOperand(); 126 Type *GEPType = GEP->getPointerOperandType(); 127 128 if (auto *GV = dyn_cast<GlobalVariable>(PtrOperand)) 129 GEPType = GV->getValueType(); 130 if (auto *AI = dyn_cast<AllocaInst>(PtrOperand)) 131 GEPType = AI->getAllocatedType(); 132 133 if (auto *ArrTy = dyn_cast<ArrayType>(GEPType)) 134 GEPType = ArrTy; 135 else 136 GEPType = ArrayType::get(ElementType, 1); // its a scalar 137 138 Value *NewGEP = Builder.CreateGEP( 139 GEPType, PtrOperand, {Builder.getInt32(0), Builder.getInt32(Index)}, 140 GEP->getName(), GEP->getNoWrapFlags()); 141 142 LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewGEP); 143 ReplacedValues[Load] = NewLoad; 144 Load->replaceAllUsesWith(NewLoad); 145 ToRemove.push_back(Load); 146 return; 147 } 148 149 if (auto *BO = dyn_cast<BinaryOperator>(&I)) { 150 if (!I.getType()->isIntegerTy(8)) 151 return; 152 SmallVector<Value *> NewOperands; 153 ProcessOperands(NewOperands); 154 Value *NewInst = 155 Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); 156 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) { 157 auto *NewBO = dyn_cast<BinaryOperator>(NewInst); 158 if (NewBO && OBO->hasNoSignedWrap()) 159 NewBO->setHasNoSignedWrap(); 160 if (NewBO && OBO->hasNoUnsignedWrap()) 161 NewBO->setHasNoUnsignedWrap(); 162 } 163 ReplacedValues[BO] = NewInst; 164 ToRemove.push_back(BO); 165 return; 166 } 167 168 if (auto *Sel = dyn_cast<SelectInst>(&I)) { 169 if (!I.getType()->isIntegerTy(8)) 170 return; 171 SmallVector<Value *> NewOperands; 172 ProcessOperands(NewOperands); 173 Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1], 174 NewOperands[2]); 175 ReplacedValues[Sel] = NewInst; 176 ToRemove.push_back(Sel); 177 return; 178 } 179 180 if (auto *Cmp = dyn_cast<CmpInst>(&I)) { 181 if (!Cmp->getOperand(0)->getType()->isIntegerTy(8)) 182 return; 183 SmallVector<Value *> NewOperands; 184 ProcessOperands(NewOperands); 185 Value *NewInst = 186 Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]); 187 Cmp->replaceAllUsesWith(NewInst); 188 ReplacedValues[Cmp] = NewInst; 189 ToRemove.push_back(Cmp); 190 return; 191 } 192 193 if (auto *Cast = dyn_cast<CastInst>(&I)) { 194 if (!Cast->getSrcTy()->isIntegerTy(8)) 195 return; 196 197 ToRemove.push_back(Cast); 198 auto *Replacement = ReplacedValues[Cast->getOperand(0)]; 199 if (Cast->getType() == Replacement->getType()) { 200 Cast->replaceAllUsesWith(Replacement); 201 return; 202 } 203 204 Value *AdjustedCast = nullptr; 205 if (Cast->getOpcode() == Instruction::ZExt) 206 AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType()); 207 if (Cast->getOpcode() == Instruction::SExt) 208 AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType()); 209 210 if (AdjustedCast) 211 Cast->replaceAllUsesWith(AdjustedCast); 212 } 213 if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { 214 if (!GEP->getType()->isPointerTy() || 215 !GEP->getSourceElementType()->isIntegerTy(8)) 216 return; 217 218 Value *BasePtr = GEP->getPointerOperand(); 219 if (ReplacedValues.count(BasePtr)) 220 BasePtr = ReplacedValues[BasePtr]; 221 222 Type *ElementType = BasePtr->getType(); 223 224 if (auto *AI = dyn_cast<AllocaInst>(BasePtr)) 225 ElementType = AI->getAllocatedType(); 226 if (auto *GV = dyn_cast<GlobalVariable>(BasePtr)) 227 ElementType = GV->getValueType(); 228 229 Type *GEPType = ElementType; 230 if (auto *ArrTy = dyn_cast<ArrayType>(ElementType)) 231 ElementType = ArrTy->getArrayElementType(); 232 else 233 GEPType = ArrayType::get(ElementType, 1); // its a scalar 234 235 ConstantInt *Offset = dyn_cast<ConstantInt>(GEP->getOperand(1)); 236 // Note: i8 to i32 offset conversion without emitting IR requires constant 237 // ints. Since offset conversion is common, we can safely assume Offset is 238 // always a ConstantInt, so no need to have a conditional bail out on 239 // nullptr, instead assert this is the case. 240 assert(Offset && "Offset is expected to be a ConstantInt"); 241 uint32_t ByteOffset = Offset->getZExtValue(); 242 uint32_t ElemSize = GEP->getDataLayout().getTypeAllocSize(ElementType); 243 assert(ElemSize > 0 && "ElementSize must be set"); 244 uint32_t Index = ByteOffset / ElemSize; 245 Value *NewGEP = Builder.CreateGEP( 246 GEPType, BasePtr, {Builder.getInt32(0), Builder.getInt32(Index)}, 247 GEP->getName(), GEP->getNoWrapFlags()); 248 ReplacedValues[GEP] = NewGEP; 249 GEP->replaceAllUsesWith(NewGEP); 250 ToRemove.push_back(GEP); 251 } 252 } 253 254 static void upcastI8AllocasAndUses(Instruction &I, 255 SmallVectorImpl<Instruction *> &ToRemove, 256 DenseMap<Value *, Value *> &ReplacedValues) { 257 auto *AI = dyn_cast<AllocaInst>(&I); 258 if (!AI || !AI->getAllocatedType()->isIntegerTy(8)) 259 return; 260 261 Type *SmallestType = nullptr; 262 263 auto ProcessLoad = [&](LoadInst *Load) { 264 for (User *LU : Load->users()) { 265 Type *Ty = nullptr; 266 if (CastInst *Cast = dyn_cast<CastInst>(LU)) 267 Ty = Cast->getType(); 268 else if (CallInst *CI = dyn_cast<CallInst>(LU)) { 269 if (CI->getIntrinsicID() == Intrinsic::memset) 270 Ty = Type::getInt32Ty(CI->getContext()); 271 } 272 273 if (!Ty) 274 continue; 275 276 if (!SmallestType || 277 Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits()) 278 SmallestType = Ty; 279 } 280 }; 281 282 for (User *U : AI->users()) { 283 if (auto *Load = dyn_cast<LoadInst>(U)) 284 ProcessLoad(Load); 285 else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) { 286 for (User *GU : GEP->users()) { 287 if (auto *Load = dyn_cast<LoadInst>(GU)) 288 ProcessLoad(Load); 289 } 290 } 291 } 292 293 if (!SmallestType) 294 return; // no valid casts found 295 296 // Replace alloca 297 IRBuilder<> Builder(AI); 298 auto *NewAlloca = Builder.CreateAlloca(SmallestType); 299 ReplacedValues[AI] = NewAlloca; 300 ToRemove.push_back(AI); 301 } 302 303 static void 304 downcastI64toI32InsertExtractElements(Instruction &I, 305 SmallVectorImpl<Instruction *> &ToRemove, 306 DenseMap<Value *, Value *> &) { 307 308 if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) { 309 Value *Idx = Extract->getIndexOperand(); 310 auto *CI = dyn_cast<ConstantInt>(Idx); 311 if (CI && CI->getBitWidth() == 64) { 312 IRBuilder<> Builder(Extract); 313 int64_t IndexValue = CI->getSExtValue(); 314 auto *Idx32 = 315 ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue); 316 Value *NewExtract = Builder.CreateExtractElement( 317 Extract->getVectorOperand(), Idx32, Extract->getName()); 318 319 Extract->replaceAllUsesWith(NewExtract); 320 ToRemove.push_back(Extract); 321 } 322 } 323 324 if (auto *Insert = dyn_cast<InsertElementInst>(&I)) { 325 Value *Idx = Insert->getOperand(2); 326 auto *CI = dyn_cast<ConstantInt>(Idx); 327 if (CI && CI->getBitWidth() == 64) { 328 int64_t IndexValue = CI->getSExtValue(); 329 auto *Idx32 = 330 ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue); 331 IRBuilder<> Builder(Insert); 332 Value *Insert32Index = Builder.CreateInsertElement( 333 Insert->getOperand(0), Insert->getOperand(1), Idx32, 334 Insert->getName()); 335 336 Insert->replaceAllUsesWith(Insert32Index); 337 ToRemove.push_back(Insert); 338 } 339 } 340 } 341 342 static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src, 343 ConstantInt *Length) { 344 345 uint64_t ByteLength = Length->getZExtValue(); 346 // If length to copy is zero, no memcpy is needed. 347 if (ByteLength == 0) 348 return; 349 350 LLVMContext &Ctx = Builder.getContext(); 351 const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout(); 352 353 auto GetArrTyFromVal = [](Value *Val) -> ArrayType * { 354 assert(isa<AllocaInst>(Val) || 355 isa<GlobalVariable>(Val) && 356 "Expected Val to be an Alloca or Global Variable"); 357 if (auto *Alloca = dyn_cast<AllocaInst>(Val)) 358 return dyn_cast<ArrayType>(Alloca->getAllocatedType()); 359 if (auto *GlobalVar = dyn_cast<GlobalVariable>(Val)) 360 return dyn_cast<ArrayType>(GlobalVar->getValueType()); 361 return nullptr; 362 }; 363 364 ArrayType *DstArrTy = GetArrTyFromVal(Dst); 365 assert(DstArrTy && "Expected Dst of memcpy to be a Pointer to an Array Type"); 366 if (auto *DstGlobalVar = dyn_cast<GlobalVariable>(Dst)) 367 assert(!DstGlobalVar->isConstant() && 368 "The Dst of memcpy must not be a constant Global Variable"); 369 [[maybe_unused]] ArrayType *SrcArrTy = GetArrTyFromVal(Src); 370 assert(SrcArrTy && "Expected Src of memcpy to be a Pointer to an Array Type"); 371 372 Type *DstElemTy = DstArrTy->getElementType(); 373 uint64_t DstElemByteSize = DL.getTypeStoreSize(DstElemTy); 374 assert(DstElemByteSize > 0 && "Dst element type store size must be set"); 375 Type *SrcElemTy = SrcArrTy->getElementType(); 376 [[maybe_unused]] uint64_t SrcElemByteSize = DL.getTypeStoreSize(SrcElemTy); 377 assert(SrcElemByteSize > 0 && "Src element type store size must be set"); 378 379 // This assumption simplifies implementation and covers currently-known 380 // use-cases for DXIL. It may be relaxed in the future if required. 381 assert(DstElemTy == SrcElemTy && 382 "The element types of Src and Dst arrays must match"); 383 384 [[maybe_unused]] uint64_t DstArrNumElems = DstArrTy->getArrayNumElements(); 385 assert(DstElemByteSize * DstArrNumElems >= ByteLength && 386 "Dst array size must be at least as large as the memcpy length"); 387 [[maybe_unused]] uint64_t SrcArrNumElems = SrcArrTy->getArrayNumElements(); 388 assert(SrcElemByteSize * SrcArrNumElems >= ByteLength && 389 "Src array size must be at least as large as the memcpy length"); 390 391 uint64_t NumElemsToCopy = ByteLength / DstElemByteSize; 392 assert(ByteLength % DstElemByteSize == 0 && 393 "memcpy length must be divisible by array element type"); 394 for (uint64_t I = 0; I < NumElemsToCopy; ++I) { 395 Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I); 396 Value *SrcPtr = Builder.CreateInBoundsGEP(SrcElemTy, Src, Offset, "gep"); 397 Value *SrcVal = Builder.CreateLoad(SrcElemTy, SrcPtr); 398 Value *DstPtr = Builder.CreateInBoundsGEP(DstElemTy, Dst, Offset, "gep"); 399 Builder.CreateStore(SrcVal, DstPtr); 400 } 401 } 402 403 static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val, 404 ConstantInt *SizeCI, 405 DenseMap<Value *, Value *> &ReplacedValues) { 406 LLVMContext &Ctx = Builder.getContext(); 407 [[maybe_unused]] const DataLayout &DL = 408 Builder.GetInsertBlock()->getModule()->getDataLayout(); 409 [[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue(); 410 411 AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst); 412 413 assert(Alloca && "Expected memset on an Alloca"); 414 assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() && 415 "Expected for memset size to match DataLayout size"); 416 417 Type *AllocatedTy = Alloca->getAllocatedType(); 418 ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy); 419 assert(ArrTy && "Expected Alloca for an Array Type"); 420 421 Type *ElemTy = ArrTy->getElementType(); 422 uint64_t Size = ArrTy->getArrayNumElements(); 423 424 [[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy); 425 426 assert(ElemSize > 0 && "Size must be set"); 427 assert(OrigSize == ElemSize * Size && "Size in bytes must match"); 428 429 Value *TypedVal = Val; 430 431 if (Val->getType() != ElemTy) { 432 if (ReplacedValues[Val]) { 433 // Note for i8 replacements if we know them we should use them. 434 // Further if this is a constant ReplacedValues will return null 435 // so we will stick to TypedVal = Val 436 TypedVal = ReplacedValues[Val]; 437 438 } else { 439 // This case Val is a ConstantInt so the cast folds away. 440 // However if we don't do the cast the store below ends up being 441 // an i8. 442 TypedVal = Builder.CreateIntCast(Val, ElemTy, false); 443 } 444 } 445 446 for (uint64_t I = 0; I < Size; ++I) { 447 Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I); 448 Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep"); 449 Builder.CreateStore(TypedVal, Ptr); 450 } 451 } 452 453 // Expands the instruction `I` into corresponding loads and stores if it is a 454 // memcpy call. In that case, the call instruction is added to the `ToRemove` 455 // vector. `ReplacedValues` is unused. 456 static void legalizeMemCpy(Instruction &I, 457 SmallVectorImpl<Instruction *> &ToRemove, 458 DenseMap<Value *, Value *> &ReplacedValues) { 459 460 CallInst *CI = dyn_cast<CallInst>(&I); 461 if (!CI) 462 return; 463 464 Intrinsic::ID ID = CI->getIntrinsicID(); 465 if (ID != Intrinsic::memcpy) 466 return; 467 468 IRBuilder<> Builder(&I); 469 Value *Dst = CI->getArgOperand(0); 470 Value *Src = CI->getArgOperand(1); 471 ConstantInt *Length = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 472 assert(Length && "Expected Length to be a ConstantInt"); 473 [[maybe_unused]] ConstantInt *IsVolatile = 474 dyn_cast<ConstantInt>(CI->getArgOperand(3)); 475 assert(IsVolatile && "Expected IsVolatile to be a ConstantInt"); 476 assert(IsVolatile->getZExtValue() == 0 && "Expected IsVolatile to be false"); 477 emitMemcpyExpansion(Builder, Dst, Src, Length); 478 ToRemove.push_back(CI); 479 } 480 481 static void removeMemSet(Instruction &I, 482 SmallVectorImpl<Instruction *> &ToRemove, 483 DenseMap<Value *, Value *> &ReplacedValues) { 484 485 CallInst *CI = dyn_cast<CallInst>(&I); 486 if (!CI) 487 return; 488 489 Intrinsic::ID ID = CI->getIntrinsicID(); 490 if (ID != Intrinsic::memset) 491 return; 492 493 IRBuilder<> Builder(&I); 494 Value *Dst = CI->getArgOperand(0); 495 Value *Val = CI->getArgOperand(1); 496 ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 497 assert(Size && "Expected Size to be a ConstantInt"); 498 emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues); 499 ToRemove.push_back(CI); 500 } 501 502 static void updateFnegToFsub(Instruction &I, 503 SmallVectorImpl<Instruction *> &ToRemove, 504 DenseMap<Value *, Value *> &) { 505 const Intrinsic::ID ID = I.getOpcode(); 506 if (ID != Instruction::FNeg) 507 return; 508 509 IRBuilder<> Builder(&I); 510 Value *In = I.getOperand(0); 511 Value *Zero = ConstantFP::get(In->getType(), -0.0); 512 I.replaceAllUsesWith(Builder.CreateFSub(Zero, In)); 513 ToRemove.push_back(&I); 514 } 515 516 static void 517 legalizeGetHighLowi64Bytes(Instruction &I, 518 SmallVectorImpl<Instruction *> &ToRemove, 519 DenseMap<Value *, Value *> &ReplacedValues) { 520 if (auto *BitCast = dyn_cast<BitCastInst>(&I)) { 521 if (BitCast->getDestTy() == 522 FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) && 523 BitCast->getSrcTy()->isIntegerTy(64)) { 524 ToRemove.push_back(BitCast); 525 ReplacedValues[BitCast] = BitCast->getOperand(0); 526 return; 527 } 528 } 529 530 if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) { 531 if (!dyn_cast<BitCastInst>(Extract->getVectorOperand())) 532 return; 533 auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType()); 534 if (VecTy && VecTy->getElementType()->isIntegerTy(32) && 535 VecTy->getNumElements() == 2) { 536 if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) { 537 unsigned Idx = Index->getZExtValue(); 538 IRBuilder<> Builder(&I); 539 540 auto *Replacement = ReplacedValues[Extract->getVectorOperand()]; 541 assert(Replacement && "The BitCast replacement should have been set " 542 "before working on ExtractElementInst."); 543 if (Idx == 0) { 544 Value *LowBytes = Builder.CreateTrunc( 545 Replacement, Type::getInt32Ty(I.getContext())); 546 ReplacedValues[Extract] = LowBytes; 547 } else { 548 assert(Idx == 1); 549 Value *LogicalShiftRight = Builder.CreateLShr( 550 Replacement, 551 ConstantInt::get( 552 Replacement->getType(), 553 APInt(Replacement->getType()->getIntegerBitWidth(), 32))); 554 Value *HighBytes = Builder.CreateTrunc( 555 LogicalShiftRight, Type::getInt32Ty(I.getContext())); 556 ReplacedValues[Extract] = HighBytes; 557 } 558 ToRemove.push_back(Extract); 559 Extract->replaceAllUsesWith(ReplacedValues[Extract]); 560 } 561 } 562 } 563 } 564 565 namespace { 566 class DXILLegalizationPipeline { 567 568 public: 569 DXILLegalizationPipeline() { initializeLegalizationPipeline(); } 570 571 bool runLegalizationPipeline(Function &F) { 572 bool MadeChange = false; 573 SmallVector<Instruction *> ToRemove; 574 DenseMap<Value *, Value *> ReplacedValues; 575 for (int Stage = 0; Stage < NumStages; ++Stage) { 576 ToRemove.clear(); 577 ReplacedValues.clear(); 578 for (auto &I : instructions(F)) { 579 for (auto &LegalizationFn : LegalizationPipeline[Stage]) 580 LegalizationFn(I, ToRemove, ReplacedValues); 581 } 582 583 for (auto *Inst : reverse(ToRemove)) 584 Inst->eraseFromParent(); 585 586 MadeChange |= !ToRemove.empty(); 587 } 588 return MadeChange; 589 } 590 591 private: 592 enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages }; 593 594 using LegalizationFnTy = 595 std::function<void(Instruction &, SmallVectorImpl<Instruction *> &, 596 DenseMap<Value *, Value *> &)>; 597 598 SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages]; 599 600 void initializeLegalizationPipeline() { 601 LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses); 602 LegalizationPipeline[Stage1].push_back(fixI8UseChain); 603 LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes); 604 LegalizationPipeline[Stage1].push_back(legalizeFreeze); 605 LegalizationPipeline[Stage1].push_back(legalizeMemCpy); 606 LegalizationPipeline[Stage1].push_back(removeMemSet); 607 LegalizationPipeline[Stage1].push_back(updateFnegToFsub); 608 // Note: legalizeGetHighLowi64Bytes and 609 // downcastI64toI32InsertExtractElements both modify extractelement, so they 610 // must run staggered stages. legalizeGetHighLowi64Bytes runs first b\c it 611 // removes extractelements, reducing the number that 612 // downcastI64toI32InsertExtractElements needs to handle. 613 LegalizationPipeline[Stage2].push_back( 614 downcastI64toI32InsertExtractElements); 615 } 616 }; 617 618 class DXILLegalizeLegacy : public FunctionPass { 619 620 public: 621 bool runOnFunction(Function &F) override; 622 DXILLegalizeLegacy() : FunctionPass(ID) {} 623 624 static char ID; // Pass identification. 625 }; 626 } // namespace 627 628 PreservedAnalyses DXILLegalizePass::run(Function &F, 629 FunctionAnalysisManager &FAM) { 630 DXILLegalizationPipeline DXLegalize; 631 bool MadeChanges = DXLegalize.runLegalizationPipeline(F); 632 if (!MadeChanges) 633 return PreservedAnalyses::all(); 634 PreservedAnalyses PA; 635 return PA; 636 } 637 638 bool DXILLegalizeLegacy::runOnFunction(Function &F) { 639 DXILLegalizationPipeline DXLegalize; 640 return DXLegalize.runLegalizationPipeline(F); 641 } 642 643 char DXILLegalizeLegacy::ID = 0; 644 645 INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false, 646 false) 647 INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false, 648 false) 649 650 FunctionPass *llvm::createDXILLegalizeLegacyPass() { 651 return new DXILLegalizeLegacy(); 652 } 653