xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILLegalizePass.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
9 #include "DXILLegalizePass.h"
10 #include "DirectX.h"
11 #include "llvm/ADT/APInt.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/Function.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instruction.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Pass.h"
20 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
21 #include <functional>
22 
23 #define DEBUG_TYPE "dxil-legalize"
24 
25 using namespace llvm;
26 
legalizeFreeze(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * >)27 static void legalizeFreeze(Instruction &I,
28                            SmallVectorImpl<Instruction *> &ToRemove,
29                            DenseMap<Value *, Value *>) {
30   auto *FI = dyn_cast<FreezeInst>(&I);
31   if (!FI)
32     return;
33 
34   FI->replaceAllUsesWith(FI->getOperand(0));
35   ToRemove.push_back(FI);
36 }
37 
fixI8UseChain(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)38 static void fixI8UseChain(Instruction &I,
39                           SmallVectorImpl<Instruction *> &ToRemove,
40                           DenseMap<Value *, Value *> &ReplacedValues) {
41 
42   auto ProcessOperands = [&](SmallVector<Value *> &NewOperands) {
43     Type *InstrType = IntegerType::get(I.getContext(), 32);
44 
45     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
46       Value *Op = I.getOperand(OpIdx);
47       if (ReplacedValues.count(Op) &&
48           ReplacedValues[Op]->getType()->isIntegerTy())
49         InstrType = ReplacedValues[Op]->getType();
50     }
51 
52     for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
53       Value *Op = I.getOperand(OpIdx);
54       if (ReplacedValues.count(Op))
55         NewOperands.push_back(ReplacedValues[Op]);
56       else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
57         APInt Value = Imm->getValue();
58         unsigned NewBitWidth = InstrType->getIntegerBitWidth();
59         // Note: options here are sext or sextOrTrunc.
60         // Since i8 isn't supported, we assume new values
61         // will always have a higher bitness.
62         assert(NewBitWidth > Value.getBitWidth() &&
63                "Replacement's BitWidth should be larger than Current.");
64         APInt NewValue = Value.sext(NewBitWidth);
65         NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
66       } else {
67         assert(!Op->getType()->isIntegerTy(8));
68         NewOperands.push_back(Op);
69       }
70     }
71   };
72   IRBuilder<> Builder(&I);
73   if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
74     if (Trunc->getDestTy()->isIntegerTy(8)) {
75       ReplacedValues[Trunc] = Trunc->getOperand(0);
76       ToRemove.push_back(Trunc);
77       return;
78     }
79   }
80 
81   if (auto *Store = dyn_cast<StoreInst>(&I)) {
82     if (!Store->getValueOperand()->getType()->isIntegerTy(8))
83       return;
84     SmallVector<Value *> NewOperands;
85     ProcessOperands(NewOperands);
86     Value *NewStore = Builder.CreateStore(NewOperands[0], NewOperands[1]);
87     ReplacedValues[Store] = NewStore;
88     ToRemove.push_back(Store);
89     return;
90   }
91 
92   if (auto *Load = dyn_cast<LoadInst>(&I);
93       Load && I.getType()->isIntegerTy(8)) {
94     SmallVector<Value *> NewOperands;
95     ProcessOperands(NewOperands);
96     Type *ElementType = NewOperands[0]->getType();
97     if (auto *AI = dyn_cast<AllocaInst>(NewOperands[0]))
98       ElementType = AI->getAllocatedType();
99     if (auto *GEP = dyn_cast<GetElementPtrInst>(NewOperands[0])) {
100       ElementType = GEP->getSourceElementType();
101       if (ElementType->isArrayTy())
102         ElementType = ElementType->getArrayElementType();
103     }
104     LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
105     ReplacedValues[Load] = NewLoad;
106     ToRemove.push_back(Load);
107     return;
108   }
109 
110   if (auto *Load = dyn_cast<LoadInst>(&I);
111       Load && isa<ConstantExpr>(Load->getPointerOperand())) {
112     auto *CE = dyn_cast<ConstantExpr>(Load->getPointerOperand());
113     if (!(CE->getOpcode() == Instruction::GetElementPtr))
114       return;
115     auto *GEP = dyn_cast<GEPOperator>(CE);
116     if (!GEP->getSourceElementType()->isIntegerTy(8))
117       return;
118 
119     Type *ElementType = Load->getType();
120     ConstantInt *Offset = dyn_cast<ConstantInt>(GEP->getOperand(1));
121     uint32_t ByteOffset = Offset->getZExtValue();
122     uint32_t ElemSize = Load->getDataLayout().getTypeAllocSize(ElementType);
123     uint32_t Index = ByteOffset / ElemSize;
124 
125     Value *PtrOperand = GEP->getPointerOperand();
126     Type *GEPType = GEP->getPointerOperandType();
127 
128     if (auto *GV = dyn_cast<GlobalVariable>(PtrOperand))
129       GEPType = GV->getValueType();
130     if (auto *AI = dyn_cast<AllocaInst>(PtrOperand))
131       GEPType = AI->getAllocatedType();
132 
133     if (auto *ArrTy = dyn_cast<ArrayType>(GEPType))
134       GEPType = ArrTy;
135     else
136       GEPType = ArrayType::get(ElementType, 1); // its a scalar
137 
138     Value *NewGEP = Builder.CreateGEP(
139         GEPType, PtrOperand, {Builder.getInt32(0), Builder.getInt32(Index)},
140         GEP->getName(), GEP->getNoWrapFlags());
141 
142     LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewGEP);
143     ReplacedValues[Load] = NewLoad;
144     Load->replaceAllUsesWith(NewLoad);
145     ToRemove.push_back(Load);
146     return;
147   }
148 
149   if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
150     if (!I.getType()->isIntegerTy(8))
151       return;
152     SmallVector<Value *> NewOperands;
153     ProcessOperands(NewOperands);
154     Value *NewInst =
155         Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
156     if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
157       auto *NewBO = dyn_cast<BinaryOperator>(NewInst);
158       if (NewBO && OBO->hasNoSignedWrap())
159         NewBO->setHasNoSignedWrap();
160       if (NewBO && OBO->hasNoUnsignedWrap())
161         NewBO->setHasNoUnsignedWrap();
162     }
163     ReplacedValues[BO] = NewInst;
164     ToRemove.push_back(BO);
165     return;
166   }
167 
168   if (auto *Sel = dyn_cast<SelectInst>(&I)) {
169     if (!I.getType()->isIntegerTy(8))
170       return;
171     SmallVector<Value *> NewOperands;
172     ProcessOperands(NewOperands);
173     Value *NewInst = Builder.CreateSelect(Sel->getCondition(), NewOperands[1],
174                                           NewOperands[2]);
175     ReplacedValues[Sel] = NewInst;
176     ToRemove.push_back(Sel);
177     return;
178   }
179 
180   if (auto *Cmp = dyn_cast<CmpInst>(&I)) {
181     if (!Cmp->getOperand(0)->getType()->isIntegerTy(8))
182       return;
183     SmallVector<Value *> NewOperands;
184     ProcessOperands(NewOperands);
185     Value *NewInst =
186         Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], NewOperands[1]);
187     Cmp->replaceAllUsesWith(NewInst);
188     ReplacedValues[Cmp] = NewInst;
189     ToRemove.push_back(Cmp);
190     return;
191   }
192 
193   if (auto *Cast = dyn_cast<CastInst>(&I)) {
194     if (!Cast->getSrcTy()->isIntegerTy(8))
195       return;
196 
197     ToRemove.push_back(Cast);
198     auto *Replacement = ReplacedValues[Cast->getOperand(0)];
199     if (Cast->getType() == Replacement->getType()) {
200       Cast->replaceAllUsesWith(Replacement);
201       return;
202     }
203 
204     Value *AdjustedCast = nullptr;
205     if (Cast->getOpcode() == Instruction::ZExt)
206       AdjustedCast = Builder.CreateZExtOrTrunc(Replacement, Cast->getType());
207     if (Cast->getOpcode() == Instruction::SExt)
208       AdjustedCast = Builder.CreateSExtOrTrunc(Replacement, Cast->getType());
209 
210     if (AdjustedCast)
211       Cast->replaceAllUsesWith(AdjustedCast);
212   }
213   if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
214     if (!GEP->getType()->isPointerTy() ||
215         !GEP->getSourceElementType()->isIntegerTy(8))
216       return;
217 
218     Value *BasePtr = GEP->getPointerOperand();
219     if (ReplacedValues.count(BasePtr))
220       BasePtr = ReplacedValues[BasePtr];
221 
222     Type *ElementType = BasePtr->getType();
223 
224     if (auto *AI = dyn_cast<AllocaInst>(BasePtr))
225       ElementType = AI->getAllocatedType();
226     if (auto *GV = dyn_cast<GlobalVariable>(BasePtr))
227       ElementType = GV->getValueType();
228 
229     Type *GEPType = ElementType;
230     if (auto *ArrTy = dyn_cast<ArrayType>(ElementType))
231       ElementType = ArrTy->getArrayElementType();
232     else
233       GEPType = ArrayType::get(ElementType, 1); // its a scalar
234 
235     ConstantInt *Offset = dyn_cast<ConstantInt>(GEP->getOperand(1));
236     // Note: i8 to i32 offset conversion without emitting IR requires constant
237     // ints. Since offset conversion is common, we can safely assume Offset is
238     // always a ConstantInt, so no need to have a conditional bail out on
239     // nullptr, instead assert this is the case.
240     assert(Offset && "Offset is expected to be a ConstantInt");
241     uint32_t ByteOffset = Offset->getZExtValue();
242     uint32_t ElemSize = GEP->getDataLayout().getTypeAllocSize(ElementType);
243     assert(ElemSize > 0 && "ElementSize must be set");
244     uint32_t Index = ByteOffset / ElemSize;
245     Value *NewGEP = Builder.CreateGEP(
246         GEPType, BasePtr, {Builder.getInt32(0), Builder.getInt32(Index)},
247         GEP->getName(), GEP->getNoWrapFlags());
248     ReplacedValues[GEP] = NewGEP;
249     GEP->replaceAllUsesWith(NewGEP);
250     ToRemove.push_back(GEP);
251   }
252 }
253 
upcastI8AllocasAndUses(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)254 static void upcastI8AllocasAndUses(Instruction &I,
255                                    SmallVectorImpl<Instruction *> &ToRemove,
256                                    DenseMap<Value *, Value *> &ReplacedValues) {
257   auto *AI = dyn_cast<AllocaInst>(&I);
258   if (!AI || !AI->getAllocatedType()->isIntegerTy(8))
259     return;
260 
261   Type *SmallestType = nullptr;
262 
263   auto ProcessLoad = [&](LoadInst *Load) {
264     for (User *LU : Load->users()) {
265       Type *Ty = nullptr;
266       if (CastInst *Cast = dyn_cast<CastInst>(LU))
267         Ty = Cast->getType();
268       else if (CallInst *CI = dyn_cast<CallInst>(LU)) {
269         if (CI->getIntrinsicID() == Intrinsic::memset)
270           Ty = Type::getInt32Ty(CI->getContext());
271       }
272 
273       if (!Ty)
274         continue;
275 
276       if (!SmallestType ||
277           Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits())
278         SmallestType = Ty;
279     }
280   };
281 
282   for (User *U : AI->users()) {
283     if (auto *Load = dyn_cast<LoadInst>(U))
284       ProcessLoad(Load);
285     else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
286       for (User *GU : GEP->users()) {
287         if (auto *Load = dyn_cast<LoadInst>(GU))
288           ProcessLoad(Load);
289       }
290     }
291   }
292 
293   if (!SmallestType)
294     return; // no valid casts found
295 
296   // Replace alloca
297   IRBuilder<> Builder(AI);
298   auto *NewAlloca = Builder.CreateAlloca(SmallestType);
299   ReplacedValues[AI] = NewAlloca;
300   ToRemove.push_back(AI);
301 }
302 
303 static void
downcastI64toI32InsertExtractElements(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > &)304 downcastI64toI32InsertExtractElements(Instruction &I,
305                                       SmallVectorImpl<Instruction *> &ToRemove,
306                                       DenseMap<Value *, Value *> &) {
307 
308   if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
309     Value *Idx = Extract->getIndexOperand();
310     auto *CI = dyn_cast<ConstantInt>(Idx);
311     if (CI && CI->getBitWidth() == 64) {
312       IRBuilder<> Builder(Extract);
313       int64_t IndexValue = CI->getSExtValue();
314       auto *Idx32 =
315           ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
316       Value *NewExtract = Builder.CreateExtractElement(
317           Extract->getVectorOperand(), Idx32, Extract->getName());
318 
319       Extract->replaceAllUsesWith(NewExtract);
320       ToRemove.push_back(Extract);
321     }
322   }
323 
324   if (auto *Insert = dyn_cast<InsertElementInst>(&I)) {
325     Value *Idx = Insert->getOperand(2);
326     auto *CI = dyn_cast<ConstantInt>(Idx);
327     if (CI && CI->getBitWidth() == 64) {
328       int64_t IndexValue = CI->getSExtValue();
329       auto *Idx32 =
330           ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
331       IRBuilder<> Builder(Insert);
332       Value *Insert32Index = Builder.CreateInsertElement(
333           Insert->getOperand(0), Insert->getOperand(1), Idx32,
334           Insert->getName());
335 
336       Insert->replaceAllUsesWith(Insert32Index);
337       ToRemove.push_back(Insert);
338     }
339   }
340 }
341 
emitMemcpyExpansion(IRBuilder<> & Builder,Value * Dst,Value * Src,ConstantInt * Length)342 static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
343                                 ConstantInt *Length) {
344 
345   uint64_t ByteLength = Length->getZExtValue();
346   // If length to copy is zero, no memcpy is needed.
347   if (ByteLength == 0)
348     return;
349 
350   LLVMContext &Ctx = Builder.getContext();
351   const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
352 
353   auto GetArrTyFromVal = [](Value *Val) -> ArrayType * {
354     assert(isa<AllocaInst>(Val) ||
355            isa<GlobalVariable>(Val) &&
356                "Expected Val to be an Alloca or Global Variable");
357     if (auto *Alloca = dyn_cast<AllocaInst>(Val))
358       return dyn_cast<ArrayType>(Alloca->getAllocatedType());
359     if (auto *GlobalVar = dyn_cast<GlobalVariable>(Val))
360       return dyn_cast<ArrayType>(GlobalVar->getValueType());
361     return nullptr;
362   };
363 
364   ArrayType *DstArrTy = GetArrTyFromVal(Dst);
365   assert(DstArrTy && "Expected Dst of memcpy to be a Pointer to an Array Type");
366   if (auto *DstGlobalVar = dyn_cast<GlobalVariable>(Dst))
367     assert(!DstGlobalVar->isConstant() &&
368            "The Dst of memcpy must not be a constant Global Variable");
369   [[maybe_unused]] ArrayType *SrcArrTy = GetArrTyFromVal(Src);
370   assert(SrcArrTy && "Expected Src of memcpy to be a Pointer to an Array Type");
371 
372   Type *DstElemTy = DstArrTy->getElementType();
373   uint64_t DstElemByteSize = DL.getTypeStoreSize(DstElemTy);
374   assert(DstElemByteSize > 0 && "Dst element type store size must be set");
375   Type *SrcElemTy = SrcArrTy->getElementType();
376   [[maybe_unused]] uint64_t SrcElemByteSize = DL.getTypeStoreSize(SrcElemTy);
377   assert(SrcElemByteSize > 0 && "Src element type store size must be set");
378 
379   // This assumption simplifies implementation and covers currently-known
380   // use-cases for DXIL. It may be relaxed in the future if required.
381   assert(DstElemTy == SrcElemTy &&
382          "The element types of Src and Dst arrays must match");
383 
384   [[maybe_unused]] uint64_t DstArrNumElems = DstArrTy->getArrayNumElements();
385   assert(DstElemByteSize * DstArrNumElems >= ByteLength &&
386          "Dst array size must be at least as large as the memcpy length");
387   [[maybe_unused]] uint64_t SrcArrNumElems = SrcArrTy->getArrayNumElements();
388   assert(SrcElemByteSize * SrcArrNumElems >= ByteLength &&
389          "Src array size must be at least as large as the memcpy length");
390 
391   uint64_t NumElemsToCopy = ByteLength / DstElemByteSize;
392   assert(ByteLength % DstElemByteSize == 0 &&
393          "memcpy length must be divisible by array element type");
394   for (uint64_t I = 0; I < NumElemsToCopy; ++I) {
395     Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
396     Value *SrcPtr = Builder.CreateInBoundsGEP(SrcElemTy, Src, Offset, "gep");
397     Value *SrcVal = Builder.CreateLoad(SrcElemTy, SrcPtr);
398     Value *DstPtr = Builder.CreateInBoundsGEP(DstElemTy, Dst, Offset, "gep");
399     Builder.CreateStore(SrcVal, DstPtr);
400   }
401 }
402 
emitMemsetExpansion(IRBuilder<> & Builder,Value * Dst,Value * Val,ConstantInt * SizeCI,DenseMap<Value *,Value * > & ReplacedValues)403 static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
404                                 ConstantInt *SizeCI,
405                                 DenseMap<Value *, Value *> &ReplacedValues) {
406   LLVMContext &Ctx = Builder.getContext();
407   [[maybe_unused]] const DataLayout &DL =
408       Builder.GetInsertBlock()->getModule()->getDataLayout();
409   [[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue();
410 
411   AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst);
412 
413   assert(Alloca && "Expected memset on an Alloca");
414   assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() &&
415          "Expected for memset size to match DataLayout size");
416 
417   Type *AllocatedTy = Alloca->getAllocatedType();
418   ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy);
419   assert(ArrTy && "Expected Alloca for an Array Type");
420 
421   Type *ElemTy = ArrTy->getElementType();
422   uint64_t Size = ArrTy->getArrayNumElements();
423 
424   [[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);
425 
426   assert(ElemSize > 0 && "Size must be set");
427   assert(OrigSize == ElemSize * Size && "Size in bytes must match");
428 
429   Value *TypedVal = Val;
430 
431   if (Val->getType() != ElemTy) {
432     if (ReplacedValues[Val]) {
433       // Note for i8 replacements if we know them we should use them.
434       // Further if this is a constant ReplacedValues will return null
435       // so we will stick to TypedVal = Val
436       TypedVal = ReplacedValues[Val];
437 
438     } else {
439       // This case Val is a ConstantInt so the cast folds away.
440       // However if we don't do the cast the store below ends up being
441       // an i8.
442       TypedVal = Builder.CreateIntCast(Val, ElemTy, false);
443     }
444   }
445 
446   for (uint64_t I = 0; I < Size; ++I) {
447     Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
448     Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
449     Builder.CreateStore(TypedVal, Ptr);
450   }
451 }
452 
453 // Expands the instruction `I` into corresponding loads and stores if it is a
454 // memcpy call. In that case, the call instruction is added to the `ToRemove`
455 // vector. `ReplacedValues` is unused.
legalizeMemCpy(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)456 static void legalizeMemCpy(Instruction &I,
457                            SmallVectorImpl<Instruction *> &ToRemove,
458                            DenseMap<Value *, Value *> &ReplacedValues) {
459 
460   CallInst *CI = dyn_cast<CallInst>(&I);
461   if (!CI)
462     return;
463 
464   Intrinsic::ID ID = CI->getIntrinsicID();
465   if (ID != Intrinsic::memcpy)
466     return;
467 
468   IRBuilder<> Builder(&I);
469   Value *Dst = CI->getArgOperand(0);
470   Value *Src = CI->getArgOperand(1);
471   ConstantInt *Length = dyn_cast<ConstantInt>(CI->getArgOperand(2));
472   assert(Length && "Expected Length to be a ConstantInt");
473   [[maybe_unused]] ConstantInt *IsVolatile =
474       dyn_cast<ConstantInt>(CI->getArgOperand(3));
475   assert(IsVolatile && "Expected IsVolatile to be a ConstantInt");
476   assert(IsVolatile->getZExtValue() == 0 && "Expected IsVolatile to be false");
477   emitMemcpyExpansion(Builder, Dst, Src, Length);
478   ToRemove.push_back(CI);
479 }
480 
removeMemSet(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)481 static void removeMemSet(Instruction &I,
482                          SmallVectorImpl<Instruction *> &ToRemove,
483                          DenseMap<Value *, Value *> &ReplacedValues) {
484 
485   CallInst *CI = dyn_cast<CallInst>(&I);
486   if (!CI)
487     return;
488 
489   Intrinsic::ID ID = CI->getIntrinsicID();
490   if (ID != Intrinsic::memset)
491     return;
492 
493   IRBuilder<> Builder(&I);
494   Value *Dst = CI->getArgOperand(0);
495   Value *Val = CI->getArgOperand(1);
496   ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(2));
497   assert(Size && "Expected Size to be a ConstantInt");
498   emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues);
499   ToRemove.push_back(CI);
500 }
501 
updateFnegToFsub(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > &)502 static void updateFnegToFsub(Instruction &I,
503                              SmallVectorImpl<Instruction *> &ToRemove,
504                              DenseMap<Value *, Value *> &) {
505   const Intrinsic::ID ID = I.getOpcode();
506   if (ID != Instruction::FNeg)
507     return;
508 
509   IRBuilder<> Builder(&I);
510   Value *In = I.getOperand(0);
511   Value *Zero = ConstantFP::get(In->getType(), -0.0);
512   I.replaceAllUsesWith(Builder.CreateFSub(Zero, In));
513   ToRemove.push_back(&I);
514 }
515 
516 static void
legalizeGetHighLowi64Bytes(Instruction & I,SmallVectorImpl<Instruction * > & ToRemove,DenseMap<Value *,Value * > & ReplacedValues)517 legalizeGetHighLowi64Bytes(Instruction &I,
518                            SmallVectorImpl<Instruction *> &ToRemove,
519                            DenseMap<Value *, Value *> &ReplacedValues) {
520   if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
521     if (BitCast->getDestTy() ==
522             FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
523         BitCast->getSrcTy()->isIntegerTy(64)) {
524       ToRemove.push_back(BitCast);
525       ReplacedValues[BitCast] = BitCast->getOperand(0);
526       return;
527     }
528   }
529 
530   if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
531     if (!dyn_cast<BitCastInst>(Extract->getVectorOperand()))
532       return;
533     auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
534     if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
535         VecTy->getNumElements() == 2) {
536       if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) {
537         unsigned Idx = Index->getZExtValue();
538         IRBuilder<> Builder(&I);
539 
540         auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
541         assert(Replacement && "The BitCast replacement should have been set "
542                               "before working on ExtractElementInst.");
543         if (Idx == 0) {
544           Value *LowBytes = Builder.CreateTrunc(
545               Replacement, Type::getInt32Ty(I.getContext()));
546           ReplacedValues[Extract] = LowBytes;
547         } else {
548           assert(Idx == 1);
549           Value *LogicalShiftRight = Builder.CreateLShr(
550               Replacement,
551               ConstantInt::get(
552                   Replacement->getType(),
553                   APInt(Replacement->getType()->getIntegerBitWidth(), 32)));
554           Value *HighBytes = Builder.CreateTrunc(
555               LogicalShiftRight, Type::getInt32Ty(I.getContext()));
556           ReplacedValues[Extract] = HighBytes;
557         }
558         ToRemove.push_back(Extract);
559         Extract->replaceAllUsesWith(ReplacedValues[Extract]);
560       }
561     }
562   }
563 }
564 
565 namespace {
566 class DXILLegalizationPipeline {
567 
568 public:
DXILLegalizationPipeline()569   DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
570 
runLegalizationPipeline(Function & F)571   bool runLegalizationPipeline(Function &F) {
572     bool MadeChange = false;
573     SmallVector<Instruction *> ToRemove;
574     DenseMap<Value *, Value *> ReplacedValues;
575     for (int Stage = 0; Stage < NumStages; ++Stage) {
576       ToRemove.clear();
577       ReplacedValues.clear();
578       for (auto &I : instructions(F)) {
579         for (auto &LegalizationFn : LegalizationPipeline[Stage])
580           LegalizationFn(I, ToRemove, ReplacedValues);
581       }
582 
583       for (auto *Inst : reverse(ToRemove))
584         Inst->eraseFromParent();
585 
586       MadeChange |= !ToRemove.empty();
587     }
588     return MadeChange;
589   }
590 
591 private:
592   enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
593 
594   using LegalizationFnTy =
595       std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
596                          DenseMap<Value *, Value *> &)>;
597 
598   SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
599 
initializeLegalizationPipeline()600   void initializeLegalizationPipeline() {
601     LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses);
602     LegalizationPipeline[Stage1].push_back(fixI8UseChain);
603     LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
604     LegalizationPipeline[Stage1].push_back(legalizeFreeze);
605     LegalizationPipeline[Stage1].push_back(legalizeMemCpy);
606     LegalizationPipeline[Stage1].push_back(removeMemSet);
607     LegalizationPipeline[Stage1].push_back(updateFnegToFsub);
608     // Note: legalizeGetHighLowi64Bytes and
609     // downcastI64toI32InsertExtractElements both modify extractelement, so they
610     // must run staggered stages. legalizeGetHighLowi64Bytes runs first b\c it
611     // removes extractelements, reducing the number that
612     // downcastI64toI32InsertExtractElements needs to handle.
613     LegalizationPipeline[Stage2].push_back(
614         downcastI64toI32InsertExtractElements);
615   }
616 };
617 
618 class DXILLegalizeLegacy : public FunctionPass {
619 
620 public:
621   bool runOnFunction(Function &F) override;
DXILLegalizeLegacy()622   DXILLegalizeLegacy() : FunctionPass(ID) {}
623 
624   static char ID; // Pass identification.
625 };
626 } // namespace
627 
run(Function & F,FunctionAnalysisManager & FAM)628 PreservedAnalyses DXILLegalizePass::run(Function &F,
629                                         FunctionAnalysisManager &FAM) {
630   DXILLegalizationPipeline DXLegalize;
631   bool MadeChanges = DXLegalize.runLegalizationPipeline(F);
632   if (!MadeChanges)
633     return PreservedAnalyses::all();
634   PreservedAnalyses PA;
635   return PA;
636 }
637 
runOnFunction(Function & F)638 bool DXILLegalizeLegacy::runOnFunction(Function &F) {
639   DXILLegalizationPipeline DXLegalize;
640   return DXLegalize.runLegalizationPipeline(F);
641 }
642 
643 char DXILLegalizeLegacy::ID = 0;
644 
645 INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
646                       false)
647 INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
648                     false)
649 
createDXILLegalizeLegacyPass()650 FunctionPass *llvm::createDXILLegalizeLegacyPass() {
651   return new DXILLegalizeLegacy();
652 }
653