1 //===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8
9 #include "DXILLegalizePass.h"
10 #include "DirectX.h"
11 #include "llvm/ADT/APInt.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/Function.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instruction.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Pass.h"
20 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
21 #include <functional>
22
23 #define DEBUG_TYPE "dxil-legalize"
24
25 using namespace llvm;
26
legalizeFreeze(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * >)27 static void legalizeFreeze(Instruction &I,
28 SmallVectorImpl<Instruction *> &ToRemove,
29 DenseMap<Value *, Value *>) {
30 auto *FI = dyn_cast<FreezeInst>(&I);
31 if (!FI)
32 return;
33
34 FI->replaceAllUsesWith(FI->getOperand(0));
35 ToRemove.push_back(FI);
36 }
37
fixI8UseChain(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)38 static void fixI8UseChain(Instruction &I,
39 SmallVectorImpl<Instruction *> &ToRemove,
40 DenseMap<Value *, Value *> &ReplacedValues) {
41
42 auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
43 Type *InstrType = IntegerType::get(I.getContext(), 32);
44
45 for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
46 Value *Op = I.getOperand(OpIdx);
47 if (ReplacedValues.count(Op) &&
48 ReplacedValues[Op]->getType()->isIntegerTy())
49 InstrType = ReplacedValues[Op]->getType();
50 }
51
52 for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
53 Value *Op = I.getOperand(OpIdx);
54 if (ReplacedValues.count(Op))
55 NewOperands.push_back(ReplacedValues[Op]);
56 else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
57 APInt Value = Imm->getValue();
58 unsigned NewBitWidth = InstrType->getIntegerBitWidth();
59 // Note: options here are sext or sextOrTrunc.
60 // Since i8 isn't supported, we assume new values
61 // will always have a higher bitness.
62 assert(NewBitWidth > Value.getBitWidth() &&
63 "Replacement's BitWidth should be larger than Current.");
64 APInt NewValue = Value.sext(NewBitWidth);
65 NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
66 } else {
67 assert(!Op->getType()->isIntegerTy(8));
68 NewOperands.push_back(Op);
69 }
70 }
71 };
72 IRBuilder<> Builder(&I);
73 if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
74 if (Trunc->getDestTy()->isIntegerTy(8)) {
75 ReplacedValues[Trunc] = Trunc->getOperand(0);
76 ToRemove.push_back(Trunc);
77 return;
78 }
79 }
80
81 if (auto *Store = dyn_cast<StoreInst>(&I)) {
82 if (!Store->getValueOperand()->getType()->isIntegerTy(8))
83 return;
84 SmallVector<Value *> NewOperands;
85 ProcessOperands(NewOperands);
86 Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
87 ReplacedValues[Store] = NewStore;
88 ToRemove.push_back(Store);
89 return;
90 }
91
92 if (auto *Load = dyn_cast<LoadInst>(&I);
93 Load && I.getType()->isIntegerTy(8)) {
94 SmallVector<Value *> NewOperands;
95 ProcessOperands(NewOperands);
96 Type *ElementType = NewOperands[0]->getType();
97 if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
98 ElementType = AI->getAllocatedType();
99 if (auto *GEP = dyn_cast<GetElementPtrInst>(NewOperands[0])) {
100 ElementType = GEP->getSourceElementType();
101 if (ElementType->isArrayTy())
102 ElementType = ElementType->getArrayElementType();
103 }
104 LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
105 ReplacedValues[Load] = NewLoad;
106 ToRemove.push_back(Load);
107 return;
108 }
109
110 if (auto *Load = dyn_cast<LoadInst>(&I);
111 Load && isa<ConstantExpr>(Load->getPointerOperand())) {
112 auto *CE = dyn_cast<ConstantExpr>(Load->getPointerOperand());
113 if (!(CE->getOpcode() == Instruction::GetElementPtr))
114 return;
115 auto *GEP = dyn_cast<GEPOperator>(CE);
116 if (!GEP->getSourceElementType()->isIntegerTy(8))
117 return;
118
119 Type *ElementType = Load->getType();
120 ConstantInt *Offset = dyn_cast<ConstantInt>(GEP->getOperand(1));
121 uint32_t ByteOffset = Offset->getZExtValue();
122 uint32_t ElemSize = Load->getDataLayout().getTypeAllocSize(ElementType);
123 uint32_t Index = ByteOffset / ElemSize;
124
125 Value *PtrOperand = GEP->getPointerOperand();
126 Type *GEPType = GEP->getPointerOperandType();
127
128 if (auto *GV = dyn_cast<GlobalVariable>(PtrOperand))
129 GEPType = GV->getValueType();
130 if (auto *AI = dyn_cast<AllocaInst>(PtrOperand))
131 GEPType = AI->getAllocatedType();
132
133 if (auto *ArrTy = dyn_cast<ArrayType>(GEPType))
134 GEPType = ArrTy;
135 else
136 GEPType = ArrayType::get(ElementType, 1); // its a scalar
137
138 Value *NewGEP = Builder.CreateGEP(
139 GEPType, PtrOperand, {Builder.getInt32(0), Builder.getInt32(Index)},
140 GEP->getName(), GEP->getNoWrapFlags());
141
142 LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewGEP);
143 ReplacedValues[Load] = NewLoad;
144 Load->replaceAllUsesWith(NewLoad);
145 ToRemove.push_back(Load);
146 return;
147 }
148
149 if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
150 if (!I.getType()->isIntegerTy(8))
151 return;
152 SmallVector<Value *> NewOperands;
153 ProcessOperands(NewOperands);
154 Value *NewInst =
155 Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
156 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
157 auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
158 if (NewBO && OBO->hasNoSignedWrap())
159 NewBO->setHasNoSignedWrap();
160 if (NewBO && OBO->hasNoUnsignedWrap())
161 NewBO->setHasNoUnsignedWrap();
162 }
163 ReplacedValues[BO] = NewInst;
164 ToRemove.push_back(BO);
165 return;
166 }
167
168 if (auto *Sel = dyn_cast<SelectInst>(&I)) {
169 if (!I.getType()->isIntegerTy(8))
170 return;
171 SmallVector<Value *> NewOperands;
172 ProcessOperands(NewOperands);
173 Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
174 NewOperands[2]);
175 ReplacedValues[Sel] = NewInst;
176 ToRemove.push_back(Sel);
177 return;
178 }
179
180 if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
181 if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
182 return;
183 SmallVector<Value *> NewOperands;
184 ProcessOperands(NewOperands);
185 Value *NewInst =
186 Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
187 Cmp->replaceAllUsesWith(NewInst);
188 ReplacedValues[Cmp] = NewInst;
189 ToRemove.push_back(Cmp);
190 return;
191 }
192
193 if (auto *Cast = dyn_cast<CastInst>(&I)) {
194 if (!Cast->getSrcTy()->isIntegerTy(8))
195 return;
196
197 ToRemove.push_back(Cast);
198 auto *Replacement = ReplacedValues[Cast->getOperand(0)];
199 if (Cast->getType() == Replacement->getType()) {
200 Cast->replaceAllUsesWith(Replacement);
201 return;
202 }
203
204 Value *AdjustedCast = nullptr;
205 if (Cast->getOpcode() == Instruction::ZExt)
206 AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
207 if (Cast->getOpcode() == Instruction::SExt)
208 AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
209
210 if (AdjustedCast)
211 Cast->replaceAllUsesWith(AdjustedCast);
212 }
213 if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
214 if (!GEP->getType()->isPointerTy() ||
215 !GEP->getSourceElementType()->isIntegerTy(8))
216 return;
217
218 Value *BasePtr = GEP->getPointerOperand();
219 if (ReplacedValues.count(BasePtr))
220 BasePtr = ReplacedValues[BasePtr];
221
222 Type *ElementType = BasePtr->getType();
223
224 if (auto *AI = dyn_cast<AllocaInst>(BasePtr))
225 ElementType = AI->getAllocatedType();
226 if (auto *GV = dyn_cast<GlobalVariable>(BasePtr))
227 ElementType = GV->getValueType();
228
229 Type *GEPType = ElementType;
230 if (auto *ArrTy = dyn_cast<ArrayType>(ElementType))
231 ElementType = ArrTy->getArrayElementType();
232 else
233 GEPType = ArrayType::get(ElementType, 1); // its a scalar
234
235 ConstantInt *Offset = dyn_cast<ConstantInt>(GEP->getOperand(1));
236 // Note: i8 to i32 offset conversion without emitting IR requires constant
237 // ints. Since offset conversion is common, we can safely assume Offset is
238 // always a ConstantInt, so no need to have a conditional bail out on
239 // nullptr, instead assert this is the case.
240 assert(Offset && "Offset is expected to be a ConstantInt");
241 uint32_t ByteOffset = Offset->getZExtValue();
242 uint32_t ElemSize = GEP->getDataLayout().getTypeAllocSize(ElementType);
243 assert(ElemSize > 0 && "ElementSize must be set");
244 uint32_t Index = ByteOffset / ElemSize;
245 Value *NewGEP = Builder.CreateGEP(
246 GEPType, BasePtr, {Builder.getInt32(0), Builder.getInt32(Index)},
247 GEP->getName(), GEP->getNoWrapFlags());
248 ReplacedValues[GEP] = NewGEP;
249 GEP->replaceAllUsesWith(NewGEP);
250 ToRemove.push_back(GEP);
251 }
252 }
253
upcastI8AllocasAndUses(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)254 static void upcastI8AllocasAndUses(Instruction &I,
255 SmallVectorImpl<Instruction *> &ToRemove,
256 DenseMap<Value *, Value *> &ReplacedValues) {
257 auto *AI = dyn_cast<AllocaInst>(&I);
258 if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
259 return;
260
261 Type *SmallestType = nullptr;
262
263 auto ProcessLoad = [&](LoadInst *Load) {
264 for (User *LU : Load->users()) {
265 Type *Ty = nullptr;
266 if (CastInst *Cast = dyn_cast<CastInst>(LU))
267 Ty = Cast->getType();
268 else if (CallInst *CI = dyn_cast<CallInst>(LU)) {
269 if (CI->getIntrinsicID() == Intrinsic::memset)
270 Ty = Type::getInt32Ty(CI->getContext());
271 }
272
273 if (!Ty)
274 continue;
275
276 if (!SmallestType ||
277 Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
278 SmallestType = Ty;
279 }
280 };
281
282 for (User *U : AI->users()) {
283 if (auto *Load = dyn_cast<LoadInst>(U))
284 ProcessLoad(Load);
285 else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
286 for (User *GU : GEP->users()) {
287 if (auto *Load = dyn_cast<LoadInst>(GU))
288 ProcessLoad(Load);
289 }
290 }
291 }
292
293 if (!SmallestType)
294 return; // no valid casts found
295
296 // Replace alloca
297 IRBuilder<> Builder(AI);
298 auto *NewAlloca = Builder.CreateAlloca(SmallestType);
299 ReplacedValues[AI] = NewAlloca;
300 ToRemove.push_back(AI);
301 }
302
303 static void
downcastI64toI32InsertExtractElements(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > &)304 downcastI64toI32InsertExtractElements(Instruction &I,
305 SmallVectorImpl<Instruction *> &ToRemove,
306 DenseMap<Value *, Value *> &) {
307
308 if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
309 Value *Idx = Extract->getIndexOperand();
310 auto *CI = dyn_cast<ConstantInt>(Idx);
311 if (CI && CI->getBitWidth() == 64) {
312 IRBuilder<> Builder(Extract);
313 int64_t IndexValue = CI->getSExtValue();
314 auto *Idx32 =
315 ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
316 Value *NewExtract = Builder.CreateExtractElement(
317 Extract->getVectorOperand(), Idx32, Extract->getName());
318
319 Extract->replaceAllUsesWith(NewExtract);
320 ToRemove.push_back(Extract);
321 }
322 }
323
324 if (auto *Insert = dyn_cast<InsertElementInst>(&I)) {
325 Value *Idx = Insert->getOperand(2);
326 auto *CI = dyn_cast<ConstantInt>(Idx);
327 if (CI && CI->getBitWidth() == 64) {
328 int64_t IndexValue = CI->getSExtValue();
329 auto *Idx32 =
330 ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
331 IRBuilder<> Builder(Insert);
332 Value *Insert32Index = Builder.CreateInsertElement(
333 Insert->getOperand(0), Insert->getOperand(1), Idx32,
334 Insert->getName());
335
336 Insert->replaceAllUsesWith(Insert32Index);
337 ToRemove.push_back(Insert);
338 }
339 }
340 }
341
emitMemcpyExpansion(IRBuilder<> & Builder,Value * Dst,Value * Src,ConstantInt * Length)342 static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
343 ConstantInt *Length) {
344
345 uint64_t ByteLength = Length->getZExtValue();
346 // If length to copy is zero, no memcpy is needed.
347 if (ByteLength == 0)
348 return;
349
350 LLVMContext &Ctx = Builder.getContext();
351 const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
352
353 auto GetArrTyFromVal = [](Value *Val) -> ArrayType * {
354 assert(isa<AllocaInst>(Val) ||
355 isa<GlobalVariable>(Val) &&
356 "Expected Val to be an Alloca or Global Variable");
357 if (auto *Alloca = dyn_cast<AllocaInst>(Val))
358 return dyn_cast<ArrayType>(Alloca->getAllocatedType());
359 if (auto *GlobalVar = dyn_cast<GlobalVariable>(Val))
360 return dyn_cast<ArrayType>(GlobalVar->getValueType());
361 return nullptr;
362 };
363
364 ArrayType *DstArrTy = GetArrTyFromVal(Dst);
365 assert(DstArrTy && "Expected Dst of memcpy to be a Pointer to an Array Type");
366 if (auto *DstGlobalVar = dyn_cast<GlobalVariable>(Dst))
367 assert(!DstGlobalVar->isConstant() &&
368 "The Dst of memcpy must not be a constant Global Variable");
369 [[maybe_unused]] ArrayType *SrcArrTy = GetArrTyFromVal(Src);
370 assert(SrcArrTy && "Expected Src of memcpy to be a Pointer to an Array Type");
371
372 Type *DstElemTy = DstArrTy->getElementType();
373 uint64_t DstElemByteSize = DL.getTypeStoreSize(DstElemTy);
374 assert(DstElemByteSize > 0 && "Dst element type store size must be set");
375 Type *SrcElemTy = SrcArrTy->getElementType();
376 [[maybe_unused]] uint64_t SrcElemByteSize = DL.getTypeStoreSize(SrcElemTy);
377 assert(SrcElemByteSize > 0 && "Src element type store size must be set");
378
379 // This assumption simplifies implementation and covers currently-known
380 // use-cases for DXIL. It may be relaxed in the future if required.
381 assert(DstElemTy == SrcElemTy &&
382 "The element types of Src and Dst arrays must match");
383
384 [[maybe_unused]] uint64_t DstArrNumElems = DstArrTy->getArrayNumElements();
385 assert(DstElemByteSize * DstArrNumElems >= ByteLength &&
386 "Dst array size must be at least as large as the memcpy length");
387 [[maybe_unused]] uint64_t SrcArrNumElems = SrcArrTy->getArrayNumElements();
388 assert(SrcElemByteSize * SrcArrNumElems >= ByteLength &&
389 "Src array size must be at least as large as the memcpy length");
390
391 uint64_t NumElemsToCopy = ByteLength / DstElemByteSize;
392 assert(ByteLength % DstElemByteSize == 0 &&
393 "memcpy length must be divisible by array element type");
394 for (uint64_t I = 0; I < NumElemsToCopy; ++I) {
395 Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
396 Value *SrcPtr = Builder.CreateInBoundsGEP(SrcElemTy, Src, Offset, "gep");
397 Value *SrcVal = Builder.CreateLoad(SrcElemTy, SrcPtr);
398 Value *DstPtr = Builder.CreateInBoundsGEP(DstElemTy, Dst, Offset, "gep");
399 Builder.CreateStore(SrcVal, DstPtr);
400 }
401 }
402
emitMemsetExpansion(IRBuilder<> & Builder,Value * Dst,Value * Val,ConstantInt * SizeCI,DenseMap<Value *,Value * > & ReplacedValues)403 static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
404 ConstantInt *SizeCI,
405 DenseMap<Value *, Value *> &ReplacedValues) {
406 LLVMContext &Ctx = Builder.getContext();
407 [[maybe_unused]] const DataLayout &DL =
408 Builder.GetInsertBlock()->getModule()->getDataLayout();
409 [[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue();
410
411 AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst);
412
413 assert(Alloca && "Expected memset on an Alloca");
414 assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() &&
415 "Expected for memset size to match DataLayout size");
416
417 Type *AllocatedTy = Alloca->getAllocatedType();
418 ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy);
419 assert(ArrTy && "Expected Alloca for an Array Type");
420
421 Type *ElemTy = ArrTy->getElementType();
422 uint64_t Size = ArrTy->getArrayNumElements();
423
424 [[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);
425
426 assert(ElemSize > 0 && "Size must be set");
427 assert(OrigSize == ElemSize * Size && "Size in bytes must match");
428
429 Value *TypedVal = Val;
430
431 if (Val->getType() != ElemTy) {
432 if (ReplacedValues[Val]) {
433 // Note for i8 replacements if we know them we should use them.
434 // Further if this is a constant ReplacedValues will return null
435 // so we will stick to TypedVal = Val
436 TypedVal = ReplacedValues[Val];
437
438 } else {
439 // This case Val is a ConstantInt so the cast folds away.
440 // However if we don't do the cast the store below ends up being
441 // an i8.
442 TypedVal = Builder.CreateIntCast(Val, ElemTy, false);
443 }
444 }
445
446 for (uint64_t I = 0; I < Size; ++I) {
447 Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
448 Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
449 Builder.CreateStore(TypedVal, Ptr);
450 }
451 }
452
453 // Expands the instruction `I` into corresponding loads and stores if it is a
454 // memcpy call. In that case, the call instruction is added to the `ToRemove`
455 // vector. `ReplacedValues` is unused.
legalizeMemCpy(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)456 static void legalizeMemCpy(Instruction &I,
457 SmallVectorImpl<Instruction *> &ToRemove,
458 DenseMap<Value *, Value *> &ReplacedValues) {
459
460 CallInst *CI = dyn_cast<CallInst>(&I);
461 if (!CI)
462 return;
463
464 Intrinsic::ID ID = CI->getIntrinsicID();
465 if (ID != Intrinsic::memcpy)
466 return;
467
468 IRBuilder<> Builder(&I);
469 Value *Dst = CI->getArgOperand(0);
470 Value *Src = CI->getArgOperand(1);
471 ConstantInt *Length = dyn_cast<ConstantInt>(CI->getArgOperand(2));
472 assert(Length && "Expected Length to be a ConstantInt");
473 [[maybe_unused]] ConstantInt *IsVolatile =
474 dyn_cast<ConstantInt>(CI->getArgOperand(3));
475 assert(IsVolatile && "Expected IsVolatile to be a ConstantInt");
476 assert(IsVolatile->getZExtValue() == 0 && "Expected IsVolatile to be false");
477 emitMemcpyExpansion(Builder, Dst, Src, Length);
478 ToRemove.push_back(CI);
479 }
480
removeMemSet(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)481 static void removeMemSet(Instruction &I,
482 SmallVectorImpl<Instruction *> &ToRemove,
483 DenseMap<Value *, Value *> &ReplacedValues) {
484
485 CallInst *CI = dyn_cast<CallInst>(&I);
486 if (!CI)
487 return;
488
489 Intrinsic::ID ID = CI->getIntrinsicID();
490 if (ID != Intrinsic::memset)
491 return;
492
493 IRBuilder<> Builder(&I);
494 Value *Dst = CI->getArgOperand(0);
495 Value *Val = CI->getArgOperand(1);
496 ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(2));
497 assert(Size && "Expected Size to be a ConstantInt");
498 emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues);
499 ToRemove.push_back(CI);
500 }
501
updateFnegToFsub(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > &)502 static void updateFnegToFsub(Instruction &I,
503 SmallVectorImpl<Instruction *> &ToRemove,
504 DenseMap<Value *, Value *> &) {
505 const Intrinsic::ID ID = I.getOpcode();
506 if (ID != Instruction::FNeg)
507 return;
508
509 IRBuilder<> Builder(&I);
510 Value *In = I.getOperand(0);
511 Value *Zero = ConstantFP::get(In->getType(), -0.0);
512 I.replaceAllUsesWith(Builder.CreateFSub(Zero, In));
513 ToRemove.push_back(&I);
514 }
515
516 static void
legalizeGetHighLowi64Bytes(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)517 legalizeGetHighLowi64Bytes(Instruction &I,
518 SmallVectorImpl<Instruction *> &ToRemove,
519 DenseMap<Value *, Value *> &ReplacedValues) {
520 if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
521 if (BitCast->getDestTy() ==
522 FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
523 BitCast->getSrcTy()->isIntegerTy(64)) {
524 ToRemove.push_back(BitCast);
525 ReplacedValues[BitCast] = BitCast->getOperand(0);
526 return;
527 }
528 }
529
530 if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
531 if (!dyn_cast<BitCastInst>(Extract->getVectorOperand()))
532 return;
533 auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
534 if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
535 VecTy->getNumElements() == 2) {
536 if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) {
537 unsigned Idx = Index->getZExtValue();
538 IRBuilder<> Builder(&I);
539
540 auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
541 assert(Replacement && "The BitCast replacement should have been set "
542 "before working on ExtractElementInst.");
543 if (Idx == 0) {
544 Value *LowBytes = Builder.CreateTrunc(
545 Replacement, Type::getInt32Ty(I.getContext()));
546 ReplacedValues[Extract] = LowBytes;
547 } else {
548 assert(Idx == 1);
549 Value *LogicalShiftRight = Builder.CreateLShr(
550 Replacement,
551 ConstantInt::get(
552 Replacement->getType(),
553 APInt(Replacement->getType()->getIntegerBitWidth(), 32)));
554 Value *HighBytes = Builder.CreateTrunc(
555 LogicalShiftRight, Type::getInt32Ty(I.getContext()));
556 ReplacedValues[Extract] = HighBytes;
557 }
558 ToRemove.push_back(Extract);
559 Extract->replaceAllUsesWith(ReplacedValues[Extract]);
560 }
561 }
562 }
563 }
564
565 namespace {
566 class DXILLegalizationPipeline {
567
568 public:
DXILLegalizationPipeline()569 DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
570
runLegalizationPipeline(Function & F)571 bool runLegalizationPipeline(Function &F) {
572 bool MadeChange = false;
573 SmallVector<Instruction *> ToRemove;
574 DenseMap<Value *, Value *> ReplacedValues;
575 for (int Stage = 0; Stage < NumStages; ++Stage) {
576 ToRemove.clear();
577 ReplacedValues.clear();
578 for (auto &I : instructions(F)) {
579 for (auto &LegalizationFn : LegalizationPipeline[Stage])
580 LegalizationFn(I, ToRemove, ReplacedValues);
581 }
582
583 for (auto *Inst : reverse(ToRemove))
584 Inst->eraseFromParent();
585
586 MadeChange |= !ToRemove.empty();
587 }
588 return MadeChange;
589 }
590
591 private:
592 enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
593
594 using LegalizationFnTy =
595 std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
596 DenseMap<Value *, Value *> &)>;
597
598 SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
599
initializeLegalizationPipeline()600 void initializeLegalizationPipeline() {
601 LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses);
602 LegalizationPipeline[Stage1].push_back(fixI8UseChain);
603 LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
604 LegalizationPipeline[Stage1].push_back(legalizeFreeze);
605 LegalizationPipeline[Stage1].push_back(legalizeMemCpy);
606 LegalizationPipeline[Stage1].push_back(removeMemSet);
607 LegalizationPipeline[Stage1].push_back(updateFnegToFsub);
608 // Note: legalizeGetHighLowi64Bytes and
609 // downcastI64toI32InsertExtractElements both modify extractelement, so they
610 // must run staggered stages. legalizeGetHighLowi64Bytes runs first b\c it
611 // removes extractelements, reducing the number that
612 // downcastI64toI32InsertExtractElements needs to handle.
613 LegalizationPipeline[Stage2].push_back(
614 downcastI64toI32InsertExtractElements);
615 }
616 };
617
618 class DXILLegalizeLegacy : public FunctionPass {
619
620 public:
621 bool runOnFunction(Function &F) override;
DXILLegalizeLegacy()622 DXILLegalizeLegacy() : FunctionPass(ID) {}
623
624 static char ID; // Pass identification.
625 };
626 } // namespace
627
run(Function & F,FunctionAnalysisManager & FAM)628 PreservedAnalyses DXILLegalizePass::run(Function &F,
629 FunctionAnalysisManager &FAM) {
630 DXILLegalizationPipeline DXLegalize;
631 bool MadeChanges = DXLegalize.runLegalizationPipeline(F);
632 if (!MadeChanges)
633 return PreservedAnalyses::all();
634 PreservedAnalyses PA;
635 return PA;
636 }
637
runOnFunction(Function & F)638 bool DXILLegalizeLegacy::runOnFunction(Function &F) {
639 DXILLegalizationPipeline DXLegalize;
640 return DXLegalize.runLegalizationPipeline(F);
641 }
642
643 char DXILLegalizeLegacy::ID = 0;
644
645 INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
646 false)
647 INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
648 false)
649
createDXILLegalizeLegacyPass()650 FunctionPass *llvm::createDXILLegalizeLegacyPass() {
651 return new DXILLegalizeLegacy();
652 }
653