xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- AMDGPInstCombineIntrinsic.cpp - AMDGPU specific InstCombine pass ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements a TargetTransformInfo analysis pass specific to the
11 // AMDGPU target machine. It uses the target's detailed information to provide
12 // more precise answers to certain TTI queries, while letting the target
13 // independent and default TTI implementations handle the rest.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "AMDGPUInstrInfo.h"
18 #include "AMDGPUTargetTransformInfo.h"
19 #include "GCNSubtarget.h"
20 #include "llvm/ADT/FloatingPointMode.h"
21 #include "llvm/IR/Dominators.h"
22 #include "llvm/IR/IntrinsicsAMDGPU.h"
23 #include "llvm/Transforms/InstCombine/InstCombiner.h"
24 #include <optional>
25 
26 using namespace llvm;
27 using namespace llvm::PatternMatch;
28 
29 #define DEBUG_TYPE "AMDGPUtti"
30 
31 namespace {
32 
33 struct AMDGPUImageDMaskIntrinsic {
34   unsigned Intr;
35 };
36 
37 #define GET_AMDGPUImageDMaskIntrinsicTable_IMPL
38 #include "InstCombineTables.inc"
39 
40 } // end anonymous namespace
41 
42 // Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs.
43 //
44 // A single NaN input is folded to minnum, so we rely on that folding for
45 // handling NaNs.
fmed3AMDGCN(const APFloat & Src0,const APFloat & Src1,const APFloat & Src2)46 static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1,
47                            const APFloat &Src2) {
48   APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2);
49 
50   APFloat::cmpResult Cmp0 = Max3.compare(Src0);
51   assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately");
52   if (Cmp0 == APFloat::cmpEqual)
53     return maxnum(Src1, Src2);
54 
55   APFloat::cmpResult Cmp1 = Max3.compare(Src1);
56   assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately");
57   if (Cmp1 == APFloat::cmpEqual)
58     return maxnum(Src0, Src2);
59 
60   return maxnum(Src0, Src1);
61 }
62 
63 // Check if a value can be converted to a 16-bit value without losing
64 // precision.
65 // The value is expected to be either a float (IsFloat = true) or an unsigned
66 // integer (IsFloat = false).
canSafelyConvertTo16Bit(Value & V,bool IsFloat)67 static bool canSafelyConvertTo16Bit(Value &V, bool IsFloat) {
68   Type *VTy = V.getType();
69   if (VTy->isHalfTy() || VTy->isIntegerTy(16)) {
70     // The value is already 16-bit, so we don't want to convert to 16-bit again!
71     return false;
72   }
73   if (IsFloat) {
74     if (ConstantFP *ConstFloat = dyn_cast<ConstantFP>(&V)) {
75       // We need to check that if we cast the index down to a half, we do not
76       // lose precision.
77       APFloat FloatValue(ConstFloat->getValueAPF());
78       bool LosesInfo = true;
79       FloatValue.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero,
80                          &LosesInfo);
81       return !LosesInfo;
82     }
83   } else {
84     if (ConstantInt *ConstInt = dyn_cast<ConstantInt>(&V)) {
85       // We need to check that if we cast the index down to an i16, we do not
86       // lose precision.
87       APInt IntValue(ConstInt->getValue());
88       return IntValue.getActiveBits() <= 16;
89     }
90   }
91 
92   Value *CastSrc;
93   bool IsExt = IsFloat ? match(&V, m_FPExt(PatternMatch::m_Value(CastSrc)))
94                        : match(&V, m_ZExt(PatternMatch::m_Value(CastSrc)));
95   if (IsExt) {
96     Type *CastSrcTy = CastSrc->getType();
97     if (CastSrcTy->isHalfTy() || CastSrcTy->isIntegerTy(16))
98       return true;
99   }
100 
101   return false;
102 }
103 
104 // Convert a value to 16-bit.
convertTo16Bit(Value & V,InstCombiner::BuilderTy & Builder)105 static Value *convertTo16Bit(Value &V, InstCombiner::BuilderTy &Builder) {
106   Type *VTy = V.getType();
107   if (isa<FPExtInst, SExtInst, ZExtInst>(&V))
108     return cast<Instruction>(&V)->getOperand(0);
109   if (VTy->isIntegerTy())
110     return Builder.CreateIntCast(&V, Type::getInt16Ty(V.getContext()), false);
111   if (VTy->isFloatingPointTy())
112     return Builder.CreateFPCast(&V, Type::getHalfTy(V.getContext()));
113 
114   llvm_unreachable("Should never be called!");
115 }
116 
117 /// Applies Func(OldIntr.Args, OldIntr.ArgTys), creates intrinsic call with
118 /// modified arguments (based on OldIntr) and replaces InstToReplace with
119 /// this newly created intrinsic call.
modifyIntrinsicCall(IntrinsicInst & OldIntr,Instruction & InstToReplace,unsigned NewIntr,InstCombiner & IC,std::function<void (SmallVectorImpl<Value * > &,SmallVectorImpl<Type * > &)> Func)120 static std::optional<Instruction *> modifyIntrinsicCall(
121     IntrinsicInst &OldIntr, Instruction &InstToReplace, unsigned NewIntr,
122     InstCombiner &IC,
123     std::function<void(SmallVectorImpl<Value *> &, SmallVectorImpl<Type *> &)>
124         Func) {
125   SmallVector<Type *, 4> ArgTys;
126   if (!Intrinsic::getIntrinsicSignature(OldIntr.getCalledFunction(), ArgTys))
127     return std::nullopt;
128 
129   SmallVector<Value *, 8> Args(OldIntr.args());
130 
131   // Modify arguments and types
132   Func(Args, ArgTys);
133 
134   CallInst *NewCall = IC.Builder.CreateIntrinsic(NewIntr, ArgTys, Args);
135   NewCall->takeName(&OldIntr);
136   NewCall->copyMetadata(OldIntr);
137   if (isa<FPMathOperator>(NewCall))
138     NewCall->copyFastMathFlags(&OldIntr);
139 
140   // Erase and replace uses
141   if (!InstToReplace.getType()->isVoidTy())
142     IC.replaceInstUsesWith(InstToReplace, NewCall);
143 
144   bool RemoveOldIntr = &OldIntr != &InstToReplace;
145 
146   auto *RetValue = IC.eraseInstFromFunction(InstToReplace);
147   if (RemoveOldIntr)
148     IC.eraseInstFromFunction(OldIntr);
149 
150   return RetValue;
151 }
152 
153 static std::optional<Instruction *>
simplifyAMDGCNImageIntrinsic(const GCNSubtarget * ST,const AMDGPU::ImageDimIntrinsicInfo * ImageDimIntr,IntrinsicInst & II,InstCombiner & IC)154 simplifyAMDGCNImageIntrinsic(const GCNSubtarget *ST,
155                              const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr,
156                              IntrinsicInst &II, InstCombiner &IC) {
157   // Optimize _L to _LZ when _L is zero
158   if (const auto *LZMappingInfo =
159           AMDGPU::getMIMGLZMappingInfo(ImageDimIntr->BaseOpcode)) {
160     if (auto *ConstantLod =
161             dyn_cast<ConstantFP>(II.getOperand(ImageDimIntr->LodIndex))) {
162       if (ConstantLod->isZero() || ConstantLod->isNegative()) {
163         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
164             AMDGPU::getImageDimIntrinsicByBaseOpcode(LZMappingInfo->LZ,
165                                                      ImageDimIntr->Dim);
166         return modifyIntrinsicCall(
167             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
168               Args.erase(Args.begin() + ImageDimIntr->LodIndex);
169             });
170       }
171     }
172   }
173 
174   // Optimize _mip away, when 'lod' is zero
175   if (const auto *MIPMappingInfo =
176           AMDGPU::getMIMGMIPMappingInfo(ImageDimIntr->BaseOpcode)) {
177     if (auto *ConstantMip =
178             dyn_cast<ConstantInt>(II.getOperand(ImageDimIntr->MipIndex))) {
179       if (ConstantMip->isZero()) {
180         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
181             AMDGPU::getImageDimIntrinsicByBaseOpcode(MIPMappingInfo->NONMIP,
182                                                      ImageDimIntr->Dim);
183         return modifyIntrinsicCall(
184             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
185               Args.erase(Args.begin() + ImageDimIntr->MipIndex);
186             });
187       }
188     }
189   }
190 
191   // Optimize _bias away when 'bias' is zero
192   if (const auto *BiasMappingInfo =
193           AMDGPU::getMIMGBiasMappingInfo(ImageDimIntr->BaseOpcode)) {
194     if (auto *ConstantBias =
195             dyn_cast<ConstantFP>(II.getOperand(ImageDimIntr->BiasIndex))) {
196       if (ConstantBias->isZero()) {
197         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
198             AMDGPU::getImageDimIntrinsicByBaseOpcode(BiasMappingInfo->NoBias,
199                                                      ImageDimIntr->Dim);
200         return modifyIntrinsicCall(
201             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
202               Args.erase(Args.begin() + ImageDimIntr->BiasIndex);
203               ArgTys.erase(ArgTys.begin() + ImageDimIntr->BiasTyArg);
204             });
205       }
206     }
207   }
208 
209   // Optimize _offset away when 'offset' is zero
210   if (const auto *OffsetMappingInfo =
211           AMDGPU::getMIMGOffsetMappingInfo(ImageDimIntr->BaseOpcode)) {
212     if (auto *ConstantOffset =
213             dyn_cast<ConstantInt>(II.getOperand(ImageDimIntr->OffsetIndex))) {
214       if (ConstantOffset->isZero()) {
215         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
216             AMDGPU::getImageDimIntrinsicByBaseOpcode(
217                 OffsetMappingInfo->NoOffset, ImageDimIntr->Dim);
218         return modifyIntrinsicCall(
219             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
220               Args.erase(Args.begin() + ImageDimIntr->OffsetIndex);
221             });
222       }
223     }
224   }
225 
226   // Try to use D16
227   if (ST->hasD16Images()) {
228 
229     const AMDGPU::MIMGBaseOpcodeInfo *BaseOpcode =
230         AMDGPU::getMIMGBaseOpcodeInfo(ImageDimIntr->BaseOpcode);
231 
232     if (BaseOpcode->HasD16) {
233 
234       // If the only use of image intrinsic is a fptrunc (with conversion to
235       // half) then both fptrunc and image intrinsic will be replaced with image
236       // intrinsic with D16 flag.
237       if (II.hasOneUse()) {
238         Instruction *User = II.user_back();
239 
240         if (User->getOpcode() == Instruction::FPTrunc &&
241             User->getType()->getScalarType()->isHalfTy()) {
242 
243           return modifyIntrinsicCall(II, *User, ImageDimIntr->Intr, IC,
244                                      [&](auto &Args, auto &ArgTys) {
245                                        // Change return type of image intrinsic.
246                                        // Set it to return type of fptrunc.
247                                        ArgTys[0] = User->getType();
248                                      });
249         }
250       }
251 
252       // Only perform D16 folding if every user of the image sample is
253       // an ExtractElementInst immediately followed by an FPTrunc to half.
254       SmallVector<std::pair<ExtractElementInst *, FPTruncInst *>, 4>
255           ExtractTruncPairs;
256       bool AllHalfExtracts = true;
257 
258       for (User *U : II.users()) {
259         auto *Ext = dyn_cast<ExtractElementInst>(U);
260         if (!Ext || !Ext->hasOneUse()) {
261           AllHalfExtracts = false;
262           break;
263         }
264 
265         auto *Tr = dyn_cast<FPTruncInst>(*Ext->user_begin());
266         if (!Tr || !Tr->getType()->isHalfTy()) {
267           AllHalfExtracts = false;
268           break;
269         }
270 
271         ExtractTruncPairs.emplace_back(Ext, Tr);
272       }
273 
274       if (!ExtractTruncPairs.empty() && AllHalfExtracts) {
275         auto *VecTy = cast<VectorType>(II.getType());
276         Type *HalfVecTy =
277             VecTy->getWithNewType(Type::getHalfTy(II.getContext()));
278 
279         // Obtain the original image sample intrinsic's signature
280         // and replace its return type with the half-vector for D16 folding
281         SmallVector<Type *, 8> SigTys;
282         Intrinsic::getIntrinsicSignature(II.getCalledFunction(), SigTys);
283         SigTys[0] = HalfVecTy;
284 
285         Module *M = II.getModule();
286         Function *HalfDecl =
287             Intrinsic::getOrInsertDeclaration(M, ImageDimIntr->Intr, SigTys);
288 
289         II.mutateType(HalfVecTy);
290         II.setCalledFunction(HalfDecl);
291 
292         IRBuilder<> Builder(II.getContext());
293         for (auto &[Ext, Tr] : ExtractTruncPairs) {
294           Value *Idx = Ext->getIndexOperand();
295 
296           Builder.SetInsertPoint(Tr);
297 
298           Value *HalfExtract = Builder.CreateExtractElement(&II, Idx);
299           HalfExtract->takeName(Tr);
300 
301           Tr->replaceAllUsesWith(HalfExtract);
302         }
303 
304         for (auto &[Ext, Tr] : ExtractTruncPairs) {
305           IC.eraseInstFromFunction(*Tr);
306           IC.eraseInstFromFunction(*Ext);
307         }
308 
309         return &II;
310       }
311     }
312   }
313 
314   // Try to use A16 or G16
315   if (!ST->hasA16() && !ST->hasG16())
316     return std::nullopt;
317 
318   // Address is interpreted as float if the instruction has a sampler or as
319   // unsigned int if there is no sampler.
320   bool HasSampler =
321       AMDGPU::getMIMGBaseOpcodeInfo(ImageDimIntr->BaseOpcode)->Sampler;
322   bool FloatCoord = false;
323   // true means derivatives can be converted to 16 bit, coordinates not
324   bool OnlyDerivatives = false;
325 
326   for (unsigned OperandIndex = ImageDimIntr->GradientStart;
327        OperandIndex < ImageDimIntr->VAddrEnd; OperandIndex++) {
328     Value *Coord = II.getOperand(OperandIndex);
329     // If the values are not derived from 16-bit values, we cannot optimize.
330     if (!canSafelyConvertTo16Bit(*Coord, HasSampler)) {
331       if (OperandIndex < ImageDimIntr->CoordStart ||
332           ImageDimIntr->GradientStart == ImageDimIntr->CoordStart) {
333         return std::nullopt;
334       }
335       // All gradients can be converted, so convert only them
336       OnlyDerivatives = true;
337       break;
338     }
339 
340     assert(OperandIndex == ImageDimIntr->GradientStart ||
341            FloatCoord == Coord->getType()->isFloatingPointTy());
342     FloatCoord = Coord->getType()->isFloatingPointTy();
343   }
344 
345   if (!OnlyDerivatives && !ST->hasA16())
346     OnlyDerivatives = true; // Only supports G16
347 
348   // Check if there is a bias parameter and if it can be converted to f16
349   if (!OnlyDerivatives && ImageDimIntr->NumBiasArgs != 0) {
350     Value *Bias = II.getOperand(ImageDimIntr->BiasIndex);
351     assert(HasSampler &&
352            "Only image instructions with a sampler can have a bias");
353     if (!canSafelyConvertTo16Bit(*Bias, HasSampler))
354       OnlyDerivatives = true;
355   }
356 
357   if (OnlyDerivatives && (!ST->hasG16() || ImageDimIntr->GradientStart ==
358                                                ImageDimIntr->CoordStart))
359     return std::nullopt;
360 
361   Type *CoordType = FloatCoord ? Type::getHalfTy(II.getContext())
362                                : Type::getInt16Ty(II.getContext());
363 
364   return modifyIntrinsicCall(
365       II, II, II.getIntrinsicID(), IC, [&](auto &Args, auto &ArgTys) {
366         ArgTys[ImageDimIntr->GradientTyArg] = CoordType;
367         if (!OnlyDerivatives) {
368           ArgTys[ImageDimIntr->CoordTyArg] = CoordType;
369 
370           // Change the bias type
371           if (ImageDimIntr->NumBiasArgs != 0)
372             ArgTys[ImageDimIntr->BiasTyArg] = Type::getHalfTy(II.getContext());
373         }
374 
375         unsigned EndIndex =
376             OnlyDerivatives ? ImageDimIntr->CoordStart : ImageDimIntr->VAddrEnd;
377         for (unsigned OperandIndex = ImageDimIntr->GradientStart;
378              OperandIndex < EndIndex; OperandIndex++) {
379           Args[OperandIndex] =
380               convertTo16Bit(*II.getOperand(OperandIndex), IC.Builder);
381         }
382 
383         // Convert the bias
384         if (!OnlyDerivatives && ImageDimIntr->NumBiasArgs != 0) {
385           Value *Bias = II.getOperand(ImageDimIntr->BiasIndex);
386           Args[ImageDimIntr->BiasIndex] = convertTo16Bit(*Bias, IC.Builder);
387         }
388       });
389 }
390 
canSimplifyLegacyMulToMul(const Instruction & I,const Value * Op0,const Value * Op1,InstCombiner & IC) const391 bool GCNTTIImpl::canSimplifyLegacyMulToMul(const Instruction &I,
392                                            const Value *Op0, const Value *Op1,
393                                            InstCombiner &IC) const {
394   // The legacy behaviour is that multiplying +/-0.0 by anything, even NaN or
395   // infinity, gives +0.0. If we can prove we don't have one of the special
396   // cases then we can use a normal multiply instead.
397   // TODO: Create and use isKnownFiniteNonZero instead of just matching
398   // constants here.
399   if (match(Op0, PatternMatch::m_FiniteNonZero()) ||
400       match(Op1, PatternMatch::m_FiniteNonZero())) {
401     // One operand is not zero or infinity or NaN.
402     return true;
403   }
404 
405   SimplifyQuery SQ = IC.getSimplifyQuery().getWithInstruction(&I);
406   if (isKnownNeverInfOrNaN(Op0, SQ) && isKnownNeverInfOrNaN(Op1, SQ)) {
407     // Neither operand is infinity or NaN.
408     return true;
409   }
410   return false;
411 }
412 
413 /// Match an fpext from half to float, or a constant we can convert.
matchFPExtFromF16(Value * Arg)414 static Value *matchFPExtFromF16(Value *Arg) {
415   Value *Src = nullptr;
416   ConstantFP *CFP = nullptr;
417   if (match(Arg, m_OneUse(m_FPExt(m_Value(Src))))) {
418     if (Src->getType()->isHalfTy())
419       return Src;
420   } else if (match(Arg, m_ConstantFP(CFP))) {
421     bool LosesInfo;
422     APFloat Val(CFP->getValueAPF());
423     Val.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &LosesInfo);
424     if (!LosesInfo)
425       return ConstantFP::get(Type::getHalfTy(Arg->getContext()), Val);
426   }
427   return nullptr;
428 }
429 
430 // Trim all zero components from the end of the vector \p UseV and return
431 // an appropriate bitset with known elements.
trimTrailingZerosInVector(InstCombiner & IC,Value * UseV,Instruction * I)432 static APInt trimTrailingZerosInVector(InstCombiner &IC, Value *UseV,
433                                        Instruction *I) {
434   auto *VTy = cast<FixedVectorType>(UseV->getType());
435   unsigned VWidth = VTy->getNumElements();
436   APInt DemandedElts = APInt::getAllOnes(VWidth);
437 
438   for (int i = VWidth - 1; i > 0; --i) {
439     auto *Elt = findScalarElement(UseV, i);
440     if (!Elt)
441       break;
442 
443     if (auto *ConstElt = dyn_cast<Constant>(Elt)) {
444       if (!ConstElt->isNullValue() && !isa<UndefValue>(Elt))
445         break;
446     } else {
447       break;
448     }
449 
450     DemandedElts.clearBit(i);
451   }
452 
453   return DemandedElts;
454 }
455 
456 // Trim elements of the end of the vector \p V, if they are
457 // equal to the first element of the vector.
defaultComponentBroadcast(Value * V)458 static APInt defaultComponentBroadcast(Value *V) {
459   auto *VTy = cast<FixedVectorType>(V->getType());
460   unsigned VWidth = VTy->getNumElements();
461   APInt DemandedElts = APInt::getAllOnes(VWidth);
462   Value *FirstComponent = findScalarElement(V, 0);
463 
464   SmallVector<int> ShuffleMask;
465   if (auto *SVI = dyn_cast<ShuffleVectorInst>(V))
466     SVI->getShuffleMask(ShuffleMask);
467 
468   for (int I = VWidth - 1; I > 0; --I) {
469     if (ShuffleMask.empty()) {
470       auto *Elt = findScalarElement(V, I);
471       if (!Elt || (Elt != FirstComponent && !isa<UndefValue>(Elt)))
472         break;
473     } else {
474       // Detect identical elements in the shufflevector result, even though
475       // findScalarElement cannot tell us what that element is.
476       if (ShuffleMask[I] != ShuffleMask[0] && ShuffleMask[I] != PoisonMaskElem)
477         break;
478     }
479     DemandedElts.clearBit(I);
480   }
481 
482   return DemandedElts;
483 }
484 
485 static Value *simplifyAMDGCNMemoryIntrinsicDemanded(InstCombiner &IC,
486                                                     IntrinsicInst &II,
487                                                     APInt DemandedElts,
488                                                     int DMaskIdx = -1,
489                                                     bool IsLoad = true);
490 
491 /// Return true if it's legal to contract llvm.amdgcn.rcp(llvm.sqrt)
canContractSqrtToRsq(const FPMathOperator * SqrtOp)492 static bool canContractSqrtToRsq(const FPMathOperator *SqrtOp) {
493   return (SqrtOp->getType()->isFloatTy() &&
494           (SqrtOp->hasApproxFunc() || SqrtOp->getFPAccuracy() >= 1.0f)) ||
495          SqrtOp->getType()->isHalfTy();
496 }
497 
498 /// Return true if we can easily prove that use U is uniform.
isTriviallyUniform(const Use & U)499 static bool isTriviallyUniform(const Use &U) {
500   Value *V = U.get();
501   if (isa<Constant>(V))
502     return true;
503   if (const auto *A = dyn_cast<Argument>(V))
504     return AMDGPU::isArgPassedInSGPR(A);
505   if (const auto *II = dyn_cast<IntrinsicInst>(V)) {
506     if (!AMDGPU::isIntrinsicAlwaysUniform(II->getIntrinsicID()))
507       return false;
508     // If II and U are in different blocks then there is a possibility of
509     // temporal divergence.
510     return II->getParent() == cast<Instruction>(U.getUser())->getParent();
511   }
512   return false;
513 }
514 
515 /// Simplify a lane index operand (e.g. llvm.amdgcn.readlane src1).
516 ///
517 /// The instruction only reads the low 5 bits for wave32, and 6 bits for wave64.
simplifyDemandedLaneMaskArg(InstCombiner & IC,IntrinsicInst & II,unsigned LaneArgIdx) const518 bool GCNTTIImpl::simplifyDemandedLaneMaskArg(InstCombiner &IC,
519                                              IntrinsicInst &II,
520                                              unsigned LaneArgIdx) const {
521   unsigned MaskBits = ST->getWavefrontSizeLog2();
522   APInt DemandedMask(32, maskTrailingOnes<unsigned>(MaskBits));
523 
524   KnownBits Known(32);
525   if (IC.SimplifyDemandedBits(&II, LaneArgIdx, DemandedMask, Known))
526     return true;
527 
528   if (!Known.isConstant())
529     return false;
530 
531   // Out of bounds indexes may appear in wave64 code compiled for wave32.
532   // Unlike the DAG version, SimplifyDemandedBits does not change constants, so
533   // manually fix it up.
534 
535   Value *LaneArg = II.getArgOperand(LaneArgIdx);
536   Constant *MaskedConst =
537       ConstantInt::get(LaneArg->getType(), Known.getConstant() & DemandedMask);
538   if (MaskedConst != LaneArg) {
539     II.getOperandUse(LaneArgIdx).set(MaskedConst);
540     return true;
541   }
542 
543   return false;
544 }
545 
rewriteCall(IRBuilderBase & B,CallInst & Old,Function & NewCallee,ArrayRef<Value * > Ops)546 static CallInst *rewriteCall(IRBuilderBase &B, CallInst &Old,
547                              Function &NewCallee, ArrayRef<Value *> Ops) {
548   SmallVector<OperandBundleDef, 2> OpBundles;
549   Old.getOperandBundlesAsDefs(OpBundles);
550 
551   CallInst *NewCall = B.CreateCall(&NewCallee, Ops, OpBundles);
552   NewCall->takeName(&Old);
553   return NewCall;
554 }
555 
556 Instruction *
hoistLaneIntrinsicThroughOperand(InstCombiner & IC,IntrinsicInst & II) const557 GCNTTIImpl::hoistLaneIntrinsicThroughOperand(InstCombiner &IC,
558                                              IntrinsicInst &II) const {
559   const auto IID = II.getIntrinsicID();
560   assert(IID == Intrinsic::amdgcn_readlane ||
561          IID == Intrinsic::amdgcn_readfirstlane ||
562          IID == Intrinsic::amdgcn_permlane64);
563 
564   Instruction *OpInst = dyn_cast<Instruction>(II.getOperand(0));
565 
566   // Only do this if both instructions are in the same block
567   // (so the exec mask won't change) and the readlane is the only user of its
568   // operand.
569   if (!OpInst || !OpInst->hasOneUser() || OpInst->getParent() != II.getParent())
570     return nullptr;
571 
572   const bool IsReadLane = (IID == Intrinsic::amdgcn_readlane);
573 
574   // If this is a readlane, check that the second operand is a constant, or is
575   // defined before OpInst so we know it's safe to move this intrinsic higher.
576   Value *LaneID = nullptr;
577   if (IsReadLane) {
578     LaneID = II.getOperand(1);
579 
580     // readlane take an extra operand for the lane ID, so we must check if that
581     // LaneID value can be used at the point where we want to move the
582     // intrinsic.
583     if (auto *LaneIDInst = dyn_cast<Instruction>(LaneID)) {
584       if (!IC.getDominatorTree().dominates(LaneIDInst, OpInst))
585         return nullptr;
586     }
587   }
588 
589   // Hoist the intrinsic (II) through OpInst.
590   //
591   // (II (OpInst x)) -> (OpInst (II x))
592   const auto DoIt = [&](unsigned OpIdx,
593                         Function *NewIntrinsic) -> Instruction * {
594     SmallVector<Value *, 2> Ops{OpInst->getOperand(OpIdx)};
595     if (IsReadLane)
596       Ops.push_back(LaneID);
597 
598     // Rewrite the intrinsic call.
599     CallInst *NewII = rewriteCall(IC.Builder, II, *NewIntrinsic, Ops);
600 
601     // Rewrite OpInst so it takes the result of the intrinsic now.
602     Instruction &NewOp = *OpInst->clone();
603     NewOp.setOperand(OpIdx, NewII);
604     return &NewOp;
605   };
606 
607   // TODO(?): Should we do more with permlane64?
608   if (IID == Intrinsic::amdgcn_permlane64 && !isa<BitCastInst>(OpInst))
609     return nullptr;
610 
611   if (isa<UnaryOperator>(OpInst))
612     return DoIt(0, II.getCalledFunction());
613 
614   if (isa<CastInst>(OpInst)) {
615     Value *Src = OpInst->getOperand(0);
616     Type *SrcTy = Src->getType();
617     if (!isTypeLegal(SrcTy))
618       return nullptr;
619 
620     Function *Remangled =
621         Intrinsic::getOrInsertDeclaration(II.getModule(), IID, {SrcTy});
622     return DoIt(0, Remangled);
623   }
624 
625   // We can also hoist through binary operators if the other operand is uniform.
626   if (isa<BinaryOperator>(OpInst)) {
627     // FIXME: If we had access to UniformityInfo here we could just check
628     // if the operand is uniform.
629     if (isTriviallyUniform(OpInst->getOperandUse(0)))
630       return DoIt(1, II.getCalledFunction());
631     if (isTriviallyUniform(OpInst->getOperandUse(1)))
632       return DoIt(0, II.getCalledFunction());
633   }
634 
635   return nullptr;
636 }
637 
638 std::optional<Instruction *>
instCombineIntrinsic(InstCombiner & IC,IntrinsicInst & II) const639 GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
640   Intrinsic::ID IID = II.getIntrinsicID();
641   switch (IID) {
642   case Intrinsic::amdgcn_rcp: {
643     Value *Src = II.getArgOperand(0);
644     if (isa<PoisonValue>(Src))
645       return IC.replaceInstUsesWith(II, Src);
646 
647     // TODO: Move to ConstantFolding/InstSimplify?
648     if (isa<UndefValue>(Src)) {
649       Type *Ty = II.getType();
650       auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics()));
651       return IC.replaceInstUsesWith(II, QNaN);
652     }
653 
654     if (II.isStrictFP())
655       break;
656 
657     if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
658       const APFloat &ArgVal = C->getValueAPF();
659       APFloat Val(ArgVal.getSemantics(), 1);
660       Val.divide(ArgVal, APFloat::rmNearestTiesToEven);
661 
662       // This is more precise than the instruction may give.
663       //
664       // TODO: The instruction always flushes denormal results (except for f16),
665       // should this also?
666       return IC.replaceInstUsesWith(II, ConstantFP::get(II.getContext(), Val));
667     }
668 
669     FastMathFlags FMF = cast<FPMathOperator>(II).getFastMathFlags();
670     if (!FMF.allowContract())
671       break;
672     auto *SrcCI = dyn_cast<IntrinsicInst>(Src);
673     if (!SrcCI)
674       break;
675 
676     auto IID = SrcCI->getIntrinsicID();
677     // llvm.amdgcn.rcp(llvm.amdgcn.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable
678     //
679     // llvm.amdgcn.rcp(llvm.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable and
680     // relaxed.
681     if (IID == Intrinsic::amdgcn_sqrt || IID == Intrinsic::sqrt) {
682       const FPMathOperator *SqrtOp = cast<FPMathOperator>(SrcCI);
683       FastMathFlags InnerFMF = SqrtOp->getFastMathFlags();
684       if (!InnerFMF.allowContract() || !SrcCI->hasOneUse())
685         break;
686 
687       if (IID == Intrinsic::sqrt && !canContractSqrtToRsq(SqrtOp))
688         break;
689 
690       Function *NewDecl = Intrinsic::getOrInsertDeclaration(
691           SrcCI->getModule(), Intrinsic::amdgcn_rsq, {SrcCI->getType()});
692 
693       InnerFMF |= FMF;
694       II.setFastMathFlags(InnerFMF);
695 
696       II.setCalledFunction(NewDecl);
697       return IC.replaceOperand(II, 0, SrcCI->getArgOperand(0));
698     }
699 
700     break;
701   }
702   case Intrinsic::amdgcn_sqrt:
703   case Intrinsic::amdgcn_rsq:
704   case Intrinsic::amdgcn_tanh: {
705     Value *Src = II.getArgOperand(0);
706     if (isa<PoisonValue>(Src))
707       return IC.replaceInstUsesWith(II, Src);
708 
709     // TODO: Move to ConstantFolding/InstSimplify?
710     if (isa<UndefValue>(Src)) {
711       Type *Ty = II.getType();
712       auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics()));
713       return IC.replaceInstUsesWith(II, QNaN);
714     }
715 
716     // f16 amdgcn.sqrt is identical to regular sqrt.
717     if (IID == Intrinsic::amdgcn_sqrt && Src->getType()->isHalfTy()) {
718       Function *NewDecl = Intrinsic::getOrInsertDeclaration(
719           II.getModule(), Intrinsic::sqrt, {II.getType()});
720       II.setCalledFunction(NewDecl);
721       return &II;
722     }
723 
724     break;
725   }
726   case Intrinsic::amdgcn_log:
727   case Intrinsic::amdgcn_exp2: {
728     const bool IsLog = IID == Intrinsic::amdgcn_log;
729     const bool IsExp = IID == Intrinsic::amdgcn_exp2;
730     Value *Src = II.getArgOperand(0);
731     Type *Ty = II.getType();
732 
733     if (isa<PoisonValue>(Src))
734       return IC.replaceInstUsesWith(II, Src);
735 
736     if (IC.getSimplifyQuery().isUndefValue(Src))
737       return IC.replaceInstUsesWith(II, ConstantFP::getNaN(Ty));
738 
739     if (ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
740       if (C->isInfinity()) {
741         // exp2(+inf) -> +inf
742         // log2(+inf) -> +inf
743         if (!C->isNegative())
744           return IC.replaceInstUsesWith(II, C);
745 
746         // exp2(-inf) -> 0
747         if (IsExp && C->isNegative())
748           return IC.replaceInstUsesWith(II, ConstantFP::getZero(Ty));
749       }
750 
751       if (II.isStrictFP())
752         break;
753 
754       if (C->isNaN()) {
755         Constant *Quieted = ConstantFP::get(Ty, C->getValue().makeQuiet());
756         return IC.replaceInstUsesWith(II, Quieted);
757       }
758 
759       // f32 instruction doesn't handle denormals, f16 does.
760       if (C->isZero() || (C->getValue().isDenormal() && Ty->isFloatTy())) {
761         Constant *FoldedValue = IsLog ? ConstantFP::getInfinity(Ty, true)
762                                       : ConstantFP::get(Ty, 1.0);
763         return IC.replaceInstUsesWith(II, FoldedValue);
764       }
765 
766       if (IsLog && C->isNegative())
767         return IC.replaceInstUsesWith(II, ConstantFP::getNaN(Ty));
768 
769       // TODO: Full constant folding matching hardware behavior.
770     }
771 
772     break;
773   }
774   case Intrinsic::amdgcn_frexp_mant:
775   case Intrinsic::amdgcn_frexp_exp: {
776     Value *Src = II.getArgOperand(0);
777     if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
778       int Exp;
779       APFloat Significand =
780           frexp(C->getValueAPF(), Exp, APFloat::rmNearestTiesToEven);
781 
782       if (IID == Intrinsic::amdgcn_frexp_mant) {
783         return IC.replaceInstUsesWith(
784             II, ConstantFP::get(II.getContext(), Significand));
785       }
786 
787       // Match instruction special case behavior.
788       if (Exp == APFloat::IEK_NaN || Exp == APFloat::IEK_Inf)
789         Exp = 0;
790 
791       return IC.replaceInstUsesWith(II, ConstantInt::get(II.getType(), Exp));
792     }
793 
794     if (isa<PoisonValue>(Src))
795       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
796 
797     if (isa<UndefValue>(Src)) {
798       return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
799     }
800 
801     break;
802   }
803   case Intrinsic::amdgcn_class: {
804     Value *Src0 = II.getArgOperand(0);
805     Value *Src1 = II.getArgOperand(1);
806     const ConstantInt *CMask = dyn_cast<ConstantInt>(Src1);
807     if (CMask) {
808       II.setCalledOperand(Intrinsic::getOrInsertDeclaration(
809           II.getModule(), Intrinsic::is_fpclass, Src0->getType()));
810 
811       // Clamp any excess bits, as they're illegal for the generic intrinsic.
812       II.setArgOperand(1, ConstantInt::get(Src1->getType(),
813                                            CMask->getZExtValue() & fcAllFlags));
814       return &II;
815     }
816 
817     // Propagate poison.
818     if (isa<PoisonValue>(Src0) || isa<PoisonValue>(Src1))
819       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
820 
821     // llvm.amdgcn.class(_, undef) -> false
822     if (IC.getSimplifyQuery().isUndefValue(Src1))
823       return IC.replaceInstUsesWith(II, ConstantInt::get(II.getType(), false));
824 
825     // llvm.amdgcn.class(undef, mask) -> mask != 0
826     if (IC.getSimplifyQuery().isUndefValue(Src0)) {
827       Value *CmpMask = IC.Builder.CreateICmpNE(
828           Src1, ConstantInt::getNullValue(Src1->getType()));
829       return IC.replaceInstUsesWith(II, CmpMask);
830     }
831     break;
832   }
833   case Intrinsic::amdgcn_cvt_pkrtz: {
834     auto foldFPTruncToF16RTZ = [](Value *Arg) -> Value * {
835       Type *HalfTy = Type::getHalfTy(Arg->getContext());
836 
837       if (isa<PoisonValue>(Arg))
838         return PoisonValue::get(HalfTy);
839       if (isa<UndefValue>(Arg))
840         return UndefValue::get(HalfTy);
841 
842       ConstantFP *CFP = nullptr;
843       if (match(Arg, m_ConstantFP(CFP))) {
844         bool LosesInfo;
845         APFloat Val(CFP->getValueAPF());
846         Val.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero, &LosesInfo);
847         return ConstantFP::get(HalfTy, Val);
848       }
849 
850       Value *Src = nullptr;
851       if (match(Arg, m_FPExt(m_Value(Src)))) {
852         if (Src->getType()->isHalfTy())
853           return Src;
854       }
855 
856       return nullptr;
857     };
858 
859     if (Value *Src0 = foldFPTruncToF16RTZ(II.getArgOperand(0))) {
860       if (Value *Src1 = foldFPTruncToF16RTZ(II.getArgOperand(1))) {
861         Value *V = PoisonValue::get(II.getType());
862         V = IC.Builder.CreateInsertElement(V, Src0, (uint64_t)0);
863         V = IC.Builder.CreateInsertElement(V, Src1, (uint64_t)1);
864         return IC.replaceInstUsesWith(II, V);
865       }
866     }
867 
868     break;
869   }
870   case Intrinsic::amdgcn_cvt_pknorm_i16:
871   case Intrinsic::amdgcn_cvt_pknorm_u16:
872   case Intrinsic::amdgcn_cvt_pk_i16:
873   case Intrinsic::amdgcn_cvt_pk_u16: {
874     Value *Src0 = II.getArgOperand(0);
875     Value *Src1 = II.getArgOperand(1);
876 
877     // TODO: Replace call with scalar operation if only one element is poison.
878     if (isa<PoisonValue>(Src0) && isa<PoisonValue>(Src1))
879       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
880 
881     if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) {
882       return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
883     }
884 
885     break;
886   }
887   case Intrinsic::amdgcn_cvt_off_f32_i4: {
888     Value* Arg = II.getArgOperand(0);
889     Type *Ty = II.getType();
890 
891     if (isa<PoisonValue>(Arg))
892       return IC.replaceInstUsesWith(II, PoisonValue::get(Ty));
893 
894     if(IC.getSimplifyQuery().isUndefValue(Arg))
895       return IC.replaceInstUsesWith(II, Constant::getNullValue(Ty));
896 
897     ConstantInt *CArg = dyn_cast<ConstantInt>(II.getArgOperand(0));
898     if (!CArg)
899       break;
900 
901     // Tabulated 0.0625 * (sext (CArg & 0xf)).
902     constexpr size_t ResValsSize = 16;
903     static constexpr float ResVals[ResValsSize] = {
904         0.0,  0.0625,  0.125,  0.1875,  0.25,  0.3125,  0.375,  0.4375,
905         -0.5, -0.4375, -0.375, -0.3125, -0.25, -0.1875, -0.125, -0.0625};
906     Constant *Res =
907         ConstantFP::get(Ty, ResVals[CArg->getZExtValue() & (ResValsSize - 1)]);
908     return IC.replaceInstUsesWith(II, Res);
909   }
910   case Intrinsic::amdgcn_ubfe:
911   case Intrinsic::amdgcn_sbfe: {
912     // Decompose simple cases into standard shifts.
913     Value *Src = II.getArgOperand(0);
914     if (isa<UndefValue>(Src)) {
915       return IC.replaceInstUsesWith(II, Src);
916     }
917 
918     unsigned Width;
919     Type *Ty = II.getType();
920     unsigned IntSize = Ty->getIntegerBitWidth();
921 
922     ConstantInt *CWidth = dyn_cast<ConstantInt>(II.getArgOperand(2));
923     if (CWidth) {
924       Width = CWidth->getZExtValue();
925       if ((Width & (IntSize - 1)) == 0) {
926         return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(Ty));
927       }
928 
929       // Hardware ignores high bits, so remove those.
930       if (Width >= IntSize) {
931         return IC.replaceOperand(
932             II, 2, ConstantInt::get(CWidth->getType(), Width & (IntSize - 1)));
933       }
934     }
935 
936     unsigned Offset;
937     ConstantInt *COffset = dyn_cast<ConstantInt>(II.getArgOperand(1));
938     if (COffset) {
939       Offset = COffset->getZExtValue();
940       if (Offset >= IntSize) {
941         return IC.replaceOperand(
942             II, 1,
943             ConstantInt::get(COffset->getType(), Offset & (IntSize - 1)));
944       }
945     }
946 
947     bool Signed = IID == Intrinsic::amdgcn_sbfe;
948 
949     if (!CWidth || !COffset)
950       break;
951 
952     // The case of Width == 0 is handled above, which makes this transformation
953     // safe.  If Width == 0, then the ashr and lshr instructions become poison
954     // value since the shift amount would be equal to the bit size.
955     assert(Width != 0);
956 
957     // TODO: This allows folding to undef when the hardware has specific
958     // behavior?
959     if (Offset + Width < IntSize) {
960       Value *Shl = IC.Builder.CreateShl(Src, IntSize - Offset - Width);
961       Value *RightShift = Signed ? IC.Builder.CreateAShr(Shl, IntSize - Width)
962                                  : IC.Builder.CreateLShr(Shl, IntSize - Width);
963       RightShift->takeName(&II);
964       return IC.replaceInstUsesWith(II, RightShift);
965     }
966 
967     Value *RightShift = Signed ? IC.Builder.CreateAShr(Src, Offset)
968                                : IC.Builder.CreateLShr(Src, Offset);
969 
970     RightShift->takeName(&II);
971     return IC.replaceInstUsesWith(II, RightShift);
972   }
973   case Intrinsic::amdgcn_exp:
974   case Intrinsic::amdgcn_exp_row:
975   case Intrinsic::amdgcn_exp_compr: {
976     ConstantInt *En = cast<ConstantInt>(II.getArgOperand(1));
977     unsigned EnBits = En->getZExtValue();
978     if (EnBits == 0xf)
979       break; // All inputs enabled.
980 
981     bool IsCompr = IID == Intrinsic::amdgcn_exp_compr;
982     bool Changed = false;
983     for (int I = 0; I < (IsCompr ? 2 : 4); ++I) {
984       if ((!IsCompr && (EnBits & (1 << I)) == 0) ||
985           (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) {
986         Value *Src = II.getArgOperand(I + 2);
987         if (!isa<PoisonValue>(Src)) {
988           IC.replaceOperand(II, I + 2, PoisonValue::get(Src->getType()));
989           Changed = true;
990         }
991       }
992     }
993 
994     if (Changed) {
995       return &II;
996     }
997 
998     break;
999   }
1000   case Intrinsic::amdgcn_fmed3: {
1001     Value *Src0 = II.getArgOperand(0);
1002     Value *Src1 = II.getArgOperand(1);
1003     Value *Src2 = II.getArgOperand(2);
1004 
1005     for (Value *Src : {Src0, Src1, Src2}) {
1006       if (isa<PoisonValue>(Src))
1007         return IC.replaceInstUsesWith(II, Src);
1008     }
1009 
1010     if (II.isStrictFP())
1011       break;
1012 
1013     // med3 with a nan input acts like
1014     // v_min_f32(v_min_f32(s0, s1), s2)
1015     //
1016     // Signalingness is ignored with ieee=0, so we fold to
1017     // minimumnum/maximumnum. With ieee=1, the v_min_f32 acts like llvm.minnum
1018     // with signaling nan handling. With ieee=0, like llvm.minimumnum except a
1019     // returned signaling nan will not be quieted.
1020 
1021     // ieee=1
1022     // s0 snan: s2
1023     // s1 snan: s2
1024     // s2 snan: qnan
1025 
1026     // s0 qnan: min(s1, s2)
1027     // s1 qnan: min(s0, s2)
1028     // s2 qnan: min(s0, s1)
1029 
1030     // ieee=0
1031     // s0 _nan: min(s1, s2)
1032     // s1 _nan: min(s0, s2)
1033     // s2 _nan: min(s0, s1)
1034 
1035     // med3 behavior with infinity
1036     // s0 +inf: max(s1, s2)
1037     // s1 +inf: max(s0, s2)
1038     // s2 +inf: max(s0, s1)
1039     // s0 -inf: min(s1, s2)
1040     // s1 -inf: min(s0, s2)
1041     // s2 -inf: min(s0, s1)
1042 
1043     // Checking for NaN before canonicalization provides better fidelity when
1044     // mapping other operations onto fmed3 since the order of operands is
1045     // unchanged.
1046     Value *V = nullptr;
1047     const APFloat *ConstSrc0 = nullptr;
1048     const APFloat *ConstSrc1 = nullptr;
1049     const APFloat *ConstSrc2 = nullptr;
1050 
1051     if ((match(Src0, m_APFloat(ConstSrc0)) &&
1052          (ConstSrc0->isNaN() || ConstSrc0->isInfinity())) ||
1053         isa<UndefValue>(Src0)) {
1054       const bool IsPosInfinity = ConstSrc0 && ConstSrc0->isPosInfinity();
1055       switch (fpenvIEEEMode(II)) {
1056       case KnownIEEEMode::On:
1057         // TODO: If Src2 is snan, does it need quieting?
1058         if (ConstSrc0 && ConstSrc0->isNaN() && ConstSrc0->isSignaling())
1059           return IC.replaceInstUsesWith(II, Src2);
1060 
1061         V = IsPosInfinity ? IC.Builder.CreateMaxNum(Src1, Src2)
1062                           : IC.Builder.CreateMinNum(Src1, Src2);
1063         break;
1064       case KnownIEEEMode::Off:
1065         V = IsPosInfinity ? IC.Builder.CreateMaximumNum(Src1, Src2)
1066                           : IC.Builder.CreateMinimumNum(Src1, Src2);
1067         break;
1068       case KnownIEEEMode::Unknown:
1069         break;
1070       }
1071     } else if ((match(Src1, m_APFloat(ConstSrc1)) &&
1072                 (ConstSrc1->isNaN() || ConstSrc1->isInfinity())) ||
1073                isa<UndefValue>(Src1)) {
1074       const bool IsPosInfinity = ConstSrc1 && ConstSrc1->isPosInfinity();
1075       switch (fpenvIEEEMode(II)) {
1076       case KnownIEEEMode::On:
1077         // TODO: If Src2 is snan, does it need quieting?
1078         if (ConstSrc1 && ConstSrc1->isNaN() && ConstSrc1->isSignaling())
1079           return IC.replaceInstUsesWith(II, Src2);
1080 
1081         V = IsPosInfinity ? IC.Builder.CreateMaxNum(Src0, Src2)
1082                           : IC.Builder.CreateMinNum(Src0, Src2);
1083         break;
1084       case KnownIEEEMode::Off:
1085         V = IsPosInfinity ? IC.Builder.CreateMaximumNum(Src0, Src2)
1086                           : IC.Builder.CreateMinimumNum(Src0, Src2);
1087         break;
1088       case KnownIEEEMode::Unknown:
1089         break;
1090       }
1091     } else if ((match(Src2, m_APFloat(ConstSrc2)) &&
1092                 (ConstSrc2->isNaN() || ConstSrc2->isInfinity())) ||
1093                isa<UndefValue>(Src2)) {
1094       switch (fpenvIEEEMode(II)) {
1095       case KnownIEEEMode::On:
1096         if (ConstSrc2 && ConstSrc2->isNaN() && ConstSrc2->isSignaling()) {
1097           auto *Quieted = ConstantFP::get(II.getType(), ConstSrc2->makeQuiet());
1098           return IC.replaceInstUsesWith(II, Quieted);
1099         }
1100 
1101         V = (ConstSrc2 && ConstSrc2->isPosInfinity())
1102                 ? IC.Builder.CreateMaxNum(Src0, Src1)
1103                 : IC.Builder.CreateMinNum(Src0, Src1);
1104         break;
1105       case KnownIEEEMode::Off:
1106         V = (ConstSrc2 && ConstSrc2->isNegInfinity())
1107                 ? IC.Builder.CreateMinimumNum(Src0, Src1)
1108                 : IC.Builder.CreateMaximumNum(Src0, Src1);
1109         break;
1110       case KnownIEEEMode::Unknown:
1111         break;
1112       }
1113     }
1114 
1115     if (V) {
1116       if (auto *CI = dyn_cast<CallInst>(V)) {
1117         CI->copyFastMathFlags(&II);
1118         CI->takeName(&II);
1119       }
1120       return IC.replaceInstUsesWith(II, V);
1121     }
1122 
1123     bool Swap = false;
1124     // Canonicalize constants to RHS operands.
1125     //
1126     // fmed3(c0, x, c1) -> fmed3(x, c0, c1)
1127     if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
1128       std::swap(Src0, Src1);
1129       Swap = true;
1130     }
1131 
1132     if (isa<Constant>(Src1) && !isa<Constant>(Src2)) {
1133       std::swap(Src1, Src2);
1134       Swap = true;
1135     }
1136 
1137     if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
1138       std::swap(Src0, Src1);
1139       Swap = true;
1140     }
1141 
1142     if (Swap) {
1143       II.setArgOperand(0, Src0);
1144       II.setArgOperand(1, Src1);
1145       II.setArgOperand(2, Src2);
1146       return &II;
1147     }
1148 
1149     if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
1150       if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
1151         if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) {
1152           APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(),
1153                                        C2->getValueAPF());
1154           return IC.replaceInstUsesWith(II,
1155                                         ConstantFP::get(II.getType(), Result));
1156         }
1157       }
1158     }
1159 
1160     if (!ST->hasMed3_16())
1161       break;
1162 
1163     // Repeat floating-point width reduction done for minnum/maxnum.
1164     // fmed3((fpext X), (fpext Y), (fpext Z)) -> fpext (fmed3(X, Y, Z))
1165     if (Value *X = matchFPExtFromF16(Src0)) {
1166       if (Value *Y = matchFPExtFromF16(Src1)) {
1167         if (Value *Z = matchFPExtFromF16(Src2)) {
1168           Value *NewCall = IC.Builder.CreateIntrinsic(
1169               IID, {X->getType()}, {X, Y, Z}, &II, II.getName());
1170           return new FPExtInst(NewCall, II.getType());
1171         }
1172       }
1173     }
1174 
1175     break;
1176   }
1177   case Intrinsic::amdgcn_icmp:
1178   case Intrinsic::amdgcn_fcmp: {
1179     const ConstantInt *CC = cast<ConstantInt>(II.getArgOperand(2));
1180     // Guard against invalid arguments.
1181     int64_t CCVal = CC->getZExtValue();
1182     bool IsInteger = IID == Intrinsic::amdgcn_icmp;
1183     if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE ||
1184                        CCVal > CmpInst::LAST_ICMP_PREDICATE)) ||
1185         (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE ||
1186                         CCVal > CmpInst::LAST_FCMP_PREDICATE)))
1187       break;
1188 
1189     Value *Src0 = II.getArgOperand(0);
1190     Value *Src1 = II.getArgOperand(1);
1191 
1192     if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
1193       if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
1194         Constant *CCmp = ConstantFoldCompareInstOperands(
1195             (ICmpInst::Predicate)CCVal, CSrc0, CSrc1, DL);
1196         if (CCmp && CCmp->isNullValue()) {
1197           return IC.replaceInstUsesWith(
1198               II, IC.Builder.CreateSExt(CCmp, II.getType()));
1199         }
1200 
1201         // The result of V_ICMP/V_FCMP assembly instructions (which this
1202         // intrinsic exposes) is one bit per thread, masked with the EXEC
1203         // register (which contains the bitmask of live threads). So a
1204         // comparison that always returns true is the same as a read of the
1205         // EXEC register.
1206         Metadata *MDArgs[] = {MDString::get(II.getContext(), "exec")};
1207         MDNode *MD = MDNode::get(II.getContext(), MDArgs);
1208         Value *Args[] = {MetadataAsValue::get(II.getContext(), MD)};
1209         CallInst *NewCall = IC.Builder.CreateIntrinsic(Intrinsic::read_register,
1210                                                        II.getType(), Args);
1211         NewCall->addFnAttr(Attribute::Convergent);
1212         NewCall->takeName(&II);
1213         return IC.replaceInstUsesWith(II, NewCall);
1214       }
1215 
1216       // Canonicalize constants to RHS.
1217       CmpInst::Predicate SwapPred =
1218           CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal));
1219       II.setArgOperand(0, Src1);
1220       II.setArgOperand(1, Src0);
1221       II.setArgOperand(
1222           2, ConstantInt::get(CC->getType(), static_cast<int>(SwapPred)));
1223       return &II;
1224     }
1225 
1226     if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE)
1227       break;
1228 
1229     // Canonicalize compare eq with true value to compare != 0
1230     // llvm.amdgcn.icmp(zext (i1 x), 1, eq)
1231     //   -> llvm.amdgcn.icmp(zext (i1 x), 0, ne)
1232     // llvm.amdgcn.icmp(sext (i1 x), -1, eq)
1233     //   -> llvm.amdgcn.icmp(sext (i1 x), 0, ne)
1234     Value *ExtSrc;
1235     if (CCVal == CmpInst::ICMP_EQ &&
1236         ((match(Src1, PatternMatch::m_One()) &&
1237           match(Src0, m_ZExt(PatternMatch::m_Value(ExtSrc)))) ||
1238          (match(Src1, PatternMatch::m_AllOnes()) &&
1239           match(Src0, m_SExt(PatternMatch::m_Value(ExtSrc))))) &&
1240         ExtSrc->getType()->isIntegerTy(1)) {
1241       IC.replaceOperand(II, 1, ConstantInt::getNullValue(Src1->getType()));
1242       IC.replaceOperand(II, 2,
1243                         ConstantInt::get(CC->getType(), CmpInst::ICMP_NE));
1244       return &II;
1245     }
1246 
1247     CmpPredicate SrcPred;
1248     Value *SrcLHS;
1249     Value *SrcRHS;
1250 
1251     // Fold compare eq/ne with 0 from a compare result as the predicate to the
1252     // intrinsic. The typical use is a wave vote function in the library, which
1253     // will be fed from a user code condition compared with 0. Fold in the
1254     // redundant compare.
1255 
1256     // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne)
1257     //   -> llvm.amdgcn.[if]cmp(a, b, pred)
1258     //
1259     // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq)
1260     //   -> llvm.amdgcn.[if]cmp(a, b, inv pred)
1261     if (match(Src1, PatternMatch::m_Zero()) &&
1262         match(Src0, PatternMatch::m_ZExtOrSExt(
1263                         m_Cmp(SrcPred, PatternMatch::m_Value(SrcLHS),
1264                               PatternMatch::m_Value(SrcRHS))))) {
1265       if (CCVal == CmpInst::ICMP_EQ)
1266         SrcPred = CmpInst::getInversePredicate(SrcPred);
1267 
1268       Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred)
1269                                  ? Intrinsic::amdgcn_fcmp
1270                                  : Intrinsic::amdgcn_icmp;
1271 
1272       Type *Ty = SrcLHS->getType();
1273       if (auto *CmpType = dyn_cast<IntegerType>(Ty)) {
1274         // Promote to next legal integer type.
1275         unsigned Width = CmpType->getBitWidth();
1276         unsigned NewWidth = Width;
1277 
1278         // Don't do anything for i1 comparisons.
1279         if (Width == 1)
1280           break;
1281 
1282         if (Width <= 16)
1283           NewWidth = 16;
1284         else if (Width <= 32)
1285           NewWidth = 32;
1286         else if (Width <= 64)
1287           NewWidth = 64;
1288         else
1289           break; // Can't handle this.
1290 
1291         if (Width != NewWidth) {
1292           IntegerType *CmpTy = IC.Builder.getIntNTy(NewWidth);
1293           if (CmpInst::isSigned(SrcPred)) {
1294             SrcLHS = IC.Builder.CreateSExt(SrcLHS, CmpTy);
1295             SrcRHS = IC.Builder.CreateSExt(SrcRHS, CmpTy);
1296           } else {
1297             SrcLHS = IC.Builder.CreateZExt(SrcLHS, CmpTy);
1298             SrcRHS = IC.Builder.CreateZExt(SrcRHS, CmpTy);
1299           }
1300         }
1301       } else if (!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy())
1302         break;
1303 
1304       Value *Args[] = {SrcLHS, SrcRHS,
1305                        ConstantInt::get(CC->getType(), SrcPred)};
1306       CallInst *NewCall = IC.Builder.CreateIntrinsic(
1307           NewIID, {II.getType(), SrcLHS->getType()}, Args);
1308       NewCall->takeName(&II);
1309       return IC.replaceInstUsesWith(II, NewCall);
1310     }
1311 
1312     break;
1313   }
1314   case Intrinsic::amdgcn_mbcnt_hi: {
1315     // exec_hi is all 0, so this is just a copy.
1316     if (ST->isWave32())
1317       return IC.replaceInstUsesWith(II, II.getArgOperand(1));
1318     break;
1319   }
1320   case Intrinsic::amdgcn_ballot: {
1321     Value *Arg = II.getArgOperand(0);
1322     if (isa<PoisonValue>(Arg))
1323       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
1324 
1325     if (auto *Src = dyn_cast<ConstantInt>(Arg)) {
1326       if (Src->isZero()) {
1327         // amdgcn.ballot(i1 0) is zero.
1328         return IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1329       }
1330     }
1331     if (ST->isWave32() && II.getType()->getIntegerBitWidth() == 64) {
1332       // %b64 = call i64 ballot.i64(...)
1333       // =>
1334       // %b32 = call i32 ballot.i32(...)
1335       // %b64 = zext i32 %b32 to i64
1336       Value *Call = IC.Builder.CreateZExt(
1337           IC.Builder.CreateIntrinsic(Intrinsic::amdgcn_ballot,
1338                                      {IC.Builder.getInt32Ty()},
1339                                      {II.getArgOperand(0)}),
1340           II.getType());
1341       Call->takeName(&II);
1342       return IC.replaceInstUsesWith(II, Call);
1343     }
1344     break;
1345   }
1346   case Intrinsic::amdgcn_wavefrontsize: {
1347     if (ST->isWaveSizeKnown())
1348       return IC.replaceInstUsesWith(
1349           II, ConstantInt::get(II.getType(), ST->getWavefrontSize()));
1350     break;
1351   }
1352   case Intrinsic::amdgcn_wqm_vote: {
1353     // wqm_vote is identity when the argument is constant.
1354     if (!isa<Constant>(II.getArgOperand(0)))
1355       break;
1356 
1357     return IC.replaceInstUsesWith(II, II.getArgOperand(0));
1358   }
1359   case Intrinsic::amdgcn_kill: {
1360     const ConstantInt *C = dyn_cast<ConstantInt>(II.getArgOperand(0));
1361     if (!C || !C->getZExtValue())
1362       break;
1363 
1364     // amdgcn.kill(i1 1) is a no-op
1365     return IC.eraseInstFromFunction(II);
1366   }
1367   case Intrinsic::amdgcn_update_dpp: {
1368     Value *Old = II.getArgOperand(0);
1369 
1370     auto *BC = cast<ConstantInt>(II.getArgOperand(5));
1371     auto *RM = cast<ConstantInt>(II.getArgOperand(3));
1372     auto *BM = cast<ConstantInt>(II.getArgOperand(4));
1373     if (BC->isZeroValue() || RM->getZExtValue() != 0xF ||
1374         BM->getZExtValue() != 0xF || isa<PoisonValue>(Old))
1375       break;
1376 
1377     // If bound_ctrl = 1, row mask = bank mask = 0xf we can omit old value.
1378     return IC.replaceOperand(II, 0, PoisonValue::get(Old->getType()));
1379   }
1380   case Intrinsic::amdgcn_permlane16:
1381   case Intrinsic::amdgcn_permlane16_var:
1382   case Intrinsic::amdgcn_permlanex16:
1383   case Intrinsic::amdgcn_permlanex16_var: {
1384     // Discard vdst_in if it's not going to be read.
1385     Value *VDstIn = II.getArgOperand(0);
1386     if (isa<PoisonValue>(VDstIn))
1387       break;
1388 
1389     // FetchInvalid operand idx.
1390     unsigned int FiIdx = (IID == Intrinsic::amdgcn_permlane16 ||
1391                           IID == Intrinsic::amdgcn_permlanex16)
1392                              ? 4  /* for permlane16 and permlanex16 */
1393                              : 3; /* for permlane16_var and permlanex16_var */
1394 
1395     // BoundCtrl operand idx.
1396     // For permlane16 and permlanex16 it should be 5
1397     // For Permlane16_var and permlanex16_var it should be 4
1398     unsigned int BcIdx = FiIdx + 1;
1399 
1400     ConstantInt *FetchInvalid = cast<ConstantInt>(II.getArgOperand(FiIdx));
1401     ConstantInt *BoundCtrl = cast<ConstantInt>(II.getArgOperand(BcIdx));
1402     if (!FetchInvalid->getZExtValue() && !BoundCtrl->getZExtValue())
1403       break;
1404 
1405     return IC.replaceOperand(II, 0, PoisonValue::get(VDstIn->getType()));
1406   }
1407   case Intrinsic::amdgcn_permlane64:
1408   case Intrinsic::amdgcn_readfirstlane:
1409   case Intrinsic::amdgcn_readlane:
1410   case Intrinsic::amdgcn_ds_bpermute: {
1411     // If the data argument is uniform these intrinsics return it unchanged.
1412     unsigned SrcIdx = IID == Intrinsic::amdgcn_ds_bpermute ? 1 : 0;
1413     const Use &Src = II.getArgOperandUse(SrcIdx);
1414     if (isTriviallyUniform(Src))
1415       return IC.replaceInstUsesWith(II, Src.get());
1416 
1417     if (IID == Intrinsic::amdgcn_readlane &&
1418         simplifyDemandedLaneMaskArg(IC, II, 1))
1419       return &II;
1420 
1421     // If the lane argument of bpermute is uniform, change it to readlane. This
1422     // generates better code and can enable further optimizations because
1423     // readlane is AlwaysUniform.
1424     if (IID == Intrinsic::amdgcn_ds_bpermute) {
1425       const Use &Lane = II.getArgOperandUse(0);
1426       if (isTriviallyUniform(Lane)) {
1427         Value *NewLane = IC.Builder.CreateLShr(Lane, 2);
1428         Function *NewDecl = Intrinsic::getOrInsertDeclaration(
1429             II.getModule(), Intrinsic::amdgcn_readlane, II.getType());
1430         II.setCalledFunction(NewDecl);
1431         II.setOperand(0, Src);
1432         II.setOperand(1, NewLane);
1433         return &II;
1434       }
1435     }
1436 
1437     if (IID != Intrinsic::amdgcn_ds_bpermute) {
1438       if (Instruction *Res = hoistLaneIntrinsicThroughOperand(IC, II))
1439         return Res;
1440     }
1441 
1442     return std::nullopt;
1443   }
1444   case Intrinsic::amdgcn_writelane: {
1445     // TODO: Fold bitcast like readlane.
1446     if (simplifyDemandedLaneMaskArg(IC, II, 1))
1447       return &II;
1448     return std::nullopt;
1449   }
1450   case Intrinsic::amdgcn_trig_preop: {
1451     // The intrinsic is declared with name mangling, but currently the
1452     // instruction only exists for f64
1453     if (!II.getType()->isDoubleTy())
1454       break;
1455 
1456     Value *Src = II.getArgOperand(0);
1457     Value *Segment = II.getArgOperand(1);
1458     if (isa<PoisonValue>(Src) || isa<PoisonValue>(Segment))
1459       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
1460 
1461     if (isa<UndefValue>(Src)) {
1462       auto *QNaN = ConstantFP::get(
1463           II.getType(), APFloat::getQNaN(II.getType()->getFltSemantics()));
1464       return IC.replaceInstUsesWith(II, QNaN);
1465     }
1466 
1467     const ConstantFP *Csrc = dyn_cast<ConstantFP>(Src);
1468     if (!Csrc)
1469       break;
1470 
1471     if (II.isStrictFP())
1472       break;
1473 
1474     const APFloat &Fsrc = Csrc->getValueAPF();
1475     if (Fsrc.isNaN()) {
1476       auto *Quieted = ConstantFP::get(II.getType(), Fsrc.makeQuiet());
1477       return IC.replaceInstUsesWith(II, Quieted);
1478     }
1479 
1480     const ConstantInt *Cseg = dyn_cast<ConstantInt>(Segment);
1481     if (!Cseg)
1482       break;
1483 
1484     unsigned Exponent = (Fsrc.bitcastToAPInt().getZExtValue() >> 52) & 0x7ff;
1485     unsigned SegmentVal = Cseg->getValue().trunc(5).getZExtValue();
1486     unsigned Shift = SegmentVal * 53;
1487     if (Exponent > 1077)
1488       Shift += Exponent - 1077;
1489 
1490     // 2.0/PI table.
1491     static const uint32_t TwoByPi[] = {
1492         0xa2f9836e, 0x4e441529, 0xfc2757d1, 0xf534ddc0, 0xdb629599, 0x3c439041,
1493         0xfe5163ab, 0xdebbc561, 0xb7246e3a, 0x424dd2e0, 0x06492eea, 0x09d1921c,
1494         0xfe1deb1c, 0xb129a73e, 0xe88235f5, 0x2ebb4484, 0xe99c7026, 0xb45f7e41,
1495         0x3991d639, 0x835339f4, 0x9c845f8b, 0xbdf9283b, 0x1ff897ff, 0xde05980f,
1496         0xef2f118b, 0x5a0a6d1f, 0x6d367ecf, 0x27cb09b7, 0x4f463f66, 0x9e5fea2d,
1497         0x7527bac7, 0xebe5f17b, 0x3d0739f7, 0x8a5292ea, 0x6bfb5fb1, 0x1f8d5d08,
1498         0x56033046};
1499 
1500     // Return 0 for outbound segment (hardware behavior).
1501     unsigned Idx = Shift >> 5;
1502     if (Idx + 2 >= std::size(TwoByPi)) {
1503       APFloat Zero = APFloat::getZero(II.getType()->getFltSemantics());
1504       return IC.replaceInstUsesWith(II, ConstantFP::get(II.getType(), Zero));
1505     }
1506 
1507     unsigned BShift = Shift & 0x1f;
1508     uint64_t Thi = Make_64(TwoByPi[Idx], TwoByPi[Idx + 1]);
1509     uint64_t Tlo = Make_64(TwoByPi[Idx + 2], 0);
1510     if (BShift)
1511       Thi = (Thi << BShift) | (Tlo >> (64 - BShift));
1512     Thi = Thi >> 11;
1513     APFloat Result = APFloat((double)Thi);
1514 
1515     int Scale = -53 - Shift;
1516     if (Exponent >= 1968)
1517       Scale += 128;
1518 
1519     Result = scalbn(Result, Scale, RoundingMode::NearestTiesToEven);
1520     return IC.replaceInstUsesWith(II, ConstantFP::get(Src->getType(), Result));
1521   }
1522   case Intrinsic::amdgcn_fmul_legacy: {
1523     Value *Op0 = II.getArgOperand(0);
1524     Value *Op1 = II.getArgOperand(1);
1525 
1526     for (Value *Src : {Op0, Op1}) {
1527       if (isa<PoisonValue>(Src))
1528         return IC.replaceInstUsesWith(II, Src);
1529     }
1530 
1531     // The legacy behaviour is that multiplying +/-0.0 by anything, even NaN or
1532     // infinity, gives +0.0.
1533     // TODO: Move to InstSimplify?
1534     if (match(Op0, PatternMatch::m_AnyZeroFP()) ||
1535         match(Op1, PatternMatch::m_AnyZeroFP()))
1536       return IC.replaceInstUsesWith(II, ConstantFP::getZero(II.getType()));
1537 
1538     // If we can prove we don't have one of the special cases then we can use a
1539     // normal fmul instruction instead.
1540     if (canSimplifyLegacyMulToMul(II, Op0, Op1, IC)) {
1541       auto *FMul = IC.Builder.CreateFMulFMF(Op0, Op1, &II);
1542       FMul->takeName(&II);
1543       return IC.replaceInstUsesWith(II, FMul);
1544     }
1545     break;
1546   }
1547   case Intrinsic::amdgcn_fma_legacy: {
1548     Value *Op0 = II.getArgOperand(0);
1549     Value *Op1 = II.getArgOperand(1);
1550     Value *Op2 = II.getArgOperand(2);
1551 
1552     for (Value *Src : {Op0, Op1, Op2}) {
1553       if (isa<PoisonValue>(Src))
1554         return IC.replaceInstUsesWith(II, Src);
1555     }
1556 
1557     // The legacy behaviour is that multiplying +/-0.0 by anything, even NaN or
1558     // infinity, gives +0.0.
1559     // TODO: Move to InstSimplify?
1560     if (match(Op0, PatternMatch::m_AnyZeroFP()) ||
1561         match(Op1, PatternMatch::m_AnyZeroFP())) {
1562       // It's tempting to just return Op2 here, but that would give the wrong
1563       // result if Op2 was -0.0.
1564       auto *Zero = ConstantFP::getZero(II.getType());
1565       auto *FAdd = IC.Builder.CreateFAddFMF(Zero, Op2, &II);
1566       FAdd->takeName(&II);
1567       return IC.replaceInstUsesWith(II, FAdd);
1568     }
1569 
1570     // If we can prove we don't have one of the special cases then we can use a
1571     // normal fma instead.
1572     if (canSimplifyLegacyMulToMul(II, Op0, Op1, IC)) {
1573       II.setCalledOperand(Intrinsic::getOrInsertDeclaration(
1574           II.getModule(), Intrinsic::fma, II.getType()));
1575       return &II;
1576     }
1577     break;
1578   }
1579   case Intrinsic::amdgcn_is_shared:
1580   case Intrinsic::amdgcn_is_private: {
1581     Value *Src = II.getArgOperand(0);
1582     if (isa<PoisonValue>(Src))
1583       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
1584     if (isa<UndefValue>(Src))
1585       return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
1586 
1587     if (isa<ConstantPointerNull>(II.getArgOperand(0)))
1588       return IC.replaceInstUsesWith(II, ConstantInt::getFalse(II.getType()));
1589     break;
1590   }
1591   case Intrinsic::amdgcn_make_buffer_rsrc: {
1592     Value *Src = II.getArgOperand(0);
1593     if (isa<PoisonValue>(Src))
1594       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
1595     return std::nullopt;
1596   }
1597   case Intrinsic::amdgcn_raw_buffer_store_format:
1598   case Intrinsic::amdgcn_struct_buffer_store_format:
1599   case Intrinsic::amdgcn_raw_tbuffer_store:
1600   case Intrinsic::amdgcn_struct_tbuffer_store:
1601   case Intrinsic::amdgcn_image_store_1d:
1602   case Intrinsic::amdgcn_image_store_1darray:
1603   case Intrinsic::amdgcn_image_store_2d:
1604   case Intrinsic::amdgcn_image_store_2darray:
1605   case Intrinsic::amdgcn_image_store_2darraymsaa:
1606   case Intrinsic::amdgcn_image_store_2dmsaa:
1607   case Intrinsic::amdgcn_image_store_3d:
1608   case Intrinsic::amdgcn_image_store_cube:
1609   case Intrinsic::amdgcn_image_store_mip_1d:
1610   case Intrinsic::amdgcn_image_store_mip_1darray:
1611   case Intrinsic::amdgcn_image_store_mip_2d:
1612   case Intrinsic::amdgcn_image_store_mip_2darray:
1613   case Intrinsic::amdgcn_image_store_mip_3d:
1614   case Intrinsic::amdgcn_image_store_mip_cube: {
1615     if (!isa<FixedVectorType>(II.getArgOperand(0)->getType()))
1616       break;
1617 
1618     APInt DemandedElts;
1619     if (ST->hasDefaultComponentBroadcast())
1620       DemandedElts = defaultComponentBroadcast(II.getArgOperand(0));
1621     else if (ST->hasDefaultComponentZero())
1622       DemandedElts = trimTrailingZerosInVector(IC, II.getArgOperand(0), &II);
1623     else
1624       break;
1625 
1626     int DMaskIdx = getAMDGPUImageDMaskIntrinsic(II.getIntrinsicID()) ? 1 : -1;
1627     if (simplifyAMDGCNMemoryIntrinsicDemanded(IC, II, DemandedElts, DMaskIdx,
1628                                               false)) {
1629       return IC.eraseInstFromFunction(II);
1630     }
1631 
1632     break;
1633   }
1634   case Intrinsic::amdgcn_prng_b32: {
1635     auto *Src = II.getArgOperand(0);
1636     if (isa<UndefValue>(Src)) {
1637       return IC.replaceInstUsesWith(II, Src);
1638     }
1639     return std::nullopt;
1640   }
1641   case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
1642   case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
1643     Value *Src0 = II.getArgOperand(0);
1644     Value *Src1 = II.getArgOperand(1);
1645     uint64_t CBSZ = cast<ConstantInt>(II.getArgOperand(3))->getZExtValue();
1646     uint64_t BLGP = cast<ConstantInt>(II.getArgOperand(4))->getZExtValue();
1647     auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
1648     auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
1649 
1650     auto getFormatNumRegs = [](unsigned FormatVal) {
1651       switch (FormatVal) {
1652       case AMDGPU::MFMAScaleFormats::FP6_E2M3:
1653       case AMDGPU::MFMAScaleFormats::FP6_E3M2:
1654         return 6u;
1655       case AMDGPU::MFMAScaleFormats::FP4_E2M1:
1656         return 4u;
1657       case AMDGPU::MFMAScaleFormats::FP8_E4M3:
1658       case AMDGPU::MFMAScaleFormats::FP8_E5M2:
1659         return 8u;
1660       default:
1661         llvm_unreachable("invalid format value");
1662       }
1663     };
1664 
1665     bool MadeChange = false;
1666     unsigned Src0NumElts = getFormatNumRegs(CBSZ);
1667     unsigned Src1NumElts = getFormatNumRegs(BLGP);
1668 
1669     // Depending on the used format, fewer registers are required so shrink the
1670     // vector type.
1671     if (Src0Ty->getNumElements() > Src0NumElts) {
1672       Src0 = IC.Builder.CreateExtractVector(
1673           FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
1674           uint64_t(0));
1675       MadeChange = true;
1676     }
1677 
1678     if (Src1Ty->getNumElements() > Src1NumElts) {
1679       Src1 = IC.Builder.CreateExtractVector(
1680           FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
1681           uint64_t(0));
1682       MadeChange = true;
1683     }
1684 
1685     if (!MadeChange)
1686       return std::nullopt;
1687 
1688     SmallVector<Value *, 10> Args(II.args());
1689     Args[0] = Src0;
1690     Args[1] = Src1;
1691 
1692     CallInst *NewII = IC.Builder.CreateIntrinsic(
1693         IID, {Src0->getType(), Src1->getType()}, Args, &II);
1694     NewII->takeName(&II);
1695     return IC.replaceInstUsesWith(II, NewII);
1696   }
1697   }
1698   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
1699             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
1700     return simplifyAMDGCNImageIntrinsic(ST, ImageDimIntr, II, IC);
1701   }
1702   return std::nullopt;
1703 }
1704 
1705 /// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics.
1706 ///
1707 /// The result of simplifying amdgcn image and buffer store intrinsics is updating
1708 /// definitions of the intrinsics vector argument, not Uses of the result like
1709 /// image and buffer loads.
1710 /// Note: This only supports non-TFE/LWE image intrinsic calls; those have
1711 ///       struct returns.
simplifyAMDGCNMemoryIntrinsicDemanded(InstCombiner & IC,IntrinsicInst & II,APInt DemandedElts,int DMaskIdx,bool IsLoad)1712 static Value *simplifyAMDGCNMemoryIntrinsicDemanded(InstCombiner &IC,
1713                                                     IntrinsicInst &II,
1714                                                     APInt DemandedElts,
1715                                                     int DMaskIdx, bool IsLoad) {
1716 
1717   auto *IIVTy = cast<FixedVectorType>(IsLoad ? II.getType()
1718                                              : II.getOperand(0)->getType());
1719   unsigned VWidth = IIVTy->getNumElements();
1720   if (VWidth == 1)
1721     return nullptr;
1722   Type *EltTy = IIVTy->getElementType();
1723 
1724   IRBuilderBase::InsertPointGuard Guard(IC.Builder);
1725   IC.Builder.SetInsertPoint(&II);
1726 
1727   // Assume the arguments are unchanged and later override them, if needed.
1728   SmallVector<Value *, 16> Args(II.args());
1729 
1730   if (DMaskIdx < 0) {
1731     // Buffer case.
1732 
1733     const unsigned ActiveBits = DemandedElts.getActiveBits();
1734     const unsigned UnusedComponentsAtFront = DemandedElts.countr_zero();
1735 
1736     // Start assuming the prefix of elements is demanded, but possibly clear
1737     // some other bits if there are trailing zeros (unused components at front)
1738     // and update offset.
1739     DemandedElts = (1 << ActiveBits) - 1;
1740 
1741     if (UnusedComponentsAtFront > 0) {
1742       static const unsigned InvalidOffsetIdx = 0xf;
1743 
1744       unsigned OffsetIdx;
1745       switch (II.getIntrinsicID()) {
1746       case Intrinsic::amdgcn_raw_buffer_load:
1747       case Intrinsic::amdgcn_raw_ptr_buffer_load:
1748         OffsetIdx = 1;
1749         break;
1750       case Intrinsic::amdgcn_s_buffer_load:
1751         // If resulting type is vec3, there is no point in trimming the
1752         // load with updated offset, as the vec3 would most likely be widened to
1753         // vec4 anyway during lowering.
1754         if (ActiveBits == 4 && UnusedComponentsAtFront == 1)
1755           OffsetIdx = InvalidOffsetIdx;
1756         else
1757           OffsetIdx = 1;
1758         break;
1759       case Intrinsic::amdgcn_struct_buffer_load:
1760       case Intrinsic::amdgcn_struct_ptr_buffer_load:
1761         OffsetIdx = 2;
1762         break;
1763       default:
1764         // TODO: handle tbuffer* intrinsics.
1765         OffsetIdx = InvalidOffsetIdx;
1766         break;
1767       }
1768 
1769       if (OffsetIdx != InvalidOffsetIdx) {
1770         // Clear demanded bits and update the offset.
1771         DemandedElts &= ~((1 << UnusedComponentsAtFront) - 1);
1772         auto *Offset = Args[OffsetIdx];
1773         unsigned SingleComponentSizeInBits =
1774             IC.getDataLayout().getTypeSizeInBits(EltTy);
1775         unsigned OffsetAdd =
1776             UnusedComponentsAtFront * SingleComponentSizeInBits / 8;
1777         auto *OffsetAddVal = ConstantInt::get(Offset->getType(), OffsetAdd);
1778         Args[OffsetIdx] = IC.Builder.CreateAdd(Offset, OffsetAddVal);
1779       }
1780     }
1781   } else {
1782     // Image case.
1783 
1784     ConstantInt *DMask = cast<ConstantInt>(Args[DMaskIdx]);
1785     unsigned DMaskVal = DMask->getZExtValue() & 0xf;
1786 
1787     // dmask 0 has special semantics, do not simplify.
1788     if (DMaskVal == 0)
1789       return nullptr;
1790 
1791     // Mask off values that are undefined because the dmask doesn't cover them
1792     DemandedElts &= (1 << llvm::popcount(DMaskVal)) - 1;
1793 
1794     unsigned NewDMaskVal = 0;
1795     unsigned OrigLdStIdx = 0;
1796     for (unsigned SrcIdx = 0; SrcIdx < 4; ++SrcIdx) {
1797       const unsigned Bit = 1 << SrcIdx;
1798       if (!!(DMaskVal & Bit)) {
1799         if (!!DemandedElts[OrigLdStIdx])
1800           NewDMaskVal |= Bit;
1801         OrigLdStIdx++;
1802       }
1803     }
1804 
1805     if (DMaskVal != NewDMaskVal)
1806       Args[DMaskIdx] = ConstantInt::get(DMask->getType(), NewDMaskVal);
1807   }
1808 
1809   unsigned NewNumElts = DemandedElts.popcount();
1810   if (!NewNumElts)
1811     return PoisonValue::get(IIVTy);
1812 
1813   if (NewNumElts >= VWidth && DemandedElts.isMask()) {
1814     if (DMaskIdx >= 0)
1815       II.setArgOperand(DMaskIdx, Args[DMaskIdx]);
1816     return nullptr;
1817   }
1818 
1819   // Validate function argument and return types, extracting overloaded types
1820   // along the way.
1821   SmallVector<Type *, 6> OverloadTys;
1822   if (!Intrinsic::getIntrinsicSignature(II.getCalledFunction(), OverloadTys))
1823     return nullptr;
1824 
1825   Type *NewTy =
1826       (NewNumElts == 1) ? EltTy : FixedVectorType::get(EltTy, NewNumElts);
1827   OverloadTys[0] = NewTy;
1828 
1829   if (!IsLoad) {
1830     SmallVector<int, 8> EltMask;
1831     for (unsigned OrigStoreIdx = 0; OrigStoreIdx < VWidth; ++OrigStoreIdx)
1832       if (DemandedElts[OrigStoreIdx])
1833         EltMask.push_back(OrigStoreIdx);
1834 
1835     if (NewNumElts == 1)
1836       Args[0] = IC.Builder.CreateExtractElement(II.getOperand(0), EltMask[0]);
1837     else
1838       Args[0] = IC.Builder.CreateShuffleVector(II.getOperand(0), EltMask);
1839   }
1840 
1841   CallInst *NewCall =
1842       IC.Builder.CreateIntrinsic(II.getIntrinsicID(), OverloadTys, Args);
1843   NewCall->takeName(&II);
1844   NewCall->copyMetadata(II);
1845 
1846   if (IsLoad) {
1847     if (NewNumElts == 1) {
1848       return IC.Builder.CreateInsertElement(PoisonValue::get(IIVTy), NewCall,
1849                                             DemandedElts.countr_zero());
1850     }
1851 
1852     SmallVector<int, 8> EltMask;
1853     unsigned NewLoadIdx = 0;
1854     for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) {
1855       if (!!DemandedElts[OrigLoadIdx])
1856         EltMask.push_back(NewLoadIdx++);
1857       else
1858         EltMask.push_back(NewNumElts);
1859     }
1860 
1861     auto *Shuffle = IC.Builder.CreateShuffleVector(NewCall, EltMask);
1862 
1863     return Shuffle;
1864   }
1865 
1866   return NewCall;
1867 }
1868 
simplifyAMDGCNLaneIntrinsicDemanded(InstCombiner & IC,IntrinsicInst & II,const APInt & DemandedElts,APInt & UndefElts) const1869 Value *GCNTTIImpl::simplifyAMDGCNLaneIntrinsicDemanded(
1870     InstCombiner &IC, IntrinsicInst &II, const APInt &DemandedElts,
1871     APInt &UndefElts) const {
1872   auto *VT = dyn_cast<FixedVectorType>(II.getType());
1873   if (!VT)
1874     return nullptr;
1875 
1876   const unsigned FirstElt = DemandedElts.countr_zero();
1877   const unsigned LastElt = DemandedElts.getActiveBits() - 1;
1878   const unsigned MaskLen = LastElt - FirstElt + 1;
1879 
1880   unsigned OldNumElts = VT->getNumElements();
1881   if (MaskLen == OldNumElts && MaskLen != 1)
1882     return nullptr;
1883 
1884   Type *EltTy = VT->getElementType();
1885   Type *NewVT = MaskLen == 1 ? EltTy : FixedVectorType::get(EltTy, MaskLen);
1886 
1887   // Theoretically we should support these intrinsics for any legal type. Avoid
1888   // introducing cases that aren't direct register types like v3i16.
1889   if (!isTypeLegal(NewVT))
1890     return nullptr;
1891 
1892   Value *Src = II.getArgOperand(0);
1893 
1894   // Make sure convergence tokens are preserved.
1895   // TODO: CreateIntrinsic should allow directly copying bundles
1896   SmallVector<OperandBundleDef, 2> OpBundles;
1897   II.getOperandBundlesAsDefs(OpBundles);
1898 
1899   Module *M = IC.Builder.GetInsertBlock()->getModule();
1900   Function *Remangled =
1901       Intrinsic::getOrInsertDeclaration(M, II.getIntrinsicID(), {NewVT});
1902 
1903   if (MaskLen == 1) {
1904     Value *Extract = IC.Builder.CreateExtractElement(Src, FirstElt);
1905 
1906     // TODO: Preserve callsite attributes?
1907     CallInst *NewCall = IC.Builder.CreateCall(Remangled, {Extract}, OpBundles);
1908 
1909     return IC.Builder.CreateInsertElement(PoisonValue::get(II.getType()),
1910                                           NewCall, FirstElt);
1911   }
1912 
1913   SmallVector<int> ExtractMask(MaskLen, -1);
1914   for (unsigned I = 0; I != MaskLen; ++I) {
1915     if (DemandedElts[FirstElt + I])
1916       ExtractMask[I] = FirstElt + I;
1917   }
1918 
1919   Value *Extract = IC.Builder.CreateShuffleVector(Src, ExtractMask);
1920 
1921   // TODO: Preserve callsite attributes?
1922   CallInst *NewCall = IC.Builder.CreateCall(Remangled, {Extract}, OpBundles);
1923 
1924   SmallVector<int> InsertMask(OldNumElts, -1);
1925   for (unsigned I = 0; I != MaskLen; ++I) {
1926     if (DemandedElts[FirstElt + I])
1927       InsertMask[FirstElt + I] = I;
1928   }
1929 
1930   // FIXME: If the call has a convergence bundle, we end up leaving the dead
1931   // call behind.
1932   return IC.Builder.CreateShuffleVector(NewCall, InsertMask);
1933 }
1934 
simplifyDemandedVectorEltsIntrinsic(InstCombiner & IC,IntrinsicInst & II,APInt DemandedElts,APInt & UndefElts,APInt & UndefElts2,APInt & UndefElts3,std::function<void (Instruction *,unsigned,APInt,APInt &)> SimplifyAndSetOp) const1935 std::optional<Value *> GCNTTIImpl::simplifyDemandedVectorEltsIntrinsic(
1936     InstCombiner &IC, IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts,
1937     APInt &UndefElts2, APInt &UndefElts3,
1938     std::function<void(Instruction *, unsigned, APInt, APInt &)>
1939         SimplifyAndSetOp) const {
1940   switch (II.getIntrinsicID()) {
1941   case Intrinsic::amdgcn_readfirstlane:
1942     SimplifyAndSetOp(&II, 0, DemandedElts, UndefElts);
1943     return simplifyAMDGCNLaneIntrinsicDemanded(IC, II, DemandedElts, UndefElts);
1944   case Intrinsic::amdgcn_raw_buffer_load:
1945   case Intrinsic::amdgcn_raw_ptr_buffer_load:
1946   case Intrinsic::amdgcn_raw_buffer_load_format:
1947   case Intrinsic::amdgcn_raw_ptr_buffer_load_format:
1948   case Intrinsic::amdgcn_raw_tbuffer_load:
1949   case Intrinsic::amdgcn_raw_ptr_tbuffer_load:
1950   case Intrinsic::amdgcn_s_buffer_load:
1951   case Intrinsic::amdgcn_struct_buffer_load:
1952   case Intrinsic::amdgcn_struct_ptr_buffer_load:
1953   case Intrinsic::amdgcn_struct_buffer_load_format:
1954   case Intrinsic::amdgcn_struct_ptr_buffer_load_format:
1955   case Intrinsic::amdgcn_struct_tbuffer_load:
1956   case Intrinsic::amdgcn_struct_ptr_tbuffer_load:
1957     return simplifyAMDGCNMemoryIntrinsicDemanded(IC, II, DemandedElts);
1958   default: {
1959     if (getAMDGPUImageDMaskIntrinsic(II.getIntrinsicID())) {
1960       return simplifyAMDGCNMemoryIntrinsicDemanded(IC, II, DemandedElts, 0);
1961     }
1962     break;
1963   }
1964   }
1965   return std::nullopt;
1966 }
1967