xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp (revision 5deeebd8c6ca991269e72902a7a62cada57947f6)
1 //===-- X86ISelLowering.cpp - X86 DAG Lowering Implementation -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that X86 uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "X86ISelLowering.h"
15 #include "MCTargetDesc/X86ShuffleDecode.h"
16 #include "X86.h"
17 #include "X86CallingConv.h"
18 #include "X86FrameLowering.h"
19 #include "X86InstrBuilder.h"
20 #include "X86IntrinsicsInfo.h"
21 #include "X86MachineFunctionInfo.h"
22 #include "X86TargetMachine.h"
23 #include "X86TargetObjectFile.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallSet.h"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSwitch.h"
29 #include "llvm/Analysis/BlockFrequencyInfo.h"
30 #include "llvm/Analysis/ObjCARCUtil.h"
31 #include "llvm/Analysis/ProfileSummaryInfo.h"
32 #include "llvm/Analysis/VectorUtils.h"
33 #include "llvm/CodeGen/IntrinsicLowering.h"
34 #include "llvm/CodeGen/MachineFrameInfo.h"
35 #include "llvm/CodeGen/MachineFunction.h"
36 #include "llvm/CodeGen/MachineInstrBuilder.h"
37 #include "llvm/CodeGen/MachineJumpTableInfo.h"
38 #include "llvm/CodeGen/MachineLoopInfo.h"
39 #include "llvm/CodeGen/MachineModuleInfo.h"
40 #include "llvm/CodeGen/MachineRegisterInfo.h"
41 #include "llvm/CodeGen/SDPatternMatch.h"
42 #include "llvm/CodeGen/TargetLowering.h"
43 #include "llvm/CodeGen/WinEHFuncInfo.h"
44 #include "llvm/IR/CallingConv.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DerivedTypes.h"
47 #include "llvm/IR/EHPersonalities.h"
48 #include "llvm/IR/Function.h"
49 #include "llvm/IR/GlobalAlias.h"
50 #include "llvm/IR/GlobalVariable.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/Intrinsics.h"
54 #include "llvm/IR/PatternMatch.h"
55 #include "llvm/MC/MCAsmInfo.h"
56 #include "llvm/MC/MCContext.h"
57 #include "llvm/MC/MCExpr.h"
58 #include "llvm/MC/MCSymbol.h"
59 #include "llvm/Support/CommandLine.h"
60 #include "llvm/Support/Debug.h"
61 #include "llvm/Support/ErrorHandling.h"
62 #include "llvm/Support/KnownBits.h"
63 #include "llvm/Support/MathExtras.h"
64 #include "llvm/Target/TargetOptions.h"
65 #include <algorithm>
66 #include <bitset>
67 #include <cctype>
68 #include <numeric>
69 using namespace llvm;
70 
71 #define DEBUG_TYPE "x86-isel"
72 
73 static cl::opt<int> ExperimentalPrefInnermostLoopAlignment(
74     "x86-experimental-pref-innermost-loop-alignment", cl::init(4),
75     cl::desc(
76         "Sets the preferable loop alignment for experiments (as log2 bytes) "
77         "for innermost loops only. If specified, this option overrides "
78         "alignment set by x86-experimental-pref-loop-alignment."),
79     cl::Hidden);
80 
81 static cl::opt<int> BrMergingBaseCostThresh(
82     "x86-br-merging-base-cost", cl::init(2),
83     cl::desc(
84         "Sets the cost threshold for when multiple conditionals will be merged "
85         "into one branch versus be split in multiple branches. Merging "
86         "conditionals saves branches at the cost of additional instructions. "
87         "This value sets the instruction cost limit, below which conditionals "
88         "will be merged, and above which conditionals will be split. Set to -1 "
89         "to never merge branches."),
90     cl::Hidden);
91 
92 static cl::opt<int> BrMergingCcmpBias(
93     "x86-br-merging-ccmp-bias", cl::init(6),
94     cl::desc("Increases 'x86-br-merging-base-cost' in cases that the target "
95              "supports conditional compare instructions."),
96     cl::Hidden);
97 
98 static cl::opt<int> BrMergingLikelyBias(
99     "x86-br-merging-likely-bias", cl::init(0),
100     cl::desc("Increases 'x86-br-merging-base-cost' in cases that it is likely "
101              "that all conditionals will be executed. For example for merging "
102              "the conditionals (a == b && c > d), if its known that a == b is "
103              "likely, then it is likely that if the conditionals are split "
104              "both sides will be executed, so it may be desirable to increase "
105              "the instruction cost threshold. Set to -1 to never merge likely "
106              "branches."),
107     cl::Hidden);
108 
109 static cl::opt<int> BrMergingUnlikelyBias(
110     "x86-br-merging-unlikely-bias", cl::init(-1),
111     cl::desc(
112         "Decreases 'x86-br-merging-base-cost' in cases that it is unlikely "
113         "that all conditionals will be executed. For example for merging "
114         "the conditionals (a == b && c > d), if its known that a == b is "
115         "unlikely, then it is unlikely that if the conditionals are split "
116         "both sides will be executed, so it may be desirable to decrease "
117         "the instruction cost threshold. Set to -1 to never merge unlikely "
118         "branches."),
119     cl::Hidden);
120 
121 static cl::opt<bool> MulConstantOptimization(
122     "mul-constant-optimization", cl::init(true),
123     cl::desc("Replace 'mul x, Const' with more effective instructions like "
124              "SHIFT, LEA, etc."),
125     cl::Hidden);
126 
X86TargetLowering(const X86TargetMachine & TM,const X86Subtarget & STI)127 X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
128                                      const X86Subtarget &STI)
129     : TargetLowering(TM), Subtarget(STI) {
130   bool UseX87 = !Subtarget.useSoftFloat() && Subtarget.hasX87();
131   MVT PtrVT = MVT::getIntegerVT(TM.getPointerSizeInBits(0));
132 
133   // Set up the TargetLowering object.
134 
135   // X86 is weird. It always uses i8 for shift amounts and setcc results.
136   setBooleanContents(ZeroOrOneBooleanContent);
137   // X86-SSE is even stranger. It uses -1 or 0 for vector masks.
138   setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
139 
140   // X86 instruction cache is coherent with its data cache so we can use the
141   // default expansion to a no-op.
142   setOperationAction(ISD::CLEAR_CACHE, MVT::Other, Expand);
143 
144   // For 64-bit, since we have so many registers, use the ILP scheduler.
145   // For 32-bit, use the register pressure specific scheduling.
146   // For Atom, always use ILP scheduling.
147   if (Subtarget.isAtom())
148     setSchedulingPreference(Sched::ILP);
149   else if (Subtarget.is64Bit())
150     setSchedulingPreference(Sched::ILP);
151   else
152     setSchedulingPreference(Sched::RegPressure);
153   const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
154   setStackPointerRegisterToSaveRestore(RegInfo->getStackRegister());
155 
156   // Bypass expensive divides and use cheaper ones.
157   if (TM.getOptLevel() >= CodeGenOptLevel::Default) {
158     if (Subtarget.hasSlowDivide32())
159       addBypassSlowDiv(32, 8);
160     if (Subtarget.hasSlowDivide64() && Subtarget.is64Bit())
161       addBypassSlowDiv(64, 32);
162   }
163 
164   // Setup Windows compiler runtime calls.
165   if (Subtarget.isTargetWindowsMSVC() || Subtarget.isTargetWindowsItanium()) {
166     static const struct {
167       const RTLIB::Libcall Op;
168       const char * const Name;
169       const CallingConv::ID CC;
170     } LibraryCalls[] = {
171       { RTLIB::SDIV_I64, "_alldiv", CallingConv::X86_StdCall },
172       { RTLIB::UDIV_I64, "_aulldiv", CallingConv::X86_StdCall },
173       { RTLIB::SREM_I64, "_allrem", CallingConv::X86_StdCall },
174       { RTLIB::UREM_I64, "_aullrem", CallingConv::X86_StdCall },
175       { RTLIB::MUL_I64, "_allmul", CallingConv::X86_StdCall },
176     };
177 
178     for (const auto &LC : LibraryCalls) {
179       setLibcallName(LC.Op, LC.Name);
180       setLibcallCallingConv(LC.Op, LC.CC);
181     }
182   }
183 
184   if (Subtarget.canUseCMPXCHG16B())
185     setMaxAtomicSizeInBitsSupported(128);
186   else if (Subtarget.canUseCMPXCHG8B())
187     setMaxAtomicSizeInBitsSupported(64);
188   else
189     setMaxAtomicSizeInBitsSupported(32);
190 
191   setMaxDivRemBitWidthSupported(Subtarget.is64Bit() ? 128 : 64);
192 
193   setMaxLargeFPConvertBitWidthSupported(128);
194 
195   // Set up the register classes.
196   addRegisterClass(MVT::i8, &X86::GR8RegClass);
197   addRegisterClass(MVT::i16, &X86::GR16RegClass);
198   addRegisterClass(MVT::i32, &X86::GR32RegClass);
199   if (Subtarget.is64Bit())
200     addRegisterClass(MVT::i64, &X86::GR64RegClass);
201 
202   for (MVT VT : MVT::integer_valuetypes())
203     setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote);
204 
205   // We don't accept any truncstore of integer registers.
206   setTruncStoreAction(MVT::i64, MVT::i32, Expand);
207   setTruncStoreAction(MVT::i64, MVT::i16, Expand);
208   setTruncStoreAction(MVT::i64, MVT::i8 , Expand);
209   setTruncStoreAction(MVT::i32, MVT::i16, Expand);
210   setTruncStoreAction(MVT::i32, MVT::i8 , Expand);
211   setTruncStoreAction(MVT::i16, MVT::i8,  Expand);
212 
213   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
214 
215   // SETOEQ and SETUNE require checking two conditions.
216   for (auto VT : {MVT::f32, MVT::f64, MVT::f80}) {
217     setCondCodeAction(ISD::SETOEQ, VT, Expand);
218     setCondCodeAction(ISD::SETUNE, VT, Expand);
219   }
220 
221   // Integer absolute.
222   if (Subtarget.canUseCMOV()) {
223     setOperationAction(ISD::ABS            , MVT::i16  , Custom);
224     setOperationAction(ISD::ABS            , MVT::i32  , Custom);
225     if (Subtarget.is64Bit())
226       setOperationAction(ISD::ABS          , MVT::i64  , Custom);
227   }
228 
229   // Absolute difference.
230   for (auto Op : {ISD::ABDS, ISD::ABDU}) {
231     setOperationAction(Op                  , MVT::i8   , Custom);
232     setOperationAction(Op                  , MVT::i16  , Custom);
233     setOperationAction(Op                  , MVT::i32  , Custom);
234     if (Subtarget.is64Bit())
235      setOperationAction(Op                 , MVT::i64  , Custom);
236   }
237 
238   // Signed saturation subtraction.
239   setOperationAction(ISD::SSUBSAT          , MVT::i8   , Custom);
240   setOperationAction(ISD::SSUBSAT          , MVT::i16  , Custom);
241   setOperationAction(ISD::SSUBSAT          , MVT::i32  , Custom);
242   if (Subtarget.is64Bit())
243     setOperationAction(ISD::SSUBSAT        , MVT::i64  , Custom);
244 
245   // Funnel shifts.
246   for (auto ShiftOp : {ISD::FSHL, ISD::FSHR}) {
247     // For slow shld targets we only lower for code size.
248     LegalizeAction ShiftDoubleAction = Subtarget.isSHLDSlow() ? Custom : Legal;
249 
250     setOperationAction(ShiftOp             , MVT::i8   , Custom);
251     setOperationAction(ShiftOp             , MVT::i16  , Custom);
252     setOperationAction(ShiftOp             , MVT::i32  , ShiftDoubleAction);
253     if (Subtarget.is64Bit())
254       setOperationAction(ShiftOp           , MVT::i64  , ShiftDoubleAction);
255   }
256 
257   if (!Subtarget.useSoftFloat()) {
258     // Promote all UINT_TO_FP to larger SINT_TO_FP's, as X86 doesn't have this
259     // operation.
260     setOperationAction(ISD::UINT_TO_FP,        MVT::i8, Promote);
261     setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i8, Promote);
262     setOperationAction(ISD::UINT_TO_FP,        MVT::i16, Promote);
263     setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i16, Promote);
264     // We have an algorithm for SSE2, and we turn this into a 64-bit
265     // FILD or VCVTUSI2SS/SD for other targets.
266     setOperationAction(ISD::UINT_TO_FP,        MVT::i32, Custom);
267     setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom);
268     // We have an algorithm for SSE2->double, and we turn this into a
269     // 64-bit FILD followed by conditional FADD for other targets.
270     setOperationAction(ISD::UINT_TO_FP,        MVT::i64, Custom);
271     setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom);
272 
273     // Promote i8 SINT_TO_FP to larger SINT_TO_FP's, as X86 doesn't have
274     // this operation.
275     setOperationAction(ISD::SINT_TO_FP,        MVT::i8, Promote);
276     setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i8, Promote);
277     // SSE has no i16 to fp conversion, only i32. We promote in the handler
278     // to allow f80 to use i16 and f64 to use i16 with sse1 only
279     setOperationAction(ISD::SINT_TO_FP,        MVT::i16, Custom);
280     setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i16, Custom);
281     // f32 and f64 cases are Legal with SSE1/SSE2, f80 case is not
282     setOperationAction(ISD::SINT_TO_FP,        MVT::i32, Custom);
283     setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i32, Custom);
284     // In 32-bit mode these are custom lowered.  In 64-bit mode F32 and F64
285     // are Legal, f80 is custom lowered.
286     setOperationAction(ISD::SINT_TO_FP,        MVT::i64, Custom);
287     setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i64, Custom);
288 
289     // Promote i8 FP_TO_SINT to larger FP_TO_SINTS's, as X86 doesn't have
290     // this operation.
291     setOperationAction(ISD::FP_TO_SINT,        MVT::i8,  Promote);
292     // FIXME: This doesn't generate invalid exception when it should. PR44019.
293     setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i8,  Promote);
294     setOperationAction(ISD::FP_TO_SINT,        MVT::i16, Custom);
295     setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i16, Custom);
296     setOperationAction(ISD::FP_TO_SINT,        MVT::i32, Custom);
297     setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom);
298     // In 32-bit mode these are custom lowered.  In 64-bit mode F32 and F64
299     // are Legal, f80 is custom lowered.
300     setOperationAction(ISD::FP_TO_SINT,        MVT::i64, Custom);
301     setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i64, Custom);
302 
303     // Handle FP_TO_UINT by promoting the destination to a larger signed
304     // conversion.
305     setOperationAction(ISD::FP_TO_UINT,        MVT::i8,  Promote);
306     // FIXME: This doesn't generate invalid exception when it should. PR44019.
307     setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i8,  Promote);
308     setOperationAction(ISD::FP_TO_UINT,        MVT::i16, Promote);
309     // FIXME: This doesn't generate invalid exception when it should. PR44019.
310     setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i16, Promote);
311     setOperationAction(ISD::FP_TO_UINT,        MVT::i32, Custom);
312     setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i32, Custom);
313     setOperationAction(ISD::FP_TO_UINT,        MVT::i64, Custom);
314     setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i64, Custom);
315 
316     setOperationAction(ISD::LRINT,             MVT::f32, Custom);
317     setOperationAction(ISD::LRINT,             MVT::f64, Custom);
318     setOperationAction(ISD::LLRINT,            MVT::f32, Custom);
319     setOperationAction(ISD::LLRINT,            MVT::f64, Custom);
320 
321     if (!Subtarget.is64Bit()) {
322       setOperationAction(ISD::LRINT,  MVT::i64, Custom);
323       setOperationAction(ISD::LLRINT, MVT::i64, Custom);
324     }
325   }
326 
327   if (Subtarget.hasSSE2()) {
328     // Custom lowering for saturating float to int conversions.
329     // We handle promotion to larger result types manually.
330     for (MVT VT : { MVT::i8, MVT::i16, MVT::i32 }) {
331       setOperationAction(ISD::FP_TO_UINT_SAT, VT, Custom);
332       setOperationAction(ISD::FP_TO_SINT_SAT, VT, Custom);
333     }
334     if (Subtarget.is64Bit()) {
335       setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom);
336       setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);
337     }
338   }
339 
340   // Handle address space casts between mixed sized pointers.
341   setOperationAction(ISD::ADDRSPACECAST, MVT::i32, Custom);
342   setOperationAction(ISD::ADDRSPACECAST, MVT::i64, Custom);
343 
344   // TODO: when we have SSE, these could be more efficient, by using movd/movq.
345   if (!Subtarget.hasSSE2()) {
346     setOperationAction(ISD::BITCAST        , MVT::f32  , Expand);
347     setOperationAction(ISD::BITCAST        , MVT::i32  , Expand);
348     if (Subtarget.is64Bit()) {
349       setOperationAction(ISD::BITCAST      , MVT::f64  , Expand);
350       // Without SSE, i64->f64 goes through memory.
351       setOperationAction(ISD::BITCAST      , MVT::i64  , Expand);
352     }
353   } else if (!Subtarget.is64Bit())
354     setOperationAction(ISD::BITCAST      , MVT::i64  , Custom);
355 
356   // Scalar integer divide and remainder are lowered to use operations that
357   // produce two results, to match the available instructions. This exposes
358   // the two-result form to trivial CSE, which is able to combine x/y and x%y
359   // into a single instruction.
360   //
361   // Scalar integer multiply-high is also lowered to use two-result
362   // operations, to match the available instructions. However, plain multiply
363   // (low) operations are left as Legal, as there are single-result
364   // instructions for this in x86. Using the two-result multiply instructions
365   // when both high and low results are needed must be arranged by dagcombine.
366   for (auto VT : { MVT::i8, MVT::i16, MVT::i32, MVT::i64 }) {
367     setOperationAction(ISD::MULHS, VT, Expand);
368     setOperationAction(ISD::MULHU, VT, Expand);
369     setOperationAction(ISD::SDIV, VT, Expand);
370     setOperationAction(ISD::UDIV, VT, Expand);
371     setOperationAction(ISD::SREM, VT, Expand);
372     setOperationAction(ISD::UREM, VT, Expand);
373   }
374 
375   setOperationAction(ISD::BR_JT            , MVT::Other, Expand);
376   setOperationAction(ISD::BRCOND           , MVT::Other, Custom);
377   for (auto VT : { MVT::f32, MVT::f64, MVT::f80, MVT::f128,
378                    MVT::i8,  MVT::i16, MVT::i32, MVT::i64 }) {
379     setOperationAction(ISD::BR_CC,     VT, Expand);
380     setOperationAction(ISD::SELECT_CC, VT, Expand);
381   }
382   if (Subtarget.is64Bit())
383     setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal);
384   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16  , Legal);
385   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8   , Legal);
386   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1   , Expand);
387 
388   setOperationAction(ISD::FREM             , MVT::f32  , Expand);
389   setOperationAction(ISD::FREM             , MVT::f64  , Expand);
390   setOperationAction(ISD::FREM             , MVT::f80  , Expand);
391   setOperationAction(ISD::FREM             , MVT::f128 , Expand);
392 
393   if (!Subtarget.useSoftFloat() && Subtarget.hasX87()) {
394     setOperationAction(ISD::GET_ROUNDING   , MVT::i32  , Custom);
395     setOperationAction(ISD::SET_ROUNDING   , MVT::Other, Custom);
396     setOperationAction(ISD::GET_FPENV_MEM  , MVT::Other, Custom);
397     setOperationAction(ISD::SET_FPENV_MEM  , MVT::Other, Custom);
398     setOperationAction(ISD::RESET_FPENV    , MVT::Other, Custom);
399   }
400 
401   // Promote the i8 variants and force them on up to i32 which has a shorter
402   // encoding.
403   setOperationPromotedToType(ISD::CTTZ           , MVT::i8   , MVT::i32);
404   setOperationPromotedToType(ISD::CTTZ_ZERO_UNDEF, MVT::i8   , MVT::i32);
405   // Promoted i16. tzcntw has a false dependency on Intel CPUs. For BSF, we emit
406   // a REP prefix to encode it as TZCNT for modern CPUs so it makes sense to
407   // promote that too.
408   setOperationPromotedToType(ISD::CTTZ           , MVT::i16  , MVT::i32);
409   setOperationPromotedToType(ISD::CTTZ_ZERO_UNDEF, MVT::i16  , MVT::i32);
410 
411   if (!Subtarget.hasBMI()) {
412     setOperationAction(ISD::CTTZ           , MVT::i32  , Custom);
413     setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32  , Legal);
414     if (Subtarget.is64Bit()) {
415       setOperationAction(ISD::CTTZ         , MVT::i64  , Custom);
416       setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Legal);
417     }
418   }
419 
420   if (Subtarget.hasLZCNT()) {
421     // When promoting the i8 variants, force them to i32 for a shorter
422     // encoding.
423     setOperationPromotedToType(ISD::CTLZ           , MVT::i8   , MVT::i32);
424     setOperationPromotedToType(ISD::CTLZ_ZERO_UNDEF, MVT::i8   , MVT::i32);
425   } else {
426     for (auto VT : {MVT::i8, MVT::i16, MVT::i32, MVT::i64}) {
427       if (VT == MVT::i64 && !Subtarget.is64Bit())
428         continue;
429       setOperationAction(ISD::CTLZ           , VT, Custom);
430       setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Custom);
431     }
432   }
433 
434   for (auto Op : {ISD::FP16_TO_FP, ISD::STRICT_FP16_TO_FP, ISD::FP_TO_FP16,
435                   ISD::STRICT_FP_TO_FP16}) {
436     // Special handling for half-precision floating point conversions.
437     // If we don't have F16C support, then lower half float conversions
438     // into library calls.
439     setOperationAction(
440         Op, MVT::f32,
441         (!Subtarget.useSoftFloat() && Subtarget.hasF16C()) ? Custom : Expand);
442     // There's never any support for operations beyond MVT::f32.
443     setOperationAction(Op, MVT::f64, Expand);
444     setOperationAction(Op, MVT::f80, Expand);
445     setOperationAction(Op, MVT::f128, Expand);
446   }
447 
448   for (auto VT : {MVT::f32, MVT::f64, MVT::f80, MVT::f128}) {
449     setOperationAction(ISD::STRICT_FP_TO_BF16, VT, Expand);
450     setOperationAction(ISD::STRICT_BF16_TO_FP, VT, Expand);
451   }
452 
453   for (MVT VT : {MVT::f32, MVT::f64, MVT::f80, MVT::f128}) {
454     setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
455     setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand);
456     setTruncStoreAction(VT, MVT::f16, Expand);
457     setTruncStoreAction(VT, MVT::bf16, Expand);
458 
459     setOperationAction(ISD::BF16_TO_FP, VT, Expand);
460     setOperationAction(ISD::FP_TO_BF16, VT, Custom);
461   }
462 
463   setOperationAction(ISD::PARITY, MVT::i8, Custom);
464   setOperationAction(ISD::PARITY, MVT::i16, Custom);
465   setOperationAction(ISD::PARITY, MVT::i32, Custom);
466   if (Subtarget.is64Bit())
467     setOperationAction(ISD::PARITY, MVT::i64, Custom);
468   if (Subtarget.hasPOPCNT()) {
469     setOperationPromotedToType(ISD::CTPOP, MVT::i8, MVT::i32);
470     // popcntw is longer to encode than popcntl and also has a false dependency
471     // on the dest that popcntl hasn't had since Cannon Lake.
472     setOperationPromotedToType(ISD::CTPOP, MVT::i16, MVT::i32);
473   } else {
474     setOperationAction(ISD::CTPOP          , MVT::i8   , Custom);
475     setOperationAction(ISD::CTPOP          , MVT::i16  , Custom);
476     setOperationAction(ISD::CTPOP          , MVT::i32  , Custom);
477     setOperationAction(ISD::CTPOP          , MVT::i64  , Custom);
478   }
479 
480   setOperationAction(ISD::READCYCLECOUNTER , MVT::i64  , Custom);
481 
482   if (!Subtarget.hasMOVBE())
483     setOperationAction(ISD::BSWAP          , MVT::i16  , Expand);
484 
485   // X86 wants to expand cmov itself.
486   for (auto VT : { MVT::f32, MVT::f64, MVT::f80, MVT::f128 }) {
487     setOperationAction(ISD::SELECT, VT, Custom);
488     setOperationAction(ISD::SETCC, VT, Custom);
489     setOperationAction(ISD::STRICT_FSETCC, VT, Custom);
490     setOperationAction(ISD::STRICT_FSETCCS, VT, Custom);
491   }
492   for (auto VT : { MVT::i8, MVT::i16, MVT::i32, MVT::i64 }) {
493     if (VT == MVT::i64 && !Subtarget.is64Bit())
494       continue;
495     setOperationAction(ISD::SELECT, VT, Custom);
496     setOperationAction(ISD::SETCC,  VT, Custom);
497   }
498 
499   // Custom action for SELECT MMX and expand action for SELECT_CC MMX
500   setOperationAction(ISD::SELECT, MVT::x86mmx, Custom);
501   setOperationAction(ISD::SELECT_CC, MVT::x86mmx, Expand);
502 
503   setOperationAction(ISD::EH_RETURN       , MVT::Other, Custom);
504   // NOTE: EH_SJLJ_SETJMP/_LONGJMP are not recommended, since
505   // LLVM/Clang supports zero-cost DWARF and SEH exception handling.
506   setOperationAction(ISD::EH_SJLJ_SETJMP, MVT::i32, Custom);
507   setOperationAction(ISD::EH_SJLJ_LONGJMP, MVT::Other, Custom);
508   setOperationAction(ISD::EH_SJLJ_SETUP_DISPATCH, MVT::Other, Custom);
509   if (TM.Options.ExceptionModel == ExceptionHandling::SjLj)
510     setLibcallName(RTLIB::UNWIND_RESUME, "_Unwind_SjLj_Resume");
511 
512   // Darwin ABI issue.
513   for (auto VT : { MVT::i32, MVT::i64 }) {
514     if (VT == MVT::i64 && !Subtarget.is64Bit())
515       continue;
516     setOperationAction(ISD::ConstantPool    , VT, Custom);
517     setOperationAction(ISD::JumpTable       , VT, Custom);
518     setOperationAction(ISD::GlobalAddress   , VT, Custom);
519     setOperationAction(ISD::GlobalTLSAddress, VT, Custom);
520     setOperationAction(ISD::ExternalSymbol  , VT, Custom);
521     setOperationAction(ISD::BlockAddress    , VT, Custom);
522   }
523 
524   // 64-bit shl, sra, srl (iff 32-bit x86)
525   for (auto VT : { MVT::i32, MVT::i64 }) {
526     if (VT == MVT::i64 && !Subtarget.is64Bit())
527       continue;
528     setOperationAction(ISD::SHL_PARTS, VT, Custom);
529     setOperationAction(ISD::SRA_PARTS, VT, Custom);
530     setOperationAction(ISD::SRL_PARTS, VT, Custom);
531   }
532 
533   if (Subtarget.hasSSEPrefetch())
534     setOperationAction(ISD::PREFETCH      , MVT::Other, Custom);
535 
536   setOperationAction(ISD::ATOMIC_FENCE  , MVT::Other, Custom);
537 
538   // Expand certain atomics
539   for (auto VT : { MVT::i8, MVT::i16, MVT::i32, MVT::i64 }) {
540     setOperationAction(ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS, VT, Custom);
541     setOperationAction(ISD::ATOMIC_LOAD_SUB, VT, Custom);
542     setOperationAction(ISD::ATOMIC_LOAD_ADD, VT, Custom);
543     setOperationAction(ISD::ATOMIC_LOAD_OR, VT, Custom);
544     setOperationAction(ISD::ATOMIC_LOAD_XOR, VT, Custom);
545     setOperationAction(ISD::ATOMIC_LOAD_AND, VT, Custom);
546     setOperationAction(ISD::ATOMIC_STORE, VT, Custom);
547   }
548 
549   if (!Subtarget.is64Bit())
550     setOperationAction(ISD::ATOMIC_LOAD, MVT::i64, Custom);
551 
552   if (Subtarget.is64Bit() && Subtarget.hasAVX()) {
553     // All CPUs supporting AVX will atomically load/store aligned 128-bit
554     // values, so we can emit [V]MOVAPS/[V]MOVDQA.
555     setOperationAction(ISD::ATOMIC_LOAD, MVT::i128, Custom);
556     setOperationAction(ISD::ATOMIC_STORE, MVT::i128, Custom);
557   }
558 
559   if (Subtarget.canUseCMPXCHG16B())
560     setOperationAction(ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS, MVT::i128, Custom);
561 
562   // FIXME - use subtarget debug flags
563   if (!Subtarget.isTargetDarwin() && !Subtarget.isTargetELF() &&
564       !Subtarget.isTargetCygMing() && !Subtarget.isTargetWin64() &&
565       TM.Options.ExceptionModel != ExceptionHandling::SjLj) {
566     setOperationAction(ISD::EH_LABEL, MVT::Other, Expand);
567   }
568 
569   setOperationAction(ISD::FRAME_TO_ARGS_OFFSET, MVT::i32, Custom);
570   setOperationAction(ISD::FRAME_TO_ARGS_OFFSET, MVT::i64, Custom);
571 
572   setOperationAction(ISD::INIT_TRAMPOLINE, MVT::Other, Custom);
573   setOperationAction(ISD::ADJUST_TRAMPOLINE, MVT::Other, Custom);
574 
575   setOperationAction(ISD::TRAP, MVT::Other, Legal);
576   setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
577   if (Subtarget.isTargetPS())
578     setOperationAction(ISD::UBSANTRAP, MVT::Other, Expand);
579   else
580     setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal);
581 
582   // VASTART needs to be custom lowered to use the VarArgsFrameIndex
583   setOperationAction(ISD::VASTART           , MVT::Other, Custom);
584   setOperationAction(ISD::VAEND             , MVT::Other, Expand);
585   bool Is64Bit = Subtarget.is64Bit();
586   setOperationAction(ISD::VAARG,  MVT::Other, Is64Bit ? Custom : Expand);
587   setOperationAction(ISD::VACOPY, MVT::Other, Is64Bit ? Custom : Expand);
588 
589   setOperationAction(ISD::STACKSAVE,          MVT::Other, Expand);
590   setOperationAction(ISD::STACKRESTORE,       MVT::Other, Expand);
591 
592   setOperationAction(ISD::DYNAMIC_STACKALLOC, PtrVT, Custom);
593 
594   // GC_TRANSITION_START and GC_TRANSITION_END need custom lowering.
595   setOperationAction(ISD::GC_TRANSITION_START, MVT::Other, Custom);
596   setOperationAction(ISD::GC_TRANSITION_END, MVT::Other, Custom);
597 
598   setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Legal);
599 
600   auto setF16Action = [&] (MVT VT, LegalizeAction Action) {
601     setOperationAction(ISD::FABS, VT, Action);
602     setOperationAction(ISD::FNEG, VT, Action);
603     setOperationAction(ISD::FCOPYSIGN, VT, Expand);
604     setOperationAction(ISD::FREM, VT, Action);
605     setOperationAction(ISD::FMA, VT, Action);
606     setOperationAction(ISD::FMINNUM, VT, Action);
607     setOperationAction(ISD::FMAXNUM, VT, Action);
608     setOperationAction(ISD::FMINIMUM, VT, Action);
609     setOperationAction(ISD::FMAXIMUM, VT, Action);
610     setOperationAction(ISD::FSIN, VT, Action);
611     setOperationAction(ISD::FCOS, VT, Action);
612     setOperationAction(ISD::FSINCOS, VT, Action);
613     setOperationAction(ISD::FTAN, VT, Action);
614     setOperationAction(ISD::FSQRT, VT, Action);
615     setOperationAction(ISD::FPOW, VT, Action);
616     setOperationAction(ISD::FLOG, VT, Action);
617     setOperationAction(ISD::FLOG2, VT, Action);
618     setOperationAction(ISD::FLOG10, VT, Action);
619     setOperationAction(ISD::FEXP, VT, Action);
620     setOperationAction(ISD::FEXP2, VT, Action);
621     setOperationAction(ISD::FEXP10, VT, Action);
622     setOperationAction(ISD::FCEIL, VT, Action);
623     setOperationAction(ISD::FFLOOR, VT, Action);
624     setOperationAction(ISD::FNEARBYINT, VT, Action);
625     setOperationAction(ISD::FRINT, VT, Action);
626     setOperationAction(ISD::BR_CC, VT, Action);
627     setOperationAction(ISD::SETCC, VT, Action);
628     setOperationAction(ISD::SELECT, VT, Custom);
629     setOperationAction(ISD::SELECT_CC, VT, Action);
630     setOperationAction(ISD::FROUND, VT, Action);
631     setOperationAction(ISD::FROUNDEVEN, VT, Action);
632     setOperationAction(ISD::FTRUNC, VT, Action);
633     setOperationAction(ISD::FLDEXP, VT, Action);
634   };
635 
636   if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
637     // f16, f32 and f64 use SSE.
638     // Set up the FP register classes.
639     addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
640                                                      : &X86::FR16RegClass);
641     addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
642                                                      : &X86::FR32RegClass);
643     addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
644                                                      : &X86::FR64RegClass);
645 
646     // Disable f32->f64 extload as we can only generate this in one instruction
647     // under optsize. So its easier to pattern match (fpext (load)) for that
648     // case instead of needing to emit 2 instructions for extload in the
649     // non-optsize case.
650     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
651 
652     for (auto VT : { MVT::f32, MVT::f64 }) {
653       // Use ANDPD to simulate FABS.
654       setOperationAction(ISD::FABS, VT, Custom);
655 
656       // Use XORP to simulate FNEG.
657       setOperationAction(ISD::FNEG, VT, Custom);
658 
659       // Use ANDPD and ORPD to simulate FCOPYSIGN.
660       setOperationAction(ISD::FCOPYSIGN, VT, Custom);
661 
662       // These might be better off as horizontal vector ops.
663       setOperationAction(ISD::FADD, VT, Custom);
664       setOperationAction(ISD::FSUB, VT, Custom);
665 
666       // We don't support sin/cos/fmod
667       setOperationAction(ISD::FSIN   , VT, Expand);
668       setOperationAction(ISD::FCOS   , VT, Expand);
669       setOperationAction(ISD::FSINCOS, VT, Expand);
670     }
671 
672     // Half type will be promoted by default.
673     setF16Action(MVT::f16, Promote);
674     setOperationAction(ISD::FADD, MVT::f16, Promote);
675     setOperationAction(ISD::FSUB, MVT::f16, Promote);
676     setOperationAction(ISD::FMUL, MVT::f16, Promote);
677     setOperationAction(ISD::FDIV, MVT::f16, Promote);
678     setOperationAction(ISD::FP_ROUND, MVT::f16, Custom);
679     setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
680     setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
681 
682     setOperationAction(ISD::STRICT_FADD, MVT::f16, Promote);
683     setOperationAction(ISD::STRICT_FSUB, MVT::f16, Promote);
684     setOperationAction(ISD::STRICT_FMUL, MVT::f16, Promote);
685     setOperationAction(ISD::STRICT_FDIV, MVT::f16, Promote);
686     setOperationAction(ISD::STRICT_FMA, MVT::f16, Promote);
687     setOperationAction(ISD::STRICT_FMINNUM, MVT::f16, Promote);
688     setOperationAction(ISD::STRICT_FMAXNUM, MVT::f16, Promote);
689     setOperationAction(ISD::STRICT_FMINIMUM, MVT::f16, Promote);
690     setOperationAction(ISD::STRICT_FMAXIMUM, MVT::f16, Promote);
691     setOperationAction(ISD::STRICT_FSQRT, MVT::f16, Promote);
692     setOperationAction(ISD::STRICT_FPOW, MVT::f16, Promote);
693     setOperationAction(ISD::STRICT_FLDEXP, MVT::f16, Promote);
694     setOperationAction(ISD::STRICT_FLOG, MVT::f16, Promote);
695     setOperationAction(ISD::STRICT_FLOG2, MVT::f16, Promote);
696     setOperationAction(ISD::STRICT_FLOG10, MVT::f16, Promote);
697     setOperationAction(ISD::STRICT_FEXP, MVT::f16, Promote);
698     setOperationAction(ISD::STRICT_FEXP2, MVT::f16, Promote);
699     setOperationAction(ISD::STRICT_FCEIL, MVT::f16, Promote);
700     setOperationAction(ISD::STRICT_FFLOOR, MVT::f16, Promote);
701     setOperationAction(ISD::STRICT_FNEARBYINT, MVT::f16, Promote);
702     setOperationAction(ISD::STRICT_FRINT, MVT::f16, Promote);
703     setOperationAction(ISD::STRICT_FSETCC, MVT::f16, Promote);
704     setOperationAction(ISD::STRICT_FSETCCS, MVT::f16, Promote);
705     setOperationAction(ISD::STRICT_FROUND, MVT::f16, Promote);
706     setOperationAction(ISD::STRICT_FROUNDEVEN, MVT::f16, Promote);
707     setOperationAction(ISD::STRICT_FTRUNC, MVT::f16, Promote);
708     setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
709     setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
710     setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
711 
712     setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2");
713     setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
714 
715     // Lower this to MOVMSK plus an AND.
716     setOperationAction(ISD::FGETSIGN, MVT::i64, Custom);
717     setOperationAction(ISD::FGETSIGN, MVT::i32, Custom);
718 
719   } else if (!Subtarget.useSoftFloat() && Subtarget.hasSSE1() &&
720              (UseX87 || Is64Bit)) {
721     // Use SSE for f32, x87 for f64.
722     // Set up the FP register classes.
723     addRegisterClass(MVT::f32, &X86::FR32RegClass);
724     if (UseX87)
725       addRegisterClass(MVT::f64, &X86::RFP64RegClass);
726 
727     // Use ANDPS to simulate FABS.
728     setOperationAction(ISD::FABS , MVT::f32, Custom);
729 
730     // Use XORP to simulate FNEG.
731     setOperationAction(ISD::FNEG , MVT::f32, Custom);
732 
733     if (UseX87)
734       setOperationAction(ISD::UNDEF, MVT::f64, Expand);
735 
736     // Use ANDPS and ORPS to simulate FCOPYSIGN.
737     if (UseX87)
738       setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
739     setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
740 
741     // We don't support sin/cos/fmod
742     setOperationAction(ISD::FSIN   , MVT::f32, Expand);
743     setOperationAction(ISD::FCOS   , MVT::f32, Expand);
744     setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
745 
746     if (UseX87) {
747       // Always expand sin/cos functions even though x87 has an instruction.
748       setOperationAction(ISD::FSIN, MVT::f64, Expand);
749       setOperationAction(ISD::FCOS, MVT::f64, Expand);
750       setOperationAction(ISD::FSINCOS, MVT::f64, Expand);
751     }
752   } else if (UseX87) {
753     // f32 and f64 in x87.
754     // Set up the FP register classes.
755     addRegisterClass(MVT::f64, &X86::RFP64RegClass);
756     addRegisterClass(MVT::f32, &X86::RFP32RegClass);
757 
758     for (auto VT : { MVT::f32, MVT::f64 }) {
759       setOperationAction(ISD::UNDEF,     VT, Expand);
760       setOperationAction(ISD::FCOPYSIGN, VT, Expand);
761 
762       // Always expand sin/cos functions even though x87 has an instruction.
763       setOperationAction(ISD::FSIN   , VT, Expand);
764       setOperationAction(ISD::FCOS   , VT, Expand);
765       setOperationAction(ISD::FSINCOS, VT, Expand);
766     }
767   }
768 
769   // Expand FP32 immediates into loads from the stack, save special cases.
770   if (isTypeLegal(MVT::f32)) {
771     if (UseX87 && (getRegClassFor(MVT::f32) == &X86::RFP32RegClass)) {
772       addLegalFPImmediate(APFloat(+0.0f)); // FLD0
773       addLegalFPImmediate(APFloat(+1.0f)); // FLD1
774       addLegalFPImmediate(APFloat(-0.0f)); // FLD0/FCHS
775       addLegalFPImmediate(APFloat(-1.0f)); // FLD1/FCHS
776     } else // SSE immediates.
777       addLegalFPImmediate(APFloat(+0.0f)); // xorps
778   }
779   // Expand FP64 immediates into loads from the stack, save special cases.
780   if (isTypeLegal(MVT::f64)) {
781     if (UseX87 && getRegClassFor(MVT::f64) == &X86::RFP64RegClass) {
782       addLegalFPImmediate(APFloat(+0.0)); // FLD0
783       addLegalFPImmediate(APFloat(+1.0)); // FLD1
784       addLegalFPImmediate(APFloat(-0.0)); // FLD0/FCHS
785       addLegalFPImmediate(APFloat(-1.0)); // FLD1/FCHS
786     } else // SSE immediates.
787       addLegalFPImmediate(APFloat(+0.0)); // xorpd
788   }
789   // Support fp16 0 immediate.
790   if (isTypeLegal(MVT::f16))
791     addLegalFPImmediate(APFloat::getZero(APFloat::IEEEhalf()));
792 
793   // Handle constrained floating-point operations of scalar.
794   setOperationAction(ISD::STRICT_FADD,      MVT::f32, Legal);
795   setOperationAction(ISD::STRICT_FADD,      MVT::f64, Legal);
796   setOperationAction(ISD::STRICT_FSUB,      MVT::f32, Legal);
797   setOperationAction(ISD::STRICT_FSUB,      MVT::f64, Legal);
798   setOperationAction(ISD::STRICT_FMUL,      MVT::f32, Legal);
799   setOperationAction(ISD::STRICT_FMUL,      MVT::f64, Legal);
800   setOperationAction(ISD::STRICT_FDIV,      MVT::f32, Legal);
801   setOperationAction(ISD::STRICT_FDIV,      MVT::f64, Legal);
802   setOperationAction(ISD::STRICT_FP_ROUND,  MVT::f32, Legal);
803   setOperationAction(ISD::STRICT_FP_ROUND,  MVT::f64, Legal);
804   setOperationAction(ISD::STRICT_FSQRT,     MVT::f32, Legal);
805   setOperationAction(ISD::STRICT_FSQRT,     MVT::f64, Legal);
806 
807   // We don't support FMA.
808   setOperationAction(ISD::FMA, MVT::f64, Expand);
809   setOperationAction(ISD::FMA, MVT::f32, Expand);
810 
811   // f80 always uses X87.
812   if (UseX87) {
813     addRegisterClass(MVT::f80, &X86::RFP80RegClass);
814     setOperationAction(ISD::UNDEF,     MVT::f80, Expand);
815     setOperationAction(ISD::FCOPYSIGN, MVT::f80, Expand);
816     {
817       APFloat TmpFlt = APFloat::getZero(APFloat::x87DoubleExtended());
818       addLegalFPImmediate(TmpFlt);  // FLD0
819       TmpFlt.changeSign();
820       addLegalFPImmediate(TmpFlt);  // FLD0/FCHS
821 
822       bool ignored;
823       APFloat TmpFlt2(+1.0);
824       TmpFlt2.convert(APFloat::x87DoubleExtended(), APFloat::rmNearestTiesToEven,
825                       &ignored);
826       addLegalFPImmediate(TmpFlt2);  // FLD1
827       TmpFlt2.changeSign();
828       addLegalFPImmediate(TmpFlt2);  // FLD1/FCHS
829     }
830 
831     // Always expand sin/cos functions even though x87 has an instruction.
832     // clang-format off
833     setOperationAction(ISD::FSIN   , MVT::f80, Expand);
834     setOperationAction(ISD::FCOS   , MVT::f80, Expand);
835     setOperationAction(ISD::FSINCOS, MVT::f80, Expand);
836     setOperationAction(ISD::FTAN   , MVT::f80, Expand);
837     setOperationAction(ISD::FASIN  , MVT::f80, Expand);
838     setOperationAction(ISD::FACOS  , MVT::f80, Expand);
839     setOperationAction(ISD::FATAN  , MVT::f80, Expand);
840     setOperationAction(ISD::FSINH  , MVT::f80, Expand);
841     setOperationAction(ISD::FCOSH  , MVT::f80, Expand);
842     setOperationAction(ISD::FTANH  , MVT::f80, Expand);
843     // clang-format on
844 
845     setOperationAction(ISD::FFLOOR, MVT::f80, Expand);
846     setOperationAction(ISD::FCEIL,  MVT::f80, Expand);
847     setOperationAction(ISD::FTRUNC, MVT::f80, Expand);
848     setOperationAction(ISD::FRINT,  MVT::f80, Expand);
849     setOperationAction(ISD::FNEARBYINT, MVT::f80, Expand);
850     setOperationAction(ISD::FROUNDEVEN, MVT::f80, Expand);
851     setOperationAction(ISD::FMA, MVT::f80, Expand);
852     setOperationAction(ISD::LROUND, MVT::f80, Expand);
853     setOperationAction(ISD::LLROUND, MVT::f80, Expand);
854     setOperationAction(ISD::LRINT, MVT::f80, Custom);
855     setOperationAction(ISD::LLRINT, MVT::f80, Custom);
856 
857     // Handle constrained floating-point operations of scalar.
858     setOperationAction(ISD::STRICT_FADD     , MVT::f80, Legal);
859     setOperationAction(ISD::STRICT_FSUB     , MVT::f80, Legal);
860     setOperationAction(ISD::STRICT_FMUL     , MVT::f80, Legal);
861     setOperationAction(ISD::STRICT_FDIV     , MVT::f80, Legal);
862     setOperationAction(ISD::STRICT_FSQRT    , MVT::f80, Legal);
863     if (isTypeLegal(MVT::f16)) {
864       setOperationAction(ISD::FP_EXTEND, MVT::f80, Custom);
865       setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f80, Custom);
866     } else {
867       setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f80, Legal);
868     }
869     // FIXME: When the target is 64-bit, STRICT_FP_ROUND will be overwritten
870     // as Custom.
871     setOperationAction(ISD::STRICT_FP_ROUND, MVT::f80, Legal);
872   }
873 
874   // f128 uses xmm registers, but most operations require libcalls.
875   if (!Subtarget.useSoftFloat() && Subtarget.is64Bit() && Subtarget.hasSSE1()) {
876     addRegisterClass(MVT::f128, Subtarget.hasVLX() ? &X86::VR128XRegClass
877                                                    : &X86::VR128RegClass);
878 
879     addLegalFPImmediate(APFloat::getZero(APFloat::IEEEquad())); // xorps
880 
881     setOperationAction(ISD::FADD,        MVT::f128, LibCall);
882     setOperationAction(ISD::STRICT_FADD, MVT::f128, LibCall);
883     setOperationAction(ISD::FSUB,        MVT::f128, LibCall);
884     setOperationAction(ISD::STRICT_FSUB, MVT::f128, LibCall);
885     setOperationAction(ISD::FDIV,        MVT::f128, LibCall);
886     setOperationAction(ISD::STRICT_FDIV, MVT::f128, LibCall);
887     setOperationAction(ISD::FMUL,        MVT::f128, LibCall);
888     setOperationAction(ISD::STRICT_FMUL, MVT::f128, LibCall);
889     setOperationAction(ISD::FMA,         MVT::f128, LibCall);
890     setOperationAction(ISD::STRICT_FMA,  MVT::f128, LibCall);
891 
892     setOperationAction(ISD::FABS, MVT::f128, Custom);
893     setOperationAction(ISD::FNEG, MVT::f128, Custom);
894     setOperationAction(ISD::FCOPYSIGN, MVT::f128, Custom);
895 
896     // clang-format off
897     setOperationAction(ISD::FSIN,         MVT::f128, LibCall);
898     setOperationAction(ISD::STRICT_FSIN,  MVT::f128, LibCall);
899     setOperationAction(ISD::FCOS,         MVT::f128, LibCall);
900     setOperationAction(ISD::STRICT_FCOS,  MVT::f128, LibCall);
901     setOperationAction(ISD::FSINCOS,      MVT::f128, LibCall);
902     setOperationAction(ISD::FTAN,         MVT::f128, LibCall);
903     setOperationAction(ISD::STRICT_FTAN,  MVT::f128, LibCall);
904     // clang-format on
905     // No STRICT_FSINCOS
906     setOperationAction(ISD::FSQRT,        MVT::f128, LibCall);
907     setOperationAction(ISD::STRICT_FSQRT, MVT::f128, LibCall);
908 
909     setOperationAction(ISD::FP_EXTEND,        MVT::f128, Custom);
910     setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f128, Custom);
911     // We need to custom handle any FP_ROUND with an f128 input, but
912     // LegalizeDAG uses the result type to know when to run a custom handler.
913     // So we have to list all legal floating point result types here.
914     if (isTypeLegal(MVT::f32)) {
915       setOperationAction(ISD::FP_ROUND, MVT::f32, Custom);
916       setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
917     }
918     if (isTypeLegal(MVT::f64)) {
919       setOperationAction(ISD::FP_ROUND, MVT::f64, Custom);
920       setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
921     }
922     if (isTypeLegal(MVT::f80)) {
923       setOperationAction(ISD::FP_ROUND, MVT::f80, Custom);
924       setOperationAction(ISD::STRICT_FP_ROUND, MVT::f80, Custom);
925     }
926 
927     setOperationAction(ISD::SETCC, MVT::f128, Custom);
928 
929     setLoadExtAction(ISD::EXTLOAD, MVT::f128, MVT::f32, Expand);
930     setLoadExtAction(ISD::EXTLOAD, MVT::f128, MVT::f64, Expand);
931     setLoadExtAction(ISD::EXTLOAD, MVT::f128, MVT::f80, Expand);
932     setTruncStoreAction(MVT::f128, MVT::f32, Expand);
933     setTruncStoreAction(MVT::f128, MVT::f64, Expand);
934     setTruncStoreAction(MVT::f128, MVT::f80, Expand);
935   }
936 
937   // Always use a library call for pow.
938   setOperationAction(ISD::FPOW             , MVT::f32  , Expand);
939   setOperationAction(ISD::FPOW             , MVT::f64  , Expand);
940   setOperationAction(ISD::FPOW             , MVT::f80  , Expand);
941   setOperationAction(ISD::FPOW             , MVT::f128 , Expand);
942 
943   setOperationAction(ISD::FLOG, MVT::f80, Expand);
944   setOperationAction(ISD::FLOG2, MVT::f80, Expand);
945   setOperationAction(ISD::FLOG10, MVT::f80, Expand);
946   setOperationAction(ISD::FEXP, MVT::f80, Expand);
947   setOperationAction(ISD::FEXP2, MVT::f80, Expand);
948   setOperationAction(ISD::FEXP10, MVT::f80, Expand);
949   setOperationAction(ISD::FMINNUM, MVT::f80, Expand);
950   setOperationAction(ISD::FMAXNUM, MVT::f80, Expand);
951 
952   // Some FP actions are always expanded for vector types.
953   for (auto VT : { MVT::v8f16, MVT::v16f16, MVT::v32f16,
954                    MVT::v4f32, MVT::v8f32,  MVT::v16f32,
955                    MVT::v2f64, MVT::v4f64,  MVT::v8f64 }) {
956     // clang-format off
957     setOperationAction(ISD::FSIN,      VT, Expand);
958     setOperationAction(ISD::FSINCOS,   VT, Expand);
959     setOperationAction(ISD::FCOS,      VT, Expand);
960     setOperationAction(ISD::FTAN,      VT, Expand);
961     setOperationAction(ISD::FREM,      VT, Expand);
962     setOperationAction(ISD::FCOPYSIGN, VT, Expand);
963     setOperationAction(ISD::FPOW,      VT, Expand);
964     setOperationAction(ISD::FLOG,      VT, Expand);
965     setOperationAction(ISD::FLOG2,     VT, Expand);
966     setOperationAction(ISD::FLOG10,    VT, Expand);
967     setOperationAction(ISD::FEXP,      VT, Expand);
968     setOperationAction(ISD::FEXP2,     VT, Expand);
969     setOperationAction(ISD::FEXP10,    VT, Expand);
970     // clang-format on
971   }
972 
973   // First set operation action for all vector types to either promote
974   // (for widening) or expand (for scalarization). Then we will selectively
975   // turn on ones that can be effectively codegen'd.
976   for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
977     setOperationAction(ISD::SDIV, VT, Expand);
978     setOperationAction(ISD::UDIV, VT, Expand);
979     setOperationAction(ISD::SREM, VT, Expand);
980     setOperationAction(ISD::UREM, VT, Expand);
981     setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT,Expand);
982     setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Expand);
983     setOperationAction(ISD::EXTRACT_SUBVECTOR, VT,Expand);
984     setOperationAction(ISD::INSERT_SUBVECTOR, VT,Expand);
985     setOperationAction(ISD::FMA,  VT, Expand);
986     setOperationAction(ISD::FFLOOR, VT, Expand);
987     setOperationAction(ISD::FCEIL, VT, Expand);
988     setOperationAction(ISD::FTRUNC, VT, Expand);
989     setOperationAction(ISD::FRINT, VT, Expand);
990     setOperationAction(ISD::FNEARBYINT, VT, Expand);
991     setOperationAction(ISD::FROUNDEVEN, VT, Expand);
992     setOperationAction(ISD::SMUL_LOHI, VT, Expand);
993     setOperationAction(ISD::MULHS, VT, Expand);
994     setOperationAction(ISD::UMUL_LOHI, VT, Expand);
995     setOperationAction(ISD::MULHU, VT, Expand);
996     setOperationAction(ISD::SDIVREM, VT, Expand);
997     setOperationAction(ISD::UDIVREM, VT, Expand);
998     setOperationAction(ISD::CTPOP, VT, Expand);
999     setOperationAction(ISD::CTTZ, VT, Expand);
1000     setOperationAction(ISD::CTLZ, VT, Expand);
1001     setOperationAction(ISD::ROTL, VT, Expand);
1002     setOperationAction(ISD::ROTR, VT, Expand);
1003     setOperationAction(ISD::BSWAP, VT, Expand);
1004     setOperationAction(ISD::SETCC, VT, Expand);
1005     setOperationAction(ISD::FP_TO_UINT, VT, Expand);
1006     setOperationAction(ISD::FP_TO_SINT, VT, Expand);
1007     setOperationAction(ISD::UINT_TO_FP, VT, Expand);
1008     setOperationAction(ISD::SINT_TO_FP, VT, Expand);
1009     setOperationAction(ISD::SIGN_EXTEND_INREG, VT,Expand);
1010     setOperationAction(ISD::TRUNCATE, VT, Expand);
1011     setOperationAction(ISD::SIGN_EXTEND, VT, Expand);
1012     setOperationAction(ISD::ZERO_EXTEND, VT, Expand);
1013     setOperationAction(ISD::ANY_EXTEND, VT, Expand);
1014     setOperationAction(ISD::SELECT_CC, VT, Expand);
1015     for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
1016       setTruncStoreAction(InnerVT, VT, Expand);
1017 
1018       setLoadExtAction(ISD::SEXTLOAD, InnerVT, VT, Expand);
1019       setLoadExtAction(ISD::ZEXTLOAD, InnerVT, VT, Expand);
1020 
1021       // N.b. ISD::EXTLOAD legality is basically ignored except for i1-like
1022       // types, we have to deal with them whether we ask for Expansion or not.
1023       // Setting Expand causes its own optimisation problems though, so leave
1024       // them legal.
1025       if (VT.getVectorElementType() == MVT::i1)
1026         setLoadExtAction(ISD::EXTLOAD, InnerVT, VT, Expand);
1027 
1028       // EXTLOAD for MVT::f16 vectors is not legal because f16 vectors are
1029       // split/scalarized right now.
1030       if (VT.getVectorElementType() == MVT::f16 ||
1031           VT.getVectorElementType() == MVT::bf16)
1032         setLoadExtAction(ISD::EXTLOAD, InnerVT, VT, Expand);
1033     }
1034   }
1035 
1036   // FIXME: In order to prevent SSE instructions being expanded to MMX ones
1037   // with -msoft-float, disable use of MMX as well.
1038   if (!Subtarget.useSoftFloat() && Subtarget.hasMMX()) {
1039     addRegisterClass(MVT::x86mmx, &X86::VR64RegClass);
1040     // No operations on x86mmx supported, everything uses intrinsics.
1041   }
1042 
1043   if (!Subtarget.useSoftFloat() && Subtarget.hasSSE1()) {
1044     addRegisterClass(MVT::v4f32, Subtarget.hasVLX() ? &X86::VR128XRegClass
1045                                                     : &X86::VR128RegClass);
1046 
1047     setOperationAction(ISD::FMAXIMUM,           MVT::f32, Custom);
1048     setOperationAction(ISD::FMINIMUM,           MVT::f32, Custom);
1049 
1050     setOperationAction(ISD::FNEG,               MVT::v4f32, Custom);
1051     setOperationAction(ISD::FABS,               MVT::v4f32, Custom);
1052     setOperationAction(ISD::FCOPYSIGN,          MVT::v4f32, Custom);
1053     setOperationAction(ISD::BUILD_VECTOR,       MVT::v4f32, Custom);
1054     setOperationAction(ISD::VECTOR_SHUFFLE,     MVT::v4f32, Custom);
1055     setOperationAction(ISD::VSELECT,            MVT::v4f32, Custom);
1056     setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4f32, Custom);
1057     setOperationAction(ISD::SELECT,             MVT::v4f32, Custom);
1058 
1059     setOperationAction(ISD::LOAD,               MVT::v2f32, Custom);
1060     setOperationAction(ISD::STORE,              MVT::v2f32, Custom);
1061 
1062     setOperationAction(ISD::STRICT_FADD,        MVT::v4f32, Legal);
1063     setOperationAction(ISD::STRICT_FSUB,        MVT::v4f32, Legal);
1064     setOperationAction(ISD::STRICT_FMUL,        MVT::v4f32, Legal);
1065     setOperationAction(ISD::STRICT_FDIV,        MVT::v4f32, Legal);
1066     setOperationAction(ISD::STRICT_FSQRT,       MVT::v4f32, Legal);
1067   }
1068 
1069   if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
1070     addRegisterClass(MVT::v2f64, Subtarget.hasVLX() ? &X86::VR128XRegClass
1071                                                     : &X86::VR128RegClass);
1072 
1073     // FIXME: Unfortunately, -soft-float and -no-implicit-float mean XMM
1074     // registers cannot be used even for integer operations.
1075     addRegisterClass(MVT::v16i8, Subtarget.hasVLX() ? &X86::VR128XRegClass
1076                                                     : &X86::VR128RegClass);
1077     addRegisterClass(MVT::v8i16, Subtarget.hasVLX() ? &X86::VR128XRegClass
1078                                                     : &X86::VR128RegClass);
1079     addRegisterClass(MVT::v8f16, Subtarget.hasVLX() ? &X86::VR128XRegClass
1080                                                     : &X86::VR128RegClass);
1081     addRegisterClass(MVT::v4i32, Subtarget.hasVLX() ? &X86::VR128XRegClass
1082                                                     : &X86::VR128RegClass);
1083     addRegisterClass(MVT::v2i64, Subtarget.hasVLX() ? &X86::VR128XRegClass
1084                                                     : &X86::VR128RegClass);
1085 
1086     for (auto VT : { MVT::f64, MVT::v4f32, MVT::v2f64 }) {
1087       setOperationAction(ISD::FMAXIMUM, VT, Custom);
1088       setOperationAction(ISD::FMINIMUM, VT, Custom);
1089     }
1090 
1091     for (auto VT : { MVT::v2i8, MVT::v4i8, MVT::v8i8,
1092                      MVT::v2i16, MVT::v4i16, MVT::v2i32 }) {
1093       setOperationAction(ISD::SDIV, VT, Custom);
1094       setOperationAction(ISD::SREM, VT, Custom);
1095       setOperationAction(ISD::UDIV, VT, Custom);
1096       setOperationAction(ISD::UREM, VT, Custom);
1097     }
1098 
1099     setOperationAction(ISD::MUL,                MVT::v2i8,  Custom);
1100     setOperationAction(ISD::MUL,                MVT::v4i8,  Custom);
1101     setOperationAction(ISD::MUL,                MVT::v8i8,  Custom);
1102 
1103     setOperationAction(ISD::MUL,                MVT::v16i8, Custom);
1104     setOperationAction(ISD::MUL,                MVT::v4i32, Custom);
1105     setOperationAction(ISD::MUL,                MVT::v2i64, Custom);
1106     setOperationAction(ISD::MULHU,              MVT::v4i32, Custom);
1107     setOperationAction(ISD::MULHS,              MVT::v4i32, Custom);
1108     setOperationAction(ISD::MULHU,              MVT::v16i8, Custom);
1109     setOperationAction(ISD::MULHS,              MVT::v16i8, Custom);
1110     setOperationAction(ISD::MULHU,              MVT::v8i16, Legal);
1111     setOperationAction(ISD::MULHS,              MVT::v8i16, Legal);
1112     setOperationAction(ISD::MUL,                MVT::v8i16, Legal);
1113     setOperationAction(ISD::AVGCEILU,           MVT::v16i8, Legal);
1114     setOperationAction(ISD::AVGCEILU,           MVT::v8i16, Legal);
1115 
1116     setOperationAction(ISD::SMULO,              MVT::v16i8, Custom);
1117     setOperationAction(ISD::UMULO,              MVT::v16i8, Custom);
1118     setOperationAction(ISD::UMULO,              MVT::v2i32, Custom);
1119 
1120     setOperationAction(ISD::FNEG,               MVT::v2f64, Custom);
1121     setOperationAction(ISD::FABS,               MVT::v2f64, Custom);
1122     setOperationAction(ISD::FCOPYSIGN,          MVT::v2f64, Custom);
1123 
1124     setOperationAction(ISD::LRINT, MVT::v4f32, Custom);
1125 
1126     for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1127       setOperationAction(ISD::SMAX, VT, VT == MVT::v8i16 ? Legal : Custom);
1128       setOperationAction(ISD::SMIN, VT, VT == MVT::v8i16 ? Legal : Custom);
1129       setOperationAction(ISD::UMAX, VT, VT == MVT::v16i8 ? Legal : Custom);
1130       setOperationAction(ISD::UMIN, VT, VT == MVT::v16i8 ? Legal : Custom);
1131     }
1132 
1133     setOperationAction(ISD::UADDSAT,            MVT::v16i8, Legal);
1134     setOperationAction(ISD::SADDSAT,            MVT::v16i8, Legal);
1135     setOperationAction(ISD::USUBSAT,            MVT::v16i8, Legal);
1136     setOperationAction(ISD::SSUBSAT,            MVT::v16i8, Legal);
1137     setOperationAction(ISD::UADDSAT,            MVT::v8i16, Legal);
1138     setOperationAction(ISD::SADDSAT,            MVT::v8i16, Legal);
1139     setOperationAction(ISD::USUBSAT,            MVT::v8i16, Legal);
1140     setOperationAction(ISD::SSUBSAT,            MVT::v8i16, Legal);
1141     setOperationAction(ISD::USUBSAT,            MVT::v4i32, Custom);
1142     setOperationAction(ISD::USUBSAT,            MVT::v2i64, Custom);
1143 
1144     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v16i8, Custom);
1145     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v8i16, Custom);
1146     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v4i32, Custom);
1147     setOperationAction(ISD::INSERT_VECTOR_ELT,  MVT::v4f32, Custom);
1148 
1149     for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1150       setOperationAction(ISD::SETCC, VT, Custom);
1151       setOperationAction(ISD::CTPOP, VT, Custom);
1152       setOperationAction(ISD::ABS, VT, Custom);
1153       setOperationAction(ISD::ABDS, VT, Custom);
1154       setOperationAction(ISD::ABDU, VT, Custom);
1155 
1156       // The condition codes aren't legal in SSE/AVX and under AVX512 we use
1157       // setcc all the way to isel and prefer SETGT in some isel patterns.
1158       setCondCodeAction(ISD::SETLT, VT, Custom);
1159       setCondCodeAction(ISD::SETLE, VT, Custom);
1160     }
1161 
1162     setOperationAction(ISD::SETCC,          MVT::v2f64, Custom);
1163     setOperationAction(ISD::SETCC,          MVT::v4f32, Custom);
1164     setOperationAction(ISD::STRICT_FSETCC,  MVT::v2f64, Custom);
1165     setOperationAction(ISD::STRICT_FSETCC,  MVT::v4f32, Custom);
1166     setOperationAction(ISD::STRICT_FSETCCS, MVT::v2f64, Custom);
1167     setOperationAction(ISD::STRICT_FSETCCS, MVT::v4f32, Custom);
1168 
1169     for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
1170       setOperationAction(ISD::SCALAR_TO_VECTOR,   VT, Custom);
1171       setOperationAction(ISD::BUILD_VECTOR,       VT, Custom);
1172       setOperationAction(ISD::VECTOR_SHUFFLE,     VT, Custom);
1173       setOperationAction(ISD::VSELECT,            VT, Custom);
1174       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1175     }
1176 
1177     for (auto VT : { MVT::v8f16, MVT::v2f64, MVT::v2i64 }) {
1178       setOperationAction(ISD::BUILD_VECTOR,       VT, Custom);
1179       setOperationAction(ISD::VECTOR_SHUFFLE,     VT, Custom);
1180       setOperationAction(ISD::VSELECT,            VT, Custom);
1181 
1182       if (VT == MVT::v2i64 && !Subtarget.is64Bit())
1183         continue;
1184 
1185       setOperationAction(ISD::INSERT_VECTOR_ELT,  VT, Custom);
1186       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1187     }
1188     setF16Action(MVT::v8f16, Expand);
1189     setOperationAction(ISD::FADD, MVT::v8f16, Expand);
1190     setOperationAction(ISD::FSUB, MVT::v8f16, Expand);
1191     setOperationAction(ISD::FMUL, MVT::v8f16, Expand);
1192     setOperationAction(ISD::FDIV, MVT::v8f16, Expand);
1193     setOperationAction(ISD::FNEG, MVT::v8f16, Custom);
1194     setOperationAction(ISD::FABS, MVT::v8f16, Custom);
1195     setOperationAction(ISD::FCOPYSIGN, MVT::v8f16, Custom);
1196 
1197     // Custom lower v2i64 and v2f64 selects.
1198     setOperationAction(ISD::SELECT,             MVT::v2f64, Custom);
1199     setOperationAction(ISD::SELECT,             MVT::v2i64, Custom);
1200     setOperationAction(ISD::SELECT,             MVT::v4i32, Custom);
1201     setOperationAction(ISD::SELECT,             MVT::v8i16, Custom);
1202     setOperationAction(ISD::SELECT,             MVT::v8f16, Custom);
1203     setOperationAction(ISD::SELECT,             MVT::v16i8, Custom);
1204 
1205     setOperationAction(ISD::FP_TO_SINT,         MVT::v4i32, Custom);
1206     setOperationAction(ISD::FP_TO_UINT,         MVT::v4i32, Custom);
1207     setOperationAction(ISD::FP_TO_SINT,         MVT::v2i32, Custom);
1208     setOperationAction(ISD::FP_TO_UINT,         MVT::v2i32, Custom);
1209     setOperationAction(ISD::STRICT_FP_TO_SINT,  MVT::v4i32, Custom);
1210     setOperationAction(ISD::STRICT_FP_TO_SINT,  MVT::v2i32, Custom);
1211 
1212     // Custom legalize these to avoid over promotion or custom promotion.
1213     for (auto VT : {MVT::v2i8, MVT::v4i8, MVT::v8i8, MVT::v2i16, MVT::v4i16}) {
1214       setOperationAction(ISD::FP_TO_SINT,        VT, Custom);
1215       setOperationAction(ISD::FP_TO_UINT,        VT, Custom);
1216       setOperationAction(ISD::STRICT_FP_TO_SINT, VT, Custom);
1217       setOperationAction(ISD::STRICT_FP_TO_UINT, VT, Custom);
1218     }
1219 
1220     setOperationAction(ISD::SINT_TO_FP,         MVT::v4i32, Custom);
1221     setOperationAction(ISD::STRICT_SINT_TO_FP,  MVT::v4i32, Custom);
1222     setOperationAction(ISD::SINT_TO_FP,         MVT::v2i32, Custom);
1223     setOperationAction(ISD::STRICT_SINT_TO_FP,  MVT::v2i32, Custom);
1224 
1225     setOperationAction(ISD::UINT_TO_FP,         MVT::v2i32, Custom);
1226     setOperationAction(ISD::STRICT_UINT_TO_FP,  MVT::v2i32, Custom);
1227 
1228     setOperationAction(ISD::UINT_TO_FP,         MVT::v4i32, Custom);
1229     setOperationAction(ISD::STRICT_UINT_TO_FP,  MVT::v4i32, Custom);
1230 
1231     // Fast v2f32 UINT_TO_FP( v2i32 ) custom conversion.
1232     setOperationAction(ISD::SINT_TO_FP,         MVT::v2f32, Custom);
1233     setOperationAction(ISD::STRICT_SINT_TO_FP,  MVT::v2f32, Custom);
1234     setOperationAction(ISD::UINT_TO_FP,         MVT::v2f32, Custom);
1235     setOperationAction(ISD::STRICT_UINT_TO_FP,  MVT::v2f32, Custom);
1236 
1237     setOperationAction(ISD::FP_EXTEND,          MVT::v2f32, Custom);
1238     setOperationAction(ISD::STRICT_FP_EXTEND,   MVT::v2f32, Custom);
1239     setOperationAction(ISD::FP_ROUND,           MVT::v2f32, Custom);
1240     setOperationAction(ISD::STRICT_FP_ROUND,    MVT::v2f32, Custom);
1241 
1242     // We want to legalize this to an f64 load rather than an i64 load on
1243     // 64-bit targets and two 32-bit loads on a 32-bit target. Similar for
1244     // store.
1245     setOperationAction(ISD::LOAD,               MVT::v2i32, Custom);
1246     setOperationAction(ISD::LOAD,               MVT::v4i16, Custom);
1247     setOperationAction(ISD::LOAD,               MVT::v8i8,  Custom);
1248     setOperationAction(ISD::STORE,              MVT::v2i32, Custom);
1249     setOperationAction(ISD::STORE,              MVT::v4i16, Custom);
1250     setOperationAction(ISD::STORE,              MVT::v8i8,  Custom);
1251 
1252     // Add 32-bit vector stores to help vectorization opportunities.
1253     setOperationAction(ISD::STORE,              MVT::v2i16, Custom);
1254     setOperationAction(ISD::STORE,              MVT::v4i8,  Custom);
1255 
1256     setOperationAction(ISD::BITCAST,            MVT::v2i32, Custom);
1257     setOperationAction(ISD::BITCAST,            MVT::v4i16, Custom);
1258     setOperationAction(ISD::BITCAST,            MVT::v8i8,  Custom);
1259     if (!Subtarget.hasAVX512())
1260       setOperationAction(ISD::BITCAST, MVT::v16i1, Custom);
1261 
1262     setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v2i64, Custom);
1263     setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v4i32, Custom);
1264     setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v8i16, Custom);
1265 
1266     setOperationAction(ISD::SIGN_EXTEND, MVT::v4i64, Custom);
1267 
1268     setOperationAction(ISD::TRUNCATE,    MVT::v2i8,  Custom);
1269     setOperationAction(ISD::TRUNCATE,    MVT::v2i16, Custom);
1270     setOperationAction(ISD::TRUNCATE,    MVT::v2i32, Custom);
1271     setOperationAction(ISD::TRUNCATE,    MVT::v2i64, Custom);
1272     setOperationAction(ISD::TRUNCATE,    MVT::v4i8,  Custom);
1273     setOperationAction(ISD::TRUNCATE,    MVT::v4i16, Custom);
1274     setOperationAction(ISD::TRUNCATE,    MVT::v4i32, Custom);
1275     setOperationAction(ISD::TRUNCATE,    MVT::v4i64, Custom);
1276     setOperationAction(ISD::TRUNCATE,    MVT::v8i8,  Custom);
1277     setOperationAction(ISD::TRUNCATE,    MVT::v8i16, Custom);
1278     setOperationAction(ISD::TRUNCATE,    MVT::v8i32, Custom);
1279     setOperationAction(ISD::TRUNCATE,    MVT::v8i64, Custom);
1280     setOperationAction(ISD::TRUNCATE,    MVT::v16i8, Custom);
1281     setOperationAction(ISD::TRUNCATE,    MVT::v16i16, Custom);
1282     setOperationAction(ISD::TRUNCATE,    MVT::v16i32, Custom);
1283     setOperationAction(ISD::TRUNCATE,    MVT::v16i64, Custom);
1284 
1285     // In the customized shift lowering, the legal v4i32/v2i64 cases
1286     // in AVX2 will be recognized.
1287     for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1288       setOperationAction(ISD::SRL,              VT, Custom);
1289       setOperationAction(ISD::SHL,              VT, Custom);
1290       setOperationAction(ISD::SRA,              VT, Custom);
1291       if (VT == MVT::v2i64) continue;
1292       setOperationAction(ISD::ROTL,             VT, Custom);
1293       setOperationAction(ISD::ROTR,             VT, Custom);
1294       setOperationAction(ISD::FSHL,             VT, Custom);
1295       setOperationAction(ISD::FSHR,             VT, Custom);
1296     }
1297 
1298     setOperationAction(ISD::STRICT_FSQRT,       MVT::v2f64, Legal);
1299     setOperationAction(ISD::STRICT_FADD,        MVT::v2f64, Legal);
1300     setOperationAction(ISD::STRICT_FSUB,        MVT::v2f64, Legal);
1301     setOperationAction(ISD::STRICT_FMUL,        MVT::v2f64, Legal);
1302     setOperationAction(ISD::STRICT_FDIV,        MVT::v2f64, Legal);
1303   }
1304 
1305   if (Subtarget.hasGFNI()) {
1306     setOperationAction(ISD::BITREVERSE, MVT::i8, Custom);
1307     setOperationAction(ISD::BITREVERSE, MVT::i16, Custom);
1308     setOperationAction(ISD::BITREVERSE, MVT::i32, Custom);
1309     setOperationAction(ISD::BITREVERSE, MVT::i64, Custom);
1310   }
1311 
1312   if (!Subtarget.useSoftFloat() && Subtarget.hasSSSE3()) {
1313     setOperationAction(ISD::ABS,                MVT::v16i8, Legal);
1314     setOperationAction(ISD::ABS,                MVT::v8i16, Legal);
1315     setOperationAction(ISD::ABS,                MVT::v4i32, Legal);
1316 
1317     for (auto VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64}) {
1318       setOperationAction(ISD::BITREVERSE,       VT, Custom);
1319       setOperationAction(ISD::CTLZ,             VT, Custom);
1320     }
1321 
1322     // These might be better off as horizontal vector ops.
1323     setOperationAction(ISD::ADD,                MVT::i16, Custom);
1324     setOperationAction(ISD::ADD,                MVT::i32, Custom);
1325     setOperationAction(ISD::SUB,                MVT::i16, Custom);
1326     setOperationAction(ISD::SUB,                MVT::i32, Custom);
1327   }
1328 
1329   if (!Subtarget.useSoftFloat() && Subtarget.hasSSE41()) {
1330     for (MVT RoundedTy : {MVT::f32, MVT::f64, MVT::v4f32, MVT::v2f64}) {
1331       setOperationAction(ISD::FFLOOR,            RoundedTy,  Legal);
1332       setOperationAction(ISD::STRICT_FFLOOR,     RoundedTy,  Legal);
1333       setOperationAction(ISD::FCEIL,             RoundedTy,  Legal);
1334       setOperationAction(ISD::STRICT_FCEIL,      RoundedTy,  Legal);
1335       setOperationAction(ISD::FTRUNC,            RoundedTy,  Legal);
1336       setOperationAction(ISD::STRICT_FTRUNC,     RoundedTy,  Legal);
1337       setOperationAction(ISD::FRINT,             RoundedTy,  Legal);
1338       setOperationAction(ISD::STRICT_FRINT,      RoundedTy,  Legal);
1339       setOperationAction(ISD::FNEARBYINT,        RoundedTy,  Legal);
1340       setOperationAction(ISD::STRICT_FNEARBYINT, RoundedTy,  Legal);
1341       setOperationAction(ISD::FROUNDEVEN,        RoundedTy,  Legal);
1342       setOperationAction(ISD::STRICT_FROUNDEVEN, RoundedTy,  Legal);
1343 
1344       setOperationAction(ISD::FROUND,            RoundedTy,  Custom);
1345     }
1346 
1347     setOperationAction(ISD::SMAX,               MVT::v16i8, Legal);
1348     setOperationAction(ISD::SMAX,               MVT::v4i32, Legal);
1349     setOperationAction(ISD::UMAX,               MVT::v8i16, Legal);
1350     setOperationAction(ISD::UMAX,               MVT::v4i32, Legal);
1351     setOperationAction(ISD::SMIN,               MVT::v16i8, Legal);
1352     setOperationAction(ISD::SMIN,               MVT::v4i32, Legal);
1353     setOperationAction(ISD::UMIN,               MVT::v8i16, Legal);
1354     setOperationAction(ISD::UMIN,               MVT::v4i32, Legal);
1355 
1356     setOperationAction(ISD::UADDSAT,            MVT::v4i32, Custom);
1357     setOperationAction(ISD::SADDSAT,            MVT::v2i64, Custom);
1358     setOperationAction(ISD::SSUBSAT,            MVT::v2i64, Custom);
1359 
1360     // FIXME: Do we need to handle scalar-to-vector here?
1361     setOperationAction(ISD::MUL,                MVT::v4i32, Legal);
1362     setOperationAction(ISD::SMULO,              MVT::v2i32, Custom);
1363 
1364     // We directly match byte blends in the backend as they match the VSELECT
1365     // condition form.
1366     setOperationAction(ISD::VSELECT,            MVT::v16i8, Legal);
1367 
1368     // SSE41 brings specific instructions for doing vector sign extend even in
1369     // cases where we don't have SRA.
1370     for (auto VT : { MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1371       setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Legal);
1372       setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Legal);
1373     }
1374 
1375     // SSE41 also has vector sign/zero extending loads, PMOV[SZ]X
1376     for (auto LoadExtOp : { ISD::SEXTLOAD, ISD::ZEXTLOAD }) {
1377       setLoadExtAction(LoadExtOp, MVT::v8i16, MVT::v8i8,  Legal);
1378       setLoadExtAction(LoadExtOp, MVT::v4i32, MVT::v4i8,  Legal);
1379       setLoadExtAction(LoadExtOp, MVT::v2i64, MVT::v2i8,  Legal);
1380       setLoadExtAction(LoadExtOp, MVT::v4i32, MVT::v4i16, Legal);
1381       setLoadExtAction(LoadExtOp, MVT::v2i64, MVT::v2i16, Legal);
1382       setLoadExtAction(LoadExtOp, MVT::v2i64, MVT::v2i32, Legal);
1383     }
1384 
1385     if (Subtarget.is64Bit() && !Subtarget.hasAVX512()) {
1386       // We need to scalarize v4i64->v432 uint_to_fp using cvtsi2ss, but we can
1387       // do the pre and post work in the vector domain.
1388       setOperationAction(ISD::UINT_TO_FP,        MVT::v4i64, Custom);
1389       setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v4i64, Custom);
1390       // We need to mark SINT_TO_FP as Custom even though we want to expand it
1391       // so that DAG combine doesn't try to turn it into uint_to_fp.
1392       setOperationAction(ISD::SINT_TO_FP,        MVT::v4i64, Custom);
1393       setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::v4i64, Custom);
1394     }
1395   }
1396 
1397   if (!Subtarget.useSoftFloat() && Subtarget.hasSSE42()) {
1398     setOperationAction(ISD::UADDSAT,            MVT::v2i64, Custom);
1399   }
1400 
1401   if (!Subtarget.useSoftFloat() && Subtarget.hasXOP()) {
1402     for (auto VT : { MVT::v16i8, MVT::v8i16,  MVT::v4i32, MVT::v2i64,
1403                      MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) {
1404       setOperationAction(ISD::ROTL, VT, Custom);
1405       setOperationAction(ISD::ROTR, VT, Custom);
1406     }
1407 
1408     // XOP can efficiently perform BITREVERSE with VPPERM.
1409     for (auto VT : { MVT::i8, MVT::i16, MVT::i32, MVT::i64 })
1410       setOperationAction(ISD::BITREVERSE, VT, Custom);
1411   }
1412 
1413   if (!Subtarget.useSoftFloat() && Subtarget.hasAVX()) {
1414     bool HasInt256 = Subtarget.hasInt256();
1415 
1416     addRegisterClass(MVT::v32i8,  Subtarget.hasVLX() ? &X86::VR256XRegClass
1417                                                      : &X86::VR256RegClass);
1418     addRegisterClass(MVT::v16i16, Subtarget.hasVLX() ? &X86::VR256XRegClass
1419                                                      : &X86::VR256RegClass);
1420     addRegisterClass(MVT::v16f16, Subtarget.hasVLX() ? &X86::VR256XRegClass
1421                                                      : &X86::VR256RegClass);
1422     addRegisterClass(MVT::v8i32,  Subtarget.hasVLX() ? &X86::VR256XRegClass
1423                                                      : &X86::VR256RegClass);
1424     addRegisterClass(MVT::v8f32,  Subtarget.hasVLX() ? &X86::VR256XRegClass
1425                                                      : &X86::VR256RegClass);
1426     addRegisterClass(MVT::v4i64,  Subtarget.hasVLX() ? &X86::VR256XRegClass
1427                                                      : &X86::VR256RegClass);
1428     addRegisterClass(MVT::v4f64,  Subtarget.hasVLX() ? &X86::VR256XRegClass
1429                                                      : &X86::VR256RegClass);
1430 
1431     for (auto VT : { MVT::v8f32, MVT::v4f64 }) {
1432       setOperationAction(ISD::FFLOOR,            VT, Legal);
1433       setOperationAction(ISD::STRICT_FFLOOR,     VT, Legal);
1434       setOperationAction(ISD::FCEIL,             VT, Legal);
1435       setOperationAction(ISD::STRICT_FCEIL,      VT, Legal);
1436       setOperationAction(ISD::FTRUNC,            VT, Legal);
1437       setOperationAction(ISD::STRICT_FTRUNC,     VT, Legal);
1438       setOperationAction(ISD::FRINT,             VT, Legal);
1439       setOperationAction(ISD::STRICT_FRINT,      VT, Legal);
1440       setOperationAction(ISD::FNEARBYINT,        VT, Legal);
1441       setOperationAction(ISD::STRICT_FNEARBYINT, VT, Legal);
1442       setOperationAction(ISD::FROUNDEVEN,        VT, Legal);
1443       setOperationAction(ISD::STRICT_FROUNDEVEN, VT, Legal);
1444 
1445       setOperationAction(ISD::FROUND,            VT, Custom);
1446 
1447       setOperationAction(ISD::FNEG,              VT, Custom);
1448       setOperationAction(ISD::FABS,              VT, Custom);
1449       setOperationAction(ISD::FCOPYSIGN,         VT, Custom);
1450 
1451       setOperationAction(ISD::FMAXIMUM,          VT, Custom);
1452       setOperationAction(ISD::FMINIMUM,          VT, Custom);
1453     }
1454 
1455     setOperationAction(ISD::LRINT, MVT::v8f32, Custom);
1456     setOperationAction(ISD::LRINT, MVT::v4f64, Custom);
1457 
1458     // (fp_to_int:v8i16 (v8f32 ..)) requires the result type to be promoted
1459     // even though v8i16 is a legal type.
1460     setOperationPromotedToType(ISD::FP_TO_SINT,        MVT::v8i16, MVT::v8i32);
1461     setOperationPromotedToType(ISD::FP_TO_UINT,        MVT::v8i16, MVT::v8i32);
1462     setOperationPromotedToType(ISD::STRICT_FP_TO_SINT, MVT::v8i16, MVT::v8i32);
1463     setOperationPromotedToType(ISD::STRICT_FP_TO_UINT, MVT::v8i16, MVT::v8i32);
1464     setOperationAction(ISD::FP_TO_SINT,                MVT::v8i32, Custom);
1465     setOperationAction(ISD::FP_TO_UINT,                MVT::v8i32, Custom);
1466     setOperationAction(ISD::STRICT_FP_TO_SINT,         MVT::v8i32, Custom);
1467 
1468     setOperationAction(ISD::SINT_TO_FP,         MVT::v8i32, Custom);
1469     setOperationAction(ISD::STRICT_SINT_TO_FP,  MVT::v8i32, Custom);
1470     setOperationAction(ISD::FP_EXTEND,          MVT::v8f32, Expand);
1471     setOperationAction(ISD::FP_ROUND,           MVT::v8f16, Expand);
1472     setOperationAction(ISD::FP_EXTEND,          MVT::v4f64, Custom);
1473     setOperationAction(ISD::STRICT_FP_EXTEND,   MVT::v4f64, Custom);
1474 
1475     setOperationAction(ISD::STRICT_FP_ROUND,    MVT::v4f32, Legal);
1476     setOperationAction(ISD::STRICT_FADD,        MVT::v8f32, Legal);
1477     setOperationAction(ISD::STRICT_FADD,        MVT::v4f64, Legal);
1478     setOperationAction(ISD::STRICT_FSUB,        MVT::v8f32, Legal);
1479     setOperationAction(ISD::STRICT_FSUB,        MVT::v4f64, Legal);
1480     setOperationAction(ISD::STRICT_FMUL,        MVT::v8f32, Legal);
1481     setOperationAction(ISD::STRICT_FMUL,        MVT::v4f64, Legal);
1482     setOperationAction(ISD::STRICT_FDIV,        MVT::v8f32, Legal);
1483     setOperationAction(ISD::STRICT_FDIV,        MVT::v4f64, Legal);
1484     setOperationAction(ISD::STRICT_FSQRT,       MVT::v8f32, Legal);
1485     setOperationAction(ISD::STRICT_FSQRT,       MVT::v4f64, Legal);
1486 
1487     if (!Subtarget.hasAVX512())
1488       setOperationAction(ISD::BITCAST, MVT::v32i1, Custom);
1489 
1490     // In the customized shift lowering, the legal v8i32/v4i64 cases
1491     // in AVX2 will be recognized.
1492     for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) {
1493       setOperationAction(ISD::SRL,             VT, Custom);
1494       setOperationAction(ISD::SHL,             VT, Custom);
1495       setOperationAction(ISD::SRA,             VT, Custom);
1496       setOperationAction(ISD::ABDS,            VT, Custom);
1497       setOperationAction(ISD::ABDU,            VT, Custom);
1498       if (VT == MVT::v4i64) continue;
1499       setOperationAction(ISD::ROTL,            VT, Custom);
1500       setOperationAction(ISD::ROTR,            VT, Custom);
1501       setOperationAction(ISD::FSHL,            VT, Custom);
1502       setOperationAction(ISD::FSHR,            VT, Custom);
1503     }
1504 
1505     // These types need custom splitting if their input is a 128-bit vector.
1506     setOperationAction(ISD::SIGN_EXTEND,       MVT::v8i64,  Custom);
1507     setOperationAction(ISD::SIGN_EXTEND,       MVT::v16i32, Custom);
1508     setOperationAction(ISD::ZERO_EXTEND,       MVT::v8i64,  Custom);
1509     setOperationAction(ISD::ZERO_EXTEND,       MVT::v16i32, Custom);
1510 
1511     setOperationAction(ISD::SELECT,            MVT::v4f64, Custom);
1512     setOperationAction(ISD::SELECT,            MVT::v4i64, Custom);
1513     setOperationAction(ISD::SELECT,            MVT::v8i32, Custom);
1514     setOperationAction(ISD::SELECT,            MVT::v16i16, Custom);
1515     setOperationAction(ISD::SELECT,            MVT::v16f16, Custom);
1516     setOperationAction(ISD::SELECT,            MVT::v32i8, Custom);
1517     setOperationAction(ISD::SELECT,            MVT::v8f32, Custom);
1518 
1519     for (auto VT : { MVT::v16i16, MVT::v8i32, MVT::v4i64 }) {
1520       setOperationAction(ISD::SIGN_EXTEND,     VT, Custom);
1521       setOperationAction(ISD::ZERO_EXTEND,     VT, Custom);
1522       setOperationAction(ISD::ANY_EXTEND,      VT, Custom);
1523     }
1524 
1525     setOperationAction(ISD::TRUNCATE,          MVT::v32i8, Custom);
1526     setOperationAction(ISD::TRUNCATE,          MVT::v32i16, Custom);
1527     setOperationAction(ISD::TRUNCATE,          MVT::v32i32, Custom);
1528     setOperationAction(ISD::TRUNCATE,          MVT::v32i64, Custom);
1529 
1530     for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) {
1531       setOperationAction(ISD::SETCC,           VT, Custom);
1532       setOperationAction(ISD::CTPOP,           VT, Custom);
1533       setOperationAction(ISD::CTLZ,            VT, Custom);
1534       setOperationAction(ISD::BITREVERSE,      VT, Custom);
1535 
1536       // The condition codes aren't legal in SSE/AVX and under AVX512 we use
1537       // setcc all the way to isel and prefer SETGT in some isel patterns.
1538       setCondCodeAction(ISD::SETLT, VT, Custom);
1539       setCondCodeAction(ISD::SETLE, VT, Custom);
1540     }
1541 
1542     setOperationAction(ISD::SETCC,          MVT::v4f64, Custom);
1543     setOperationAction(ISD::SETCC,          MVT::v8f32, Custom);
1544     setOperationAction(ISD::STRICT_FSETCC,  MVT::v4f64, Custom);
1545     setOperationAction(ISD::STRICT_FSETCC,  MVT::v8f32, Custom);
1546     setOperationAction(ISD::STRICT_FSETCCS, MVT::v4f64, Custom);
1547     setOperationAction(ISD::STRICT_FSETCCS, MVT::v8f32, Custom);
1548 
1549     if (Subtarget.hasAnyFMA()) {
1550       for (auto VT : { MVT::f32, MVT::f64, MVT::v4f32, MVT::v8f32,
1551                        MVT::v2f64, MVT::v4f64 }) {
1552         setOperationAction(ISD::FMA, VT, Legal);
1553         setOperationAction(ISD::STRICT_FMA, VT, Legal);
1554       }
1555     }
1556 
1557     for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) {
1558       setOperationAction(ISD::ADD, VT, HasInt256 ? Legal : Custom);
1559       setOperationAction(ISD::SUB, VT, HasInt256 ? Legal : Custom);
1560     }
1561 
1562     setOperationAction(ISD::MUL,       MVT::v4i64,  Custom);
1563     setOperationAction(ISD::MUL,       MVT::v8i32,  HasInt256 ? Legal : Custom);
1564     setOperationAction(ISD::MUL,       MVT::v16i16, HasInt256 ? Legal : Custom);
1565     setOperationAction(ISD::MUL,       MVT::v32i8,  Custom);
1566 
1567     setOperationAction(ISD::MULHU,     MVT::v8i32,  Custom);
1568     setOperationAction(ISD::MULHS,     MVT::v8i32,  Custom);
1569     setOperationAction(ISD::MULHU,     MVT::v16i16, HasInt256 ? Legal : Custom);
1570     setOperationAction(ISD::MULHS,     MVT::v16i16, HasInt256 ? Legal : Custom);
1571     setOperationAction(ISD::MULHU,     MVT::v32i8,  Custom);
1572     setOperationAction(ISD::MULHS,     MVT::v32i8,  Custom);
1573     setOperationAction(ISD::AVGCEILU,  MVT::v16i16, HasInt256 ? Legal : Custom);
1574     setOperationAction(ISD::AVGCEILU,  MVT::v32i8,  HasInt256 ? Legal : Custom);
1575 
1576     setOperationAction(ISD::SMULO,     MVT::v32i8, Custom);
1577     setOperationAction(ISD::UMULO,     MVT::v32i8, Custom);
1578 
1579     setOperationAction(ISD::ABS,       MVT::v4i64,  Custom);
1580     setOperationAction(ISD::SMAX,      MVT::v4i64,  Custom);
1581     setOperationAction(ISD::UMAX,      MVT::v4i64,  Custom);
1582     setOperationAction(ISD::SMIN,      MVT::v4i64,  Custom);
1583     setOperationAction(ISD::UMIN,      MVT::v4i64,  Custom);
1584 
1585     setOperationAction(ISD::UADDSAT,   MVT::v32i8,  HasInt256 ? Legal : Custom);
1586     setOperationAction(ISD::SADDSAT,   MVT::v32i8,  HasInt256 ? Legal : Custom);
1587     setOperationAction(ISD::USUBSAT,   MVT::v32i8,  HasInt256 ? Legal : Custom);
1588     setOperationAction(ISD::SSUBSAT,   MVT::v32i8,  HasInt256 ? Legal : Custom);
1589     setOperationAction(ISD::UADDSAT,   MVT::v16i16, HasInt256 ? Legal : Custom);
1590     setOperationAction(ISD::SADDSAT,   MVT::v16i16, HasInt256 ? Legal : Custom);
1591     setOperationAction(ISD::USUBSAT,   MVT::v16i16, HasInt256 ? Legal : Custom);
1592     setOperationAction(ISD::SSUBSAT,   MVT::v16i16, HasInt256 ? Legal : Custom);
1593     setOperationAction(ISD::UADDSAT,   MVT::v8i32, Custom);
1594     setOperationAction(ISD::USUBSAT,   MVT::v8i32, Custom);
1595     setOperationAction(ISD::UADDSAT,   MVT::v4i64, Custom);
1596     setOperationAction(ISD::USUBSAT,   MVT::v4i64, Custom);
1597 
1598     for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32 }) {
1599       setOperationAction(ISD::ABS,  VT, HasInt256 ? Legal : Custom);
1600       setOperationAction(ISD::SMAX, VT, HasInt256 ? Legal : Custom);
1601       setOperationAction(ISD::UMAX, VT, HasInt256 ? Legal : Custom);
1602       setOperationAction(ISD::SMIN, VT, HasInt256 ? Legal : Custom);
1603       setOperationAction(ISD::UMIN, VT, HasInt256 ? Legal : Custom);
1604     }
1605 
1606     for (auto VT : {MVT::v16i16, MVT::v8i32, MVT::v4i64}) {
1607       setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Custom);
1608       setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom);
1609     }
1610 
1611     if (HasInt256) {
1612       // The custom lowering for UINT_TO_FP for v8i32 becomes interesting
1613       // when we have a 256bit-wide blend with immediate.
1614       setOperationAction(ISD::UINT_TO_FP, MVT::v8i32, Custom);
1615       setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v8i32, Custom);
1616 
1617       // AVX2 also has wider vector sign/zero extending loads, VPMOV[SZ]X
1618       for (auto LoadExtOp : { ISD::SEXTLOAD, ISD::ZEXTLOAD }) {
1619         setLoadExtAction(LoadExtOp, MVT::v16i16, MVT::v16i8, Legal);
1620         setLoadExtAction(LoadExtOp, MVT::v8i32,  MVT::v8i8,  Legal);
1621         setLoadExtAction(LoadExtOp, MVT::v4i64,  MVT::v4i8,  Legal);
1622         setLoadExtAction(LoadExtOp, MVT::v8i32,  MVT::v8i16, Legal);
1623         setLoadExtAction(LoadExtOp, MVT::v4i64,  MVT::v4i16, Legal);
1624         setLoadExtAction(LoadExtOp, MVT::v4i64,  MVT::v4i32, Legal);
1625       }
1626     }
1627 
1628     for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
1629                      MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) {
1630       setOperationAction(ISD::MLOAD,  VT, Subtarget.hasVLX() ? Legal : Custom);
1631       setOperationAction(ISD::MSTORE, VT, Legal);
1632     }
1633 
1634     // Extract subvector is special because the value type
1635     // (result) is 128-bit but the source is 256-bit wide.
1636     for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64,
1637                      MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
1638       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal);
1639     }
1640 
1641     // Custom lower several nodes for 256-bit types.
1642     for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64,
1643                     MVT::v16f16, MVT::v8f32, MVT::v4f64 }) {
1644       setOperationAction(ISD::BUILD_VECTOR,       VT, Custom);
1645       setOperationAction(ISD::VECTOR_SHUFFLE,     VT, Custom);
1646       setOperationAction(ISD::VSELECT,            VT, Custom);
1647       setOperationAction(ISD::INSERT_VECTOR_ELT,  VT, Custom);
1648       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1649       setOperationAction(ISD::SCALAR_TO_VECTOR,   VT, Custom);
1650       setOperationAction(ISD::INSERT_SUBVECTOR,   VT, Legal);
1651       setOperationAction(ISD::CONCAT_VECTORS,     VT, Custom);
1652       setOperationAction(ISD::STORE,              VT, Custom);
1653     }
1654     setF16Action(MVT::v16f16, Expand);
1655     setOperationAction(ISD::FNEG, MVT::v16f16, Custom);
1656     setOperationAction(ISD::FABS, MVT::v16f16, Custom);
1657     setOperationAction(ISD::FCOPYSIGN, MVT::v16f16, Custom);
1658     setOperationAction(ISD::FADD, MVT::v16f16, Expand);
1659     setOperationAction(ISD::FSUB, MVT::v16f16, Expand);
1660     setOperationAction(ISD::FMUL, MVT::v16f16, Expand);
1661     setOperationAction(ISD::FDIV, MVT::v16f16, Expand);
1662 
1663     if (HasInt256) {
1664       setOperationAction(ISD::VSELECT, MVT::v32i8, Legal);
1665 
1666       // Custom legalize 2x32 to get a little better code.
1667       setOperationAction(ISD::MGATHER, MVT::v2f32, Custom);
1668       setOperationAction(ISD::MGATHER, MVT::v2i32, Custom);
1669 
1670       for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
1671                        MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 })
1672         setOperationAction(ISD::MGATHER,  VT, Custom);
1673     }
1674   }
1675 
1676   if (!Subtarget.useSoftFloat() && !Subtarget.hasFP16() &&
1677       Subtarget.hasF16C()) {
1678     for (MVT VT : { MVT::f16, MVT::v2f16, MVT::v4f16, MVT::v8f16 }) {
1679       setOperationAction(ISD::FP_ROUND,           VT, Custom);
1680       setOperationAction(ISD::STRICT_FP_ROUND,    VT, Custom);
1681     }
1682     for (MVT VT : { MVT::f32, MVT::v2f32, MVT::v4f32, MVT::v8f32 }) {
1683       setOperationAction(ISD::FP_EXTEND,          VT, Custom);
1684       setOperationAction(ISD::STRICT_FP_EXTEND,   VT, Custom);
1685     }
1686     for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV}) {
1687       setOperationPromotedToType(Opc, MVT::v8f16, MVT::v8f32);
1688       setOperationPromotedToType(Opc, MVT::v16f16, MVT::v16f32);
1689     }
1690   }
1691 
1692   // This block controls legalization of the mask vector sizes that are
1693   // available with AVX512. 512-bit vectors are in a separate block controlled
1694   // by useAVX512Regs.
1695   if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) {
1696     addRegisterClass(MVT::v1i1,   &X86::VK1RegClass);
1697     addRegisterClass(MVT::v2i1,   &X86::VK2RegClass);
1698     addRegisterClass(MVT::v4i1,   &X86::VK4RegClass);
1699     addRegisterClass(MVT::v8i1,   &X86::VK8RegClass);
1700     addRegisterClass(MVT::v16i1,  &X86::VK16RegClass);
1701 
1702     setOperationAction(ISD::SELECT,             MVT::v1i1, Custom);
1703     setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v1i1, Custom);
1704     setOperationAction(ISD::BUILD_VECTOR,       MVT::v1i1, Custom);
1705 
1706     setOperationPromotedToType(ISD::FP_TO_SINT,        MVT::v8i1,  MVT::v8i32);
1707     setOperationPromotedToType(ISD::FP_TO_UINT,        MVT::v8i1,  MVT::v8i32);
1708     setOperationPromotedToType(ISD::FP_TO_SINT,        MVT::v4i1,  MVT::v4i32);
1709     setOperationPromotedToType(ISD::FP_TO_UINT,        MVT::v4i1,  MVT::v4i32);
1710     setOperationPromotedToType(ISD::STRICT_FP_TO_SINT, MVT::v8i1,  MVT::v8i32);
1711     setOperationPromotedToType(ISD::STRICT_FP_TO_UINT, MVT::v8i1,  MVT::v8i32);
1712     setOperationPromotedToType(ISD::STRICT_FP_TO_SINT, MVT::v4i1,  MVT::v4i32);
1713     setOperationPromotedToType(ISD::STRICT_FP_TO_UINT, MVT::v4i1,  MVT::v4i32);
1714     setOperationAction(ISD::FP_TO_SINT,                MVT::v2i1,  Custom);
1715     setOperationAction(ISD::FP_TO_UINT,                MVT::v2i1,  Custom);
1716     setOperationAction(ISD::STRICT_FP_TO_SINT,         MVT::v2i1,  Custom);
1717     setOperationAction(ISD::STRICT_FP_TO_UINT,         MVT::v2i1,  Custom);
1718 
1719     // There is no byte sized k-register load or store without AVX512DQ.
1720     if (!Subtarget.hasDQI()) {
1721       setOperationAction(ISD::LOAD, MVT::v1i1, Custom);
1722       setOperationAction(ISD::LOAD, MVT::v2i1, Custom);
1723       setOperationAction(ISD::LOAD, MVT::v4i1, Custom);
1724       setOperationAction(ISD::LOAD, MVT::v8i1, Custom);
1725 
1726       setOperationAction(ISD::STORE, MVT::v1i1, Custom);
1727       setOperationAction(ISD::STORE, MVT::v2i1, Custom);
1728       setOperationAction(ISD::STORE, MVT::v4i1, Custom);
1729       setOperationAction(ISD::STORE, MVT::v8i1, Custom);
1730     }
1731 
1732     // Extends of v16i1/v8i1/v4i1/v2i1 to 128-bit vectors.
1733     for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1734       setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
1735       setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
1736       setOperationAction(ISD::ANY_EXTEND,  VT, Custom);
1737     }
1738 
1739     for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1, MVT::v16i1 })
1740       setOperationAction(ISD::VSELECT,          VT, Expand);
1741 
1742     for (auto VT : { MVT::v2i1, MVT::v4i1, MVT::v8i1, MVT::v16i1 }) {
1743       setOperationAction(ISD::SETCC,            VT, Custom);
1744       setOperationAction(ISD::SELECT,           VT, Custom);
1745       setOperationAction(ISD::TRUNCATE,         VT, Custom);
1746 
1747       setOperationAction(ISD::BUILD_VECTOR,     VT, Custom);
1748       setOperationAction(ISD::CONCAT_VECTORS,   VT, Custom);
1749       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1750       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1751       setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
1752       setOperationAction(ISD::VECTOR_SHUFFLE,   VT,  Custom);
1753     }
1754 
1755     for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 })
1756       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1757   }
1758   if (Subtarget.hasDQI() && Subtarget.hasVLX()) {
1759     for (MVT VT : {MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) {
1760       setOperationAction(ISD::LRINT, VT, Legal);
1761       setOperationAction(ISD::LLRINT, VT, Legal);
1762     }
1763   }
1764 
1765   // This block controls legalization for 512-bit operations with 8/16/32/64 bit
1766   // elements. 512-bits can be disabled based on prefer-vector-width and
1767   // required-vector-width function attributes.
1768   if (!Subtarget.useSoftFloat() && Subtarget.useAVX512Regs()) {
1769     bool HasBWI = Subtarget.hasBWI();
1770 
1771     addRegisterClass(MVT::v16i32, &X86::VR512RegClass);
1772     addRegisterClass(MVT::v16f32, &X86::VR512RegClass);
1773     addRegisterClass(MVT::v8i64,  &X86::VR512RegClass);
1774     addRegisterClass(MVT::v8f64,  &X86::VR512RegClass);
1775     addRegisterClass(MVT::v32i16, &X86::VR512RegClass);
1776     addRegisterClass(MVT::v32f16, &X86::VR512RegClass);
1777     addRegisterClass(MVT::v64i8,  &X86::VR512RegClass);
1778 
1779     for (auto ExtType : {ISD::ZEXTLOAD, ISD::SEXTLOAD}) {
1780       setLoadExtAction(ExtType, MVT::v16i32, MVT::v16i8,  Legal);
1781       setLoadExtAction(ExtType, MVT::v16i32, MVT::v16i16, Legal);
1782       setLoadExtAction(ExtType, MVT::v8i64,  MVT::v8i8,   Legal);
1783       setLoadExtAction(ExtType, MVT::v8i64,  MVT::v8i16,  Legal);
1784       setLoadExtAction(ExtType, MVT::v8i64,  MVT::v8i32,  Legal);
1785       if (HasBWI)
1786         setLoadExtAction(ExtType, MVT::v32i16, MVT::v32i8, Legal);
1787     }
1788 
1789     for (MVT VT : { MVT::v16f32, MVT::v8f64 }) {
1790       setOperationAction(ISD::FMAXIMUM, VT, Custom);
1791       setOperationAction(ISD::FMINIMUM, VT, Custom);
1792       setOperationAction(ISD::FNEG,  VT, Custom);
1793       setOperationAction(ISD::FABS,  VT, Custom);
1794       setOperationAction(ISD::FMA,   VT, Legal);
1795       setOperationAction(ISD::STRICT_FMA, VT, Legal);
1796       setOperationAction(ISD::FCOPYSIGN, VT, Custom);
1797     }
1798     setOperationAction(ISD::LRINT, MVT::v16f32,
1799                        Subtarget.hasDQI() ? Legal : Custom);
1800     setOperationAction(ISD::LRINT, MVT::v8f64,
1801                        Subtarget.hasDQI() ? Legal : Custom);
1802     if (Subtarget.hasDQI())
1803       setOperationAction(ISD::LLRINT, MVT::v8f64, Legal);
1804 
1805     for (MVT VT : { MVT::v16i1, MVT::v16i8 }) {
1806       setOperationPromotedToType(ISD::FP_TO_SINT       , VT, MVT::v16i32);
1807       setOperationPromotedToType(ISD::FP_TO_UINT       , VT, MVT::v16i32);
1808       setOperationPromotedToType(ISD::STRICT_FP_TO_SINT, VT, MVT::v16i32);
1809       setOperationPromotedToType(ISD::STRICT_FP_TO_UINT, VT, MVT::v16i32);
1810     }
1811 
1812     for (MVT VT : { MVT::v16i16, MVT::v16i32 }) {
1813       setOperationAction(ISD::FP_TO_SINT,        VT, Custom);
1814       setOperationAction(ISD::FP_TO_UINT,        VT, Custom);
1815       setOperationAction(ISD::STRICT_FP_TO_SINT, VT, Custom);
1816       setOperationAction(ISD::STRICT_FP_TO_UINT, VT, Custom);
1817     }
1818 
1819     setOperationAction(ISD::SINT_TO_FP,        MVT::v16i32, Custom);
1820     setOperationAction(ISD::UINT_TO_FP,        MVT::v16i32, Custom);
1821     setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::v16i32, Custom);
1822     setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v16i32, Custom);
1823     setOperationAction(ISD::FP_EXTEND,         MVT::v8f64,  Custom);
1824     setOperationAction(ISD::STRICT_FP_EXTEND,  MVT::v8f64,  Custom);
1825 
1826     setOperationAction(ISD::STRICT_FADD,      MVT::v16f32, Legal);
1827     setOperationAction(ISD::STRICT_FADD,      MVT::v8f64,  Legal);
1828     setOperationAction(ISD::STRICT_FSUB,      MVT::v16f32, Legal);
1829     setOperationAction(ISD::STRICT_FSUB,      MVT::v8f64,  Legal);
1830     setOperationAction(ISD::STRICT_FMUL,      MVT::v16f32, Legal);
1831     setOperationAction(ISD::STRICT_FMUL,      MVT::v8f64,  Legal);
1832     setOperationAction(ISD::STRICT_FDIV,      MVT::v16f32, Legal);
1833     setOperationAction(ISD::STRICT_FDIV,      MVT::v8f64,  Legal);
1834     setOperationAction(ISD::STRICT_FSQRT,     MVT::v16f32, Legal);
1835     setOperationAction(ISD::STRICT_FSQRT,     MVT::v8f64,  Legal);
1836     setOperationAction(ISD::STRICT_FP_ROUND,  MVT::v8f32,  Legal);
1837 
1838     setTruncStoreAction(MVT::v8i64,   MVT::v8i8,   Legal);
1839     setTruncStoreAction(MVT::v8i64,   MVT::v8i16,  Legal);
1840     setTruncStoreAction(MVT::v8i64,   MVT::v8i32,  Legal);
1841     setTruncStoreAction(MVT::v16i32,  MVT::v16i8,  Legal);
1842     setTruncStoreAction(MVT::v16i32,  MVT::v16i16, Legal);
1843     if (HasBWI)
1844       setTruncStoreAction(MVT::v32i16,  MVT::v32i8, Legal);
1845 
1846     // With 512-bit vectors and no VLX, we prefer to widen MLOAD/MSTORE
1847     // to 512-bit rather than use the AVX2 instructions so that we can use
1848     // k-masks.
1849     if (!Subtarget.hasVLX()) {
1850       for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
1851            MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) {
1852         setOperationAction(ISD::MLOAD,  VT, Custom);
1853         setOperationAction(ISD::MSTORE, VT, Custom);
1854       }
1855     }
1856 
1857     setOperationAction(ISD::TRUNCATE,    MVT::v8i32,  Legal);
1858     setOperationAction(ISD::TRUNCATE,    MVT::v16i16, Legal);
1859     setOperationAction(ISD::TRUNCATE,    MVT::v32i8,  HasBWI ? Legal : Custom);
1860     setOperationAction(ISD::ZERO_EXTEND, MVT::v32i16, Custom);
1861     setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom);
1862     setOperationAction(ISD::ZERO_EXTEND, MVT::v8i64,  Custom);
1863     setOperationAction(ISD::ANY_EXTEND,  MVT::v32i16, Custom);
1864     setOperationAction(ISD::ANY_EXTEND,  MVT::v16i32, Custom);
1865     setOperationAction(ISD::ANY_EXTEND,  MVT::v8i64,  Custom);
1866     setOperationAction(ISD::SIGN_EXTEND, MVT::v32i16, Custom);
1867     setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom);
1868     setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64,  Custom);
1869 
1870     if (HasBWI) {
1871       // Extends from v64i1 masks to 512-bit vectors.
1872       setOperationAction(ISD::SIGN_EXTEND,        MVT::v64i8, Custom);
1873       setOperationAction(ISD::ZERO_EXTEND,        MVT::v64i8, Custom);
1874       setOperationAction(ISD::ANY_EXTEND,         MVT::v64i8, Custom);
1875     }
1876 
1877     for (auto VT : { MVT::v16f32, MVT::v8f64 }) {
1878       setOperationAction(ISD::FFLOOR,            VT, Legal);
1879       setOperationAction(ISD::STRICT_FFLOOR,     VT, Legal);
1880       setOperationAction(ISD::FCEIL,             VT, Legal);
1881       setOperationAction(ISD::STRICT_FCEIL,      VT, Legal);
1882       setOperationAction(ISD::FTRUNC,            VT, Legal);
1883       setOperationAction(ISD::STRICT_FTRUNC,     VT, Legal);
1884       setOperationAction(ISD::FRINT,             VT, Legal);
1885       setOperationAction(ISD::STRICT_FRINT,      VT, Legal);
1886       setOperationAction(ISD::FNEARBYINT,        VT, Legal);
1887       setOperationAction(ISD::STRICT_FNEARBYINT, VT, Legal);
1888       setOperationAction(ISD::FROUNDEVEN,        VT, Legal);
1889       setOperationAction(ISD::STRICT_FROUNDEVEN, VT, Legal);
1890 
1891       setOperationAction(ISD::FROUND,            VT, Custom);
1892     }
1893 
1894     for (auto VT : {MVT::v32i16, MVT::v16i32, MVT::v8i64}) {
1895       setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Custom);
1896       setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom);
1897     }
1898 
1899     setOperationAction(ISD::ADD, MVT::v32i16, HasBWI ? Legal : Custom);
1900     setOperationAction(ISD::SUB, MVT::v32i16, HasBWI ? Legal : Custom);
1901     setOperationAction(ISD::ADD, MVT::v64i8,  HasBWI ? Legal : Custom);
1902     setOperationAction(ISD::SUB, MVT::v64i8,  HasBWI ? Legal : Custom);
1903 
1904     setOperationAction(ISD::MUL, MVT::v8i64,  Custom);
1905     setOperationAction(ISD::MUL, MVT::v16i32, Legal);
1906     setOperationAction(ISD::MUL, MVT::v32i16, HasBWI ? Legal : Custom);
1907     setOperationAction(ISD::MUL, MVT::v64i8,  Custom);
1908 
1909     setOperationAction(ISD::MULHU, MVT::v16i32, Custom);
1910     setOperationAction(ISD::MULHS, MVT::v16i32, Custom);
1911     setOperationAction(ISD::MULHS, MVT::v32i16, HasBWI ? Legal : Custom);
1912     setOperationAction(ISD::MULHU, MVT::v32i16, HasBWI ? Legal : Custom);
1913     setOperationAction(ISD::MULHS, MVT::v64i8,  Custom);
1914     setOperationAction(ISD::MULHU, MVT::v64i8,  Custom);
1915     setOperationAction(ISD::AVGCEILU, MVT::v32i16, HasBWI ? Legal : Custom);
1916     setOperationAction(ISD::AVGCEILU, MVT::v64i8,  HasBWI ? Legal : Custom);
1917 
1918     setOperationAction(ISD::SMULO, MVT::v64i8, Custom);
1919     setOperationAction(ISD::UMULO, MVT::v64i8, Custom);
1920 
1921     for (auto VT : { MVT::v64i8, MVT::v32i16, MVT::v16i32, MVT::v8i64 }) {
1922       setOperationAction(ISD::SRL,              VT, Custom);
1923       setOperationAction(ISD::SHL,              VT, Custom);
1924       setOperationAction(ISD::SRA,              VT, Custom);
1925       setOperationAction(ISD::ROTL,             VT, Custom);
1926       setOperationAction(ISD::ROTR,             VT, Custom);
1927       setOperationAction(ISD::SETCC,            VT, Custom);
1928       setOperationAction(ISD::ABDS,             VT, Custom);
1929       setOperationAction(ISD::ABDU,             VT, Custom);
1930       setOperationAction(ISD::BITREVERSE,       VT, Custom);
1931 
1932       // The condition codes aren't legal in SSE/AVX and under AVX512 we use
1933       // setcc all the way to isel and prefer SETGT in some isel patterns.
1934       setCondCodeAction(ISD::SETLT, VT, Custom);
1935       setCondCodeAction(ISD::SETLE, VT, Custom);
1936     }
1937 
1938     setOperationAction(ISD::SETCC,          MVT::v8f64, Custom);
1939     setOperationAction(ISD::SETCC,          MVT::v16f32, Custom);
1940     setOperationAction(ISD::STRICT_FSETCC,  MVT::v8f64, Custom);
1941     setOperationAction(ISD::STRICT_FSETCC,  MVT::v16f32, Custom);
1942     setOperationAction(ISD::STRICT_FSETCCS, MVT::v8f64, Custom);
1943     setOperationAction(ISD::STRICT_FSETCCS, MVT::v16f32, Custom);
1944 
1945     for (auto VT : { MVT::v16i32, MVT::v8i64 }) {
1946       setOperationAction(ISD::SMAX,             VT, Legal);
1947       setOperationAction(ISD::UMAX,             VT, Legal);
1948       setOperationAction(ISD::SMIN,             VT, Legal);
1949       setOperationAction(ISD::UMIN,             VT, Legal);
1950       setOperationAction(ISD::ABS,              VT, Legal);
1951       setOperationAction(ISD::CTPOP,            VT, Custom);
1952     }
1953 
1954     for (auto VT : { MVT::v64i8, MVT::v32i16 }) {
1955       setOperationAction(ISD::ABS,     VT, HasBWI ? Legal : Custom);
1956       setOperationAction(ISD::CTPOP,   VT, Subtarget.hasBITALG() ? Legal : Custom);
1957       setOperationAction(ISD::CTLZ,    VT, Custom);
1958       setOperationAction(ISD::SMAX,    VT, HasBWI ? Legal : Custom);
1959       setOperationAction(ISD::UMAX,    VT, HasBWI ? Legal : Custom);
1960       setOperationAction(ISD::SMIN,    VT, HasBWI ? Legal : Custom);
1961       setOperationAction(ISD::UMIN,    VT, HasBWI ? Legal : Custom);
1962       setOperationAction(ISD::UADDSAT, VT, HasBWI ? Legal : Custom);
1963       setOperationAction(ISD::SADDSAT, VT, HasBWI ? Legal : Custom);
1964       setOperationAction(ISD::USUBSAT, VT, HasBWI ? Legal : Custom);
1965       setOperationAction(ISD::SSUBSAT, VT, HasBWI ? Legal : Custom);
1966     }
1967 
1968     setOperationAction(ISD::FSHL,       MVT::v64i8, Custom);
1969     setOperationAction(ISD::FSHR,       MVT::v64i8, Custom);
1970     setOperationAction(ISD::FSHL,      MVT::v32i16, Custom);
1971     setOperationAction(ISD::FSHR,      MVT::v32i16, Custom);
1972     setOperationAction(ISD::FSHL,      MVT::v16i32, Custom);
1973     setOperationAction(ISD::FSHR,      MVT::v16i32, Custom);
1974 
1975     if (Subtarget.hasDQI()) {
1976       for (auto Opc : {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::STRICT_SINT_TO_FP,
1977                        ISD::STRICT_UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
1978                        ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
1979         setOperationAction(Opc,           MVT::v8i64, Custom);
1980       setOperationAction(ISD::MUL,        MVT::v8i64, Legal);
1981     }
1982 
1983     if (Subtarget.hasCDI()) {
1984       // NonVLX sub-targets extend 128/256 vectors to use the 512 version.
1985       for (auto VT : { MVT::v16i32, MVT::v8i64} ) {
1986         setOperationAction(ISD::CTLZ,            VT, Legal);
1987       }
1988     } // Subtarget.hasCDI()
1989 
1990     if (Subtarget.hasVPOPCNTDQ()) {
1991       for (auto VT : { MVT::v16i32, MVT::v8i64 })
1992         setOperationAction(ISD::CTPOP, VT, Legal);
1993     }
1994 
1995     // Extract subvector is special because the value type
1996     // (result) is 256-bit but the source is 512-bit wide.
1997     // 128-bit was made Legal under AVX1.
1998     for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64,
1999                      MVT::v16f16, MVT::v8f32, MVT::v4f64 })
2000       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal);
2001 
2002     for (auto VT : { MVT::v64i8, MVT::v32i16, MVT::v16i32, MVT::v8i64,
2003                      MVT::v32f16, MVT::v16f32, MVT::v8f64 }) {
2004       setOperationAction(ISD::CONCAT_VECTORS,     VT, Custom);
2005       setOperationAction(ISD::INSERT_SUBVECTOR,   VT, Legal);
2006       setOperationAction(ISD::SELECT,             VT, Custom);
2007       setOperationAction(ISD::VSELECT,            VT, Custom);
2008       setOperationAction(ISD::BUILD_VECTOR,       VT, Custom);
2009       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
2010       setOperationAction(ISD::VECTOR_SHUFFLE,     VT, Custom);
2011       setOperationAction(ISD::SCALAR_TO_VECTOR,   VT, Custom);
2012       setOperationAction(ISD::INSERT_VECTOR_ELT,  VT, Custom);
2013     }
2014     setF16Action(MVT::v32f16, Expand);
2015     setOperationAction(ISD::FP_ROUND, MVT::v16f16, Custom);
2016     setOperationAction(ISD::STRICT_FP_ROUND, MVT::v16f16, Custom);
2017     setOperationAction(ISD::FP_EXTEND, MVT::v16f32, Custom);
2018     setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v16f32, Custom);
2019     for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
2020       setOperationPromotedToType(Opc, MVT::v32f16, MVT::v32f32);
2021 
2022     for (auto VT : { MVT::v16i32, MVT::v8i64, MVT::v16f32, MVT::v8f64 }) {
2023       setOperationAction(ISD::MLOAD,               VT, Legal);
2024       setOperationAction(ISD::MSTORE,              VT, Legal);
2025       setOperationAction(ISD::MGATHER,             VT, Custom);
2026       setOperationAction(ISD::MSCATTER,            VT, Custom);
2027     }
2028     if (HasBWI) {
2029       for (auto VT : { MVT::v64i8, MVT::v32i16 }) {
2030         setOperationAction(ISD::MLOAD,        VT, Legal);
2031         setOperationAction(ISD::MSTORE,       VT, Legal);
2032       }
2033     } else {
2034       setOperationAction(ISD::STORE, MVT::v32i16, Custom);
2035       setOperationAction(ISD::STORE, MVT::v64i8,  Custom);
2036     }
2037 
2038     if (Subtarget.hasVBMI2()) {
2039       for (auto VT : {MVT::v32i16, MVT::v16i32, MVT::v8i64}) {
2040         setOperationAction(ISD::FSHL, VT, Custom);
2041         setOperationAction(ISD::FSHR, VT, Custom);
2042       }
2043 
2044       setOperationAction(ISD::ROTL, MVT::v32i16, Custom);
2045       setOperationAction(ISD::ROTR, MVT::v32i16, Custom);
2046     }
2047 
2048     setOperationAction(ISD::FNEG, MVT::v32f16, Custom);
2049     setOperationAction(ISD::FABS, MVT::v32f16, Custom);
2050     setOperationAction(ISD::FCOPYSIGN, MVT::v32f16, Custom);
2051   }// useAVX512Regs
2052 
2053   if (!Subtarget.useSoftFloat() && Subtarget.hasVBMI2()) {
2054     for (auto VT : {MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v16i16, MVT::v8i32,
2055                     MVT::v4i64}) {
2056       setOperationAction(ISD::FSHL, VT, Custom);
2057       setOperationAction(ISD::FSHR, VT, Custom);
2058     }
2059   }
2060 
2061   // This block controls legalization for operations that don't have
2062   // pre-AVX512 equivalents. Without VLX we use 512-bit operations for
2063   // narrower widths.
2064   if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) {
2065     // These operations are handled on non-VLX by artificially widening in
2066     // isel patterns.
2067 
2068     setOperationAction(ISD::STRICT_FP_TO_UINT,  MVT::v8i32, Custom);
2069     setOperationAction(ISD::STRICT_FP_TO_UINT,  MVT::v4i32, Custom);
2070     setOperationAction(ISD::STRICT_FP_TO_UINT,  MVT::v2i32, Custom);
2071 
2072     if (Subtarget.hasDQI()) {
2073       // Fast v2f32 SINT_TO_FP( v2i64 ) custom conversion.
2074       // v2f32 UINT_TO_FP is already custom under SSE2.
2075       assert(isOperationCustom(ISD::UINT_TO_FP, MVT::v2f32) &&
2076              isOperationCustom(ISD::STRICT_UINT_TO_FP, MVT::v2f32) &&
2077              "Unexpected operation action!");
2078       // v2i64 FP_TO_S/UINT(v2f32) custom conversion.
2079       setOperationAction(ISD::FP_TO_SINT,        MVT::v2f32, Custom);
2080       setOperationAction(ISD::FP_TO_UINT,        MVT::v2f32, Custom);
2081       setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::v2f32, Custom);
2082       setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::v2f32, Custom);
2083     }
2084 
2085     for (auto VT : { MVT::v2i64, MVT::v4i64 }) {
2086       setOperationAction(ISD::SMAX, VT, Legal);
2087       setOperationAction(ISD::UMAX, VT, Legal);
2088       setOperationAction(ISD::SMIN, VT, Legal);
2089       setOperationAction(ISD::UMIN, VT, Legal);
2090       setOperationAction(ISD::ABS,  VT, Legal);
2091     }
2092 
2093     for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 }) {
2094       setOperationAction(ISD::ROTL,     VT, Custom);
2095       setOperationAction(ISD::ROTR,     VT, Custom);
2096     }
2097 
2098     // Custom legalize 2x32 to get a little better code.
2099     setOperationAction(ISD::MSCATTER, MVT::v2f32, Custom);
2100     setOperationAction(ISD::MSCATTER, MVT::v2i32, Custom);
2101 
2102     for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
2103                      MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 })
2104       setOperationAction(ISD::MSCATTER, VT, Custom);
2105 
2106     if (Subtarget.hasDQI()) {
2107       for (auto Opc : {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::STRICT_SINT_TO_FP,
2108                        ISD::STRICT_UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
2109                        ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT}) {
2110         setOperationAction(Opc, MVT::v2i64, Custom);
2111         setOperationAction(Opc, MVT::v4i64, Custom);
2112       }
2113       setOperationAction(ISD::MUL, MVT::v2i64, Legal);
2114       setOperationAction(ISD::MUL, MVT::v4i64, Legal);
2115     }
2116 
2117     if (Subtarget.hasCDI()) {
2118       for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 }) {
2119         setOperationAction(ISD::CTLZ,            VT, Legal);
2120       }
2121     } // Subtarget.hasCDI()
2122 
2123     if (Subtarget.hasVPOPCNTDQ()) {
2124       for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 })
2125         setOperationAction(ISD::CTPOP, VT, Legal);
2126     }
2127   }
2128 
2129   // This block control legalization of v32i1/v64i1 which are available with
2130   // AVX512BW..
2131   if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) {
2132     addRegisterClass(MVT::v32i1,  &X86::VK32RegClass);
2133     addRegisterClass(MVT::v64i1,  &X86::VK64RegClass);
2134 
2135     for (auto VT : { MVT::v32i1, MVT::v64i1 }) {
2136       setOperationAction(ISD::VSELECT,            VT, Expand);
2137       setOperationAction(ISD::TRUNCATE,           VT, Custom);
2138       setOperationAction(ISD::SETCC,              VT, Custom);
2139       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
2140       setOperationAction(ISD::INSERT_VECTOR_ELT,  VT, Custom);
2141       setOperationAction(ISD::SELECT,             VT, Custom);
2142       setOperationAction(ISD::BUILD_VECTOR,       VT, Custom);
2143       setOperationAction(ISD::VECTOR_SHUFFLE,     VT, Custom);
2144       setOperationAction(ISD::CONCAT_VECTORS,     VT, Custom);
2145       setOperationAction(ISD::INSERT_SUBVECTOR,   VT, Custom);
2146     }
2147 
2148     for (auto VT : { MVT::v16i1, MVT::v32i1 })
2149       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
2150 
2151     // Extends from v32i1 masks to 256-bit vectors.
2152     setOperationAction(ISD::SIGN_EXTEND,        MVT::v32i8, Custom);
2153     setOperationAction(ISD::ZERO_EXTEND,        MVT::v32i8, Custom);
2154     setOperationAction(ISD::ANY_EXTEND,         MVT::v32i8, Custom);
2155 
2156     for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) {
2157       setOperationAction(ISD::MLOAD,  VT, Subtarget.hasVLX() ? Legal : Custom);
2158       setOperationAction(ISD::MSTORE, VT, Subtarget.hasVLX() ? Legal : Custom);
2159     }
2160 
2161     // These operations are handled on non-VLX by artificially widening in
2162     // isel patterns.
2163     // TODO: Custom widen in lowering on non-VLX and drop the isel patterns?
2164 
2165     if (Subtarget.hasBITALG()) {
2166       for (auto VT : { MVT::v16i8, MVT::v32i8, MVT::v8i16, MVT::v16i16 })
2167         setOperationAction(ISD::CTPOP, VT, Legal);
2168     }
2169   }
2170 
2171   if (!Subtarget.useSoftFloat() && Subtarget.hasFP16()) {
2172     auto setGroup = [&] (MVT VT) {
2173       setOperationAction(ISD::FADD,               VT, Legal);
2174       setOperationAction(ISD::STRICT_FADD,        VT, Legal);
2175       setOperationAction(ISD::FSUB,               VT, Legal);
2176       setOperationAction(ISD::STRICT_FSUB,        VT, Legal);
2177       setOperationAction(ISD::FMUL,               VT, Legal);
2178       setOperationAction(ISD::STRICT_FMUL,        VT, Legal);
2179       setOperationAction(ISD::FDIV,               VT, Legal);
2180       setOperationAction(ISD::STRICT_FDIV,        VT, Legal);
2181       setOperationAction(ISD::FSQRT,              VT, Legal);
2182       setOperationAction(ISD::STRICT_FSQRT,       VT, Legal);
2183 
2184       setOperationAction(ISD::FFLOOR,             VT, Legal);
2185       setOperationAction(ISD::STRICT_FFLOOR,      VT, Legal);
2186       setOperationAction(ISD::FCEIL,              VT, Legal);
2187       setOperationAction(ISD::STRICT_FCEIL,       VT, Legal);
2188       setOperationAction(ISD::FTRUNC,             VT, Legal);
2189       setOperationAction(ISD::STRICT_FTRUNC,      VT, Legal);
2190       setOperationAction(ISD::FRINT,              VT, Legal);
2191       setOperationAction(ISD::STRICT_FRINT,       VT, Legal);
2192       setOperationAction(ISD::FNEARBYINT,         VT, Legal);
2193       setOperationAction(ISD::STRICT_FNEARBYINT,  VT, Legal);
2194       setOperationAction(ISD::FROUNDEVEN, VT, Legal);
2195       setOperationAction(ISD::STRICT_FROUNDEVEN, VT, Legal);
2196 
2197       setOperationAction(ISD::FROUND,             VT, Custom);
2198 
2199       setOperationAction(ISD::LOAD,               VT, Legal);
2200       setOperationAction(ISD::STORE,              VT, Legal);
2201 
2202       setOperationAction(ISD::FMA,                VT, Legal);
2203       setOperationAction(ISD::STRICT_FMA,         VT, Legal);
2204       setOperationAction(ISD::VSELECT,            VT, Legal);
2205       setOperationAction(ISD::BUILD_VECTOR,       VT, Custom);
2206       setOperationAction(ISD::SELECT,             VT, Custom);
2207 
2208       setOperationAction(ISD::FNEG,               VT, Custom);
2209       setOperationAction(ISD::FABS,               VT, Custom);
2210       setOperationAction(ISD::FCOPYSIGN,          VT, Custom);
2211       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
2212       setOperationAction(ISD::VECTOR_SHUFFLE,     VT, Custom);
2213 
2214       setOperationAction(ISD::SETCC,              VT, Custom);
2215       setOperationAction(ISD::STRICT_FSETCC,      VT, Custom);
2216       setOperationAction(ISD::STRICT_FSETCCS,     VT, Custom);
2217     };
2218 
2219     // AVX512_FP16 scalar operations
2220     setGroup(MVT::f16);
2221     setOperationAction(ISD::FREM,                 MVT::f16, Promote);
2222     setOperationAction(ISD::STRICT_FREM,          MVT::f16, Promote);
2223     setOperationAction(ISD::SELECT_CC,            MVT::f16, Expand);
2224     setOperationAction(ISD::BR_CC,                MVT::f16, Expand);
2225     setOperationAction(ISD::STRICT_FROUND,        MVT::f16, Promote);
2226     setOperationAction(ISD::FROUNDEVEN,           MVT::f16, Legal);
2227     setOperationAction(ISD::STRICT_FROUNDEVEN,    MVT::f16, Legal);
2228     setOperationAction(ISD::FP_ROUND,             MVT::f16, Custom);
2229     setOperationAction(ISD::STRICT_FP_ROUND,      MVT::f16, Custom);
2230     setOperationAction(ISD::FMAXIMUM,             MVT::f16, Custom);
2231     setOperationAction(ISD::FMINIMUM,             MVT::f16, Custom);
2232     setOperationAction(ISD::FP_EXTEND,            MVT::f32, Legal);
2233     setOperationAction(ISD::STRICT_FP_EXTEND,     MVT::f32, Legal);
2234 
2235     setCondCodeAction(ISD::SETOEQ, MVT::f16, Expand);
2236     setCondCodeAction(ISD::SETUNE, MVT::f16, Expand);
2237 
2238     if (Subtarget.useAVX512Regs()) {
2239       setGroup(MVT::v32f16);
2240       setOperationAction(ISD::SCALAR_TO_VECTOR,       MVT::v32f16, Custom);
2241       setOperationAction(ISD::SINT_TO_FP,             MVT::v32i16, Legal);
2242       setOperationAction(ISD::STRICT_SINT_TO_FP,      MVT::v32i16, Legal);
2243       setOperationAction(ISD::UINT_TO_FP,             MVT::v32i16, Legal);
2244       setOperationAction(ISD::STRICT_UINT_TO_FP,      MVT::v32i16, Legal);
2245       setOperationAction(ISD::FP_ROUND,               MVT::v16f16, Legal);
2246       setOperationAction(ISD::STRICT_FP_ROUND,        MVT::v16f16, Legal);
2247       setOperationAction(ISD::FP_EXTEND,              MVT::v16f32, Custom);
2248       setOperationAction(ISD::STRICT_FP_EXTEND,       MVT::v16f32, Legal);
2249       setOperationAction(ISD::FP_EXTEND,              MVT::v8f64,  Custom);
2250       setOperationAction(ISD::STRICT_FP_EXTEND,       MVT::v8f64,  Legal);
2251       setOperationAction(ISD::INSERT_VECTOR_ELT,      MVT::v32f16, Custom);
2252 
2253       setOperationAction(ISD::FP_TO_SINT,             MVT::v32i16, Custom);
2254       setOperationAction(ISD::STRICT_FP_TO_SINT,      MVT::v32i16, Custom);
2255       setOperationAction(ISD::FP_TO_UINT,             MVT::v32i16, Custom);
2256       setOperationAction(ISD::STRICT_FP_TO_UINT,      MVT::v32i16, Custom);
2257       setOperationPromotedToType(ISD::FP_TO_SINT,     MVT::v32i8,  MVT::v32i16);
2258       setOperationPromotedToType(ISD::STRICT_FP_TO_SINT, MVT::v32i8,
2259                                  MVT::v32i16);
2260       setOperationPromotedToType(ISD::FP_TO_UINT,     MVT::v32i8,  MVT::v32i16);
2261       setOperationPromotedToType(ISD::STRICT_FP_TO_UINT, MVT::v32i8,
2262                                  MVT::v32i16);
2263       setOperationPromotedToType(ISD::FP_TO_SINT,     MVT::v32i1,  MVT::v32i16);
2264       setOperationPromotedToType(ISD::STRICT_FP_TO_SINT, MVT::v32i1,
2265                                  MVT::v32i16);
2266       setOperationPromotedToType(ISD::FP_TO_UINT,     MVT::v32i1,  MVT::v32i16);
2267       setOperationPromotedToType(ISD::STRICT_FP_TO_UINT, MVT::v32i1,
2268                                  MVT::v32i16);
2269 
2270       setOperationAction(ISD::EXTRACT_SUBVECTOR,      MVT::v16f16, Legal);
2271       setOperationAction(ISD::INSERT_SUBVECTOR,       MVT::v32f16, Legal);
2272       setOperationAction(ISD::CONCAT_VECTORS,         MVT::v32f16, Custom);
2273 
2274       setLoadExtAction(ISD::EXTLOAD, MVT::v8f64,  MVT::v8f16,  Legal);
2275       setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Legal);
2276     }
2277 
2278     if (Subtarget.hasVLX()) {
2279       setGroup(MVT::v8f16);
2280       setGroup(MVT::v16f16);
2281 
2282       setOperationAction(ISD::SCALAR_TO_VECTOR,   MVT::v8f16,  Legal);
2283       setOperationAction(ISD::SCALAR_TO_VECTOR,   MVT::v16f16, Custom);
2284       setOperationAction(ISD::SINT_TO_FP,         MVT::v16i16, Legal);
2285       setOperationAction(ISD::STRICT_SINT_TO_FP,  MVT::v16i16, Legal);
2286       setOperationAction(ISD::SINT_TO_FP,         MVT::v8i16,  Legal);
2287       setOperationAction(ISD::STRICT_SINT_TO_FP,  MVT::v8i16,  Legal);
2288       setOperationAction(ISD::UINT_TO_FP,         MVT::v16i16, Legal);
2289       setOperationAction(ISD::STRICT_UINT_TO_FP,  MVT::v16i16, Legal);
2290       setOperationAction(ISD::UINT_TO_FP,         MVT::v8i16,  Legal);
2291       setOperationAction(ISD::STRICT_UINT_TO_FP,  MVT::v8i16,  Legal);
2292 
2293       setOperationAction(ISD::FP_TO_SINT,         MVT::v8i16, Custom);
2294       setOperationAction(ISD::STRICT_FP_TO_SINT,  MVT::v8i16, Custom);
2295       setOperationAction(ISD::FP_TO_UINT,         MVT::v8i16, Custom);
2296       setOperationAction(ISD::STRICT_FP_TO_UINT,  MVT::v8i16, Custom);
2297       setOperationAction(ISD::FP_ROUND,           MVT::v8f16, Legal);
2298       setOperationAction(ISD::STRICT_FP_ROUND,    MVT::v8f16, Legal);
2299       setOperationAction(ISD::FP_EXTEND,          MVT::v8f32, Custom);
2300       setOperationAction(ISD::STRICT_FP_EXTEND,   MVT::v8f32, Legal);
2301       setOperationAction(ISD::FP_EXTEND,          MVT::v4f64, Custom);
2302       setOperationAction(ISD::STRICT_FP_EXTEND,   MVT::v4f64, Legal);
2303 
2304       // INSERT_VECTOR_ELT v8f16 extended to VECTOR_SHUFFLE
2305       setOperationAction(ISD::INSERT_VECTOR_ELT,    MVT::v8f16,  Custom);
2306       setOperationAction(ISD::INSERT_VECTOR_ELT,    MVT::v16f16, Custom);
2307 
2308       setOperationAction(ISD::EXTRACT_SUBVECTOR,    MVT::v8f16, Legal);
2309       setOperationAction(ISD::INSERT_SUBVECTOR,     MVT::v16f16, Legal);
2310       setOperationAction(ISD::CONCAT_VECTORS,       MVT::v16f16, Custom);
2311 
2312       setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Legal);
2313       setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Legal);
2314       setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Legal);
2315       setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Legal);
2316 
2317       // Need to custom widen these to prevent scalarization.
2318       setOperationAction(ISD::LOAD,  MVT::v4f16, Custom);
2319       setOperationAction(ISD::STORE, MVT::v4f16, Custom);
2320     }
2321   }
2322 
2323   if (!Subtarget.useSoftFloat() &&
2324       (Subtarget.hasAVXNECONVERT() || Subtarget.hasBF16())) {
2325     addRegisterClass(MVT::v8bf16, Subtarget.hasAVX512() ? &X86::VR128XRegClass
2326                                                         : &X86::VR128RegClass);
2327     addRegisterClass(MVT::v16bf16, Subtarget.hasAVX512() ? &X86::VR256XRegClass
2328                                                          : &X86::VR256RegClass);
2329     // We set the type action of bf16 to TypeSoftPromoteHalf, but we don't
2330     // provide the method to promote BUILD_VECTOR and INSERT_VECTOR_ELT.
2331     // Set the operation action Custom to do the customization later.
2332     setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
2333     setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::bf16, Custom);
2334     for (auto VT : {MVT::v8bf16, MVT::v16bf16}) {
2335       setF16Action(VT, Expand);
2336       setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
2337       setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
2338       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Legal);
2339       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
2340     }
2341     for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV}) {
2342       setOperationPromotedToType(Opc, MVT::v8bf16, MVT::v8f32);
2343       setOperationPromotedToType(Opc, MVT::v16bf16, MVT::v16f32);
2344     }
2345     setOperationAction(ISD::FP_ROUND, MVT::v8bf16, Custom);
2346     addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
2347   }
2348 
2349   if (!Subtarget.useSoftFloat() && Subtarget.hasBF16()) {
2350     addRegisterClass(MVT::v32bf16, &X86::VR512RegClass);
2351     setF16Action(MVT::v32bf16, Expand);
2352     for (unsigned Opc : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FDIV})
2353       setOperationPromotedToType(Opc, MVT::v32bf16, MVT::v32f32);
2354     setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
2355     setOperationAction(ISD::FP_ROUND, MVT::v16bf16, Custom);
2356     setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
2357     setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v32bf16, Legal);
2358     setOperationAction(ISD::CONCAT_VECTORS, MVT::v32bf16, Custom);
2359   }
2360 
2361   if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) {
2362     setTruncStoreAction(MVT::v4i64, MVT::v4i8,  Legal);
2363     setTruncStoreAction(MVT::v4i64, MVT::v4i16, Legal);
2364     setTruncStoreAction(MVT::v4i64, MVT::v4i32, Legal);
2365     setTruncStoreAction(MVT::v8i32, MVT::v8i8,  Legal);
2366     setTruncStoreAction(MVT::v8i32, MVT::v8i16, Legal);
2367 
2368     setTruncStoreAction(MVT::v2i64, MVT::v2i8,  Legal);
2369     setTruncStoreAction(MVT::v2i64, MVT::v2i16, Legal);
2370     setTruncStoreAction(MVT::v2i64, MVT::v2i32, Legal);
2371     setTruncStoreAction(MVT::v4i32, MVT::v4i8,  Legal);
2372     setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal);
2373 
2374     if (Subtarget.hasBWI()) {
2375       setTruncStoreAction(MVT::v16i16,  MVT::v16i8, Legal);
2376       setTruncStoreAction(MVT::v8i16,   MVT::v8i8,  Legal);
2377     }
2378 
2379     if (Subtarget.hasFP16()) {
2380       // vcvttph2[u]dq v4f16 -> v4i32/64, v2f16 -> v2i32/64
2381       setOperationAction(ISD::FP_TO_SINT,        MVT::v2f16, Custom);
2382       setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::v2f16, Custom);
2383       setOperationAction(ISD::FP_TO_UINT,        MVT::v2f16, Custom);
2384       setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::v2f16, Custom);
2385       setOperationAction(ISD::FP_TO_SINT,        MVT::v4f16, Custom);
2386       setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::v4f16, Custom);
2387       setOperationAction(ISD::FP_TO_UINT,        MVT::v4f16, Custom);
2388       setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::v4f16, Custom);
2389       // vcvt[u]dq2ph v4i32/64 -> v4f16, v2i32/64 -> v2f16
2390       setOperationAction(ISD::SINT_TO_FP,        MVT::v2f16, Custom);
2391       setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::v2f16, Custom);
2392       setOperationAction(ISD::UINT_TO_FP,        MVT::v2f16, Custom);
2393       setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v2f16, Custom);
2394       setOperationAction(ISD::SINT_TO_FP,        MVT::v4f16, Custom);
2395       setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::v4f16, Custom);
2396       setOperationAction(ISD::UINT_TO_FP,        MVT::v4f16, Custom);
2397       setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v4f16, Custom);
2398       // vcvtps2phx v4f32 -> v4f16, v2f32 -> v2f16
2399       setOperationAction(ISD::FP_ROUND,          MVT::v2f16, Custom);
2400       setOperationAction(ISD::STRICT_FP_ROUND,   MVT::v2f16, Custom);
2401       setOperationAction(ISD::FP_ROUND,          MVT::v4f16, Custom);
2402       setOperationAction(ISD::STRICT_FP_ROUND,   MVT::v4f16, Custom);
2403       // vcvtph2psx v4f16 -> v4f32, v2f16 -> v2f32
2404       setOperationAction(ISD::FP_EXTEND,         MVT::v2f16, Custom);
2405       setOperationAction(ISD::STRICT_FP_EXTEND,  MVT::v2f16, Custom);
2406       setOperationAction(ISD::FP_EXTEND,         MVT::v4f16, Custom);
2407       setOperationAction(ISD::STRICT_FP_EXTEND,  MVT::v4f16, Custom);
2408     }
2409   }
2410 
2411   if (!Subtarget.useSoftFloat() && Subtarget.hasAMXTILE()) {
2412     addRegisterClass(MVT::x86amx, &X86::TILERegClass);
2413   }
2414 
2415   // We want to custom lower some of our intrinsics.
2416   setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
2417   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
2418   setOperationAction(ISD::INTRINSIC_VOID, MVT::Other, Custom);
2419   if (!Subtarget.is64Bit()) {
2420     setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i64, Custom);
2421   }
2422 
2423   // Only custom-lower 64-bit SADDO and friends on 64-bit because we don't
2424   // handle type legalization for these operations here.
2425   //
2426   // FIXME: We really should do custom legalization for addition and
2427   // subtraction on x86-32 once PR3203 is fixed.  We really can't do much better
2428   // than generic legalization for 64-bit multiplication-with-overflow, though.
2429   for (auto VT : { MVT::i8, MVT::i16, MVT::i32, MVT::i64 }) {
2430     if (VT == MVT::i64 && !Subtarget.is64Bit())
2431       continue;
2432     // Add/Sub/Mul with overflow operations are custom lowered.
2433     setOperationAction(ISD::SADDO, VT, Custom);
2434     setOperationAction(ISD::UADDO, VT, Custom);
2435     setOperationAction(ISD::SSUBO, VT, Custom);
2436     setOperationAction(ISD::USUBO, VT, Custom);
2437     setOperationAction(ISD::SMULO, VT, Custom);
2438     setOperationAction(ISD::UMULO, VT, Custom);
2439 
2440     // Support carry in as value rather than glue.
2441     setOperationAction(ISD::UADDO_CARRY, VT, Custom);
2442     setOperationAction(ISD::USUBO_CARRY, VT, Custom);
2443     setOperationAction(ISD::SETCCCARRY, VT, Custom);
2444     setOperationAction(ISD::SADDO_CARRY, VT, Custom);
2445     setOperationAction(ISD::SSUBO_CARRY, VT, Custom);
2446   }
2447 
2448   // Combine sin / cos into _sincos_stret if it is available.
2449   if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr &&
2450       getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) {
2451     setOperationAction(ISD::FSINCOS, MVT::f64, Custom);
2452     setOperationAction(ISD::FSINCOS, MVT::f32, Custom);
2453   }
2454 
2455   if (Subtarget.isTargetWin64()) {
2456     setOperationAction(ISD::SDIV, MVT::i128, Custom);
2457     setOperationAction(ISD::UDIV, MVT::i128, Custom);
2458     setOperationAction(ISD::SREM, MVT::i128, Custom);
2459     setOperationAction(ISD::UREM, MVT::i128, Custom);
2460     setOperationAction(ISD::FP_TO_SINT, MVT::i128, Custom);
2461     setOperationAction(ISD::FP_TO_UINT, MVT::i128, Custom);
2462     setOperationAction(ISD::SINT_TO_FP, MVT::i128, Custom);
2463     setOperationAction(ISD::UINT_TO_FP, MVT::i128, Custom);
2464     setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i128, Custom);
2465     setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i128, Custom);
2466     setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i128, Custom);
2467     setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom);
2468   }
2469 
2470   // On 32 bit MSVC, `fmodf(f32)` is not defined - only `fmod(f64)`
2471   // is. We should promote the value to 64-bits to solve this.
2472   // This is what the CRT headers do - `fmodf` is an inline header
2473   // function casting to f64 and calling `fmod`.
2474   if (Subtarget.is32Bit() &&
2475       (Subtarget.isTargetWindowsMSVC() || Subtarget.isTargetWindowsItanium()))
2476     // clang-format off
2477    for (ISD::NodeType Op :
2478          {ISD::FACOS,  ISD::STRICT_FACOS,
2479           ISD::FASIN,  ISD::STRICT_FASIN,
2480           ISD::FATAN,  ISD::STRICT_FATAN,
2481           ISD::FCEIL,  ISD::STRICT_FCEIL,
2482           ISD::FCOS,   ISD::STRICT_FCOS,
2483           ISD::FCOSH,  ISD::STRICT_FCOSH,
2484           ISD::FEXP,   ISD::STRICT_FEXP,
2485           ISD::FFLOOR, ISD::STRICT_FFLOOR,
2486           ISD::FREM,   ISD::STRICT_FREM,
2487           ISD::FLOG,   ISD::STRICT_FLOG,
2488           ISD::FLOG10, ISD::STRICT_FLOG10,
2489           ISD::FPOW,   ISD::STRICT_FPOW,
2490           ISD::FSIN,   ISD::STRICT_FSIN,
2491           ISD::FSINH,  ISD::STRICT_FSINH,
2492           ISD::FTAN,   ISD::STRICT_FTAN,
2493           ISD::FTANH,  ISD::STRICT_FTANH})
2494       if (isOperationExpand(Op, MVT::f32))
2495         setOperationAction(Op, MVT::f32, Promote);
2496   // clang-format on
2497 
2498   // On MSVC, both 32-bit and 64-bit, ldexpf(f32) is not defined.  MinGW has
2499   // it, but it's just a wrapper around ldexp.
2500   if (Subtarget.isOSWindows()) {
2501     for (ISD::NodeType Op : {ISD::FLDEXP, ISD::STRICT_FLDEXP, ISD::FFREXP})
2502       if (isOperationExpand(Op, MVT::f32))
2503         setOperationAction(Op, MVT::f32, Promote);
2504   }
2505 
2506   // We have target-specific dag combine patterns for the following nodes:
2507   setTargetDAGCombine({ISD::VECTOR_SHUFFLE,
2508                        ISD::SCALAR_TO_VECTOR,
2509                        ISD::INSERT_VECTOR_ELT,
2510                        ISD::EXTRACT_VECTOR_ELT,
2511                        ISD::CONCAT_VECTORS,
2512                        ISD::INSERT_SUBVECTOR,
2513                        ISD::EXTRACT_SUBVECTOR,
2514                        ISD::BITCAST,
2515                        ISD::VSELECT,
2516                        ISD::SELECT,
2517                        ISD::SHL,
2518                        ISD::SRA,
2519                        ISD::SRL,
2520                        ISD::OR,
2521                        ISD::AND,
2522                        ISD::AVGCEILS,
2523                        ISD::AVGCEILU,
2524                        ISD::AVGFLOORS,
2525                        ISD::AVGFLOORU,
2526                        ISD::BITREVERSE,
2527                        ISD::ADD,
2528                        ISD::FADD,
2529                        ISD::FSUB,
2530                        ISD::FNEG,
2531                        ISD::FMA,
2532                        ISD::STRICT_FMA,
2533                        ISD::FMINNUM,
2534                        ISD::FMAXNUM,
2535                        ISD::SUB,
2536                        ISD::LOAD,
2537                        ISD::LRINT,
2538                        ISD::LLRINT,
2539                        ISD::MLOAD,
2540                        ISD::STORE,
2541                        ISD::MSTORE,
2542                        ISD::TRUNCATE,
2543                        ISD::ZERO_EXTEND,
2544                        ISD::ANY_EXTEND,
2545                        ISD::SIGN_EXTEND,
2546                        ISD::SIGN_EXTEND_INREG,
2547                        ISD::ANY_EXTEND_VECTOR_INREG,
2548                        ISD::SIGN_EXTEND_VECTOR_INREG,
2549                        ISD::ZERO_EXTEND_VECTOR_INREG,
2550                        ISD::SINT_TO_FP,
2551                        ISD::UINT_TO_FP,
2552                        ISD::STRICT_SINT_TO_FP,
2553                        ISD::STRICT_UINT_TO_FP,
2554                        ISD::SETCC,
2555                        ISD::MUL,
2556                        ISD::XOR,
2557                        ISD::MSCATTER,
2558                        ISD::MGATHER,
2559                        ISD::FP16_TO_FP,
2560                        ISD::FP_EXTEND,
2561                        ISD::STRICT_FP_EXTEND,
2562                        ISD::FP_ROUND,
2563                        ISD::STRICT_FP_ROUND});
2564 
2565   computeRegisterProperties(Subtarget.getRegisterInfo());
2566 
2567   MaxStoresPerMemset = 16; // For @llvm.memset -> sequence of stores
2568   MaxStoresPerMemsetOptSize = 8;
2569   MaxStoresPerMemcpy = 8; // For @llvm.memcpy -> sequence of stores
2570   MaxStoresPerMemcpyOptSize = 4;
2571   MaxStoresPerMemmove = 8; // For @llvm.memmove -> sequence of stores
2572   MaxStoresPerMemmoveOptSize = 4;
2573 
2574   // TODO: These control memcmp expansion in CGP and could be raised higher, but
2575   // that needs to benchmarked and balanced with the potential use of vector
2576   // load/store types (PR33329, PR33914).
2577   MaxLoadsPerMemcmp = 2;
2578   MaxLoadsPerMemcmpOptSize = 2;
2579 
2580   // Default loop alignment, which can be overridden by -align-loops.
2581   setPrefLoopAlignment(Align(16));
2582 
2583   // An out-of-order CPU can speculatively execute past a predictable branch,
2584   // but a conditional move could be stalled by an expensive earlier operation.
2585   PredictableSelectIsExpensive = Subtarget.getSchedModel().isOutOfOrder();
2586   EnableExtLdPromotion = true;
2587   setPrefFunctionAlignment(Align(16));
2588 
2589   verifyIntrinsicTables();
2590 
2591   // Default to having -disable-strictnode-mutation on
2592   IsStrictFPEnabled = true;
2593 }
2594 
2595 // This has so far only been implemented for 64-bit MachO.
useLoadStackGuardNode() const2596 bool X86TargetLowering::useLoadStackGuardNode() const {
2597   return Subtarget.isTargetMachO() && Subtarget.is64Bit();
2598 }
2599 
useStackGuardXorFP() const2600 bool X86TargetLowering::useStackGuardXorFP() const {
2601   // Currently only MSVC CRTs XOR the frame pointer into the stack guard value.
2602   return Subtarget.getTargetTriple().isOSMSVCRT() && !Subtarget.isTargetMachO();
2603 }
2604 
emitStackGuardXorFP(SelectionDAG & DAG,SDValue Val,const SDLoc & DL) const2605 SDValue X86TargetLowering::emitStackGuardXorFP(SelectionDAG &DAG, SDValue Val,
2606                                                const SDLoc &DL) const {
2607   EVT PtrTy = getPointerTy(DAG.getDataLayout());
2608   unsigned XorOp = Subtarget.is64Bit() ? X86::XOR64_FP : X86::XOR32_FP;
2609   MachineSDNode *Node = DAG.getMachineNode(XorOp, DL, PtrTy, Val);
2610   return SDValue(Node, 0);
2611 }
2612 
2613 TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const2614 X86TargetLowering::getPreferredVectorAction(MVT VT) const {
2615   if ((VT == MVT::v32i1 || VT == MVT::v64i1) && Subtarget.hasAVX512() &&
2616       !Subtarget.hasBWI())
2617     return TypeSplitVector;
2618 
2619   if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
2620       !Subtarget.hasF16C() && VT.getVectorElementType() == MVT::f16)
2621     return TypeSplitVector;
2622 
2623   if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
2624       VT.getVectorElementType() != MVT::i1)
2625     return TypeWidenVector;
2626 
2627   return TargetLoweringBase::getPreferredVectorAction(VT);
2628 }
2629 
2630 FastISel *
createFastISel(FunctionLoweringInfo & funcInfo,const TargetLibraryInfo * libInfo) const2631 X86TargetLowering::createFastISel(FunctionLoweringInfo &funcInfo,
2632                                   const TargetLibraryInfo *libInfo) const {
2633   return X86::createFastISel(funcInfo, libInfo);
2634 }
2635 
2636 //===----------------------------------------------------------------------===//
2637 //                           Other Lowering Hooks
2638 //===----------------------------------------------------------------------===//
2639 
mayFoldLoad(SDValue Op,const X86Subtarget & Subtarget,bool AssumeSingleUse)2640 bool X86::mayFoldLoad(SDValue Op, const X86Subtarget &Subtarget,
2641                       bool AssumeSingleUse) {
2642   if (!AssumeSingleUse && !Op.hasOneUse())
2643     return false;
2644   if (!ISD::isNormalLoad(Op.getNode()))
2645     return false;
2646 
2647   // If this is an unaligned vector, make sure the target supports folding it.
2648   auto *Ld = cast<LoadSDNode>(Op.getNode());
2649   if (!Subtarget.hasAVX() && !Subtarget.hasSSEUnalignedMem() &&
2650       Ld->getValueSizeInBits(0) == 128 && Ld->getAlign() < Align(16))
2651     return false;
2652 
2653   // TODO: If this is a non-temporal load and the target has an instruction
2654   //       for it, it should not be folded. See "useNonTemporalLoad()".
2655 
2656   return true;
2657 }
2658 
mayFoldLoadIntoBroadcastFromMem(SDValue Op,MVT EltVT,const X86Subtarget & Subtarget,bool AssumeSingleUse)2659 bool X86::mayFoldLoadIntoBroadcastFromMem(SDValue Op, MVT EltVT,
2660                                           const X86Subtarget &Subtarget,
2661                                           bool AssumeSingleUse) {
2662   assert(Subtarget.hasAVX() && "Expected AVX for broadcast from memory");
2663   if (!X86::mayFoldLoad(Op, Subtarget, AssumeSingleUse))
2664     return false;
2665 
2666   // We can not replace a wide volatile load with a broadcast-from-memory,
2667   // because that would narrow the load, which isn't legal for volatiles.
2668   auto *Ld = cast<LoadSDNode>(Op.getNode());
2669   return !Ld->isVolatile() ||
2670          Ld->getValueSizeInBits(0) == EltVT.getScalarSizeInBits();
2671 }
2672 
mayFoldIntoStore(SDValue Op)2673 bool X86::mayFoldIntoStore(SDValue Op) {
2674   return Op.hasOneUse() && ISD::isNormalStore(*Op.getNode()->use_begin());
2675 }
2676 
mayFoldIntoZeroExtend(SDValue Op)2677 bool X86::mayFoldIntoZeroExtend(SDValue Op) {
2678   if (Op.hasOneUse()) {
2679     unsigned Opcode = Op.getNode()->use_begin()->getOpcode();
2680     return (ISD::ZERO_EXTEND == Opcode);
2681   }
2682   return false;
2683 }
2684 
isLogicOp(unsigned Opcode)2685 static bool isLogicOp(unsigned Opcode) {
2686   // TODO: Add support for X86ISD::FAND/FOR/FXOR/FANDN with test coverage.
2687   return ISD::isBitwiseLogicOp(Opcode) || X86ISD::ANDNP == Opcode;
2688 }
2689 
isTargetShuffle(unsigned Opcode)2690 static bool isTargetShuffle(unsigned Opcode) {
2691   switch(Opcode) {
2692   default: return false;
2693   case X86ISD::BLENDI:
2694   case X86ISD::PSHUFB:
2695   case X86ISD::PSHUFD:
2696   case X86ISD::PSHUFHW:
2697   case X86ISD::PSHUFLW:
2698   case X86ISD::SHUFP:
2699   case X86ISD::INSERTPS:
2700   case X86ISD::EXTRQI:
2701   case X86ISD::INSERTQI:
2702   case X86ISD::VALIGN:
2703   case X86ISD::PALIGNR:
2704   case X86ISD::VSHLDQ:
2705   case X86ISD::VSRLDQ:
2706   case X86ISD::MOVLHPS:
2707   case X86ISD::MOVHLPS:
2708   case X86ISD::MOVSHDUP:
2709   case X86ISD::MOVSLDUP:
2710   case X86ISD::MOVDDUP:
2711   case X86ISD::MOVSS:
2712   case X86ISD::MOVSD:
2713   case X86ISD::MOVSH:
2714   case X86ISD::UNPCKL:
2715   case X86ISD::UNPCKH:
2716   case X86ISD::VBROADCAST:
2717   case X86ISD::VPERMILPI:
2718   case X86ISD::VPERMILPV:
2719   case X86ISD::VPERM2X128:
2720   case X86ISD::SHUF128:
2721   case X86ISD::VPERMIL2:
2722   case X86ISD::VPERMI:
2723   case X86ISD::VPPERM:
2724   case X86ISD::VPERMV:
2725   case X86ISD::VPERMV3:
2726   case X86ISD::VZEXT_MOVL:
2727     return true;
2728   }
2729 }
2730 
isTargetShuffleVariableMask(unsigned Opcode)2731 static bool isTargetShuffleVariableMask(unsigned Opcode) {
2732   switch (Opcode) {
2733   default: return false;
2734   // Target Shuffles.
2735   case X86ISD::PSHUFB:
2736   case X86ISD::VPERMILPV:
2737   case X86ISD::VPERMIL2:
2738   case X86ISD::VPPERM:
2739   case X86ISD::VPERMV:
2740   case X86ISD::VPERMV3:
2741     return true;
2742   // 'Faux' Target Shuffles.
2743   case ISD::OR:
2744   case ISD::AND:
2745   case X86ISD::ANDNP:
2746     return true;
2747   }
2748 }
2749 
getReturnAddressFrameIndex(SelectionDAG & DAG) const2750 SDValue X86TargetLowering::getReturnAddressFrameIndex(SelectionDAG &DAG) const {
2751   MachineFunction &MF = DAG.getMachineFunction();
2752   const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
2753   X86MachineFunctionInfo *FuncInfo = MF.getInfo<X86MachineFunctionInfo>();
2754   int ReturnAddrIndex = FuncInfo->getRAIndex();
2755 
2756   if (ReturnAddrIndex == 0) {
2757     // Set up a frame object for the return address.
2758     unsigned SlotSize = RegInfo->getSlotSize();
2759     ReturnAddrIndex = MF.getFrameInfo().CreateFixedObject(SlotSize,
2760                                                           -(int64_t)SlotSize,
2761                                                           false);
2762     FuncInfo->setRAIndex(ReturnAddrIndex);
2763   }
2764 
2765   return DAG.getFrameIndex(ReturnAddrIndex, getPointerTy(DAG.getDataLayout()));
2766 }
2767 
isOffsetSuitableForCodeModel(int64_t Offset,CodeModel::Model CM,bool HasSymbolicDisplacement)2768 bool X86::isOffsetSuitableForCodeModel(int64_t Offset, CodeModel::Model CM,
2769                                        bool HasSymbolicDisplacement) {
2770   // Offset should fit into 32 bit immediate field.
2771   if (!isInt<32>(Offset))
2772     return false;
2773 
2774   // If we don't have a symbolic displacement - we don't have any extra
2775   // restrictions.
2776   if (!HasSymbolicDisplacement)
2777     return true;
2778 
2779   // We can fold large offsets in the large code model because we always use
2780   // 64-bit offsets.
2781   if (CM == CodeModel::Large)
2782     return true;
2783 
2784   // For kernel code model we know that all object resist in the negative half
2785   // of 32bits address space. We may not accept negative offsets, since they may
2786   // be just off and we may accept pretty large positive ones.
2787   if (CM == CodeModel::Kernel)
2788     return Offset >= 0;
2789 
2790   // For other non-large code models we assume that latest small object is 16MB
2791   // before end of 31 bits boundary. We may also accept pretty large negative
2792   // constants knowing that all objects are in the positive half of address
2793   // space.
2794   return Offset < 16 * 1024 * 1024;
2795 }
2796 
2797 /// Return true if the condition is an signed comparison operation.
isX86CCSigned(unsigned X86CC)2798 static bool isX86CCSigned(unsigned X86CC) {
2799   switch (X86CC) {
2800   default:
2801     llvm_unreachable("Invalid integer condition!");
2802   case X86::COND_E:
2803   case X86::COND_NE:
2804   case X86::COND_B:
2805   case X86::COND_A:
2806   case X86::COND_BE:
2807   case X86::COND_AE:
2808     return false;
2809   case X86::COND_G:
2810   case X86::COND_GE:
2811   case X86::COND_L:
2812   case X86::COND_LE:
2813     return true;
2814   }
2815 }
2816 
TranslateIntegerX86CC(ISD::CondCode SetCCOpcode)2817 static X86::CondCode TranslateIntegerX86CC(ISD::CondCode SetCCOpcode) {
2818   switch (SetCCOpcode) {
2819   // clang-format off
2820   default: llvm_unreachable("Invalid integer condition!");
2821   case ISD::SETEQ:  return X86::COND_E;
2822   case ISD::SETGT:  return X86::COND_G;
2823   case ISD::SETGE:  return X86::COND_GE;
2824   case ISD::SETLT:  return X86::COND_L;
2825   case ISD::SETLE:  return X86::COND_LE;
2826   case ISD::SETNE:  return X86::COND_NE;
2827   case ISD::SETULT: return X86::COND_B;
2828   case ISD::SETUGT: return X86::COND_A;
2829   case ISD::SETULE: return X86::COND_BE;
2830   case ISD::SETUGE: return X86::COND_AE;
2831   // clang-format on
2832   }
2833 }
2834 
2835 /// Do a one-to-one translation of a ISD::CondCode to the X86-specific
2836 /// condition code, returning the condition code and the LHS/RHS of the
2837 /// comparison to make.
TranslateX86CC(ISD::CondCode SetCCOpcode,const SDLoc & DL,bool isFP,SDValue & LHS,SDValue & RHS,SelectionDAG & DAG)2838 static X86::CondCode TranslateX86CC(ISD::CondCode SetCCOpcode, const SDLoc &DL,
2839                                     bool isFP, SDValue &LHS, SDValue &RHS,
2840                                     SelectionDAG &DAG) {
2841   if (!isFP) {
2842     if (ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS)) {
2843       if (SetCCOpcode == ISD::SETGT && RHSC->isAllOnes()) {
2844         // X > -1   -> X == 0, jump !sign.
2845         RHS = DAG.getConstant(0, DL, RHS.getValueType());
2846         return X86::COND_NS;
2847       }
2848       if (SetCCOpcode == ISD::SETLT && RHSC->isZero()) {
2849         // X < 0   -> X == 0, jump on sign.
2850         return X86::COND_S;
2851       }
2852       if (SetCCOpcode == ISD::SETGE && RHSC->isZero()) {
2853         // X >= 0   -> X == 0, jump on !sign.
2854         return X86::COND_NS;
2855       }
2856       if (SetCCOpcode == ISD::SETLT && RHSC->isOne()) {
2857         // X < 1   -> X <= 0
2858         RHS = DAG.getConstant(0, DL, RHS.getValueType());
2859         return X86::COND_LE;
2860       }
2861     }
2862 
2863     return TranslateIntegerX86CC(SetCCOpcode);
2864   }
2865 
2866   // First determine if it is required or is profitable to flip the operands.
2867 
2868   // If LHS is a foldable load, but RHS is not, flip the condition.
2869   if (ISD::isNON_EXTLoad(LHS.getNode()) &&
2870       !ISD::isNON_EXTLoad(RHS.getNode())) {
2871     SetCCOpcode = getSetCCSwappedOperands(SetCCOpcode);
2872     std::swap(LHS, RHS);
2873   }
2874 
2875   switch (SetCCOpcode) {
2876   default: break;
2877   case ISD::SETOLT:
2878   case ISD::SETOLE:
2879   case ISD::SETUGT:
2880   case ISD::SETUGE:
2881     std::swap(LHS, RHS);
2882     break;
2883   }
2884 
2885   // On a floating point condition, the flags are set as follows:
2886   // ZF  PF  CF   op
2887   //  0 | 0 | 0 | X > Y
2888   //  0 | 0 | 1 | X < Y
2889   //  1 | 0 | 0 | X == Y
2890   //  1 | 1 | 1 | unordered
2891   switch (SetCCOpcode) {
2892   // clang-format off
2893   default: llvm_unreachable("Condcode should be pre-legalized away");
2894   case ISD::SETUEQ:
2895   case ISD::SETEQ:   return X86::COND_E;
2896   case ISD::SETOLT:              // flipped
2897   case ISD::SETOGT:
2898   case ISD::SETGT:   return X86::COND_A;
2899   case ISD::SETOLE:              // flipped
2900   case ISD::SETOGE:
2901   case ISD::SETGE:   return X86::COND_AE;
2902   case ISD::SETUGT:              // flipped
2903   case ISD::SETULT:
2904   case ISD::SETLT:   return X86::COND_B;
2905   case ISD::SETUGE:              // flipped
2906   case ISD::SETULE:
2907   case ISD::SETLE:   return X86::COND_BE;
2908   case ISD::SETONE:
2909   case ISD::SETNE:   return X86::COND_NE;
2910   case ISD::SETUO:   return X86::COND_P;
2911   case ISD::SETO:    return X86::COND_NP;
2912   case ISD::SETOEQ:
2913   case ISD::SETUNE:  return X86::COND_INVALID;
2914   // clang-format on
2915   }
2916 }
2917 
2918 /// Is there a floating point cmov for the specific X86 condition code?
2919 /// Current x86 isa includes the following FP cmov instructions:
2920 /// fcmovb, fcomvbe, fcomve, fcmovu, fcmovae, fcmova, fcmovne, fcmovnu.
hasFPCMov(unsigned X86CC)2921 static bool hasFPCMov(unsigned X86CC) {
2922   switch (X86CC) {
2923   default:
2924     return false;
2925   case X86::COND_B:
2926   case X86::COND_BE:
2927   case X86::COND_E:
2928   case X86::COND_P:
2929   case X86::COND_A:
2930   case X86::COND_AE:
2931   case X86::COND_NE:
2932   case X86::COND_NP:
2933     return true;
2934   }
2935 }
2936 
useVPTERNLOG(const X86Subtarget & Subtarget,MVT VT)2937 static bool useVPTERNLOG(const X86Subtarget &Subtarget, MVT VT) {
2938   return Subtarget.hasVLX() || Subtarget.canExtendTo512DQ() ||
2939          VT.is512BitVector();
2940 }
2941 
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const2942 bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
2943                                            const CallInst &I,
2944                                            MachineFunction &MF,
2945                                            unsigned Intrinsic) const {
2946   Info.flags = MachineMemOperand::MONone;
2947   Info.offset = 0;
2948 
2949   const IntrinsicData* IntrData = getIntrinsicWithChain(Intrinsic);
2950   if (!IntrData) {
2951     switch (Intrinsic) {
2952     case Intrinsic::x86_aesenc128kl:
2953     case Intrinsic::x86_aesdec128kl:
2954       Info.opc = ISD::INTRINSIC_W_CHAIN;
2955       Info.ptrVal = I.getArgOperand(1);
2956       Info.memVT = EVT::getIntegerVT(I.getType()->getContext(), 48);
2957       Info.align = Align(1);
2958       Info.flags |= MachineMemOperand::MOLoad;
2959       return true;
2960     case Intrinsic::x86_aesenc256kl:
2961     case Intrinsic::x86_aesdec256kl:
2962       Info.opc = ISD::INTRINSIC_W_CHAIN;
2963       Info.ptrVal = I.getArgOperand(1);
2964       Info.memVT = EVT::getIntegerVT(I.getType()->getContext(), 64);
2965       Info.align = Align(1);
2966       Info.flags |= MachineMemOperand::MOLoad;
2967       return true;
2968     case Intrinsic::x86_aesencwide128kl:
2969     case Intrinsic::x86_aesdecwide128kl:
2970       Info.opc = ISD::INTRINSIC_W_CHAIN;
2971       Info.ptrVal = I.getArgOperand(0);
2972       Info.memVT = EVT::getIntegerVT(I.getType()->getContext(), 48);
2973       Info.align = Align(1);
2974       Info.flags |= MachineMemOperand::MOLoad;
2975       return true;
2976     case Intrinsic::x86_aesencwide256kl:
2977     case Intrinsic::x86_aesdecwide256kl:
2978       Info.opc = ISD::INTRINSIC_W_CHAIN;
2979       Info.ptrVal = I.getArgOperand(0);
2980       Info.memVT = EVT::getIntegerVT(I.getType()->getContext(), 64);
2981       Info.align = Align(1);
2982       Info.flags |= MachineMemOperand::MOLoad;
2983       return true;
2984     case Intrinsic::x86_cmpccxadd32:
2985     case Intrinsic::x86_cmpccxadd64:
2986     case Intrinsic::x86_atomic_bts:
2987     case Intrinsic::x86_atomic_btc:
2988     case Intrinsic::x86_atomic_btr: {
2989       Info.opc = ISD::INTRINSIC_W_CHAIN;
2990       Info.ptrVal = I.getArgOperand(0);
2991       unsigned Size = I.getType()->getScalarSizeInBits();
2992       Info.memVT = EVT::getIntegerVT(I.getType()->getContext(), Size);
2993       Info.align = Align(Size);
2994       Info.flags |= MachineMemOperand::MOLoad | MachineMemOperand::MOStore |
2995                     MachineMemOperand::MOVolatile;
2996       return true;
2997     }
2998     case Intrinsic::x86_atomic_bts_rm:
2999     case Intrinsic::x86_atomic_btc_rm:
3000     case Intrinsic::x86_atomic_btr_rm: {
3001       Info.opc = ISD::INTRINSIC_W_CHAIN;
3002       Info.ptrVal = I.getArgOperand(0);
3003       unsigned Size = I.getArgOperand(1)->getType()->getScalarSizeInBits();
3004       Info.memVT = EVT::getIntegerVT(I.getType()->getContext(), Size);
3005       Info.align = Align(Size);
3006       Info.flags |= MachineMemOperand::MOLoad | MachineMemOperand::MOStore |
3007                     MachineMemOperand::MOVolatile;
3008       return true;
3009     }
3010     case Intrinsic::x86_aadd32:
3011     case Intrinsic::x86_aadd64:
3012     case Intrinsic::x86_aand32:
3013     case Intrinsic::x86_aand64:
3014     case Intrinsic::x86_aor32:
3015     case Intrinsic::x86_aor64:
3016     case Intrinsic::x86_axor32:
3017     case Intrinsic::x86_axor64:
3018     case Intrinsic::x86_atomic_add_cc:
3019     case Intrinsic::x86_atomic_sub_cc:
3020     case Intrinsic::x86_atomic_or_cc:
3021     case Intrinsic::x86_atomic_and_cc:
3022     case Intrinsic::x86_atomic_xor_cc: {
3023       Info.opc = ISD::INTRINSIC_W_CHAIN;
3024       Info.ptrVal = I.getArgOperand(0);
3025       unsigned Size = I.getArgOperand(1)->getType()->getScalarSizeInBits();
3026       Info.memVT = EVT::getIntegerVT(I.getType()->getContext(), Size);
3027       Info.align = Align(Size);
3028       Info.flags |= MachineMemOperand::MOLoad | MachineMemOperand::MOStore |
3029                     MachineMemOperand::MOVolatile;
3030       return true;
3031     }
3032     }
3033     return false;
3034   }
3035 
3036   switch (IntrData->Type) {
3037   case TRUNCATE_TO_MEM_VI8:
3038   case TRUNCATE_TO_MEM_VI16:
3039   case TRUNCATE_TO_MEM_VI32: {
3040     Info.opc = ISD::INTRINSIC_VOID;
3041     Info.ptrVal = I.getArgOperand(0);
3042     MVT VT  = MVT::getVT(I.getArgOperand(1)->getType());
3043     MVT ScalarVT = MVT::INVALID_SIMPLE_VALUE_TYPE;
3044     if (IntrData->Type == TRUNCATE_TO_MEM_VI8)
3045       ScalarVT = MVT::i8;
3046     else if (IntrData->Type == TRUNCATE_TO_MEM_VI16)
3047       ScalarVT = MVT::i16;
3048     else if (IntrData->Type == TRUNCATE_TO_MEM_VI32)
3049       ScalarVT = MVT::i32;
3050 
3051     Info.memVT = MVT::getVectorVT(ScalarVT, VT.getVectorNumElements());
3052     Info.align = Align(1);
3053     Info.flags |= MachineMemOperand::MOStore;
3054     break;
3055   }
3056   case GATHER:
3057   case GATHER_AVX2: {
3058     Info.opc = ISD::INTRINSIC_W_CHAIN;
3059     Info.ptrVal = nullptr;
3060     MVT DataVT = MVT::getVT(I.getType());
3061     MVT IndexVT = MVT::getVT(I.getArgOperand(2)->getType());
3062     unsigned NumElts = std::min(DataVT.getVectorNumElements(),
3063                                 IndexVT.getVectorNumElements());
3064     Info.memVT = MVT::getVectorVT(DataVT.getVectorElementType(), NumElts);
3065     Info.align = Align(1);
3066     Info.flags |= MachineMemOperand::MOLoad;
3067     break;
3068   }
3069   case SCATTER: {
3070     Info.opc = ISD::INTRINSIC_VOID;
3071     Info.ptrVal = nullptr;
3072     MVT DataVT = MVT::getVT(I.getArgOperand(3)->getType());
3073     MVT IndexVT = MVT::getVT(I.getArgOperand(2)->getType());
3074     unsigned NumElts = std::min(DataVT.getVectorNumElements(),
3075                                 IndexVT.getVectorNumElements());
3076     Info.memVT = MVT::getVectorVT(DataVT.getVectorElementType(), NumElts);
3077     Info.align = Align(1);
3078     Info.flags |= MachineMemOperand::MOStore;
3079     break;
3080   }
3081   default:
3082     return false;
3083   }
3084 
3085   return true;
3086 }
3087 
3088 /// Returns true if the target can instruction select the
3089 /// specified FP immediate natively. If false, the legalizer will
3090 /// materialize the FP immediate as a load from a constant pool.
isFPImmLegal(const APFloat & Imm,EVT VT,bool ForCodeSize) const3091 bool X86TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
3092                                      bool ForCodeSize) const {
3093   for (const APFloat &FPImm : LegalFPImmediates)
3094     if (Imm.bitwiseIsEqual(FPImm))
3095       return true;
3096   return false;
3097 }
3098 
shouldReduceLoadWidth(SDNode * Load,ISD::LoadExtType ExtTy,EVT NewVT) const3099 bool X86TargetLowering::shouldReduceLoadWidth(SDNode *Load,
3100                                               ISD::LoadExtType ExtTy,
3101                                               EVT NewVT) const {
3102   assert(cast<LoadSDNode>(Load)->isSimple() && "illegal to narrow");
3103 
3104   // "ELF Handling for Thread-Local Storage" specifies that R_X86_64_GOTTPOFF
3105   // relocation target a movq or addq instruction: don't let the load shrink.
3106   SDValue BasePtr = cast<LoadSDNode>(Load)->getBasePtr();
3107   if (BasePtr.getOpcode() == X86ISD::WrapperRIP)
3108     if (const auto *GA = dyn_cast<GlobalAddressSDNode>(BasePtr.getOperand(0)))
3109       return GA->getTargetFlags() != X86II::MO_GOTTPOFF;
3110 
3111   // If this is an (1) AVX vector load with (2) multiple uses and (3) all of
3112   // those uses are extracted directly into a store, then the extract + store
3113   // can be store-folded. Therefore, it's probably not worth splitting the load.
3114   EVT VT = Load->getValueType(0);
3115   if ((VT.is256BitVector() || VT.is512BitVector()) && !Load->hasOneUse()) {
3116     for (auto UI = Load->use_begin(), UE = Load->use_end(); UI != UE; ++UI) {
3117       // Skip uses of the chain value. Result 0 of the node is the load value.
3118       if (UI.getUse().getResNo() != 0)
3119         continue;
3120 
3121       // If this use is not an extract + store, it's probably worth splitting.
3122       if (UI->getOpcode() != ISD::EXTRACT_SUBVECTOR || !UI->hasOneUse() ||
3123           UI->use_begin()->getOpcode() != ISD::STORE)
3124         return true;
3125     }
3126     // All non-chain uses are extract + store.
3127     return false;
3128   }
3129 
3130   return true;
3131 }
3132 
3133 /// Returns true if it is beneficial to convert a load of a constant
3134 /// to just the constant itself.
shouldConvertConstantLoadToIntImm(const APInt & Imm,Type * Ty) const3135 bool X86TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
3136                                                           Type *Ty) const {
3137   assert(Ty->isIntegerTy());
3138 
3139   unsigned BitSize = Ty->getPrimitiveSizeInBits();
3140   if (BitSize == 0 || BitSize > 64)
3141     return false;
3142   return true;
3143 }
3144 
reduceSelectOfFPConstantLoads(EVT CmpOpVT) const3145 bool X86TargetLowering::reduceSelectOfFPConstantLoads(EVT CmpOpVT) const {
3146   // If we are using XMM registers in the ABI and the condition of the select is
3147   // a floating-point compare and we have blendv or conditional move, then it is
3148   // cheaper to select instead of doing a cross-register move and creating a
3149   // load that depends on the compare result.
3150   bool IsFPSetCC = CmpOpVT.isFloatingPoint() && CmpOpVT != MVT::f128;
3151   return !IsFPSetCC || !Subtarget.isTarget64BitLP64() || !Subtarget.hasAVX();
3152 }
3153 
convertSelectOfConstantsToMath(EVT VT) const3154 bool X86TargetLowering::convertSelectOfConstantsToMath(EVT VT) const {
3155   // TODO: It might be a win to ease or lift this restriction, but the generic
3156   // folds in DAGCombiner conflict with vector folds for an AVX512 target.
3157   if (VT.isVector() && Subtarget.hasAVX512())
3158     return false;
3159 
3160   return true;
3161 }
3162 
decomposeMulByConstant(LLVMContext & Context,EVT VT,SDValue C) const3163 bool X86TargetLowering::decomposeMulByConstant(LLVMContext &Context, EVT VT,
3164                                                SDValue C) const {
3165   // TODO: We handle scalars using custom code, but generic combining could make
3166   // that unnecessary.
3167   APInt MulC;
3168   if (!ISD::isConstantSplatVector(C.getNode(), MulC))
3169     return false;
3170 
3171   // Find the type this will be legalized too. Otherwise we might prematurely
3172   // convert this to shl+add/sub and then still have to type legalize those ops.
3173   // Another choice would be to defer the decision for illegal types until
3174   // after type legalization. But constant splat vectors of i64 can't make it
3175   // through type legalization on 32-bit targets so we would need to special
3176   // case vXi64.
3177   while (getTypeAction(Context, VT) != TypeLegal)
3178     VT = getTypeToTransformTo(Context, VT);
3179 
3180   // If vector multiply is legal, assume that's faster than shl + add/sub.
3181   // Multiply is a complex op with higher latency and lower throughput in
3182   // most implementations, sub-vXi32 vector multiplies are always fast,
3183   // vXi32 mustn't have a SlowMULLD implementation, and anything larger (vXi64)
3184   // is always going to be slow.
3185   unsigned EltSizeInBits = VT.getScalarSizeInBits();
3186   if (isOperationLegal(ISD::MUL, VT) && EltSizeInBits <= 32 &&
3187       (EltSizeInBits != 32 || !Subtarget.isPMULLDSlow()))
3188     return false;
3189 
3190   // shl+add, shl+sub, shl+add+neg
3191   return (MulC + 1).isPowerOf2() || (MulC - 1).isPowerOf2() ||
3192          (1 - MulC).isPowerOf2() || (-(MulC + 1)).isPowerOf2();
3193 }
3194 
isExtractSubvectorCheap(EVT ResVT,EVT SrcVT,unsigned Index) const3195 bool X86TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
3196                                                 unsigned Index) const {
3197   if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
3198     return false;
3199 
3200   // Mask vectors support all subregister combinations and operations that
3201   // extract half of vector.
3202   if (ResVT.getVectorElementType() == MVT::i1)
3203     return Index == 0 || ((ResVT.getSizeInBits() == SrcVT.getSizeInBits()*2) &&
3204                           (Index == ResVT.getVectorNumElements()));
3205 
3206   return (Index % ResVT.getVectorNumElements()) == 0;
3207 }
3208 
shouldScalarizeBinop(SDValue VecOp) const3209 bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
3210   unsigned Opc = VecOp.getOpcode();
3211 
3212   // Assume target opcodes can't be scalarized.
3213   // TODO - do we have any exceptions?
3214   if (Opc >= ISD::BUILTIN_OP_END)
3215     return false;
3216 
3217   // If the vector op is not supported, try to convert to scalar.
3218   EVT VecVT = VecOp.getValueType();
3219   if (!isOperationLegalOrCustomOrPromote(Opc, VecVT))
3220     return true;
3221 
3222   // If the vector op is supported, but the scalar op is not, the transform may
3223   // not be worthwhile.
3224   EVT ScalarVT = VecVT.getScalarType();
3225   return isOperationLegalOrCustomOrPromote(Opc, ScalarVT);
3226 }
3227 
shouldFormOverflowOp(unsigned Opcode,EVT VT,bool) const3228 bool X86TargetLowering::shouldFormOverflowOp(unsigned Opcode, EVT VT,
3229                                              bool) const {
3230   // TODO: Allow vectors?
3231   if (VT.isVector())
3232     return false;
3233   return VT.isSimple() || !isOperationExpand(Opcode, VT);
3234 }
3235 
isCheapToSpeculateCttz(Type * Ty) const3236 bool X86TargetLowering::isCheapToSpeculateCttz(Type *Ty) const {
3237   // Speculate cttz only if we can directly use TZCNT or can promote to i32.
3238   return Subtarget.hasBMI() ||
3239          (!Ty->isVectorTy() && Ty->getScalarSizeInBits() < 32);
3240 }
3241 
isCheapToSpeculateCtlz(Type * Ty) const3242 bool X86TargetLowering::isCheapToSpeculateCtlz(Type *Ty) const {
3243   // Speculate ctlz only if we can directly use LZCNT.
3244   return Subtarget.hasLZCNT();
3245 }
3246 
ShouldShrinkFPConstant(EVT VT) const3247 bool X86TargetLowering::ShouldShrinkFPConstant(EVT VT) const {
3248   // Don't shrink FP constpool if SSE2 is available since cvtss2sd is more
3249   // expensive than a straight movsd. On the other hand, it's important to
3250   // shrink long double fp constant since fldt is very slow.
3251   return !Subtarget.hasSSE2() || VT == MVT::f80;
3252 }
3253 
isScalarFPTypeInSSEReg(EVT VT) const3254 bool X86TargetLowering::isScalarFPTypeInSSEReg(EVT VT) const {
3255   return (VT == MVT::f64 && Subtarget.hasSSE2()) ||
3256          (VT == MVT::f32 && Subtarget.hasSSE1()) || VT == MVT::f16;
3257 }
3258 
isLoadBitCastBeneficial(EVT LoadVT,EVT BitcastVT,const SelectionDAG & DAG,const MachineMemOperand & MMO) const3259 bool X86TargetLowering::isLoadBitCastBeneficial(EVT LoadVT, EVT BitcastVT,
3260                                                 const SelectionDAG &DAG,
3261                                                 const MachineMemOperand &MMO) const {
3262   if (!Subtarget.hasAVX512() && !LoadVT.isVector() && BitcastVT.isVector() &&
3263       BitcastVT.getVectorElementType() == MVT::i1)
3264     return false;
3265 
3266   if (!Subtarget.hasDQI() && BitcastVT == MVT::v8i1 && LoadVT == MVT::i8)
3267     return false;
3268 
3269   // If both types are legal vectors, it's always ok to convert them.
3270   if (LoadVT.isVector() && BitcastVT.isVector() &&
3271       isTypeLegal(LoadVT) && isTypeLegal(BitcastVT))
3272     return true;
3273 
3274   return TargetLowering::isLoadBitCastBeneficial(LoadVT, BitcastVT, DAG, MMO);
3275 }
3276 
canMergeStoresTo(unsigned AddressSpace,EVT MemVT,const MachineFunction & MF) const3277 bool X86TargetLowering::canMergeStoresTo(unsigned AddressSpace, EVT MemVT,
3278                                          const MachineFunction &MF) const {
3279   // Do not merge to float value size (128 bytes) if no implicit
3280   // float attribute is set.
3281   bool NoFloat = MF.getFunction().hasFnAttribute(Attribute::NoImplicitFloat);
3282 
3283   if (NoFloat) {
3284     unsigned MaxIntSize = Subtarget.is64Bit() ? 64 : 32;
3285     return (MemVT.getSizeInBits() <= MaxIntSize);
3286   }
3287   // Make sure we don't merge greater than our preferred vector
3288   // width.
3289   if (MemVT.getSizeInBits() > Subtarget.getPreferVectorWidth())
3290     return false;
3291 
3292   return true;
3293 }
3294 
isCtlzFast() const3295 bool X86TargetLowering::isCtlzFast() const {
3296   return Subtarget.hasFastLZCNT();
3297 }
3298 
isMaskAndCmp0FoldingBeneficial(const Instruction & AndI) const3299 bool X86TargetLowering::isMaskAndCmp0FoldingBeneficial(
3300     const Instruction &AndI) const {
3301   return true;
3302 }
3303 
hasAndNotCompare(SDValue Y) const3304 bool X86TargetLowering::hasAndNotCompare(SDValue Y) const {
3305   EVT VT = Y.getValueType();
3306 
3307   if (VT.isVector())
3308     return false;
3309 
3310   if (!Subtarget.hasBMI())
3311     return false;
3312 
3313   // There are only 32-bit and 64-bit forms for 'andn'.
3314   if (VT != MVT::i32 && VT != MVT::i64)
3315     return false;
3316 
3317   return !isa<ConstantSDNode>(Y) || cast<ConstantSDNode>(Y)->isOpaque();
3318 }
3319 
hasAndNot(SDValue Y) const3320 bool X86TargetLowering::hasAndNot(SDValue Y) const {
3321   EVT VT = Y.getValueType();
3322 
3323   if (!VT.isVector())
3324     return hasAndNotCompare(Y);
3325 
3326   // Vector.
3327 
3328   if (!Subtarget.hasSSE1() || VT.getSizeInBits() < 128)
3329     return false;
3330 
3331   if (VT == MVT::v4i32)
3332     return true;
3333 
3334   return Subtarget.hasSSE2();
3335 }
3336 
hasBitTest(SDValue X,SDValue Y) const3337 bool X86TargetLowering::hasBitTest(SDValue X, SDValue Y) const {
3338   return X.getValueType().isScalarInteger(); // 'bt'
3339 }
3340 
3341 bool X86TargetLowering::
shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(SDValue X,ConstantSDNode * XC,ConstantSDNode * CC,SDValue Y,unsigned OldShiftOpcode,unsigned NewShiftOpcode,SelectionDAG & DAG) const3342     shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
3343         SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y,
3344         unsigned OldShiftOpcode, unsigned NewShiftOpcode,
3345         SelectionDAG &DAG) const {
3346   // Does baseline recommend not to perform the fold by default?
3347   if (!TargetLowering::shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
3348           X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG))
3349     return false;
3350   // For scalars this transform is always beneficial.
3351   if (X.getValueType().isScalarInteger())
3352     return true;
3353   // If all the shift amounts are identical, then transform is beneficial even
3354   // with rudimentary SSE2 shifts.
3355   if (DAG.isSplatValue(Y, /*AllowUndefs=*/true))
3356     return true;
3357   // If we have AVX2 with it's powerful shift operations, then it's also good.
3358   if (Subtarget.hasAVX2())
3359     return true;
3360   // Pre-AVX2 vector codegen for this pattern is best for variant with 'shl'.
3361   return NewShiftOpcode == ISD::SHL;
3362 }
3363 
preferedOpcodeForCmpEqPiecesOfOperand(EVT VT,unsigned ShiftOpc,bool MayTransformRotate,const APInt & ShiftOrRotateAmt,const std::optional<APInt> & AndMask) const3364 unsigned X86TargetLowering::preferedOpcodeForCmpEqPiecesOfOperand(
3365     EVT VT, unsigned ShiftOpc, bool MayTransformRotate,
3366     const APInt &ShiftOrRotateAmt, const std::optional<APInt> &AndMask) const {
3367   if (!VT.isInteger())
3368     return ShiftOpc;
3369 
3370   bool PreferRotate = false;
3371   if (VT.isVector()) {
3372     // For vectors, if we have rotate instruction support, then its definetly
3373     // best. Otherwise its not clear what the best so just don't make changed.
3374     PreferRotate = Subtarget.hasAVX512() && (VT.getScalarType() == MVT::i32 ||
3375                                              VT.getScalarType() == MVT::i64);
3376   } else {
3377     // For scalar, if we have bmi prefer rotate for rorx. Otherwise prefer
3378     // rotate unless we have a zext mask+shr.
3379     PreferRotate = Subtarget.hasBMI2();
3380     if (!PreferRotate) {
3381       unsigned MaskBits =
3382           VT.getScalarSizeInBits() - ShiftOrRotateAmt.getZExtValue();
3383       PreferRotate = (MaskBits != 8) && (MaskBits != 16) && (MaskBits != 32);
3384     }
3385   }
3386 
3387   if (ShiftOpc == ISD::SHL || ShiftOpc == ISD::SRL) {
3388     assert(AndMask.has_value() && "Null andmask when querying about shift+and");
3389 
3390     if (PreferRotate && MayTransformRotate)
3391       return ISD::ROTL;
3392 
3393     // If vector we don't really get much benefit swapping around constants.
3394     // Maybe we could check if the DAG has the flipped node already in the
3395     // future.
3396     if (VT.isVector())
3397       return ShiftOpc;
3398 
3399     // See if the beneficial to swap shift type.
3400     if (ShiftOpc == ISD::SHL) {
3401       // If the current setup has imm64 mask, then inverse will have
3402       // at least imm32 mask (or be zext i32 -> i64).
3403       if (VT == MVT::i64)
3404         return AndMask->getSignificantBits() > 32 ? (unsigned)ISD::SRL
3405                                                   : ShiftOpc;
3406 
3407       // We can only benefit if req at least 7-bit for the mask. We
3408       // don't want to replace shl of 1,2,3 as they can be implemented
3409       // with lea/add.
3410       return ShiftOrRotateAmt.uge(7) ? (unsigned)ISD::SRL : ShiftOpc;
3411     }
3412 
3413     if (VT == MVT::i64)
3414       // Keep exactly 32-bit imm64, this is zext i32 -> i64 which is
3415       // extremely efficient.
3416       return AndMask->getSignificantBits() > 33 ? (unsigned)ISD::SHL : ShiftOpc;
3417 
3418     // Keep small shifts as shl so we can generate add/lea.
3419     return ShiftOrRotateAmt.ult(7) ? (unsigned)ISD::SHL : ShiftOpc;
3420   }
3421 
3422   // We prefer rotate for vectors of if we won't get a zext mask with SRL
3423   // (PreferRotate will be set in the latter case).
3424   if (PreferRotate || !MayTransformRotate || VT.isVector())
3425     return ShiftOpc;
3426 
3427   // Non-vector type and we have a zext mask with SRL.
3428   return ISD::SRL;
3429 }
3430 
3431 TargetLoweringBase::CondMergingParams
getJumpConditionMergingParams(Instruction::BinaryOps Opc,const Value * Lhs,const Value * Rhs) const3432 X86TargetLowering::getJumpConditionMergingParams(Instruction::BinaryOps Opc,
3433                                                  const Value *Lhs,
3434                                                  const Value *Rhs) const {
3435   using namespace llvm::PatternMatch;
3436   int BaseCost = BrMergingBaseCostThresh.getValue();
3437   // With CCMP, branches can be merged in a more efficient way.
3438   if (BaseCost >= 0 && Subtarget.hasCCMP())
3439     BaseCost += BrMergingCcmpBias;
3440   // a == b && a == c is a fast pattern on x86.
3441   ICmpInst::Predicate Pred;
3442   if (BaseCost >= 0 && Opc == Instruction::And &&
3443       match(Lhs, m_ICmp(Pred, m_Value(), m_Value())) &&
3444       Pred == ICmpInst::ICMP_EQ &&
3445       match(Rhs, m_ICmp(Pred, m_Value(), m_Value())) &&
3446       Pred == ICmpInst::ICMP_EQ)
3447     BaseCost += 1;
3448   return {BaseCost, BrMergingLikelyBias.getValue(),
3449           BrMergingUnlikelyBias.getValue()};
3450 }
3451 
preferScalarizeSplat(SDNode * N) const3452 bool X86TargetLowering::preferScalarizeSplat(SDNode *N) const {
3453   return N->getOpcode() != ISD::FP_EXTEND;
3454 }
3455 
shouldFoldConstantShiftPairToMask(const SDNode * N,CombineLevel Level) const3456 bool X86TargetLowering::shouldFoldConstantShiftPairToMask(
3457     const SDNode *N, CombineLevel Level) const {
3458   assert(((N->getOpcode() == ISD::SHL &&
3459            N->getOperand(0).getOpcode() == ISD::SRL) ||
3460           (N->getOpcode() == ISD::SRL &&
3461            N->getOperand(0).getOpcode() == ISD::SHL)) &&
3462          "Expected shift-shift mask");
3463   // TODO: Should we always create i64 masks? Or only folded immediates?
3464   EVT VT = N->getValueType(0);
3465   if ((Subtarget.hasFastVectorShiftMasks() && VT.isVector()) ||
3466       (Subtarget.hasFastScalarShiftMasks() && !VT.isVector())) {
3467     // Only fold if the shift values are equal - so it folds to AND.
3468     // TODO - we should fold if either is a non-uniform vector but we don't do
3469     // the fold for non-splats yet.
3470     return N->getOperand(1) == N->getOperand(0).getOperand(1);
3471   }
3472   return TargetLoweringBase::shouldFoldConstantShiftPairToMask(N, Level);
3473 }
3474 
shouldFoldMaskToVariableShiftPair(SDValue Y) const3475 bool X86TargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
3476   EVT VT = Y.getValueType();
3477 
3478   // For vectors, we don't have a preference, but we probably want a mask.
3479   if (VT.isVector())
3480     return false;
3481 
3482   // 64-bit shifts on 32-bit targets produce really bad bloated code.
3483   if (VT == MVT::i64 && !Subtarget.is64Bit())
3484     return false;
3485 
3486   return true;
3487 }
3488 
3489 TargetLowering::ShiftLegalizationStrategy
preferredShiftLegalizationStrategy(SelectionDAG & DAG,SDNode * N,unsigned ExpansionFactor) const3490 X86TargetLowering::preferredShiftLegalizationStrategy(
3491     SelectionDAG &DAG, SDNode *N, unsigned ExpansionFactor) const {
3492   if (DAG.getMachineFunction().getFunction().hasMinSize() &&
3493       !Subtarget.isOSWindows())
3494     return ShiftLegalizationStrategy::LowerToLibcall;
3495   return TargetLowering::preferredShiftLegalizationStrategy(DAG, N,
3496                                                             ExpansionFactor);
3497 }
3498 
shouldSplatInsEltVarIndex(EVT VT) const3499 bool X86TargetLowering::shouldSplatInsEltVarIndex(EVT VT) const {
3500   // Any legal vector type can be splatted more efficiently than
3501   // loading/spilling from memory.
3502   return isTypeLegal(VT);
3503 }
3504 
hasFastEqualityCompare(unsigned NumBits) const3505 MVT X86TargetLowering::hasFastEqualityCompare(unsigned NumBits) const {
3506   MVT VT = MVT::getIntegerVT(NumBits);
3507   if (isTypeLegal(VT))
3508     return VT;
3509 
3510   // PMOVMSKB can handle this.
3511   if (NumBits == 128 && isTypeLegal(MVT::v16i8))
3512     return MVT::v16i8;
3513 
3514   // VPMOVMSKB can handle this.
3515   if (NumBits == 256 && isTypeLegal(MVT::v32i8))
3516     return MVT::v32i8;
3517 
3518   // TODO: Allow 64-bit type for 32-bit target.
3519   // TODO: 512-bit types should be allowed, but make sure that those
3520   // cases are handled in combineVectorSizedSetCCEquality().
3521 
3522   return MVT::INVALID_SIMPLE_VALUE_TYPE;
3523 }
3524 
3525 /// Val is the undef sentinel value or equal to the specified value.
isUndefOrEqual(int Val,int CmpVal)3526 static bool isUndefOrEqual(int Val, int CmpVal) {
3527   return ((Val == SM_SentinelUndef) || (Val == CmpVal));
3528 }
3529 
3530 /// Return true if every element in Mask is the undef sentinel value or equal to
3531 /// the specified value.
isUndefOrEqual(ArrayRef<int> Mask,int CmpVal)3532 static bool isUndefOrEqual(ArrayRef<int> Mask, int CmpVal) {
3533   return llvm::all_of(Mask, [CmpVal](int M) {
3534     return (M == SM_SentinelUndef) || (M == CmpVal);
3535   });
3536 }
3537 
3538 /// Return true if every element in Mask, beginning from position Pos and ending
3539 /// in Pos+Size is the undef sentinel value or equal to the specified value.
isUndefOrEqualInRange(ArrayRef<int> Mask,int CmpVal,unsigned Pos,unsigned Size)3540 static bool isUndefOrEqualInRange(ArrayRef<int> Mask, int CmpVal, unsigned Pos,
3541                                   unsigned Size) {
3542   return llvm::all_of(Mask.slice(Pos, Size),
3543                       [CmpVal](int M) { return isUndefOrEqual(M, CmpVal); });
3544 }
3545 
3546 /// Val is either the undef or zero sentinel value.
isUndefOrZero(int Val)3547 static bool isUndefOrZero(int Val) {
3548   return ((Val == SM_SentinelUndef) || (Val == SM_SentinelZero));
3549 }
3550 
3551 /// Return true if every element in Mask, beginning from position Pos and ending
3552 /// in Pos+Size is the undef sentinel value.
isUndefInRange(ArrayRef<int> Mask,unsigned Pos,unsigned Size)3553 static bool isUndefInRange(ArrayRef<int> Mask, unsigned Pos, unsigned Size) {
3554   return llvm::all_of(Mask.slice(Pos, Size),
3555                       [](int M) { return M == SM_SentinelUndef; });
3556 }
3557 
3558 /// Return true if the mask creates a vector whose lower half is undefined.
isUndefLowerHalf(ArrayRef<int> Mask)3559 static bool isUndefLowerHalf(ArrayRef<int> Mask) {
3560   unsigned NumElts = Mask.size();
3561   return isUndefInRange(Mask, 0, NumElts / 2);
3562 }
3563 
3564 /// Return true if the mask creates a vector whose upper half is undefined.
isUndefUpperHalf(ArrayRef<int> Mask)3565 static bool isUndefUpperHalf(ArrayRef<int> Mask) {
3566   unsigned NumElts = Mask.size();
3567   return isUndefInRange(Mask, NumElts / 2, NumElts / 2);
3568 }
3569 
3570 /// Return true if Val falls within the specified range (L, H].
isInRange(int Val,int Low,int Hi)3571 static bool isInRange(int Val, int Low, int Hi) {
3572   return (Val >= Low && Val < Hi);
3573 }
3574 
3575 /// Return true if the value of any element in Mask falls within the specified
3576 /// range (L, H].
isAnyInRange(ArrayRef<int> Mask,int Low,int Hi)3577 static bool isAnyInRange(ArrayRef<int> Mask, int Low, int Hi) {
3578   return llvm::any_of(Mask, [Low, Hi](int M) { return isInRange(M, Low, Hi); });
3579 }
3580 
3581 /// Return true if the value of any element in Mask is the zero sentinel value.
isAnyZero(ArrayRef<int> Mask)3582 static bool isAnyZero(ArrayRef<int> Mask) {
3583   return llvm::any_of(Mask, [](int M) { return M == SM_SentinelZero; });
3584 }
3585 
3586 /// Return true if Val is undef or if its value falls within the
3587 /// specified range (L, H].
isUndefOrInRange(int Val,int Low,int Hi)3588 static bool isUndefOrInRange(int Val, int Low, int Hi) {
3589   return (Val == SM_SentinelUndef) || isInRange(Val, Low, Hi);
3590 }
3591 
3592 /// Return true if every element in Mask is undef or if its value
3593 /// falls within the specified range (L, H].
isUndefOrInRange(ArrayRef<int> Mask,int Low,int Hi)3594 static bool isUndefOrInRange(ArrayRef<int> Mask, int Low, int Hi) {
3595   return llvm::all_of(
3596       Mask, [Low, Hi](int M) { return isUndefOrInRange(M, Low, Hi); });
3597 }
3598 
3599 /// Return true if Val is undef, zero or if its value falls within the
3600 /// specified range (L, H].
isUndefOrZeroOrInRange(int Val,int Low,int Hi)3601 static bool isUndefOrZeroOrInRange(int Val, int Low, int Hi) {
3602   return isUndefOrZero(Val) || isInRange(Val, Low, Hi);
3603 }
3604 
3605 /// Return true if every element in Mask is undef, zero or if its value
3606 /// falls within the specified range (L, H].
isUndefOrZeroOrInRange(ArrayRef<int> Mask,int Low,int Hi)3607 static bool isUndefOrZeroOrInRange(ArrayRef<int> Mask, int Low, int Hi) {
3608   return llvm::all_of(
3609       Mask, [Low, Hi](int M) { return isUndefOrZeroOrInRange(M, Low, Hi); });
3610 }
3611 
3612 /// Return true if every element in Mask, is an in-place blend/select mask or is
3613 /// undef.
isBlendOrUndef(ArrayRef<int> Mask)3614 LLVM_ATTRIBUTE_UNUSED static bool isBlendOrUndef(ArrayRef<int> Mask) {
3615   unsigned NumElts = Mask.size();
3616   for (auto [I, M] : enumerate(Mask))
3617     if (!isUndefOrEqual(M, I) && !isUndefOrEqual(M, I + NumElts))
3618       return false;
3619   return true;
3620 }
3621 
3622 /// Return true if every element in Mask, beginning
3623 /// from position Pos and ending in Pos + Size, falls within the specified
3624 /// sequence (Low, Low + Step, ..., Low + (Size - 1) * Step) or is undef.
isSequentialOrUndefInRange(ArrayRef<int> Mask,unsigned Pos,unsigned Size,int Low,int Step=1)3625 static bool isSequentialOrUndefInRange(ArrayRef<int> Mask, unsigned Pos,
3626                                        unsigned Size, int Low, int Step = 1) {
3627   for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step)
3628     if (!isUndefOrEqual(Mask[i], Low))
3629       return false;
3630   return true;
3631 }
3632 
3633 /// Return true if every element in Mask, beginning
3634 /// from position Pos and ending in Pos+Size, falls within the specified
3635 /// sequential range (Low, Low+Size], or is undef or is zero.
isSequentialOrUndefOrZeroInRange(ArrayRef<int> Mask,unsigned Pos,unsigned Size,int Low,int Step=1)3636 static bool isSequentialOrUndefOrZeroInRange(ArrayRef<int> Mask, unsigned Pos,
3637                                              unsigned Size, int Low,
3638                                              int Step = 1) {
3639   for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step)
3640     if (!isUndefOrZero(Mask[i]) && Mask[i] != Low)
3641       return false;
3642   return true;
3643 }
3644 
3645 /// Return true if every element in Mask, beginning
3646 /// from position Pos and ending in Pos+Size is undef or is zero.
isUndefOrZeroInRange(ArrayRef<int> Mask,unsigned Pos,unsigned Size)3647 static bool isUndefOrZeroInRange(ArrayRef<int> Mask, unsigned Pos,
3648                                  unsigned Size) {
3649   return llvm::all_of(Mask.slice(Pos, Size), isUndefOrZero);
3650 }
3651 
3652 /// Return true if every element of a single input is referenced by the shuffle
3653 /// mask. i.e. it just permutes them all.
isCompletePermute(ArrayRef<int> Mask)3654 static bool isCompletePermute(ArrayRef<int> Mask) {
3655   unsigned NumElts = Mask.size();
3656   APInt DemandedElts = APInt::getZero(NumElts);
3657   for (int M : Mask)
3658     if (isInRange(M, 0, NumElts))
3659       DemandedElts.setBit(M);
3660   return DemandedElts.isAllOnes();
3661 }
3662 
3663 /// Helper function to test whether a shuffle mask could be
3664 /// simplified by widening the elements being shuffled.
3665 ///
3666 /// Appends the mask for wider elements in WidenedMask if valid. Otherwise
3667 /// leaves it in an unspecified state.
3668 ///
3669 /// NOTE: This must handle normal vector shuffle masks and *target* vector
3670 /// shuffle masks. The latter have the special property of a '-2' representing
3671 /// a zero-ed lane of a vector.
canWidenShuffleElements(ArrayRef<int> Mask,SmallVectorImpl<int> & WidenedMask)3672 static bool canWidenShuffleElements(ArrayRef<int> Mask,
3673                                     SmallVectorImpl<int> &WidenedMask) {
3674   WidenedMask.assign(Mask.size() / 2, 0);
3675   for (int i = 0, Size = Mask.size(); i < Size; i += 2) {
3676     int M0 = Mask[i];
3677     int M1 = Mask[i + 1];
3678 
3679     // If both elements are undef, its trivial.
3680     if (M0 == SM_SentinelUndef && M1 == SM_SentinelUndef) {
3681       WidenedMask[i / 2] = SM_SentinelUndef;
3682       continue;
3683     }
3684 
3685     // Check for an undef mask and a mask value properly aligned to fit with
3686     // a pair of values. If we find such a case, use the non-undef mask's value.
3687     if (M0 == SM_SentinelUndef && M1 >= 0 && (M1 % 2) == 1) {
3688       WidenedMask[i / 2] = M1 / 2;
3689       continue;
3690     }
3691     if (M1 == SM_SentinelUndef && M0 >= 0 && (M0 % 2) == 0) {
3692       WidenedMask[i / 2] = M0 / 2;
3693       continue;
3694     }
3695 
3696     // When zeroing, we need to spread the zeroing across both lanes to widen.
3697     if (M0 == SM_SentinelZero || M1 == SM_SentinelZero) {
3698       if ((M0 == SM_SentinelZero || M0 == SM_SentinelUndef) &&
3699           (M1 == SM_SentinelZero || M1 == SM_SentinelUndef)) {
3700         WidenedMask[i / 2] = SM_SentinelZero;
3701         continue;
3702       }
3703       return false;
3704     }
3705 
3706     // Finally check if the two mask values are adjacent and aligned with
3707     // a pair.
3708     if (M0 != SM_SentinelUndef && (M0 % 2) == 0 && (M0 + 1) == M1) {
3709       WidenedMask[i / 2] = M0 / 2;
3710       continue;
3711     }
3712 
3713     // Otherwise we can't safely widen the elements used in this shuffle.
3714     return false;
3715   }
3716   assert(WidenedMask.size() == Mask.size() / 2 &&
3717          "Incorrect size of mask after widening the elements!");
3718 
3719   return true;
3720 }
3721 
canWidenShuffleElements(ArrayRef<int> Mask,const APInt & Zeroable,bool V2IsZero,SmallVectorImpl<int> & WidenedMask)3722 static bool canWidenShuffleElements(ArrayRef<int> Mask,
3723                                     const APInt &Zeroable,
3724                                     bool V2IsZero,
3725                                     SmallVectorImpl<int> &WidenedMask) {
3726   // Create an alternative mask with info about zeroable elements.
3727   // Here we do not set undef elements as zeroable.
3728   SmallVector<int, 64> ZeroableMask(Mask);
3729   if (V2IsZero) {
3730     assert(!Zeroable.isZero() && "V2's non-undef elements are used?!");
3731     for (int i = 0, Size = Mask.size(); i != Size; ++i)
3732       if (Mask[i] != SM_SentinelUndef && Zeroable[i])
3733         ZeroableMask[i] = SM_SentinelZero;
3734   }
3735   return canWidenShuffleElements(ZeroableMask, WidenedMask);
3736 }
3737 
canWidenShuffleElements(ArrayRef<int> Mask)3738 static bool canWidenShuffleElements(ArrayRef<int> Mask) {
3739   SmallVector<int, 32> WidenedMask;
3740   return canWidenShuffleElements(Mask, WidenedMask);
3741 }
3742 
3743 // Attempt to narrow/widen shuffle mask until it matches the target number of
3744 // elements.
scaleShuffleElements(ArrayRef<int> Mask,unsigned NumDstElts,SmallVectorImpl<int> & ScaledMask)3745 static bool scaleShuffleElements(ArrayRef<int> Mask, unsigned NumDstElts,
3746                                  SmallVectorImpl<int> &ScaledMask) {
3747   unsigned NumSrcElts = Mask.size();
3748   assert(((NumSrcElts % NumDstElts) == 0 || (NumDstElts % NumSrcElts) == 0) &&
3749          "Illegal shuffle scale factor");
3750 
3751   // Narrowing is guaranteed to work.
3752   if (NumDstElts >= NumSrcElts) {
3753     int Scale = NumDstElts / NumSrcElts;
3754     llvm::narrowShuffleMaskElts(Scale, Mask, ScaledMask);
3755     return true;
3756   }
3757 
3758   // We have to repeat the widening until we reach the target size, but we can
3759   // split out the first widening as it sets up ScaledMask for us.
3760   if (canWidenShuffleElements(Mask, ScaledMask)) {
3761     while (ScaledMask.size() > NumDstElts) {
3762       SmallVector<int, 16> WidenedMask;
3763       if (!canWidenShuffleElements(ScaledMask, WidenedMask))
3764         return false;
3765       ScaledMask = std::move(WidenedMask);
3766     }
3767     return true;
3768   }
3769 
3770   return false;
3771 }
3772 
canScaleShuffleElements(ArrayRef<int> Mask,unsigned NumDstElts)3773 static bool canScaleShuffleElements(ArrayRef<int> Mask, unsigned NumDstElts) {
3774   SmallVector<int, 32> ScaledMask;
3775   return scaleShuffleElements(Mask, NumDstElts, ScaledMask);
3776 }
3777 
3778 /// Returns true if Elt is a constant zero or a floating point constant +0.0.
isZeroNode(SDValue Elt)3779 bool X86::isZeroNode(SDValue Elt) {
3780   return isNullConstant(Elt) || isNullFPConstant(Elt);
3781 }
3782 
3783 // Build a vector of constants.
3784 // Use an UNDEF node if MaskElt == -1.
3785 // Split 64-bit constants in the 32-bit mode.
getConstVector(ArrayRef<int> Values,MVT VT,SelectionDAG & DAG,const SDLoc & dl,bool IsMask=false)3786 static SDValue getConstVector(ArrayRef<int> Values, MVT VT, SelectionDAG &DAG,
3787                               const SDLoc &dl, bool IsMask = false) {
3788 
3789   SmallVector<SDValue, 32>  Ops;
3790   bool Split = false;
3791 
3792   MVT ConstVecVT = VT;
3793   unsigned NumElts = VT.getVectorNumElements();
3794   bool In64BitMode = DAG.getTargetLoweringInfo().isTypeLegal(MVT::i64);
3795   if (!In64BitMode && VT.getVectorElementType() == MVT::i64) {
3796     ConstVecVT = MVT::getVectorVT(MVT::i32, NumElts * 2);
3797     Split = true;
3798   }
3799 
3800   MVT EltVT = ConstVecVT.getVectorElementType();
3801   for (unsigned i = 0; i < NumElts; ++i) {
3802     bool IsUndef = Values[i] < 0 && IsMask;
3803     SDValue OpNode = IsUndef ? DAG.getUNDEF(EltVT) :
3804       DAG.getConstant(Values[i], dl, EltVT);
3805     Ops.push_back(OpNode);
3806     if (Split)
3807       Ops.push_back(IsUndef ? DAG.getUNDEF(EltVT) :
3808                     DAG.getConstant(0, dl, EltVT));
3809   }
3810   SDValue ConstsNode = DAG.getBuildVector(ConstVecVT, dl, Ops);
3811   if (Split)
3812     ConstsNode = DAG.getBitcast(VT, ConstsNode);
3813   return ConstsNode;
3814 }
3815 
getConstVector(ArrayRef<APInt> Bits,const APInt & Undefs,MVT VT,SelectionDAG & DAG,const SDLoc & dl)3816 static SDValue getConstVector(ArrayRef<APInt> Bits, const APInt &Undefs,
3817                               MVT VT, SelectionDAG &DAG, const SDLoc &dl) {
3818   assert(Bits.size() == Undefs.getBitWidth() &&
3819          "Unequal constant and undef arrays");
3820   SmallVector<SDValue, 32> Ops;
3821   bool Split = false;
3822 
3823   MVT ConstVecVT = VT;
3824   unsigned NumElts = VT.getVectorNumElements();
3825   bool In64BitMode = DAG.getTargetLoweringInfo().isTypeLegal(MVT::i64);
3826   if (!In64BitMode && VT.getVectorElementType() == MVT::i64) {
3827     ConstVecVT = MVT::getVectorVT(MVT::i32, NumElts * 2);
3828     Split = true;
3829   }
3830 
3831   MVT EltVT = ConstVecVT.getVectorElementType();
3832   for (unsigned i = 0, e = Bits.size(); i != e; ++i) {
3833     if (Undefs[i]) {
3834       Ops.append(Split ? 2 : 1, DAG.getUNDEF(EltVT));
3835       continue;
3836     }
3837     const APInt &V = Bits[i];
3838     assert(V.getBitWidth() == VT.getScalarSizeInBits() && "Unexpected sizes");
3839     if (Split) {
3840       Ops.push_back(DAG.getConstant(V.trunc(32), dl, EltVT));
3841       Ops.push_back(DAG.getConstant(V.lshr(32).trunc(32), dl, EltVT));
3842     } else if (EltVT == MVT::f32) {
3843       APFloat FV(APFloat::IEEEsingle(), V);
3844       Ops.push_back(DAG.getConstantFP(FV, dl, EltVT));
3845     } else if (EltVT == MVT::f64) {
3846       APFloat FV(APFloat::IEEEdouble(), V);
3847       Ops.push_back(DAG.getConstantFP(FV, dl, EltVT));
3848     } else {
3849       Ops.push_back(DAG.getConstant(V, dl, EltVT));
3850     }
3851   }
3852 
3853   SDValue ConstsNode = DAG.getBuildVector(ConstVecVT, dl, Ops);
3854   return DAG.getBitcast(VT, ConstsNode);
3855 }
3856 
getConstVector(ArrayRef<APInt> Bits,MVT VT,SelectionDAG & DAG,const SDLoc & dl)3857 static SDValue getConstVector(ArrayRef<APInt> Bits, MVT VT,
3858                               SelectionDAG &DAG, const SDLoc &dl) {
3859   APInt Undefs = APInt::getZero(Bits.size());
3860   return getConstVector(Bits, Undefs, VT, DAG, dl);
3861 }
3862 
3863 /// Returns a vector of specified type with all zero elements.
getZeroVector(MVT VT,const X86Subtarget & Subtarget,SelectionDAG & DAG,const SDLoc & dl)3864 static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
3865                              SelectionDAG &DAG, const SDLoc &dl) {
3866   assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector() ||
3867           VT.getVectorElementType() == MVT::i1) &&
3868          "Unexpected vector type");
3869 
3870   // Try to build SSE/AVX zero vectors as <N x i32> bitcasted to their dest
3871   // type. This ensures they get CSE'd. But if the integer type is not
3872   // available, use a floating-point +0.0 instead.
3873   SDValue Vec;
3874   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
3875   if (!Subtarget.hasSSE2() && VT.is128BitVector()) {
3876     Vec = DAG.getConstantFP(+0.0, dl, MVT::v4f32);
3877   } else if (VT.isFloatingPoint() &&
3878              TLI.isTypeLegal(VT.getVectorElementType())) {
3879     Vec = DAG.getConstantFP(+0.0, dl, VT);
3880   } else if (VT.getVectorElementType() == MVT::i1) {
3881     assert((Subtarget.hasBWI() || VT.getVectorNumElements() <= 16) &&
3882            "Unexpected vector type");
3883     Vec = DAG.getConstant(0, dl, VT);
3884   } else {
3885     unsigned Num32BitElts = VT.getSizeInBits() / 32;
3886     Vec = DAG.getConstant(0, dl, MVT::getVectorVT(MVT::i32, Num32BitElts));
3887   }
3888   return DAG.getBitcast(VT, Vec);
3889 }
3890 
3891 // Helper to determine if the ops are all the extracted subvectors come from a
3892 // single source. If we allow commute they don't have to be in order (Lo/Hi).
getSplitVectorSrc(SDValue LHS,SDValue RHS,bool AllowCommute)3893 static SDValue getSplitVectorSrc(SDValue LHS, SDValue RHS, bool AllowCommute) {
3894   if (LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
3895       RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
3896       LHS.getValueType() != RHS.getValueType() ||
3897       LHS.getOperand(0) != RHS.getOperand(0))
3898     return SDValue();
3899 
3900   SDValue Src = LHS.getOperand(0);
3901   if (Src.getValueSizeInBits() != (LHS.getValueSizeInBits() * 2))
3902     return SDValue();
3903 
3904   unsigned NumElts = LHS.getValueType().getVectorNumElements();
3905   if ((LHS.getConstantOperandAPInt(1) == 0 &&
3906        RHS.getConstantOperandAPInt(1) == NumElts) ||
3907       (AllowCommute && RHS.getConstantOperandAPInt(1) == 0 &&
3908        LHS.getConstantOperandAPInt(1) == NumElts))
3909     return Src;
3910 
3911   return SDValue();
3912 }
3913 
extractSubVector(SDValue Vec,unsigned IdxVal,SelectionDAG & DAG,const SDLoc & dl,unsigned vectorWidth)3914 static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
3915                                 const SDLoc &dl, unsigned vectorWidth) {
3916   EVT VT = Vec.getValueType();
3917   EVT ElVT = VT.getVectorElementType();
3918   unsigned Factor = VT.getSizeInBits() / vectorWidth;
3919   EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
3920                                   VT.getVectorNumElements() / Factor);
3921 
3922   // Extract the relevant vectorWidth bits.  Generate an EXTRACT_SUBVECTOR
3923   unsigned ElemsPerChunk = vectorWidth / ElVT.getSizeInBits();
3924   assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
3925 
3926   // This is the index of the first element of the vectorWidth-bit chunk
3927   // we want. Since ElemsPerChunk is a power of 2 just need to clear bits.
3928   IdxVal &= ~(ElemsPerChunk - 1);
3929 
3930   // If the input is a buildvector just emit a smaller one.
3931   if (Vec.getOpcode() == ISD::BUILD_VECTOR)
3932     return DAG.getBuildVector(ResultVT, dl,
3933                               Vec->ops().slice(IdxVal, ElemsPerChunk));
3934 
3935   // Check if we're extracting the upper undef of a widening pattern.
3936   if (Vec.getOpcode() == ISD::INSERT_SUBVECTOR && Vec.getOperand(0).isUndef() &&
3937       Vec.getOperand(1).getValueType().getVectorNumElements() <= IdxVal &&
3938       isNullConstant(Vec.getOperand(2)))
3939     return DAG.getUNDEF(ResultVT);
3940 
3941   SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, dl);
3942   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResultVT, Vec, VecIdx);
3943 }
3944 
3945 /// Generate a DAG to grab 128-bits from a vector > 128 bits.  This
3946 /// sets things up to match to an AVX VEXTRACTF128 / VEXTRACTI128
3947 /// or AVX-512 VEXTRACTF32x4 / VEXTRACTI32x4
3948 /// instructions or a simple subregister reference. Idx is an index in the
3949 /// 128 bits we want.  It need not be aligned to a 128-bit boundary.  That makes
3950 /// lowering EXTRACT_VECTOR_ELT operations easier.
extract128BitVector(SDValue Vec,unsigned IdxVal,SelectionDAG & DAG,const SDLoc & dl)3951 static SDValue extract128BitVector(SDValue Vec, unsigned IdxVal,
3952                                    SelectionDAG &DAG, const SDLoc &dl) {
3953   assert((Vec.getValueType().is256BitVector() ||
3954           Vec.getValueType().is512BitVector()) && "Unexpected vector size!");
3955   return extractSubVector(Vec, IdxVal, DAG, dl, 128);
3956 }
3957 
3958 /// Generate a DAG to grab 256-bits from a 512-bit vector.
extract256BitVector(SDValue Vec,unsigned IdxVal,SelectionDAG & DAG,const SDLoc & dl)3959 static SDValue extract256BitVector(SDValue Vec, unsigned IdxVal,
3960                                    SelectionDAG &DAG, const SDLoc &dl) {
3961   assert(Vec.getValueType().is512BitVector() && "Unexpected vector size!");
3962   return extractSubVector(Vec, IdxVal, DAG, dl, 256);
3963 }
3964 
insertSubVector(SDValue Result,SDValue Vec,unsigned IdxVal,SelectionDAG & DAG,const SDLoc & dl,unsigned vectorWidth)3965 static SDValue insertSubVector(SDValue Result, SDValue Vec, unsigned IdxVal,
3966                                SelectionDAG &DAG, const SDLoc &dl,
3967                                unsigned vectorWidth) {
3968   assert((vectorWidth == 128 || vectorWidth == 256) &&
3969          "Unsupported vector width");
3970   // Inserting UNDEF is Result
3971   if (Vec.isUndef())
3972     return Result;
3973   EVT VT = Vec.getValueType();
3974   EVT ElVT = VT.getVectorElementType();
3975   EVT ResultVT = Result.getValueType();
3976 
3977   // Insert the relevant vectorWidth bits.
3978   unsigned ElemsPerChunk = vectorWidth/ElVT.getSizeInBits();
3979   assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
3980 
3981   // This is the index of the first element of the vectorWidth-bit chunk
3982   // we want. Since ElemsPerChunk is a power of 2 just need to clear bits.
3983   IdxVal &= ~(ElemsPerChunk - 1);
3984 
3985   SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, dl);
3986   return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResultVT, Result, Vec, VecIdx);
3987 }
3988 
3989 /// Generate a DAG to put 128-bits into a vector > 128 bits.  This
3990 /// sets things up to match to an AVX VINSERTF128/VINSERTI128 or
3991 /// AVX-512 VINSERTF32x4/VINSERTI32x4 instructions or a
3992 /// simple superregister reference.  Idx is an index in the 128 bits
3993 /// we want.  It need not be aligned to a 128-bit boundary.  That makes
3994 /// lowering INSERT_VECTOR_ELT operations easier.
insert128BitVector(SDValue Result,SDValue Vec,unsigned IdxVal,SelectionDAG & DAG,const SDLoc & dl)3995 static SDValue insert128BitVector(SDValue Result, SDValue Vec, unsigned IdxVal,
3996                                   SelectionDAG &DAG, const SDLoc &dl) {
3997   assert(Vec.getValueType().is128BitVector() && "Unexpected vector size!");
3998   return insertSubVector(Result, Vec, IdxVal, DAG, dl, 128);
3999 }
4000 
4001 /// Widen a vector to a larger size with the same scalar type, with the new
4002 /// elements either zero or undef.
widenSubVector(MVT VT,SDValue Vec,bool ZeroNewElements,const X86Subtarget & Subtarget,SelectionDAG & DAG,const SDLoc & dl)4003 static SDValue widenSubVector(MVT VT, SDValue Vec, bool ZeroNewElements,
4004                               const X86Subtarget &Subtarget, SelectionDAG &DAG,
4005                               const SDLoc &dl) {
4006   assert(Vec.getValueSizeInBits().getFixedValue() <= VT.getFixedSizeInBits() &&
4007          Vec.getValueType().getScalarType() == VT.getScalarType() &&
4008          "Unsupported vector widening type");
4009   SDValue Res = ZeroNewElements ? getZeroVector(VT, Subtarget, DAG, dl)
4010                                 : DAG.getUNDEF(VT);
4011   return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, Res, Vec,
4012                      DAG.getIntPtrConstant(0, dl));
4013 }
4014 
4015 /// Widen a vector to a larger size with the same scalar type, with the new
4016 /// elements either zero or undef.
widenSubVector(SDValue Vec,bool ZeroNewElements,const X86Subtarget & Subtarget,SelectionDAG & DAG,const SDLoc & dl,unsigned WideSizeInBits)4017 static SDValue widenSubVector(SDValue Vec, bool ZeroNewElements,
4018                               const X86Subtarget &Subtarget, SelectionDAG &DAG,
4019                               const SDLoc &dl, unsigned WideSizeInBits) {
4020   assert(Vec.getValueSizeInBits() <= WideSizeInBits &&
4021          (WideSizeInBits % Vec.getScalarValueSizeInBits()) == 0 &&
4022          "Unsupported vector widening type");
4023   unsigned WideNumElts = WideSizeInBits / Vec.getScalarValueSizeInBits();
4024   MVT SVT = Vec.getSimpleValueType().getScalarType();
4025   MVT VT = MVT::getVectorVT(SVT, WideNumElts);
4026   return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
4027 }
4028 
4029 /// Widen a mask vector type to a minimum of v8i1/v16i1 to allow use of KSHIFT
4030 /// and bitcast with integer types.
widenMaskVectorType(MVT VT,const X86Subtarget & Subtarget)4031 static MVT widenMaskVectorType(MVT VT, const X86Subtarget &Subtarget) {
4032   assert(VT.getVectorElementType() == MVT::i1 && "Expected bool vector");
4033   unsigned NumElts = VT.getVectorNumElements();
4034   if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
4035     return Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
4036   return VT;
4037 }
4038 
4039 /// Widen a mask vector to a minimum of v8i1/v16i1 to allow use of KSHIFT and
4040 /// bitcast with integer types.
widenMaskVector(SDValue Vec,bool ZeroNewElements,const X86Subtarget & Subtarget,SelectionDAG & DAG,const SDLoc & dl)4041 static SDValue widenMaskVector(SDValue Vec, bool ZeroNewElements,
4042                                const X86Subtarget &Subtarget, SelectionDAG &DAG,
4043                                const SDLoc &dl) {
4044   MVT VT = widenMaskVectorType(Vec.getSimpleValueType(), Subtarget);
4045   return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
4046 }
4047 
4048 // Helper function to collect subvector ops that are concatenated together,
4049 // either by ISD::CONCAT_VECTORS or a ISD::INSERT_SUBVECTOR series.
4050 // The subvectors in Ops are guaranteed to be the same type.
collectConcatOps(SDNode * N,SmallVectorImpl<SDValue> & Ops,SelectionDAG & DAG)4051 static bool collectConcatOps(SDNode *N, SmallVectorImpl<SDValue> &Ops,
4052                              SelectionDAG &DAG) {
4053   assert(Ops.empty() && "Expected an empty ops vector");
4054 
4055   if (N->getOpcode() == ISD::CONCAT_VECTORS) {
4056     Ops.append(N->op_begin(), N->op_end());
4057     return true;
4058   }
4059 
4060   if (N->getOpcode() == ISD::INSERT_SUBVECTOR) {
4061     SDValue Src = N->getOperand(0);
4062     SDValue Sub = N->getOperand(1);
4063     const APInt &Idx = N->getConstantOperandAPInt(2);
4064     EVT VT = Src.getValueType();
4065     EVT SubVT = Sub.getValueType();
4066 
4067     if (VT.getSizeInBits() == (SubVT.getSizeInBits() * 2)) {
4068       // insert_subvector(undef, x, lo)
4069       if (Idx == 0 && Src.isUndef()) {
4070         Ops.push_back(Sub);
4071         Ops.push_back(DAG.getUNDEF(SubVT));
4072         return true;
4073       }
4074       if (Idx == (VT.getVectorNumElements() / 2)) {
4075         // insert_subvector(insert_subvector(undef, x, lo), y, hi)
4076         if (Src.getOpcode() == ISD::INSERT_SUBVECTOR &&
4077             Src.getOperand(1).getValueType() == SubVT &&
4078             isNullConstant(Src.getOperand(2))) {
4079           // Attempt to recurse into inner (matching) concats.
4080           SDValue Lo = Src.getOperand(1);
4081           SDValue Hi = Sub;
4082           SmallVector<SDValue, 2> LoOps, HiOps;
4083           if (collectConcatOps(Lo.getNode(), LoOps, DAG) &&
4084               collectConcatOps(Hi.getNode(), HiOps, DAG) &&
4085               LoOps.size() == HiOps.size()) {
4086             Ops.append(LoOps);
4087             Ops.append(HiOps);
4088             return true;
4089           }
4090           Ops.push_back(Lo);
4091           Ops.push_back(Hi);
4092           return true;
4093         }
4094         // insert_subvector(x, extract_subvector(x, lo), hi)
4095         if (Sub.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
4096             Sub.getOperand(0) == Src && isNullConstant(Sub.getOperand(1))) {
4097           Ops.append(2, Sub);
4098           return true;
4099         }
4100         // insert_subvector(undef, x, hi)
4101         if (Src.isUndef()) {
4102           Ops.push_back(DAG.getUNDEF(SubVT));
4103           Ops.push_back(Sub);
4104           return true;
4105         }
4106       }
4107     }
4108   }
4109 
4110   return false;
4111 }
4112 
4113 // Helper to check if \p V can be split into subvectors and the upper subvectors
4114 // are all undef. In which case return the lower subvector.
isUpperSubvectorUndef(SDValue V,const SDLoc & DL,SelectionDAG & DAG)4115 static SDValue isUpperSubvectorUndef(SDValue V, const SDLoc &DL,
4116                                      SelectionDAG &DAG) {
4117   SmallVector<SDValue> SubOps;
4118   if (!collectConcatOps(V.getNode(), SubOps, DAG))
4119     return SDValue();
4120 
4121   unsigned NumSubOps = SubOps.size();
4122   unsigned HalfNumSubOps = NumSubOps / 2;
4123   assert((NumSubOps % 2) == 0 && "Unexpected number of subvectors");
4124 
4125   ArrayRef<SDValue> UpperOps(SubOps.begin() + HalfNumSubOps, SubOps.end());
4126   if (any_of(UpperOps, [](SDValue Op) { return !Op.isUndef(); }))
4127     return SDValue();
4128 
4129   EVT HalfVT = V.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
4130   ArrayRef<SDValue> LowerOps(SubOps.begin(), SubOps.begin() + HalfNumSubOps);
4131   return DAG.getNode(ISD::CONCAT_VECTORS, DL, HalfVT, LowerOps);
4132 }
4133 
4134 // Helper to check if we can access all the constituent subvectors without any
4135 // extract ops.
isFreeToSplitVector(SDNode * N,SelectionDAG & DAG)4136 static bool isFreeToSplitVector(SDNode *N, SelectionDAG &DAG) {
4137   SmallVector<SDValue> Ops;
4138   return collectConcatOps(N, Ops, DAG);
4139 }
4140 
splitVector(SDValue Op,SelectionDAG & DAG,const SDLoc & dl)4141 static std::pair<SDValue, SDValue> splitVector(SDValue Op, SelectionDAG &DAG,
4142                                                const SDLoc &dl) {
4143   EVT VT = Op.getValueType();
4144   unsigned NumElems = VT.getVectorNumElements();
4145   unsigned SizeInBits = VT.getSizeInBits();
4146   assert((NumElems % 2) == 0 && (SizeInBits % 2) == 0 &&
4147          "Can't split odd sized vector");
4148 
4149   // If this is a splat value (with no-undefs) then use the lower subvector,
4150   // which should be a free extraction.
4151   SDValue Lo = extractSubVector(Op, 0, DAG, dl, SizeInBits / 2);
4152   if (DAG.isSplatValue(Op, /*AllowUndefs*/ false))
4153     return std::make_pair(Lo, Lo);
4154 
4155   SDValue Hi = extractSubVector(Op, NumElems / 2, DAG, dl, SizeInBits / 2);
4156   return std::make_pair(Lo, Hi);
4157 }
4158 
4159 /// Break an operation into 2 half sized ops and then concatenate the results.
splitVectorOp(SDValue Op,SelectionDAG & DAG,const SDLoc & dl)4160 static SDValue splitVectorOp(SDValue Op, SelectionDAG &DAG, const SDLoc &dl) {
4161   unsigned NumOps = Op.getNumOperands();
4162   EVT VT = Op.getValueType();
4163 
4164   // Extract the LHS Lo/Hi vectors
4165   SmallVector<SDValue> LoOps(NumOps, SDValue());
4166   SmallVector<SDValue> HiOps(NumOps, SDValue());
4167   for (unsigned I = 0; I != NumOps; ++I) {
4168     SDValue SrcOp = Op.getOperand(I);
4169     if (!SrcOp.getValueType().isVector()) {
4170       LoOps[I] = HiOps[I] = SrcOp;
4171       continue;
4172     }
4173     std::tie(LoOps[I], HiOps[I]) = splitVector(SrcOp, DAG, dl);
4174   }
4175 
4176   EVT LoVT, HiVT;
4177   std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
4178   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
4179                      DAG.getNode(Op.getOpcode(), dl, LoVT, LoOps),
4180                      DAG.getNode(Op.getOpcode(), dl, HiVT, HiOps));
4181 }
4182 
4183 /// Break an unary integer operation into 2 half sized ops and then
4184 /// concatenate the result back.
splitVectorIntUnary(SDValue Op,SelectionDAG & DAG,const SDLoc & dl)4185 static SDValue splitVectorIntUnary(SDValue Op, SelectionDAG &DAG,
4186                                    const SDLoc &dl) {
4187   // Make sure we only try to split 256/512-bit types to avoid creating
4188   // narrow vectors.
4189   [[maybe_unused]] EVT VT = Op.getValueType();
4190   assert((Op.getOperand(0).getValueType().is256BitVector() ||
4191           Op.getOperand(0).getValueType().is512BitVector()) &&
4192          (VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
4193   assert(Op.getOperand(0).getValueType().getVectorNumElements() ==
4194              VT.getVectorNumElements() &&
4195          "Unexpected VTs!");
4196   return splitVectorOp(Op, DAG, dl);
4197 }
4198 
4199 /// Break a binary integer operation into 2 half sized ops and then
4200 /// concatenate the result back.
splitVectorIntBinary(SDValue Op,SelectionDAG & DAG,const SDLoc & dl)4201 static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG,
4202                                     const SDLoc &dl) {
4203   // Assert that all the types match.
4204   [[maybe_unused]] EVT VT = Op.getValueType();
4205   assert(Op.getOperand(0).getValueType() == VT &&
4206          Op.getOperand(1).getValueType() == VT && "Unexpected VTs!");
4207   assert((VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
4208   return splitVectorOp(Op, DAG, dl);
4209 }
4210 
4211 // Helper for splitting operands of an operation to legal target size and
4212 // apply a function on each part.
4213 // Useful for operations that are available on SSE2 in 128-bit, on AVX2 in
4214 // 256-bit and on AVX512BW in 512-bit. The argument VT is the type used for
4215 // deciding if/how to split Ops. Ops elements do *not* have to be of type VT.
4216 // The argument Builder is a function that will be applied on each split part:
4217 // SDValue Builder(SelectionDAG&G, SDLoc, ArrayRef<SDValue>)
4218 template <typename F>
SplitOpsAndApply(SelectionDAG & DAG,const X86Subtarget & Subtarget,const SDLoc & DL,EVT VT,ArrayRef<SDValue> Ops,F Builder,bool CheckBWI=true)4219 SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
4220                          const SDLoc &DL, EVT VT, ArrayRef<SDValue> Ops,
4221                          F Builder, bool CheckBWI = true) {
4222   assert(Subtarget.hasSSE2() && "Target assumed to support at least SSE2");
4223   unsigned NumSubs = 1;
4224   if ((CheckBWI && Subtarget.useBWIRegs()) ||
4225       (!CheckBWI && Subtarget.useAVX512Regs())) {
4226     if (VT.getSizeInBits() > 512) {
4227       NumSubs = VT.getSizeInBits() / 512;
4228       assert((VT.getSizeInBits() % 512) == 0 && "Illegal vector size");
4229     }
4230   } else if (Subtarget.hasAVX2()) {
4231     if (VT.getSizeInBits() > 256) {
4232       NumSubs = VT.getSizeInBits() / 256;
4233       assert((VT.getSizeInBits() % 256) == 0 && "Illegal vector size");
4234     }
4235   } else {
4236     if (VT.getSizeInBits() > 128) {
4237       NumSubs = VT.getSizeInBits() / 128;
4238       assert((VT.getSizeInBits() % 128) == 0 && "Illegal vector size");
4239     }
4240   }
4241 
4242   if (NumSubs == 1)
4243     return Builder(DAG, DL, Ops);
4244 
4245   SmallVector<SDValue, 4> Subs;
4246   for (unsigned i = 0; i != NumSubs; ++i) {
4247     SmallVector<SDValue, 2> SubOps;
4248     for (SDValue Op : Ops) {
4249       EVT OpVT = Op.getValueType();
4250       unsigned NumSubElts = OpVT.getVectorNumElements() / NumSubs;
4251       unsigned SizeSub = OpVT.getSizeInBits() / NumSubs;
4252       SubOps.push_back(extractSubVector(Op, i * NumSubElts, DAG, DL, SizeSub));
4253     }
4254     Subs.push_back(Builder(DAG, DL, SubOps));
4255   }
4256   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
4257 }
4258 
4259 // Helper function that extends a non-512-bit vector op to 512-bits on non-VLX
4260 // targets.
getAVX512Node(unsigned Opcode,const SDLoc & DL,MVT VT,ArrayRef<SDValue> Ops,SelectionDAG & DAG,const X86Subtarget & Subtarget)4261 static SDValue getAVX512Node(unsigned Opcode, const SDLoc &DL, MVT VT,
4262                              ArrayRef<SDValue> Ops, SelectionDAG &DAG,
4263                              const X86Subtarget &Subtarget) {
4264   assert(Subtarget.hasAVX512() && "AVX512 target expected");
4265   MVT SVT = VT.getScalarType();
4266 
4267   // If we have a 32/64 splatted constant, splat it to DstTy to
4268   // encourage a foldable broadcast'd operand.
4269   auto MakeBroadcastOp = [&](SDValue Op, MVT OpVT, MVT DstVT) {
4270     unsigned OpEltSizeInBits = OpVT.getScalarSizeInBits();
4271     // AVX512 broadcasts 32/64-bit operands.
4272     // TODO: Support float once getAVX512Node is used by fp-ops.
4273     if (!OpVT.isInteger() || OpEltSizeInBits < 32 ||
4274         !DAG.getTargetLoweringInfo().isTypeLegal(SVT))
4275       return SDValue();
4276     // If we're not widening, don't bother if we're not bitcasting.
4277     if (OpVT == DstVT && Op.getOpcode() != ISD::BITCAST)
4278       return SDValue();
4279     if (auto *BV = dyn_cast<BuildVectorSDNode>(peekThroughBitcasts(Op))) {
4280       APInt SplatValue, SplatUndef;
4281       unsigned SplatBitSize;
4282       bool HasAnyUndefs;
4283       if (BV->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
4284                               HasAnyUndefs, OpEltSizeInBits) &&
4285           !HasAnyUndefs && SplatValue.getBitWidth() == OpEltSizeInBits)
4286         return DAG.getConstant(SplatValue, DL, DstVT);
4287     }
4288     return SDValue();
4289   };
4290 
4291   bool Widen = !(Subtarget.hasVLX() || VT.is512BitVector());
4292 
4293   MVT DstVT = VT;
4294   if (Widen)
4295     DstVT = MVT::getVectorVT(SVT, 512 / SVT.getSizeInBits());
4296 
4297   // Canonicalize src operands.
4298   SmallVector<SDValue> SrcOps(Ops.begin(), Ops.end());
4299   for (SDValue &Op : SrcOps) {
4300     MVT OpVT = Op.getSimpleValueType();
4301     // Just pass through scalar operands.
4302     if (!OpVT.isVector())
4303       continue;
4304     assert(OpVT == VT && "Vector type mismatch");
4305 
4306     if (SDValue BroadcastOp = MakeBroadcastOp(Op, OpVT, DstVT)) {
4307       Op = BroadcastOp;
4308       continue;
4309     }
4310 
4311     // Just widen the subvector by inserting into an undef wide vector.
4312     if (Widen)
4313       Op = widenSubVector(Op, false, Subtarget, DAG, DL, 512);
4314   }
4315 
4316   SDValue Res = DAG.getNode(Opcode, DL, DstVT, SrcOps);
4317 
4318   // Perform the 512-bit op then extract the bottom subvector.
4319   if (Widen)
4320     Res = extractSubVector(Res, 0, DAG, DL, VT.getSizeInBits());
4321   return Res;
4322 }
4323 
4324 /// Insert i1-subvector to i1-vector.
insert1BitVector(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)4325 static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
4326                                 const X86Subtarget &Subtarget) {
4327 
4328   SDLoc dl(Op);
4329   SDValue Vec = Op.getOperand(0);
4330   SDValue SubVec = Op.getOperand(1);
4331   SDValue Idx = Op.getOperand(2);
4332   unsigned IdxVal = Op.getConstantOperandVal(2);
4333 
4334   // Inserting undef is a nop. We can just return the original vector.
4335   if (SubVec.isUndef())
4336     return Vec;
4337 
4338   if (IdxVal == 0 && Vec.isUndef()) // the operation is legal
4339     return Op;
4340 
4341   MVT OpVT = Op.getSimpleValueType();
4342   unsigned NumElems = OpVT.getVectorNumElements();
4343   SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl);
4344 
4345   // Extend to natively supported kshift.
4346   MVT WideOpVT = widenMaskVectorType(OpVT, Subtarget);
4347 
4348   // Inserting into the lsbs of a zero vector is legal. ISel will insert shifts
4349   // if necessary.
4350   if (IdxVal == 0 && ISD::isBuildVectorAllZeros(Vec.getNode())) {
4351     // May need to promote to a legal type.
4352     Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT,
4353                      DAG.getConstant(0, dl, WideOpVT),
4354                      SubVec, Idx);
4355     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx);
4356   }
4357 
4358   MVT SubVecVT = SubVec.getSimpleValueType();
4359   unsigned SubVecNumElems = SubVecVT.getVectorNumElements();
4360   assert(IdxVal + SubVecNumElems <= NumElems &&
4361          IdxVal % SubVecVT.getSizeInBits() == 0 &&
4362          "Unexpected index value in INSERT_SUBVECTOR");
4363 
4364   SDValue Undef = DAG.getUNDEF(WideOpVT);
4365 
4366   if (IdxVal == 0) {
4367     // Zero lower bits of the Vec
4368     SDValue ShiftBits = DAG.getTargetConstant(SubVecNumElems, dl, MVT::i8);
4369     Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, Undef, Vec,
4370                       ZeroIdx);
4371     Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, ShiftBits);
4372     Vec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Vec, ShiftBits);
4373     // Merge them together, SubVec should be zero extended.
4374     SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT,
4375                          DAG.getConstant(0, dl, WideOpVT),
4376                          SubVec, ZeroIdx);
4377     Op = DAG.getNode(ISD::OR, dl, WideOpVT, Vec, SubVec);
4378     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx);
4379   }
4380 
4381   SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT,
4382                        Undef, SubVec, ZeroIdx);
4383 
4384   if (Vec.isUndef()) {
4385     assert(IdxVal != 0 && "Unexpected index");
4386     SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
4387                          DAG.getTargetConstant(IdxVal, dl, MVT::i8));
4388     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, SubVec, ZeroIdx);
4389   }
4390 
4391   if (ISD::isBuildVectorAllZeros(Vec.getNode())) {
4392     assert(IdxVal != 0 && "Unexpected index");
4393     // If upper elements of Vec are known undef, then just shift into place.
4394     if (llvm::all_of(Vec->ops().slice(IdxVal + SubVecNumElems),
4395                      [](SDValue V) { return V.isUndef(); })) {
4396       SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
4397                            DAG.getTargetConstant(IdxVal, dl, MVT::i8));
4398     } else {
4399       NumElems = WideOpVT.getVectorNumElements();
4400       unsigned ShiftLeft = NumElems - SubVecNumElems;
4401       unsigned ShiftRight = NumElems - SubVecNumElems - IdxVal;
4402       SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
4403                            DAG.getTargetConstant(ShiftLeft, dl, MVT::i8));
4404       if (ShiftRight != 0)
4405         SubVec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, SubVec,
4406                              DAG.getTargetConstant(ShiftRight, dl, MVT::i8));
4407     }
4408     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, SubVec, ZeroIdx);
4409   }
4410 
4411   // Simple case when we put subvector in the upper part
4412   if (IdxVal + SubVecNumElems == NumElems) {
4413     SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
4414                          DAG.getTargetConstant(IdxVal, dl, MVT::i8));
4415     if (SubVecNumElems * 2 == NumElems) {
4416       // Special case, use legal zero extending insert_subvector. This allows
4417       // isel to optimize when bits are known zero.
4418       Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVecVT, Vec, ZeroIdx);
4419       Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT,
4420                         DAG.getConstant(0, dl, WideOpVT),
4421                         Vec, ZeroIdx);
4422     } else {
4423       // Otherwise use explicit shifts to zero the bits.
4424       Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT,
4425                         Undef, Vec, ZeroIdx);
4426       NumElems = WideOpVT.getVectorNumElements();
4427       SDValue ShiftBits = DAG.getTargetConstant(NumElems - IdxVal, dl, MVT::i8);
4428       Vec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Vec, ShiftBits);
4429       Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, ShiftBits);
4430     }
4431     Op = DAG.getNode(ISD::OR, dl, WideOpVT, Vec, SubVec);
4432     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx);
4433   }
4434 
4435   // Inserting into the middle is more complicated.
4436 
4437   NumElems = WideOpVT.getVectorNumElements();
4438 
4439   // Widen the vector if needed.
4440   Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, Undef, Vec, ZeroIdx);
4441 
4442   unsigned ShiftLeft = NumElems - SubVecNumElems;
4443   unsigned ShiftRight = NumElems - SubVecNumElems - IdxVal;
4444 
4445   // Do an optimization for the most frequently used types.
4446   if (WideOpVT != MVT::v64i1 || Subtarget.is64Bit()) {
4447     APInt Mask0 = APInt::getBitsSet(NumElems, IdxVal, IdxVal + SubVecNumElems);
4448     Mask0.flipAllBits();
4449     SDValue CMask0 = DAG.getConstant(Mask0, dl, MVT::getIntegerVT(NumElems));
4450     SDValue VMask0 = DAG.getNode(ISD::BITCAST, dl, WideOpVT, CMask0);
4451     Vec = DAG.getNode(ISD::AND, dl, WideOpVT, Vec, VMask0);
4452     SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
4453                          DAG.getTargetConstant(ShiftLeft, dl, MVT::i8));
4454     SubVec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, SubVec,
4455                          DAG.getTargetConstant(ShiftRight, dl, MVT::i8));
4456     Op = DAG.getNode(ISD::OR, dl, WideOpVT, Vec, SubVec);
4457 
4458     // Reduce to original width if needed.
4459     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx);
4460   }
4461 
4462   // Clear the upper bits of the subvector and move it to its insert position.
4463   SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec,
4464                        DAG.getTargetConstant(ShiftLeft, dl, MVT::i8));
4465   SubVec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, SubVec,
4466                        DAG.getTargetConstant(ShiftRight, dl, MVT::i8));
4467 
4468   // Isolate the bits below the insertion point.
4469   unsigned LowShift = NumElems - IdxVal;
4470   SDValue Low = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Vec,
4471                             DAG.getTargetConstant(LowShift, dl, MVT::i8));
4472   Low = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Low,
4473                     DAG.getTargetConstant(LowShift, dl, MVT::i8));
4474 
4475   // Isolate the bits after the last inserted bit.
4476   unsigned HighShift = IdxVal + SubVecNumElems;
4477   SDValue High = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec,
4478                             DAG.getTargetConstant(HighShift, dl, MVT::i8));
4479   High = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, High,
4480                     DAG.getTargetConstant(HighShift, dl, MVT::i8));
4481 
4482   // Now OR all 3 pieces together.
4483   Vec = DAG.getNode(ISD::OR, dl, WideOpVT, Low, High);
4484   SubVec = DAG.getNode(ISD::OR, dl, WideOpVT, SubVec, Vec);
4485 
4486   // Reduce to original width if needed.
4487   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, SubVec, ZeroIdx);
4488 }
4489 
concatSubVectors(SDValue V1,SDValue V2,SelectionDAG & DAG,const SDLoc & dl)4490 static SDValue concatSubVectors(SDValue V1, SDValue V2, SelectionDAG &DAG,
4491                                 const SDLoc &dl) {
4492   assert(V1.getValueType() == V2.getValueType() && "subvector type mismatch");
4493   EVT SubVT = V1.getValueType();
4494   EVT SubSVT = SubVT.getScalarType();
4495   unsigned SubNumElts = SubVT.getVectorNumElements();
4496   unsigned SubVectorWidth = SubVT.getSizeInBits();
4497   EVT VT = EVT::getVectorVT(*DAG.getContext(), SubSVT, 2 * SubNumElts);
4498   SDValue V = insertSubVector(DAG.getUNDEF(VT), V1, 0, DAG, dl, SubVectorWidth);
4499   return insertSubVector(V, V2, SubNumElts, DAG, dl, SubVectorWidth);
4500 }
4501 
4502 /// Returns a vector of specified type with all bits set.
4503 /// Always build ones vectors as <4 x i32>, <8 x i32> or <16 x i32>.
4504 /// Then bitcast to their original type, ensuring they get CSE'd.
getOnesVector(EVT VT,SelectionDAG & DAG,const SDLoc & dl)4505 static SDValue getOnesVector(EVT VT, SelectionDAG &DAG, const SDLoc &dl) {
4506   assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
4507          "Expected a 128/256/512-bit vector type");
4508   unsigned NumElts = VT.getSizeInBits() / 32;
4509   SDValue Vec = DAG.getAllOnesConstant(dl, MVT::getVectorVT(MVT::i32, NumElts));
4510   return DAG.getBitcast(VT, Vec);
4511 }
4512 
getEXTEND_VECTOR_INREG(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue In,SelectionDAG & DAG)4513 static SDValue getEXTEND_VECTOR_INREG(unsigned Opcode, const SDLoc &DL, EVT VT,
4514                                       SDValue In, SelectionDAG &DAG) {
4515   EVT InVT = In.getValueType();
4516   assert(VT.isVector() && InVT.isVector() && "Expected vector VTs.");
4517   assert((ISD::ANY_EXTEND == Opcode || ISD::SIGN_EXTEND == Opcode ||
4518           ISD::ZERO_EXTEND == Opcode) &&
4519          "Unknown extension opcode");
4520 
4521   // For 256-bit vectors, we only need the lower (128-bit) input half.
4522   // For 512-bit vectors, we only need the lower input half or quarter.
4523   if (InVT.getSizeInBits() > 128) {
4524     assert(VT.getSizeInBits() == InVT.getSizeInBits() &&
4525            "Expected VTs to be the same size!");
4526     unsigned Scale = VT.getScalarSizeInBits() / InVT.getScalarSizeInBits();
4527     In = extractSubVector(In, 0, DAG, DL,
4528                           std::max(128U, (unsigned)VT.getSizeInBits() / Scale));
4529     InVT = In.getValueType();
4530   }
4531 
4532   if (VT.getVectorNumElements() != InVT.getVectorNumElements())
4533     Opcode = DAG.getOpcode_EXTEND_VECTOR_INREG(Opcode);
4534 
4535   return DAG.getNode(Opcode, DL, VT, In);
4536 }
4537 
4538 // Create OR(AND(LHS,MASK),AND(RHS,~MASK)) bit select pattern
getBitSelect(const SDLoc & DL,MVT VT,SDValue LHS,SDValue RHS,SDValue Mask,SelectionDAG & DAG)4539 static SDValue getBitSelect(const SDLoc &DL, MVT VT, SDValue LHS, SDValue RHS,
4540                             SDValue Mask, SelectionDAG &DAG) {
4541   LHS = DAG.getNode(ISD::AND, DL, VT, LHS, Mask);
4542   RHS = DAG.getNode(X86ISD::ANDNP, DL, VT, Mask, RHS);
4543   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
4544 }
4545 
createUnpackShuffleMask(EVT VT,SmallVectorImpl<int> & Mask,bool Lo,bool Unary)4546 void llvm::createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask,
4547                                    bool Lo, bool Unary) {
4548   assert(VT.getScalarType().isSimple() && (VT.getSizeInBits() % 128) == 0 &&
4549          "Illegal vector type to unpack");
4550   assert(Mask.empty() && "Expected an empty shuffle mask vector");
4551   int NumElts = VT.getVectorNumElements();
4552   int NumEltsInLane = 128 / VT.getScalarSizeInBits();
4553   for (int i = 0; i < NumElts; ++i) {
4554     unsigned LaneStart = (i / NumEltsInLane) * NumEltsInLane;
4555     int Pos = (i % NumEltsInLane) / 2 + LaneStart;
4556     Pos += (Unary ? 0 : NumElts * (i % 2));
4557     Pos += (Lo ? 0 : NumEltsInLane / 2);
4558     Mask.push_back(Pos);
4559   }
4560 }
4561 
4562 /// Similar to unpacklo/unpackhi, but without the 128-bit lane limitation
4563 /// imposed by AVX and specific to the unary pattern. Example:
4564 /// v8iX Lo --> <0, 0, 1, 1, 2, 2, 3, 3>
4565 /// v8iX Hi --> <4, 4, 5, 5, 6, 6, 7, 7>
createSplat2ShuffleMask(MVT VT,SmallVectorImpl<int> & Mask,bool Lo)4566 void llvm::createSplat2ShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
4567                                    bool Lo) {
4568   assert(Mask.empty() && "Expected an empty shuffle mask vector");
4569   int NumElts = VT.getVectorNumElements();
4570   for (int i = 0; i < NumElts; ++i) {
4571     int Pos = i / 2;
4572     Pos += (Lo ? 0 : NumElts / 2);
4573     Mask.push_back(Pos);
4574   }
4575 }
4576 
4577 // Attempt to constant fold, else just create a VECTOR_SHUFFLE.
getVectorShuffle(SelectionDAG & DAG,EVT VT,const SDLoc & dl,SDValue V1,SDValue V2,ArrayRef<int> Mask)4578 static SDValue getVectorShuffle(SelectionDAG &DAG, EVT VT, const SDLoc &dl,
4579                                 SDValue V1, SDValue V2, ArrayRef<int> Mask) {
4580   if ((ISD::isBuildVectorOfConstantSDNodes(V1.getNode()) || V1.isUndef()) &&
4581       (ISD::isBuildVectorOfConstantSDNodes(V2.getNode()) || V2.isUndef())) {
4582     SmallVector<SDValue> Ops(Mask.size(), DAG.getUNDEF(VT.getScalarType()));
4583     for (int I = 0, NumElts = Mask.size(); I != NumElts; ++I) {
4584       int M = Mask[I];
4585       if (M < 0)
4586         continue;
4587       SDValue V = (M < NumElts) ? V1 : V2;
4588       if (V.isUndef())
4589         continue;
4590       Ops[I] = V.getOperand(M % NumElts);
4591     }
4592     return DAG.getBuildVector(VT, dl, Ops);
4593   }
4594 
4595   return DAG.getVectorShuffle(VT, dl, V1, V2, Mask);
4596 }
4597 
4598 /// Returns a vector_shuffle node for an unpackl operation.
getUnpackl(SelectionDAG & DAG,const SDLoc & dl,EVT VT,SDValue V1,SDValue V2)4599 static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
4600                           SDValue V1, SDValue V2) {
4601   SmallVector<int, 8> Mask;
4602   createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false);
4603   return getVectorShuffle(DAG, VT, dl, V1, V2, Mask);
4604 }
4605 
4606 /// Returns a vector_shuffle node for an unpackh operation.
getUnpackh(SelectionDAG & DAG,const SDLoc & dl,EVT VT,SDValue V1,SDValue V2)4607 static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
4608                           SDValue V1, SDValue V2) {
4609   SmallVector<int, 8> Mask;
4610   createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false);
4611   return getVectorShuffle(DAG, VT, dl, V1, V2, Mask);
4612 }
4613 
4614 /// Returns a node that packs the LHS + RHS nodes together at half width.
4615 /// May return X86ISD::PACKSS/PACKUS, packing the top/bottom half.
4616 /// TODO: Add subvector splitting if/when we have a need for it.
getPack(SelectionDAG & DAG,const X86Subtarget & Subtarget,const SDLoc & dl,MVT VT,SDValue LHS,SDValue RHS,bool PackHiHalf=false)4617 static SDValue getPack(SelectionDAG &DAG, const X86Subtarget &Subtarget,
4618                        const SDLoc &dl, MVT VT, SDValue LHS, SDValue RHS,
4619                        bool PackHiHalf = false) {
4620   MVT OpVT = LHS.getSimpleValueType();
4621   unsigned EltSizeInBits = VT.getScalarSizeInBits();
4622   bool UsePackUS = Subtarget.hasSSE41() || EltSizeInBits == 8;
4623   assert(OpVT == RHS.getSimpleValueType() &&
4624          VT.getSizeInBits() == OpVT.getSizeInBits() &&
4625          (EltSizeInBits * 2) == OpVT.getScalarSizeInBits() &&
4626          "Unexpected PACK operand types");
4627   assert((EltSizeInBits == 8 || EltSizeInBits == 16 || EltSizeInBits == 32) &&
4628          "Unexpected PACK result type");
4629 
4630   // Rely on vector shuffles for vXi64 -> vXi32 packing.
4631   if (EltSizeInBits == 32) {
4632     SmallVector<int> PackMask;
4633     int Offset = PackHiHalf ? 1 : 0;
4634     int NumElts = VT.getVectorNumElements();
4635     for (int I = 0; I != NumElts; I += 4) {
4636       PackMask.push_back(I + Offset);
4637       PackMask.push_back(I + Offset + 2);
4638       PackMask.push_back(I + Offset + NumElts);
4639       PackMask.push_back(I + Offset + NumElts + 2);
4640     }
4641     return DAG.getVectorShuffle(VT, dl, DAG.getBitcast(VT, LHS),
4642                                 DAG.getBitcast(VT, RHS), PackMask);
4643   }
4644 
4645   // See if we already have sufficient leading bits for PACKSS/PACKUS.
4646   if (!PackHiHalf) {
4647     if (UsePackUS &&
4648         DAG.computeKnownBits(LHS).countMaxActiveBits() <= EltSizeInBits &&
4649         DAG.computeKnownBits(RHS).countMaxActiveBits() <= EltSizeInBits)
4650       return DAG.getNode(X86ISD::PACKUS, dl, VT, LHS, RHS);
4651 
4652     if (DAG.ComputeMaxSignificantBits(LHS) <= EltSizeInBits &&
4653         DAG.ComputeMaxSignificantBits(RHS) <= EltSizeInBits)
4654       return DAG.getNode(X86ISD::PACKSS, dl, VT, LHS, RHS);
4655   }
4656 
4657   // Fallback to sign/zero extending the requested half and pack.
4658   SDValue Amt = DAG.getTargetConstant(EltSizeInBits, dl, MVT::i8);
4659   if (UsePackUS) {
4660     if (PackHiHalf) {
4661       LHS = DAG.getNode(X86ISD::VSRLI, dl, OpVT, LHS, Amt);
4662       RHS = DAG.getNode(X86ISD::VSRLI, dl, OpVT, RHS, Amt);
4663     } else {
4664       SDValue Mask = DAG.getConstant((1ULL << EltSizeInBits) - 1, dl, OpVT);
4665       LHS = DAG.getNode(ISD::AND, dl, OpVT, LHS, Mask);
4666       RHS = DAG.getNode(ISD::AND, dl, OpVT, RHS, Mask);
4667     };
4668     return DAG.getNode(X86ISD::PACKUS, dl, VT, LHS, RHS);
4669   };
4670 
4671   if (!PackHiHalf) {
4672     LHS = DAG.getNode(X86ISD::VSHLI, dl, OpVT, LHS, Amt);
4673     RHS = DAG.getNode(X86ISD::VSHLI, dl, OpVT, RHS, Amt);
4674   }
4675   LHS = DAG.getNode(X86ISD::VSRAI, dl, OpVT, LHS, Amt);
4676   RHS = DAG.getNode(X86ISD::VSRAI, dl, OpVT, RHS, Amt);
4677   return DAG.getNode(X86ISD::PACKSS, dl, VT, LHS, RHS);
4678 }
4679 
4680 /// Return a vector_shuffle of the specified vector of zero or undef vector.
4681 /// This produces a shuffle where the low element of V2 is swizzled into the
4682 /// zero/undef vector, landing at element Idx.
4683 /// This produces a shuffle mask like 4,1,2,3 (idx=0) or  0,1,2,4 (idx=3).
getShuffleVectorZeroOrUndef(SDValue V2,int Idx,bool IsZero,const X86Subtarget & Subtarget,SelectionDAG & DAG)4684 static SDValue getShuffleVectorZeroOrUndef(SDValue V2, int Idx,
4685                                            bool IsZero,
4686                                            const X86Subtarget &Subtarget,
4687                                            SelectionDAG &DAG) {
4688   MVT VT = V2.getSimpleValueType();
4689   SDValue V1 = IsZero
4690     ? getZeroVector(VT, Subtarget, DAG, SDLoc(V2)) : DAG.getUNDEF(VT);
4691   int NumElems = VT.getVectorNumElements();
4692   SmallVector<int, 16> MaskVec(NumElems);
4693   for (int i = 0; i != NumElems; ++i)
4694     // If this is the insertion idx, put the low elt of V2 here.
4695     MaskVec[i] = (i == Idx) ? NumElems : i;
4696   return DAG.getVectorShuffle(VT, SDLoc(V2), V1, V2, MaskVec);
4697 }
4698 
getTargetConstantPoolFromBasePtr(SDValue Ptr)4699 static ConstantPoolSDNode *getTargetConstantPoolFromBasePtr(SDValue Ptr) {
4700   if (Ptr.getOpcode() == X86ISD::Wrapper ||
4701       Ptr.getOpcode() == X86ISD::WrapperRIP)
4702     Ptr = Ptr.getOperand(0);
4703   return dyn_cast<ConstantPoolSDNode>(Ptr);
4704 }
4705 
4706 // TODO: Add support for non-zero offsets.
getTargetConstantFromBasePtr(SDValue Ptr)4707 static const Constant *getTargetConstantFromBasePtr(SDValue Ptr) {
4708   ConstantPoolSDNode *CNode = getTargetConstantPoolFromBasePtr(Ptr);
4709   if (!CNode || CNode->isMachineConstantPoolEntry() || CNode->getOffset() != 0)
4710     return nullptr;
4711   return CNode->getConstVal();
4712 }
4713 
getTargetConstantFromNode(LoadSDNode * Load)4714 static const Constant *getTargetConstantFromNode(LoadSDNode *Load) {
4715   if (!Load || !ISD::isNormalLoad(Load))
4716     return nullptr;
4717   return getTargetConstantFromBasePtr(Load->getBasePtr());
4718 }
4719 
getTargetConstantFromNode(SDValue Op)4720 static const Constant *getTargetConstantFromNode(SDValue Op) {
4721   Op = peekThroughBitcasts(Op);
4722   return getTargetConstantFromNode(dyn_cast<LoadSDNode>(Op));
4723 }
4724 
4725 const Constant *
getTargetConstantFromLoad(LoadSDNode * LD) const4726 X86TargetLowering::getTargetConstantFromLoad(LoadSDNode *LD) const {
4727   assert(LD && "Unexpected null LoadSDNode");
4728   return getTargetConstantFromNode(LD);
4729 }
4730 
4731 // Extract raw constant bits from constant pools.
getTargetConstantBitsFromNode(SDValue Op,unsigned EltSizeInBits,APInt & UndefElts,SmallVectorImpl<APInt> & EltBits,bool AllowWholeUndefs=true,bool AllowPartialUndefs=false)4732 static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
4733                                           APInt &UndefElts,
4734                                           SmallVectorImpl<APInt> &EltBits,
4735                                           bool AllowWholeUndefs = true,
4736                                           bool AllowPartialUndefs = false) {
4737   assert(EltBits.empty() && "Expected an empty EltBits vector");
4738 
4739   Op = peekThroughBitcasts(Op);
4740 
4741   EVT VT = Op.getValueType();
4742   unsigned SizeInBits = VT.getSizeInBits();
4743   assert((SizeInBits % EltSizeInBits) == 0 && "Can't split constant!");
4744   unsigned NumElts = SizeInBits / EltSizeInBits;
4745 
4746   // Bitcast a source array of element bits to the target size.
4747   auto CastBitData = [&](APInt &UndefSrcElts, ArrayRef<APInt> SrcEltBits) {
4748     unsigned NumSrcElts = UndefSrcElts.getBitWidth();
4749     unsigned SrcEltSizeInBits = SrcEltBits[0].getBitWidth();
4750     assert((NumSrcElts * SrcEltSizeInBits) == SizeInBits &&
4751            "Constant bit sizes don't match");
4752 
4753     // Don't split if we don't allow undef bits.
4754     bool AllowUndefs = AllowWholeUndefs || AllowPartialUndefs;
4755     if (UndefSrcElts.getBoolValue() && !AllowUndefs)
4756       return false;
4757 
4758     // If we're already the right size, don't bother bitcasting.
4759     if (NumSrcElts == NumElts) {
4760       UndefElts = UndefSrcElts;
4761       EltBits.assign(SrcEltBits.begin(), SrcEltBits.end());
4762       return true;
4763     }
4764 
4765     // Extract all the undef/constant element data and pack into single bitsets.
4766     APInt UndefBits(SizeInBits, 0);
4767     APInt MaskBits(SizeInBits, 0);
4768 
4769     for (unsigned i = 0; i != NumSrcElts; ++i) {
4770       unsigned BitOffset = i * SrcEltSizeInBits;
4771       if (UndefSrcElts[i])
4772         UndefBits.setBits(BitOffset, BitOffset + SrcEltSizeInBits);
4773       MaskBits.insertBits(SrcEltBits[i], BitOffset);
4774     }
4775 
4776     // Split the undef/constant single bitset data into the target elements.
4777     UndefElts = APInt(NumElts, 0);
4778     EltBits.resize(NumElts, APInt(EltSizeInBits, 0));
4779 
4780     for (unsigned i = 0; i != NumElts; ++i) {
4781       unsigned BitOffset = i * EltSizeInBits;
4782       APInt UndefEltBits = UndefBits.extractBits(EltSizeInBits, BitOffset);
4783 
4784       // Only treat an element as UNDEF if all bits are UNDEF.
4785       if (UndefEltBits.isAllOnes()) {
4786         if (!AllowWholeUndefs)
4787           return false;
4788         UndefElts.setBit(i);
4789         continue;
4790       }
4791 
4792       // If only some bits are UNDEF then treat them as zero (or bail if not
4793       // supported).
4794       if (UndefEltBits.getBoolValue() && !AllowPartialUndefs)
4795         return false;
4796 
4797       EltBits[i] = MaskBits.extractBits(EltSizeInBits, BitOffset);
4798     }
4799     return true;
4800   };
4801 
4802   // Collect constant bits and insert into mask/undef bit masks.
4803   auto CollectConstantBits = [](const Constant *Cst, APInt &Mask, APInt &Undefs,
4804                                 unsigned UndefBitIndex) {
4805     if (!Cst)
4806       return false;
4807     if (isa<UndefValue>(Cst)) {
4808       Undefs.setBit(UndefBitIndex);
4809       return true;
4810     }
4811     if (auto *CInt = dyn_cast<ConstantInt>(Cst)) {
4812       Mask = CInt->getValue();
4813       return true;
4814     }
4815     if (auto *CFP = dyn_cast<ConstantFP>(Cst)) {
4816       Mask = CFP->getValueAPF().bitcastToAPInt();
4817       return true;
4818     }
4819     if (auto *CDS = dyn_cast<ConstantDataSequential>(Cst)) {
4820       Type *Ty = CDS->getType();
4821       Mask = APInt::getZero(Ty->getPrimitiveSizeInBits());
4822       Type *EltTy = CDS->getElementType();
4823       bool IsInteger = EltTy->isIntegerTy();
4824       bool IsFP =
4825           EltTy->isHalfTy() || EltTy->isFloatTy() || EltTy->isDoubleTy();
4826       if (!IsInteger && !IsFP)
4827         return false;
4828       unsigned EltBits = EltTy->getPrimitiveSizeInBits();
4829       for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I)
4830         if (IsInteger)
4831           Mask.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
4832         else
4833           Mask.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
4834                           I * EltBits);
4835       return true;
4836     }
4837     return false;
4838   };
4839 
4840   // Handle UNDEFs.
4841   if (Op.isUndef()) {
4842     APInt UndefSrcElts = APInt::getAllOnes(NumElts);
4843     SmallVector<APInt, 64> SrcEltBits(NumElts, APInt(EltSizeInBits, 0));
4844     return CastBitData(UndefSrcElts, SrcEltBits);
4845   }
4846 
4847   // Extract scalar constant bits.
4848   if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) {
4849     APInt UndefSrcElts = APInt::getZero(1);
4850     SmallVector<APInt, 64> SrcEltBits(1, Cst->getAPIntValue());
4851     return CastBitData(UndefSrcElts, SrcEltBits);
4852   }
4853   if (auto *Cst = dyn_cast<ConstantFPSDNode>(Op)) {
4854     APInt UndefSrcElts = APInt::getZero(1);
4855     APInt RawBits = Cst->getValueAPF().bitcastToAPInt();
4856     SmallVector<APInt, 64> SrcEltBits(1, RawBits);
4857     return CastBitData(UndefSrcElts, SrcEltBits);
4858   }
4859 
4860   // Extract constant bits from build vector.
4861   if (auto *BV = dyn_cast<BuildVectorSDNode>(Op)) {
4862     BitVector Undefs;
4863     SmallVector<APInt> SrcEltBits;
4864     unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
4865     if (BV->getConstantRawBits(true, SrcEltSizeInBits, SrcEltBits, Undefs)) {
4866       APInt UndefSrcElts = APInt::getZero(SrcEltBits.size());
4867       for (unsigned I = 0, E = SrcEltBits.size(); I != E; ++I)
4868         if (Undefs[I])
4869           UndefSrcElts.setBit(I);
4870       return CastBitData(UndefSrcElts, SrcEltBits);
4871     }
4872   }
4873 
4874   // Extract constant bits from constant pool vector.
4875   if (auto *Cst = getTargetConstantFromNode(Op)) {
4876     Type *CstTy = Cst->getType();
4877     unsigned CstSizeInBits = CstTy->getPrimitiveSizeInBits();
4878     if (!CstTy->isVectorTy() || (CstSizeInBits % SizeInBits) != 0)
4879       return false;
4880 
4881     unsigned SrcEltSizeInBits = CstTy->getScalarSizeInBits();
4882     unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits;
4883     if ((SizeInBits % SrcEltSizeInBits) != 0)
4884       return false;
4885 
4886     APInt UndefSrcElts(NumSrcElts, 0);
4887     SmallVector<APInt, 64> SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0));
4888     for (unsigned i = 0; i != NumSrcElts; ++i)
4889       if (!CollectConstantBits(Cst->getAggregateElement(i), SrcEltBits[i],
4890                                UndefSrcElts, i))
4891         return false;
4892 
4893     return CastBitData(UndefSrcElts, SrcEltBits);
4894   }
4895 
4896   // Extract constant bits from a broadcasted constant pool scalar.
4897   if (Op.getOpcode() == X86ISD::VBROADCAST_LOAD &&
4898       EltSizeInBits <= VT.getScalarSizeInBits()) {
4899     auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
4900     if (MemIntr->getMemoryVT().getStoreSizeInBits() != VT.getScalarSizeInBits())
4901       return false;
4902 
4903     SDValue Ptr = MemIntr->getBasePtr();
4904     if (const Constant *C = getTargetConstantFromBasePtr(Ptr)) {
4905       unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
4906       unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits;
4907 
4908       APInt UndefSrcElts(NumSrcElts, 0);
4909       SmallVector<APInt, 64> SrcEltBits(1, APInt(SrcEltSizeInBits, 0));
4910       if (CollectConstantBits(C, SrcEltBits[0], UndefSrcElts, 0)) {
4911         if (UndefSrcElts[0])
4912           UndefSrcElts.setBits(0, NumSrcElts);
4913         if (SrcEltBits[0].getBitWidth() != SrcEltSizeInBits)
4914           SrcEltBits[0] = SrcEltBits[0].trunc(SrcEltSizeInBits);
4915         SrcEltBits.append(NumSrcElts - 1, SrcEltBits[0]);
4916         return CastBitData(UndefSrcElts, SrcEltBits);
4917       }
4918     }
4919   }
4920 
4921   // Extract constant bits from a subvector broadcast.
4922   if (Op.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) {
4923     auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
4924     SDValue Ptr = MemIntr->getBasePtr();
4925     // The source constant may be larger than the subvector broadcast,
4926     // ensure we extract the correct subvector constants.
4927     if (const Constant *Cst = getTargetConstantFromBasePtr(Ptr)) {
4928       Type *CstTy = Cst->getType();
4929       unsigned CstSizeInBits = CstTy->getPrimitiveSizeInBits();
4930       unsigned SubVecSizeInBits = MemIntr->getMemoryVT().getStoreSizeInBits();
4931       if (!CstTy->isVectorTy() || (CstSizeInBits % SubVecSizeInBits) != 0 ||
4932           (SizeInBits % SubVecSizeInBits) != 0)
4933         return false;
4934       unsigned CstEltSizeInBits = CstTy->getScalarSizeInBits();
4935       unsigned NumSubElts = SubVecSizeInBits / CstEltSizeInBits;
4936       unsigned NumSubVecs = SizeInBits / SubVecSizeInBits;
4937       APInt UndefSubElts(NumSubElts, 0);
4938       SmallVector<APInt, 64> SubEltBits(NumSubElts * NumSubVecs,
4939                                         APInt(CstEltSizeInBits, 0));
4940       for (unsigned i = 0; i != NumSubElts; ++i) {
4941         if (!CollectConstantBits(Cst->getAggregateElement(i), SubEltBits[i],
4942                                  UndefSubElts, i))
4943           return false;
4944         for (unsigned j = 1; j != NumSubVecs; ++j)
4945           SubEltBits[i + (j * NumSubElts)] = SubEltBits[i];
4946       }
4947       UndefSubElts = APInt::getSplat(NumSubVecs * UndefSubElts.getBitWidth(),
4948                                      UndefSubElts);
4949       return CastBitData(UndefSubElts, SubEltBits);
4950     }
4951   }
4952 
4953   // Extract a rematerialized scalar constant insertion.
4954   if (Op.getOpcode() == X86ISD::VZEXT_MOVL &&
4955       Op.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR &&
4956       isa<ConstantSDNode>(Op.getOperand(0).getOperand(0))) {
4957     unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
4958     unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits;
4959 
4960     APInt UndefSrcElts(NumSrcElts, 0);
4961     SmallVector<APInt, 64> SrcEltBits;
4962     const APInt &C = Op.getOperand(0).getConstantOperandAPInt(0);
4963     SrcEltBits.push_back(C.zextOrTrunc(SrcEltSizeInBits));
4964     SrcEltBits.append(NumSrcElts - 1, APInt(SrcEltSizeInBits, 0));
4965     return CastBitData(UndefSrcElts, SrcEltBits);
4966   }
4967 
4968   // Insert constant bits from a base and sub vector sources.
4969   if (Op.getOpcode() == ISD::INSERT_SUBVECTOR) {
4970     // If bitcasts to larger elements we might lose track of undefs - don't
4971     // allow any to be safe.
4972     unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
4973     bool AllowUndefs = EltSizeInBits >= SrcEltSizeInBits;
4974 
4975     APInt UndefSrcElts, UndefSubElts;
4976     SmallVector<APInt, 32> EltSrcBits, EltSubBits;
4977     if (getTargetConstantBitsFromNode(Op.getOperand(1), SrcEltSizeInBits,
4978                                       UndefSubElts, EltSubBits,
4979                                       AllowWholeUndefs && AllowUndefs,
4980                                       AllowPartialUndefs && AllowUndefs) &&
4981         getTargetConstantBitsFromNode(Op.getOperand(0), SrcEltSizeInBits,
4982                                       UndefSrcElts, EltSrcBits,
4983                                       AllowWholeUndefs && AllowUndefs,
4984                                       AllowPartialUndefs && AllowUndefs)) {
4985       unsigned BaseIdx = Op.getConstantOperandVal(2);
4986       UndefSrcElts.insertBits(UndefSubElts, BaseIdx);
4987       for (unsigned i = 0, e = EltSubBits.size(); i != e; ++i)
4988         EltSrcBits[BaseIdx + i] = EltSubBits[i];
4989       return CastBitData(UndefSrcElts, EltSrcBits);
4990     }
4991   }
4992 
4993   // Extract constant bits from a subvector's source.
4994   if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
4995     // TODO - support extract_subvector through bitcasts.
4996     if (EltSizeInBits != VT.getScalarSizeInBits())
4997       return false;
4998 
4999     if (getTargetConstantBitsFromNode(Op.getOperand(0), EltSizeInBits,
5000                                       UndefElts, EltBits, AllowWholeUndefs,
5001                                       AllowPartialUndefs)) {
5002       EVT SrcVT = Op.getOperand(0).getValueType();
5003       unsigned NumSrcElts = SrcVT.getVectorNumElements();
5004       unsigned NumSubElts = VT.getVectorNumElements();
5005       unsigned BaseIdx = Op.getConstantOperandVal(1);
5006       UndefElts = UndefElts.extractBits(NumSubElts, BaseIdx);
5007       if ((BaseIdx + NumSubElts) != NumSrcElts)
5008         EltBits.erase(EltBits.begin() + BaseIdx + NumSubElts, EltBits.end());
5009       if (BaseIdx != 0)
5010         EltBits.erase(EltBits.begin(), EltBits.begin() + BaseIdx);
5011       return true;
5012     }
5013   }
5014 
5015   // Extract constant bits from shuffle node sources.
5016   if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(Op)) {
5017     // TODO - support shuffle through bitcasts.
5018     if (EltSizeInBits != VT.getScalarSizeInBits())
5019       return false;
5020 
5021     ArrayRef<int> Mask = SVN->getMask();
5022     if ((!AllowWholeUndefs || !AllowPartialUndefs) &&
5023         llvm::any_of(Mask, [](int M) { return M < 0; }))
5024       return false;
5025 
5026     APInt UndefElts0, UndefElts1;
5027     SmallVector<APInt, 32> EltBits0, EltBits1;
5028     if (isAnyInRange(Mask, 0, NumElts) &&
5029         !getTargetConstantBitsFromNode(Op.getOperand(0), EltSizeInBits,
5030                                        UndefElts0, EltBits0, AllowWholeUndefs,
5031                                        AllowPartialUndefs))
5032       return false;
5033     if (isAnyInRange(Mask, NumElts, 2 * NumElts) &&
5034         !getTargetConstantBitsFromNode(Op.getOperand(1), EltSizeInBits,
5035                                        UndefElts1, EltBits1, AllowWholeUndefs,
5036                                        AllowPartialUndefs))
5037       return false;
5038 
5039     UndefElts = APInt::getZero(NumElts);
5040     for (int i = 0; i != (int)NumElts; ++i) {
5041       int M = Mask[i];
5042       if (M < 0) {
5043         UndefElts.setBit(i);
5044         EltBits.push_back(APInt::getZero(EltSizeInBits));
5045       } else if (M < (int)NumElts) {
5046         if (UndefElts0[M])
5047           UndefElts.setBit(i);
5048         EltBits.push_back(EltBits0[M]);
5049       } else {
5050         if (UndefElts1[M - NumElts])
5051           UndefElts.setBit(i);
5052         EltBits.push_back(EltBits1[M - NumElts]);
5053       }
5054     }
5055     return true;
5056   }
5057 
5058   return false;
5059 }
5060 
5061 namespace llvm {
5062 namespace X86 {
isConstantSplat(SDValue Op,APInt & SplatVal,bool AllowPartialUndefs)5063 bool isConstantSplat(SDValue Op, APInt &SplatVal, bool AllowPartialUndefs) {
5064   APInt UndefElts;
5065   SmallVector<APInt, 16> EltBits;
5066   if (getTargetConstantBitsFromNode(
5067           Op, Op.getScalarValueSizeInBits(), UndefElts, EltBits,
5068           /*AllowWholeUndefs*/ true, AllowPartialUndefs)) {
5069     int SplatIndex = -1;
5070     for (int i = 0, e = EltBits.size(); i != e; ++i) {
5071       if (UndefElts[i])
5072         continue;
5073       if (0 <= SplatIndex && EltBits[i] != EltBits[SplatIndex]) {
5074         SplatIndex = -1;
5075         break;
5076       }
5077       SplatIndex = i;
5078     }
5079     if (0 <= SplatIndex) {
5080       SplatVal = EltBits[SplatIndex];
5081       return true;
5082     }
5083   }
5084 
5085   return false;
5086 }
5087 } // namespace X86
5088 } // namespace llvm
5089 
getTargetShuffleMaskIndices(SDValue MaskNode,unsigned MaskEltSizeInBits,SmallVectorImpl<uint64_t> & RawMask,APInt & UndefElts)5090 static bool getTargetShuffleMaskIndices(SDValue MaskNode,
5091                                         unsigned MaskEltSizeInBits,
5092                                         SmallVectorImpl<uint64_t> &RawMask,
5093                                         APInt &UndefElts) {
5094   // Extract the raw target constant bits.
5095   SmallVector<APInt, 64> EltBits;
5096   if (!getTargetConstantBitsFromNode(MaskNode, MaskEltSizeInBits, UndefElts,
5097                                      EltBits, /* AllowWholeUndefs */ true,
5098                                      /* AllowPartialUndefs */ false))
5099     return false;
5100 
5101   // Insert the extracted elements into the mask.
5102   for (const APInt &Elt : EltBits)
5103     RawMask.push_back(Elt.getZExtValue());
5104 
5105   return true;
5106 }
5107 
5108 // Match not(xor X, -1) -> X.
5109 // Match not(pcmpgt(C, X)) -> pcmpgt(X, C - 1).
5110 // Match not(extract_subvector(xor X, -1)) -> extract_subvector(X).
5111 // Match not(concat_vectors(xor X, -1, xor Y, -1)) -> concat_vectors(X, Y).
IsNOT(SDValue V,SelectionDAG & DAG)5112 static SDValue IsNOT(SDValue V, SelectionDAG &DAG) {
5113   V = peekThroughBitcasts(V);
5114   if (V.getOpcode() == ISD::XOR &&
5115       (ISD::isBuildVectorAllOnes(V.getOperand(1).getNode()) ||
5116        isAllOnesConstant(V.getOperand(1))))
5117     return V.getOperand(0);
5118   if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
5119       (isNullConstant(V.getOperand(1)) || V.getOperand(0).hasOneUse())) {
5120     if (SDValue Not = IsNOT(V.getOperand(0), DAG)) {
5121       Not = DAG.getBitcast(V.getOperand(0).getValueType(), Not);
5122       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(Not), V.getValueType(),
5123                          Not, V.getOperand(1));
5124     }
5125   }
5126   if (V.getOpcode() == X86ISD::PCMPGT &&
5127       !ISD::isBuildVectorAllZeros(V.getOperand(0).getNode()) &&
5128       !ISD::isBuildVectorAllOnes(V.getOperand(0).getNode()) &&
5129       V.getOperand(0).hasOneUse()) {
5130     APInt UndefElts;
5131     SmallVector<APInt> EltBits;
5132     if (getTargetConstantBitsFromNode(V.getOperand(0),
5133                                       V.getScalarValueSizeInBits(), UndefElts,
5134                                       EltBits)) {
5135       // Don't fold min_signed_value -> (min_signed_value - 1)
5136       bool MinSigned = false;
5137       for (APInt &Elt : EltBits) {
5138         MinSigned |= Elt.isMinSignedValue();
5139         Elt -= 1;
5140       }
5141       if (!MinSigned) {
5142         SDLoc DL(V);
5143         MVT VT = V.getSimpleValueType();
5144         return DAG.getNode(X86ISD::PCMPGT, DL, VT, V.getOperand(1),
5145                            getConstVector(EltBits, UndefElts, VT, DAG, DL));
5146       }
5147     }
5148   }
5149   SmallVector<SDValue, 2> CatOps;
5150   if (collectConcatOps(V.getNode(), CatOps, DAG)) {
5151     for (SDValue &CatOp : CatOps) {
5152       SDValue NotCat = IsNOT(CatOp, DAG);
5153       if (!NotCat) return SDValue();
5154       CatOp = DAG.getBitcast(CatOp.getValueType(), NotCat);
5155     }
5156     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(V), V.getValueType(), CatOps);
5157   }
5158   return SDValue();
5159 }
5160 
5161 /// Create a shuffle mask that matches the PACKSS/PACKUS truncation.
5162 /// A multi-stage pack shuffle mask is created by specifying NumStages > 1.
5163 /// Note: This ignores saturation, so inputs must be checked first.
createPackShuffleMask(MVT VT,SmallVectorImpl<int> & Mask,bool Unary,unsigned NumStages=1)5164 static void createPackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
5165                                   bool Unary, unsigned NumStages = 1) {
5166   assert(Mask.empty() && "Expected an empty shuffle mask vector");
5167   unsigned NumElts = VT.getVectorNumElements();
5168   unsigned NumLanes = VT.getSizeInBits() / 128;
5169   unsigned NumEltsPerLane = 128 / VT.getScalarSizeInBits();
5170   unsigned Offset = Unary ? 0 : NumElts;
5171   unsigned Repetitions = 1u << (NumStages - 1);
5172   unsigned Increment = 1u << NumStages;
5173   assert((NumEltsPerLane >> NumStages) > 0 && "Illegal packing compaction");
5174 
5175   for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
5176     for (unsigned Stage = 0; Stage != Repetitions; ++Stage) {
5177       for (unsigned Elt = 0; Elt != NumEltsPerLane; Elt += Increment)
5178         Mask.push_back(Elt + (Lane * NumEltsPerLane));
5179       for (unsigned Elt = 0; Elt != NumEltsPerLane; Elt += Increment)
5180         Mask.push_back(Elt + (Lane * NumEltsPerLane) + Offset);
5181     }
5182   }
5183 }
5184 
5185 // Split the demanded elts of a PACKSS/PACKUS node between its operands.
getPackDemandedElts(EVT VT,const APInt & DemandedElts,APInt & DemandedLHS,APInt & DemandedRHS)5186 static void getPackDemandedElts(EVT VT, const APInt &DemandedElts,
5187                                 APInt &DemandedLHS, APInt &DemandedRHS) {
5188   int NumLanes = VT.getSizeInBits() / 128;
5189   int NumElts = DemandedElts.getBitWidth();
5190   int NumInnerElts = NumElts / 2;
5191   int NumEltsPerLane = NumElts / NumLanes;
5192   int NumInnerEltsPerLane = NumInnerElts / NumLanes;
5193 
5194   DemandedLHS = APInt::getZero(NumInnerElts);
5195   DemandedRHS = APInt::getZero(NumInnerElts);
5196 
5197   // Map DemandedElts to the packed operands.
5198   for (int Lane = 0; Lane != NumLanes; ++Lane) {
5199     for (int Elt = 0; Elt != NumInnerEltsPerLane; ++Elt) {
5200       int OuterIdx = (Lane * NumEltsPerLane) + Elt;
5201       int InnerIdx = (Lane * NumInnerEltsPerLane) + Elt;
5202       if (DemandedElts[OuterIdx])
5203         DemandedLHS.setBit(InnerIdx);
5204       if (DemandedElts[OuterIdx + NumInnerEltsPerLane])
5205         DemandedRHS.setBit(InnerIdx);
5206     }
5207   }
5208 }
5209 
5210 // Split the demanded elts of a HADD/HSUB node between its operands.
getHorizDemandedElts(EVT VT,const APInt & DemandedElts,APInt & DemandedLHS,APInt & DemandedRHS)5211 static void getHorizDemandedElts(EVT VT, const APInt &DemandedElts,
5212                                  APInt &DemandedLHS, APInt &DemandedRHS) {
5213   getHorizDemandedEltsForFirstOperand(VT.getSizeInBits(), DemandedElts,
5214                                       DemandedLHS, DemandedRHS);
5215   DemandedLHS |= DemandedLHS << 1;
5216   DemandedRHS |= DemandedRHS << 1;
5217 }
5218 
5219 /// Calculates the shuffle mask corresponding to the target-specific opcode.
5220 /// If the mask could be calculated, returns it in \p Mask, returns the shuffle
5221 /// operands in \p Ops, and returns true.
5222 /// Sets \p IsUnary to true if only one source is used. Note that this will set
5223 /// IsUnary for shuffles which use a single input multiple times, and in those
5224 /// cases it will adjust the mask to only have indices within that single input.
5225 /// It is an error to call this with non-empty Mask/Ops vectors.
getTargetShuffleMask(SDValue N,bool AllowSentinelZero,SmallVectorImpl<SDValue> & Ops,SmallVectorImpl<int> & Mask,bool & IsUnary)5226 static bool getTargetShuffleMask(SDValue N, bool AllowSentinelZero,
5227                                  SmallVectorImpl<SDValue> &Ops,
5228                                  SmallVectorImpl<int> &Mask, bool &IsUnary) {
5229   if (!isTargetShuffle(N.getOpcode()))
5230     return false;
5231 
5232   MVT VT = N.getSimpleValueType();
5233   unsigned NumElems = VT.getVectorNumElements();
5234   unsigned MaskEltSize = VT.getScalarSizeInBits();
5235   SmallVector<uint64_t, 32> RawMask;
5236   APInt RawUndefs;
5237   uint64_t ImmN;
5238 
5239   assert(Mask.empty() && "getTargetShuffleMask expects an empty Mask vector");
5240   assert(Ops.empty() && "getTargetShuffleMask expects an empty Ops vector");
5241 
5242   IsUnary = false;
5243   bool IsFakeUnary = false;
5244   switch (N.getOpcode()) {
5245   case X86ISD::BLENDI:
5246     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5247     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5248     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5249     DecodeBLENDMask(NumElems, ImmN, Mask);
5250     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5251     break;
5252   case X86ISD::SHUFP:
5253     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5254     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5255     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5256     DecodeSHUFPMask(NumElems, MaskEltSize, ImmN, Mask);
5257     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5258     break;
5259   case X86ISD::INSERTPS:
5260     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5261     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5262     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5263     DecodeINSERTPSMask(ImmN, Mask);
5264     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5265     break;
5266   case X86ISD::EXTRQI:
5267     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5268     if (isa<ConstantSDNode>(N.getOperand(1)) &&
5269         isa<ConstantSDNode>(N.getOperand(2))) {
5270       int BitLen = N.getConstantOperandVal(1);
5271       int BitIdx = N.getConstantOperandVal(2);
5272       DecodeEXTRQIMask(NumElems, MaskEltSize, BitLen, BitIdx, Mask);
5273       IsUnary = true;
5274     }
5275     break;
5276   case X86ISD::INSERTQI:
5277     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5278     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5279     if (isa<ConstantSDNode>(N.getOperand(2)) &&
5280         isa<ConstantSDNode>(N.getOperand(3))) {
5281       int BitLen = N.getConstantOperandVal(2);
5282       int BitIdx = N.getConstantOperandVal(3);
5283       DecodeINSERTQIMask(NumElems, MaskEltSize, BitLen, BitIdx, Mask);
5284       IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5285     }
5286     break;
5287   case X86ISD::UNPCKH:
5288     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5289     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5290     DecodeUNPCKHMask(NumElems, MaskEltSize, Mask);
5291     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5292     break;
5293   case X86ISD::UNPCKL:
5294     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5295     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5296     DecodeUNPCKLMask(NumElems, MaskEltSize, Mask);
5297     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5298     break;
5299   case X86ISD::MOVHLPS:
5300     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5301     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5302     DecodeMOVHLPSMask(NumElems, Mask);
5303     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5304     break;
5305   case X86ISD::MOVLHPS:
5306     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5307     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5308     DecodeMOVLHPSMask(NumElems, Mask);
5309     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5310     break;
5311   case X86ISD::VALIGN:
5312     assert((VT.getScalarType() == MVT::i32 || VT.getScalarType() == MVT::i64) &&
5313            "Only 32-bit and 64-bit elements are supported!");
5314     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5315     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5316     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5317     DecodeVALIGNMask(NumElems, ImmN, Mask);
5318     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5319     Ops.push_back(N.getOperand(1));
5320     Ops.push_back(N.getOperand(0));
5321     break;
5322   case X86ISD::PALIGNR:
5323     assert(VT.getScalarType() == MVT::i8 && "Byte vector expected");
5324     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5325     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5326     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5327     DecodePALIGNRMask(NumElems, ImmN, Mask);
5328     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5329     Ops.push_back(N.getOperand(1));
5330     Ops.push_back(N.getOperand(0));
5331     break;
5332   case X86ISD::VSHLDQ:
5333     assert(VT.getScalarType() == MVT::i8 && "Byte vector expected");
5334     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5335     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5336     DecodePSLLDQMask(NumElems, ImmN, Mask);
5337     IsUnary = true;
5338     break;
5339   case X86ISD::VSRLDQ:
5340     assert(VT.getScalarType() == MVT::i8 && "Byte vector expected");
5341     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5342     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5343     DecodePSRLDQMask(NumElems, ImmN, Mask);
5344     IsUnary = true;
5345     break;
5346   case X86ISD::PSHUFD:
5347   case X86ISD::VPERMILPI:
5348     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5349     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5350     DecodePSHUFMask(NumElems, MaskEltSize, ImmN, Mask);
5351     IsUnary = true;
5352     break;
5353   case X86ISD::PSHUFHW:
5354     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5355     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5356     DecodePSHUFHWMask(NumElems, ImmN, Mask);
5357     IsUnary = true;
5358     break;
5359   case X86ISD::PSHUFLW:
5360     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5361     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5362     DecodePSHUFLWMask(NumElems, ImmN, Mask);
5363     IsUnary = true;
5364     break;
5365   case X86ISD::VZEXT_MOVL:
5366     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5367     DecodeZeroMoveLowMask(NumElems, Mask);
5368     IsUnary = true;
5369     break;
5370   case X86ISD::VBROADCAST:
5371     // We only decode broadcasts of same-sized vectors, peeking through to
5372     // extracted subvectors is likely to cause hasOneUse issues with
5373     // SimplifyDemandedBits etc.
5374     if (N.getOperand(0).getValueType() == VT) {
5375       DecodeVectorBroadcast(NumElems, Mask);
5376       IsUnary = true;
5377       break;
5378     }
5379     return false;
5380   case X86ISD::VPERMILPV: {
5381     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5382     IsUnary = true;
5383     SDValue MaskNode = N.getOperand(1);
5384     if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask,
5385                                     RawUndefs)) {
5386       DecodeVPERMILPMask(NumElems, MaskEltSize, RawMask, RawUndefs, Mask);
5387       break;
5388     }
5389     return false;
5390   }
5391   case X86ISD::PSHUFB: {
5392     assert(VT.getScalarType() == MVT::i8 && "Byte vector expected");
5393     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5394     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5395     IsUnary = true;
5396     SDValue MaskNode = N.getOperand(1);
5397     if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask, RawUndefs)) {
5398       DecodePSHUFBMask(RawMask, RawUndefs, Mask);
5399       break;
5400     }
5401     return false;
5402   }
5403   case X86ISD::VPERMI:
5404     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5405     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5406     DecodeVPERMMask(NumElems, ImmN, Mask);
5407     IsUnary = true;
5408     break;
5409   case X86ISD::MOVSS:
5410   case X86ISD::MOVSD:
5411   case X86ISD::MOVSH:
5412     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5413     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5414     DecodeScalarMoveMask(NumElems, /* IsLoad */ false, Mask);
5415     break;
5416   case X86ISD::VPERM2X128:
5417     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5418     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5419     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5420     DecodeVPERM2X128Mask(NumElems, ImmN, Mask);
5421     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5422     break;
5423   case X86ISD::SHUF128:
5424     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5425     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5426     ImmN = N.getConstantOperandVal(N.getNumOperands() - 1);
5427     decodeVSHUF64x2FamilyMask(NumElems, MaskEltSize, ImmN, Mask);
5428     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5429     break;
5430   case X86ISD::MOVSLDUP:
5431     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5432     DecodeMOVSLDUPMask(NumElems, Mask);
5433     IsUnary = true;
5434     break;
5435   case X86ISD::MOVSHDUP:
5436     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5437     DecodeMOVSHDUPMask(NumElems, Mask);
5438     IsUnary = true;
5439     break;
5440   case X86ISD::MOVDDUP:
5441     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5442     DecodeMOVDDUPMask(NumElems, Mask);
5443     IsUnary = true;
5444     break;
5445   case X86ISD::VPERMIL2: {
5446     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5447     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5448     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5449     SDValue MaskNode = N.getOperand(2);
5450     SDValue CtrlNode = N.getOperand(3);
5451     if (ConstantSDNode *CtrlOp = dyn_cast<ConstantSDNode>(CtrlNode)) {
5452       unsigned CtrlImm = CtrlOp->getZExtValue();
5453       if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask,
5454                                       RawUndefs)) {
5455         DecodeVPERMIL2PMask(NumElems, MaskEltSize, CtrlImm, RawMask, RawUndefs,
5456                             Mask);
5457         break;
5458       }
5459     }
5460     return false;
5461   }
5462   case X86ISD::VPPERM: {
5463     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5464     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5465     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(1);
5466     SDValue MaskNode = N.getOperand(2);
5467     if (getTargetShuffleMaskIndices(MaskNode, 8, RawMask, RawUndefs)) {
5468       DecodeVPPERMMask(RawMask, RawUndefs, Mask);
5469       break;
5470     }
5471     return false;
5472   }
5473   case X86ISD::VPERMV: {
5474     assert(N.getOperand(1).getValueType() == VT && "Unexpected value type");
5475     IsUnary = true;
5476     // Unlike most shuffle nodes, VPERMV's mask operand is operand 0.
5477     Ops.push_back(N.getOperand(1));
5478     SDValue MaskNode = N.getOperand(0);
5479     if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask,
5480                                     RawUndefs)) {
5481       DecodeVPERMVMask(RawMask, RawUndefs, Mask);
5482       break;
5483     }
5484     return false;
5485   }
5486   case X86ISD::VPERMV3: {
5487     assert(N.getOperand(0).getValueType() == VT && "Unexpected value type");
5488     assert(N.getOperand(2).getValueType() == VT && "Unexpected value type");
5489     IsUnary = IsFakeUnary = N.getOperand(0) == N.getOperand(2);
5490     // Unlike most shuffle nodes, VPERMV3's mask operand is the middle one.
5491     Ops.push_back(N.getOperand(0));
5492     Ops.push_back(N.getOperand(2));
5493     SDValue MaskNode = N.getOperand(1);
5494     if (getTargetShuffleMaskIndices(MaskNode, MaskEltSize, RawMask,
5495                                     RawUndefs)) {
5496       DecodeVPERMV3Mask(RawMask, RawUndefs, Mask);
5497       break;
5498     }
5499     return false;
5500   }
5501   default:
5502     llvm_unreachable("unknown target shuffle node");
5503   }
5504 
5505   // Empty mask indicates the decode failed.
5506   if (Mask.empty())
5507     return false;
5508 
5509   // Check if we're getting a shuffle mask with zero'd elements.
5510   if (!AllowSentinelZero && isAnyZero(Mask))
5511     return false;
5512 
5513   // If we have a fake unary shuffle, the shuffle mask is spread across two
5514   // inputs that are actually the same node. Re-map the mask to always point
5515   // into the first input.
5516   if (IsFakeUnary)
5517     for (int &M : Mask)
5518       if (M >= (int)Mask.size())
5519         M -= Mask.size();
5520 
5521   // If we didn't already add operands in the opcode-specific code, default to
5522   // adding 1 or 2 operands starting at 0.
5523   if (Ops.empty()) {
5524     Ops.push_back(N.getOperand(0));
5525     if (!IsUnary || IsFakeUnary)
5526       Ops.push_back(N.getOperand(1));
5527   }
5528 
5529   return true;
5530 }
5531 
5532 // Wrapper for getTargetShuffleMask with InUnary;
getTargetShuffleMask(SDValue N,bool AllowSentinelZero,SmallVectorImpl<SDValue> & Ops,SmallVectorImpl<int> & Mask)5533 static bool getTargetShuffleMask(SDValue N, bool AllowSentinelZero,
5534                                  SmallVectorImpl<SDValue> &Ops,
5535                                  SmallVectorImpl<int> &Mask) {
5536   bool IsUnary;
5537   return getTargetShuffleMask(N, AllowSentinelZero, Ops, Mask, IsUnary);
5538 }
5539 
5540 /// Compute whether each element of a shuffle is zeroable.
5541 ///
5542 /// A "zeroable" vector shuffle element is one which can be lowered to zero.
5543 /// Either it is an undef element in the shuffle mask, the element of the input
5544 /// referenced is undef, or the element of the input referenced is known to be
5545 /// zero. Many x86 shuffles can zero lanes cheaply and we often want to handle
5546 /// as many lanes with this technique as possible to simplify the remaining
5547 /// shuffle.
computeZeroableShuffleElements(ArrayRef<int> Mask,SDValue V1,SDValue V2,APInt & KnownUndef,APInt & KnownZero)5548 static void computeZeroableShuffleElements(ArrayRef<int> Mask,
5549                                            SDValue V1, SDValue V2,
5550                                            APInt &KnownUndef, APInt &KnownZero) {
5551   int Size = Mask.size();
5552   KnownUndef = KnownZero = APInt::getZero(Size);
5553 
5554   V1 = peekThroughBitcasts(V1);
5555   V2 = peekThroughBitcasts(V2);
5556 
5557   bool V1IsZero = ISD::isBuildVectorAllZeros(V1.getNode());
5558   bool V2IsZero = ISD::isBuildVectorAllZeros(V2.getNode());
5559 
5560   int VectorSizeInBits = V1.getValueSizeInBits();
5561   int ScalarSizeInBits = VectorSizeInBits / Size;
5562   assert(!(VectorSizeInBits % ScalarSizeInBits) && "Illegal shuffle mask size");
5563 
5564   for (int i = 0; i < Size; ++i) {
5565     int M = Mask[i];
5566     // Handle the easy cases.
5567     if (M < 0) {
5568       KnownUndef.setBit(i);
5569       continue;
5570     }
5571     if ((M >= 0 && M < Size && V1IsZero) || (M >= Size && V2IsZero)) {
5572       KnownZero.setBit(i);
5573       continue;
5574     }
5575 
5576     // Determine shuffle input and normalize the mask.
5577     SDValue V = M < Size ? V1 : V2;
5578     M %= Size;
5579 
5580     // Currently we can only search BUILD_VECTOR for UNDEF/ZERO elements.
5581     if (V.getOpcode() != ISD::BUILD_VECTOR)
5582       continue;
5583 
5584     // If the BUILD_VECTOR has fewer elements then the bitcasted portion of
5585     // the (larger) source element must be UNDEF/ZERO.
5586     if ((Size % V.getNumOperands()) == 0) {
5587       int Scale = Size / V->getNumOperands();
5588       SDValue Op = V.getOperand(M / Scale);
5589       if (Op.isUndef())
5590         KnownUndef.setBit(i);
5591       if (X86::isZeroNode(Op))
5592         KnownZero.setBit(i);
5593       else if (ConstantSDNode *Cst = dyn_cast<ConstantSDNode>(Op)) {
5594         APInt Val = Cst->getAPIntValue();
5595         Val = Val.extractBits(ScalarSizeInBits, (M % Scale) * ScalarSizeInBits);
5596         if (Val == 0)
5597           KnownZero.setBit(i);
5598       } else if (ConstantFPSDNode *Cst = dyn_cast<ConstantFPSDNode>(Op)) {
5599         APInt Val = Cst->getValueAPF().bitcastToAPInt();
5600         Val = Val.extractBits(ScalarSizeInBits, (M % Scale) * ScalarSizeInBits);
5601         if (Val == 0)
5602           KnownZero.setBit(i);
5603       }
5604       continue;
5605     }
5606 
5607     // If the BUILD_VECTOR has more elements then all the (smaller) source
5608     // elements must be UNDEF or ZERO.
5609     if ((V.getNumOperands() % Size) == 0) {
5610       int Scale = V->getNumOperands() / Size;
5611       bool AllUndef = true;
5612       bool AllZero = true;
5613       for (int j = 0; j < Scale; ++j) {
5614         SDValue Op = V.getOperand((M * Scale) + j);
5615         AllUndef &= Op.isUndef();
5616         AllZero &= X86::isZeroNode(Op);
5617       }
5618       if (AllUndef)
5619         KnownUndef.setBit(i);
5620       if (AllZero)
5621         KnownZero.setBit(i);
5622       continue;
5623     }
5624   }
5625 }
5626 
5627 /// Decode a target shuffle mask and inputs and see if any values are
5628 /// known to be undef or zero from their inputs.
5629 /// Returns true if the target shuffle mask was decoded.
5630 /// FIXME: Merge this with computeZeroableShuffleElements?
getTargetShuffleAndZeroables(SDValue N,SmallVectorImpl<int> & Mask,SmallVectorImpl<SDValue> & Ops,APInt & KnownUndef,APInt & KnownZero)5631 static bool getTargetShuffleAndZeroables(SDValue N, SmallVectorImpl<int> &Mask,
5632                                          SmallVectorImpl<SDValue> &Ops,
5633                                          APInt &KnownUndef, APInt &KnownZero) {
5634   bool IsUnary;
5635   if (!isTargetShuffle(N.getOpcode()))
5636     return false;
5637 
5638   MVT VT = N.getSimpleValueType();
5639   if (!getTargetShuffleMask(N, true, Ops, Mask, IsUnary))
5640     return false;
5641 
5642   int Size = Mask.size();
5643   SDValue V1 = Ops[0];
5644   SDValue V2 = IsUnary ? V1 : Ops[1];
5645   KnownUndef = KnownZero = APInt::getZero(Size);
5646 
5647   V1 = peekThroughBitcasts(V1);
5648   V2 = peekThroughBitcasts(V2);
5649 
5650   assert((VT.getSizeInBits() % Size) == 0 &&
5651          "Illegal split of shuffle value type");
5652   unsigned EltSizeInBits = VT.getSizeInBits() / Size;
5653 
5654   // Extract known constant input data.
5655   APInt UndefSrcElts[2];
5656   SmallVector<APInt, 32> SrcEltBits[2];
5657   bool IsSrcConstant[2] = {
5658       getTargetConstantBitsFromNode(V1, EltSizeInBits, UndefSrcElts[0],
5659                                     SrcEltBits[0], /*AllowWholeUndefs*/ true,
5660                                     /*AllowPartialUndefs*/ false),
5661       getTargetConstantBitsFromNode(V2, EltSizeInBits, UndefSrcElts[1],
5662                                     SrcEltBits[1], /*AllowWholeUndefs*/ true,
5663                                     /*AllowPartialUndefs*/ false)};
5664 
5665   for (int i = 0; i < Size; ++i) {
5666     int M = Mask[i];
5667 
5668     // Already decoded as SM_SentinelZero / SM_SentinelUndef.
5669     if (M < 0) {
5670       assert(isUndefOrZero(M) && "Unknown shuffle sentinel value!");
5671       if (SM_SentinelUndef == M)
5672         KnownUndef.setBit(i);
5673       if (SM_SentinelZero == M)
5674         KnownZero.setBit(i);
5675       continue;
5676     }
5677 
5678     // Determine shuffle input and normalize the mask.
5679     unsigned SrcIdx = M / Size;
5680     SDValue V = M < Size ? V1 : V2;
5681     M %= Size;
5682 
5683     // We are referencing an UNDEF input.
5684     if (V.isUndef()) {
5685       KnownUndef.setBit(i);
5686       continue;
5687     }
5688 
5689     // SCALAR_TO_VECTOR - only the first element is defined, and the rest UNDEF.
5690     // TODO: We currently only set UNDEF for integer types - floats use the same
5691     // registers as vectors and many of the scalar folded loads rely on the
5692     // SCALAR_TO_VECTOR pattern.
5693     if (V.getOpcode() == ISD::SCALAR_TO_VECTOR &&
5694         (Size % V.getValueType().getVectorNumElements()) == 0) {
5695       int Scale = Size / V.getValueType().getVectorNumElements();
5696       int Idx = M / Scale;
5697       if (Idx != 0 && !VT.isFloatingPoint())
5698         KnownUndef.setBit(i);
5699       else if (Idx == 0 && X86::isZeroNode(V.getOperand(0)))
5700         KnownZero.setBit(i);
5701       continue;
5702     }
5703 
5704     // INSERT_SUBVECTOR - to widen vectors we often insert them into UNDEF
5705     // base vectors.
5706     if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
5707       SDValue Vec = V.getOperand(0);
5708       int NumVecElts = Vec.getValueType().getVectorNumElements();
5709       if (Vec.isUndef() && Size == NumVecElts) {
5710         int Idx = V.getConstantOperandVal(2);
5711         int NumSubElts = V.getOperand(1).getValueType().getVectorNumElements();
5712         if (M < Idx || (Idx + NumSubElts) <= M)
5713           KnownUndef.setBit(i);
5714       }
5715       continue;
5716     }
5717 
5718     // Attempt to extract from the source's constant bits.
5719     if (IsSrcConstant[SrcIdx]) {
5720       if (UndefSrcElts[SrcIdx][M])
5721         KnownUndef.setBit(i);
5722       else if (SrcEltBits[SrcIdx][M] == 0)
5723         KnownZero.setBit(i);
5724     }
5725   }
5726 
5727   assert(VT.getVectorNumElements() == (unsigned)Size &&
5728          "Different mask size from vector size!");
5729   return true;
5730 }
5731 
5732 // Replace target shuffle mask elements with known undef/zero sentinels.
resolveTargetShuffleFromZeroables(SmallVectorImpl<int> & Mask,const APInt & KnownUndef,const APInt & KnownZero,bool ResolveKnownZeros=true)5733 static void resolveTargetShuffleFromZeroables(SmallVectorImpl<int> &Mask,
5734                                               const APInt &KnownUndef,
5735                                               const APInt &KnownZero,
5736                                               bool ResolveKnownZeros= true) {
5737   unsigned NumElts = Mask.size();
5738   assert(KnownUndef.getBitWidth() == NumElts &&
5739          KnownZero.getBitWidth() == NumElts && "Shuffle mask size mismatch");
5740 
5741   for (unsigned i = 0; i != NumElts; ++i) {
5742     if (KnownUndef[i])
5743       Mask[i] = SM_SentinelUndef;
5744     else if (ResolveKnownZeros && KnownZero[i])
5745       Mask[i] = SM_SentinelZero;
5746   }
5747 }
5748 
5749 // Extract target shuffle mask sentinel elements to known undef/zero bitmasks.
resolveZeroablesFromTargetShuffle(const SmallVectorImpl<int> & Mask,APInt & KnownUndef,APInt & KnownZero)5750 static void resolveZeroablesFromTargetShuffle(const SmallVectorImpl<int> &Mask,
5751                                               APInt &KnownUndef,
5752                                               APInt &KnownZero) {
5753   unsigned NumElts = Mask.size();
5754   KnownUndef = KnownZero = APInt::getZero(NumElts);
5755 
5756   for (unsigned i = 0; i != NumElts; ++i) {
5757     int M = Mask[i];
5758     if (SM_SentinelUndef == M)
5759       KnownUndef.setBit(i);
5760     if (SM_SentinelZero == M)
5761       KnownZero.setBit(i);
5762   }
5763 }
5764 
5765 // Attempt to create a shuffle mask from a VSELECT/BLENDV condition mask.
createShuffleMaskFromVSELECT(SmallVectorImpl<int> & Mask,SDValue Cond,bool IsBLENDV=false)5766 static bool createShuffleMaskFromVSELECT(SmallVectorImpl<int> &Mask,
5767                                          SDValue Cond, bool IsBLENDV = false) {
5768   EVT CondVT = Cond.getValueType();
5769   unsigned EltSizeInBits = CondVT.getScalarSizeInBits();
5770   unsigned NumElts = CondVT.getVectorNumElements();
5771 
5772   APInt UndefElts;
5773   SmallVector<APInt, 32> EltBits;
5774   if (!getTargetConstantBitsFromNode(Cond, EltSizeInBits, UndefElts, EltBits,
5775                                      /*AllowWholeUndefs*/ true,
5776                                      /*AllowPartialUndefs*/ false))
5777     return false;
5778 
5779   Mask.resize(NumElts, SM_SentinelUndef);
5780 
5781   for (int i = 0; i != (int)NumElts; ++i) {
5782     Mask[i] = i;
5783     // Arbitrarily choose from the 2nd operand if the select condition element
5784     // is undef.
5785     // TODO: Can we do better by matching patterns such as even/odd?
5786     if (UndefElts[i] || (!IsBLENDV && EltBits[i].isZero()) ||
5787         (IsBLENDV && EltBits[i].isNonNegative()))
5788       Mask[i] += NumElts;
5789   }
5790 
5791   return true;
5792 }
5793 
5794 // Forward declaration (for getFauxShuffleMask recursive check).
5795 static bool getTargetShuffleInputs(SDValue Op, const APInt &DemandedElts,
5796                                    SmallVectorImpl<SDValue> &Inputs,
5797                                    SmallVectorImpl<int> &Mask,
5798                                    const SelectionDAG &DAG, unsigned Depth,
5799                                    bool ResolveKnownElts);
5800 
5801 // Attempt to decode ops that could be represented as a shuffle mask.
5802 // The decoded shuffle mask may contain a different number of elements to the
5803 // destination value type.
5804 // TODO: Merge into getTargetShuffleInputs()
getFauxShuffleMask(SDValue N,const APInt & DemandedElts,SmallVectorImpl<int> & Mask,SmallVectorImpl<SDValue> & Ops,const SelectionDAG & DAG,unsigned Depth,bool ResolveKnownElts)5805 static bool getFauxShuffleMask(SDValue N, const APInt &DemandedElts,
5806                                SmallVectorImpl<int> &Mask,
5807                                SmallVectorImpl<SDValue> &Ops,
5808                                const SelectionDAG &DAG, unsigned Depth,
5809                                bool ResolveKnownElts) {
5810   Mask.clear();
5811   Ops.clear();
5812 
5813   MVT VT = N.getSimpleValueType();
5814   unsigned NumElts = VT.getVectorNumElements();
5815   unsigned NumSizeInBits = VT.getSizeInBits();
5816   unsigned NumBitsPerElt = VT.getScalarSizeInBits();
5817   if ((NumBitsPerElt % 8) != 0 || (NumSizeInBits % 8) != 0)
5818     return false;
5819   assert(NumElts == DemandedElts.getBitWidth() && "Unexpected vector size");
5820   unsigned NumSizeInBytes = NumSizeInBits / 8;
5821   unsigned NumBytesPerElt = NumBitsPerElt / 8;
5822 
5823   unsigned Opcode = N.getOpcode();
5824   switch (Opcode) {
5825   case ISD::VECTOR_SHUFFLE: {
5826     // Don't treat ISD::VECTOR_SHUFFLE as a target shuffle so decode it here.
5827     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(N)->getMask();
5828     if (isUndefOrInRange(ShuffleMask, 0, 2 * NumElts)) {
5829       Mask.append(ShuffleMask.begin(), ShuffleMask.end());
5830       Ops.push_back(N.getOperand(0));
5831       Ops.push_back(N.getOperand(1));
5832       return true;
5833     }
5834     return false;
5835   }
5836   case ISD::AND:
5837   case X86ISD::ANDNP: {
5838     // Attempt to decode as a per-byte mask.
5839     APInt UndefElts;
5840     SmallVector<APInt, 32> EltBits;
5841     SDValue N0 = N.getOperand(0);
5842     SDValue N1 = N.getOperand(1);
5843     bool IsAndN = (X86ISD::ANDNP == Opcode);
5844     uint64_t ZeroMask = IsAndN ? 255 : 0;
5845     if (!getTargetConstantBitsFromNode(IsAndN ? N0 : N1, 8, UndefElts, EltBits,
5846                                        /*AllowWholeUndefs*/ false,
5847                                        /*AllowPartialUndefs*/ false))
5848       return false;
5849     // We can't assume an undef src element gives an undef dst - the other src
5850     // might be zero.
5851     assert(UndefElts.isZero() && "Unexpected UNDEF element in AND/ANDNP mask");
5852     for (int i = 0, e = (int)EltBits.size(); i != e; ++i) {
5853       const APInt &ByteBits = EltBits[i];
5854       if (ByteBits != 0 && ByteBits != 255)
5855         return false;
5856       Mask.push_back(ByteBits == ZeroMask ? SM_SentinelZero : i);
5857     }
5858     Ops.push_back(IsAndN ? N1 : N0);
5859     return true;
5860   }
5861   case ISD::OR: {
5862     // Handle OR(SHUFFLE,SHUFFLE) case where one source is zero and the other
5863     // is a valid shuffle index.
5864     SDValue N0 = peekThroughBitcasts(N.getOperand(0));
5865     SDValue N1 = peekThroughBitcasts(N.getOperand(1));
5866     if (!N0.getValueType().isVector() || !N1.getValueType().isVector())
5867       return false;
5868 
5869     SmallVector<int, 64> SrcMask0, SrcMask1;
5870     SmallVector<SDValue, 2> SrcInputs0, SrcInputs1;
5871     APInt Demand0 = APInt::getAllOnes(N0.getValueType().getVectorNumElements());
5872     APInt Demand1 = APInt::getAllOnes(N1.getValueType().getVectorNumElements());
5873     if (!getTargetShuffleInputs(N0, Demand0, SrcInputs0, SrcMask0, DAG,
5874                                 Depth + 1, true) ||
5875         !getTargetShuffleInputs(N1, Demand1, SrcInputs1, SrcMask1, DAG,
5876                                 Depth + 1, true))
5877       return false;
5878 
5879     size_t MaskSize = std::max(SrcMask0.size(), SrcMask1.size());
5880     SmallVector<int, 64> Mask0, Mask1;
5881     narrowShuffleMaskElts(MaskSize / SrcMask0.size(), SrcMask0, Mask0);
5882     narrowShuffleMaskElts(MaskSize / SrcMask1.size(), SrcMask1, Mask1);
5883     for (int i = 0; i != (int)MaskSize; ++i) {
5884       // NOTE: Don't handle SM_SentinelUndef, as we can end up in infinite
5885       // loops converting between OR and BLEND shuffles due to
5886       // canWidenShuffleElements merging away undef elements, meaning we
5887       // fail to recognise the OR as the undef element isn't known zero.
5888       if (Mask0[i] == SM_SentinelZero && Mask1[i] == SM_SentinelZero)
5889         Mask.push_back(SM_SentinelZero);
5890       else if (Mask1[i] == SM_SentinelZero)
5891         Mask.push_back(i);
5892       else if (Mask0[i] == SM_SentinelZero)
5893         Mask.push_back(i + MaskSize);
5894       else
5895         return false;
5896     }
5897     Ops.push_back(N0);
5898     Ops.push_back(N1);
5899     return true;
5900   }
5901   case ISD::INSERT_SUBVECTOR: {
5902     SDValue Src = N.getOperand(0);
5903     SDValue Sub = N.getOperand(1);
5904     EVT SubVT = Sub.getValueType();
5905     unsigned NumSubElts = SubVT.getVectorNumElements();
5906     if (!N->isOnlyUserOf(Sub.getNode()))
5907       return false;
5908     SDValue SubBC = peekThroughBitcasts(Sub);
5909     uint64_t InsertIdx = N.getConstantOperandVal(2);
5910     // Handle INSERT_SUBVECTOR(SRC0, EXTRACT_SUBVECTOR(SRC1)).
5911     if (SubBC.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
5912         SubBC.getOperand(0).getValueSizeInBits() == NumSizeInBits) {
5913       uint64_t ExtractIdx = SubBC.getConstantOperandVal(1);
5914       SDValue SubBCSrc = SubBC.getOperand(0);
5915       unsigned NumSubSrcBCElts = SubBCSrc.getValueType().getVectorNumElements();
5916       unsigned MaxElts = std::max(NumElts, NumSubSrcBCElts);
5917       assert((MaxElts % NumElts) == 0 && (MaxElts % NumSubSrcBCElts) == 0 &&
5918              "Subvector valuetype mismatch");
5919       InsertIdx *= (MaxElts / NumElts);
5920       ExtractIdx *= (MaxElts / NumSubSrcBCElts);
5921       NumSubElts *= (MaxElts / NumElts);
5922       bool SrcIsUndef = Src.isUndef();
5923       for (int i = 0; i != (int)MaxElts; ++i)
5924         Mask.push_back(SrcIsUndef ? SM_SentinelUndef : i);
5925       for (int i = 0; i != (int)NumSubElts; ++i)
5926         Mask[InsertIdx + i] = (SrcIsUndef ? 0 : MaxElts) + ExtractIdx + i;
5927       if (!SrcIsUndef)
5928         Ops.push_back(Src);
5929       Ops.push_back(SubBCSrc);
5930       return true;
5931     }
5932     // Handle CONCAT(SUB0, SUB1).
5933     // Limit this to vXi64 512-bit vector cases to make the most of AVX512
5934     // cross lane shuffles.
5935     if (Depth > 0 && InsertIdx == NumSubElts && NumElts == (2 * NumSubElts) &&
5936         NumBitsPerElt == 64 && NumSizeInBits == 512 &&
5937         Src.getOpcode() == ISD::INSERT_SUBVECTOR &&
5938         Src.getOperand(0).isUndef() &&
5939         Src.getOperand(1).getValueType() == SubVT &&
5940         Src.getConstantOperandVal(2) == 0) {
5941       for (int i = 0; i != (int)NumSubElts; ++i)
5942         Mask.push_back(i);
5943       for (int i = 0; i != (int)NumSubElts; ++i)
5944         Mask.push_back(i + NumElts);
5945       Ops.push_back(Src.getOperand(1));
5946       Ops.push_back(Sub);
5947       return true;
5948     }
5949     // Handle INSERT_SUBVECTOR(SRC0, SHUFFLE(SRC1)).
5950     SmallVector<int, 64> SubMask;
5951     SmallVector<SDValue, 2> SubInputs;
5952     SDValue SubSrc = peekThroughOneUseBitcasts(Sub);
5953     EVT SubSrcVT = SubSrc.getValueType();
5954     if (!SubSrcVT.isVector())
5955       return false;
5956 
5957     APInt SubDemand = APInt::getAllOnes(SubSrcVT.getVectorNumElements());
5958     if (!getTargetShuffleInputs(SubSrc, SubDemand, SubInputs, SubMask, DAG,
5959                                 Depth + 1, ResolveKnownElts))
5960       return false;
5961 
5962     // Subvector shuffle inputs must not be larger than the subvector.
5963     if (llvm::any_of(SubInputs, [SubVT](SDValue SubInput) {
5964           return SubVT.getFixedSizeInBits() <
5965                  SubInput.getValueSizeInBits().getFixedValue();
5966         }))
5967       return false;
5968 
5969     if (SubMask.size() != NumSubElts) {
5970       assert(((SubMask.size() % NumSubElts) == 0 ||
5971               (NumSubElts % SubMask.size()) == 0) && "Illegal submask scale");
5972       if ((NumSubElts % SubMask.size()) == 0) {
5973         int Scale = NumSubElts / SubMask.size();
5974         SmallVector<int,64> ScaledSubMask;
5975         narrowShuffleMaskElts(Scale, SubMask, ScaledSubMask);
5976         SubMask = ScaledSubMask;
5977       } else {
5978         int Scale = SubMask.size() / NumSubElts;
5979         NumSubElts = SubMask.size();
5980         NumElts *= Scale;
5981         InsertIdx *= Scale;
5982       }
5983     }
5984     Ops.push_back(Src);
5985     Ops.append(SubInputs.begin(), SubInputs.end());
5986     if (ISD::isBuildVectorAllZeros(Src.getNode()))
5987       Mask.append(NumElts, SM_SentinelZero);
5988     else
5989       for (int i = 0; i != (int)NumElts; ++i)
5990         Mask.push_back(i);
5991     for (int i = 0; i != (int)NumSubElts; ++i) {
5992       int M = SubMask[i];
5993       if (0 <= M) {
5994         int InputIdx = M / NumSubElts;
5995         M = (NumElts * (1 + InputIdx)) + (M % NumSubElts);
5996       }
5997       Mask[i + InsertIdx] = M;
5998     }
5999     return true;
6000   }
6001   case X86ISD::PINSRB:
6002   case X86ISD::PINSRW:
6003   case ISD::SCALAR_TO_VECTOR:
6004   case ISD::INSERT_VECTOR_ELT: {
6005     // Match against a insert_vector_elt/scalar_to_vector of an extract from a
6006     // vector, for matching src/dst vector types.
6007     SDValue Scl = N.getOperand(Opcode == ISD::SCALAR_TO_VECTOR ? 0 : 1);
6008 
6009     unsigned DstIdx = 0;
6010     if (Opcode != ISD::SCALAR_TO_VECTOR) {
6011       // Check we have an in-range constant insertion index.
6012       if (!isa<ConstantSDNode>(N.getOperand(2)) ||
6013           N.getConstantOperandAPInt(2).uge(NumElts))
6014         return false;
6015       DstIdx = N.getConstantOperandVal(2);
6016 
6017       // Attempt to recognise an INSERT*(VEC, 0, DstIdx) shuffle pattern.
6018       if (X86::isZeroNode(Scl)) {
6019         Ops.push_back(N.getOperand(0));
6020         for (unsigned i = 0; i != NumElts; ++i)
6021           Mask.push_back(i == DstIdx ? SM_SentinelZero : (int)i);
6022         return true;
6023       }
6024     }
6025 
6026     // Peek through trunc/aext/zext/bitcast.
6027     // TODO: aext shouldn't require SM_SentinelZero padding.
6028     // TODO: handle shift of scalars.
6029     unsigned MinBitsPerElt = Scl.getScalarValueSizeInBits();
6030     while (Scl.getOpcode() == ISD::TRUNCATE ||
6031            Scl.getOpcode() == ISD::ANY_EXTEND ||
6032            Scl.getOpcode() == ISD::ZERO_EXTEND ||
6033            (Scl.getOpcode() == ISD::BITCAST &&
6034             Scl.getScalarValueSizeInBits() ==
6035                 Scl.getOperand(0).getScalarValueSizeInBits())) {
6036       Scl = Scl.getOperand(0);
6037       MinBitsPerElt =
6038           std::min<unsigned>(MinBitsPerElt, Scl.getScalarValueSizeInBits());
6039     }
6040     if ((MinBitsPerElt % 8) != 0)
6041       return false;
6042 
6043     // Attempt to find the source vector the scalar was extracted from.
6044     SDValue SrcExtract;
6045     if ((Scl.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
6046          Scl.getOpcode() == X86ISD::PEXTRW ||
6047          Scl.getOpcode() == X86ISD::PEXTRB) &&
6048         Scl.getOperand(0).getValueSizeInBits() == NumSizeInBits) {
6049       SrcExtract = Scl;
6050     }
6051     if (!SrcExtract || !isa<ConstantSDNode>(SrcExtract.getOperand(1)))
6052       return false;
6053 
6054     SDValue SrcVec = SrcExtract.getOperand(0);
6055     EVT SrcVT = SrcVec.getValueType();
6056     if (!SrcVT.getScalarType().isByteSized())
6057       return false;
6058     unsigned SrcIdx = SrcExtract.getConstantOperandVal(1);
6059     unsigned SrcByte = SrcIdx * (SrcVT.getScalarSizeInBits() / 8);
6060     unsigned DstByte = DstIdx * NumBytesPerElt;
6061     MinBitsPerElt =
6062         std::min<unsigned>(MinBitsPerElt, SrcVT.getScalarSizeInBits());
6063 
6064     // Create 'identity' byte level shuffle mask and then add inserted bytes.
6065     if (Opcode == ISD::SCALAR_TO_VECTOR) {
6066       Ops.push_back(SrcVec);
6067       Mask.append(NumSizeInBytes, SM_SentinelUndef);
6068     } else {
6069       Ops.push_back(SrcVec);
6070       Ops.push_back(N.getOperand(0));
6071       for (int i = 0; i != (int)NumSizeInBytes; ++i)
6072         Mask.push_back(NumSizeInBytes + i);
6073     }
6074 
6075     unsigned MinBytesPerElts = MinBitsPerElt / 8;
6076     MinBytesPerElts = std::min(MinBytesPerElts, NumBytesPerElt);
6077     for (unsigned i = 0; i != MinBytesPerElts; ++i)
6078       Mask[DstByte + i] = SrcByte + i;
6079     for (unsigned i = MinBytesPerElts; i < NumBytesPerElt; ++i)
6080       Mask[DstByte + i] = SM_SentinelZero;
6081     return true;
6082   }
6083   case X86ISD::PACKSS:
6084   case X86ISD::PACKUS: {
6085     SDValue N0 = N.getOperand(0);
6086     SDValue N1 = N.getOperand(1);
6087     assert(N0.getValueType().getVectorNumElements() == (NumElts / 2) &&
6088            N1.getValueType().getVectorNumElements() == (NumElts / 2) &&
6089            "Unexpected input value type");
6090 
6091     APInt EltsLHS, EltsRHS;
6092     getPackDemandedElts(VT, DemandedElts, EltsLHS, EltsRHS);
6093 
6094     // If we know input saturation won't happen (or we don't care for particular
6095     // lanes), we can treat this as a truncation shuffle.
6096     bool Offset0 = false, Offset1 = false;
6097     if (Opcode == X86ISD::PACKSS) {
6098       if ((!(N0.isUndef() || EltsLHS.isZero()) &&
6099            DAG.ComputeNumSignBits(N0, EltsLHS, Depth + 1) <= NumBitsPerElt) ||
6100           (!(N1.isUndef() || EltsRHS.isZero()) &&
6101            DAG.ComputeNumSignBits(N1, EltsRHS, Depth + 1) <= NumBitsPerElt))
6102         return false;
6103       // We can't easily fold ASHR into a shuffle, but if it was feeding a
6104       // PACKSS then it was likely being used for sign-extension for a
6105       // truncation, so just peek through and adjust the mask accordingly.
6106       if (N0.getOpcode() == X86ISD::VSRAI && N->isOnlyUserOf(N0.getNode()) &&
6107           N0.getConstantOperandAPInt(1) == NumBitsPerElt) {
6108         Offset0 = true;
6109         N0 = N0.getOperand(0);
6110       }
6111       if (N1.getOpcode() == X86ISD::VSRAI && N->isOnlyUserOf(N1.getNode()) &&
6112           N1.getConstantOperandAPInt(1) == NumBitsPerElt) {
6113         Offset1 = true;
6114         N1 = N1.getOperand(0);
6115       }
6116     } else {
6117       APInt ZeroMask = APInt::getHighBitsSet(2 * NumBitsPerElt, NumBitsPerElt);
6118       if ((!(N0.isUndef() || EltsLHS.isZero()) &&
6119            !DAG.MaskedValueIsZero(N0, ZeroMask, EltsLHS, Depth + 1)) ||
6120           (!(N1.isUndef() || EltsRHS.isZero()) &&
6121            !DAG.MaskedValueIsZero(N1, ZeroMask, EltsRHS, Depth + 1)))
6122         return false;
6123     }
6124 
6125     bool IsUnary = (N0 == N1);
6126 
6127     Ops.push_back(N0);
6128     if (!IsUnary)
6129       Ops.push_back(N1);
6130 
6131     createPackShuffleMask(VT, Mask, IsUnary);
6132 
6133     if (Offset0 || Offset1) {
6134       for (int &M : Mask)
6135         if ((Offset0 && isInRange(M, 0, NumElts)) ||
6136             (Offset1 && isInRange(M, NumElts, 2 * NumElts)))
6137           ++M;
6138     }
6139     return true;
6140   }
6141   case ISD::VSELECT:
6142   case X86ISD::BLENDV: {
6143     SDValue Cond = N.getOperand(0);
6144     if (createShuffleMaskFromVSELECT(Mask, Cond, Opcode == X86ISD::BLENDV)) {
6145       Ops.push_back(N.getOperand(1));
6146       Ops.push_back(N.getOperand(2));
6147       return true;
6148     }
6149     return false;
6150   }
6151   case X86ISD::VTRUNC: {
6152     SDValue Src = N.getOperand(0);
6153     EVT SrcVT = Src.getValueType();
6154     // Truncated source must be a simple vector.
6155     if (!SrcVT.isSimple() || (SrcVT.getSizeInBits() % 128) != 0 ||
6156         (SrcVT.getScalarSizeInBits() % 8) != 0)
6157       return false;
6158     unsigned NumSrcElts = SrcVT.getVectorNumElements();
6159     unsigned NumBitsPerSrcElt = SrcVT.getScalarSizeInBits();
6160     unsigned Scale = NumBitsPerSrcElt / NumBitsPerElt;
6161     assert((NumBitsPerSrcElt % NumBitsPerElt) == 0 && "Illegal truncation");
6162     for (unsigned i = 0; i != NumSrcElts; ++i)
6163       Mask.push_back(i * Scale);
6164     Mask.append(NumElts - NumSrcElts, SM_SentinelZero);
6165     Ops.push_back(Src);
6166     return true;
6167   }
6168   case X86ISD::VSHLI:
6169   case X86ISD::VSRLI: {
6170     uint64_t ShiftVal = N.getConstantOperandVal(1);
6171     // Out of range bit shifts are guaranteed to be zero.
6172     if (NumBitsPerElt <= ShiftVal) {
6173       Mask.append(NumElts, SM_SentinelZero);
6174       return true;
6175     }
6176 
6177     // We can only decode 'whole byte' bit shifts as shuffles.
6178     if ((ShiftVal % 8) != 0)
6179       break;
6180 
6181     uint64_t ByteShift = ShiftVal / 8;
6182     Ops.push_back(N.getOperand(0));
6183 
6184     // Clear mask to all zeros and insert the shifted byte indices.
6185     Mask.append(NumSizeInBytes, SM_SentinelZero);
6186 
6187     if (X86ISD::VSHLI == Opcode) {
6188       for (unsigned i = 0; i != NumSizeInBytes; i += NumBytesPerElt)
6189         for (unsigned j = ByteShift; j != NumBytesPerElt; ++j)
6190           Mask[i + j] = i + j - ByteShift;
6191     } else {
6192       for (unsigned i = 0; i != NumSizeInBytes; i += NumBytesPerElt)
6193         for (unsigned j = ByteShift; j != NumBytesPerElt; ++j)
6194           Mask[i + j - ByteShift] = i + j;
6195     }
6196     return true;
6197   }
6198   case X86ISD::VROTLI:
6199   case X86ISD::VROTRI: {
6200     // We can only decode 'whole byte' bit rotates as shuffles.
6201     uint64_t RotateVal = N.getConstantOperandAPInt(1).urem(NumBitsPerElt);
6202     if ((RotateVal % 8) != 0)
6203       return false;
6204     Ops.push_back(N.getOperand(0));
6205     int Offset = RotateVal / 8;
6206     Offset = (X86ISD::VROTLI == Opcode ? NumBytesPerElt - Offset : Offset);
6207     for (int i = 0; i != (int)NumElts; ++i) {
6208       int BaseIdx = i * NumBytesPerElt;
6209       for (int j = 0; j != (int)NumBytesPerElt; ++j) {
6210         Mask.push_back(BaseIdx + ((Offset + j) % NumBytesPerElt));
6211       }
6212     }
6213     return true;
6214   }
6215   case X86ISD::VBROADCAST: {
6216     SDValue Src = N.getOperand(0);
6217     if (!Src.getSimpleValueType().isVector()) {
6218       if (Src.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
6219           !isNullConstant(Src.getOperand(1)) ||
6220           Src.getOperand(0).getValueType().getScalarType() !=
6221               VT.getScalarType())
6222         return false;
6223       Src = Src.getOperand(0);
6224     }
6225     Ops.push_back(Src);
6226     Mask.append(NumElts, 0);
6227     return true;
6228   }
6229   case ISD::SIGN_EXTEND_VECTOR_INREG: {
6230     SDValue Src = N.getOperand(0);
6231     EVT SrcVT = Src.getValueType();
6232     unsigned NumBitsPerSrcElt = SrcVT.getScalarSizeInBits();
6233 
6234     // Extended source must be a simple vector.
6235     if (!SrcVT.isSimple() || (SrcVT.getSizeInBits() % 128) != 0 ||
6236         (NumBitsPerSrcElt % 8) != 0)
6237       return false;
6238 
6239     // We can only handle all-signbits extensions.
6240     APInt DemandedSrcElts =
6241         DemandedElts.zextOrTrunc(SrcVT.getVectorNumElements());
6242     if (DAG.ComputeNumSignBits(Src, DemandedSrcElts) != NumBitsPerSrcElt)
6243       return false;
6244 
6245     assert((NumBitsPerElt % NumBitsPerSrcElt) == 0 && "Unexpected extension");
6246     unsigned Scale = NumBitsPerElt / NumBitsPerSrcElt;
6247     for (unsigned I = 0; I != NumElts; ++I)
6248       Mask.append(Scale, I);
6249     Ops.push_back(Src);
6250     return true;
6251   }
6252   case ISD::ZERO_EXTEND:
6253   case ISD::ANY_EXTEND:
6254   case ISD::ZERO_EXTEND_VECTOR_INREG:
6255   case ISD::ANY_EXTEND_VECTOR_INREG: {
6256     SDValue Src = N.getOperand(0);
6257     EVT SrcVT = Src.getValueType();
6258 
6259     // Extended source must be a simple vector.
6260     if (!SrcVT.isSimple() || (SrcVT.getSizeInBits() % 128) != 0 ||
6261         (SrcVT.getScalarSizeInBits() % 8) != 0)
6262       return false;
6263 
6264     bool IsAnyExtend =
6265         (ISD::ANY_EXTEND == Opcode || ISD::ANY_EXTEND_VECTOR_INREG == Opcode);
6266     DecodeZeroExtendMask(SrcVT.getScalarSizeInBits(), NumBitsPerElt, NumElts,
6267                          IsAnyExtend, Mask);
6268     Ops.push_back(Src);
6269     return true;
6270   }
6271   }
6272 
6273   return false;
6274 }
6275 
6276 /// Removes unused/repeated shuffle source inputs and adjusts the shuffle mask.
resolveTargetShuffleInputsAndMask(SmallVectorImpl<SDValue> & Inputs,SmallVectorImpl<int> & Mask)6277 static void resolveTargetShuffleInputsAndMask(SmallVectorImpl<SDValue> &Inputs,
6278                                               SmallVectorImpl<int> &Mask) {
6279   int MaskWidth = Mask.size();
6280   SmallVector<SDValue, 16> UsedInputs;
6281   for (int i = 0, e = Inputs.size(); i < e; ++i) {
6282     int lo = UsedInputs.size() * MaskWidth;
6283     int hi = lo + MaskWidth;
6284 
6285     // Strip UNDEF input usage.
6286     if (Inputs[i].isUndef())
6287       for (int &M : Mask)
6288         if ((lo <= M) && (M < hi))
6289           M = SM_SentinelUndef;
6290 
6291     // Check for unused inputs.
6292     if (none_of(Mask, [lo, hi](int i) { return (lo <= i) && (i < hi); })) {
6293       for (int &M : Mask)
6294         if (lo <= M)
6295           M -= MaskWidth;
6296       continue;
6297     }
6298 
6299     // Check for repeated inputs.
6300     bool IsRepeat = false;
6301     for (int j = 0, ue = UsedInputs.size(); j != ue; ++j) {
6302       if (UsedInputs[j] != Inputs[i])
6303         continue;
6304       for (int &M : Mask)
6305         if (lo <= M)
6306           M = (M < hi) ? ((M - lo) + (j * MaskWidth)) : (M - MaskWidth);
6307       IsRepeat = true;
6308       break;
6309     }
6310     if (IsRepeat)
6311       continue;
6312 
6313     UsedInputs.push_back(Inputs[i]);
6314   }
6315   Inputs = UsedInputs;
6316 }
6317 
6318 /// Calls getTargetShuffleAndZeroables to resolve a target shuffle mask's inputs
6319 /// and then sets the SM_SentinelUndef and SM_SentinelZero values.
6320 /// Returns true if the target shuffle mask was decoded.
getTargetShuffleInputs(SDValue Op,const APInt & DemandedElts,SmallVectorImpl<SDValue> & Inputs,SmallVectorImpl<int> & Mask,APInt & KnownUndef,APInt & KnownZero,const SelectionDAG & DAG,unsigned Depth,bool ResolveKnownElts)6321 static bool getTargetShuffleInputs(SDValue Op, const APInt &DemandedElts,
6322                                    SmallVectorImpl<SDValue> &Inputs,
6323                                    SmallVectorImpl<int> &Mask,
6324                                    APInt &KnownUndef, APInt &KnownZero,
6325                                    const SelectionDAG &DAG, unsigned Depth,
6326                                    bool ResolveKnownElts) {
6327   if (Depth >= SelectionDAG::MaxRecursionDepth)
6328     return false; // Limit search depth.
6329 
6330   EVT VT = Op.getValueType();
6331   if (!VT.isSimple() || !VT.isVector())
6332     return false;
6333 
6334   if (getTargetShuffleAndZeroables(Op, Mask, Inputs, KnownUndef, KnownZero)) {
6335     if (ResolveKnownElts)
6336       resolveTargetShuffleFromZeroables(Mask, KnownUndef, KnownZero);
6337     return true;
6338   }
6339   if (getFauxShuffleMask(Op, DemandedElts, Mask, Inputs, DAG, Depth,
6340                          ResolveKnownElts)) {
6341     resolveZeroablesFromTargetShuffle(Mask, KnownUndef, KnownZero);
6342     return true;
6343   }
6344   return false;
6345 }
6346 
getTargetShuffleInputs(SDValue Op,const APInt & DemandedElts,SmallVectorImpl<SDValue> & Inputs,SmallVectorImpl<int> & Mask,const SelectionDAG & DAG,unsigned Depth,bool ResolveKnownElts)6347 static bool getTargetShuffleInputs(SDValue Op, const APInt &DemandedElts,
6348                                    SmallVectorImpl<SDValue> &Inputs,
6349                                    SmallVectorImpl<int> &Mask,
6350                                    const SelectionDAG &DAG, unsigned Depth,
6351                                    bool ResolveKnownElts) {
6352   APInt KnownUndef, KnownZero;
6353   return getTargetShuffleInputs(Op, DemandedElts, Inputs, Mask, KnownUndef,
6354                                 KnownZero, DAG, Depth, ResolveKnownElts);
6355 }
6356 
getTargetShuffleInputs(SDValue Op,SmallVectorImpl<SDValue> & Inputs,SmallVectorImpl<int> & Mask,const SelectionDAG & DAG,unsigned Depth=0,bool ResolveKnownElts=true)6357 static bool getTargetShuffleInputs(SDValue Op, SmallVectorImpl<SDValue> &Inputs,
6358                                    SmallVectorImpl<int> &Mask,
6359                                    const SelectionDAG &DAG, unsigned Depth = 0,
6360                                    bool ResolveKnownElts = true) {
6361   EVT VT = Op.getValueType();
6362   if (!VT.isSimple() || !VT.isVector())
6363     return false;
6364 
6365   unsigned NumElts = Op.getValueType().getVectorNumElements();
6366   APInt DemandedElts = APInt::getAllOnes(NumElts);
6367   return getTargetShuffleInputs(Op, DemandedElts, Inputs, Mask, DAG, Depth,
6368                                 ResolveKnownElts);
6369 }
6370 
6371 // Attempt to create a scalar/subvector broadcast from the base MemSDNode.
getBROADCAST_LOAD(unsigned Opcode,const SDLoc & DL,EVT VT,EVT MemVT,MemSDNode * Mem,unsigned Offset,SelectionDAG & DAG)6372 static SDValue getBROADCAST_LOAD(unsigned Opcode, const SDLoc &DL, EVT VT,
6373                                  EVT MemVT, MemSDNode *Mem, unsigned Offset,
6374                                  SelectionDAG &DAG) {
6375   assert((Opcode == X86ISD::VBROADCAST_LOAD ||
6376           Opcode == X86ISD::SUBV_BROADCAST_LOAD) &&
6377          "Unknown broadcast load type");
6378 
6379   // Ensure this is a simple (non-atomic, non-voltile), temporal read memop.
6380   if (!Mem || !Mem->readMem() || !Mem->isSimple() || Mem->isNonTemporal())
6381     return SDValue();
6382 
6383   SDValue Ptr = DAG.getMemBasePlusOffset(Mem->getBasePtr(),
6384                                          TypeSize::getFixed(Offset), DL);
6385   SDVTList Tys = DAG.getVTList(VT, MVT::Other);
6386   SDValue Ops[] = {Mem->getChain(), Ptr};
6387   SDValue BcstLd = DAG.getMemIntrinsicNode(
6388       Opcode, DL, Tys, Ops, MemVT,
6389       DAG.getMachineFunction().getMachineMemOperand(
6390           Mem->getMemOperand(), Offset, MemVT.getStoreSize()));
6391   DAG.makeEquivalentMemoryOrdering(SDValue(Mem, 1), BcstLd.getValue(1));
6392   return BcstLd;
6393 }
6394 
6395 /// Returns the scalar element that will make up the i'th
6396 /// element of the result of the vector shuffle.
getShuffleScalarElt(SDValue Op,unsigned Index,SelectionDAG & DAG,unsigned Depth)6397 static SDValue getShuffleScalarElt(SDValue Op, unsigned Index,
6398                                    SelectionDAG &DAG, unsigned Depth) {
6399   if (Depth >= SelectionDAG::MaxRecursionDepth)
6400     return SDValue(); // Limit search depth.
6401 
6402   EVT VT = Op.getValueType();
6403   unsigned Opcode = Op.getOpcode();
6404   unsigned NumElems = VT.getVectorNumElements();
6405 
6406   // Recurse into ISD::VECTOR_SHUFFLE node to find scalars.
6407   if (auto *SV = dyn_cast<ShuffleVectorSDNode>(Op)) {
6408     int Elt = SV->getMaskElt(Index);
6409 
6410     if (Elt < 0)
6411       return DAG.getUNDEF(VT.getVectorElementType());
6412 
6413     SDValue Src = (Elt < (int)NumElems) ? SV->getOperand(0) : SV->getOperand(1);
6414     return getShuffleScalarElt(Src, Elt % NumElems, DAG, Depth + 1);
6415   }
6416 
6417   // Recurse into target specific vector shuffles to find scalars.
6418   if (isTargetShuffle(Opcode)) {
6419     MVT ShufVT = VT.getSimpleVT();
6420     MVT ShufSVT = ShufVT.getVectorElementType();
6421     int NumElems = (int)ShufVT.getVectorNumElements();
6422     SmallVector<int, 16> ShuffleMask;
6423     SmallVector<SDValue, 16> ShuffleOps;
6424     if (!getTargetShuffleMask(Op, true, ShuffleOps, ShuffleMask))
6425       return SDValue();
6426 
6427     int Elt = ShuffleMask[Index];
6428     if (Elt == SM_SentinelZero)
6429       return ShufSVT.isInteger() ? DAG.getConstant(0, SDLoc(Op), ShufSVT)
6430                                  : DAG.getConstantFP(+0.0, SDLoc(Op), ShufSVT);
6431     if (Elt == SM_SentinelUndef)
6432       return DAG.getUNDEF(ShufSVT);
6433 
6434     assert(0 <= Elt && Elt < (2 * NumElems) && "Shuffle index out of range");
6435     SDValue Src = (Elt < NumElems) ? ShuffleOps[0] : ShuffleOps[1];
6436     return getShuffleScalarElt(Src, Elt % NumElems, DAG, Depth + 1);
6437   }
6438 
6439   // Recurse into insert_subvector base/sub vector to find scalars.
6440   if (Opcode == ISD::INSERT_SUBVECTOR) {
6441     SDValue Vec = Op.getOperand(0);
6442     SDValue Sub = Op.getOperand(1);
6443     uint64_t SubIdx = Op.getConstantOperandVal(2);
6444     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
6445 
6446     if (SubIdx <= Index && Index < (SubIdx + NumSubElts))
6447       return getShuffleScalarElt(Sub, Index - SubIdx, DAG, Depth + 1);
6448     return getShuffleScalarElt(Vec, Index, DAG, Depth + 1);
6449   }
6450 
6451   // Recurse into concat_vectors sub vector to find scalars.
6452   if (Opcode == ISD::CONCAT_VECTORS) {
6453     EVT SubVT = Op.getOperand(0).getValueType();
6454     unsigned NumSubElts = SubVT.getVectorNumElements();
6455     uint64_t SubIdx = Index / NumSubElts;
6456     uint64_t SubElt = Index % NumSubElts;
6457     return getShuffleScalarElt(Op.getOperand(SubIdx), SubElt, DAG, Depth + 1);
6458   }
6459 
6460   // Recurse into extract_subvector src vector to find scalars.
6461   if (Opcode == ISD::EXTRACT_SUBVECTOR) {
6462     SDValue Src = Op.getOperand(0);
6463     uint64_t SrcIdx = Op.getConstantOperandVal(1);
6464     return getShuffleScalarElt(Src, Index + SrcIdx, DAG, Depth + 1);
6465   }
6466 
6467   // We only peek through bitcasts of the same vector width.
6468   if (Opcode == ISD::BITCAST) {
6469     SDValue Src = Op.getOperand(0);
6470     EVT SrcVT = Src.getValueType();
6471     if (SrcVT.isVector() && SrcVT.getVectorNumElements() == NumElems)
6472       return getShuffleScalarElt(Src, Index, DAG, Depth + 1);
6473     return SDValue();
6474   }
6475 
6476   // Actual nodes that may contain scalar elements
6477 
6478   // For insert_vector_elt - either return the index matching scalar or recurse
6479   // into the base vector.
6480   if (Opcode == ISD::INSERT_VECTOR_ELT &&
6481       isa<ConstantSDNode>(Op.getOperand(2))) {
6482     if (Op.getConstantOperandAPInt(2) == Index)
6483       return Op.getOperand(1);
6484     return getShuffleScalarElt(Op.getOperand(0), Index, DAG, Depth + 1);
6485   }
6486 
6487   if (Opcode == ISD::SCALAR_TO_VECTOR)
6488     return (Index == 0) ? Op.getOperand(0)
6489                         : DAG.getUNDEF(VT.getVectorElementType());
6490 
6491   if (Opcode == ISD::BUILD_VECTOR)
6492     return Op.getOperand(Index);
6493 
6494   return SDValue();
6495 }
6496 
6497 // Use PINSRB/PINSRW/PINSRD to create a build vector.
LowerBuildVectorAsInsert(SDValue Op,const SDLoc & DL,const APInt & NonZeroMask,unsigned NumNonZero,unsigned NumZero,SelectionDAG & DAG,const X86Subtarget & Subtarget)6498 static SDValue LowerBuildVectorAsInsert(SDValue Op, const SDLoc &DL,
6499                                         const APInt &NonZeroMask,
6500                                         unsigned NumNonZero, unsigned NumZero,
6501                                         SelectionDAG &DAG,
6502                                         const X86Subtarget &Subtarget) {
6503   MVT VT = Op.getSimpleValueType();
6504   unsigned NumElts = VT.getVectorNumElements();
6505   assert(((VT == MVT::v8i16 && Subtarget.hasSSE2()) ||
6506           ((VT == MVT::v16i8 || VT == MVT::v4i32) && Subtarget.hasSSE41())) &&
6507          "Illegal vector insertion");
6508 
6509   SDValue V;
6510   bool First = true;
6511 
6512   for (unsigned i = 0; i < NumElts; ++i) {
6513     bool IsNonZero = NonZeroMask[i];
6514     if (!IsNonZero)
6515       continue;
6516 
6517     // If the build vector contains zeros or our first insertion is not the
6518     // first index then insert into zero vector to break any register
6519     // dependency else use SCALAR_TO_VECTOR.
6520     if (First) {
6521       First = false;
6522       if (NumZero || 0 != i)
6523         V = getZeroVector(VT, Subtarget, DAG, DL);
6524       else {
6525         assert(0 == i && "Expected insertion into zero-index");
6526         V = DAG.getAnyExtOrTrunc(Op.getOperand(i), DL, MVT::i32);
6527         V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i32, V);
6528         V = DAG.getBitcast(VT, V);
6529         continue;
6530       }
6531     }
6532     V = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, V, Op.getOperand(i),
6533                     DAG.getIntPtrConstant(i, DL));
6534   }
6535 
6536   return V;
6537 }
6538 
6539 /// Custom lower build_vector of v16i8.
LowerBuildVectorv16i8(SDValue Op,const SDLoc & DL,const APInt & NonZeroMask,unsigned NumNonZero,unsigned NumZero,SelectionDAG & DAG,const X86Subtarget & Subtarget)6540 static SDValue LowerBuildVectorv16i8(SDValue Op, const SDLoc &DL,
6541                                      const APInt &NonZeroMask,
6542                                      unsigned NumNonZero, unsigned NumZero,
6543                                      SelectionDAG &DAG,
6544                                      const X86Subtarget &Subtarget) {
6545   if (NumNonZero > 8 && !Subtarget.hasSSE41())
6546     return SDValue();
6547 
6548   // SSE4.1 - use PINSRB to insert each byte directly.
6549   if (Subtarget.hasSSE41())
6550     return LowerBuildVectorAsInsert(Op, DL, NonZeroMask, NumNonZero, NumZero,
6551                                     DAG, Subtarget);
6552 
6553   SDValue V;
6554 
6555   // Pre-SSE4.1 - merge byte pairs and insert with PINSRW.
6556   // If both the lowest 16-bits are non-zero, then convert to MOVD.
6557   if (!NonZeroMask.extractBits(2, 0).isZero() &&
6558       !NonZeroMask.extractBits(2, 2).isZero()) {
6559     for (unsigned I = 0; I != 4; ++I) {
6560       if (!NonZeroMask[I])
6561         continue;
6562       SDValue Elt = DAG.getZExtOrTrunc(Op.getOperand(I), DL, MVT::i32);
6563       if (I != 0)
6564         Elt = DAG.getNode(ISD::SHL, DL, MVT::i32, Elt,
6565                           DAG.getConstant(I * 8, DL, MVT::i8));
6566       V = V ? DAG.getNode(ISD::OR, DL, MVT::i32, V, Elt) : Elt;
6567     }
6568     assert(V && "Failed to fold v16i8 vector to zero");
6569     V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i32, V);
6570     V = DAG.getNode(X86ISD::VZEXT_MOVL, DL, MVT::v4i32, V);
6571     V = DAG.getBitcast(MVT::v8i16, V);
6572   }
6573   for (unsigned i = V ? 4 : 0; i < 16; i += 2) {
6574     bool ThisIsNonZero = NonZeroMask[i];
6575     bool NextIsNonZero = NonZeroMask[i + 1];
6576     if (!ThisIsNonZero && !NextIsNonZero)
6577       continue;
6578 
6579     SDValue Elt;
6580     if (ThisIsNonZero) {
6581       if (NumZero || NextIsNonZero)
6582         Elt = DAG.getZExtOrTrunc(Op.getOperand(i), DL, MVT::i32);
6583       else
6584         Elt = DAG.getAnyExtOrTrunc(Op.getOperand(i), DL, MVT::i32);
6585     }
6586 
6587     if (NextIsNonZero) {
6588       SDValue NextElt = Op.getOperand(i + 1);
6589       if (i == 0 && NumZero)
6590         NextElt = DAG.getZExtOrTrunc(NextElt, DL, MVT::i32);
6591       else
6592         NextElt = DAG.getAnyExtOrTrunc(NextElt, DL, MVT::i32);
6593       NextElt = DAG.getNode(ISD::SHL, DL, MVT::i32, NextElt,
6594                             DAG.getConstant(8, DL, MVT::i8));
6595       if (ThisIsNonZero)
6596         Elt = DAG.getNode(ISD::OR, DL, MVT::i32, NextElt, Elt);
6597       else
6598         Elt = NextElt;
6599     }
6600 
6601     // If our first insertion is not the first index or zeros are needed, then
6602     // insert into zero vector. Otherwise, use SCALAR_TO_VECTOR (leaves high
6603     // elements undefined).
6604     if (!V) {
6605       if (i != 0 || NumZero)
6606         V = getZeroVector(MVT::v8i16, Subtarget, DAG, DL);
6607       else {
6608         V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i32, Elt);
6609         V = DAG.getBitcast(MVT::v8i16, V);
6610         continue;
6611       }
6612     }
6613     Elt = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Elt);
6614     V = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v8i16, V, Elt,
6615                     DAG.getIntPtrConstant(i / 2, DL));
6616   }
6617 
6618   return DAG.getBitcast(MVT::v16i8, V);
6619 }
6620 
6621 /// Custom lower build_vector of v8i16.
LowerBuildVectorv8i16(SDValue Op,const SDLoc & DL,const APInt & NonZeroMask,unsigned NumNonZero,unsigned NumZero,SelectionDAG & DAG,const X86Subtarget & Subtarget)6622 static SDValue LowerBuildVectorv8i16(SDValue Op, const SDLoc &DL,
6623                                      const APInt &NonZeroMask,
6624                                      unsigned NumNonZero, unsigned NumZero,
6625                                      SelectionDAG &DAG,
6626                                      const X86Subtarget &Subtarget) {
6627   if (NumNonZero > 4 && !Subtarget.hasSSE41())
6628     return SDValue();
6629 
6630   // Use PINSRW to insert each byte directly.
6631   return LowerBuildVectorAsInsert(Op, DL, NonZeroMask, NumNonZero, NumZero, DAG,
6632                                   Subtarget);
6633 }
6634 
6635 /// Custom lower build_vector of v4i32 or v4f32.
LowerBuildVectorv4x32(SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)6636 static SDValue LowerBuildVectorv4x32(SDValue Op, const SDLoc &DL,
6637                                      SelectionDAG &DAG,
6638                                      const X86Subtarget &Subtarget) {
6639   // If this is a splat of a pair of elements, use MOVDDUP (unless the target
6640   // has XOP; in that case defer lowering to potentially use VPERMIL2PS).
6641   // Because we're creating a less complicated build vector here, we may enable
6642   // further folding of the MOVDDUP via shuffle transforms.
6643   if (Subtarget.hasSSE3() && !Subtarget.hasXOP() &&
6644       Op.getOperand(0) == Op.getOperand(2) &&
6645       Op.getOperand(1) == Op.getOperand(3) &&
6646       Op.getOperand(0) != Op.getOperand(1)) {
6647     MVT VT = Op.getSimpleValueType();
6648     MVT EltVT = VT.getVectorElementType();
6649     // Create a new build vector with the first 2 elements followed by undef
6650     // padding, bitcast to v2f64, duplicate, and bitcast back.
6651     SDValue Ops[4] = { Op.getOperand(0), Op.getOperand(1),
6652                        DAG.getUNDEF(EltVT), DAG.getUNDEF(EltVT) };
6653     SDValue NewBV = DAG.getBitcast(MVT::v2f64, DAG.getBuildVector(VT, DL, Ops));
6654     SDValue Dup = DAG.getNode(X86ISD::MOVDDUP, DL, MVT::v2f64, NewBV);
6655     return DAG.getBitcast(VT, Dup);
6656   }
6657 
6658   // Find all zeroable elements.
6659   std::bitset<4> Zeroable, Undefs;
6660   for (int i = 0; i < 4; ++i) {
6661     SDValue Elt = Op.getOperand(i);
6662     Undefs[i] = Elt.isUndef();
6663     Zeroable[i] = (Elt.isUndef() || X86::isZeroNode(Elt));
6664   }
6665   assert(Zeroable.size() - Zeroable.count() > 1 &&
6666          "We expect at least two non-zero elements!");
6667 
6668   // We only know how to deal with build_vector nodes where elements are either
6669   // zeroable or extract_vector_elt with constant index.
6670   SDValue FirstNonZero;
6671   unsigned FirstNonZeroIdx;
6672   for (unsigned i = 0; i < 4; ++i) {
6673     if (Zeroable[i])
6674       continue;
6675     SDValue Elt = Op.getOperand(i);
6676     if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
6677         !isa<ConstantSDNode>(Elt.getOperand(1)))
6678       return SDValue();
6679     // Make sure that this node is extracting from a 128-bit vector.
6680     MVT VT = Elt.getOperand(0).getSimpleValueType();
6681     if (!VT.is128BitVector())
6682       return SDValue();
6683     if (!FirstNonZero.getNode()) {
6684       FirstNonZero = Elt;
6685       FirstNonZeroIdx = i;
6686     }
6687   }
6688 
6689   assert(FirstNonZero.getNode() && "Unexpected build vector of all zeros!");
6690   SDValue V1 = FirstNonZero.getOperand(0);
6691   MVT VT = V1.getSimpleValueType();
6692 
6693   // See if this build_vector can be lowered as a blend with zero.
6694   SDValue Elt;
6695   unsigned EltMaskIdx, EltIdx;
6696   int Mask[4];
6697   for (EltIdx = 0; EltIdx < 4; ++EltIdx) {
6698     if (Zeroable[EltIdx]) {
6699       // The zero vector will be on the right hand side.
6700       Mask[EltIdx] = EltIdx+4;
6701       continue;
6702     }
6703 
6704     Elt = Op->getOperand(EltIdx);
6705     // By construction, Elt is a EXTRACT_VECTOR_ELT with constant index.
6706     EltMaskIdx = Elt.getConstantOperandVal(1);
6707     if (Elt.getOperand(0) != V1 || EltMaskIdx != EltIdx)
6708       break;
6709     Mask[EltIdx] = EltIdx;
6710   }
6711 
6712   if (EltIdx == 4) {
6713     // Let the shuffle legalizer deal with blend operations.
6714     SDValue VZeroOrUndef = (Zeroable == Undefs)
6715                                ? DAG.getUNDEF(VT)
6716                                : getZeroVector(VT, Subtarget, DAG, DL);
6717     if (V1.getSimpleValueType() != VT)
6718       V1 = DAG.getBitcast(VT, V1);
6719     return DAG.getVectorShuffle(VT, SDLoc(V1), V1, VZeroOrUndef, Mask);
6720   }
6721 
6722   // See if we can lower this build_vector to a INSERTPS.
6723   if (!Subtarget.hasSSE41())
6724     return SDValue();
6725 
6726   SDValue V2 = Elt.getOperand(0);
6727   if (Elt == FirstNonZero && EltIdx == FirstNonZeroIdx)
6728     V1 = SDValue();
6729 
6730   bool CanFold = true;
6731   for (unsigned i = EltIdx + 1; i < 4 && CanFold; ++i) {
6732     if (Zeroable[i])
6733       continue;
6734 
6735     SDValue Current = Op->getOperand(i);
6736     SDValue SrcVector = Current->getOperand(0);
6737     if (!V1.getNode())
6738       V1 = SrcVector;
6739     CanFold = (SrcVector == V1) && (Current.getConstantOperandAPInt(1) == i);
6740   }
6741 
6742   if (!CanFold)
6743     return SDValue();
6744 
6745   assert(V1.getNode() && "Expected at least two non-zero elements!");
6746   if (V1.getSimpleValueType() != MVT::v4f32)
6747     V1 = DAG.getBitcast(MVT::v4f32, V1);
6748   if (V2.getSimpleValueType() != MVT::v4f32)
6749     V2 = DAG.getBitcast(MVT::v4f32, V2);
6750 
6751   // Ok, we can emit an INSERTPS instruction.
6752   unsigned ZMask = Zeroable.to_ulong();
6753 
6754   unsigned InsertPSMask = EltMaskIdx << 6 | EltIdx << 4 | ZMask;
6755   assert((InsertPSMask & ~0xFFu) == 0 && "Invalid mask!");
6756   SDValue Result = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32, V1, V2,
6757                                DAG.getIntPtrConstant(InsertPSMask, DL, true));
6758   return DAG.getBitcast(VT, Result);
6759 }
6760 
6761 /// Return a vector logical shift node.
getVShift(bool isLeft,EVT VT,SDValue SrcOp,unsigned NumBits,SelectionDAG & DAG,const TargetLowering & TLI,const SDLoc & dl)6762 static SDValue getVShift(bool isLeft, EVT VT, SDValue SrcOp, unsigned NumBits,
6763                          SelectionDAG &DAG, const TargetLowering &TLI,
6764                          const SDLoc &dl) {
6765   assert(VT.is128BitVector() && "Unknown type for VShift");
6766   MVT ShVT = MVT::v16i8;
6767   unsigned Opc = isLeft ? X86ISD::VSHLDQ : X86ISD::VSRLDQ;
6768   SrcOp = DAG.getBitcast(ShVT, SrcOp);
6769   assert(NumBits % 8 == 0 && "Only support byte sized shifts");
6770   SDValue ShiftVal = DAG.getTargetConstant(NumBits / 8, dl, MVT::i8);
6771   return DAG.getBitcast(VT, DAG.getNode(Opc, dl, ShVT, SrcOp, ShiftVal));
6772 }
6773 
LowerAsSplatVectorLoad(SDValue SrcOp,MVT VT,const SDLoc & dl,SelectionDAG & DAG)6774 static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
6775                                       SelectionDAG &DAG) {
6776 
6777   // Check if the scalar load can be widened into a vector load. And if
6778   // the address is "base + cst" see if the cst can be "absorbed" into
6779   // the shuffle mask.
6780   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(SrcOp)) {
6781     SDValue Ptr = LD->getBasePtr();
6782     if (!ISD::isNormalLoad(LD) || !LD->isSimple())
6783       return SDValue();
6784     EVT PVT = LD->getValueType(0);
6785     if (PVT != MVT::i32 && PVT != MVT::f32)
6786       return SDValue();
6787 
6788     int FI = -1;
6789     int64_t Offset = 0;
6790     if (FrameIndexSDNode *FINode = dyn_cast<FrameIndexSDNode>(Ptr)) {
6791       FI = FINode->getIndex();
6792       Offset = 0;
6793     } else if (DAG.isBaseWithConstantOffset(Ptr) &&
6794                isa<FrameIndexSDNode>(Ptr.getOperand(0))) {
6795       FI = cast<FrameIndexSDNode>(Ptr.getOperand(0))->getIndex();
6796       Offset = Ptr.getConstantOperandVal(1);
6797       Ptr = Ptr.getOperand(0);
6798     } else {
6799       return SDValue();
6800     }
6801 
6802     // FIXME: 256-bit vector instructions don't require a strict alignment,
6803     // improve this code to support it better.
6804     Align RequiredAlign(VT.getSizeInBits() / 8);
6805     SDValue Chain = LD->getChain();
6806     // Make sure the stack object alignment is at least 16 or 32.
6807     MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
6808     MaybeAlign InferredAlign = DAG.InferPtrAlign(Ptr);
6809     if (!InferredAlign || *InferredAlign < RequiredAlign) {
6810       if (MFI.isFixedObjectIndex(FI)) {
6811         // Can't change the alignment. FIXME: It's possible to compute
6812         // the exact stack offset and reference FI + adjust offset instead.
6813         // If someone *really* cares about this. That's the way to implement it.
6814         return SDValue();
6815       } else {
6816         MFI.setObjectAlignment(FI, RequiredAlign);
6817       }
6818     }
6819 
6820     // (Offset % 16 or 32) must be multiple of 4. Then address is then
6821     // Ptr + (Offset & ~15).
6822     if (Offset < 0)
6823       return SDValue();
6824     if ((Offset % RequiredAlign.value()) & 3)
6825       return SDValue();
6826     int64_t StartOffset = Offset & ~int64_t(RequiredAlign.value() - 1);
6827     if (StartOffset) {
6828       SDLoc DL(Ptr);
6829       Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
6830                         DAG.getConstant(StartOffset, DL, Ptr.getValueType()));
6831     }
6832 
6833     int EltNo = (Offset - StartOffset) >> 2;
6834     unsigned NumElems = VT.getVectorNumElements();
6835 
6836     EVT NVT = EVT::getVectorVT(*DAG.getContext(), PVT, NumElems);
6837     SDValue V1 = DAG.getLoad(NVT, dl, Chain, Ptr,
6838                              LD->getPointerInfo().getWithOffset(StartOffset));
6839 
6840     SmallVector<int, 8> Mask(NumElems, EltNo);
6841 
6842     return DAG.getVectorShuffle(NVT, dl, V1, DAG.getUNDEF(NVT), Mask);
6843   }
6844 
6845   return SDValue();
6846 }
6847 
6848 // Recurse to find a LoadSDNode source and the accumulated ByteOffest.
findEltLoadSrc(SDValue Elt,LoadSDNode * & Ld,int64_t & ByteOffset)6849 static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
6850   if (ISD::isNON_EXTLoad(Elt.getNode())) {
6851     auto *BaseLd = cast<LoadSDNode>(Elt);
6852     if (!BaseLd->isSimple())
6853       return false;
6854     Ld = BaseLd;
6855     ByteOffset = 0;
6856     return true;
6857   }
6858 
6859   switch (Elt.getOpcode()) {
6860   case ISD::BITCAST:
6861   case ISD::TRUNCATE:
6862   case ISD::SCALAR_TO_VECTOR:
6863     return findEltLoadSrc(Elt.getOperand(0), Ld, ByteOffset);
6864   case ISD::SRL:
6865     if (auto *AmtC = dyn_cast<ConstantSDNode>(Elt.getOperand(1))) {
6866       uint64_t Amt = AmtC->getZExtValue();
6867       if ((Amt % 8) == 0 && findEltLoadSrc(Elt.getOperand(0), Ld, ByteOffset)) {
6868         ByteOffset += Amt / 8;
6869         return true;
6870       }
6871     }
6872     break;
6873   case ISD::EXTRACT_VECTOR_ELT:
6874     if (auto *IdxC = dyn_cast<ConstantSDNode>(Elt.getOperand(1))) {
6875       SDValue Src = Elt.getOperand(0);
6876       unsigned SrcSizeInBits = Src.getScalarValueSizeInBits();
6877       unsigned DstSizeInBits = Elt.getScalarValueSizeInBits();
6878       if (DstSizeInBits == SrcSizeInBits && (SrcSizeInBits % 8) == 0 &&
6879           findEltLoadSrc(Src, Ld, ByteOffset)) {
6880         uint64_t Idx = IdxC->getZExtValue();
6881         ByteOffset += Idx * (SrcSizeInBits / 8);
6882         return true;
6883       }
6884     }
6885     break;
6886   }
6887 
6888   return false;
6889 }
6890 
6891 /// Given the initializing elements 'Elts' of a vector of type 'VT', see if the
6892 /// elements can be replaced by a single large load which has the same value as
6893 /// a build_vector or insert_subvector whose loaded operands are 'Elts'.
6894 ///
6895 /// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
EltsFromConsecutiveLoads(EVT VT,ArrayRef<SDValue> Elts,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget,bool IsAfterLegalize)6896 static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
6897                                         const SDLoc &DL, SelectionDAG &DAG,
6898                                         const X86Subtarget &Subtarget,
6899                                         bool IsAfterLegalize) {
6900   if ((VT.getScalarSizeInBits() % 8) != 0)
6901     return SDValue();
6902 
6903   unsigned NumElems = Elts.size();
6904 
6905   int LastLoadedElt = -1;
6906   APInt LoadMask = APInt::getZero(NumElems);
6907   APInt ZeroMask = APInt::getZero(NumElems);
6908   APInt UndefMask = APInt::getZero(NumElems);
6909 
6910   SmallVector<LoadSDNode*, 8> Loads(NumElems, nullptr);
6911   SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
6912 
6913   // For each element in the initializer, see if we've found a load, zero or an
6914   // undef.
6915   for (unsigned i = 0; i < NumElems; ++i) {
6916     SDValue Elt = peekThroughBitcasts(Elts[i]);
6917     if (!Elt.getNode())
6918       return SDValue();
6919     if (Elt.isUndef()) {
6920       UndefMask.setBit(i);
6921       continue;
6922     }
6923     if (X86::isZeroNode(Elt) || ISD::isBuildVectorAllZeros(Elt.getNode())) {
6924       ZeroMask.setBit(i);
6925       continue;
6926     }
6927 
6928     // Each loaded element must be the correct fractional portion of the
6929     // requested vector load.
6930     unsigned EltSizeInBits = Elt.getValueSizeInBits();
6931     if ((NumElems * EltSizeInBits) != VT.getSizeInBits())
6932       return SDValue();
6933 
6934     if (!findEltLoadSrc(Elt, Loads[i], ByteOffsets[i]) || ByteOffsets[i] < 0)
6935       return SDValue();
6936     unsigned LoadSizeInBits = Loads[i]->getValueSizeInBits(0);
6937     if (((ByteOffsets[i] * 8) + EltSizeInBits) > LoadSizeInBits)
6938       return SDValue();
6939 
6940     LoadMask.setBit(i);
6941     LastLoadedElt = i;
6942   }
6943   assert((ZeroMask.popcount() + UndefMask.popcount() + LoadMask.popcount()) ==
6944              NumElems &&
6945          "Incomplete element masks");
6946 
6947   // Handle Special Cases - all undef or undef/zero.
6948   if (UndefMask.popcount() == NumElems)
6949     return DAG.getUNDEF(VT);
6950   if ((ZeroMask.popcount() + UndefMask.popcount()) == NumElems)
6951     return VT.isInteger() ? DAG.getConstant(0, DL, VT)
6952                           : DAG.getConstantFP(0.0, DL, VT);
6953 
6954   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6955   int FirstLoadedElt = LoadMask.countr_zero();
6956   SDValue EltBase = peekThroughBitcasts(Elts[FirstLoadedElt]);
6957   EVT EltBaseVT = EltBase.getValueType();
6958   assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
6959          "Register/Memory size mismatch");
6960   LoadSDNode *LDBase = Loads[FirstLoadedElt];
6961   assert(LDBase && "Did not find base load for merging consecutive loads");
6962   unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
6963   unsigned BaseSizeInBytes = BaseSizeInBits / 8;
6964   int NumLoadedElts = (1 + LastLoadedElt - FirstLoadedElt);
6965   int LoadSizeInBits = NumLoadedElts * BaseSizeInBits;
6966   assert((BaseSizeInBits % 8) == 0 && "Sub-byte element loads detected");
6967 
6968   // TODO: Support offsetting the base load.
6969   if (ByteOffsets[FirstLoadedElt] != 0)
6970     return SDValue();
6971 
6972   // Check to see if the element's load is consecutive to the base load
6973   // or offset from a previous (already checked) load.
6974   auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
6975     LoadSDNode *Ld = Loads[EltIdx];
6976     int64_t ByteOffset = ByteOffsets[EltIdx];
6977     if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
6978       int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
6979       return (0 <= BaseIdx && BaseIdx < (int)NumElems && LoadMask[BaseIdx] &&
6980               Loads[BaseIdx] == Ld && ByteOffsets[BaseIdx] == 0);
6981     }
6982     return DAG.areNonVolatileConsecutiveLoads(Ld, Base, BaseSizeInBytes,
6983                                               EltIdx - FirstLoadedElt);
6984   };
6985 
6986   // Consecutive loads can contain UNDEFS but not ZERO elements.
6987   // Consecutive loads with UNDEFs and ZEROs elements require a
6988   // an additional shuffle stage to clear the ZERO elements.
6989   bool IsConsecutiveLoad = true;
6990   bool IsConsecutiveLoadWithZeros = true;
6991   for (int i = FirstLoadedElt + 1; i <= LastLoadedElt; ++i) {
6992     if (LoadMask[i]) {
6993       if (!CheckConsecutiveLoad(LDBase, i)) {
6994         IsConsecutiveLoad = false;
6995         IsConsecutiveLoadWithZeros = false;
6996         break;
6997       }
6998     } else if (ZeroMask[i]) {
6999       IsConsecutiveLoad = false;
7000     }
7001   }
7002 
7003   auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7004     auto MMOFlags = LDBase->getMemOperand()->getFlags();
7005     assert(LDBase->isSimple() &&
7006            "Cannot merge volatile or atomic loads.");
7007     SDValue NewLd =
7008         DAG.getLoad(VT, DL, LDBase->getChain(), LDBase->getBasePtr(),
7009                     LDBase->getPointerInfo(), LDBase->getOriginalAlign(),
7010                     MMOFlags);
7011     for (auto *LD : Loads)
7012       if (LD)
7013         DAG.makeEquivalentMemoryOrdering(LD, NewLd);
7014     return NewLd;
7015   };
7016 
7017   // Check if the base load is entirely dereferenceable.
7018   bool IsDereferenceable = LDBase->getPointerInfo().isDereferenceable(
7019       VT.getSizeInBits() / 8, *DAG.getContext(), DAG.getDataLayout());
7020 
7021   // LOAD - all consecutive load/undefs (must start/end with a load or be
7022   // entirely dereferenceable). If we have found an entire vector of loads and
7023   // undefs, then return a large load of the entire vector width starting at the
7024   // base pointer. If the vector contains zeros, then attempt to shuffle those
7025   // elements.
7026   if (FirstLoadedElt == 0 &&
7027       (NumLoadedElts == (int)NumElems || IsDereferenceable) &&
7028       (IsConsecutiveLoad || IsConsecutiveLoadWithZeros)) {
7029     if (IsAfterLegalize && !TLI.isOperationLegal(ISD::LOAD, VT))
7030       return SDValue();
7031 
7032     // Don't create 256-bit non-temporal aligned loads without AVX2 as these
7033     // will lower to regular temporal loads and use the cache.
7034     if (LDBase->isNonTemporal() && LDBase->getAlign() >= Align(32) &&
7035         VT.is256BitVector() && !Subtarget.hasInt256())
7036       return SDValue();
7037 
7038     if (NumElems == 1)
7039       return DAG.getBitcast(VT, Elts[FirstLoadedElt]);
7040 
7041     if (!ZeroMask)
7042       return CreateLoad(VT, LDBase);
7043 
7044     // IsConsecutiveLoadWithZeros - we need to create a shuffle of the loaded
7045     // vector and a zero vector to clear out the zero elements.
7046     if (!IsAfterLegalize && VT.isVector()) {
7047       unsigned NumMaskElts = VT.getVectorNumElements();
7048       if ((NumMaskElts % NumElems) == 0) {
7049         unsigned Scale = NumMaskElts / NumElems;
7050         SmallVector<int, 4> ClearMask(NumMaskElts, -1);
7051         for (unsigned i = 0; i < NumElems; ++i) {
7052           if (UndefMask[i])
7053             continue;
7054           int Offset = ZeroMask[i] ? NumMaskElts : 0;
7055           for (unsigned j = 0; j != Scale; ++j)
7056             ClearMask[(i * Scale) + j] = (i * Scale) + j + Offset;
7057         }
7058         SDValue V = CreateLoad(VT, LDBase);
7059         SDValue Z = VT.isInteger() ? DAG.getConstant(0, DL, VT)
7060                                    : DAG.getConstantFP(0.0, DL, VT);
7061         return DAG.getVectorShuffle(VT, DL, V, Z, ClearMask);
7062       }
7063     }
7064   }
7065 
7066   // If the upper half of a ymm/zmm load is undef then just load the lower half.
7067   if (VT.is256BitVector() || VT.is512BitVector()) {
7068     unsigned HalfNumElems = NumElems / 2;
7069     if (UndefMask.extractBits(HalfNumElems, HalfNumElems).isAllOnes()) {
7070       EVT HalfVT =
7071           EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
7072       SDValue HalfLD =
7073           EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7074                                    DAG, Subtarget, IsAfterLegalize);
7075       if (HalfLD)
7076         return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
7077                            HalfLD, DAG.getIntPtrConstant(0, DL));
7078     }
7079   }
7080 
7081   // VZEXT_LOAD - consecutive 32/64-bit load/undefs followed by zeros/undefs.
7082   if (IsConsecutiveLoad && FirstLoadedElt == 0 &&
7083       ((LoadSizeInBits == 16 && Subtarget.hasFP16()) || LoadSizeInBits == 32 ||
7084        LoadSizeInBits == 64) &&
7085       ((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()))) {
7086     MVT VecSVT = VT.isFloatingPoint() ? MVT::getFloatingPointVT(LoadSizeInBits)
7087                                       : MVT::getIntegerVT(LoadSizeInBits);
7088     MVT VecVT = MVT::getVectorVT(VecSVT, VT.getSizeInBits() / LoadSizeInBits);
7089     // Allow v4f32 on SSE1 only targets.
7090     // FIXME: Add more isel patterns so we can just use VT directly.
7091     if (!Subtarget.hasSSE2() && VT == MVT::v4f32)
7092       VecVT = MVT::v4f32;
7093     if (TLI.isTypeLegal(VecVT)) {
7094       SDVTList Tys = DAG.getVTList(VecVT, MVT::Other);
7095       SDValue Ops[] = { LDBase->getChain(), LDBase->getBasePtr() };
7096       SDValue ResNode = DAG.getMemIntrinsicNode(
7097           X86ISD::VZEXT_LOAD, DL, Tys, Ops, VecSVT, LDBase->getPointerInfo(),
7098           LDBase->getOriginalAlign(), MachineMemOperand::MOLoad);
7099       for (auto *LD : Loads)
7100         if (LD)
7101           DAG.makeEquivalentMemoryOrdering(LD, ResNode);
7102       return DAG.getBitcast(VT, ResNode);
7103     }
7104   }
7105 
7106   // BROADCAST - match the smallest possible repetition pattern, load that
7107   // scalar/subvector element and then broadcast to the entire vector.
7108   if (ZeroMask.isZero() && isPowerOf2_32(NumElems) && Subtarget.hasAVX() &&
7109       (VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector())) {
7110     for (unsigned SubElems = 1; SubElems < NumElems; SubElems *= 2) {
7111       unsigned RepeatSize = SubElems * BaseSizeInBits;
7112       unsigned ScalarSize = std::min(RepeatSize, 64u);
7113       if (!Subtarget.hasAVX2() && ScalarSize < 32)
7114         continue;
7115 
7116       // Don't attempt a 1:N subvector broadcast - it should be caught by
7117       // combineConcatVectorOps, else will cause infinite loops.
7118       if (RepeatSize > ScalarSize && SubElems == 1)
7119         continue;
7120 
7121       bool Match = true;
7122       SmallVector<SDValue, 8> RepeatedLoads(SubElems, DAG.getUNDEF(EltBaseVT));
7123       for (unsigned i = 0; i != NumElems && Match; ++i) {
7124         if (!LoadMask[i])
7125           continue;
7126         SDValue Elt = peekThroughBitcasts(Elts[i]);
7127         if (RepeatedLoads[i % SubElems].isUndef())
7128           RepeatedLoads[i % SubElems] = Elt;
7129         else
7130           Match &= (RepeatedLoads[i % SubElems] == Elt);
7131       }
7132 
7133       // We must have loads at both ends of the repetition.
7134       Match &= !RepeatedLoads.front().isUndef();
7135       Match &= !RepeatedLoads.back().isUndef();
7136       if (!Match)
7137         continue;
7138 
7139       EVT RepeatVT =
7140           VT.isInteger() && (RepeatSize != 64 || TLI.isTypeLegal(MVT::i64))
7141               ? EVT::getIntegerVT(*DAG.getContext(), ScalarSize)
7142               : EVT::getFloatingPointVT(ScalarSize);
7143       if (RepeatSize > ScalarSize)
7144         RepeatVT = EVT::getVectorVT(*DAG.getContext(), RepeatVT,
7145                                     RepeatSize / ScalarSize);
7146       EVT BroadcastVT =
7147           EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
7148                            VT.getSizeInBits() / ScalarSize);
7149       if (TLI.isTypeLegal(BroadcastVT)) {
7150         if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7151                 RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
7152           SDValue Broadcast = RepeatLoad;
7153           if (RepeatSize > ScalarSize) {
7154             while (Broadcast.getValueSizeInBits() < VT.getSizeInBits())
7155               Broadcast = concatSubVectors(Broadcast, Broadcast, DAG, DL);
7156           } else {
7157             if (!Subtarget.hasAVX2() &&
7158                 !X86::mayFoldLoadIntoBroadcastFromMem(
7159                     RepeatLoad, RepeatVT.getScalarType().getSimpleVT(),
7160                     Subtarget,
7161                     /*AssumeSingleUse=*/true))
7162               return SDValue();
7163             Broadcast =
7164                 DAG.getNode(X86ISD::VBROADCAST, DL, BroadcastVT, RepeatLoad);
7165           }
7166           return DAG.getBitcast(VT, Broadcast);
7167         }
7168       }
7169     }
7170   }
7171 
7172   return SDValue();
7173 }
7174 
7175 // Combine a vector ops (shuffles etc.) that is equal to build_vector load1,
7176 // load2, load3, load4, <0, 1, 2, 3> into a vector load if the load addresses
7177 // are consecutive, non-overlapping, and in the right order.
combineToConsecutiveLoads(EVT VT,SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget,bool IsAfterLegalize)7178 static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
7179                                          SelectionDAG &DAG,
7180                                          const X86Subtarget &Subtarget,
7181                                          bool IsAfterLegalize) {
7182   SmallVector<SDValue, 64> Elts;
7183   for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
7184     if (SDValue Elt = getShuffleScalarElt(Op, i, DAG, 0)) {
7185       Elts.push_back(Elt);
7186       continue;
7187     }
7188     return SDValue();
7189   }
7190   assert(Elts.size() == VT.getVectorNumElements());
7191   return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7192                                   IsAfterLegalize);
7193 }
7194 
getConstantVector(MVT VT,ArrayRef<APInt> Bits,const APInt & Undefs,LLVMContext & C)7195 static Constant *getConstantVector(MVT VT, ArrayRef<APInt> Bits,
7196                                    const APInt &Undefs, LLVMContext &C) {
7197   unsigned ScalarSize = VT.getScalarSizeInBits();
7198   Type *Ty = EVT(VT.getScalarType()).getTypeForEVT(C);
7199 
7200   auto getConstantScalar = [&](const APInt &Val) -> Constant * {
7201     if (VT.isFloatingPoint()) {
7202       if (ScalarSize == 16)
7203         return ConstantFP::get(C, APFloat(APFloat::IEEEhalf(), Val));
7204       if (ScalarSize == 32)
7205         return ConstantFP::get(C, APFloat(APFloat::IEEEsingle(), Val));
7206       assert(ScalarSize == 64 && "Unsupported floating point scalar size");
7207       return ConstantFP::get(C, APFloat(APFloat::IEEEdouble(), Val));
7208     }
7209     return Constant::getIntegerValue(Ty, Val);
7210   };
7211 
7212   SmallVector<Constant *, 32> ConstantVec;
7213   for (unsigned I = 0, E = Bits.size(); I != E; ++I)
7214     ConstantVec.push_back(Undefs[I] ? UndefValue::get(Ty)
7215                                     : getConstantScalar(Bits[I]));
7216 
7217   return ConstantVector::get(ArrayRef<Constant *>(ConstantVec));
7218 }
7219 
getConstantVector(MVT VT,const APInt & SplatValue,unsigned SplatBitSize,LLVMContext & C)7220 static Constant *getConstantVector(MVT VT, const APInt &SplatValue,
7221                                    unsigned SplatBitSize, LLVMContext &C) {
7222   unsigned ScalarSize = VT.getScalarSizeInBits();
7223 
7224   auto getConstantScalar = [&](const APInt &Val) -> Constant * {
7225     if (VT.isFloatingPoint()) {
7226       if (ScalarSize == 16)
7227         return ConstantFP::get(C, APFloat(APFloat::IEEEhalf(), Val));
7228       if (ScalarSize == 32)
7229         return ConstantFP::get(C, APFloat(APFloat::IEEEsingle(), Val));
7230       assert(ScalarSize == 64 && "Unsupported floating point scalar size");
7231       return ConstantFP::get(C, APFloat(APFloat::IEEEdouble(), Val));
7232     }
7233     return Constant::getIntegerValue(Type::getIntNTy(C, ScalarSize), Val);
7234   };
7235 
7236   if (ScalarSize == SplatBitSize)
7237     return getConstantScalar(SplatValue);
7238 
7239   unsigned NumElm = SplatBitSize / ScalarSize;
7240   SmallVector<Constant *, 32> ConstantVec;
7241   for (unsigned I = 0; I != NumElm; ++I) {
7242     APInt Val = SplatValue.extractBits(ScalarSize, ScalarSize * I);
7243     ConstantVec.push_back(getConstantScalar(Val));
7244   }
7245   return ConstantVector::get(ArrayRef<Constant *>(ConstantVec));
7246 }
7247 
isFoldableUseOfShuffle(SDNode * N)7248 static bool isFoldableUseOfShuffle(SDNode *N) {
7249   for (auto *U : N->uses()) {
7250     unsigned Opc = U->getOpcode();
7251     // VPERMV/VPERMV3 shuffles can never fold their index operands.
7252     if (Opc == X86ISD::VPERMV && U->getOperand(0).getNode() == N)
7253       return false;
7254     if (Opc == X86ISD::VPERMV3 && U->getOperand(1).getNode() == N)
7255       return false;
7256     if (isTargetShuffle(Opc))
7257       return true;
7258     if (Opc == ISD::BITCAST) // Ignore bitcasts
7259       return isFoldableUseOfShuffle(U);
7260     if (N->hasOneUse()) {
7261       // TODO, there may be some general way to know if a SDNode can
7262       // be folded. We now only know whether an MI is foldable.
7263       if (Opc == X86ISD::VPDPBUSD && U->getOperand(2).getNode() != N)
7264         return false;
7265       return true;
7266     }
7267   }
7268   return false;
7269 }
7270 
7271 /// Attempt to use the vbroadcast instruction to generate a splat value
7272 /// from a splat BUILD_VECTOR which uses:
7273 ///  a. A single scalar load, or a constant.
7274 ///  b. Repeated pattern of constants (e.g. <0,1,0,1> or <0,1,2,3,0,1,2,3>).
7275 ///
7276 /// The VBROADCAST node is returned when a pattern is found,
7277 /// or SDValue() otherwise.
lowerBuildVectorAsBroadcast(BuildVectorSDNode * BVOp,const SDLoc & dl,const X86Subtarget & Subtarget,SelectionDAG & DAG)7278 static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
7279                                            const SDLoc &dl,
7280                                            const X86Subtarget &Subtarget,
7281                                            SelectionDAG &DAG) {
7282   // VBROADCAST requires AVX.
7283   // TODO: Splats could be generated for non-AVX CPUs using SSE
7284   // instructions, but there's less potential gain for only 128-bit vectors.
7285   if (!Subtarget.hasAVX())
7286     return SDValue();
7287 
7288   MVT VT = BVOp->getSimpleValueType(0);
7289   unsigned NumElts = VT.getVectorNumElements();
7290   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7291   assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
7292          "Unsupported vector type for broadcast.");
7293 
7294   // See if the build vector is a repeating sequence of scalars (inc. splat).
7295   SDValue Ld;
7296   BitVector UndefElements;
7297   SmallVector<SDValue, 16> Sequence;
7298   if (BVOp->getRepeatedSequence(Sequence, &UndefElements)) {
7299     assert((NumElts % Sequence.size()) == 0 && "Sequence doesn't fit.");
7300     if (Sequence.size() == 1)
7301       Ld = Sequence[0];
7302   }
7303 
7304   // Attempt to use VBROADCASTM
7305   // From this pattern:
7306   // a. t0 = (zext_i64 (bitcast_i8 v2i1 X))
7307   // b. t1 = (build_vector t0 t0)
7308   //
7309   // Create (VBROADCASTM v2i1 X)
7310   if (!Sequence.empty() && Subtarget.hasCDI()) {
7311     // If not a splat, are the upper sequence values zeroable?
7312     unsigned SeqLen = Sequence.size();
7313     bool UpperZeroOrUndef =
7314         SeqLen == 1 ||
7315         llvm::all_of(ArrayRef(Sequence).drop_front(),
7316                      [](SDValue V) { return !V || isNullConstantOrUndef(V); });
7317     SDValue Op0 = Sequence[0];
7318     if (UpperZeroOrUndef && ((Op0.getOpcode() == ISD::BITCAST) ||
7319                              (Op0.getOpcode() == ISD::ZERO_EXTEND &&
7320                               Op0.getOperand(0).getOpcode() == ISD::BITCAST))) {
7321       SDValue BOperand = Op0.getOpcode() == ISD::BITCAST
7322                              ? Op0.getOperand(0)
7323                              : Op0.getOperand(0).getOperand(0);
7324       MVT MaskVT = BOperand.getSimpleValueType();
7325       MVT EltType = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen);
7326       if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) ||  // for broadcastmb2q
7327           (EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d
7328         MVT BcstVT = MVT::getVectorVT(EltType, NumElts / SeqLen);
7329         if (!VT.is512BitVector() && !Subtarget.hasVLX()) {
7330           unsigned Scale = 512 / VT.getSizeInBits();
7331           BcstVT = MVT::getVectorVT(EltType, Scale * (NumElts / SeqLen));
7332         }
7333         SDValue Bcst = DAG.getNode(X86ISD::VBROADCASTM, dl, BcstVT, BOperand);
7334         if (BcstVT.getSizeInBits() != VT.getSizeInBits())
7335           Bcst = extractSubVector(Bcst, 0, DAG, dl, VT.getSizeInBits());
7336         return DAG.getBitcast(VT, Bcst);
7337       }
7338     }
7339   }
7340 
7341   unsigned NumUndefElts = UndefElements.count();
7342   if (!Ld || (NumElts - NumUndefElts) <= 1) {
7343     APInt SplatValue, Undef;
7344     unsigned SplatBitSize;
7345     bool HasUndef;
7346     // Check if this is a repeated constant pattern suitable for broadcasting.
7347     if (BVOp->isConstantSplat(SplatValue, Undef, SplatBitSize, HasUndef) &&
7348         SplatBitSize > VT.getScalarSizeInBits() &&
7349         SplatBitSize < VT.getSizeInBits()) {
7350       // Avoid replacing with broadcast when it's a use of a shuffle
7351       // instruction to preserve the present custom lowering of shuffles.
7352       if (isFoldableUseOfShuffle(BVOp))
7353         return SDValue();
7354       // replace BUILD_VECTOR with broadcast of the repeated constants.
7355       LLVMContext *Ctx = DAG.getContext();
7356       MVT PVT = TLI.getPointerTy(DAG.getDataLayout());
7357       if (SplatBitSize == 32 || SplatBitSize == 64 ||
7358           (SplatBitSize < 32 && Subtarget.hasAVX2())) {
7359         // Load the constant scalar/subvector and broadcast it.
7360         MVT CVT = MVT::getIntegerVT(SplatBitSize);
7361         Constant *C = getConstantVector(VT, SplatValue, SplatBitSize, *Ctx);
7362         SDValue CP = DAG.getConstantPool(C, PVT);
7363         unsigned Repeat = VT.getSizeInBits() / SplatBitSize;
7364 
7365         Align Alignment = cast<ConstantPoolSDNode>(CP)->getAlign();
7366         SDVTList Tys = DAG.getVTList(MVT::getVectorVT(CVT, Repeat), MVT::Other);
7367         SDValue Ops[] = {DAG.getEntryNode(), CP};
7368         MachinePointerInfo MPI =
7369             MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
7370         SDValue Brdcst =
7371             DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT,
7372                                     MPI, Alignment, MachineMemOperand::MOLoad);
7373         return DAG.getBitcast(VT, Brdcst);
7374       }
7375       if (SplatBitSize > 64) {
7376         // Load the vector of constants and broadcast it.
7377         Constant *VecC = getConstantVector(VT, SplatValue, SplatBitSize, *Ctx);
7378         SDValue VCP = DAG.getConstantPool(VecC, PVT);
7379         unsigned NumElm = SplatBitSize / VT.getScalarSizeInBits();
7380         MVT VVT = MVT::getVectorVT(VT.getScalarType(), NumElm);
7381         Align Alignment = cast<ConstantPoolSDNode>(VCP)->getAlign();
7382         SDVTList Tys = DAG.getVTList(VT, MVT::Other);
7383         SDValue Ops[] = {DAG.getEntryNode(), VCP};
7384         MachinePointerInfo MPI =
7385             MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
7386         return DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, dl, Tys,
7387                                        Ops, VVT, MPI, Alignment,
7388                                        MachineMemOperand::MOLoad);
7389       }
7390     }
7391 
7392     // If we are moving a scalar into a vector (Ld must be set and all elements
7393     // but 1 are undef) and that operation is not obviously supported by
7394     // vmovd/vmovq/vmovss/vmovsd, then keep trying to form a broadcast.
7395     // That's better than general shuffling and may eliminate a load to GPR and
7396     // move from scalar to vector register.
7397     if (!Ld || NumElts - NumUndefElts != 1)
7398       return SDValue();
7399     unsigned ScalarSize = Ld.getValueSizeInBits();
7400     if (!(UndefElements[0] || (ScalarSize != 32 && ScalarSize != 64)))
7401       return SDValue();
7402   }
7403 
7404   bool ConstSplatVal =
7405       (Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP);
7406   bool IsLoad = ISD::isNormalLoad(Ld.getNode());
7407 
7408   // TODO: Handle broadcasts of non-constant sequences.
7409 
7410   // Make sure that all of the users of a non-constant load are from the
7411   // BUILD_VECTOR node.
7412   // FIXME: Is the use count needed for non-constant, non-load case?
7413   if (!ConstSplatVal && !IsLoad && !BVOp->isOnlyUserOf(Ld.getNode()))
7414     return SDValue();
7415 
7416   unsigned ScalarSize = Ld.getValueSizeInBits();
7417   bool IsGE256 = (VT.getSizeInBits() >= 256);
7418 
7419   // When optimizing for size, generate up to 5 extra bytes for a broadcast
7420   // instruction to save 8 or more bytes of constant pool data.
7421   // TODO: If multiple splats are generated to load the same constant,
7422   // it may be detrimental to overall size. There needs to be a way to detect
7423   // that condition to know if this is truly a size win.
7424   bool OptForSize = DAG.shouldOptForSize();
7425 
7426   // Handle broadcasting a single constant scalar from the constant pool
7427   // into a vector.
7428   // On Sandybridge (no AVX2), it is still better to load a constant vector
7429   // from the constant pool and not to broadcast it from a scalar.
7430   // But override that restriction when optimizing for size.
7431   // TODO: Check if splatting is recommended for other AVX-capable CPUs.
7432   if (ConstSplatVal && (Subtarget.hasAVX2() || OptForSize)) {
7433     EVT CVT = Ld.getValueType();
7434     assert(!CVT.isVector() && "Must not broadcast a vector type");
7435 
7436     // Splat f16, f32, i32, v4f64, v4i64 in all cases with AVX2.
7437     // For size optimization, also splat v2f64 and v2i64, and for size opt
7438     // with AVX2, also splat i8 and i16.
7439     // With pattern matching, the VBROADCAST node may become a VMOVDDUP.
7440     if (ScalarSize == 32 ||
7441         (ScalarSize == 64 && (IsGE256 || Subtarget.hasVLX())) ||
7442         (CVT == MVT::f16 && Subtarget.hasAVX2()) ||
7443         (OptForSize && (ScalarSize == 64 || Subtarget.hasAVX2()))) {
7444       const Constant *C = nullptr;
7445       if (ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Ld))
7446         C = CI->getConstantIntValue();
7447       else if (ConstantFPSDNode *CF = dyn_cast<ConstantFPSDNode>(Ld))
7448         C = CF->getConstantFPValue();
7449 
7450       assert(C && "Invalid constant type");
7451 
7452       SDValue CP =
7453           DAG.getConstantPool(C, TLI.getPointerTy(DAG.getDataLayout()));
7454       Align Alignment = cast<ConstantPoolSDNode>(CP)->getAlign();
7455 
7456       SDVTList Tys = DAG.getVTList(VT, MVT::Other);
7457       SDValue Ops[] = {DAG.getEntryNode(), CP};
7458       MachinePointerInfo MPI =
7459           MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
7460       return DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT,
7461                                      MPI, Alignment, MachineMemOperand::MOLoad);
7462     }
7463   }
7464 
7465   // Handle AVX2 in-register broadcasts.
7466   if (!IsLoad && Subtarget.hasInt256() &&
7467       (ScalarSize == 32 || (IsGE256 && ScalarSize == 64)))
7468     return DAG.getNode(X86ISD::VBROADCAST, dl, VT, Ld);
7469 
7470   // The scalar source must be a normal load.
7471   if (!IsLoad)
7472     return SDValue();
7473 
7474   // Make sure the non-chain result is only used by this build vector.
7475   if (!Ld->hasNUsesOfValue(NumElts - NumUndefElts, 0))
7476     return SDValue();
7477 
7478   if (ScalarSize == 32 || (IsGE256 && ScalarSize == 64) ||
7479       (Subtarget.hasVLX() && ScalarSize == 64)) {
7480     auto *LN = cast<LoadSDNode>(Ld);
7481     SDVTList Tys = DAG.getVTList(VT, MVT::Other);
7482     SDValue Ops[] = {LN->getChain(), LN->getBasePtr()};
7483     SDValue BCast =
7484         DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops,
7485                                 LN->getMemoryVT(), LN->getMemOperand());
7486     DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
7487     return BCast;
7488   }
7489 
7490   // The integer check is needed for the 64-bit into 128-bit so it doesn't match
7491   // double since there is no vbroadcastsd xmm
7492   if (Subtarget.hasInt256() && Ld.getValueType().isInteger() &&
7493       (ScalarSize == 8 || ScalarSize == 16 || ScalarSize == 64)) {
7494     auto *LN = cast<LoadSDNode>(Ld);
7495     SDVTList Tys = DAG.getVTList(VT, MVT::Other);
7496     SDValue Ops[] = {LN->getChain(), LN->getBasePtr()};
7497     SDValue BCast =
7498         DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops,
7499                                 LN->getMemoryVT(), LN->getMemOperand());
7500     DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
7501     return BCast;
7502   }
7503 
7504   if (ScalarSize == 16 && Subtarget.hasFP16() && IsGE256)
7505     return DAG.getNode(X86ISD::VBROADCAST, dl, VT, Ld);
7506 
7507   // Unsupported broadcast.
7508   return SDValue();
7509 }
7510 
7511 /// For an EXTRACT_VECTOR_ELT with a constant index return the real
7512 /// underlying vector and index.
7513 ///
7514 /// Modifies \p ExtractedFromVec to the real vector and returns the real
7515 /// index.
getUnderlyingExtractedFromVec(SDValue & ExtractedFromVec,SDValue ExtIdx)7516 static int getUnderlyingExtractedFromVec(SDValue &ExtractedFromVec,
7517                                          SDValue ExtIdx) {
7518   int Idx = ExtIdx->getAsZExtVal();
7519   if (!isa<ShuffleVectorSDNode>(ExtractedFromVec))
7520     return Idx;
7521 
7522   // For 256-bit vectors, LowerEXTRACT_VECTOR_ELT_SSE4 may have already
7523   // lowered this:
7524   //   (extract_vector_elt (v8f32 %1), Constant<6>)
7525   // to:
7526   //   (extract_vector_elt (vector_shuffle<2,u,u,u>
7527   //                           (extract_subvector (v8f32 %0), Constant<4>),
7528   //                           undef)
7529   //                       Constant<0>)
7530   // In this case the vector is the extract_subvector expression and the index
7531   // is 2, as specified by the shuffle.
7532   ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(ExtractedFromVec);
7533   SDValue ShuffleVec = SVOp->getOperand(0);
7534   MVT ShuffleVecVT = ShuffleVec.getSimpleValueType();
7535   assert(ShuffleVecVT.getVectorElementType() ==
7536          ExtractedFromVec.getSimpleValueType().getVectorElementType());
7537 
7538   int ShuffleIdx = SVOp->getMaskElt(Idx);
7539   if (isUndefOrInRange(ShuffleIdx, 0, ShuffleVecVT.getVectorNumElements())) {
7540     ExtractedFromVec = ShuffleVec;
7541     return ShuffleIdx;
7542   }
7543   return Idx;
7544 }
7545 
buildFromShuffleMostly(SDValue Op,const SDLoc & DL,SelectionDAG & DAG)7546 static SDValue buildFromShuffleMostly(SDValue Op, const SDLoc &DL,
7547                                       SelectionDAG &DAG) {
7548   MVT VT = Op.getSimpleValueType();
7549 
7550   // Skip if insert_vec_elt is not supported.
7551   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7552   if (!TLI.isOperationLegalOrCustom(ISD::INSERT_VECTOR_ELT, VT))
7553     return SDValue();
7554 
7555   unsigned NumElems = Op.getNumOperands();
7556   SDValue VecIn1;
7557   SDValue VecIn2;
7558   SmallVector<unsigned, 4> InsertIndices;
7559   SmallVector<int, 8> Mask(NumElems, -1);
7560 
7561   for (unsigned i = 0; i != NumElems; ++i) {
7562     unsigned Opc = Op.getOperand(i).getOpcode();
7563 
7564     if (Opc == ISD::UNDEF)
7565       continue;
7566 
7567     if (Opc != ISD::EXTRACT_VECTOR_ELT) {
7568       // Quit if more than 1 elements need inserting.
7569       if (InsertIndices.size() > 1)
7570         return SDValue();
7571 
7572       InsertIndices.push_back(i);
7573       continue;
7574     }
7575 
7576     SDValue ExtractedFromVec = Op.getOperand(i).getOperand(0);
7577     SDValue ExtIdx = Op.getOperand(i).getOperand(1);
7578 
7579     // Quit if non-constant index.
7580     if (!isa<ConstantSDNode>(ExtIdx))
7581       return SDValue();
7582     int Idx = getUnderlyingExtractedFromVec(ExtractedFromVec, ExtIdx);
7583 
7584     // Quit if extracted from vector of different type.
7585     if (ExtractedFromVec.getValueType() != VT)
7586       return SDValue();
7587 
7588     if (!VecIn1.getNode())
7589       VecIn1 = ExtractedFromVec;
7590     else if (VecIn1 != ExtractedFromVec) {
7591       if (!VecIn2.getNode())
7592         VecIn2 = ExtractedFromVec;
7593       else if (VecIn2 != ExtractedFromVec)
7594         // Quit if more than 2 vectors to shuffle
7595         return SDValue();
7596     }
7597 
7598     if (ExtractedFromVec == VecIn1)
7599       Mask[i] = Idx;
7600     else if (ExtractedFromVec == VecIn2)
7601       Mask[i] = Idx + NumElems;
7602   }
7603 
7604   if (!VecIn1.getNode())
7605     return SDValue();
7606 
7607   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(VT);
7608   SDValue NV = DAG.getVectorShuffle(VT, DL, VecIn1, VecIn2, Mask);
7609 
7610   for (unsigned Idx : InsertIndices)
7611     NV = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, NV, Op.getOperand(Idx),
7612                      DAG.getIntPtrConstant(Idx, DL));
7613 
7614   return NV;
7615 }
7616 
7617 // Lower BUILD_VECTOR operation for v8bf16, v16bf16 and v32bf16 types.
LowerBUILD_VECTORvXbf16(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)7618 static SDValue LowerBUILD_VECTORvXbf16(SDValue Op, SelectionDAG &DAG,
7619                                        const X86Subtarget &Subtarget) {
7620   MVT VT = Op.getSimpleValueType();
7621   MVT IVT =
7622       VT.changeVectorElementType(Subtarget.hasFP16() ? MVT::f16 : MVT::i16);
7623   SmallVector<SDValue, 16> NewOps;
7624   for (unsigned I = 0, E = Op.getNumOperands(); I != E; ++I)
7625     NewOps.push_back(DAG.getBitcast(Subtarget.hasFP16() ? MVT::f16 : MVT::i16,
7626                                     Op.getOperand(I)));
7627   SDValue Res = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(), IVT, NewOps);
7628   return DAG.getBitcast(VT, Res);
7629 }
7630 
7631 // Lower BUILD_VECTOR operation for v8i1 and v16i1 types.
LowerBUILD_VECTORvXi1(SDValue Op,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)7632 static SDValue LowerBUILD_VECTORvXi1(SDValue Op, const SDLoc &dl,
7633                                      SelectionDAG &DAG,
7634                                      const X86Subtarget &Subtarget) {
7635 
7636   MVT VT = Op.getSimpleValueType();
7637   assert((VT.getVectorElementType() == MVT::i1) &&
7638          "Unexpected type in LowerBUILD_VECTORvXi1!");
7639   if (ISD::isBuildVectorAllZeros(Op.getNode()) ||
7640       ISD::isBuildVectorAllOnes(Op.getNode()))
7641     return Op;
7642 
7643   uint64_t Immediate = 0;
7644   SmallVector<unsigned, 16> NonConstIdx;
7645   bool IsSplat = true;
7646   bool HasConstElts = false;
7647   int SplatIdx = -1;
7648   for (unsigned idx = 0, e = Op.getNumOperands(); idx < e; ++idx) {
7649     SDValue In = Op.getOperand(idx);
7650     if (In.isUndef())
7651       continue;
7652     if (auto *InC = dyn_cast<ConstantSDNode>(In)) {
7653       Immediate |= (InC->getZExtValue() & 0x1) << idx;
7654       HasConstElts = true;
7655     } else {
7656       NonConstIdx.push_back(idx);
7657     }
7658     if (SplatIdx < 0)
7659       SplatIdx = idx;
7660     else if (In != Op.getOperand(SplatIdx))
7661       IsSplat = false;
7662   }
7663 
7664   // for splat use " (select i1 splat_elt, all-ones, all-zeroes)"
7665   if (IsSplat) {
7666     // The build_vector allows the scalar element to be larger than the vector
7667     // element type. We need to mask it to use as a condition unless we know
7668     // the upper bits are zero.
7669     // FIXME: Use computeKnownBits instead of checking specific opcode?
7670     SDValue Cond = Op.getOperand(SplatIdx);
7671     assert(Cond.getValueType() == MVT::i8 && "Unexpected VT!");
7672     if (Cond.getOpcode() != ISD::SETCC)
7673       Cond = DAG.getNode(ISD::AND, dl, MVT::i8, Cond,
7674                          DAG.getConstant(1, dl, MVT::i8));
7675 
7676     // Perform the select in the scalar domain so we can use cmov.
7677     if (VT == MVT::v64i1 && !Subtarget.is64Bit()) {
7678       SDValue Select = DAG.getSelect(dl, MVT::i32, Cond,
7679                                      DAG.getAllOnesConstant(dl, MVT::i32),
7680                                      DAG.getConstant(0, dl, MVT::i32));
7681       Select = DAG.getBitcast(MVT::v32i1, Select);
7682       return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Select, Select);
7683     } else {
7684       MVT ImmVT = MVT::getIntegerVT(std::max((unsigned)VT.getSizeInBits(), 8U));
7685       SDValue Select = DAG.getSelect(dl, ImmVT, Cond,
7686                                      DAG.getAllOnesConstant(dl, ImmVT),
7687                                      DAG.getConstant(0, dl, ImmVT));
7688       MVT VecVT = VT.getSizeInBits() >= 8 ? VT : MVT::v8i1;
7689       Select = DAG.getBitcast(VecVT, Select);
7690       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, Select,
7691                          DAG.getIntPtrConstant(0, dl));
7692     }
7693   }
7694 
7695   // insert elements one by one
7696   SDValue DstVec;
7697   if (HasConstElts) {
7698     if (VT == MVT::v64i1 && !Subtarget.is64Bit()) {
7699       SDValue ImmL = DAG.getConstant(Lo_32(Immediate), dl, MVT::i32);
7700       SDValue ImmH = DAG.getConstant(Hi_32(Immediate), dl, MVT::i32);
7701       ImmL = DAG.getBitcast(MVT::v32i1, ImmL);
7702       ImmH = DAG.getBitcast(MVT::v32i1, ImmH);
7703       DstVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, ImmL, ImmH);
7704     } else {
7705       MVT ImmVT = MVT::getIntegerVT(std::max((unsigned)VT.getSizeInBits(), 8U));
7706       SDValue Imm = DAG.getConstant(Immediate, dl, ImmVT);
7707       MVT VecVT = VT.getSizeInBits() >= 8 ? VT : MVT::v8i1;
7708       DstVec = DAG.getBitcast(VecVT, Imm);
7709       DstVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, DstVec,
7710                            DAG.getIntPtrConstant(0, dl));
7711     }
7712   } else
7713     DstVec = DAG.getUNDEF(VT);
7714 
7715   for (unsigned InsertIdx : NonConstIdx) {
7716     DstVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, DstVec,
7717                          Op.getOperand(InsertIdx),
7718                          DAG.getIntPtrConstant(InsertIdx, dl));
7719   }
7720   return DstVec;
7721 }
7722 
isHorizOp(unsigned Opcode)7723 LLVM_ATTRIBUTE_UNUSED static bool isHorizOp(unsigned Opcode) {
7724   switch (Opcode) {
7725   case X86ISD::PACKSS:
7726   case X86ISD::PACKUS:
7727   case X86ISD::FHADD:
7728   case X86ISD::FHSUB:
7729   case X86ISD::HADD:
7730   case X86ISD::HSUB:
7731     return true;
7732   }
7733   return false;
7734 }
7735 
7736 /// This is a helper function of LowerToHorizontalOp().
7737 /// This function checks that the build_vector \p N in input implements a
7738 /// 128-bit partial horizontal operation on a 256-bit vector, but that operation
7739 /// may not match the layout of an x86 256-bit horizontal instruction.
7740 /// In other words, if this returns true, then some extraction/insertion will
7741 /// be required to produce a valid horizontal instruction.
7742 ///
7743 /// Parameter \p Opcode defines the kind of horizontal operation to match.
7744 /// For example, if \p Opcode is equal to ISD::ADD, then this function
7745 /// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode
7746 /// is equal to ISD::SUB, then this function checks if this is a horizontal
7747 /// arithmetic sub.
7748 ///
7749 /// This function only analyzes elements of \p N whose indices are
7750 /// in range [BaseIdx, LastIdx).
7751 ///
7752 /// TODO: This function was originally used to match both real and fake partial
7753 /// horizontal operations, but the index-matching logic is incorrect for that.
7754 /// See the corrected implementation in isHopBuildVector(). Can we reduce this
7755 /// code because it is only used for partial h-op matching now?
isHorizontalBinOpPart(const BuildVectorSDNode * N,unsigned Opcode,const SDLoc & DL,SelectionDAG & DAG,unsigned BaseIdx,unsigned LastIdx,SDValue & V0,SDValue & V1)7756 static bool isHorizontalBinOpPart(const BuildVectorSDNode *N, unsigned Opcode,
7757                                   const SDLoc &DL, SelectionDAG &DAG,
7758                                   unsigned BaseIdx, unsigned LastIdx,
7759                                   SDValue &V0, SDValue &V1) {
7760   EVT VT = N->getValueType(0);
7761   assert(VT.is256BitVector() && "Only use for matching partial 256-bit h-ops");
7762   assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!");
7763   assert(VT.isVector() && VT.getVectorNumElements() >= LastIdx &&
7764          "Invalid Vector in input!");
7765 
7766   bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD);
7767   bool CanFold = true;
7768   unsigned ExpectedVExtractIdx = BaseIdx;
7769   unsigned NumElts = LastIdx - BaseIdx;
7770   V0 = DAG.getUNDEF(VT);
7771   V1 = DAG.getUNDEF(VT);
7772 
7773   // Check if N implements a horizontal binop.
7774   for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) {
7775     SDValue Op = N->getOperand(i + BaseIdx);
7776 
7777     // Skip UNDEFs.
7778     if (Op->isUndef()) {
7779       // Update the expected vector extract index.
7780       if (i * 2 == NumElts)
7781         ExpectedVExtractIdx = BaseIdx;
7782       ExpectedVExtractIdx += 2;
7783       continue;
7784     }
7785 
7786     CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
7787 
7788     if (!CanFold)
7789       break;
7790 
7791     SDValue Op0 = Op.getOperand(0);
7792     SDValue Op1 = Op.getOperand(1);
7793 
7794     // Try to match the following pattern:
7795     // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
7796     CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7797         Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7798         Op0.getOperand(0) == Op1.getOperand(0) &&
7799         isa<ConstantSDNode>(Op0.getOperand(1)) &&
7800         isa<ConstantSDNode>(Op1.getOperand(1)));
7801     if (!CanFold)
7802       break;
7803 
7804     unsigned I0 = Op0.getConstantOperandVal(1);
7805     unsigned I1 = Op1.getConstantOperandVal(1);
7806 
7807     if (i * 2 < NumElts) {
7808       if (V0.isUndef()) {
7809         V0 = Op0.getOperand(0);
7810         if (V0.getValueType() != VT)
7811           return false;
7812       }
7813     } else {
7814       if (V1.isUndef()) {
7815         V1 = Op0.getOperand(0);
7816         if (V1.getValueType() != VT)
7817           return false;
7818       }
7819       if (i * 2 == NumElts)
7820         ExpectedVExtractIdx = BaseIdx;
7821     }
7822 
7823     SDValue Expected = (i * 2 < NumElts) ? V0 : V1;
7824     if (I0 == ExpectedVExtractIdx)
7825       CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
7826     else if (IsCommutable && I1 == ExpectedVExtractIdx) {
7827       // Try to match the following dag sequence:
7828       // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
7829       CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
7830     } else
7831       CanFold = false;
7832 
7833     ExpectedVExtractIdx += 2;
7834   }
7835 
7836   return CanFold;
7837 }
7838 
7839 /// Emit a sequence of two 128-bit horizontal add/sub followed by
7840 /// a concat_vector.
7841 ///
7842 /// This is a helper function of LowerToHorizontalOp().
7843 /// This function expects two 256-bit vectors called V0 and V1.
7844 /// At first, each vector is split into two separate 128-bit vectors.
7845 /// Then, the resulting 128-bit vectors are used to implement two
7846 /// horizontal binary operations.
7847 ///
7848 /// The kind of horizontal binary operation is defined by \p X86Opcode.
7849 ///
7850 /// \p Mode specifies how the 128-bit parts of V0 and V1 are passed in input to
7851 /// the two new horizontal binop.
7852 /// When Mode is set, the first horizontal binop dag node would take as input
7853 /// the lower 128-bit of V0 and the upper 128-bit of V0. The second
7854 /// horizontal binop dag node would take as input the lower 128-bit of V1
7855 /// and the upper 128-bit of V1.
7856 ///   Example:
7857 ///     HADD V0_LO, V0_HI
7858 ///     HADD V1_LO, V1_HI
7859 ///
7860 /// Otherwise, the first horizontal binop dag node takes as input the lower
7861 /// 128-bit of V0 and the lower 128-bit of V1, and the second horizontal binop
7862 /// dag node takes the upper 128-bit of V0 and the upper 128-bit of V1.
7863 ///   Example:
7864 ///     HADD V0_LO, V1_LO
7865 ///     HADD V0_HI, V1_HI
7866 ///
7867 /// If \p isUndefLO is set, then the algorithm propagates UNDEF to the lower
7868 /// 128-bits of the result. If \p isUndefHI is set, then UNDEF is propagated to
7869 /// the upper 128-bits of the result.
ExpandHorizontalBinOp(const SDValue & V0,const SDValue & V1,const SDLoc & DL,SelectionDAG & DAG,unsigned X86Opcode,bool Mode,bool isUndefLO,bool isUndefHI)7870 static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1,
7871                                      const SDLoc &DL, SelectionDAG &DAG,
7872                                      unsigned X86Opcode, bool Mode,
7873                                      bool isUndefLO, bool isUndefHI) {
7874   MVT VT = V0.getSimpleValueType();
7875   assert(VT.is256BitVector() && VT == V1.getSimpleValueType() &&
7876          "Invalid nodes in input!");
7877 
7878   unsigned NumElts = VT.getVectorNumElements();
7879   SDValue V0_LO = extract128BitVector(V0, 0, DAG, DL);
7880   SDValue V0_HI = extract128BitVector(V0, NumElts/2, DAG, DL);
7881   SDValue V1_LO = extract128BitVector(V1, 0, DAG, DL);
7882   SDValue V1_HI = extract128BitVector(V1, NumElts/2, DAG, DL);
7883   MVT NewVT = V0_LO.getSimpleValueType();
7884 
7885   SDValue LO = DAG.getUNDEF(NewVT);
7886   SDValue HI = DAG.getUNDEF(NewVT);
7887 
7888   if (Mode) {
7889     // Don't emit a horizontal binop if the result is expected to be UNDEF.
7890     if (!isUndefLO && !V0->isUndef())
7891       LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V0_HI);
7892     if (!isUndefHI && !V1->isUndef())
7893       HI = DAG.getNode(X86Opcode, DL, NewVT, V1_LO, V1_HI);
7894   } else {
7895     // Don't emit a horizontal binop if the result is expected to be UNDEF.
7896     if (!isUndefLO && (!V0_LO->isUndef() || !V1_LO->isUndef()))
7897       LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V1_LO);
7898 
7899     if (!isUndefHI && (!V0_HI->isUndef() || !V1_HI->isUndef()))
7900       HI = DAG.getNode(X86Opcode, DL, NewVT, V0_HI, V1_HI);
7901   }
7902 
7903   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI);
7904 }
7905 
7906 /// Returns true iff \p BV builds a vector with the result equivalent to
7907 /// the result of ADDSUB/SUBADD operation.
7908 /// If true is returned then the operands of ADDSUB = Opnd0 +- Opnd1
7909 /// (SUBADD = Opnd0 -+ Opnd1) operation are written to the parameters
7910 /// \p Opnd0 and \p Opnd1.
isAddSubOrSubAdd(const BuildVectorSDNode * BV,const X86Subtarget & Subtarget,SelectionDAG & DAG,SDValue & Opnd0,SDValue & Opnd1,unsigned & NumExtracts,bool & IsSubAdd)7911 static bool isAddSubOrSubAdd(const BuildVectorSDNode *BV,
7912                              const X86Subtarget &Subtarget, SelectionDAG &DAG,
7913                              SDValue &Opnd0, SDValue &Opnd1,
7914                              unsigned &NumExtracts,
7915                              bool &IsSubAdd) {
7916 
7917   MVT VT = BV->getSimpleValueType(0);
7918   if (!Subtarget.hasSSE3() || !VT.isFloatingPoint())
7919     return false;
7920 
7921   unsigned NumElts = VT.getVectorNumElements();
7922   SDValue InVec0 = DAG.getUNDEF(VT);
7923   SDValue InVec1 = DAG.getUNDEF(VT);
7924 
7925   NumExtracts = 0;
7926 
7927   // Odd-numbered elements in the input build vector are obtained from
7928   // adding/subtracting two integer/float elements.
7929   // Even-numbered elements in the input build vector are obtained from
7930   // subtracting/adding two integer/float elements.
7931   unsigned Opc[2] = {0, 0};
7932   for (unsigned i = 0, e = NumElts; i != e; ++i) {
7933     SDValue Op = BV->getOperand(i);
7934 
7935     // Skip 'undef' values.
7936     unsigned Opcode = Op.getOpcode();
7937     if (Opcode == ISD::UNDEF)
7938       continue;
7939 
7940     // Early exit if we found an unexpected opcode.
7941     if (Opcode != ISD::FADD && Opcode != ISD::FSUB)
7942       return false;
7943 
7944     SDValue Op0 = Op.getOperand(0);
7945     SDValue Op1 = Op.getOperand(1);
7946 
7947     // Try to match the following pattern:
7948     // (BINOP (extract_vector_elt A, i), (extract_vector_elt B, i))
7949     // Early exit if we cannot match that sequence.
7950     if (Op0.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
7951         Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
7952         !isa<ConstantSDNode>(Op0.getOperand(1)) ||
7953         Op0.getOperand(1) != Op1.getOperand(1))
7954       return false;
7955 
7956     unsigned I0 = Op0.getConstantOperandVal(1);
7957     if (I0 != i)
7958       return false;
7959 
7960     // We found a valid add/sub node, make sure its the same opcode as previous
7961     // elements for this parity.
7962     if (Opc[i % 2] != 0 && Opc[i % 2] != Opcode)
7963       return false;
7964     Opc[i % 2] = Opcode;
7965 
7966     // Update InVec0 and InVec1.
7967     if (InVec0.isUndef()) {
7968       InVec0 = Op0.getOperand(0);
7969       if (InVec0.getSimpleValueType() != VT)
7970         return false;
7971     }
7972     if (InVec1.isUndef()) {
7973       InVec1 = Op1.getOperand(0);
7974       if (InVec1.getSimpleValueType() != VT)
7975         return false;
7976     }
7977 
7978     // Make sure that operands in input to each add/sub node always
7979     // come from a same pair of vectors.
7980     if (InVec0 != Op0.getOperand(0)) {
7981       if (Opcode == ISD::FSUB)
7982         return false;
7983 
7984       // FADD is commutable. Try to commute the operands
7985       // and then test again.
7986       std::swap(Op0, Op1);
7987       if (InVec0 != Op0.getOperand(0))
7988         return false;
7989     }
7990 
7991     if (InVec1 != Op1.getOperand(0))
7992       return false;
7993 
7994     // Increment the number of extractions done.
7995     ++NumExtracts;
7996   }
7997 
7998   // Ensure we have found an opcode for both parities and that they are
7999   // different. Don't try to fold this build_vector into an ADDSUB/SUBADD if the
8000   // inputs are undef.
8001   if (!Opc[0] || !Opc[1] || Opc[0] == Opc[1] ||
8002       InVec0.isUndef() || InVec1.isUndef())
8003     return false;
8004 
8005   IsSubAdd = Opc[0] == ISD::FADD;
8006 
8007   Opnd0 = InVec0;
8008   Opnd1 = InVec1;
8009   return true;
8010 }
8011 
8012 /// Returns true if is possible to fold MUL and an idiom that has already been
8013 /// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into
8014 /// FMADDSUB/FMSUBADD(x, y, \p Opnd1). If (and only if) true is returned, the
8015 /// operands of FMADDSUB/FMSUBADD are written to parameters \p Opnd0, \p Opnd1, \p Opnd2.
8016 ///
8017 /// Prior to calling this function it should be known that there is some
8018 /// SDNode that potentially can be replaced with an X86ISD::ADDSUB operation
8019 /// using \p Opnd0 and \p Opnd1 as operands. Also, this method is called
8020 /// before replacement of such SDNode with ADDSUB operation. Thus the number
8021 /// of \p Opnd0 uses is expected to be equal to 2.
8022 /// For example, this function may be called for the following IR:
8023 ///    %AB = fmul fast <2 x double> %A, %B
8024 ///    %Sub = fsub fast <2 x double> %AB, %C
8025 ///    %Add = fadd fast <2 x double> %AB, %C
8026 ///    %Addsub = shufflevector <2 x double> %Sub, <2 x double> %Add,
8027 ///                            <2 x i32> <i32 0, i32 3>
8028 /// There is a def for %Addsub here, which potentially can be replaced by
8029 /// X86ISD::ADDSUB operation:
8030 ///    %Addsub = X86ISD::ADDSUB %AB, %C
8031 /// and such ADDSUB can further be replaced with FMADDSUB:
8032 ///    %Addsub = FMADDSUB %A, %B, %C.
8033 ///
8034 /// The main reason why this method is called before the replacement of the
8035 /// recognized ADDSUB idiom with ADDSUB operation is that such replacement
8036 /// is illegal sometimes. E.g. 512-bit ADDSUB is not available, while 512-bit
8037 /// FMADDSUB is.
isFMAddSubOrFMSubAdd(const X86Subtarget & Subtarget,SelectionDAG & DAG,SDValue & Opnd0,SDValue & Opnd1,SDValue & Opnd2,unsigned ExpectedUses)8038 static bool isFMAddSubOrFMSubAdd(const X86Subtarget &Subtarget,
8039                                  SelectionDAG &DAG,
8040                                  SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2,
8041                                  unsigned ExpectedUses) {
8042   if (Opnd0.getOpcode() != ISD::FMUL ||
8043       !Opnd0->hasNUsesOfValue(ExpectedUses, 0) || !Subtarget.hasAnyFMA())
8044     return false;
8045 
8046   // FIXME: These checks must match the similar ones in
8047   // DAGCombiner::visitFADDForFMACombine. It would be good to have one
8048   // function that would answer if it is Ok to fuse MUL + ADD to FMADD
8049   // or MUL + ADDSUB to FMADDSUB.
8050   const TargetOptions &Options = DAG.getTarget().Options;
8051   bool AllowFusion =
8052       (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
8053   if (!AllowFusion)
8054     return false;
8055 
8056   Opnd2 = Opnd1;
8057   Opnd1 = Opnd0.getOperand(1);
8058   Opnd0 = Opnd0.getOperand(0);
8059 
8060   return true;
8061 }
8062 
8063 /// Try to fold a build_vector that performs an 'addsub' or 'fmaddsub' or
8064 /// 'fsubadd' operation accordingly to X86ISD::ADDSUB or X86ISD::FMADDSUB or
8065 /// X86ISD::FMSUBADD node.
lowerToAddSubOrFMAddSub(const BuildVectorSDNode * BV,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)8066 static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV,
8067                                        const SDLoc &DL,
8068                                        const X86Subtarget &Subtarget,
8069                                        SelectionDAG &DAG) {
8070   SDValue Opnd0, Opnd1;
8071   unsigned NumExtracts;
8072   bool IsSubAdd;
8073   if (!isAddSubOrSubAdd(BV, Subtarget, DAG, Opnd0, Opnd1, NumExtracts,
8074                         IsSubAdd))
8075     return SDValue();
8076 
8077   MVT VT = BV->getSimpleValueType(0);
8078 
8079   // Try to generate X86ISD::FMADDSUB node here.
8080   SDValue Opnd2;
8081   if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) {
8082     unsigned Opc = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB;
8083     return DAG.getNode(Opc, DL, VT, Opnd0, Opnd1, Opnd2);
8084   }
8085 
8086   // We only support ADDSUB.
8087   if (IsSubAdd)
8088     return SDValue();
8089 
8090   // There are no known X86 targets with 512-bit ADDSUB instructions!
8091   // Convert to blend(fsub,fadd).
8092   if (VT.is512BitVector()) {
8093     SmallVector<int> Mask;
8094     for (int I = 0, E = VT.getVectorNumElements(); I != E; I += 2) {
8095         Mask.push_back(I);
8096         Mask.push_back(I + E + 1);
8097     }
8098     SDValue Sub = DAG.getNode(ISD::FSUB, DL, VT, Opnd0, Opnd1);
8099     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, Opnd0, Opnd1);
8100     return DAG.getVectorShuffle(VT, DL, Sub, Add, Mask);
8101   }
8102 
8103   return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1);
8104 }
8105 
isHopBuildVector(const BuildVectorSDNode * BV,SelectionDAG & DAG,unsigned & HOpcode,SDValue & V0,SDValue & V1)8106 static bool isHopBuildVector(const BuildVectorSDNode *BV, SelectionDAG &DAG,
8107                              unsigned &HOpcode, SDValue &V0, SDValue &V1) {
8108   // Initialize outputs to known values.
8109   MVT VT = BV->getSimpleValueType(0);
8110   HOpcode = ISD::DELETED_NODE;
8111   V0 = DAG.getUNDEF(VT);
8112   V1 = DAG.getUNDEF(VT);
8113 
8114   // x86 256-bit horizontal ops are defined in a non-obvious way. Each 128-bit
8115   // half of the result is calculated independently from the 128-bit halves of
8116   // the inputs, so that makes the index-checking logic below more complicated.
8117   unsigned NumElts = VT.getVectorNumElements();
8118   unsigned GenericOpcode = ISD::DELETED_NODE;
8119   unsigned Num128BitChunks = VT.is256BitVector() ? 2 : 1;
8120   unsigned NumEltsIn128Bits = NumElts / Num128BitChunks;
8121   unsigned NumEltsIn64Bits = NumEltsIn128Bits / 2;
8122   for (unsigned i = 0; i != Num128BitChunks; ++i) {
8123     for (unsigned j = 0; j != NumEltsIn128Bits; ++j) {
8124       // Ignore undef elements.
8125       SDValue Op = BV->getOperand(i * NumEltsIn128Bits + j);
8126       if (Op.isUndef())
8127         continue;
8128 
8129       // If there's an opcode mismatch, we're done.
8130       if (HOpcode != ISD::DELETED_NODE && Op.getOpcode() != GenericOpcode)
8131         return false;
8132 
8133       // Initialize horizontal opcode.
8134       if (HOpcode == ISD::DELETED_NODE) {
8135         GenericOpcode = Op.getOpcode();
8136         switch (GenericOpcode) {
8137         // clang-format off
8138         case ISD::ADD: HOpcode = X86ISD::HADD; break;
8139         case ISD::SUB: HOpcode = X86ISD::HSUB; break;
8140         case ISD::FADD: HOpcode = X86ISD::FHADD; break;
8141         case ISD::FSUB: HOpcode = X86ISD::FHSUB; break;
8142         default: return false;
8143         // clang-format on
8144         }
8145       }
8146 
8147       SDValue Op0 = Op.getOperand(0);
8148       SDValue Op1 = Op.getOperand(1);
8149       if (Op0.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
8150           Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
8151           Op0.getOperand(0) != Op1.getOperand(0) ||
8152           !isa<ConstantSDNode>(Op0.getOperand(1)) ||
8153           !isa<ConstantSDNode>(Op1.getOperand(1)) || !Op.hasOneUse())
8154         return false;
8155 
8156       // The source vector is chosen based on which 64-bit half of the
8157       // destination vector is being calculated.
8158       if (j < NumEltsIn64Bits) {
8159         if (V0.isUndef())
8160           V0 = Op0.getOperand(0);
8161       } else {
8162         if (V1.isUndef())
8163           V1 = Op0.getOperand(0);
8164       }
8165 
8166       SDValue SourceVec = (j < NumEltsIn64Bits) ? V0 : V1;
8167       if (SourceVec != Op0.getOperand(0))
8168         return false;
8169 
8170       // op (extract_vector_elt A, I), (extract_vector_elt A, I+1)
8171       unsigned ExtIndex0 = Op0.getConstantOperandVal(1);
8172       unsigned ExtIndex1 = Op1.getConstantOperandVal(1);
8173       unsigned ExpectedIndex = i * NumEltsIn128Bits +
8174                                (j % NumEltsIn64Bits) * 2;
8175       if (ExpectedIndex == ExtIndex0 && ExtIndex1 == ExtIndex0 + 1)
8176         continue;
8177 
8178       // If this is not a commutative op, this does not match.
8179       if (GenericOpcode != ISD::ADD && GenericOpcode != ISD::FADD)
8180         return false;
8181 
8182       // Addition is commutative, so try swapping the extract indexes.
8183       // op (extract_vector_elt A, I+1), (extract_vector_elt A, I)
8184       if (ExpectedIndex == ExtIndex1 && ExtIndex0 == ExtIndex1 + 1)
8185         continue;
8186 
8187       // Extract indexes do not match horizontal requirement.
8188       return false;
8189     }
8190   }
8191   // We matched. Opcode and operands are returned by reference as arguments.
8192   return true;
8193 }
8194 
getHopForBuildVector(const BuildVectorSDNode * BV,const SDLoc & DL,SelectionDAG & DAG,unsigned HOpcode,SDValue V0,SDValue V1)8195 static SDValue getHopForBuildVector(const BuildVectorSDNode *BV,
8196                                     const SDLoc &DL, SelectionDAG &DAG,
8197                                     unsigned HOpcode, SDValue V0, SDValue V1) {
8198   // If either input vector is not the same size as the build vector,
8199   // extract/insert the low bits to the correct size.
8200   // This is free (examples: zmm --> xmm, xmm --> ymm).
8201   MVT VT = BV->getSimpleValueType(0);
8202   unsigned Width = VT.getSizeInBits();
8203   if (V0.getValueSizeInBits() > Width)
8204     V0 = extractSubVector(V0, 0, DAG, DL, Width);
8205   else if (V0.getValueSizeInBits() < Width)
8206     V0 = insertSubVector(DAG.getUNDEF(VT), V0, 0, DAG, DL, Width);
8207 
8208   if (V1.getValueSizeInBits() > Width)
8209     V1 = extractSubVector(V1, 0, DAG, DL, Width);
8210   else if (V1.getValueSizeInBits() < Width)
8211     V1 = insertSubVector(DAG.getUNDEF(VT), V1, 0, DAG, DL, Width);
8212 
8213   unsigned NumElts = VT.getVectorNumElements();
8214   APInt DemandedElts = APInt::getAllOnes(NumElts);
8215   for (unsigned i = 0; i != NumElts; ++i)
8216     if (BV->getOperand(i).isUndef())
8217       DemandedElts.clearBit(i);
8218 
8219   // If we don't need the upper xmm, then perform as a xmm hop.
8220   unsigned HalfNumElts = NumElts / 2;
8221   if (VT.is256BitVector() && DemandedElts.lshr(HalfNumElts) == 0) {
8222     MVT HalfVT = VT.getHalfNumVectorElementsVT();
8223     V0 = extractSubVector(V0, 0, DAG, DL, 128);
8224     V1 = extractSubVector(V1, 0, DAG, DL, 128);
8225     SDValue Half = DAG.getNode(HOpcode, DL, HalfVT, V0, V1);
8226     return insertSubVector(DAG.getUNDEF(VT), Half, 0, DAG, DL, 256);
8227   }
8228 
8229   return DAG.getNode(HOpcode, DL, VT, V0, V1);
8230 }
8231 
8232 /// Lower BUILD_VECTOR to a horizontal add/sub operation if possible.
LowerToHorizontalOp(const BuildVectorSDNode * BV,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)8233 static SDValue LowerToHorizontalOp(const BuildVectorSDNode *BV, const SDLoc &DL,
8234                                    const X86Subtarget &Subtarget,
8235                                    SelectionDAG &DAG) {
8236   // We need at least 2 non-undef elements to make this worthwhile by default.
8237   unsigned NumNonUndefs =
8238       count_if(BV->op_values(), [](SDValue V) { return !V.isUndef(); });
8239   if (NumNonUndefs < 2)
8240     return SDValue();
8241 
8242   // There are 4 sets of horizontal math operations distinguished by type:
8243   // int/FP at 128-bit/256-bit. Each type was introduced with a different
8244   // subtarget feature. Try to match those "native" patterns first.
8245   MVT VT = BV->getSimpleValueType(0);
8246   if (((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget.hasSSE3()) ||
8247       ((VT == MVT::v8i16 || VT == MVT::v4i32) && Subtarget.hasSSSE3()) ||
8248       ((VT == MVT::v8f32 || VT == MVT::v4f64) && Subtarget.hasAVX()) ||
8249       ((VT == MVT::v16i16 || VT == MVT::v8i32) && Subtarget.hasAVX2())) {
8250     unsigned HOpcode;
8251     SDValue V0, V1;
8252     if (isHopBuildVector(BV, DAG, HOpcode, V0, V1))
8253       return getHopForBuildVector(BV, DL, DAG, HOpcode, V0, V1);
8254   }
8255 
8256   // Try harder to match 256-bit ops by using extract/concat.
8257   if (!Subtarget.hasAVX() || !VT.is256BitVector())
8258     return SDValue();
8259 
8260   // Count the number of UNDEF operands in the build_vector in input.
8261   unsigned NumElts = VT.getVectorNumElements();
8262   unsigned Half = NumElts / 2;
8263   unsigned NumUndefsLO = 0;
8264   unsigned NumUndefsHI = 0;
8265   for (unsigned i = 0, e = Half; i != e; ++i)
8266     if (BV->getOperand(i)->isUndef())
8267       NumUndefsLO++;
8268 
8269   for (unsigned i = Half, e = NumElts; i != e; ++i)
8270     if (BV->getOperand(i)->isUndef())
8271       NumUndefsHI++;
8272 
8273   SDValue InVec0, InVec1;
8274   if (VT == MVT::v8i32 || VT == MVT::v16i16) {
8275     SDValue InVec2, InVec3;
8276     unsigned X86Opcode;
8277     bool CanFold = true;
8278 
8279     if (isHorizontalBinOpPart(BV, ISD::ADD, DL, DAG, 0, Half, InVec0, InVec1) &&
8280         isHorizontalBinOpPart(BV, ISD::ADD, DL, DAG, Half, NumElts, InVec2,
8281                               InVec3) &&
8282         ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) &&
8283         ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3))
8284       X86Opcode = X86ISD::HADD;
8285     else if (isHorizontalBinOpPart(BV, ISD::SUB, DL, DAG, 0, Half, InVec0,
8286                                    InVec1) &&
8287              isHorizontalBinOpPart(BV, ISD::SUB, DL, DAG, Half, NumElts, InVec2,
8288                                    InVec3) &&
8289              ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) &&
8290              ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3))
8291       X86Opcode = X86ISD::HSUB;
8292     else
8293       CanFold = false;
8294 
8295     if (CanFold) {
8296       // Do not try to expand this build_vector into a pair of horizontal
8297       // add/sub if we can emit a pair of scalar add/sub.
8298       if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half)
8299         return SDValue();
8300 
8301       // Convert this build_vector into a pair of horizontal binops followed by
8302       // a concat vector. We must adjust the outputs from the partial horizontal
8303       // matching calls above to account for undefined vector halves.
8304       SDValue V0 = InVec0.isUndef() ? InVec2 : InVec0;
8305       SDValue V1 = InVec1.isUndef() ? InVec3 : InVec1;
8306       assert((!V0.isUndef() || !V1.isUndef()) && "Horizontal-op of undefs?");
8307       bool isUndefLO = NumUndefsLO == Half;
8308       bool isUndefHI = NumUndefsHI == Half;
8309       return ExpandHorizontalBinOp(V0, V1, DL, DAG, X86Opcode, false, isUndefLO,
8310                                    isUndefHI);
8311     }
8312   }
8313 
8314   if (VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
8315       VT == MVT::v16i16) {
8316     unsigned X86Opcode;
8317     if (isHorizontalBinOpPart(BV, ISD::ADD, DL, DAG, 0, NumElts, InVec0,
8318                               InVec1))
8319       X86Opcode = X86ISD::HADD;
8320     else if (isHorizontalBinOpPart(BV, ISD::SUB, DL, DAG, 0, NumElts, InVec0,
8321                                    InVec1))
8322       X86Opcode = X86ISD::HSUB;
8323     else if (isHorizontalBinOpPart(BV, ISD::FADD, DL, DAG, 0, NumElts, InVec0,
8324                                    InVec1))
8325       X86Opcode = X86ISD::FHADD;
8326     else if (isHorizontalBinOpPart(BV, ISD::FSUB, DL, DAG, 0, NumElts, InVec0,
8327                                    InVec1))
8328       X86Opcode = X86ISD::FHSUB;
8329     else
8330       return SDValue();
8331 
8332     // Don't try to expand this build_vector into a pair of horizontal add/sub
8333     // if we can simply emit a pair of scalar add/sub.
8334     if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half)
8335       return SDValue();
8336 
8337     // Convert this build_vector into two horizontal add/sub followed by
8338     // a concat vector.
8339     bool isUndefLO = NumUndefsLO == Half;
8340     bool isUndefHI = NumUndefsHI == Half;
8341     return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, true,
8342                                  isUndefLO, isUndefHI);
8343   }
8344 
8345   return SDValue();
8346 }
8347 
8348 static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
8349                           SelectionDAG &DAG);
8350 
8351 /// If a BUILD_VECTOR's source elements all apply the same bit operation and
8352 /// one of their operands is constant, lower to a pair of BUILD_VECTOR and
8353 /// just apply the bit to the vectors.
8354 /// NOTE: Its not in our interest to start make a general purpose vectorizer
8355 /// from this, but enough scalar bit operations are created from the later
8356 /// legalization + scalarization stages to need basic support.
lowerBuildVectorToBitOp(BuildVectorSDNode * Op,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)8357 static SDValue lowerBuildVectorToBitOp(BuildVectorSDNode *Op, const SDLoc &DL,
8358                                        const X86Subtarget &Subtarget,
8359                                        SelectionDAG &DAG) {
8360   MVT VT = Op->getSimpleValueType(0);
8361   unsigned NumElems = VT.getVectorNumElements();
8362   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
8363 
8364   // Check that all elements have the same opcode.
8365   // TODO: Should we allow UNDEFS and if so how many?
8366   unsigned Opcode = Op->getOperand(0).getOpcode();
8367   for (unsigned i = 1; i < NumElems; ++i)
8368     if (Opcode != Op->getOperand(i).getOpcode())
8369       return SDValue();
8370 
8371   // TODO: We may be able to add support for other Ops (ADD/SUB + shifts).
8372   bool IsShift = false;
8373   switch (Opcode) {
8374   default:
8375     return SDValue();
8376   case ISD::SHL:
8377   case ISD::SRL:
8378   case ISD::SRA:
8379     IsShift = true;
8380     break;
8381   case ISD::AND:
8382   case ISD::XOR:
8383   case ISD::OR:
8384     // Don't do this if the buildvector is a splat - we'd replace one
8385     // constant with an entire vector.
8386     if (Op->getSplatValue())
8387       return SDValue();
8388     if (!TLI.isOperationLegalOrPromote(Opcode, VT))
8389       return SDValue();
8390     break;
8391   }
8392 
8393   SmallVector<SDValue, 4> LHSElts, RHSElts;
8394   for (SDValue Elt : Op->ops()) {
8395     SDValue LHS = Elt.getOperand(0);
8396     SDValue RHS = Elt.getOperand(1);
8397 
8398     // We expect the canonicalized RHS operand to be the constant.
8399     if (!isa<ConstantSDNode>(RHS))
8400       return SDValue();
8401 
8402     // Extend shift amounts.
8403     if (RHS.getValueSizeInBits() != VT.getScalarSizeInBits()) {
8404       if (!IsShift)
8405         return SDValue();
8406       RHS = DAG.getZExtOrTrunc(RHS, DL, VT.getScalarType());
8407     }
8408 
8409     LHSElts.push_back(LHS);
8410     RHSElts.push_back(RHS);
8411   }
8412 
8413   // Limit to shifts by uniform immediates.
8414   // TODO: Only accept vXi8/vXi64 special cases?
8415   // TODO: Permit non-uniform XOP/AVX2/MULLO cases?
8416   if (IsShift && any_of(RHSElts, [&](SDValue V) { return RHSElts[0] != V; }))
8417     return SDValue();
8418 
8419   SDValue LHS = DAG.getBuildVector(VT, DL, LHSElts);
8420   SDValue RHS = DAG.getBuildVector(VT, DL, RHSElts);
8421   SDValue Res = DAG.getNode(Opcode, DL, VT, LHS, RHS);
8422 
8423   if (!IsShift)
8424     return Res;
8425 
8426   // Immediately lower the shift to ensure the constant build vector doesn't
8427   // get converted to a constant pool before the shift is lowered.
8428   return LowerShift(Res, Subtarget, DAG);
8429 }
8430 
8431 /// Create a vector constant without a load. SSE/AVX provide the bare minimum
8432 /// functionality to do this, so it's all zeros, all ones, or some derivation
8433 /// that is cheap to calculate.
materializeVectorConstant(SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)8434 static SDValue materializeVectorConstant(SDValue Op, const SDLoc &DL,
8435                                          SelectionDAG &DAG,
8436                                          const X86Subtarget &Subtarget) {
8437   MVT VT = Op.getSimpleValueType();
8438 
8439   // Vectors containing all zeros can be matched by pxor and xorps.
8440   if (ISD::isBuildVectorAllZeros(Op.getNode()))
8441     return Op;
8442 
8443   // Vectors containing all ones can be matched by pcmpeqd on 128-bit width
8444   // vectors or broken into v4i32 operations on 256-bit vectors. AVX2 can use
8445   // vpcmpeqd on 256-bit vectors.
8446   if (Subtarget.hasSSE2() && ISD::isBuildVectorAllOnes(Op.getNode())) {
8447     if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32)
8448       return Op;
8449 
8450     return getOnesVector(VT, DAG, DL);
8451   }
8452 
8453   return SDValue();
8454 }
8455 
8456 /// Look for opportunities to create a VPERMV/VPERMILPV/PSHUFB variable permute
8457 /// from a vector of source values and a vector of extraction indices.
8458 /// The vectors might be manipulated to match the type of the permute op.
createVariablePermute(MVT VT,SDValue SrcVec,SDValue IndicesVec,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)8459 static SDValue createVariablePermute(MVT VT, SDValue SrcVec, SDValue IndicesVec,
8460                                      const SDLoc &DL, SelectionDAG &DAG,
8461                                      const X86Subtarget &Subtarget) {
8462   MVT ShuffleVT = VT;
8463   EVT IndicesVT = EVT(VT).changeVectorElementTypeToInteger();
8464   unsigned NumElts = VT.getVectorNumElements();
8465   unsigned SizeInBits = VT.getSizeInBits();
8466 
8467   // Adjust IndicesVec to match VT size.
8468   assert(IndicesVec.getValueType().getVectorNumElements() >= NumElts &&
8469          "Illegal variable permute mask size");
8470   if (IndicesVec.getValueType().getVectorNumElements() > NumElts) {
8471     // Narrow/widen the indices vector to the correct size.
8472     if (IndicesVec.getValueSizeInBits() > SizeInBits)
8473       IndicesVec = extractSubVector(IndicesVec, 0, DAG, SDLoc(IndicesVec),
8474                                     NumElts * VT.getScalarSizeInBits());
8475     else if (IndicesVec.getValueSizeInBits() < SizeInBits)
8476       IndicesVec = widenSubVector(IndicesVec, false, Subtarget, DAG,
8477                                   SDLoc(IndicesVec), SizeInBits);
8478     // Zero-extend the index elements within the vector.
8479     if (IndicesVec.getValueType().getVectorNumElements() > NumElts)
8480       IndicesVec = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(IndicesVec),
8481                                IndicesVT, IndicesVec);
8482   }
8483   IndicesVec = DAG.getZExtOrTrunc(IndicesVec, SDLoc(IndicesVec), IndicesVT);
8484 
8485   // Handle SrcVec that don't match VT type.
8486   if (SrcVec.getValueSizeInBits() != SizeInBits) {
8487     if ((SrcVec.getValueSizeInBits() % SizeInBits) == 0) {
8488       // Handle larger SrcVec by treating it as a larger permute.
8489       unsigned Scale = SrcVec.getValueSizeInBits() / SizeInBits;
8490       VT = MVT::getVectorVT(VT.getScalarType(), Scale * NumElts);
8491       IndicesVT = EVT(VT).changeVectorElementTypeToInteger();
8492       IndicesVec = widenSubVector(IndicesVT.getSimpleVT(), IndicesVec, false,
8493                                   Subtarget, DAG, SDLoc(IndicesVec));
8494       SDValue NewSrcVec =
8495           createVariablePermute(VT, SrcVec, IndicesVec, DL, DAG, Subtarget);
8496       if (NewSrcVec)
8497         return extractSubVector(NewSrcVec, 0, DAG, DL, SizeInBits);
8498       return SDValue();
8499     } else if (SrcVec.getValueSizeInBits() < SizeInBits) {
8500       // Widen smaller SrcVec to match VT.
8501       SrcVec = widenSubVector(VT, SrcVec, false, Subtarget, DAG, SDLoc(SrcVec));
8502     } else
8503       return SDValue();
8504   }
8505 
8506   auto ScaleIndices = [&DAG](SDValue Idx, uint64_t Scale) {
8507     assert(isPowerOf2_64(Scale) && "Illegal variable permute shuffle scale");
8508     EVT SrcVT = Idx.getValueType();
8509     unsigned NumDstBits = SrcVT.getScalarSizeInBits() / Scale;
8510     uint64_t IndexScale = 0;
8511     uint64_t IndexOffset = 0;
8512 
8513     // If we're scaling a smaller permute op, then we need to repeat the
8514     // indices, scaling and offsetting them as well.
8515     // e.g. v4i32 -> v16i8 (Scale = 4)
8516     // IndexScale = v4i32 Splat(4 << 24 | 4 << 16 | 4 << 8 | 4)
8517     // IndexOffset = v4i32 Splat(3 << 24 | 2 << 16 | 1 << 8 | 0)
8518     for (uint64_t i = 0; i != Scale; ++i) {
8519       IndexScale |= Scale << (i * NumDstBits);
8520       IndexOffset |= i << (i * NumDstBits);
8521     }
8522 
8523     Idx = DAG.getNode(ISD::MUL, SDLoc(Idx), SrcVT, Idx,
8524                       DAG.getConstant(IndexScale, SDLoc(Idx), SrcVT));
8525     Idx = DAG.getNode(ISD::ADD, SDLoc(Idx), SrcVT, Idx,
8526                       DAG.getConstant(IndexOffset, SDLoc(Idx), SrcVT));
8527     return Idx;
8528   };
8529 
8530   unsigned Opcode = 0;
8531   switch (VT.SimpleTy) {
8532   default:
8533     break;
8534   case MVT::v16i8:
8535     if (Subtarget.hasSSSE3())
8536       Opcode = X86ISD::PSHUFB;
8537     break;
8538   case MVT::v8i16:
8539     if (Subtarget.hasVLX() && Subtarget.hasBWI())
8540       Opcode = X86ISD::VPERMV;
8541     else if (Subtarget.hasSSSE3()) {
8542       Opcode = X86ISD::PSHUFB;
8543       ShuffleVT = MVT::v16i8;
8544     }
8545     break;
8546   case MVT::v4f32:
8547   case MVT::v4i32:
8548     if (Subtarget.hasAVX()) {
8549       Opcode = X86ISD::VPERMILPV;
8550       ShuffleVT = MVT::v4f32;
8551     } else if (Subtarget.hasSSSE3()) {
8552       Opcode = X86ISD::PSHUFB;
8553       ShuffleVT = MVT::v16i8;
8554     }
8555     break;
8556   case MVT::v2f64:
8557   case MVT::v2i64:
8558     if (Subtarget.hasAVX()) {
8559       // VPERMILPD selects using bit#1 of the index vector, so scale IndicesVec.
8560       IndicesVec = DAG.getNode(ISD::ADD, DL, IndicesVT, IndicesVec, IndicesVec);
8561       Opcode = X86ISD::VPERMILPV;
8562       ShuffleVT = MVT::v2f64;
8563     } else if (Subtarget.hasSSE41()) {
8564       // SSE41 can compare v2i64 - select between indices 0 and 1.
8565       return DAG.getSelectCC(
8566           DL, IndicesVec,
8567           getZeroVector(IndicesVT.getSimpleVT(), Subtarget, DAG, DL),
8568           DAG.getVectorShuffle(VT, DL, SrcVec, SrcVec, {0, 0}),
8569           DAG.getVectorShuffle(VT, DL, SrcVec, SrcVec, {1, 1}),
8570           ISD::CondCode::SETEQ);
8571     }
8572     break;
8573   case MVT::v32i8:
8574     if (Subtarget.hasVLX() && Subtarget.hasVBMI())
8575       Opcode = X86ISD::VPERMV;
8576     else if (Subtarget.hasXOP()) {
8577       SDValue LoSrc = extract128BitVector(SrcVec, 0, DAG, DL);
8578       SDValue HiSrc = extract128BitVector(SrcVec, 16, DAG, DL);
8579       SDValue LoIdx = extract128BitVector(IndicesVec, 0, DAG, DL);
8580       SDValue HiIdx = extract128BitVector(IndicesVec, 16, DAG, DL);
8581       return DAG.getNode(
8582           ISD::CONCAT_VECTORS, DL, VT,
8583           DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, LoSrc, HiSrc, LoIdx),
8584           DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, LoSrc, HiSrc, HiIdx));
8585     } else if (Subtarget.hasAVX()) {
8586       SDValue Lo = extract128BitVector(SrcVec, 0, DAG, DL);
8587       SDValue Hi = extract128BitVector(SrcVec, 16, DAG, DL);
8588       SDValue LoLo = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Lo);
8589       SDValue HiHi = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Hi, Hi);
8590       auto PSHUFBBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
8591                               ArrayRef<SDValue> Ops) {
8592         // Permute Lo and Hi and then select based on index range.
8593         // This works as SHUFB uses bits[3:0] to permute elements and we don't
8594         // care about the bit[7] as its just an index vector.
8595         SDValue Idx = Ops[2];
8596         EVT VT = Idx.getValueType();
8597         return DAG.getSelectCC(DL, Idx, DAG.getConstant(15, DL, VT),
8598                                DAG.getNode(X86ISD::PSHUFB, DL, VT, Ops[1], Idx),
8599                                DAG.getNode(X86ISD::PSHUFB, DL, VT, Ops[0], Idx),
8600                                ISD::CondCode::SETGT);
8601       };
8602       SDValue Ops[] = {LoLo, HiHi, IndicesVec};
8603       return SplitOpsAndApply(DAG, Subtarget, DL, MVT::v32i8, Ops,
8604                               PSHUFBBuilder);
8605     }
8606     break;
8607   case MVT::v16i16:
8608     if (Subtarget.hasVLX() && Subtarget.hasBWI())
8609       Opcode = X86ISD::VPERMV;
8610     else if (Subtarget.hasAVX()) {
8611       // Scale to v32i8 and perform as v32i8.
8612       IndicesVec = ScaleIndices(IndicesVec, 2);
8613       return DAG.getBitcast(
8614           VT, createVariablePermute(
8615                   MVT::v32i8, DAG.getBitcast(MVT::v32i8, SrcVec),
8616                   DAG.getBitcast(MVT::v32i8, IndicesVec), DL, DAG, Subtarget));
8617     }
8618     break;
8619   case MVT::v8f32:
8620   case MVT::v8i32:
8621     if (Subtarget.hasAVX2())
8622       Opcode = X86ISD::VPERMV;
8623     else if (Subtarget.hasAVX()) {
8624       SrcVec = DAG.getBitcast(MVT::v8f32, SrcVec);
8625       SDValue LoLo = DAG.getVectorShuffle(MVT::v8f32, DL, SrcVec, SrcVec,
8626                                           {0, 1, 2, 3, 0, 1, 2, 3});
8627       SDValue HiHi = DAG.getVectorShuffle(MVT::v8f32, DL, SrcVec, SrcVec,
8628                                           {4, 5, 6, 7, 4, 5, 6, 7});
8629       if (Subtarget.hasXOP())
8630         return DAG.getBitcast(
8631             VT, DAG.getNode(X86ISD::VPERMIL2, DL, MVT::v8f32, LoLo, HiHi,
8632                             IndicesVec, DAG.getTargetConstant(0, DL, MVT::i8)));
8633       // Permute Lo and Hi and then select based on index range.
8634       // This works as VPERMILPS only uses index bits[0:1] to permute elements.
8635       SDValue Res = DAG.getSelectCC(
8636           DL, IndicesVec, DAG.getConstant(3, DL, MVT::v8i32),
8637           DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, HiHi, IndicesVec),
8638           DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, LoLo, IndicesVec),
8639           ISD::CondCode::SETGT);
8640       return DAG.getBitcast(VT, Res);
8641     }
8642     break;
8643   case MVT::v4i64:
8644   case MVT::v4f64:
8645     if (Subtarget.hasAVX512()) {
8646       if (!Subtarget.hasVLX()) {
8647         MVT WidenSrcVT = MVT::getVectorVT(VT.getScalarType(), 8);
8648         SrcVec = widenSubVector(WidenSrcVT, SrcVec, false, Subtarget, DAG,
8649                                 SDLoc(SrcVec));
8650         IndicesVec = widenSubVector(MVT::v8i64, IndicesVec, false, Subtarget,
8651                                     DAG, SDLoc(IndicesVec));
8652         SDValue Res = createVariablePermute(WidenSrcVT, SrcVec, IndicesVec, DL,
8653                                             DAG, Subtarget);
8654         return extract256BitVector(Res, 0, DAG, DL);
8655       }
8656       Opcode = X86ISD::VPERMV;
8657     } else if (Subtarget.hasAVX()) {
8658       SrcVec = DAG.getBitcast(MVT::v4f64, SrcVec);
8659       SDValue LoLo =
8660           DAG.getVectorShuffle(MVT::v4f64, DL, SrcVec, SrcVec, {0, 1, 0, 1});
8661       SDValue HiHi =
8662           DAG.getVectorShuffle(MVT::v4f64, DL, SrcVec, SrcVec, {2, 3, 2, 3});
8663       // VPERMIL2PD selects with bit#1 of the index vector, so scale IndicesVec.
8664       IndicesVec = DAG.getNode(ISD::ADD, DL, IndicesVT, IndicesVec, IndicesVec);
8665       if (Subtarget.hasXOP())
8666         return DAG.getBitcast(
8667             VT, DAG.getNode(X86ISD::VPERMIL2, DL, MVT::v4f64, LoLo, HiHi,
8668                             IndicesVec, DAG.getTargetConstant(0, DL, MVT::i8)));
8669       // Permute Lo and Hi and then select based on index range.
8670       // This works as VPERMILPD only uses index bit[1] to permute elements.
8671       SDValue Res = DAG.getSelectCC(
8672           DL, IndicesVec, DAG.getConstant(2, DL, MVT::v4i64),
8673           DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v4f64, HiHi, IndicesVec),
8674           DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v4f64, LoLo, IndicesVec),
8675           ISD::CondCode::SETGT);
8676       return DAG.getBitcast(VT, Res);
8677     }
8678     break;
8679   case MVT::v64i8:
8680     if (Subtarget.hasVBMI())
8681       Opcode = X86ISD::VPERMV;
8682     break;
8683   case MVT::v32i16:
8684     if (Subtarget.hasBWI())
8685       Opcode = X86ISD::VPERMV;
8686     break;
8687   case MVT::v16f32:
8688   case MVT::v16i32:
8689   case MVT::v8f64:
8690   case MVT::v8i64:
8691     if (Subtarget.hasAVX512())
8692       Opcode = X86ISD::VPERMV;
8693     break;
8694   }
8695   if (!Opcode)
8696     return SDValue();
8697 
8698   assert((VT.getSizeInBits() == ShuffleVT.getSizeInBits()) &&
8699          (VT.getScalarSizeInBits() % ShuffleVT.getScalarSizeInBits()) == 0 &&
8700          "Illegal variable permute shuffle type");
8701 
8702   uint64_t Scale = VT.getScalarSizeInBits() / ShuffleVT.getScalarSizeInBits();
8703   if (Scale > 1)
8704     IndicesVec = ScaleIndices(IndicesVec, Scale);
8705 
8706   EVT ShuffleIdxVT = EVT(ShuffleVT).changeVectorElementTypeToInteger();
8707   IndicesVec = DAG.getBitcast(ShuffleIdxVT, IndicesVec);
8708 
8709   SrcVec = DAG.getBitcast(ShuffleVT, SrcVec);
8710   SDValue Res = Opcode == X86ISD::VPERMV
8711                     ? DAG.getNode(Opcode, DL, ShuffleVT, IndicesVec, SrcVec)
8712                     : DAG.getNode(Opcode, DL, ShuffleVT, SrcVec, IndicesVec);
8713   return DAG.getBitcast(VT, Res);
8714 }
8715 
8716 // Tries to lower a BUILD_VECTOR composed of extract-extract chains that can be
8717 // reasoned to be a permutation of a vector by indices in a non-constant vector.
8718 // (build_vector (extract_elt V, (extract_elt I, 0)),
8719 //               (extract_elt V, (extract_elt I, 1)),
8720 //                    ...
8721 // ->
8722 // (vpermv I, V)
8723 //
8724 // TODO: Handle undefs
8725 // TODO: Utilize pshufb and zero mask blending to support more efficient
8726 // construction of vectors with constant-0 elements.
8727 static SDValue
LowerBUILD_VECTORAsVariablePermute(SDValue V,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)8728 LowerBUILD_VECTORAsVariablePermute(SDValue V, const SDLoc &DL,
8729                                    SelectionDAG &DAG,
8730                                    const X86Subtarget &Subtarget) {
8731   SDValue SrcVec, IndicesVec;
8732   // Check for a match of the permute source vector and permute index elements.
8733   // This is done by checking that the i-th build_vector operand is of the form:
8734   // (extract_elt SrcVec, (extract_elt IndicesVec, i)).
8735   for (unsigned Idx = 0, E = V.getNumOperands(); Idx != E; ++Idx) {
8736     SDValue Op = V.getOperand(Idx);
8737     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
8738       return SDValue();
8739 
8740     // If this is the first extract encountered in V, set the source vector,
8741     // otherwise verify the extract is from the previously defined source
8742     // vector.
8743     if (!SrcVec)
8744       SrcVec = Op.getOperand(0);
8745     else if (SrcVec != Op.getOperand(0))
8746       return SDValue();
8747     SDValue ExtractedIndex = Op->getOperand(1);
8748     // Peek through extends.
8749     if (ExtractedIndex.getOpcode() == ISD::ZERO_EXTEND ||
8750         ExtractedIndex.getOpcode() == ISD::SIGN_EXTEND)
8751       ExtractedIndex = ExtractedIndex.getOperand(0);
8752     if (ExtractedIndex.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
8753       return SDValue();
8754 
8755     // If this is the first extract from the index vector candidate, set the
8756     // indices vector, otherwise verify the extract is from the previously
8757     // defined indices vector.
8758     if (!IndicesVec)
8759       IndicesVec = ExtractedIndex.getOperand(0);
8760     else if (IndicesVec != ExtractedIndex.getOperand(0))
8761       return SDValue();
8762 
8763     auto *PermIdx = dyn_cast<ConstantSDNode>(ExtractedIndex.getOperand(1));
8764     if (!PermIdx || PermIdx->getAPIntValue() != Idx)
8765       return SDValue();
8766   }
8767 
8768   MVT VT = V.getSimpleValueType();
8769   return createVariablePermute(VT, SrcVec, IndicesVec, DL, DAG, Subtarget);
8770 }
8771 
8772 SDValue
LowerBUILD_VECTOR(SDValue Op,SelectionDAG & DAG) const8773 X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
8774   SDLoc dl(Op);
8775 
8776   MVT VT = Op.getSimpleValueType();
8777   MVT EltVT = VT.getVectorElementType();
8778   MVT OpEltVT = Op.getOperand(0).getSimpleValueType();
8779   unsigned NumElems = Op.getNumOperands();
8780 
8781   // Generate vectors for predicate vectors.
8782   if (VT.getVectorElementType() == MVT::i1 && Subtarget.hasAVX512())
8783     return LowerBUILD_VECTORvXi1(Op, dl, DAG, Subtarget);
8784 
8785   if (VT.getVectorElementType() == MVT::bf16 &&
8786       (Subtarget.hasAVXNECONVERT() || Subtarget.hasBF16()))
8787     return LowerBUILD_VECTORvXbf16(Op, DAG, Subtarget);
8788 
8789   if (SDValue VectorCst = materializeVectorConstant(Op, dl, DAG, Subtarget))
8790     return VectorCst;
8791 
8792   unsigned EVTBits = EltVT.getSizeInBits();
8793   APInt UndefMask = APInt::getZero(NumElems);
8794   APInt FrozenUndefMask = APInt::getZero(NumElems);
8795   APInt ZeroMask = APInt::getZero(NumElems);
8796   APInt NonZeroMask = APInt::getZero(NumElems);
8797   bool IsAllConstants = true;
8798   bool OneUseFrozenUndefs = true;
8799   SmallSet<SDValue, 8> Values;
8800   unsigned NumConstants = NumElems;
8801   for (unsigned i = 0; i < NumElems; ++i) {
8802     SDValue Elt = Op.getOperand(i);
8803     if (Elt.isUndef()) {
8804       UndefMask.setBit(i);
8805       continue;
8806     }
8807     if (ISD::isFreezeUndef(Elt.getNode())) {
8808       OneUseFrozenUndefs = OneUseFrozenUndefs && Elt->hasOneUse();
8809       FrozenUndefMask.setBit(i);
8810       continue;
8811     }
8812     Values.insert(Elt);
8813     if (!isIntOrFPConstant(Elt)) {
8814       IsAllConstants = false;
8815       NumConstants--;
8816     }
8817     if (X86::isZeroNode(Elt)) {
8818       ZeroMask.setBit(i);
8819     } else {
8820       NonZeroMask.setBit(i);
8821     }
8822   }
8823 
8824   // All undef vector. Return an UNDEF.
8825   if (UndefMask.isAllOnes())
8826     return DAG.getUNDEF(VT);
8827 
8828   // All undef/freeze(undef) vector. Return a FREEZE UNDEF.
8829   if (OneUseFrozenUndefs && (UndefMask | FrozenUndefMask).isAllOnes())
8830     return DAG.getFreeze(DAG.getUNDEF(VT));
8831 
8832   // All undef/freeze(undef)/zero vector. Return a zero vector.
8833   if ((UndefMask | FrozenUndefMask | ZeroMask).isAllOnes())
8834     return getZeroVector(VT, Subtarget, DAG, dl);
8835 
8836   // If we have multiple FREEZE-UNDEF operands, we are likely going to end up
8837   // lowering into a suboptimal insertion sequence. Instead, thaw the UNDEF in
8838   // our source BUILD_VECTOR, create another FREEZE-UNDEF splat BUILD_VECTOR,
8839   // and blend the FREEZE-UNDEF operands back in.
8840   // FIXME: is this worthwhile even for a single FREEZE-UNDEF operand?
8841   if (unsigned NumFrozenUndefElts = FrozenUndefMask.popcount();
8842       NumFrozenUndefElts >= 2 && NumFrozenUndefElts < NumElems) {
8843     SmallVector<int, 16> BlendMask(NumElems, -1);
8844     SmallVector<SDValue, 16> Elts(NumElems, DAG.getUNDEF(OpEltVT));
8845     for (unsigned i = 0; i < NumElems; ++i) {
8846       if (UndefMask[i]) {
8847         BlendMask[i] = -1;
8848         continue;
8849       }
8850       BlendMask[i] = i;
8851       if (!FrozenUndefMask[i])
8852         Elts[i] = Op.getOperand(i);
8853       else
8854         BlendMask[i] += NumElems;
8855     }
8856     SDValue EltsBV = DAG.getBuildVector(VT, dl, Elts);
8857     SDValue FrozenUndefElt = DAG.getFreeze(DAG.getUNDEF(OpEltVT));
8858     SDValue FrozenUndefBV = DAG.getSplatBuildVector(VT, dl, FrozenUndefElt);
8859     return DAG.getVectorShuffle(VT, dl, EltsBV, FrozenUndefBV, BlendMask);
8860   }
8861 
8862   BuildVectorSDNode *BV = cast<BuildVectorSDNode>(Op.getNode());
8863 
8864   // If the upper elts of a ymm/zmm are undef/freeze(undef)/zero then we might
8865   // be better off lowering to a smaller build vector and padding with
8866   // undef/zero.
8867   if ((VT.is256BitVector() || VT.is512BitVector()) &&
8868       !isFoldableUseOfShuffle(BV)) {
8869     unsigned UpperElems = NumElems / 2;
8870     APInt UndefOrZeroMask = FrozenUndefMask | UndefMask | ZeroMask;
8871     unsigned NumUpperUndefsOrZeros = UndefOrZeroMask.countl_one();
8872     if (NumUpperUndefsOrZeros >= UpperElems) {
8873       if (VT.is512BitVector() &&
8874           NumUpperUndefsOrZeros >= (NumElems - (NumElems / 4)))
8875         UpperElems = NumElems - (NumElems / 4);
8876       // If freeze(undef) is in any upper elements, force to zero.
8877       bool UndefUpper = UndefMask.countl_one() >= UpperElems;
8878       MVT LowerVT = MVT::getVectorVT(EltVT, NumElems - UpperElems);
8879       SDValue NewBV =
8880           DAG.getBuildVector(LowerVT, dl, Op->ops().drop_back(UpperElems));
8881       return widenSubVector(VT, NewBV, !UndefUpper, Subtarget, DAG, dl);
8882     }
8883   }
8884 
8885   if (SDValue AddSub = lowerToAddSubOrFMAddSub(BV, dl, Subtarget, DAG))
8886     return AddSub;
8887   if (SDValue HorizontalOp = LowerToHorizontalOp(BV, dl, Subtarget, DAG))
8888     return HorizontalOp;
8889   if (SDValue Broadcast = lowerBuildVectorAsBroadcast(BV, dl, Subtarget, DAG))
8890     return Broadcast;
8891   if (SDValue BitOp = lowerBuildVectorToBitOp(BV, dl, Subtarget, DAG))
8892     return BitOp;
8893 
8894   unsigned NumZero = ZeroMask.popcount();
8895   unsigned NumNonZero = NonZeroMask.popcount();
8896 
8897   // If we are inserting one variable into a vector of non-zero constants, try
8898   // to avoid loading each constant element as a scalar. Load the constants as a
8899   // vector and then insert the variable scalar element. If insertion is not
8900   // supported, fall back to a shuffle to get the scalar blended with the
8901   // constants. Insertion into a zero vector is handled as a special-case
8902   // somewhere below here.
8903   if (NumConstants == NumElems - 1 && NumNonZero != 1 &&
8904       FrozenUndefMask.isZero() &&
8905       (isOperationLegalOrCustom(ISD::INSERT_VECTOR_ELT, VT) ||
8906        isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT))) {
8907     // Create an all-constant vector. The variable element in the old
8908     // build vector is replaced by undef in the constant vector. Save the
8909     // variable scalar element and its index for use in the insertelement.
8910     LLVMContext &Context = *DAG.getContext();
8911     Type *EltType = Op.getValueType().getScalarType().getTypeForEVT(Context);
8912     SmallVector<Constant *, 16> ConstVecOps(NumElems, UndefValue::get(EltType));
8913     SDValue VarElt;
8914     SDValue InsIndex;
8915     for (unsigned i = 0; i != NumElems; ++i) {
8916       SDValue Elt = Op.getOperand(i);
8917       if (auto *C = dyn_cast<ConstantSDNode>(Elt))
8918         ConstVecOps[i] = ConstantInt::get(Context, C->getAPIntValue());
8919       else if (auto *C = dyn_cast<ConstantFPSDNode>(Elt))
8920         ConstVecOps[i] = ConstantFP::get(Context, C->getValueAPF());
8921       else if (!Elt.isUndef()) {
8922         assert(!VarElt.getNode() && !InsIndex.getNode() &&
8923                "Expected one variable element in this vector");
8924         VarElt = Elt;
8925         InsIndex = DAG.getVectorIdxConstant(i, dl);
8926       }
8927     }
8928     Constant *CV = ConstantVector::get(ConstVecOps);
8929     SDValue DAGConstVec = DAG.getConstantPool(CV, VT);
8930 
8931     // The constants we just created may not be legal (eg, floating point). We
8932     // must lower the vector right here because we can not guarantee that we'll
8933     // legalize it before loading it. This is also why we could not just create
8934     // a new build vector here. If the build vector contains illegal constants,
8935     // it could get split back up into a series of insert elements.
8936     // TODO: Improve this by using shorter loads with broadcast/VZEXT_LOAD.
8937     SDValue LegalDAGConstVec = LowerConstantPool(DAGConstVec, DAG);
8938     MachineFunction &MF = DAG.getMachineFunction();
8939     MachinePointerInfo MPI = MachinePointerInfo::getConstantPool(MF);
8940     SDValue Ld = DAG.getLoad(VT, dl, DAG.getEntryNode(), LegalDAGConstVec, MPI);
8941     unsigned InsertC = InsIndex->getAsZExtVal();
8942     unsigned NumEltsInLow128Bits = 128 / VT.getScalarSizeInBits();
8943     if (InsertC < NumEltsInLow128Bits)
8944       return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Ld, VarElt, InsIndex);
8945 
8946     // There's no good way to insert into the high elements of a >128-bit
8947     // vector, so use shuffles to avoid an extract/insert sequence.
8948     assert(VT.getSizeInBits() > 128 && "Invalid insertion index?");
8949     assert(Subtarget.hasAVX() && "Must have AVX with >16-byte vector");
8950     SmallVector<int, 8> ShuffleMask;
8951     unsigned NumElts = VT.getVectorNumElements();
8952     for (unsigned i = 0; i != NumElts; ++i)
8953       ShuffleMask.push_back(i == InsertC ? NumElts : i);
8954     SDValue S2V = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, VarElt);
8955     return DAG.getVectorShuffle(VT, dl, Ld, S2V, ShuffleMask);
8956   }
8957 
8958   // Special case for single non-zero, non-undef, element.
8959   if (NumNonZero == 1) {
8960     unsigned Idx = NonZeroMask.countr_zero();
8961     SDValue Item = Op.getOperand(Idx);
8962 
8963     // If we have a constant or non-constant insertion into the low element of
8964     // a vector, we can do this with SCALAR_TO_VECTOR + shuffle of zero into
8965     // the rest of the elements.  This will be matched as movd/movq/movss/movsd
8966     // depending on what the source datatype is.
8967     if (Idx == 0) {
8968       if (NumZero == 0)
8969         return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Item);
8970 
8971       if (EltVT == MVT::i32 || EltVT == MVT::f16 || EltVT == MVT::f32 ||
8972           EltVT == MVT::f64 || (EltVT == MVT::i64 && Subtarget.is64Bit()) ||
8973           (EltVT == MVT::i16 && Subtarget.hasFP16())) {
8974         assert((VT.is128BitVector() || VT.is256BitVector() ||
8975                 VT.is512BitVector()) &&
8976                "Expected an SSE value type!");
8977         Item = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Item);
8978         // Turn it into a MOVL (i.e. movsh, movss, movsd, movw or movd) to a
8979         // zero vector.
8980         return getShuffleVectorZeroOrUndef(Item, 0, true, Subtarget, DAG);
8981       }
8982 
8983       // We can't directly insert an i8 or i16 into a vector, so zero extend
8984       // it to i32 first.
8985       if (EltVT == MVT::i16 || EltVT == MVT::i8) {
8986         Item = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Item);
8987         MVT ShufVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32);
8988         Item = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, ShufVT, Item);
8989         Item = getShuffleVectorZeroOrUndef(Item, 0, true, Subtarget, DAG);
8990         return DAG.getBitcast(VT, Item);
8991       }
8992     }
8993 
8994     // Is it a vector logical left shift?
8995     if (NumElems == 2 && Idx == 1 &&
8996         X86::isZeroNode(Op.getOperand(0)) &&
8997         !X86::isZeroNode(Op.getOperand(1))) {
8998       unsigned NumBits = VT.getSizeInBits();
8999       return getVShift(true, VT,
9000                        DAG.getNode(ISD::SCALAR_TO_VECTOR, dl,
9001                                    VT, Op.getOperand(1)),
9002                        NumBits/2, DAG, *this, dl);
9003     }
9004 
9005     if (IsAllConstants) // Otherwise, it's better to do a constpool load.
9006       return SDValue();
9007 
9008     // Otherwise, if this is a vector with i32 or f32 elements, and the element
9009     // is a non-constant being inserted into an element other than the low one,
9010     // we can't use a constant pool load.  Instead, use SCALAR_TO_VECTOR (aka
9011     // movd/movss) to move this into the low element, then shuffle it into
9012     // place.
9013     if (EVTBits == 32) {
9014       Item = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Item);
9015       return getShuffleVectorZeroOrUndef(Item, Idx, NumZero > 0, Subtarget, DAG);
9016     }
9017   }
9018 
9019   // Splat is obviously ok. Let legalizer expand it to a shuffle.
9020   if (Values.size() == 1) {
9021     if (EVTBits == 32) {
9022       // Instead of a shuffle like this:
9023       // shuffle (scalar_to_vector (load (ptr + 4))), undef, <0, 0, 0, 0>
9024       // Check if it's possible to issue this instead.
9025       // shuffle (vload ptr)), undef, <1, 1, 1, 1>
9026       unsigned Idx = NonZeroMask.countr_zero();
9027       SDValue Item = Op.getOperand(Idx);
9028       if (Op.getNode()->isOnlyUserOf(Item.getNode()))
9029         return LowerAsSplatVectorLoad(Item, VT, dl, DAG);
9030     }
9031     return SDValue();
9032   }
9033 
9034   // A vector full of immediates; various special cases are already
9035   // handled, so this is best done with a single constant-pool load.
9036   if (IsAllConstants)
9037     return SDValue();
9038 
9039   if (SDValue V = LowerBUILD_VECTORAsVariablePermute(Op, dl, DAG, Subtarget))
9040     return V;
9041 
9042   // See if we can use a vector load to get all of the elements.
9043   {
9044     SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
9045     if (SDValue LD =
9046             EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9047       return LD;
9048   }
9049 
9050   // If this is a splat of pairs of 32-bit elements, we can use a narrower
9051   // build_vector and broadcast it.
9052   // TODO: We could probably generalize this more.
9053   if (Subtarget.hasAVX2() && EVTBits == 32 && Values.size() == 2) {
9054     SDValue Ops[4] = { Op.getOperand(0), Op.getOperand(1),
9055                        DAG.getUNDEF(EltVT), DAG.getUNDEF(EltVT) };
9056     auto CanSplat = [](SDValue Op, unsigned NumElems, ArrayRef<SDValue> Ops) {
9057       // Make sure all the even/odd operands match.
9058       for (unsigned i = 2; i != NumElems; ++i)
9059         if (Ops[i % 2] != Op.getOperand(i))
9060           return false;
9061       return true;
9062     };
9063     if (CanSplat(Op, NumElems, Ops)) {
9064       MVT WideEltVT = VT.isFloatingPoint() ? MVT::f64 : MVT::i64;
9065       MVT NarrowVT = MVT::getVectorVT(EltVT, 4);
9066       // Create a new build vector and cast to v2i64/v2f64.
9067       SDValue NewBV = DAG.getBitcast(MVT::getVectorVT(WideEltVT, 2),
9068                                      DAG.getBuildVector(NarrowVT, dl, Ops));
9069       // Broadcast from v2i64/v2f64 and cast to final VT.
9070       MVT BcastVT = MVT::getVectorVT(WideEltVT, NumElems / 2);
9071       return DAG.getBitcast(VT, DAG.getNode(X86ISD::VBROADCAST, dl, BcastVT,
9072                                             NewBV));
9073     }
9074   }
9075 
9076   // For AVX-length vectors, build the individual 128-bit pieces and use
9077   // shuffles to put them in place.
9078   if (VT.getSizeInBits() > 128) {
9079     MVT HVT = MVT::getVectorVT(EltVT, NumElems / 2);
9080 
9081     // Build both the lower and upper subvector.
9082     SDValue Lower =
9083         DAG.getBuildVector(HVT, dl, Op->ops().slice(0, NumElems / 2));
9084     SDValue Upper = DAG.getBuildVector(
9085         HVT, dl, Op->ops().slice(NumElems / 2, NumElems /2));
9086 
9087     // Recreate the wider vector with the lower and upper part.
9088     return concatSubVectors(Lower, Upper, DAG, dl);
9089   }
9090 
9091   // Let legalizer expand 2-wide build_vectors.
9092   if (EVTBits == 64) {
9093     if (NumNonZero == 1) {
9094       // One half is zero or undef.
9095       unsigned Idx = NonZeroMask.countr_zero();
9096       SDValue V2 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT,
9097                                Op.getOperand(Idx));
9098       return getShuffleVectorZeroOrUndef(V2, Idx, true, Subtarget, DAG);
9099     }
9100     return SDValue();
9101   }
9102 
9103   // If element VT is < 32 bits, convert it to inserts into a zero vector.
9104   if (EVTBits == 8 && NumElems == 16)
9105     if (SDValue V = LowerBuildVectorv16i8(Op, dl, NonZeroMask, NumNonZero,
9106                                           NumZero, DAG, Subtarget))
9107       return V;
9108 
9109   if (EltVT == MVT::i16 && NumElems == 8)
9110     if (SDValue V = LowerBuildVectorv8i16(Op, dl, NonZeroMask, NumNonZero,
9111                                           NumZero, DAG, Subtarget))
9112       return V;
9113 
9114   // If element VT is == 32 bits and has 4 elems, try to generate an INSERTPS
9115   if (EVTBits == 32 && NumElems == 4)
9116     if (SDValue V = LowerBuildVectorv4x32(Op, dl, DAG, Subtarget))
9117       return V;
9118 
9119   // If element VT is == 32 bits, turn it into a number of shuffles.
9120   if (NumElems == 4 && NumZero > 0) {
9121     SmallVector<SDValue, 8> Ops(NumElems);
9122     for (unsigned i = 0; i < 4; ++i) {
9123       bool isZero = !NonZeroMask[i];
9124       if (isZero)
9125         Ops[i] = getZeroVector(VT, Subtarget, DAG, dl);
9126       else
9127         Ops[i] = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op.getOperand(i));
9128     }
9129 
9130     for (unsigned i = 0; i < 2; ++i) {
9131       switch (NonZeroMask.extractBitsAsZExtValue(2, i * 2)) {
9132         default: llvm_unreachable("Unexpected NonZero count");
9133         case 0:
9134           Ops[i] = Ops[i*2];  // Must be a zero vector.
9135           break;
9136         case 1:
9137           Ops[i] = getMOVL(DAG, dl, VT, Ops[i*2+1], Ops[i*2]);
9138           break;
9139         case 2:
9140           Ops[i] = getMOVL(DAG, dl, VT, Ops[i*2], Ops[i*2+1]);
9141           break;
9142         case 3:
9143           Ops[i] = getUnpackl(DAG, dl, VT, Ops[i*2], Ops[i*2+1]);
9144           break;
9145       }
9146     }
9147 
9148     bool Reverse1 = NonZeroMask.extractBitsAsZExtValue(2, 0) == 2;
9149     bool Reverse2 = NonZeroMask.extractBitsAsZExtValue(2, 2) == 2;
9150     int MaskVec[] = {
9151       Reverse1 ? 1 : 0,
9152       Reverse1 ? 0 : 1,
9153       static_cast<int>(Reverse2 ? NumElems+1 : NumElems),
9154       static_cast<int>(Reverse2 ? NumElems   : NumElems+1)
9155     };
9156     return DAG.getVectorShuffle(VT, dl, Ops[0], Ops[1], MaskVec);
9157   }
9158 
9159   assert(Values.size() > 1 && "Expected non-undef and non-splat vector");
9160 
9161   // Check for a build vector from mostly shuffle plus few inserting.
9162   if (SDValue Sh = buildFromShuffleMostly(Op, dl, DAG))
9163     return Sh;
9164 
9165   // For SSE 4.1, use insertps to put the high elements into the low element.
9166   if (Subtarget.hasSSE41() && EltVT != MVT::f16) {
9167     SDValue Result;
9168     if (!Op.getOperand(0).isUndef())
9169       Result = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op.getOperand(0));
9170     else
9171       Result = DAG.getUNDEF(VT);
9172 
9173     for (unsigned i = 1; i < NumElems; ++i) {
9174       if (Op.getOperand(i).isUndef()) continue;
9175       Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Result,
9176                            Op.getOperand(i), DAG.getIntPtrConstant(i, dl));
9177     }
9178     return Result;
9179   }
9180 
9181   // Otherwise, expand into a number of unpckl*, start by extending each of
9182   // our (non-undef) elements to the full vector width with the element in the
9183   // bottom slot of the vector (which generates no code for SSE).
9184   SmallVector<SDValue, 8> Ops(NumElems);
9185   for (unsigned i = 0; i < NumElems; ++i) {
9186     if (!Op.getOperand(i).isUndef())
9187       Ops[i] = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op.getOperand(i));
9188     else
9189       Ops[i] = DAG.getUNDEF(VT);
9190   }
9191 
9192   // Next, we iteratively mix elements, e.g. for v4f32:
9193   //   Step 1: unpcklps 0, 1 ==> X: <?, ?, 1, 0>
9194   //         : unpcklps 2, 3 ==> Y: <?, ?, 3, 2>
9195   //   Step 2: unpcklpd X, Y ==>    <3, 2, 1, 0>
9196   for (unsigned Scale = 1; Scale < NumElems; Scale *= 2) {
9197     // Generate scaled UNPCKL shuffle mask.
9198     SmallVector<int, 16> Mask;
9199     for(unsigned i = 0; i != Scale; ++i)
9200       Mask.push_back(i);
9201     for (unsigned i = 0; i != Scale; ++i)
9202       Mask.push_back(NumElems+i);
9203     Mask.append(NumElems - Mask.size(), SM_SentinelUndef);
9204 
9205     for (unsigned i = 0, e = NumElems / (2 * Scale); i != e; ++i)
9206       Ops[i] = DAG.getVectorShuffle(VT, dl, Ops[2*i], Ops[(2*i)+1], Mask);
9207   }
9208   return Ops[0];
9209 }
9210 
9211 // 256-bit AVX can use the vinsertf128 instruction
9212 // to create 256-bit vectors from two other 128-bit ones.
9213 // TODO: Detect subvector broadcast here instead of DAG combine?
LowerAVXCONCAT_VECTORS(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)9214 static SDValue LowerAVXCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG,
9215                                       const X86Subtarget &Subtarget) {
9216   SDLoc dl(Op);
9217   MVT ResVT = Op.getSimpleValueType();
9218 
9219   assert((ResVT.is256BitVector() ||
9220           ResVT.is512BitVector()) && "Value type must be 256-/512-bit wide");
9221 
9222   unsigned NumOperands = Op.getNumOperands();
9223   unsigned NumFreezeUndef = 0;
9224   unsigned NumZero = 0;
9225   unsigned NumNonZero = 0;
9226   unsigned NonZeros = 0;
9227   for (unsigned i = 0; i != NumOperands; ++i) {
9228     SDValue SubVec = Op.getOperand(i);
9229     if (SubVec.isUndef())
9230       continue;
9231     if (ISD::isFreezeUndef(SubVec.getNode())) {
9232         // If the freeze(undef) has multiple uses then we must fold to zero.
9233         if (SubVec.hasOneUse())
9234           ++NumFreezeUndef;
9235         else
9236           ++NumZero;
9237     }
9238     else if (ISD::isBuildVectorAllZeros(SubVec.getNode()))
9239       ++NumZero;
9240     else {
9241       assert(i < sizeof(NonZeros) * CHAR_BIT); // Ensure the shift is in range.
9242       NonZeros |= 1 << i;
9243       ++NumNonZero;
9244     }
9245   }
9246 
9247   // If we have more than 2 non-zeros, build each half separately.
9248   if (NumNonZero > 2) {
9249     MVT HalfVT = ResVT.getHalfNumVectorElementsVT();
9250     ArrayRef<SDUse> Ops = Op->ops();
9251     SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT,
9252                              Ops.slice(0, NumOperands/2));
9253     SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT,
9254                              Ops.slice(NumOperands/2));
9255     return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
9256   }
9257 
9258   // Otherwise, build it up through insert_subvectors.
9259   SDValue Vec = NumZero ? getZeroVector(ResVT, Subtarget, DAG, dl)
9260                         : (NumFreezeUndef ? DAG.getFreeze(DAG.getUNDEF(ResVT))
9261                                           : DAG.getUNDEF(ResVT));
9262 
9263   MVT SubVT = Op.getOperand(0).getSimpleValueType();
9264   unsigned NumSubElems = SubVT.getVectorNumElements();
9265   for (unsigned i = 0; i != NumOperands; ++i) {
9266     if ((NonZeros & (1 << i)) == 0)
9267       continue;
9268 
9269     Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Vec,
9270                       Op.getOperand(i),
9271                       DAG.getIntPtrConstant(i * NumSubElems, dl));
9272   }
9273 
9274   return Vec;
9275 }
9276 
9277 // Returns true if the given node is a type promotion (by concatenating i1
9278 // zeros) of the result of a node that already zeros all upper bits of
9279 // k-register.
9280 // TODO: Merge this with LowerAVXCONCAT_VECTORS?
LowerCONCAT_VECTORSvXi1(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)9281 static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op,
9282                                        const X86Subtarget &Subtarget,
9283                                        SelectionDAG & DAG) {
9284   SDLoc dl(Op);
9285   MVT ResVT = Op.getSimpleValueType();
9286   unsigned NumOperands = Op.getNumOperands();
9287 
9288   assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
9289          "Unexpected number of operands in CONCAT_VECTORS");
9290 
9291   uint64_t Zeros = 0;
9292   uint64_t NonZeros = 0;
9293   for (unsigned i = 0; i != NumOperands; ++i) {
9294     SDValue SubVec = Op.getOperand(i);
9295     if (SubVec.isUndef())
9296       continue;
9297     assert(i < sizeof(NonZeros) * CHAR_BIT); // Ensure the shift is in range.
9298     if (ISD::isBuildVectorAllZeros(SubVec.getNode()))
9299       Zeros |= (uint64_t)1 << i;
9300     else
9301       NonZeros |= (uint64_t)1 << i;
9302   }
9303 
9304   unsigned NumElems = ResVT.getVectorNumElements();
9305 
9306   // If we are inserting non-zero vector and there are zeros in LSBs and undef
9307   // in the MSBs we need to emit a KSHIFTL. The generic lowering to
9308   // insert_subvector will give us two kshifts.
9309   if (isPowerOf2_64(NonZeros) && Zeros != 0 && NonZeros > Zeros &&
9310       Log2_64(NonZeros) != NumOperands - 1) {
9311     unsigned Idx = Log2_64(NonZeros);
9312     SDValue SubVec = Op.getOperand(Idx);
9313     unsigned SubVecNumElts = SubVec.getSimpleValueType().getVectorNumElements();
9314     MVT ShiftVT = widenMaskVectorType(ResVT, Subtarget);
9315     Op = widenSubVector(ShiftVT, SubVec, false, Subtarget, DAG, dl);
9316     Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, Op,
9317                      DAG.getTargetConstant(Idx * SubVecNumElts, dl, MVT::i8));
9318     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Op,
9319                        DAG.getIntPtrConstant(0, dl));
9320   }
9321 
9322   // If there are zero or one non-zeros we can handle this very simply.
9323   if (NonZeros == 0 || isPowerOf2_64(NonZeros)) {
9324     SDValue Vec = Zeros ? DAG.getConstant(0, dl, ResVT) : DAG.getUNDEF(ResVT);
9325     if (!NonZeros)
9326       return Vec;
9327     unsigned Idx = Log2_64(NonZeros);
9328     SDValue SubVec = Op.getOperand(Idx);
9329     unsigned SubVecNumElts = SubVec.getSimpleValueType().getVectorNumElements();
9330     return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Vec, SubVec,
9331                        DAG.getIntPtrConstant(Idx * SubVecNumElts, dl));
9332   }
9333 
9334   if (NumOperands > 2) {
9335     MVT HalfVT = ResVT.getHalfNumVectorElementsVT();
9336     ArrayRef<SDUse> Ops = Op->ops();
9337     SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT,
9338                              Ops.slice(0, NumOperands/2));
9339     SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT,
9340                              Ops.slice(NumOperands/2));
9341     return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
9342   }
9343 
9344   assert(llvm::popcount(NonZeros) == 2 && "Simple cases not handled?");
9345 
9346   if (ResVT.getVectorNumElements() >= 16)
9347     return Op; // The operation is legal with KUNPCK
9348 
9349   SDValue Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT,
9350                             DAG.getUNDEF(ResVT), Op.getOperand(0),
9351                             DAG.getIntPtrConstant(0, dl));
9352   return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Vec, Op.getOperand(1),
9353                      DAG.getIntPtrConstant(NumElems/2, dl));
9354 }
9355 
LowerCONCAT_VECTORS(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)9356 static SDValue LowerCONCAT_VECTORS(SDValue Op,
9357                                    const X86Subtarget &Subtarget,
9358                                    SelectionDAG &DAG) {
9359   MVT VT = Op.getSimpleValueType();
9360   if (VT.getVectorElementType() == MVT::i1)
9361     return LowerCONCAT_VECTORSvXi1(Op, Subtarget, DAG);
9362 
9363   assert((VT.is256BitVector() && Op.getNumOperands() == 2) ||
9364          (VT.is512BitVector() && (Op.getNumOperands() == 2 ||
9365           Op.getNumOperands() == 4)));
9366 
9367   // AVX can use the vinsertf128 instruction to create 256-bit vectors
9368   // from two other 128-bit ones.
9369 
9370   // 512-bit vector may contain 2 256-bit vectors or 4 128-bit vectors
9371   return LowerAVXCONCAT_VECTORS(Op, DAG, Subtarget);
9372 }
9373 
9374 //===----------------------------------------------------------------------===//
9375 // Vector shuffle lowering
9376 //
9377 // This is an experimental code path for lowering vector shuffles on x86. It is
9378 // designed to handle arbitrary vector shuffles and blends, gracefully
9379 // degrading performance as necessary. It works hard to recognize idiomatic
9380 // shuffles and lower them to optimal instruction patterns without leaving
9381 // a framework that allows reasonably efficient handling of all vector shuffle
9382 // patterns.
9383 //===----------------------------------------------------------------------===//
9384 
9385 /// Tiny helper function to identify a no-op mask.
9386 ///
9387 /// This is a somewhat boring predicate function. It checks whether the mask
9388 /// array input, which is assumed to be a single-input shuffle mask of the kind
9389 /// used by the X86 shuffle instructions (not a fully general
9390 /// ShuffleVectorSDNode mask) requires any shuffles to occur. Both undef and an
9391 /// in-place shuffle are 'no-op's.
isNoopShuffleMask(ArrayRef<int> Mask)9392 static bool isNoopShuffleMask(ArrayRef<int> Mask) {
9393   for (int i = 0, Size = Mask.size(); i < Size; ++i) {
9394     assert(Mask[i] >= -1 && "Out of bound mask element!");
9395     if (Mask[i] >= 0 && Mask[i] != i)
9396       return false;
9397   }
9398   return true;
9399 }
9400 
9401 /// Test whether there are elements crossing LaneSizeInBits lanes in this
9402 /// shuffle mask.
9403 ///
9404 /// X86 divides up its shuffles into in-lane and cross-lane shuffle operations
9405 /// and we routinely test for these.
isLaneCrossingShuffleMask(unsigned LaneSizeInBits,unsigned ScalarSizeInBits,ArrayRef<int> Mask)9406 static bool isLaneCrossingShuffleMask(unsigned LaneSizeInBits,
9407                                       unsigned ScalarSizeInBits,
9408                                       ArrayRef<int> Mask) {
9409   assert(LaneSizeInBits && ScalarSizeInBits &&
9410          (LaneSizeInBits % ScalarSizeInBits) == 0 &&
9411          "Illegal shuffle lane size");
9412   int LaneSize = LaneSizeInBits / ScalarSizeInBits;
9413   int Size = Mask.size();
9414   for (int i = 0; i < Size; ++i)
9415     if (Mask[i] >= 0 && (Mask[i] % Size) / LaneSize != i / LaneSize)
9416       return true;
9417   return false;
9418 }
9419 
9420 /// Test whether there are elements crossing 128-bit lanes in this
9421 /// shuffle mask.
is128BitLaneCrossingShuffleMask(MVT VT,ArrayRef<int> Mask)9422 static bool is128BitLaneCrossingShuffleMask(MVT VT, ArrayRef<int> Mask) {
9423   return isLaneCrossingShuffleMask(128, VT.getScalarSizeInBits(), Mask);
9424 }
9425 
9426 /// Test whether elements in each LaneSizeInBits lane in this shuffle mask come
9427 /// from multiple lanes - this is different to isLaneCrossingShuffleMask to
9428 /// better support 'repeated mask + lane permute' style shuffles.
isMultiLaneShuffleMask(unsigned LaneSizeInBits,unsigned ScalarSizeInBits,ArrayRef<int> Mask)9429 static bool isMultiLaneShuffleMask(unsigned LaneSizeInBits,
9430                                    unsigned ScalarSizeInBits,
9431                                    ArrayRef<int> Mask) {
9432   assert(LaneSizeInBits && ScalarSizeInBits &&
9433          (LaneSizeInBits % ScalarSizeInBits) == 0 &&
9434          "Illegal shuffle lane size");
9435   int NumElts = Mask.size();
9436   int NumEltsPerLane = LaneSizeInBits / ScalarSizeInBits;
9437   int NumLanes = NumElts / NumEltsPerLane;
9438   if (NumLanes > 1) {
9439     for (int i = 0; i != NumLanes; ++i) {
9440       int SrcLane = -1;
9441       for (int j = 0; j != NumEltsPerLane; ++j) {
9442         int M = Mask[(i * NumEltsPerLane) + j];
9443         if (M < 0)
9444           continue;
9445         int Lane = (M % NumElts) / NumEltsPerLane;
9446         if (SrcLane >= 0 && SrcLane != Lane)
9447           return true;
9448         SrcLane = Lane;
9449       }
9450     }
9451   }
9452   return false;
9453 }
9454 
9455 /// Test whether a shuffle mask is equivalent within each sub-lane.
9456 ///
9457 /// This checks a shuffle mask to see if it is performing the same
9458 /// lane-relative shuffle in each sub-lane. This trivially implies
9459 /// that it is also not lane-crossing. It may however involve a blend from the
9460 /// same lane of a second vector.
9461 ///
9462 /// The specific repeated shuffle mask is populated in \p RepeatedMask, as it is
9463 /// non-trivial to compute in the face of undef lanes. The representation is
9464 /// suitable for use with existing 128-bit shuffles as entries from the second
9465 /// vector have been remapped to [LaneSize, 2*LaneSize).
isRepeatedShuffleMask(unsigned LaneSizeInBits,MVT VT,ArrayRef<int> Mask,SmallVectorImpl<int> & RepeatedMask)9466 static bool isRepeatedShuffleMask(unsigned LaneSizeInBits, MVT VT,
9467                                   ArrayRef<int> Mask,
9468                                   SmallVectorImpl<int> &RepeatedMask) {
9469   auto LaneSize = LaneSizeInBits / VT.getScalarSizeInBits();
9470   RepeatedMask.assign(LaneSize, -1);
9471   int Size = Mask.size();
9472   for (int i = 0; i < Size; ++i) {
9473     assert(Mask[i] == SM_SentinelUndef || Mask[i] >= 0);
9474     if (Mask[i] < 0)
9475       continue;
9476     if ((Mask[i] % Size) / LaneSize != i / LaneSize)
9477       // This entry crosses lanes, so there is no way to model this shuffle.
9478       return false;
9479 
9480     // Ok, handle the in-lane shuffles by detecting if and when they repeat.
9481     // Adjust second vector indices to start at LaneSize instead of Size.
9482     int LocalM = Mask[i] < Size ? Mask[i] % LaneSize
9483                                 : Mask[i] % LaneSize + LaneSize;
9484     if (RepeatedMask[i % LaneSize] < 0)
9485       // This is the first non-undef entry in this slot of a 128-bit lane.
9486       RepeatedMask[i % LaneSize] = LocalM;
9487     else if (RepeatedMask[i % LaneSize] != LocalM)
9488       // Found a mismatch with the repeated mask.
9489       return false;
9490   }
9491   return true;
9492 }
9493 
9494 /// Test whether a shuffle mask is equivalent within each 128-bit lane.
9495 static bool
is128BitLaneRepeatedShuffleMask(MVT VT,ArrayRef<int> Mask,SmallVectorImpl<int> & RepeatedMask)9496 is128BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask,
9497                                 SmallVectorImpl<int> &RepeatedMask) {
9498   return isRepeatedShuffleMask(128, VT, Mask, RepeatedMask);
9499 }
9500 
9501 static bool
is128BitLaneRepeatedShuffleMask(MVT VT,ArrayRef<int> Mask)9502 is128BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask) {
9503   SmallVector<int, 32> RepeatedMask;
9504   return isRepeatedShuffleMask(128, VT, Mask, RepeatedMask);
9505 }
9506 
9507 /// Test whether a shuffle mask is equivalent within each 256-bit lane.
9508 static bool
is256BitLaneRepeatedShuffleMask(MVT VT,ArrayRef<int> Mask,SmallVectorImpl<int> & RepeatedMask)9509 is256BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask,
9510                                 SmallVectorImpl<int> &RepeatedMask) {
9511   return isRepeatedShuffleMask(256, VT, Mask, RepeatedMask);
9512 }
9513 
9514 /// Test whether a target shuffle mask is equivalent within each sub-lane.
9515 /// Unlike isRepeatedShuffleMask we must respect SM_SentinelZero.
isRepeatedTargetShuffleMask(unsigned LaneSizeInBits,unsigned EltSizeInBits,ArrayRef<int> Mask,SmallVectorImpl<int> & RepeatedMask)9516 static bool isRepeatedTargetShuffleMask(unsigned LaneSizeInBits,
9517                                         unsigned EltSizeInBits,
9518                                         ArrayRef<int> Mask,
9519                                         SmallVectorImpl<int> &RepeatedMask) {
9520   int LaneSize = LaneSizeInBits / EltSizeInBits;
9521   RepeatedMask.assign(LaneSize, SM_SentinelUndef);
9522   int Size = Mask.size();
9523   for (int i = 0; i < Size; ++i) {
9524     assert(isUndefOrZero(Mask[i]) || (Mask[i] >= 0));
9525     if (Mask[i] == SM_SentinelUndef)
9526       continue;
9527     if (Mask[i] == SM_SentinelZero) {
9528       if (!isUndefOrZero(RepeatedMask[i % LaneSize]))
9529         return false;
9530       RepeatedMask[i % LaneSize] = SM_SentinelZero;
9531       continue;
9532     }
9533     if ((Mask[i] % Size) / LaneSize != i / LaneSize)
9534       // This entry crosses lanes, so there is no way to model this shuffle.
9535       return false;
9536 
9537     // Handle the in-lane shuffles by detecting if and when they repeat. Adjust
9538     // later vector indices to start at multiples of LaneSize instead of Size.
9539     int LaneM = Mask[i] / Size;
9540     int LocalM = (Mask[i] % LaneSize) + (LaneM * LaneSize);
9541     if (RepeatedMask[i % LaneSize] == SM_SentinelUndef)
9542       // This is the first non-undef entry in this slot of a 128-bit lane.
9543       RepeatedMask[i % LaneSize] = LocalM;
9544     else if (RepeatedMask[i % LaneSize] != LocalM)
9545       // Found a mismatch with the repeated mask.
9546       return false;
9547   }
9548   return true;
9549 }
9550 
9551 /// Test whether a target shuffle mask is equivalent within each sub-lane.
9552 /// Unlike isRepeatedShuffleMask we must respect SM_SentinelZero.
isRepeatedTargetShuffleMask(unsigned LaneSizeInBits,MVT VT,ArrayRef<int> Mask,SmallVectorImpl<int> & RepeatedMask)9553 static bool isRepeatedTargetShuffleMask(unsigned LaneSizeInBits, MVT VT,
9554                                         ArrayRef<int> Mask,
9555                                         SmallVectorImpl<int> &RepeatedMask) {
9556   return isRepeatedTargetShuffleMask(LaneSizeInBits, VT.getScalarSizeInBits(),
9557                                      Mask, RepeatedMask);
9558 }
9559 
9560 /// Checks whether the vector elements referenced by two shuffle masks are
9561 /// equivalent.
IsElementEquivalent(int MaskSize,SDValue Op,SDValue ExpectedOp,int Idx,int ExpectedIdx)9562 static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
9563                                 int Idx, int ExpectedIdx) {
9564   assert(0 <= Idx && Idx < MaskSize && 0 <= ExpectedIdx &&
9565          ExpectedIdx < MaskSize && "Out of range element index");
9566   if (!Op || !ExpectedOp || Op.getOpcode() != ExpectedOp.getOpcode())
9567     return false;
9568 
9569   switch (Op.getOpcode()) {
9570   case ISD::BUILD_VECTOR:
9571     // If the values are build vectors, we can look through them to find
9572     // equivalent inputs that make the shuffles equivalent.
9573     // TODO: Handle MaskSize != Op.getNumOperands()?
9574     if (MaskSize == (int)Op.getNumOperands() &&
9575         MaskSize == (int)ExpectedOp.getNumOperands())
9576       return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
9577     break;
9578   case X86ISD::VBROADCAST:
9579   case X86ISD::VBROADCAST_LOAD:
9580     // TODO: Handle MaskSize != Op.getValueType().getVectorNumElements()?
9581     return (Op == ExpectedOp &&
9582             (int)Op.getValueType().getVectorNumElements() == MaskSize);
9583   case X86ISD::HADD:
9584   case X86ISD::HSUB:
9585   case X86ISD::FHADD:
9586   case X86ISD::FHSUB:
9587   case X86ISD::PACKSS:
9588   case X86ISD::PACKUS:
9589     // HOP(X,X) can refer to the elt from the lower/upper half of a lane.
9590     // TODO: Handle MaskSize != NumElts?
9591     // TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases.
9592     if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) {
9593       MVT VT = Op.getSimpleValueType();
9594       int NumElts = VT.getVectorNumElements();
9595       if (MaskSize == NumElts) {
9596         int NumLanes = VT.getSizeInBits() / 128;
9597         int NumEltsPerLane = NumElts / NumLanes;
9598         int NumHalfEltsPerLane = NumEltsPerLane / 2;
9599         bool SameLane =
9600             (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9601         bool SameElt =
9602             (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9603         return SameLane && SameElt;
9604       }
9605     }
9606     break;
9607   }
9608 
9609   return false;
9610 }
9611 
9612 /// Checks whether a shuffle mask is equivalent to an explicit list of
9613 /// arguments.
9614 ///
9615 /// This is a fast way to test a shuffle mask against a fixed pattern:
9616 ///
9617 ///   if (isShuffleEquivalent(Mask, 3, 2, {1, 0})) { ... }
9618 ///
9619 /// It returns true if the mask is exactly as wide as the argument list, and
9620 /// each element of the mask is either -1 (signifying undef) or the value given
9621 /// in the argument.
isShuffleEquivalent(ArrayRef<int> Mask,ArrayRef<int> ExpectedMask,SDValue V1=SDValue (),SDValue V2=SDValue ())9622 static bool isShuffleEquivalent(ArrayRef<int> Mask, ArrayRef<int> ExpectedMask,
9623                                 SDValue V1 = SDValue(),
9624                                 SDValue V2 = SDValue()) {
9625   int Size = Mask.size();
9626   if (Size != (int)ExpectedMask.size())
9627     return false;
9628 
9629   for (int i = 0; i < Size; ++i) {
9630     assert(Mask[i] >= -1 && "Out of bound mask element!");
9631     int MaskIdx = Mask[i];
9632     int ExpectedIdx = ExpectedMask[i];
9633     if (0 <= MaskIdx && MaskIdx != ExpectedIdx) {
9634       SDValue MaskV = MaskIdx < Size ? V1 : V2;
9635       SDValue ExpectedV = ExpectedIdx < Size ? V1 : V2;
9636       MaskIdx = MaskIdx < Size ? MaskIdx : (MaskIdx - Size);
9637       ExpectedIdx = ExpectedIdx < Size ? ExpectedIdx : (ExpectedIdx - Size);
9638       if (!IsElementEquivalent(Size, MaskV, ExpectedV, MaskIdx, ExpectedIdx))
9639         return false;
9640     }
9641   }
9642   return true;
9643 }
9644 
9645 /// Checks whether a target shuffle mask is equivalent to an explicit pattern.
9646 ///
9647 /// The masks must be exactly the same width.
9648 ///
9649 /// If an element in Mask matches SM_SentinelUndef (-1) then the corresponding
9650 /// value in ExpectedMask is always accepted. Otherwise the indices must match.
9651 ///
9652 /// SM_SentinelZero is accepted as a valid negative index but must match in
9653 /// both, or via a known bits test.
isTargetShuffleEquivalent(MVT VT,ArrayRef<int> Mask,ArrayRef<int> ExpectedMask,const SelectionDAG & DAG,SDValue V1=SDValue (),SDValue V2=SDValue ())9654 static bool isTargetShuffleEquivalent(MVT VT, ArrayRef<int> Mask,
9655                                       ArrayRef<int> ExpectedMask,
9656                                       const SelectionDAG &DAG,
9657                                       SDValue V1 = SDValue(),
9658                                       SDValue V2 = SDValue()) {
9659   int Size = Mask.size();
9660   if (Size != (int)ExpectedMask.size())
9661     return false;
9662   assert(llvm::all_of(ExpectedMask,
9663                       [Size](int M) { return isInRange(M, 0, 2 * Size); }) &&
9664          "Illegal target shuffle mask");
9665 
9666   // Check for out-of-range target shuffle mask indices.
9667   if (!isUndefOrZeroOrInRange(Mask, 0, 2 * Size))
9668     return false;
9669 
9670   // Don't use V1/V2 if they're not the same size as the shuffle mask type.
9671   if (V1 && (V1.getValueSizeInBits() != VT.getSizeInBits() ||
9672              !V1.getValueType().isVector()))
9673     V1 = SDValue();
9674   if (V2 && (V2.getValueSizeInBits() != VT.getSizeInBits() ||
9675              !V2.getValueType().isVector()))
9676     V2 = SDValue();
9677 
9678   APInt ZeroV1 = APInt::getZero(Size);
9679   APInt ZeroV2 = APInt::getZero(Size);
9680 
9681   for (int i = 0; i < Size; ++i) {
9682     int MaskIdx = Mask[i];
9683     int ExpectedIdx = ExpectedMask[i];
9684     if (MaskIdx == SM_SentinelUndef || MaskIdx == ExpectedIdx)
9685       continue;
9686     if (MaskIdx == SM_SentinelZero) {
9687       // If we need this expected index to be a zero element, then update the
9688       // relevant zero mask and perform the known bits at the end to minimize
9689       // repeated computes.
9690       SDValue ExpectedV = ExpectedIdx < Size ? V1 : V2;
9691       if (ExpectedV &&
9692           Size == (int)ExpectedV.getValueType().getVectorNumElements()) {
9693         int BitIdx = ExpectedIdx < Size ? ExpectedIdx : (ExpectedIdx - Size);
9694         APInt &ZeroMask = ExpectedIdx < Size ? ZeroV1 : ZeroV2;
9695         ZeroMask.setBit(BitIdx);
9696         continue;
9697       }
9698     }
9699     if (MaskIdx >= 0) {
9700       SDValue MaskV = MaskIdx < Size ? V1 : V2;
9701       SDValue ExpectedV = ExpectedIdx < Size ? V1 : V2;
9702       MaskIdx = MaskIdx < Size ? MaskIdx : (MaskIdx - Size);
9703       ExpectedIdx = ExpectedIdx < Size ? ExpectedIdx : (ExpectedIdx - Size);
9704       if (IsElementEquivalent(Size, MaskV, ExpectedV, MaskIdx, ExpectedIdx))
9705         continue;
9706     }
9707     return false;
9708   }
9709   return (ZeroV1.isZero() || DAG.MaskedVectorIsZero(V1, ZeroV1)) &&
9710          (ZeroV2.isZero() || DAG.MaskedVectorIsZero(V2, ZeroV2));
9711 }
9712 
9713 // Check if the shuffle mask is suitable for the AVX vpunpcklwd or vpunpckhwd
9714 // instructions.
isUnpackWdShuffleMask(ArrayRef<int> Mask,MVT VT,const SelectionDAG & DAG)9715 static bool isUnpackWdShuffleMask(ArrayRef<int> Mask, MVT VT,
9716                                   const SelectionDAG &DAG) {
9717   if (VT != MVT::v8i32 && VT != MVT::v8f32)
9718     return false;
9719 
9720   SmallVector<int, 8> Unpcklwd;
9721   createUnpackShuffleMask(MVT::v8i16, Unpcklwd, /* Lo = */ true,
9722                           /* Unary = */ false);
9723   SmallVector<int, 8> Unpckhwd;
9724   createUnpackShuffleMask(MVT::v8i16, Unpckhwd, /* Lo = */ false,
9725                           /* Unary = */ false);
9726   bool IsUnpackwdMask = (isTargetShuffleEquivalent(VT, Mask, Unpcklwd, DAG) ||
9727                          isTargetShuffleEquivalent(VT, Mask, Unpckhwd, DAG));
9728   return IsUnpackwdMask;
9729 }
9730 
is128BitUnpackShuffleMask(ArrayRef<int> Mask,const SelectionDAG & DAG)9731 static bool is128BitUnpackShuffleMask(ArrayRef<int> Mask,
9732                                       const SelectionDAG &DAG) {
9733   // Create 128-bit vector type based on mask size.
9734   MVT EltVT = MVT::getIntegerVT(128 / Mask.size());
9735   MVT VT = MVT::getVectorVT(EltVT, Mask.size());
9736 
9737   // We can't assume a canonical shuffle mask, so try the commuted version too.
9738   SmallVector<int, 4> CommutedMask(Mask);
9739   ShuffleVectorSDNode::commuteMask(CommutedMask);
9740 
9741   // Match any of unary/binary or low/high.
9742   for (unsigned i = 0; i != 4; ++i) {
9743     SmallVector<int, 16> UnpackMask;
9744     createUnpackShuffleMask(VT, UnpackMask, (i >> 1) % 2, i % 2);
9745     if (isTargetShuffleEquivalent(VT, Mask, UnpackMask, DAG) ||
9746         isTargetShuffleEquivalent(VT, CommutedMask, UnpackMask, DAG))
9747       return true;
9748   }
9749   return false;
9750 }
9751 
9752 /// Return true if a shuffle mask chooses elements identically in its top and
9753 /// bottom halves. For example, any splat mask has the same top and bottom
9754 /// halves. If an element is undefined in only one half of the mask, the halves
9755 /// are not considered identical.
hasIdenticalHalvesShuffleMask(ArrayRef<int> Mask)9756 static bool hasIdenticalHalvesShuffleMask(ArrayRef<int> Mask) {
9757   assert(Mask.size() % 2 == 0 && "Expecting even number of elements in mask");
9758   unsigned HalfSize = Mask.size() / 2;
9759   for (unsigned i = 0; i != HalfSize; ++i) {
9760     if (Mask[i] != Mask[i + HalfSize])
9761       return false;
9762   }
9763   return true;
9764 }
9765 
9766 /// Get a 4-lane 8-bit shuffle immediate for a mask.
9767 ///
9768 /// This helper function produces an 8-bit shuffle immediate corresponding to
9769 /// the ubiquitous shuffle encoding scheme used in x86 instructions for
9770 /// shuffling 4 lanes. It can be used with most of the PSHUF instructions for
9771 /// example.
9772 ///
9773 /// NB: We rely heavily on "undef" masks preserving the input lane.
getV4X86ShuffleImm(ArrayRef<int> Mask)9774 static unsigned getV4X86ShuffleImm(ArrayRef<int> Mask) {
9775   assert(Mask.size() == 4 && "Only 4-lane shuffle masks");
9776   assert(Mask[0] >= -1 && Mask[0] < 4 && "Out of bound mask element!");
9777   assert(Mask[1] >= -1 && Mask[1] < 4 && "Out of bound mask element!");
9778   assert(Mask[2] >= -1 && Mask[2] < 4 && "Out of bound mask element!");
9779   assert(Mask[3] >= -1 && Mask[3] < 4 && "Out of bound mask element!");
9780 
9781   // If the mask only uses one non-undef element, then fully 'splat' it to
9782   // improve later broadcast matching.
9783   int FirstIndex = find_if(Mask, [](int M) { return M >= 0; }) - Mask.begin();
9784   assert(0 <= FirstIndex && FirstIndex < 4 && "All undef shuffle mask");
9785 
9786   int FirstElt = Mask[FirstIndex];
9787   if (all_of(Mask, [FirstElt](int M) { return M < 0 || M == FirstElt; }))
9788     return (FirstElt << 6) | (FirstElt << 4) | (FirstElt << 2) | FirstElt;
9789 
9790   unsigned Imm = 0;
9791   Imm |= (Mask[0] < 0 ? 0 : Mask[0]) << 0;
9792   Imm |= (Mask[1] < 0 ? 1 : Mask[1]) << 2;
9793   Imm |= (Mask[2] < 0 ? 2 : Mask[2]) << 4;
9794   Imm |= (Mask[3] < 0 ? 3 : Mask[3]) << 6;
9795   return Imm;
9796 }
9797 
getV4X86ShuffleImm8ForMask(ArrayRef<int> Mask,const SDLoc & DL,SelectionDAG & DAG)9798 static SDValue getV4X86ShuffleImm8ForMask(ArrayRef<int> Mask, const SDLoc &DL,
9799                                           SelectionDAG &DAG) {
9800   return DAG.getTargetConstant(getV4X86ShuffleImm(Mask), DL, MVT::i8);
9801 }
9802 
9803 // The Shuffle result is as follow:
9804 // 0*a[0]0*a[1]...0*a[n] , n >=0 where a[] elements in a ascending order.
9805 // Each Zeroable's element correspond to a particular Mask's element.
9806 // As described in computeZeroableShuffleElements function.
9807 //
9808 // The function looks for a sub-mask that the nonzero elements are in
9809 // increasing order. If such sub-mask exist. The function returns true.
isNonZeroElementsInOrder(const APInt & Zeroable,ArrayRef<int> Mask,const EVT & VectorType,bool & IsZeroSideLeft)9810 static bool isNonZeroElementsInOrder(const APInt &Zeroable,
9811                                      ArrayRef<int> Mask, const EVT &VectorType,
9812                                      bool &IsZeroSideLeft) {
9813   int NextElement = -1;
9814   // Check if the Mask's nonzero elements are in increasing order.
9815   for (int i = 0, e = Mask.size(); i < e; i++) {
9816     // Checks if the mask's zeros elements are built from only zeros.
9817     assert(Mask[i] >= -1 && "Out of bound mask element!");
9818     if (Mask[i] < 0)
9819       return false;
9820     if (Zeroable[i])
9821       continue;
9822     // Find the lowest non zero element
9823     if (NextElement < 0) {
9824       NextElement = Mask[i] != 0 ? VectorType.getVectorNumElements() : 0;
9825       IsZeroSideLeft = NextElement != 0;
9826     }
9827     // Exit if the mask's non zero elements are not in increasing order.
9828     if (NextElement != Mask[i])
9829       return false;
9830     NextElement++;
9831   }
9832   return true;
9833 }
9834 
9835 /// Try to lower a shuffle with a single PSHUFB of V1 or V2.
lowerShuffleWithPSHUFB(const SDLoc & DL,MVT VT,ArrayRef<int> Mask,SDValue V1,SDValue V2,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)9836 static SDValue lowerShuffleWithPSHUFB(const SDLoc &DL, MVT VT,
9837                                       ArrayRef<int> Mask, SDValue V1,
9838                                       SDValue V2, const APInt &Zeroable,
9839                                       const X86Subtarget &Subtarget,
9840                                       SelectionDAG &DAG) {
9841   int Size = Mask.size();
9842   int LaneSize = 128 / VT.getScalarSizeInBits();
9843   const int NumBytes = VT.getSizeInBits() / 8;
9844   const int NumEltBytes = VT.getScalarSizeInBits() / 8;
9845 
9846   assert((Subtarget.hasSSSE3() && VT.is128BitVector()) ||
9847          (Subtarget.hasAVX2() && VT.is256BitVector()) ||
9848          (Subtarget.hasBWI() && VT.is512BitVector()));
9849 
9850   SmallVector<SDValue, 64> PSHUFBMask(NumBytes);
9851   // Sign bit set in i8 mask means zero element.
9852   SDValue ZeroMask = DAG.getConstant(0x80, DL, MVT::i8);
9853 
9854   SDValue V;
9855   for (int i = 0; i < NumBytes; ++i) {
9856     int M = Mask[i / NumEltBytes];
9857     if (M < 0) {
9858       PSHUFBMask[i] = DAG.getUNDEF(MVT::i8);
9859       continue;
9860     }
9861     if (Zeroable[i / NumEltBytes]) {
9862       PSHUFBMask[i] = ZeroMask;
9863       continue;
9864     }
9865 
9866     // We can only use a single input of V1 or V2.
9867     SDValue SrcV = (M >= Size ? V2 : V1);
9868     if (V && V != SrcV)
9869       return SDValue();
9870     V = SrcV;
9871     M %= Size;
9872 
9873     // PSHUFB can't cross lanes, ensure this doesn't happen.
9874     if ((M / LaneSize) != ((i / NumEltBytes) / LaneSize))
9875       return SDValue();
9876 
9877     M = M % LaneSize;
9878     M = M * NumEltBytes + (i % NumEltBytes);
9879     PSHUFBMask[i] = DAG.getConstant(M, DL, MVT::i8);
9880   }
9881   assert(V && "Failed to find a source input");
9882 
9883   MVT I8VT = MVT::getVectorVT(MVT::i8, NumBytes);
9884   return DAG.getBitcast(
9885       VT, DAG.getNode(X86ISD::PSHUFB, DL, I8VT, DAG.getBitcast(I8VT, V),
9886                       DAG.getBuildVector(I8VT, DL, PSHUFBMask)));
9887 }
9888 
9889 static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
9890                            const X86Subtarget &Subtarget, SelectionDAG &DAG,
9891                            const SDLoc &dl);
9892 
9893 // X86 has dedicated shuffle that can be lowered to VEXPAND
lowerShuffleToEXPAND(const SDLoc & DL,MVT VT,const APInt & Zeroable,ArrayRef<int> Mask,SDValue & V1,SDValue & V2,SelectionDAG & DAG,const X86Subtarget & Subtarget)9894 static SDValue lowerShuffleToEXPAND(const SDLoc &DL, MVT VT,
9895                                     const APInt &Zeroable,
9896                                     ArrayRef<int> Mask, SDValue &V1,
9897                                     SDValue &V2, SelectionDAG &DAG,
9898                                     const X86Subtarget &Subtarget) {
9899   bool IsLeftZeroSide = true;
9900   if (!isNonZeroElementsInOrder(Zeroable, Mask, V1.getValueType(),
9901                                 IsLeftZeroSide))
9902     return SDValue();
9903   unsigned VEXPANDMask = (~Zeroable).getZExtValue();
9904   MVT IntegerType =
9905       MVT::getIntegerVT(std::max((int)VT.getVectorNumElements(), 8));
9906   SDValue MaskNode = DAG.getConstant(VEXPANDMask, DL, IntegerType);
9907   unsigned NumElts = VT.getVectorNumElements();
9908   assert((NumElts == 4 || NumElts == 8 || NumElts == 16) &&
9909          "Unexpected number of vector elements");
9910   SDValue VMask = getMaskNode(MaskNode, MVT::getVectorVT(MVT::i1, NumElts),
9911                               Subtarget, DAG, DL);
9912   SDValue ZeroVector = getZeroVector(VT, Subtarget, DAG, DL);
9913   SDValue ExpandedVector = IsLeftZeroSide ? V2 : V1;
9914   return DAG.getNode(X86ISD::EXPAND, DL, VT, ExpandedVector, ZeroVector, VMask);
9915 }
9916 
matchShuffleWithUNPCK(MVT VT,SDValue & V1,SDValue & V2,unsigned & UnpackOpcode,bool IsUnary,ArrayRef<int> TargetMask,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)9917 static bool matchShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2,
9918                                   unsigned &UnpackOpcode, bool IsUnary,
9919                                   ArrayRef<int> TargetMask, const SDLoc &DL,
9920                                   SelectionDAG &DAG,
9921                                   const X86Subtarget &Subtarget) {
9922   int NumElts = VT.getVectorNumElements();
9923 
9924   bool Undef1 = true, Undef2 = true, Zero1 = true, Zero2 = true;
9925   for (int i = 0; i != NumElts; i += 2) {
9926     int M1 = TargetMask[i + 0];
9927     int M2 = TargetMask[i + 1];
9928     Undef1 &= (SM_SentinelUndef == M1);
9929     Undef2 &= (SM_SentinelUndef == M2);
9930     Zero1 &= isUndefOrZero(M1);
9931     Zero2 &= isUndefOrZero(M2);
9932   }
9933   assert(!((Undef1 || Zero1) && (Undef2 || Zero2)) &&
9934          "Zeroable shuffle detected");
9935 
9936   // Attempt to match the target mask against the unpack lo/hi mask patterns.
9937   SmallVector<int, 64> Unpckl, Unpckh;
9938   createUnpackShuffleMask(VT, Unpckl, /* Lo = */ true, IsUnary);
9939   if (isTargetShuffleEquivalent(VT, TargetMask, Unpckl, DAG, V1,
9940                                 (IsUnary ? V1 : V2))) {
9941     UnpackOpcode = X86ISD::UNPCKL;
9942     V2 = (Undef2 ? DAG.getUNDEF(VT) : (IsUnary ? V1 : V2));
9943     V1 = (Undef1 ? DAG.getUNDEF(VT) : V1);
9944     return true;
9945   }
9946 
9947   createUnpackShuffleMask(VT, Unpckh, /* Lo = */ false, IsUnary);
9948   if (isTargetShuffleEquivalent(VT, TargetMask, Unpckh, DAG, V1,
9949                                 (IsUnary ? V1 : V2))) {
9950     UnpackOpcode = X86ISD::UNPCKH;
9951     V2 = (Undef2 ? DAG.getUNDEF(VT) : (IsUnary ? V1 : V2));
9952     V1 = (Undef1 ? DAG.getUNDEF(VT) : V1);
9953     return true;
9954   }
9955 
9956   // If an unary shuffle, attempt to match as an unpack lo/hi with zero.
9957   if (IsUnary && (Zero1 || Zero2)) {
9958     // Don't bother if we can blend instead.
9959     if ((Subtarget.hasSSE41() || VT == MVT::v2i64 || VT == MVT::v2f64) &&
9960         isSequentialOrUndefOrZeroInRange(TargetMask, 0, NumElts, 0))
9961       return false;
9962 
9963     bool MatchLo = true, MatchHi = true;
9964     for (int i = 0; (i != NumElts) && (MatchLo || MatchHi); ++i) {
9965       int M = TargetMask[i];
9966 
9967       // Ignore if the input is known to be zero or the index is undef.
9968       if ((((i & 1) == 0) && Zero1) || (((i & 1) == 1) && Zero2) ||
9969           (M == SM_SentinelUndef))
9970         continue;
9971 
9972       MatchLo &= (M == Unpckl[i]);
9973       MatchHi &= (M == Unpckh[i]);
9974     }
9975 
9976     if (MatchLo || MatchHi) {
9977       UnpackOpcode = MatchLo ? X86ISD::UNPCKL : X86ISD::UNPCKH;
9978       V2 = Zero2 ? getZeroVector(VT, Subtarget, DAG, DL) : V1;
9979       V1 = Zero1 ? getZeroVector(VT, Subtarget, DAG, DL) : V1;
9980       return true;
9981     }
9982   }
9983 
9984   // If a binary shuffle, commute and try again.
9985   if (!IsUnary) {
9986     ShuffleVectorSDNode::commuteMask(Unpckl);
9987     if (isTargetShuffleEquivalent(VT, TargetMask, Unpckl, DAG)) {
9988       UnpackOpcode = X86ISD::UNPCKL;
9989       std::swap(V1, V2);
9990       return true;
9991     }
9992 
9993     ShuffleVectorSDNode::commuteMask(Unpckh);
9994     if (isTargetShuffleEquivalent(VT, TargetMask, Unpckh, DAG)) {
9995       UnpackOpcode = X86ISD::UNPCKH;
9996       std::swap(V1, V2);
9997       return true;
9998     }
9999   }
10000 
10001   return false;
10002 }
10003 
10004 // X86 has dedicated unpack instructions that can handle specific blend
10005 // operations: UNPCKH and UNPCKL.
lowerShuffleWithUNPCK(const SDLoc & DL,MVT VT,ArrayRef<int> Mask,SDValue V1,SDValue V2,SelectionDAG & DAG)10006 static SDValue lowerShuffleWithUNPCK(const SDLoc &DL, MVT VT,
10007                                      ArrayRef<int> Mask, SDValue V1, SDValue V2,
10008                                      SelectionDAG &DAG) {
10009   SmallVector<int, 8> Unpckl;
10010   createUnpackShuffleMask(VT, Unpckl, /* Lo = */ true, /* Unary = */ false);
10011   if (isShuffleEquivalent(Mask, Unpckl, V1, V2))
10012     return DAG.getNode(X86ISD::UNPCKL, DL, VT, V1, V2);
10013 
10014   SmallVector<int, 8> Unpckh;
10015   createUnpackShuffleMask(VT, Unpckh, /* Lo = */ false, /* Unary = */ false);
10016   if (isShuffleEquivalent(Mask, Unpckh, V1, V2))
10017     return DAG.getNode(X86ISD::UNPCKH, DL, VT, V1, V2);
10018 
10019   // Commute and try again.
10020   ShuffleVectorSDNode::commuteMask(Unpckl);
10021   if (isShuffleEquivalent(Mask, Unpckl, V1, V2))
10022     return DAG.getNode(X86ISD::UNPCKL, DL, VT, V2, V1);
10023 
10024   ShuffleVectorSDNode::commuteMask(Unpckh);
10025   if (isShuffleEquivalent(Mask, Unpckh, V1, V2))
10026     return DAG.getNode(X86ISD::UNPCKH, DL, VT, V2, V1);
10027 
10028   return SDValue();
10029 }
10030 
10031 /// Check if the mask can be mapped to a preliminary shuffle (vperm 64-bit)
10032 /// followed by unpack 256-bit.
lowerShuffleWithUNPCK256(const SDLoc & DL,MVT VT,ArrayRef<int> Mask,SDValue V1,SDValue V2,SelectionDAG & DAG)10033 static SDValue lowerShuffleWithUNPCK256(const SDLoc &DL, MVT VT,
10034                                         ArrayRef<int> Mask, SDValue V1,
10035                                         SDValue V2, SelectionDAG &DAG) {
10036   SmallVector<int, 32> Unpckl, Unpckh;
10037   createSplat2ShuffleMask(VT, Unpckl, /* Lo */ true);
10038   createSplat2ShuffleMask(VT, Unpckh, /* Lo */ false);
10039 
10040   unsigned UnpackOpcode;
10041   if (isShuffleEquivalent(Mask, Unpckl, V1, V2))
10042     UnpackOpcode = X86ISD::UNPCKL;
10043   else if (isShuffleEquivalent(Mask, Unpckh, V1, V2))
10044     UnpackOpcode = X86ISD::UNPCKH;
10045   else
10046     return SDValue();
10047 
10048   // This is a "natural" unpack operation (rather than the 128-bit sectored
10049   // operation implemented by AVX). We need to rearrange 64-bit chunks of the
10050   // input in order to use the x86 instruction.
10051   V1 = DAG.getVectorShuffle(MVT::v4f64, DL, DAG.getBitcast(MVT::v4f64, V1),
10052                             DAG.getUNDEF(MVT::v4f64), {0, 2, 1, 3});
10053   V1 = DAG.getBitcast(VT, V1);
10054   return DAG.getNode(UnpackOpcode, DL, VT, V1, V1);
10055 }
10056 
10057 // Check if the mask can be mapped to a TRUNCATE or VTRUNC, truncating the
10058 // source into the lower elements and zeroing the upper elements.
matchShuffleAsVTRUNC(MVT & SrcVT,MVT & DstVT,MVT VT,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget)10059 static bool matchShuffleAsVTRUNC(MVT &SrcVT, MVT &DstVT, MVT VT,
10060                                  ArrayRef<int> Mask, const APInt &Zeroable,
10061                                  const X86Subtarget &Subtarget) {
10062   if (!VT.is512BitVector() && !Subtarget.hasVLX())
10063     return false;
10064 
10065   unsigned NumElts = Mask.size();
10066   unsigned EltSizeInBits = VT.getScalarSizeInBits();
10067   unsigned MaxScale = 64 / EltSizeInBits;
10068 
10069   for (unsigned Scale = 2; Scale <= MaxScale; Scale += Scale) {
10070     unsigned SrcEltBits = EltSizeInBits * Scale;
10071     if (SrcEltBits < 32 && !Subtarget.hasBWI())
10072       continue;
10073     unsigned NumSrcElts = NumElts / Scale;
10074     if (!isSequentialOrUndefInRange(Mask, 0, NumSrcElts, 0, Scale))
10075       continue;
10076     unsigned UpperElts = NumElts - NumSrcElts;
10077     if (!Zeroable.extractBits(UpperElts, NumSrcElts).isAllOnes())
10078       continue;
10079     SrcVT = MVT::getIntegerVT(EltSizeInBits * Scale);
10080     SrcVT = MVT::getVectorVT(SrcVT, NumSrcElts);
10081     DstVT = MVT::getIntegerVT(EltSizeInBits);
10082     if ((NumSrcElts * EltSizeInBits) >= 128) {
10083       // ISD::TRUNCATE
10084       DstVT = MVT::getVectorVT(DstVT, NumSrcElts);
10085     } else {
10086       // X86ISD::VTRUNC
10087       DstVT = MVT::getVectorVT(DstVT, 128 / EltSizeInBits);
10088     }
10089     return true;
10090   }
10091 
10092   return false;
10093 }
10094 
10095 // Helper to create TRUNCATE/VTRUNC nodes, optionally with zero/undef upper
10096 // element padding to the final DstVT.
getAVX512TruncNode(const SDLoc & DL,MVT DstVT,SDValue Src,const X86Subtarget & Subtarget,SelectionDAG & DAG,bool ZeroUppers)10097 static SDValue getAVX512TruncNode(const SDLoc &DL, MVT DstVT, SDValue Src,
10098                                   const X86Subtarget &Subtarget,
10099                                   SelectionDAG &DAG, bool ZeroUppers) {
10100   MVT SrcVT = Src.getSimpleValueType();
10101   MVT DstSVT = DstVT.getScalarType();
10102   unsigned NumDstElts = DstVT.getVectorNumElements();
10103   unsigned NumSrcElts = SrcVT.getVectorNumElements();
10104   unsigned DstEltSizeInBits = DstVT.getScalarSizeInBits();
10105 
10106   if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT))
10107     return SDValue();
10108 
10109   // Perform a direct ISD::TRUNCATE if possible.
10110   if (NumSrcElts == NumDstElts)
10111     return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Src);
10112 
10113   if (NumSrcElts > NumDstElts) {
10114     MVT TruncVT = MVT::getVectorVT(DstSVT, NumSrcElts);
10115     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Src);
10116     return extractSubVector(Trunc, 0, DAG, DL, DstVT.getSizeInBits());
10117   }
10118 
10119   if ((NumSrcElts * DstEltSizeInBits) >= 128) {
10120     MVT TruncVT = MVT::getVectorVT(DstSVT, NumSrcElts);
10121     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Src);
10122     return widenSubVector(Trunc, ZeroUppers, Subtarget, DAG, DL,
10123                           DstVT.getSizeInBits());
10124   }
10125 
10126   // Non-VLX targets must truncate from a 512-bit type, so we need to
10127   // widen, truncate and then possibly extract the original subvector.
10128   if (!Subtarget.hasVLX() && !SrcVT.is512BitVector()) {
10129     SDValue NewSrc = widenSubVector(Src, ZeroUppers, Subtarget, DAG, DL, 512);
10130     return getAVX512TruncNode(DL, DstVT, NewSrc, Subtarget, DAG, ZeroUppers);
10131   }
10132 
10133   // Fallback to a X86ISD::VTRUNC, padding if necessary.
10134   MVT TruncVT = MVT::getVectorVT(DstSVT, 128 / DstEltSizeInBits);
10135   SDValue Trunc = DAG.getNode(X86ISD::VTRUNC, DL, TruncVT, Src);
10136   if (DstVT != TruncVT)
10137     Trunc = widenSubVector(Trunc, ZeroUppers, Subtarget, DAG, DL,
10138                            DstVT.getSizeInBits());
10139   return Trunc;
10140 }
10141 
10142 // Try to lower trunc+vector_shuffle to a vpmovdb or a vpmovdw instruction.
10143 //
10144 // An example is the following:
10145 //
10146 // t0: ch = EntryToken
10147 //           t2: v4i64,ch = CopyFromReg t0, Register:v4i64 %0
10148 //         t25: v4i32 = truncate t2
10149 //       t41: v8i16 = bitcast t25
10150 //       t21: v8i16 = BUILD_VECTOR undef:i16, undef:i16, undef:i16, undef:i16,
10151 //       Constant:i16<0>, Constant:i16<0>, Constant:i16<0>, Constant:i16<0>
10152 //     t51: v8i16 = vector_shuffle<0,2,4,6,12,13,14,15> t41, t21
10153 //   t18: v2i64 = bitcast t51
10154 //
10155 // One can just use a single vpmovdw instruction, without avx512vl we need to
10156 // use the zmm variant and extract the lower subvector, padding with zeroes.
10157 // TODO: Merge with lowerShuffleAsVTRUNC.
lowerShuffleWithVPMOV(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)10158 static SDValue lowerShuffleWithVPMOV(const SDLoc &DL, MVT VT, SDValue V1,
10159                                      SDValue V2, ArrayRef<int> Mask,
10160                                      const APInt &Zeroable,
10161                                      const X86Subtarget &Subtarget,
10162                                      SelectionDAG &DAG) {
10163   assert((VT == MVT::v16i8 || VT == MVT::v8i16) && "Unexpected VTRUNC type");
10164   if (!Subtarget.hasAVX512())
10165     return SDValue();
10166 
10167   unsigned NumElts = VT.getVectorNumElements();
10168   unsigned EltSizeInBits = VT.getScalarSizeInBits();
10169   unsigned MaxScale = 64 / EltSizeInBits;
10170   for (unsigned Scale = 2; Scale <= MaxScale; Scale += Scale) {
10171     unsigned SrcEltBits = EltSizeInBits * Scale;
10172     unsigned NumSrcElts = NumElts / Scale;
10173     unsigned UpperElts = NumElts - NumSrcElts;
10174     if (!isSequentialOrUndefInRange(Mask, 0, NumSrcElts, 0, Scale) ||
10175         !Zeroable.extractBits(UpperElts, NumSrcElts).isAllOnes())
10176       continue;
10177 
10178     // Attempt to find a matching source truncation, but as a fall back VLX
10179     // cases can use the VPMOV directly.
10180     SDValue Src = peekThroughBitcasts(V1);
10181     if (Src.getOpcode() == ISD::TRUNCATE &&
10182         Src.getScalarValueSizeInBits() == SrcEltBits) {
10183       Src = Src.getOperand(0);
10184     } else if (Subtarget.hasVLX()) {
10185       MVT SrcSVT = MVT::getIntegerVT(SrcEltBits);
10186       MVT SrcVT = MVT::getVectorVT(SrcSVT, NumSrcElts);
10187       Src = DAG.getBitcast(SrcVT, Src);
10188       // Don't do this if PACKSS/PACKUS could perform it cheaper.
10189       if (Scale == 2 &&
10190           ((DAG.ComputeNumSignBits(Src) > EltSizeInBits) ||
10191            (DAG.computeKnownBits(Src).countMinLeadingZeros() >= EltSizeInBits)))
10192         return SDValue();
10193     } else
10194       return SDValue();
10195 
10196     // VPMOVWB is only available with avx512bw.
10197     if (!Subtarget.hasBWI() && Src.getScalarValueSizeInBits() < 32)
10198       return SDValue();
10199 
10200     bool UndefUppers = isUndefInRange(Mask, NumSrcElts, UpperElts);
10201     return getAVX512TruncNode(DL, VT, Src, Subtarget, DAG, !UndefUppers);
10202   }
10203 
10204   return SDValue();
10205 }
10206 
10207 // Attempt to match binary shuffle patterns as a truncate.
lowerShuffleAsVTRUNC(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)10208 static SDValue lowerShuffleAsVTRUNC(const SDLoc &DL, MVT VT, SDValue V1,
10209                                     SDValue V2, ArrayRef<int> Mask,
10210                                     const APInt &Zeroable,
10211                                     const X86Subtarget &Subtarget,
10212                                     SelectionDAG &DAG) {
10213   assert((VT.is128BitVector() || VT.is256BitVector()) &&
10214          "Unexpected VTRUNC type");
10215   if (!Subtarget.hasAVX512())
10216     return SDValue();
10217 
10218   unsigned NumElts = VT.getVectorNumElements();
10219   unsigned EltSizeInBits = VT.getScalarSizeInBits();
10220   unsigned MaxScale = 64 / EltSizeInBits;
10221   for (unsigned Scale = 2; Scale <= MaxScale; Scale += Scale) {
10222     // TODO: Support non-BWI VPMOVWB truncations?
10223     unsigned SrcEltBits = EltSizeInBits * Scale;
10224     if (SrcEltBits < 32 && !Subtarget.hasBWI())
10225       continue;
10226 
10227     // Match shuffle <Ofs,Ofs+Scale,Ofs+2*Scale,..,undef_or_zero,undef_or_zero>
10228     // Bail if the V2 elements are undef.
10229     unsigned NumHalfSrcElts = NumElts / Scale;
10230     unsigned NumSrcElts = 2 * NumHalfSrcElts;
10231     for (unsigned Offset = 0; Offset != Scale; ++Offset) {
10232       if (!isSequentialOrUndefInRange(Mask, 0, NumSrcElts, Offset, Scale) ||
10233           isUndefInRange(Mask, NumHalfSrcElts, NumHalfSrcElts))
10234         continue;
10235 
10236       // The elements beyond the truncation must be undef/zero.
10237       unsigned UpperElts = NumElts - NumSrcElts;
10238       if (UpperElts > 0 &&
10239           !Zeroable.extractBits(UpperElts, NumSrcElts).isAllOnes())
10240         continue;
10241       bool UndefUppers =
10242           UpperElts > 0 && isUndefInRange(Mask, NumSrcElts, UpperElts);
10243 
10244       // For offset truncations, ensure that the concat is cheap.
10245       if (Offset) {
10246         auto IsCheapConcat = [&](SDValue Lo, SDValue Hi) {
10247           if (Lo.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
10248               Hi.getOpcode() == ISD::EXTRACT_SUBVECTOR)
10249             return Lo.getOperand(0) == Hi.getOperand(0);
10250           if (ISD::isNormalLoad(Lo.getNode()) &&
10251               ISD::isNormalLoad(Hi.getNode())) {
10252             auto *LDLo = cast<LoadSDNode>(Lo);
10253             auto *LDHi = cast<LoadSDNode>(Hi);
10254             return DAG.areNonVolatileConsecutiveLoads(
10255                 LDHi, LDLo, Lo.getValueType().getStoreSize(), 1);
10256           }
10257           return false;
10258         };
10259         if (!IsCheapConcat(V1, V2))
10260           continue;
10261       }
10262 
10263       // As we're using both sources then we need to concat them together
10264       // and truncate from the double-sized src.
10265       MVT ConcatVT = MVT::getVectorVT(VT.getScalarType(), NumElts * 2);
10266       SDValue Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, V1, V2);
10267 
10268       MVT SrcSVT = MVT::getIntegerVT(SrcEltBits);
10269       MVT SrcVT = MVT::getVectorVT(SrcSVT, NumSrcElts);
10270       Src = DAG.getBitcast(SrcVT, Src);
10271 
10272       // Shift the offset'd elements into place for the truncation.
10273       // TODO: Use getTargetVShiftByConstNode.
10274       if (Offset)
10275         Src = DAG.getNode(
10276             X86ISD::VSRLI, DL, SrcVT, Src,
10277             DAG.getTargetConstant(Offset * EltSizeInBits, DL, MVT::i8));
10278 
10279       return getAVX512TruncNode(DL, VT, Src, Subtarget, DAG, !UndefUppers);
10280     }
10281   }
10282 
10283   return SDValue();
10284 }
10285 
10286 /// Check whether a compaction lowering can be done by dropping even/odd
10287 /// elements and compute how many times even/odd elements must be dropped.
10288 ///
10289 /// This handles shuffles which take every Nth element where N is a power of
10290 /// two. Example shuffle masks:
10291 ///
10292 /// (even)
10293 ///  N = 1:  0,  2,  4,  6,  8, 10, 12, 14,  0,  2,  4,  6,  8, 10, 12, 14
10294 ///  N = 1:  0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30
10295 ///  N = 2:  0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12
10296 ///  N = 2:  0,  4,  8, 12, 16, 20, 24, 28,  0,  4,  8, 12, 16, 20, 24, 28
10297 ///  N = 3:  0,  8,  0,  8,  0,  8,  0,  8,  0,  8,  0,  8,  0,  8,  0,  8
10298 ///  N = 3:  0,  8, 16, 24,  0,  8, 16, 24,  0,  8, 16, 24,  0,  8, 16, 24
10299 ///
10300 /// (odd)
10301 ///  N = 1:  1,  3,  5,  7,  9, 11, 13, 15,  0,  2,  4,  6,  8, 10, 12, 14
10302 ///  N = 1:  1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31
10303 ///
10304 /// Any of these lanes can of course be undef.
10305 ///
10306 /// This routine only supports N <= 3.
10307 /// FIXME: Evaluate whether either AVX or AVX-512 have any opportunities here
10308 /// for larger N.
10309 ///
10310 /// \returns N above, or the number of times even/odd elements must be dropped
10311 /// if there is such a number. Otherwise returns zero.
canLowerByDroppingElements(ArrayRef<int> Mask,bool MatchEven,bool IsSingleInput)10312 static int canLowerByDroppingElements(ArrayRef<int> Mask, bool MatchEven,
10313                                       bool IsSingleInput) {
10314   // The modulus for the shuffle vector entries is based on whether this is
10315   // a single input or not.
10316   int ShuffleModulus = Mask.size() * (IsSingleInput ? 1 : 2);
10317   assert(isPowerOf2_32((uint32_t)ShuffleModulus) &&
10318          "We should only be called with masks with a power-of-2 size!");
10319 
10320   uint64_t ModMask = (uint64_t)ShuffleModulus - 1;
10321   int Offset = MatchEven ? 0 : 1;
10322 
10323   // We track whether the input is viable for all power-of-2 strides 2^1, 2^2,
10324   // and 2^3 simultaneously. This is because we may have ambiguity with
10325   // partially undef inputs.
10326   bool ViableForN[3] = {true, true, true};
10327 
10328   for (int i = 0, e = Mask.size(); i < e; ++i) {
10329     // Ignore undef lanes, we'll optimistically collapse them to the pattern we
10330     // want.
10331     if (Mask[i] < 0)
10332       continue;
10333 
10334     bool IsAnyViable = false;
10335     for (unsigned j = 0; j != std::size(ViableForN); ++j)
10336       if (ViableForN[j]) {
10337         uint64_t N = j + 1;
10338 
10339         // The shuffle mask must be equal to (i * 2^N) % M.
10340         if ((uint64_t)(Mask[i] - Offset) == (((uint64_t)i << N) & ModMask))
10341           IsAnyViable = true;
10342         else
10343           ViableForN[j] = false;
10344       }
10345     // Early exit if we exhaust the possible powers of two.
10346     if (!IsAnyViable)
10347       break;
10348   }
10349 
10350   for (unsigned j = 0; j != std::size(ViableForN); ++j)
10351     if (ViableForN[j])
10352       return j + 1;
10353 
10354   // Return 0 as there is no viable power of two.
10355   return 0;
10356 }
10357 
10358 // X86 has dedicated pack instructions that can handle specific truncation
10359 // operations: PACKSS and PACKUS.
10360 // Checks for compaction shuffle masks if MaxStages > 1.
10361 // TODO: Add support for matching multiple PACKSS/PACKUS stages.
matchShuffleWithPACK(MVT VT,MVT & SrcVT,SDValue & V1,SDValue & V2,unsigned & PackOpcode,ArrayRef<int> TargetMask,const SelectionDAG & DAG,const X86Subtarget & Subtarget,unsigned MaxStages=1)10362 static bool matchShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1, SDValue &V2,
10363                                  unsigned &PackOpcode, ArrayRef<int> TargetMask,
10364                                  const SelectionDAG &DAG,
10365                                  const X86Subtarget &Subtarget,
10366                                  unsigned MaxStages = 1) {
10367   unsigned NumElts = VT.getVectorNumElements();
10368   unsigned BitSize = VT.getScalarSizeInBits();
10369   assert(0 < MaxStages && MaxStages <= 3 && (BitSize << MaxStages) <= 64 &&
10370          "Illegal maximum compaction");
10371 
10372   auto MatchPACK = [&](SDValue N1, SDValue N2, MVT PackVT) {
10373     unsigned NumSrcBits = PackVT.getScalarSizeInBits();
10374     unsigned NumPackedBits = NumSrcBits - BitSize;
10375     N1 = peekThroughBitcasts(N1);
10376     N2 = peekThroughBitcasts(N2);
10377     unsigned NumBits1 = N1.getScalarValueSizeInBits();
10378     unsigned NumBits2 = N2.getScalarValueSizeInBits();
10379     bool IsZero1 = llvm::isNullOrNullSplat(N1, /*AllowUndefs*/ false);
10380     bool IsZero2 = llvm::isNullOrNullSplat(N2, /*AllowUndefs*/ false);
10381     if ((!N1.isUndef() && !IsZero1 && NumBits1 != NumSrcBits) ||
10382         (!N2.isUndef() && !IsZero2 && NumBits2 != NumSrcBits))
10383       return false;
10384     if (Subtarget.hasSSE41() || BitSize == 8) {
10385       APInt ZeroMask = APInt::getHighBitsSet(NumSrcBits, NumPackedBits);
10386       if ((N1.isUndef() || IsZero1 || DAG.MaskedValueIsZero(N1, ZeroMask)) &&
10387           (N2.isUndef() || IsZero2 || DAG.MaskedValueIsZero(N2, ZeroMask))) {
10388         V1 = N1;
10389         V2 = N2;
10390         SrcVT = PackVT;
10391         PackOpcode = X86ISD::PACKUS;
10392         return true;
10393       }
10394     }
10395     bool IsAllOnes1 = llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false);
10396     bool IsAllOnes2 = llvm::isAllOnesOrAllOnesSplat(N2, /*AllowUndefs*/ false);
10397     if ((N1.isUndef() || IsZero1 || IsAllOnes1 ||
10398          DAG.ComputeNumSignBits(N1) > NumPackedBits) &&
10399         (N2.isUndef() || IsZero2 || IsAllOnes2 ||
10400          DAG.ComputeNumSignBits(N2) > NumPackedBits)) {
10401       V1 = N1;
10402       V2 = N2;
10403       SrcVT = PackVT;
10404       PackOpcode = X86ISD::PACKSS;
10405       return true;
10406     }
10407     return false;
10408   };
10409 
10410   // Attempt to match against wider and wider compaction patterns.
10411   for (unsigned NumStages = 1; NumStages <= MaxStages; ++NumStages) {
10412     MVT PackSVT = MVT::getIntegerVT(BitSize << NumStages);
10413     MVT PackVT = MVT::getVectorVT(PackSVT, NumElts >> NumStages);
10414 
10415     // Try binary shuffle.
10416     SmallVector<int, 32> BinaryMask;
10417     createPackShuffleMask(VT, BinaryMask, false, NumStages);
10418     if (isTargetShuffleEquivalent(VT, TargetMask, BinaryMask, DAG, V1, V2))
10419       if (MatchPACK(V1, V2, PackVT))
10420         return true;
10421 
10422     // Try unary shuffle.
10423     SmallVector<int, 32> UnaryMask;
10424     createPackShuffleMask(VT, UnaryMask, true, NumStages);
10425     if (isTargetShuffleEquivalent(VT, TargetMask, UnaryMask, DAG, V1))
10426       if (MatchPACK(V1, V1, PackVT))
10427         return true;
10428   }
10429 
10430   return false;
10431 }
10432 
lowerShuffleWithPACK(const SDLoc & DL,MVT VT,ArrayRef<int> Mask,SDValue V1,SDValue V2,SelectionDAG & DAG,const X86Subtarget & Subtarget)10433 static SDValue lowerShuffleWithPACK(const SDLoc &DL, MVT VT, ArrayRef<int> Mask,
10434                                     SDValue V1, SDValue V2, SelectionDAG &DAG,
10435                                     const X86Subtarget &Subtarget) {
10436   MVT PackVT;
10437   unsigned PackOpcode;
10438   unsigned SizeBits = VT.getSizeInBits();
10439   unsigned EltBits = VT.getScalarSizeInBits();
10440   unsigned MaxStages = Log2_32(64 / EltBits);
10441   if (!matchShuffleWithPACK(VT, PackVT, V1, V2, PackOpcode, Mask, DAG,
10442                             Subtarget, MaxStages))
10443     return SDValue();
10444 
10445   unsigned CurrentEltBits = PackVT.getScalarSizeInBits();
10446   unsigned NumStages = Log2_32(CurrentEltBits / EltBits);
10447 
10448   // Don't lower multi-stage packs on AVX512, truncation is better.
10449   if (NumStages != 1 && SizeBits == 128 && Subtarget.hasVLX())
10450     return SDValue();
10451 
10452   // Pack to the largest type possible:
10453   // vXi64/vXi32 -> PACK*SDW and vXi16 -> PACK*SWB.
10454   unsigned MaxPackBits = 16;
10455   if (CurrentEltBits > 16 &&
10456       (PackOpcode == X86ISD::PACKSS || Subtarget.hasSSE41()))
10457     MaxPackBits = 32;
10458 
10459   // Repeatedly pack down to the target size.
10460   SDValue Res;
10461   for (unsigned i = 0; i != NumStages; ++i) {
10462     unsigned SrcEltBits = std::min(MaxPackBits, CurrentEltBits);
10463     unsigned NumSrcElts = SizeBits / SrcEltBits;
10464     MVT SrcSVT = MVT::getIntegerVT(SrcEltBits);
10465     MVT DstSVT = MVT::getIntegerVT(SrcEltBits / 2);
10466     MVT SrcVT = MVT::getVectorVT(SrcSVT, NumSrcElts);
10467     MVT DstVT = MVT::getVectorVT(DstSVT, NumSrcElts * 2);
10468     Res = DAG.getNode(PackOpcode, DL, DstVT, DAG.getBitcast(SrcVT, V1),
10469                       DAG.getBitcast(SrcVT, V2));
10470     V1 = V2 = Res;
10471     CurrentEltBits /= 2;
10472   }
10473   assert(Res && Res.getValueType() == VT &&
10474          "Failed to lower compaction shuffle");
10475   return Res;
10476 }
10477 
10478 /// Try to emit a bitmask instruction for a shuffle.
10479 ///
10480 /// This handles cases where we can model a blend exactly as a bitmask due to
10481 /// one of the inputs being zeroable.
lowerShuffleAsBitMask(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)10482 static SDValue lowerShuffleAsBitMask(const SDLoc &DL, MVT VT, SDValue V1,
10483                                      SDValue V2, ArrayRef<int> Mask,
10484                                      const APInt &Zeroable,
10485                                      const X86Subtarget &Subtarget,
10486                                      SelectionDAG &DAG) {
10487   MVT MaskVT = VT;
10488   MVT EltVT = VT.getVectorElementType();
10489   SDValue Zero, AllOnes;
10490   // Use f64 if i64 isn't legal.
10491   if (EltVT == MVT::i64 && !Subtarget.is64Bit()) {
10492     EltVT = MVT::f64;
10493     MaskVT = MVT::getVectorVT(EltVT, Mask.size());
10494   }
10495 
10496   MVT LogicVT = VT;
10497   if (EltVT == MVT::f32 || EltVT == MVT::f64) {
10498     Zero = DAG.getConstantFP(0.0, DL, EltVT);
10499     APFloat AllOnesValue =
10500         APFloat::getAllOnesValue(SelectionDAG::EVTToAPFloatSemantics(EltVT));
10501     AllOnes = DAG.getConstantFP(AllOnesValue, DL, EltVT);
10502     LogicVT =
10503         MVT::getVectorVT(EltVT == MVT::f64 ? MVT::i64 : MVT::i32, Mask.size());
10504   } else {
10505     Zero = DAG.getConstant(0, DL, EltVT);
10506     AllOnes = DAG.getAllOnesConstant(DL, EltVT);
10507   }
10508 
10509   SmallVector<SDValue, 16> VMaskOps(Mask.size(), Zero);
10510   SDValue V;
10511   for (int i = 0, Size = Mask.size(); i < Size; ++i) {
10512     if (Zeroable[i])
10513       continue;
10514     if (Mask[i] % Size != i)
10515       return SDValue(); // Not a blend.
10516     if (!V)
10517       V = Mask[i] < Size ? V1 : V2;
10518     else if (V != (Mask[i] < Size ? V1 : V2))
10519       return SDValue(); // Can only let one input through the mask.
10520 
10521     VMaskOps[i] = AllOnes;
10522   }
10523   if (!V)
10524     return SDValue(); // No non-zeroable elements!
10525 
10526   SDValue VMask = DAG.getBuildVector(MaskVT, DL, VMaskOps);
10527   VMask = DAG.getBitcast(LogicVT, VMask);
10528   V = DAG.getBitcast(LogicVT, V);
10529   SDValue And = DAG.getNode(ISD::AND, DL, LogicVT, V, VMask);
10530   return DAG.getBitcast(VT, And);
10531 }
10532 
10533 /// Try to emit a blend instruction for a shuffle using bit math.
10534 ///
10535 /// This is used as a fallback approach when first class blend instructions are
10536 /// unavailable. Currently it is only suitable for integer vectors, but could
10537 /// be generalized for floating point vectors if desirable.
lowerShuffleAsBitBlend(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG)10538 static SDValue lowerShuffleAsBitBlend(const SDLoc &DL, MVT VT, SDValue V1,
10539                                       SDValue V2, ArrayRef<int> Mask,
10540                                       SelectionDAG &DAG) {
10541   assert(VT.isInteger() && "Only supports integer vector types!");
10542   MVT EltVT = VT.getVectorElementType();
10543   SDValue Zero = DAG.getConstant(0, DL, EltVT);
10544   SDValue AllOnes = DAG.getAllOnesConstant(DL, EltVT);
10545   SmallVector<SDValue, 16> MaskOps;
10546   for (int i = 0, Size = Mask.size(); i < Size; ++i) {
10547     if (Mask[i] >= 0 && Mask[i] != i && Mask[i] != i + Size)
10548       return SDValue(); // Shuffled input!
10549     MaskOps.push_back(Mask[i] < Size ? AllOnes : Zero);
10550   }
10551 
10552   SDValue V1Mask = DAG.getBuildVector(VT, DL, MaskOps);
10553   return getBitSelect(DL, VT, V1, V2, V1Mask, DAG);
10554 }
10555 
10556 static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask,
10557                                     SDValue PreservedSrc,
10558                                     const X86Subtarget &Subtarget,
10559                                     SelectionDAG &DAG);
10560 
matchShuffleAsBlend(MVT VT,SDValue V1,SDValue V2,MutableArrayRef<int> Mask,const APInt & Zeroable,bool & ForceV1Zero,bool & ForceV2Zero,uint64_t & BlendMask)10561 static bool matchShuffleAsBlend(MVT VT, SDValue V1, SDValue V2,
10562                                 MutableArrayRef<int> Mask,
10563                                 const APInt &Zeroable, bool &ForceV1Zero,
10564                                 bool &ForceV2Zero, uint64_t &BlendMask) {
10565   bool V1IsZeroOrUndef =
10566       V1.isUndef() || ISD::isBuildVectorAllZeros(V1.getNode());
10567   bool V2IsZeroOrUndef =
10568       V2.isUndef() || ISD::isBuildVectorAllZeros(V2.getNode());
10569 
10570   BlendMask = 0;
10571   ForceV1Zero = false, ForceV2Zero = false;
10572   assert(Mask.size() <= 64 && "Shuffle mask too big for blend mask");
10573 
10574   int NumElts = Mask.size();
10575   int NumLanes = VT.getSizeInBits() / 128;
10576   int NumEltsPerLane = NumElts / NumLanes;
10577   assert((NumLanes * NumEltsPerLane) == NumElts && "Value type mismatch");
10578 
10579   // For 32/64-bit elements, if we only reference one input (plus any undefs),
10580   // then ensure the blend mask part for that lane just references that input.
10581   bool ForceWholeLaneMasks =
10582       VT.is256BitVector() && VT.getScalarSizeInBits() >= 32;
10583 
10584   // Attempt to generate the binary blend mask. If an input is zero then
10585   // we can use any lane.
10586   for (int Lane = 0; Lane != NumLanes; ++Lane) {
10587     // Keep track of the inputs used per lane.
10588     bool LaneV1InUse = false;
10589     bool LaneV2InUse = false;
10590     uint64_t LaneBlendMask = 0;
10591     for (int LaneElt = 0; LaneElt != NumEltsPerLane; ++LaneElt) {
10592       int Elt = (Lane * NumEltsPerLane) + LaneElt;
10593       int M = Mask[Elt];
10594       if (M == SM_SentinelUndef)
10595         continue;
10596       if (M == Elt || (0 <= M && M < NumElts &&
10597                      IsElementEquivalent(NumElts, V1, V1, M, Elt))) {
10598         Mask[Elt] = Elt;
10599         LaneV1InUse = true;
10600         continue;
10601       }
10602       if (M == (Elt + NumElts) ||
10603           (NumElts <= M &&
10604            IsElementEquivalent(NumElts, V2, V2, M - NumElts, Elt))) {
10605         LaneBlendMask |= 1ull << LaneElt;
10606         Mask[Elt] = Elt + NumElts;
10607         LaneV2InUse = true;
10608         continue;
10609       }
10610       if (Zeroable[Elt]) {
10611         if (V1IsZeroOrUndef) {
10612           ForceV1Zero = true;
10613           Mask[Elt] = Elt;
10614           LaneV1InUse = true;
10615           continue;
10616         }
10617         if (V2IsZeroOrUndef) {
10618           ForceV2Zero = true;
10619           LaneBlendMask |= 1ull << LaneElt;
10620           Mask[Elt] = Elt + NumElts;
10621           LaneV2InUse = true;
10622           continue;
10623         }
10624       }
10625       return false;
10626     }
10627 
10628     // If we only used V2 then splat the lane blend mask to avoid any demanded
10629     // elts from V1 in this lane (the V1 equivalent is implicit with a zero
10630     // blend mask bit).
10631     if (ForceWholeLaneMasks && LaneV2InUse && !LaneV1InUse)
10632       LaneBlendMask = (1ull << NumEltsPerLane) - 1;
10633 
10634     BlendMask |= LaneBlendMask << (Lane * NumEltsPerLane);
10635   }
10636   return true;
10637 }
10638 
10639 /// Try to emit a blend instruction for a shuffle.
10640 ///
10641 /// This doesn't do any checks for the availability of instructions for blending
10642 /// these values. It relies on the availability of the X86ISD::BLENDI pattern to
10643 /// be matched in the backend with the type given. What it does check for is
10644 /// that the shuffle mask is a blend, or convertible into a blend with zero.
lowerShuffleAsBlend(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Original,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)10645 static SDValue lowerShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1,
10646                                    SDValue V2, ArrayRef<int> Original,
10647                                    const APInt &Zeroable,
10648                                    const X86Subtarget &Subtarget,
10649                                    SelectionDAG &DAG) {
10650   uint64_t BlendMask = 0;
10651   bool ForceV1Zero = false, ForceV2Zero = false;
10652   SmallVector<int, 64> Mask(Original);
10653   if (!matchShuffleAsBlend(VT, V1, V2, Mask, Zeroable, ForceV1Zero, ForceV2Zero,
10654                            BlendMask))
10655     return SDValue();
10656 
10657   // Create a REAL zero vector - ISD::isBuildVectorAllZeros allows UNDEFs.
10658   if (ForceV1Zero)
10659     V1 = getZeroVector(VT, Subtarget, DAG, DL);
10660   if (ForceV2Zero)
10661     V2 = getZeroVector(VT, Subtarget, DAG, DL);
10662 
10663   unsigned NumElts = VT.getVectorNumElements();
10664 
10665   switch (VT.SimpleTy) {
10666   case MVT::v4i64:
10667   case MVT::v8i32:
10668     assert(Subtarget.hasAVX2() && "256-bit integer blends require AVX2!");
10669     [[fallthrough]];
10670   case MVT::v4f64:
10671   case MVT::v8f32:
10672     assert(Subtarget.hasAVX() && "256-bit float blends require AVX!");
10673     [[fallthrough]];
10674   case MVT::v2f64:
10675   case MVT::v2i64:
10676   case MVT::v4f32:
10677   case MVT::v4i32:
10678   case MVT::v8i16:
10679     assert(Subtarget.hasSSE41() && "128-bit blends require SSE41!");
10680     return DAG.getNode(X86ISD::BLENDI, DL, VT, V1, V2,
10681                        DAG.getTargetConstant(BlendMask, DL, MVT::i8));
10682   case MVT::v16i16: {
10683     assert(Subtarget.hasAVX2() && "v16i16 blends require AVX2!");
10684     SmallVector<int, 8> RepeatedMask;
10685     if (is128BitLaneRepeatedShuffleMask(MVT::v16i16, Mask, RepeatedMask)) {
10686       // We can lower these with PBLENDW which is mirrored across 128-bit lanes.
10687       assert(RepeatedMask.size() == 8 && "Repeated mask size doesn't match!");
10688       BlendMask = 0;
10689       for (int i = 0; i < 8; ++i)
10690         if (RepeatedMask[i] >= 8)
10691           BlendMask |= 1ull << i;
10692       return DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1, V2,
10693                          DAG.getTargetConstant(BlendMask, DL, MVT::i8));
10694     }
10695     // Use PBLENDW for lower/upper lanes and then blend lanes.
10696     // TODO - we should allow 2 PBLENDW here and leave shuffle combine to
10697     // merge to VSELECT where useful.
10698     uint64_t LoMask = BlendMask & 0xFF;
10699     uint64_t HiMask = (BlendMask >> 8) & 0xFF;
10700     if (LoMask == 0 || LoMask == 255 || HiMask == 0 || HiMask == 255) {
10701       SDValue Lo = DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1, V2,
10702                                DAG.getTargetConstant(LoMask, DL, MVT::i8));
10703       SDValue Hi = DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1, V2,
10704                                DAG.getTargetConstant(HiMask, DL, MVT::i8));
10705       return DAG.getVectorShuffle(
10706           MVT::v16i16, DL, Lo, Hi,
10707           {0, 1, 2, 3, 4, 5, 6, 7, 24, 25, 26, 27, 28, 29, 30, 31});
10708     }
10709     [[fallthrough]];
10710   }
10711   case MVT::v32i8:
10712     assert(Subtarget.hasAVX2() && "256-bit byte-blends require AVX2!");
10713     [[fallthrough]];
10714   case MVT::v16i8: {
10715     assert(Subtarget.hasSSE41() && "128-bit byte-blends require SSE41!");
10716 
10717     // Attempt to lower to a bitmask if we can. VPAND is faster than VPBLENDVB.
10718     if (SDValue Masked = lowerShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable,
10719                                                Subtarget, DAG))
10720       return Masked;
10721 
10722     if (Subtarget.hasBWI() && Subtarget.hasVLX()) {
10723       MVT IntegerType = MVT::getIntegerVT(std::max<unsigned>(NumElts, 8));
10724       SDValue MaskNode = DAG.getConstant(BlendMask, DL, IntegerType);
10725       return getVectorMaskingNode(V2, MaskNode, V1, Subtarget, DAG);
10726     }
10727 
10728     // If we have VPTERNLOG, we can use that as a bit blend.
10729     if (Subtarget.hasVLX())
10730       if (SDValue BitBlend =
10731               lowerShuffleAsBitBlend(DL, VT, V1, V2, Mask, DAG))
10732         return BitBlend;
10733 
10734     // Scale the blend by the number of bytes per element.
10735     int Scale = VT.getScalarSizeInBits() / 8;
10736 
10737     // This form of blend is always done on bytes. Compute the byte vector
10738     // type.
10739     MVT BlendVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8);
10740 
10741     // x86 allows load folding with blendvb from the 2nd source operand. But
10742     // we are still using LLVM select here (see comment below), so that's V1.
10743     // If V2 can be load-folded and V1 cannot be load-folded, then commute to
10744     // allow that load-folding possibility.
10745     if (!ISD::isNormalLoad(V1.getNode()) && ISD::isNormalLoad(V2.getNode())) {
10746       ShuffleVectorSDNode::commuteMask(Mask);
10747       std::swap(V1, V2);
10748     }
10749 
10750     // Compute the VSELECT mask. Note that VSELECT is really confusing in the
10751     // mix of LLVM's code generator and the x86 backend. We tell the code
10752     // generator that boolean values in the elements of an x86 vector register
10753     // are -1 for true and 0 for false. We then use the LLVM semantics of 'true'
10754     // mapping a select to operand #1, and 'false' mapping to operand #2. The
10755     // reality in x86 is that vector masks (pre-AVX-512) use only the high bit
10756     // of the element (the remaining are ignored) and 0 in that high bit would
10757     // mean operand #1 while 1 in the high bit would mean operand #2. So while
10758     // the LLVM model for boolean values in vector elements gets the relevant
10759     // bit set, it is set backwards and over constrained relative to x86's
10760     // actual model.
10761     SmallVector<SDValue, 32> VSELECTMask;
10762     for (int i = 0, Size = Mask.size(); i < Size; ++i)
10763       for (int j = 0; j < Scale; ++j)
10764         VSELECTMask.push_back(
10765             Mask[i] < 0 ? DAG.getUNDEF(MVT::i8)
10766                         : DAG.getConstant(Mask[i] < Size ? -1 : 0, DL,
10767                                           MVT::i8));
10768 
10769     V1 = DAG.getBitcast(BlendVT, V1);
10770     V2 = DAG.getBitcast(BlendVT, V2);
10771     return DAG.getBitcast(
10772         VT,
10773         DAG.getSelect(DL, BlendVT, DAG.getBuildVector(BlendVT, DL, VSELECTMask),
10774                       V1, V2));
10775   }
10776   case MVT::v16f32:
10777   case MVT::v8f64:
10778   case MVT::v8i64:
10779   case MVT::v16i32:
10780   case MVT::v32i16:
10781   case MVT::v64i8: {
10782     // Attempt to lower to a bitmask if we can. Only if not optimizing for size.
10783     bool OptForSize = DAG.shouldOptForSize();
10784     if (!OptForSize) {
10785       if (SDValue Masked = lowerShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable,
10786                                                  Subtarget, DAG))
10787         return Masked;
10788     }
10789 
10790     // Otherwise load an immediate into a GPR, cast to k-register, and use a
10791     // masked move.
10792     MVT IntegerType = MVT::getIntegerVT(std::max<unsigned>(NumElts, 8));
10793     SDValue MaskNode = DAG.getConstant(BlendMask, DL, IntegerType);
10794     return getVectorMaskingNode(V2, MaskNode, V1, Subtarget, DAG);
10795   }
10796   default:
10797     llvm_unreachable("Not a supported integer vector type!");
10798   }
10799 }
10800 
10801 /// Try to lower as a blend of elements from two inputs followed by
10802 /// a single-input permutation.
10803 ///
10804 /// This matches the pattern where we can blend elements from two inputs and
10805 /// then reduce the shuffle to a single-input permutation.
lowerShuffleAsBlendAndPermute(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG,bool ImmBlends=false)10806 static SDValue lowerShuffleAsBlendAndPermute(const SDLoc &DL, MVT VT,
10807                                              SDValue V1, SDValue V2,
10808                                              ArrayRef<int> Mask,
10809                                              SelectionDAG &DAG,
10810                                              bool ImmBlends = false) {
10811   // We build up the blend mask while checking whether a blend is a viable way
10812   // to reduce the shuffle.
10813   SmallVector<int, 32> BlendMask(Mask.size(), -1);
10814   SmallVector<int, 32> PermuteMask(Mask.size(), -1);
10815 
10816   for (int i = 0, Size = Mask.size(); i < Size; ++i) {
10817     if (Mask[i] < 0)
10818       continue;
10819 
10820     assert(Mask[i] < Size * 2 && "Shuffle input is out of bounds.");
10821 
10822     if (BlendMask[Mask[i] % Size] < 0)
10823       BlendMask[Mask[i] % Size] = Mask[i];
10824     else if (BlendMask[Mask[i] % Size] != Mask[i])
10825       return SDValue(); // Can't blend in the needed input!
10826 
10827     PermuteMask[i] = Mask[i] % Size;
10828   }
10829 
10830   // If only immediate blends, then bail if the blend mask can't be widened to
10831   // i16.
10832   unsigned EltSize = VT.getScalarSizeInBits();
10833   if (ImmBlends && EltSize == 8 && !canWidenShuffleElements(BlendMask))
10834     return SDValue();
10835 
10836   SDValue V = DAG.getVectorShuffle(VT, DL, V1, V2, BlendMask);
10837   return DAG.getVectorShuffle(VT, DL, V, DAG.getUNDEF(VT), PermuteMask);
10838 }
10839 
10840 /// Try to lower as an unpack of elements from two inputs followed by
10841 /// a single-input permutation.
10842 ///
10843 /// This matches the pattern where we can unpack elements from two inputs and
10844 /// then reduce the shuffle to a single-input (wider) permutation.
lowerShuffleAsUNPCKAndPermute(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG)10845 static SDValue lowerShuffleAsUNPCKAndPermute(const SDLoc &DL, MVT VT,
10846                                              SDValue V1, SDValue V2,
10847                                              ArrayRef<int> Mask,
10848                                              SelectionDAG &DAG) {
10849   int NumElts = Mask.size();
10850   int NumLanes = VT.getSizeInBits() / 128;
10851   int NumLaneElts = NumElts / NumLanes;
10852   int NumHalfLaneElts = NumLaneElts / 2;
10853 
10854   bool MatchLo = true, MatchHi = true;
10855   SDValue Ops[2] = {DAG.getUNDEF(VT), DAG.getUNDEF(VT)};
10856 
10857   // Determine UNPCKL/UNPCKH type and operand order.
10858   for (int Elt = 0; Elt != NumElts; ++Elt) {
10859     int M = Mask[Elt];
10860     if (M < 0)
10861       continue;
10862 
10863     // Normalize the mask value depending on whether it's V1 or V2.
10864     int NormM = M;
10865     SDValue &Op = Ops[Elt & 1];
10866     if (M < NumElts && (Op.isUndef() || Op == V1))
10867       Op = V1;
10868     else if (NumElts <= M && (Op.isUndef() || Op == V2)) {
10869       Op = V2;
10870       NormM -= NumElts;
10871     } else
10872       return SDValue();
10873 
10874     bool MatchLoAnyLane = false, MatchHiAnyLane = false;
10875     for (int Lane = 0; Lane != NumElts; Lane += NumLaneElts) {
10876       int Lo = Lane, Mid = Lane + NumHalfLaneElts, Hi = Lane + NumLaneElts;
10877       MatchLoAnyLane |= isUndefOrInRange(NormM, Lo, Mid);
10878       MatchHiAnyLane |= isUndefOrInRange(NormM, Mid, Hi);
10879       if (MatchLoAnyLane || MatchHiAnyLane) {
10880         assert((MatchLoAnyLane ^ MatchHiAnyLane) &&
10881                "Failed to match UNPCKLO/UNPCKHI");
10882         break;
10883       }
10884     }
10885     MatchLo &= MatchLoAnyLane;
10886     MatchHi &= MatchHiAnyLane;
10887     if (!MatchLo && !MatchHi)
10888       return SDValue();
10889   }
10890   assert((MatchLo ^ MatchHi) && "Failed to match UNPCKLO/UNPCKHI");
10891 
10892   // Element indices have changed after unpacking. Calculate permute mask
10893   // so that they will be put back to the position as dictated by the
10894   // original shuffle mask indices.
10895   SmallVector<int, 32> PermuteMask(NumElts, -1);
10896   for (int Elt = 0; Elt != NumElts; ++Elt) {
10897     int M = Mask[Elt];
10898     if (M < 0)
10899       continue;
10900     int NormM = M;
10901     if (NumElts <= M)
10902       NormM -= NumElts;
10903     bool IsFirstOp = M < NumElts;
10904     int BaseMaskElt =
10905         NumLaneElts * (NormM / NumLaneElts) + (2 * (NormM % NumHalfLaneElts));
10906     if ((IsFirstOp && V1 == Ops[0]) || (!IsFirstOp && V2 == Ops[0]))
10907       PermuteMask[Elt] = BaseMaskElt;
10908     else if ((IsFirstOp && V1 == Ops[1]) || (!IsFirstOp && V2 == Ops[1]))
10909       PermuteMask[Elt] = BaseMaskElt + 1;
10910     assert(PermuteMask[Elt] != -1 &&
10911            "Input mask element is defined but failed to assign permute mask");
10912   }
10913 
10914   unsigned UnpckOp = MatchLo ? X86ISD::UNPCKL : X86ISD::UNPCKH;
10915   SDValue Unpck = DAG.getNode(UnpckOp, DL, VT, Ops);
10916   return DAG.getVectorShuffle(VT, DL, Unpck, DAG.getUNDEF(VT), PermuteMask);
10917 }
10918 
10919 /// Try to lower a shuffle as a permute of the inputs followed by an
10920 /// UNPCK instruction.
10921 ///
10922 /// This specifically targets cases where we end up with alternating between
10923 /// the two inputs, and so can permute them into something that feeds a single
10924 /// UNPCK instruction. Note that this routine only targets integer vectors
10925 /// because for floating point vectors we have a generalized SHUFPS lowering
10926 /// strategy that handles everything that doesn't *exactly* match an unpack,
10927 /// making this clever lowering unnecessary.
lowerShuffleAsPermuteAndUnpack(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)10928 static SDValue lowerShuffleAsPermuteAndUnpack(const SDLoc &DL, MVT VT,
10929                                               SDValue V1, SDValue V2,
10930                                               ArrayRef<int> Mask,
10931                                               const X86Subtarget &Subtarget,
10932                                               SelectionDAG &DAG) {
10933   int Size = Mask.size();
10934   assert(Mask.size() >= 2 && "Single element masks are invalid.");
10935 
10936   // This routine only supports 128-bit integer dual input vectors.
10937   if (VT.isFloatingPoint() || !VT.is128BitVector() || V2.isUndef())
10938     return SDValue();
10939 
10940   int NumLoInputs =
10941       count_if(Mask, [Size](int M) { return M >= 0 && M % Size < Size / 2; });
10942   int NumHiInputs =
10943       count_if(Mask, [Size](int M) { return M % Size >= Size / 2; });
10944 
10945   bool UnpackLo = NumLoInputs >= NumHiInputs;
10946 
10947   auto TryUnpack = [&](int ScalarSize, int Scale) {
10948     SmallVector<int, 16> V1Mask((unsigned)Size, -1);
10949     SmallVector<int, 16> V2Mask((unsigned)Size, -1);
10950 
10951     for (int i = 0; i < Size; ++i) {
10952       if (Mask[i] < 0)
10953         continue;
10954 
10955       // Each element of the unpack contains Scale elements from this mask.
10956       int UnpackIdx = i / Scale;
10957 
10958       // We only handle the case where V1 feeds the first slots of the unpack.
10959       // We rely on canonicalization to ensure this is the case.
10960       if ((UnpackIdx % 2 == 0) != (Mask[i] < Size))
10961         return SDValue();
10962 
10963       // Setup the mask for this input. The indexing is tricky as we have to
10964       // handle the unpack stride.
10965       SmallVectorImpl<int> &VMask = (UnpackIdx % 2 == 0) ? V1Mask : V2Mask;
10966       VMask[(UnpackIdx / 2) * Scale + i % Scale + (UnpackLo ? 0 : Size / 2)] =
10967           Mask[i] % Size;
10968     }
10969 
10970     // If we will have to shuffle both inputs to use the unpack, check whether
10971     // we can just unpack first and shuffle the result. If so, skip this unpack.
10972     if ((NumLoInputs == 0 || NumHiInputs == 0) && !isNoopShuffleMask(V1Mask) &&
10973         !isNoopShuffleMask(V2Mask))
10974       return SDValue();
10975 
10976     // Shuffle the inputs into place.
10977     V1 = DAG.getVectorShuffle(VT, DL, V1, DAG.getUNDEF(VT), V1Mask);
10978     V2 = DAG.getVectorShuffle(VT, DL, V2, DAG.getUNDEF(VT), V2Mask);
10979 
10980     // Cast the inputs to the type we will use to unpack them.
10981     MVT UnpackVT =
10982         MVT::getVectorVT(MVT::getIntegerVT(ScalarSize), Size / Scale);
10983     V1 = DAG.getBitcast(UnpackVT, V1);
10984     V2 = DAG.getBitcast(UnpackVT, V2);
10985 
10986     // Unpack the inputs and cast the result back to the desired type.
10987     return DAG.getBitcast(
10988         VT, DAG.getNode(UnpackLo ? X86ISD::UNPCKL : X86ISD::UNPCKH, DL,
10989                         UnpackVT, V1, V2));
10990   };
10991 
10992   // We try each unpack from the largest to the smallest to try and find one
10993   // that fits this mask.
10994   int OrigScalarSize = VT.getScalarSizeInBits();
10995   for (int ScalarSize = 64; ScalarSize >= OrigScalarSize; ScalarSize /= 2)
10996     if (SDValue Unpack = TryUnpack(ScalarSize, ScalarSize / OrigScalarSize))
10997       return Unpack;
10998 
10999   // If we're shuffling with a zero vector then we're better off not doing
11000   // VECTOR_SHUFFLE(UNPCK()) as we lose track of those zero elements.
11001   if (ISD::isBuildVectorAllZeros(V1.getNode()) ||
11002       ISD::isBuildVectorAllZeros(V2.getNode()))
11003     return SDValue();
11004 
11005   // If none of the unpack-rooted lowerings worked (or were profitable) try an
11006   // initial unpack.
11007   if (NumLoInputs == 0 || NumHiInputs == 0) {
11008     assert((NumLoInputs > 0 || NumHiInputs > 0) &&
11009            "We have to have *some* inputs!");
11010     int HalfOffset = NumLoInputs == 0 ? Size / 2 : 0;
11011 
11012     // FIXME: We could consider the total complexity of the permute of each
11013     // possible unpacking. Or at the least we should consider how many
11014     // half-crossings are created.
11015     // FIXME: We could consider commuting the unpacks.
11016 
11017     SmallVector<int, 32> PermMask((unsigned)Size, -1);
11018     for (int i = 0; i < Size; ++i) {
11019       if (Mask[i] < 0)
11020         continue;
11021 
11022       assert(Mask[i] % Size >= HalfOffset && "Found input from wrong half!");
11023 
11024       PermMask[i] =
11025           2 * ((Mask[i] % Size) - HalfOffset) + (Mask[i] < Size ? 0 : 1);
11026     }
11027     return DAG.getVectorShuffle(
11028         VT, DL,
11029         DAG.getNode(NumLoInputs == 0 ? X86ISD::UNPCKH : X86ISD::UNPCKL, DL, VT,
11030                     V1, V2),
11031         DAG.getUNDEF(VT), PermMask);
11032   }
11033 
11034   return SDValue();
11035 }
11036 
11037 /// Helper to form a PALIGNR-based rotate+permute, merging 2 inputs and then
11038 /// permuting the elements of the result in place.
lowerShuffleAsByteRotateAndPermute(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)11039 static SDValue lowerShuffleAsByteRotateAndPermute(
11040     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
11041     const X86Subtarget &Subtarget, SelectionDAG &DAG) {
11042   if ((VT.is128BitVector() && !Subtarget.hasSSSE3()) ||
11043       (VT.is256BitVector() && !Subtarget.hasAVX2()) ||
11044       (VT.is512BitVector() && !Subtarget.hasBWI()))
11045     return SDValue();
11046 
11047   // We don't currently support lane crossing permutes.
11048   if (is128BitLaneCrossingShuffleMask(VT, Mask))
11049     return SDValue();
11050 
11051   int Scale = VT.getScalarSizeInBits() / 8;
11052   int NumLanes = VT.getSizeInBits() / 128;
11053   int NumElts = VT.getVectorNumElements();
11054   int NumEltsPerLane = NumElts / NumLanes;
11055 
11056   // Determine range of mask elts.
11057   bool Blend1 = true;
11058   bool Blend2 = true;
11059   std::pair<int, int> Range1 = std::make_pair(INT_MAX, INT_MIN);
11060   std::pair<int, int> Range2 = std::make_pair(INT_MAX, INT_MIN);
11061   for (int Lane = 0; Lane != NumElts; Lane += NumEltsPerLane) {
11062     for (int Elt = 0; Elt != NumEltsPerLane; ++Elt) {
11063       int M = Mask[Lane + Elt];
11064       if (M < 0)
11065         continue;
11066       if (M < NumElts) {
11067         Blend1 &= (M == (Lane + Elt));
11068         assert(Lane <= M && M < (Lane + NumEltsPerLane) && "Out of range mask");
11069         M = M % NumEltsPerLane;
11070         Range1.first = std::min(Range1.first, M);
11071         Range1.second = std::max(Range1.second, M);
11072       } else {
11073         M -= NumElts;
11074         Blend2 &= (M == (Lane + Elt));
11075         assert(Lane <= M && M < (Lane + NumEltsPerLane) && "Out of range mask");
11076         M = M % NumEltsPerLane;
11077         Range2.first = std::min(Range2.first, M);
11078         Range2.second = std::max(Range2.second, M);
11079       }
11080     }
11081   }
11082 
11083   // Bail if we don't need both elements.
11084   // TODO - it might be worth doing this for unary shuffles if the permute
11085   // can be widened.
11086   if (!(0 <= Range1.first && Range1.second < NumEltsPerLane) ||
11087       !(0 <= Range2.first && Range2.second < NumEltsPerLane))
11088     return SDValue();
11089 
11090   if (VT.getSizeInBits() > 128 && (Blend1 || Blend2))
11091     return SDValue();
11092 
11093   // Rotate the 2 ops so we can access both ranges, then permute the result.
11094   auto RotateAndPermute = [&](SDValue Lo, SDValue Hi, int RotAmt, int Ofs) {
11095     MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8);
11096     SDValue Rotate = DAG.getBitcast(
11097         VT, DAG.getNode(X86ISD::PALIGNR, DL, ByteVT, DAG.getBitcast(ByteVT, Hi),
11098                         DAG.getBitcast(ByteVT, Lo),
11099                         DAG.getTargetConstant(Scale * RotAmt, DL, MVT::i8)));
11100     SmallVector<int, 64> PermMask(NumElts, SM_SentinelUndef);
11101     for (int Lane = 0; Lane != NumElts; Lane += NumEltsPerLane) {
11102       for (int Elt = 0; Elt != NumEltsPerLane; ++Elt) {
11103         int M = Mask[Lane + Elt];
11104         if (M < 0)
11105           continue;
11106         if (M < NumElts)
11107           PermMask[Lane + Elt] = Lane + ((M + Ofs - RotAmt) % NumEltsPerLane);
11108         else
11109           PermMask[Lane + Elt] = Lane + ((M - Ofs - RotAmt) % NumEltsPerLane);
11110       }
11111     }
11112     return DAG.getVectorShuffle(VT, DL, Rotate, DAG.getUNDEF(VT), PermMask);
11113   };
11114 
11115   // Check if the ranges are small enough to rotate from either direction.
11116   if (Range2.second < Range1.first)
11117     return RotateAndPermute(V1, V2, Range1.first, 0);
11118   if (Range1.second < Range2.first)
11119     return RotateAndPermute(V2, V1, Range2.first, NumElts);
11120   return SDValue();
11121 }
11122 
isBroadcastShuffleMask(ArrayRef<int> Mask)11123 static bool isBroadcastShuffleMask(ArrayRef<int> Mask) {
11124   return isUndefOrEqual(Mask, 0);
11125 }
11126 
isNoopOrBroadcastShuffleMask(ArrayRef<int> Mask)11127 static bool isNoopOrBroadcastShuffleMask(ArrayRef<int> Mask) {
11128   return isNoopShuffleMask(Mask) || isBroadcastShuffleMask(Mask);
11129 }
11130 
11131 /// Check if the Mask consists of the same element repeated multiple times.
isSingleElementRepeatedMask(ArrayRef<int> Mask)11132 static bool isSingleElementRepeatedMask(ArrayRef<int> Mask) {
11133   size_t NumUndefs = 0;
11134   std::optional<int> UniqueElt;
11135   for (int Elt : Mask) {
11136     if (Elt == SM_SentinelUndef) {
11137       NumUndefs++;
11138       continue;
11139     }
11140     if (UniqueElt.has_value() && UniqueElt.value() != Elt)
11141       return false;
11142     UniqueElt = Elt;
11143   }
11144   // Make sure the element is repeated enough times by checking the number of
11145   // undefs is small.
11146   return NumUndefs <= Mask.size() / 2 && UniqueElt.has_value();
11147 }
11148 
11149 /// Generic routine to decompose a shuffle and blend into independent
11150 /// blends and permutes.
11151 ///
11152 /// This matches the extremely common pattern for handling combined
11153 /// shuffle+blend operations on newer X86 ISAs where we have very fast blend
11154 /// operations. It will try to pick the best arrangement of shuffles and
11155 /// blends. For vXi8/vXi16 shuffles we may use unpack instead of blend.
lowerShuffleAsDecomposedShuffleMerge(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)11156 static SDValue lowerShuffleAsDecomposedShuffleMerge(
11157     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
11158     const X86Subtarget &Subtarget, SelectionDAG &DAG) {
11159   int NumElts = Mask.size();
11160   int NumLanes = VT.getSizeInBits() / 128;
11161   int NumEltsPerLane = NumElts / NumLanes;
11162 
11163   // Shuffle the input elements into the desired positions in V1 and V2 and
11164   // unpack/blend them together.
11165   bool IsAlternating = true;
11166   SmallVector<int, 32> V1Mask(NumElts, -1);
11167   SmallVector<int, 32> V2Mask(NumElts, -1);
11168   SmallVector<int, 32> FinalMask(NumElts, -1);
11169   for (int i = 0; i < NumElts; ++i) {
11170     int M = Mask[i];
11171     if (M >= 0 && M < NumElts) {
11172       V1Mask[i] = M;
11173       FinalMask[i] = i;
11174       IsAlternating &= (i & 1) == 0;
11175     } else if (M >= NumElts) {
11176       V2Mask[i] = M - NumElts;
11177       FinalMask[i] = i + NumElts;
11178       IsAlternating &= (i & 1) == 1;
11179     }
11180   }
11181 
11182   // If we effectively only demand the 0'th element of \p Input, and not only
11183   // as 0'th element, then broadcast said input,
11184   // and change \p InputMask to be a no-op (identity) mask.
11185   auto canonicalizeBroadcastableInput = [DL, VT, &Subtarget,
11186                                          &DAG](SDValue &Input,
11187                                                MutableArrayRef<int> InputMask) {
11188     unsigned EltSizeInBits = Input.getScalarValueSizeInBits();
11189     if (!Subtarget.hasAVX2() && (!Subtarget.hasAVX() || EltSizeInBits < 32 ||
11190                                  !X86::mayFoldLoad(Input, Subtarget)))
11191       return;
11192     if (isNoopShuffleMask(InputMask))
11193       return;
11194     assert(isBroadcastShuffleMask(InputMask) &&
11195            "Expected to demand only the 0'th element.");
11196     Input = DAG.getNode(X86ISD::VBROADCAST, DL, VT, Input);
11197     for (auto I : enumerate(InputMask)) {
11198       int &InputMaskElt = I.value();
11199       if (InputMaskElt >= 0)
11200         InputMaskElt = I.index();
11201     }
11202   };
11203 
11204   // Currently, we may need to produce one shuffle per input, and blend results.
11205   // It is possible that the shuffle for one of the inputs is already a no-op.
11206   // See if we can simplify non-no-op shuffles into broadcasts,
11207   // which we consider to be strictly better than an arbitrary shuffle.
11208   if (isNoopOrBroadcastShuffleMask(V1Mask) &&
11209       isNoopOrBroadcastShuffleMask(V2Mask)) {
11210     canonicalizeBroadcastableInput(V1, V1Mask);
11211     canonicalizeBroadcastableInput(V2, V2Mask);
11212   }
11213 
11214   // Try to lower with the simpler initial blend/unpack/rotate strategies unless
11215   // one of the input shuffles would be a no-op. We prefer to shuffle inputs as
11216   // the shuffle may be able to fold with a load or other benefit. However, when
11217   // we'll have to do 2x as many shuffles in order to achieve this, a 2-input
11218   // pre-shuffle first is a better strategy.
11219   if (!isNoopShuffleMask(V1Mask) && !isNoopShuffleMask(V2Mask)) {
11220     // Only prefer immediate blends to unpack/rotate.
11221     if (SDValue BlendPerm = lowerShuffleAsBlendAndPermute(DL, VT, V1, V2, Mask,
11222                                                           DAG, true))
11223       return BlendPerm;
11224     // If either input vector provides only a single element which is repeated
11225     // multiple times, unpacking from both input vectors would generate worse
11226     // code. e.g. for
11227     // t5: v16i8 = vector_shuffle<16,0,16,1,16,2,16,3,16,4,16,5,16,6,16,7> t2, t4
11228     // it is better to process t4 first to create a vector of t4[0], then unpack
11229     // that vector with t2.
11230     if (!isSingleElementRepeatedMask(V1Mask) &&
11231         !isSingleElementRepeatedMask(V2Mask))
11232       if (SDValue UnpackPerm =
11233               lowerShuffleAsUNPCKAndPermute(DL, VT, V1, V2, Mask, DAG))
11234         return UnpackPerm;
11235     if (SDValue RotatePerm = lowerShuffleAsByteRotateAndPermute(
11236             DL, VT, V1, V2, Mask, Subtarget, DAG))
11237       return RotatePerm;
11238     // Unpack/rotate failed - try again with variable blends.
11239     if (SDValue BlendPerm = lowerShuffleAsBlendAndPermute(DL, VT, V1, V2, Mask,
11240                                                           DAG))
11241       return BlendPerm;
11242     if (VT.getScalarSizeInBits() >= 32)
11243       if (SDValue PermUnpack = lowerShuffleAsPermuteAndUnpack(
11244               DL, VT, V1, V2, Mask, Subtarget, DAG))
11245         return PermUnpack;
11246   }
11247 
11248   // If the final mask is an alternating blend of vXi8/vXi16, convert to an
11249   // UNPCKL(SHUFFLE, SHUFFLE) pattern.
11250   // TODO: It doesn't have to be alternating - but each lane mustn't have more
11251   // than half the elements coming from each source.
11252   if (IsAlternating && VT.getScalarSizeInBits() < 32) {
11253     V1Mask.assign(NumElts, -1);
11254     V2Mask.assign(NumElts, -1);
11255     FinalMask.assign(NumElts, -1);
11256     for (int i = 0; i != NumElts; i += NumEltsPerLane)
11257       for (int j = 0; j != NumEltsPerLane; ++j) {
11258         int M = Mask[i + j];
11259         if (M >= 0 && M < NumElts) {
11260           V1Mask[i + (j / 2)] = M;
11261           FinalMask[i + j] = i + (j / 2);
11262         } else if (M >= NumElts) {
11263           V2Mask[i + (j / 2)] = M - NumElts;
11264           FinalMask[i + j] = i + (j / 2) + NumElts;
11265         }
11266       }
11267   }
11268 
11269   V1 = DAG.getVectorShuffle(VT, DL, V1, DAG.getUNDEF(VT), V1Mask);
11270   V2 = DAG.getVectorShuffle(VT, DL, V2, DAG.getUNDEF(VT), V2Mask);
11271   return DAG.getVectorShuffle(VT, DL, V1, V2, FinalMask);
11272 }
11273 
matchShuffleAsBitRotate(MVT & RotateVT,int EltSizeInBits,const X86Subtarget & Subtarget,ArrayRef<int> Mask)11274 static int matchShuffleAsBitRotate(MVT &RotateVT, int EltSizeInBits,
11275                                    const X86Subtarget &Subtarget,
11276                                    ArrayRef<int> Mask) {
11277   assert(!isNoopShuffleMask(Mask) && "We shouldn't lower no-op shuffles!");
11278   assert(EltSizeInBits < 64 && "Can't rotate 64-bit integers");
11279 
11280   // AVX512 only has vXi32/vXi64 rotates, so limit the rotation sub group size.
11281   int MinSubElts = Subtarget.hasAVX512() ? std::max(32 / EltSizeInBits, 2) : 2;
11282   int MaxSubElts = 64 / EltSizeInBits;
11283   unsigned RotateAmt, NumSubElts;
11284   if (!ShuffleVectorInst::isBitRotateMask(Mask, EltSizeInBits, MinSubElts,
11285                                           MaxSubElts, NumSubElts, RotateAmt))
11286     return -1;
11287   unsigned NumElts = Mask.size();
11288   MVT RotateSVT = MVT::getIntegerVT(EltSizeInBits * NumSubElts);
11289   RotateVT = MVT::getVectorVT(RotateSVT, NumElts / NumSubElts);
11290   return RotateAmt;
11291 }
11292 
11293 /// Lower shuffle using X86ISD::VROTLI rotations.
lowerShuffleAsBitRotate(const SDLoc & DL,MVT VT,SDValue V1,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)11294 static SDValue lowerShuffleAsBitRotate(const SDLoc &DL, MVT VT, SDValue V1,
11295                                        ArrayRef<int> Mask,
11296                                        const X86Subtarget &Subtarget,
11297                                        SelectionDAG &DAG) {
11298   // Only XOP + AVX512 targets have bit rotation instructions.
11299   // If we at least have SSSE3 (PSHUFB) then we shouldn't attempt to use this.
11300   bool IsLegal =
11301       (VT.is128BitVector() && Subtarget.hasXOP()) || Subtarget.hasAVX512();
11302   if (!IsLegal && Subtarget.hasSSE3())
11303     return SDValue();
11304 
11305   MVT RotateVT;
11306   int RotateAmt = matchShuffleAsBitRotate(RotateVT, VT.getScalarSizeInBits(),
11307                                           Subtarget, Mask);
11308   if (RotateAmt < 0)
11309     return SDValue();
11310 
11311   // For pre-SSSE3 targets, if we are shuffling vXi8 elts then ISD::ROTL,
11312   // expanded to OR(SRL,SHL), will be more efficient, but if they can
11313   // widen to vXi16 or more then existing lowering should will be better.
11314   if (!IsLegal) {
11315     if ((RotateAmt % 16) == 0)
11316       return SDValue();
11317     // TODO: Use getTargetVShiftByConstNode.
11318     unsigned ShlAmt = RotateAmt;
11319     unsigned SrlAmt = RotateVT.getScalarSizeInBits() - RotateAmt;
11320     V1 = DAG.getBitcast(RotateVT, V1);
11321     SDValue SHL = DAG.getNode(X86ISD::VSHLI, DL, RotateVT, V1,
11322                               DAG.getTargetConstant(ShlAmt, DL, MVT::i8));
11323     SDValue SRL = DAG.getNode(X86ISD::VSRLI, DL, RotateVT, V1,
11324                               DAG.getTargetConstant(SrlAmt, DL, MVT::i8));
11325     SDValue Rot = DAG.getNode(ISD::OR, DL, RotateVT, SHL, SRL);
11326     return DAG.getBitcast(VT, Rot);
11327   }
11328 
11329   SDValue Rot =
11330       DAG.getNode(X86ISD::VROTLI, DL, RotateVT, DAG.getBitcast(RotateVT, V1),
11331                   DAG.getTargetConstant(RotateAmt, DL, MVT::i8));
11332   return DAG.getBitcast(VT, Rot);
11333 }
11334 
11335 /// Try to match a vector shuffle as an element rotation.
11336 ///
11337 /// This is used for support PALIGNR for SSSE3 or VALIGND/Q for AVX512.
matchShuffleAsElementRotate(SDValue & V1,SDValue & V2,ArrayRef<int> Mask)11338 static int matchShuffleAsElementRotate(SDValue &V1, SDValue &V2,
11339                                        ArrayRef<int> Mask) {
11340   int NumElts = Mask.size();
11341 
11342   // We need to detect various ways of spelling a rotation:
11343   //   [11, 12, 13, 14, 15,  0,  1,  2]
11344   //   [-1, 12, 13, 14, -1, -1,  1, -1]
11345   //   [-1, -1, -1, -1, -1, -1,  1,  2]
11346   //   [ 3,  4,  5,  6,  7,  8,  9, 10]
11347   //   [-1,  4,  5,  6, -1, -1,  9, -1]
11348   //   [-1,  4,  5,  6, -1, -1, -1, -1]
11349   int Rotation = 0;
11350   SDValue Lo, Hi;
11351   for (int i = 0; i < NumElts; ++i) {
11352     int M = Mask[i];
11353     assert((M == SM_SentinelUndef || (0 <= M && M < (2*NumElts))) &&
11354            "Unexpected mask index.");
11355     if (M < 0)
11356       continue;
11357 
11358     // Determine where a rotated vector would have started.
11359     int StartIdx = i - (M % NumElts);
11360     if (StartIdx == 0)
11361       // The identity rotation isn't interesting, stop.
11362       return -1;
11363 
11364     // If we found the tail of a vector the rotation must be the missing
11365     // front. If we found the head of a vector, it must be how much of the
11366     // head.
11367     int CandidateRotation = StartIdx < 0 ? -StartIdx : NumElts - StartIdx;
11368 
11369     if (Rotation == 0)
11370       Rotation = CandidateRotation;
11371     else if (Rotation != CandidateRotation)
11372       // The rotations don't match, so we can't match this mask.
11373       return -1;
11374 
11375     // Compute which value this mask is pointing at.
11376     SDValue MaskV = M < NumElts ? V1 : V2;
11377 
11378     // Compute which of the two target values this index should be assigned
11379     // to. This reflects whether the high elements are remaining or the low
11380     // elements are remaining.
11381     SDValue &TargetV = StartIdx < 0 ? Hi : Lo;
11382 
11383     // Either set up this value if we've not encountered it before, or check
11384     // that it remains consistent.
11385     if (!TargetV)
11386       TargetV = MaskV;
11387     else if (TargetV != MaskV)
11388       // This may be a rotation, but it pulls from the inputs in some
11389       // unsupported interleaving.
11390       return -1;
11391   }
11392 
11393   // Check that we successfully analyzed the mask, and normalize the results.
11394   assert(Rotation != 0 && "Failed to locate a viable rotation!");
11395   assert((Lo || Hi) && "Failed to find a rotated input vector!");
11396   if (!Lo)
11397     Lo = Hi;
11398   else if (!Hi)
11399     Hi = Lo;
11400 
11401   V1 = Lo;
11402   V2 = Hi;
11403 
11404   return Rotation;
11405 }
11406 
11407 /// Try to lower a vector shuffle as a byte rotation.
11408 ///
11409 /// SSSE3 has a generic PALIGNR instruction in x86 that will do an arbitrary
11410 /// byte-rotation of the concatenation of two vectors; pre-SSSE3 can use
11411 /// a PSRLDQ/PSLLDQ/POR pattern to get a similar effect. This routine will
11412 /// try to generically lower a vector shuffle through such an pattern. It
11413 /// does not check for the profitability of lowering either as PALIGNR or
11414 /// PSRLDQ/PSLLDQ/POR, only whether the mask is valid to lower in that form.
11415 /// This matches shuffle vectors that look like:
11416 ///
11417 ///   v8i16 [11, 12, 13, 14, 15, 0, 1, 2]
11418 ///
11419 /// Essentially it concatenates V1 and V2, shifts right by some number of
11420 /// elements, and takes the low elements as the result. Note that while this is
11421 /// specified as a *right shift* because x86 is little-endian, it is a *left
11422 /// rotate* of the vector lanes.
matchShuffleAsByteRotate(MVT VT,SDValue & V1,SDValue & V2,ArrayRef<int> Mask)11423 static int matchShuffleAsByteRotate(MVT VT, SDValue &V1, SDValue &V2,
11424                                     ArrayRef<int> Mask) {
11425   // Don't accept any shuffles with zero elements.
11426   if (isAnyZero(Mask))
11427     return -1;
11428 
11429   // PALIGNR works on 128-bit lanes.
11430   SmallVector<int, 16> RepeatedMask;
11431   if (!is128BitLaneRepeatedShuffleMask(VT, Mask, RepeatedMask))
11432     return -1;
11433 
11434   int Rotation = matchShuffleAsElementRotate(V1, V2, RepeatedMask);
11435   if (Rotation <= 0)
11436     return -1;
11437 
11438   // PALIGNR rotates bytes, so we need to scale the
11439   // rotation based on how many bytes are in the vector lane.
11440   int NumElts = RepeatedMask.size();
11441   int Scale = 16 / NumElts;
11442   return Rotation * Scale;
11443 }
11444 
lowerShuffleAsByteRotate(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)11445 static SDValue lowerShuffleAsByteRotate(const SDLoc &DL, MVT VT, SDValue V1,
11446                                         SDValue V2, ArrayRef<int> Mask,
11447                                         const X86Subtarget &Subtarget,
11448                                         SelectionDAG &DAG) {
11449   assert(!isNoopShuffleMask(Mask) && "We shouldn't lower no-op shuffles!");
11450 
11451   SDValue Lo = V1, Hi = V2;
11452   int ByteRotation = matchShuffleAsByteRotate(VT, Lo, Hi, Mask);
11453   if (ByteRotation <= 0)
11454     return SDValue();
11455 
11456   // Cast the inputs to i8 vector of correct length to match PALIGNR or
11457   // PSLLDQ/PSRLDQ.
11458   MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8);
11459   Lo = DAG.getBitcast(ByteVT, Lo);
11460   Hi = DAG.getBitcast(ByteVT, Hi);
11461 
11462   // SSSE3 targets can use the palignr instruction.
11463   if (Subtarget.hasSSSE3()) {
11464     assert((!VT.is512BitVector() || Subtarget.hasBWI()) &&
11465            "512-bit PALIGNR requires BWI instructions");
11466     return DAG.getBitcast(
11467         VT, DAG.getNode(X86ISD::PALIGNR, DL, ByteVT, Lo, Hi,
11468                         DAG.getTargetConstant(ByteRotation, DL, MVT::i8)));
11469   }
11470 
11471   assert(VT.is128BitVector() &&
11472          "Rotate-based lowering only supports 128-bit lowering!");
11473   assert(Mask.size() <= 16 &&
11474          "Can shuffle at most 16 bytes in a 128-bit vector!");
11475   assert(ByteVT == MVT::v16i8 &&
11476          "SSE2 rotate lowering only needed for v16i8!");
11477 
11478   // Default SSE2 implementation
11479   int LoByteShift = 16 - ByteRotation;
11480   int HiByteShift = ByteRotation;
11481 
11482   SDValue LoShift =
11483       DAG.getNode(X86ISD::VSHLDQ, DL, MVT::v16i8, Lo,
11484                   DAG.getTargetConstant(LoByteShift, DL, MVT::i8));
11485   SDValue HiShift =
11486       DAG.getNode(X86ISD::VSRLDQ, DL, MVT::v16i8, Hi,
11487                   DAG.getTargetConstant(HiByteShift, DL, MVT::i8));
11488   return DAG.getBitcast(VT,
11489                         DAG.getNode(ISD::OR, DL, MVT::v16i8, LoShift, HiShift));
11490 }
11491 
11492 /// Try to lower a vector shuffle as a dword/qword rotation.
11493 ///
11494 /// AVX512 has a VALIGND/VALIGNQ instructions that will do an arbitrary
11495 /// rotation of the concatenation of two vectors; This routine will
11496 /// try to generically lower a vector shuffle through such an pattern.
11497 ///
11498 /// Essentially it concatenates V1 and V2, shifts right by some number of
11499 /// elements, and takes the low elements as the result. Note that while this is
11500 /// specified as a *right shift* because x86 is little-endian, it is a *left
11501 /// rotate* of the vector lanes.
lowerShuffleAsVALIGN(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)11502 static SDValue lowerShuffleAsVALIGN(const SDLoc &DL, MVT VT, SDValue V1,
11503                                     SDValue V2, ArrayRef<int> Mask,
11504                                     const APInt &Zeroable,
11505                                     const X86Subtarget &Subtarget,
11506                                     SelectionDAG &DAG) {
11507   assert((VT.getScalarType() == MVT::i32 || VT.getScalarType() == MVT::i64) &&
11508          "Only 32-bit and 64-bit elements are supported!");
11509 
11510   // 128/256-bit vectors are only supported with VLX.
11511   assert((Subtarget.hasVLX() || (!VT.is128BitVector() && !VT.is256BitVector()))
11512          && "VLX required for 128/256-bit vectors");
11513 
11514   SDValue Lo = V1, Hi = V2;
11515   int Rotation = matchShuffleAsElementRotate(Lo, Hi, Mask);
11516   if (0 < Rotation)
11517     return DAG.getNode(X86ISD::VALIGN, DL, VT, Lo, Hi,
11518                        DAG.getTargetConstant(Rotation, DL, MVT::i8));
11519 
11520   // See if we can use VALIGN as a cross-lane version of VSHLDQ/VSRLDQ.
11521   // TODO: Pull this out as a matchShuffleAsElementShift helper?
11522   // TODO: We can probably make this more aggressive and use shift-pairs like
11523   // lowerShuffleAsByteShiftMask.
11524   unsigned NumElts = Mask.size();
11525   unsigned ZeroLo = Zeroable.countr_one();
11526   unsigned ZeroHi = Zeroable.countl_one();
11527   assert((ZeroLo + ZeroHi) < NumElts && "Zeroable shuffle detected");
11528   if (!ZeroLo && !ZeroHi)
11529     return SDValue();
11530 
11531   if (ZeroLo) {
11532     SDValue Src = Mask[ZeroLo] < (int)NumElts ? V1 : V2;
11533     int Low = Mask[ZeroLo] < (int)NumElts ? 0 : NumElts;
11534     if (isSequentialOrUndefInRange(Mask, ZeroLo, NumElts - ZeroLo, Low))
11535       return DAG.getNode(X86ISD::VALIGN, DL, VT, Src,
11536                          getZeroVector(VT, Subtarget, DAG, DL),
11537                          DAG.getTargetConstant(NumElts - ZeroLo, DL, MVT::i8));
11538   }
11539 
11540   if (ZeroHi) {
11541     SDValue Src = Mask[0] < (int)NumElts ? V1 : V2;
11542     int Low = Mask[0] < (int)NumElts ? 0 : NumElts;
11543     if (isSequentialOrUndefInRange(Mask, 0, NumElts - ZeroHi, Low + ZeroHi))
11544       return DAG.getNode(X86ISD::VALIGN, DL, VT,
11545                          getZeroVector(VT, Subtarget, DAG, DL), Src,
11546                          DAG.getTargetConstant(ZeroHi, DL, MVT::i8));
11547   }
11548 
11549   return SDValue();
11550 }
11551 
11552 /// Try to lower a vector shuffle as a byte shift sequence.
lowerShuffleAsByteShiftMask(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)11553 static SDValue lowerShuffleAsByteShiftMask(const SDLoc &DL, MVT VT, SDValue V1,
11554                                            SDValue V2, ArrayRef<int> Mask,
11555                                            const APInt &Zeroable,
11556                                            const X86Subtarget &Subtarget,
11557                                            SelectionDAG &DAG) {
11558   assert(!isNoopShuffleMask(Mask) && "We shouldn't lower no-op shuffles!");
11559   assert(VT.is128BitVector() && "Only 128-bit vectors supported");
11560 
11561   // We need a shuffle that has zeros at one/both ends and a sequential
11562   // shuffle from one source within.
11563   unsigned ZeroLo = Zeroable.countr_one();
11564   unsigned ZeroHi = Zeroable.countl_one();
11565   if (!ZeroLo && !ZeroHi)
11566     return SDValue();
11567 
11568   unsigned NumElts = Mask.size();
11569   unsigned Len = NumElts - (ZeroLo + ZeroHi);
11570   if (!isSequentialOrUndefInRange(Mask, ZeroLo, Len, Mask[ZeroLo]))
11571     return SDValue();
11572 
11573   unsigned Scale = VT.getScalarSizeInBits() / 8;
11574   ArrayRef<int> StubMask = Mask.slice(ZeroLo, Len);
11575   if (!isUndefOrInRange(StubMask, 0, NumElts) &&
11576       !isUndefOrInRange(StubMask, NumElts, 2 * NumElts))
11577     return SDValue();
11578 
11579   SDValue Res = Mask[ZeroLo] < (int)NumElts ? V1 : V2;
11580   Res = DAG.getBitcast(MVT::v16i8, Res);
11581 
11582   // Use VSHLDQ/VSRLDQ ops to zero the ends of a vector and leave an
11583   // inner sequential set of elements, possibly offset:
11584   // 01234567 --> zzzzzz01 --> 1zzzzzzz
11585   // 01234567 --> 4567zzzz --> zzzzz456
11586   // 01234567 --> z0123456 --> 3456zzzz --> zz3456zz
11587   if (ZeroLo == 0) {
11588     unsigned Shift = (NumElts - 1) - (Mask[ZeroLo + Len - 1] % NumElts);
11589     Res = DAG.getNode(X86ISD::VSHLDQ, DL, MVT::v16i8, Res,
11590                       DAG.getTargetConstant(Scale * Shift, DL, MVT::i8));
11591     Res = DAG.getNode(X86ISD::VSRLDQ, DL, MVT::v16i8, Res,
11592                       DAG.getTargetConstant(Scale * ZeroHi, DL, MVT::i8));
11593   } else if (ZeroHi == 0) {
11594     unsigned Shift = Mask[ZeroLo] % NumElts;
11595     Res = DAG.getNode(X86ISD::VSRLDQ, DL, MVT::v16i8, Res,
11596                       DAG.getTargetConstant(Scale * Shift, DL, MVT::i8));
11597     Res = DAG.getNode(X86ISD::VSHLDQ, DL, MVT::v16i8, Res,
11598                       DAG.getTargetConstant(Scale * ZeroLo, DL, MVT::i8));
11599   } else if (!Subtarget.hasSSSE3()) {
11600     // If we don't have PSHUFB then its worth avoiding an AND constant mask
11601     // by performing 3 byte shifts. Shuffle combining can kick in above that.
11602     // TODO: There may be some cases where VSH{LR}DQ+PAND is still better.
11603     unsigned Shift = (NumElts - 1) - (Mask[ZeroLo + Len - 1] % NumElts);
11604     Res = DAG.getNode(X86ISD::VSHLDQ, DL, MVT::v16i8, Res,
11605                       DAG.getTargetConstant(Scale * Shift, DL, MVT::i8));
11606     Shift += Mask[ZeroLo] % NumElts;
11607     Res = DAG.getNode(X86ISD::VSRLDQ, DL, MVT::v16i8, Res,
11608                       DAG.getTargetConstant(Scale * Shift, DL, MVT::i8));
11609     Res = DAG.getNode(X86ISD::VSHLDQ, DL, MVT::v16i8, Res,
11610                       DAG.getTargetConstant(Scale * ZeroLo, DL, MVT::i8));
11611   } else
11612     return SDValue();
11613 
11614   return DAG.getBitcast(VT, Res);
11615 }
11616 
11617 /// Try to lower a vector shuffle as a bit shift (shifts in zeros).
11618 ///
11619 /// Attempts to match a shuffle mask against the PSLL(W/D/Q/DQ) and
11620 /// PSRL(W/D/Q/DQ) SSE2 and AVX2 logical bit-shift instructions. The function
11621 /// matches elements from one of the input vectors shuffled to the left or
11622 /// right with zeroable elements 'shifted in'. It handles both the strictly
11623 /// bit-wise element shifts and the byte shift across an entire 128-bit double
11624 /// quad word lane.
11625 ///
11626 /// PSHL : (little-endian) left bit shift.
11627 /// [ zz, 0, zz,  2 ]
11628 /// [ -1, 4, zz, -1 ]
11629 /// PSRL : (little-endian) right bit shift.
11630 /// [  1, zz,  3, zz]
11631 /// [ -1, -1,  7, zz]
11632 /// PSLLDQ : (little-endian) left byte shift
11633 /// [ zz,  0,  1,  2,  3,  4,  5,  6]
11634 /// [ zz, zz, -1, -1,  2,  3,  4, -1]
11635 /// [ zz, zz, zz, zz, zz, zz, -1,  1]
11636 /// PSRLDQ : (little-endian) right byte shift
11637 /// [  5, 6,  7, zz, zz, zz, zz, zz]
11638 /// [ -1, 5,  6,  7, zz, zz, zz, zz]
11639 /// [  1, 2, -1, -1, -1, -1, zz, zz]
matchShuffleAsShift(MVT & ShiftVT,unsigned & Opcode,unsigned ScalarSizeInBits,ArrayRef<int> Mask,int MaskOffset,const APInt & Zeroable,const X86Subtarget & Subtarget)11640 static int matchShuffleAsShift(MVT &ShiftVT, unsigned &Opcode,
11641                                unsigned ScalarSizeInBits, ArrayRef<int> Mask,
11642                                int MaskOffset, const APInt &Zeroable,
11643                                const X86Subtarget &Subtarget) {
11644   int Size = Mask.size();
11645   unsigned SizeInBits = Size * ScalarSizeInBits;
11646 
11647   auto CheckZeros = [&](int Shift, int Scale, bool Left) {
11648     for (int i = 0; i < Size; i += Scale)
11649       for (int j = 0; j < Shift; ++j)
11650         if (!Zeroable[i + j + (Left ? 0 : (Scale - Shift))])
11651           return false;
11652 
11653     return true;
11654   };
11655 
11656   auto MatchShift = [&](int Shift, int Scale, bool Left) {
11657     for (int i = 0; i != Size; i += Scale) {
11658       unsigned Pos = Left ? i + Shift : i;
11659       unsigned Low = Left ? i : i + Shift;
11660       unsigned Len = Scale - Shift;
11661       if (!isSequentialOrUndefInRange(Mask, Pos, Len, Low + MaskOffset))
11662         return -1;
11663     }
11664 
11665     int ShiftEltBits = ScalarSizeInBits * Scale;
11666     bool ByteShift = ShiftEltBits > 64;
11667     Opcode = Left ? (ByteShift ? X86ISD::VSHLDQ : X86ISD::VSHLI)
11668                   : (ByteShift ? X86ISD::VSRLDQ : X86ISD::VSRLI);
11669     int ShiftAmt = Shift * ScalarSizeInBits / (ByteShift ? 8 : 1);
11670 
11671     // Normalize the scale for byte shifts to still produce an i64 element
11672     // type.
11673     Scale = ByteShift ? Scale / 2 : Scale;
11674 
11675     // We need to round trip through the appropriate type for the shift.
11676     MVT ShiftSVT = MVT::getIntegerVT(ScalarSizeInBits * Scale);
11677     ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, SizeInBits / 8)
11678                         : MVT::getVectorVT(ShiftSVT, Size / Scale);
11679     return (int)ShiftAmt;
11680   };
11681 
11682   // SSE/AVX supports logical shifts up to 64-bit integers - so we can just
11683   // keep doubling the size of the integer elements up to that. We can
11684   // then shift the elements of the integer vector by whole multiples of
11685   // their width within the elements of the larger integer vector. Test each
11686   // multiple to see if we can find a match with the moved element indices
11687   // and that the shifted in elements are all zeroable.
11688   unsigned MaxWidth = ((SizeInBits == 512) && !Subtarget.hasBWI() ? 64 : 128);
11689   for (int Scale = 2; Scale * ScalarSizeInBits <= MaxWidth; Scale *= 2)
11690     for (int Shift = 1; Shift != Scale; ++Shift)
11691       for (bool Left : {true, false})
11692         if (CheckZeros(Shift, Scale, Left)) {
11693           int ShiftAmt = MatchShift(Shift, Scale, Left);
11694           if (0 < ShiftAmt)
11695             return ShiftAmt;
11696         }
11697 
11698   // no match
11699   return -1;
11700 }
11701 
lowerShuffleAsShift(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG,bool BitwiseOnly)11702 static SDValue lowerShuffleAsShift(const SDLoc &DL, MVT VT, SDValue V1,
11703                                    SDValue V2, ArrayRef<int> Mask,
11704                                    const APInt &Zeroable,
11705                                    const X86Subtarget &Subtarget,
11706                                    SelectionDAG &DAG, bool BitwiseOnly) {
11707   int Size = Mask.size();
11708   assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size");
11709 
11710   MVT ShiftVT;
11711   SDValue V = V1;
11712   unsigned Opcode;
11713 
11714   // Try to match shuffle against V1 shift.
11715   int ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
11716                                      Mask, 0, Zeroable, Subtarget);
11717 
11718   // If V1 failed, try to match shuffle against V2 shift.
11719   if (ShiftAmt < 0) {
11720     ShiftAmt = matchShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(),
11721                                    Mask, Size, Zeroable, Subtarget);
11722     V = V2;
11723   }
11724 
11725   if (ShiftAmt < 0)
11726     return SDValue();
11727 
11728   if (BitwiseOnly && (Opcode == X86ISD::VSHLDQ || Opcode == X86ISD::VSRLDQ))
11729     return SDValue();
11730 
11731   assert(DAG.getTargetLoweringInfo().isTypeLegal(ShiftVT) &&
11732          "Illegal integer vector type");
11733   V = DAG.getBitcast(ShiftVT, V);
11734   V = DAG.getNode(Opcode, DL, ShiftVT, V,
11735                   DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
11736   return DAG.getBitcast(VT, V);
11737 }
11738 
11739 // EXTRQ: Extract Len elements from lower half of source, starting at Idx.
11740 // Remainder of lower half result is zero and upper half is all undef.
matchShuffleAsEXTRQ(MVT VT,SDValue & V1,SDValue & V2,ArrayRef<int> Mask,uint64_t & BitLen,uint64_t & BitIdx,const APInt & Zeroable)11741 static bool matchShuffleAsEXTRQ(MVT VT, SDValue &V1, SDValue &V2,
11742                                 ArrayRef<int> Mask, uint64_t &BitLen,
11743                                 uint64_t &BitIdx, const APInt &Zeroable) {
11744   int Size = Mask.size();
11745   int HalfSize = Size / 2;
11746   assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size");
11747   assert(!Zeroable.isAllOnes() && "Fully zeroable shuffle mask");
11748 
11749   // Upper half must be undefined.
11750   if (!isUndefUpperHalf(Mask))
11751     return false;
11752 
11753   // Determine the extraction length from the part of the
11754   // lower half that isn't zeroable.
11755   int Len = HalfSize;
11756   for (; Len > 0; --Len)
11757     if (!Zeroable[Len - 1])
11758       break;
11759   assert(Len > 0 && "Zeroable shuffle mask");
11760 
11761   // Attempt to match first Len sequential elements from the lower half.
11762   SDValue Src;
11763   int Idx = -1;
11764   for (int i = 0; i != Len; ++i) {
11765     int M = Mask[i];
11766     if (M == SM_SentinelUndef)
11767       continue;
11768     SDValue &V = (M < Size ? V1 : V2);
11769     M = M % Size;
11770 
11771     // The extracted elements must start at a valid index and all mask
11772     // elements must be in the lower half.
11773     if (i > M || M >= HalfSize)
11774       return false;
11775 
11776     if (Idx < 0 || (Src == V && Idx == (M - i))) {
11777       Src = V;
11778       Idx = M - i;
11779       continue;
11780     }
11781     return false;
11782   }
11783 
11784   if (!Src || Idx < 0)
11785     return false;
11786 
11787   assert((Idx + Len) <= HalfSize && "Illegal extraction mask");
11788   BitLen = (Len * VT.getScalarSizeInBits()) & 0x3f;
11789   BitIdx = (Idx * VT.getScalarSizeInBits()) & 0x3f;
11790   V1 = Src;
11791   return true;
11792 }
11793 
11794 // INSERTQ: Extract lowest Len elements from lower half of second source and
11795 // insert over first source, starting at Idx.
11796 // { A[0], .., A[Idx-1], B[0], .., B[Len-1], A[Idx+Len], .., UNDEF, ... }
matchShuffleAsINSERTQ(MVT VT,SDValue & V1,SDValue & V2,ArrayRef<int> Mask,uint64_t & BitLen,uint64_t & BitIdx)11797 static bool matchShuffleAsINSERTQ(MVT VT, SDValue &V1, SDValue &V2,
11798                                   ArrayRef<int> Mask, uint64_t &BitLen,
11799                                   uint64_t &BitIdx) {
11800   int Size = Mask.size();
11801   int HalfSize = Size / 2;
11802   assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size");
11803 
11804   // Upper half must be undefined.
11805   if (!isUndefUpperHalf(Mask))
11806     return false;
11807 
11808   for (int Idx = 0; Idx != HalfSize; ++Idx) {
11809     SDValue Base;
11810 
11811     // Attempt to match first source from mask before insertion point.
11812     if (isUndefInRange(Mask, 0, Idx)) {
11813       /* EMPTY */
11814     } else if (isSequentialOrUndefInRange(Mask, 0, Idx, 0)) {
11815       Base = V1;
11816     } else if (isSequentialOrUndefInRange(Mask, 0, Idx, Size)) {
11817       Base = V2;
11818     } else {
11819       continue;
11820     }
11821 
11822     // Extend the extraction length looking to match both the insertion of
11823     // the second source and the remaining elements of the first.
11824     for (int Hi = Idx + 1; Hi <= HalfSize; ++Hi) {
11825       SDValue Insert;
11826       int Len = Hi - Idx;
11827 
11828       // Match insertion.
11829       if (isSequentialOrUndefInRange(Mask, Idx, Len, 0)) {
11830         Insert = V1;
11831       } else if (isSequentialOrUndefInRange(Mask, Idx, Len, Size)) {
11832         Insert = V2;
11833       } else {
11834         continue;
11835       }
11836 
11837       // Match the remaining elements of the lower half.
11838       if (isUndefInRange(Mask, Hi, HalfSize - Hi)) {
11839         /* EMPTY */
11840       } else if ((!Base || (Base == V1)) &&
11841                  isSequentialOrUndefInRange(Mask, Hi, HalfSize - Hi, Hi)) {
11842         Base = V1;
11843       } else if ((!Base || (Base == V2)) &&
11844                  isSequentialOrUndefInRange(Mask, Hi, HalfSize - Hi,
11845                                             Size + Hi)) {
11846         Base = V2;
11847       } else {
11848         continue;
11849       }
11850 
11851       BitLen = (Len * VT.getScalarSizeInBits()) & 0x3f;
11852       BitIdx = (Idx * VT.getScalarSizeInBits()) & 0x3f;
11853       V1 = Base;
11854       V2 = Insert;
11855       return true;
11856     }
11857   }
11858 
11859   return false;
11860 }
11861 
11862 /// Try to lower a vector shuffle using SSE4a EXTRQ/INSERTQ.
lowerShuffleWithSSE4A(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,SelectionDAG & DAG)11863 static SDValue lowerShuffleWithSSE4A(const SDLoc &DL, MVT VT, SDValue V1,
11864                                      SDValue V2, ArrayRef<int> Mask,
11865                                      const APInt &Zeroable, SelectionDAG &DAG) {
11866   uint64_t BitLen, BitIdx;
11867   if (matchShuffleAsEXTRQ(VT, V1, V2, Mask, BitLen, BitIdx, Zeroable))
11868     return DAG.getNode(X86ISD::EXTRQI, DL, VT, V1,
11869                        DAG.getTargetConstant(BitLen, DL, MVT::i8),
11870                        DAG.getTargetConstant(BitIdx, DL, MVT::i8));
11871 
11872   if (matchShuffleAsINSERTQ(VT, V1, V2, Mask, BitLen, BitIdx))
11873     return DAG.getNode(X86ISD::INSERTQI, DL, VT, V1 ? V1 : DAG.getUNDEF(VT),
11874                        V2 ? V2 : DAG.getUNDEF(VT),
11875                        DAG.getTargetConstant(BitLen, DL, MVT::i8),
11876                        DAG.getTargetConstant(BitIdx, DL, MVT::i8));
11877 
11878   return SDValue();
11879 }
11880 
11881 /// Lower a vector shuffle as a zero or any extension.
11882 ///
11883 /// Given a specific number of elements, element bit width, and extension
11884 /// stride, produce either a zero or any extension based on the available
11885 /// features of the subtarget. The extended elements are consecutive and
11886 /// begin and can start from an offsetted element index in the input; to
11887 /// avoid excess shuffling the offset must either being in the bottom lane
11888 /// or at the start of a higher lane. All extended elements must be from
11889 /// the same lane.
lowerShuffleAsSpecificZeroOrAnyExtend(const SDLoc & DL,MVT VT,int Scale,int Offset,bool AnyExt,SDValue InputV,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)11890 static SDValue lowerShuffleAsSpecificZeroOrAnyExtend(
11891     const SDLoc &DL, MVT VT, int Scale, int Offset, bool AnyExt, SDValue InputV,
11892     ArrayRef<int> Mask, const X86Subtarget &Subtarget, SelectionDAG &DAG) {
11893   assert(Scale > 1 && "Need a scale to extend.");
11894   int EltBits = VT.getScalarSizeInBits();
11895   int NumElements = VT.getVectorNumElements();
11896   int NumEltsPerLane = 128 / EltBits;
11897   int OffsetLane = Offset / NumEltsPerLane;
11898   assert((EltBits == 8 || EltBits == 16 || EltBits == 32) &&
11899          "Only 8, 16, and 32 bit elements can be extended.");
11900   assert(Scale * EltBits <= 64 && "Cannot zero extend past 64 bits.");
11901   assert(0 <= Offset && "Extension offset must be positive.");
11902   assert((Offset < NumEltsPerLane || Offset % NumEltsPerLane == 0) &&
11903          "Extension offset must be in the first lane or start an upper lane.");
11904 
11905   // Check that an index is in same lane as the base offset.
11906   auto SafeOffset = [&](int Idx) {
11907     return OffsetLane == (Idx / NumEltsPerLane);
11908   };
11909 
11910   // Shift along an input so that the offset base moves to the first element.
11911   auto ShuffleOffset = [&](SDValue V) {
11912     if (!Offset)
11913       return V;
11914 
11915     SmallVector<int, 8> ShMask((unsigned)NumElements, -1);
11916     for (int i = 0; i * Scale < NumElements; ++i) {
11917       int SrcIdx = i + Offset;
11918       ShMask[i] = SafeOffset(SrcIdx) ? SrcIdx : -1;
11919     }
11920     return DAG.getVectorShuffle(VT, DL, V, DAG.getUNDEF(VT), ShMask);
11921   };
11922 
11923   // Found a valid a/zext mask! Try various lowering strategies based on the
11924   // input type and available ISA extensions.
11925   if (Subtarget.hasSSE41()) {
11926     // Not worth offsetting 128-bit vectors if scale == 2, a pattern using
11927     // PUNPCK will catch this in a later shuffle match.
11928     if (Offset && Scale == 2 && VT.is128BitVector())
11929       return SDValue();
11930     MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * Scale),
11931                                  NumElements / Scale);
11932     InputV = DAG.getBitcast(VT, InputV);
11933     InputV = ShuffleOffset(InputV);
11934     InputV = getEXTEND_VECTOR_INREG(AnyExt ? ISD::ANY_EXTEND : ISD::ZERO_EXTEND,
11935                                     DL, ExtVT, InputV, DAG);
11936     return DAG.getBitcast(VT, InputV);
11937   }
11938 
11939   assert(VT.is128BitVector() && "Only 128-bit vectors can be extended.");
11940   InputV = DAG.getBitcast(VT, InputV);
11941 
11942   // For any extends we can cheat for larger element sizes and use shuffle
11943   // instructions that can fold with a load and/or copy.
11944   if (AnyExt && EltBits == 32) {
11945     int PSHUFDMask[4] = {Offset, -1, SafeOffset(Offset + 1) ? Offset + 1 : -1,
11946                          -1};
11947     return DAG.getBitcast(
11948         VT, DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32,
11949                         DAG.getBitcast(MVT::v4i32, InputV),
11950                         getV4X86ShuffleImm8ForMask(PSHUFDMask, DL, DAG)));
11951   }
11952   if (AnyExt && EltBits == 16 && Scale > 2) {
11953     int PSHUFDMask[4] = {Offset / 2, -1,
11954                          SafeOffset(Offset + 1) ? (Offset + 1) / 2 : -1, -1};
11955     InputV = DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32,
11956                          DAG.getBitcast(MVT::v4i32, InputV),
11957                          getV4X86ShuffleImm8ForMask(PSHUFDMask, DL, DAG));
11958     int PSHUFWMask[4] = {1, -1, -1, -1};
11959     unsigned OddEvenOp = (Offset & 1) ? X86ISD::PSHUFLW : X86ISD::PSHUFHW;
11960     return DAG.getBitcast(
11961         VT, DAG.getNode(OddEvenOp, DL, MVT::v8i16,
11962                         DAG.getBitcast(MVT::v8i16, InputV),
11963                         getV4X86ShuffleImm8ForMask(PSHUFWMask, DL, DAG)));
11964   }
11965 
11966   // The SSE4A EXTRQ instruction can efficiently extend the first 2 lanes
11967   // to 64-bits.
11968   if ((Scale * EltBits) == 64 && EltBits < 32 && Subtarget.hasSSE4A()) {
11969     assert(NumElements == (int)Mask.size() && "Unexpected shuffle mask size!");
11970     assert(VT.is128BitVector() && "Unexpected vector width!");
11971 
11972     int LoIdx = Offset * EltBits;
11973     SDValue Lo = DAG.getBitcast(
11974         MVT::v2i64, DAG.getNode(X86ISD::EXTRQI, DL, VT, InputV,
11975                                 DAG.getTargetConstant(EltBits, DL, MVT::i8),
11976                                 DAG.getTargetConstant(LoIdx, DL, MVT::i8)));
11977 
11978     if (isUndefUpperHalf(Mask) || !SafeOffset(Offset + 1))
11979       return DAG.getBitcast(VT, Lo);
11980 
11981     int HiIdx = (Offset + 1) * EltBits;
11982     SDValue Hi = DAG.getBitcast(
11983         MVT::v2i64, DAG.getNode(X86ISD::EXTRQI, DL, VT, InputV,
11984                                 DAG.getTargetConstant(EltBits, DL, MVT::i8),
11985                                 DAG.getTargetConstant(HiIdx, DL, MVT::i8)));
11986     return DAG.getBitcast(VT,
11987                           DAG.getNode(X86ISD::UNPCKL, DL, MVT::v2i64, Lo, Hi));
11988   }
11989 
11990   // If this would require more than 2 unpack instructions to expand, use
11991   // pshufb when available. We can only use more than 2 unpack instructions
11992   // when zero extending i8 elements which also makes it easier to use pshufb.
11993   if (Scale > 4 && EltBits == 8 && Subtarget.hasSSSE3()) {
11994     assert(NumElements == 16 && "Unexpected byte vector width!");
11995     SDValue PSHUFBMask[16];
11996     for (int i = 0; i < 16; ++i) {
11997       int Idx = Offset + (i / Scale);
11998       if ((i % Scale == 0 && SafeOffset(Idx))) {
11999         PSHUFBMask[i] = DAG.getConstant(Idx, DL, MVT::i8);
12000         continue;
12001       }
12002       PSHUFBMask[i] =
12003           AnyExt ? DAG.getUNDEF(MVT::i8) : DAG.getConstant(0x80, DL, MVT::i8);
12004     }
12005     InputV = DAG.getBitcast(MVT::v16i8, InputV);
12006     return DAG.getBitcast(
12007         VT, DAG.getNode(X86ISD::PSHUFB, DL, MVT::v16i8, InputV,
12008                         DAG.getBuildVector(MVT::v16i8, DL, PSHUFBMask)));
12009   }
12010 
12011   // If we are extending from an offset, ensure we start on a boundary that
12012   // we can unpack from.
12013   int AlignToUnpack = Offset % (NumElements / Scale);
12014   if (AlignToUnpack) {
12015     SmallVector<int, 8> ShMask((unsigned)NumElements, -1);
12016     for (int i = AlignToUnpack; i < NumElements; ++i)
12017       ShMask[i - AlignToUnpack] = i;
12018     InputV = DAG.getVectorShuffle(VT, DL, InputV, DAG.getUNDEF(VT), ShMask);
12019     Offset -= AlignToUnpack;
12020   }
12021 
12022   // Otherwise emit a sequence of unpacks.
12023   do {
12024     unsigned UnpackLoHi = X86ISD::UNPCKL;
12025     if (Offset >= (NumElements / 2)) {
12026       UnpackLoHi = X86ISD::UNPCKH;
12027       Offset -= (NumElements / 2);
12028     }
12029 
12030     MVT InputVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits), NumElements);
12031     SDValue Ext = AnyExt ? DAG.getUNDEF(InputVT)
12032                          : getZeroVector(InputVT, Subtarget, DAG, DL);
12033     InputV = DAG.getBitcast(InputVT, InputV);
12034     InputV = DAG.getNode(UnpackLoHi, DL, InputVT, InputV, Ext);
12035     Scale /= 2;
12036     EltBits *= 2;
12037     NumElements /= 2;
12038   } while (Scale > 1);
12039   return DAG.getBitcast(VT, InputV);
12040 }
12041 
12042 /// Try to lower a vector shuffle as a zero extension on any microarch.
12043 ///
12044 /// This routine will try to do everything in its power to cleverly lower
12045 /// a shuffle which happens to match the pattern of a zero extend. It doesn't
12046 /// check for the profitability of this lowering,  it tries to aggressively
12047 /// match this pattern. It will use all of the micro-architectural details it
12048 /// can to emit an efficient lowering. It handles both blends with all-zero
12049 /// inputs to explicitly zero-extend and undef-lanes (sometimes undef due to
12050 /// masking out later).
12051 ///
12052 /// The reason we have dedicated lowering for zext-style shuffles is that they
12053 /// are both incredibly common and often quite performance sensitive.
lowerShuffleAsZeroOrAnyExtend(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)12054 static SDValue lowerShuffleAsZeroOrAnyExtend(
12055     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
12056     const APInt &Zeroable, const X86Subtarget &Subtarget,
12057     SelectionDAG &DAG) {
12058   int Bits = VT.getSizeInBits();
12059   int NumLanes = Bits / 128;
12060   int NumElements = VT.getVectorNumElements();
12061   int NumEltsPerLane = NumElements / NumLanes;
12062   assert(VT.getScalarSizeInBits() <= 32 &&
12063          "Exceeds 32-bit integer zero extension limit");
12064   assert((int)Mask.size() == NumElements && "Unexpected shuffle mask size");
12065 
12066   // Define a helper function to check a particular ext-scale and lower to it if
12067   // valid.
12068   auto Lower = [&](int Scale) -> SDValue {
12069     SDValue InputV;
12070     bool AnyExt = true;
12071     int Offset = 0;
12072     int Matches = 0;
12073     for (int i = 0; i < NumElements; ++i) {
12074       int M = Mask[i];
12075       if (M < 0)
12076         continue; // Valid anywhere but doesn't tell us anything.
12077       if (i % Scale != 0) {
12078         // Each of the extended elements need to be zeroable.
12079         if (!Zeroable[i])
12080           return SDValue();
12081 
12082         // We no longer are in the anyext case.
12083         AnyExt = false;
12084         continue;
12085       }
12086 
12087       // Each of the base elements needs to be consecutive indices into the
12088       // same input vector.
12089       SDValue V = M < NumElements ? V1 : V2;
12090       M = M % NumElements;
12091       if (!InputV) {
12092         InputV = V;
12093         Offset = M - (i / Scale);
12094       } else if (InputV != V)
12095         return SDValue(); // Flip-flopping inputs.
12096 
12097       // Offset must start in the lowest 128-bit lane or at the start of an
12098       // upper lane.
12099       // FIXME: Is it ever worth allowing a negative base offset?
12100       if (!((0 <= Offset && Offset < NumEltsPerLane) ||
12101             (Offset % NumEltsPerLane) == 0))
12102         return SDValue();
12103 
12104       // If we are offsetting, all referenced entries must come from the same
12105       // lane.
12106       if (Offset && (Offset / NumEltsPerLane) != (M / NumEltsPerLane))
12107         return SDValue();
12108 
12109       if ((M % NumElements) != (Offset + (i / Scale)))
12110         return SDValue(); // Non-consecutive strided elements.
12111       Matches++;
12112     }
12113 
12114     // If we fail to find an input, we have a zero-shuffle which should always
12115     // have already been handled.
12116     // FIXME: Maybe handle this here in case during blending we end up with one?
12117     if (!InputV)
12118       return SDValue();
12119 
12120     // If we are offsetting, don't extend if we only match a single input, we
12121     // can always do better by using a basic PSHUF or PUNPCK.
12122     if (Offset != 0 && Matches < 2)
12123       return SDValue();
12124 
12125     return lowerShuffleAsSpecificZeroOrAnyExtend(DL, VT, Scale, Offset, AnyExt,
12126                                                  InputV, Mask, Subtarget, DAG);
12127   };
12128 
12129   // The widest scale possible for extending is to a 64-bit integer.
12130   assert(Bits % 64 == 0 &&
12131          "The number of bits in a vector must be divisible by 64 on x86!");
12132   int NumExtElements = Bits / 64;
12133 
12134   // Each iteration, try extending the elements half as much, but into twice as
12135   // many elements.
12136   for (; NumExtElements < NumElements; NumExtElements *= 2) {
12137     assert(NumElements % NumExtElements == 0 &&
12138            "The input vector size must be divisible by the extended size.");
12139     if (SDValue V = Lower(NumElements / NumExtElements))
12140       return V;
12141   }
12142 
12143   // General extends failed, but 128-bit vectors may be able to use MOVQ.
12144   if (Bits != 128)
12145     return SDValue();
12146 
12147   // Returns one of the source operands if the shuffle can be reduced to a
12148   // MOVQ, copying the lower 64-bits and zero-extending to the upper 64-bits.
12149   auto CanZExtLowHalf = [&]() {
12150     for (int i = NumElements / 2; i != NumElements; ++i)
12151       if (!Zeroable[i])
12152         return SDValue();
12153     if (isSequentialOrUndefInRange(Mask, 0, NumElements / 2, 0))
12154       return V1;
12155     if (isSequentialOrUndefInRange(Mask, 0, NumElements / 2, NumElements))
12156       return V2;
12157     return SDValue();
12158   };
12159 
12160   if (SDValue V = CanZExtLowHalf()) {
12161     V = DAG.getBitcast(MVT::v2i64, V);
12162     V = DAG.getNode(X86ISD::VZEXT_MOVL, DL, MVT::v2i64, V);
12163     return DAG.getBitcast(VT, V);
12164   }
12165 
12166   // No viable ext lowering found.
12167   return SDValue();
12168 }
12169 
12170 /// Try to get a scalar value for a specific element of a vector.
12171 ///
12172 /// Looks through BUILD_VECTOR and SCALAR_TO_VECTOR nodes to find a scalar.
getScalarValueForVectorElement(SDValue V,int Idx,SelectionDAG & DAG)12173 static SDValue getScalarValueForVectorElement(SDValue V, int Idx,
12174                                               SelectionDAG &DAG) {
12175   MVT VT = V.getSimpleValueType();
12176   MVT EltVT = VT.getVectorElementType();
12177   V = peekThroughBitcasts(V);
12178 
12179   // If the bitcasts shift the element size, we can't extract an equivalent
12180   // element from it.
12181   MVT NewVT = V.getSimpleValueType();
12182   if (!NewVT.isVector() || NewVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
12183     return SDValue();
12184 
12185   if (V.getOpcode() == ISD::BUILD_VECTOR ||
12186       (Idx == 0 && V.getOpcode() == ISD::SCALAR_TO_VECTOR)) {
12187     // Ensure the scalar operand is the same size as the destination.
12188     // FIXME: Add support for scalar truncation where possible.
12189     SDValue S = V.getOperand(Idx);
12190     if (EltVT.getSizeInBits() == S.getSimpleValueType().getSizeInBits())
12191       return DAG.getBitcast(EltVT, S);
12192   }
12193 
12194   return SDValue();
12195 }
12196 
12197 /// Helper to test for a load that can be folded with x86 shuffles.
12198 ///
12199 /// This is particularly important because the set of instructions varies
12200 /// significantly based on whether the operand is a load or not.
isShuffleFoldableLoad(SDValue V)12201 static bool isShuffleFoldableLoad(SDValue V) {
12202   return V->hasOneUse() &&
12203          ISD::isNON_EXTLoad(peekThroughOneUseBitcasts(V).getNode());
12204 }
12205 
12206 template<typename T>
isSoftF16(T VT,const X86Subtarget & Subtarget)12207 static bool isSoftF16(T VT, const X86Subtarget &Subtarget) {
12208   T EltVT = VT.getScalarType();
12209   return EltVT == MVT::bf16 || (EltVT == MVT::f16 && !Subtarget.hasFP16());
12210 }
12211 
12212 /// Try to lower insertion of a single element into a zero vector.
12213 ///
12214 /// This is a common pattern that we have especially efficient patterns to lower
12215 /// across all subtarget feature sets.
lowerShuffleAsElementInsertion(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)12216 static SDValue lowerShuffleAsElementInsertion(
12217     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
12218     const APInt &Zeroable, const X86Subtarget &Subtarget,
12219     SelectionDAG &DAG) {
12220   MVT ExtVT = VT;
12221   MVT EltVT = VT.getVectorElementType();
12222   unsigned NumElts = VT.getVectorNumElements();
12223   unsigned EltBits = VT.getScalarSizeInBits();
12224 
12225   if (isSoftF16(EltVT, Subtarget))
12226     return SDValue();
12227 
12228   int V2Index =
12229       find_if(Mask, [&Mask](int M) { return M >= (int)Mask.size(); }) -
12230       Mask.begin();
12231   bool IsV1Constant = getTargetConstantFromNode(V1) != nullptr;
12232   bool IsV1Zeroable = true;
12233   for (int i = 0, Size = Mask.size(); i < Size; ++i)
12234     if (i != V2Index && !Zeroable[i]) {
12235       IsV1Zeroable = false;
12236       break;
12237     }
12238 
12239   // Bail if a non-zero V1 isn't used in place.
12240   if (!IsV1Zeroable) {
12241     SmallVector<int, 8> V1Mask(Mask);
12242     V1Mask[V2Index] = -1;
12243     if (!isNoopShuffleMask(V1Mask))
12244       return SDValue();
12245   }
12246 
12247   // Check for a single input from a SCALAR_TO_VECTOR node.
12248   // FIXME: All of this should be canonicalized into INSERT_VECTOR_ELT and
12249   // all the smarts here sunk into that routine. However, the current
12250   // lowering of BUILD_VECTOR makes that nearly impossible until the old
12251   // vector shuffle lowering is dead.
12252   SDValue V2S = getScalarValueForVectorElement(V2, Mask[V2Index] - Mask.size(),
12253                                                DAG);
12254   if (V2S && DAG.getTargetLoweringInfo().isTypeLegal(V2S.getValueType())) {
12255     // We need to zext the scalar if it is smaller than an i32.
12256     V2S = DAG.getBitcast(EltVT, V2S);
12257     if (EltVT == MVT::i8 || (EltVT == MVT::i16 && !Subtarget.hasFP16())) {
12258       // Using zext to expand a narrow element won't work for non-zero
12259       // insertions. But we can use a masked constant vector if we're
12260       // inserting V2 into the bottom of V1.
12261       if (!IsV1Zeroable && !(IsV1Constant && V2Index == 0))
12262         return SDValue();
12263 
12264       // Zero-extend directly to i32.
12265       ExtVT = MVT::getVectorVT(MVT::i32, ExtVT.getSizeInBits() / 32);
12266       V2S = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, V2S);
12267 
12268       // If we're inserting into a constant, mask off the inserted index
12269       // and OR with the zero-extended scalar.
12270       if (!IsV1Zeroable) {
12271         SmallVector<APInt> Bits(NumElts, APInt::getAllOnes(EltBits));
12272         Bits[V2Index] = APInt::getZero(EltBits);
12273         SDValue BitMask = getConstVector(Bits, VT, DAG, DL);
12274         V1 = DAG.getNode(ISD::AND, DL, VT, V1, BitMask);
12275         V2 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ExtVT, V2S);
12276         V2 = DAG.getBitcast(VT, DAG.getNode(X86ISD::VZEXT_MOVL, DL, ExtVT, V2));
12277         return DAG.getNode(ISD::OR, DL, VT, V1, V2);
12278       }
12279     }
12280     V2 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ExtVT, V2S);
12281   } else if (Mask[V2Index] != (int)Mask.size() || EltVT == MVT::i8 ||
12282              EltVT == MVT::i16) {
12283     // Either not inserting from the low element of the input or the input
12284     // element size is too small to use VZEXT_MOVL to clear the high bits.
12285     return SDValue();
12286   }
12287 
12288   if (!IsV1Zeroable) {
12289     // If V1 can't be treated as a zero vector we have fewer options to lower
12290     // this. We can't support integer vectors or non-zero targets cheaply.
12291     assert(VT == ExtVT && "Cannot change extended type when non-zeroable!");
12292     if (!VT.isFloatingPoint() || V2Index != 0)
12293       return SDValue();
12294     if (!VT.is128BitVector())
12295       return SDValue();
12296 
12297     // Otherwise, use MOVSD, MOVSS or MOVSH.
12298     unsigned MovOpc = 0;
12299     if (EltVT == MVT::f16)
12300       MovOpc = X86ISD::MOVSH;
12301     else if (EltVT == MVT::f32)
12302       MovOpc = X86ISD::MOVSS;
12303     else if (EltVT == MVT::f64)
12304       MovOpc = X86ISD::MOVSD;
12305     else
12306       llvm_unreachable("Unsupported floating point element type to handle!");
12307     return DAG.getNode(MovOpc, DL, ExtVT, V1, V2);
12308   }
12309 
12310   // This lowering only works for the low element with floating point vectors.
12311   if (VT.isFloatingPoint() && V2Index != 0)
12312     return SDValue();
12313 
12314   V2 = DAG.getNode(X86ISD::VZEXT_MOVL, DL, ExtVT, V2);
12315   if (ExtVT != VT)
12316     V2 = DAG.getBitcast(VT, V2);
12317 
12318   if (V2Index != 0) {
12319     // If we have 4 or fewer lanes we can cheaply shuffle the element into
12320     // the desired position. Otherwise it is more efficient to do a vector
12321     // shift left. We know that we can do a vector shift left because all
12322     // the inputs are zero.
12323     if (VT.isFloatingPoint() || NumElts <= 4) {
12324       SmallVector<int, 4> V2Shuffle(Mask.size(), 1);
12325       V2Shuffle[V2Index] = 0;
12326       V2 = DAG.getVectorShuffle(VT, DL, V2, DAG.getUNDEF(VT), V2Shuffle);
12327     } else {
12328       V2 = DAG.getBitcast(MVT::v16i8, V2);
12329       V2 = DAG.getNode(
12330           X86ISD::VSHLDQ, DL, MVT::v16i8, V2,
12331           DAG.getTargetConstant(V2Index * EltBits / 8, DL, MVT::i8));
12332       V2 = DAG.getBitcast(VT, V2);
12333     }
12334   }
12335   return V2;
12336 }
12337 
12338 /// Try to lower broadcast of a single - truncated - integer element,
12339 /// coming from a scalar_to_vector/build_vector node \p V0 with larger elements.
12340 ///
12341 /// This assumes we have AVX2.
lowerShuffleAsTruncBroadcast(const SDLoc & DL,MVT VT,SDValue V0,int BroadcastIdx,const X86Subtarget & Subtarget,SelectionDAG & DAG)12342 static SDValue lowerShuffleAsTruncBroadcast(const SDLoc &DL, MVT VT, SDValue V0,
12343                                             int BroadcastIdx,
12344                                             const X86Subtarget &Subtarget,
12345                                             SelectionDAG &DAG) {
12346   assert(Subtarget.hasAVX2() &&
12347          "We can only lower integer broadcasts with AVX2!");
12348 
12349   MVT EltVT = VT.getVectorElementType();
12350   MVT V0VT = V0.getSimpleValueType();
12351 
12352   assert(VT.isInteger() && "Unexpected non-integer trunc broadcast!");
12353   assert(V0VT.isVector() && "Unexpected non-vector vector-sized value!");
12354 
12355   MVT V0EltVT = V0VT.getVectorElementType();
12356   if (!V0EltVT.isInteger())
12357     return SDValue();
12358 
12359   const unsigned EltSize = EltVT.getSizeInBits();
12360   const unsigned V0EltSize = V0EltVT.getSizeInBits();
12361 
12362   // This is only a truncation if the original element type is larger.
12363   if (V0EltSize <= EltSize)
12364     return SDValue();
12365 
12366   assert(((V0EltSize % EltSize) == 0) &&
12367          "Scalar type sizes must all be powers of 2 on x86!");
12368 
12369   const unsigned V0Opc = V0.getOpcode();
12370   const unsigned Scale = V0EltSize / EltSize;
12371   const unsigned V0BroadcastIdx = BroadcastIdx / Scale;
12372 
12373   if ((V0Opc != ISD::SCALAR_TO_VECTOR || V0BroadcastIdx != 0) &&
12374       V0Opc != ISD::BUILD_VECTOR)
12375     return SDValue();
12376 
12377   SDValue Scalar = V0.getOperand(V0BroadcastIdx);
12378 
12379   // If we're extracting non-least-significant bits, shift so we can truncate.
12380   // Hopefully, we can fold away the trunc/srl/load into the broadcast.
12381   // Even if we can't (and !isShuffleFoldableLoad(Scalar)), prefer
12382   // vpbroadcast+vmovd+shr to vpshufb(m)+vmovd.
12383   if (const int OffsetIdx = BroadcastIdx % Scale)
12384     Scalar = DAG.getNode(ISD::SRL, DL, Scalar.getValueType(), Scalar,
12385                          DAG.getConstant(OffsetIdx * EltSize, DL, MVT::i8));
12386 
12387   return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
12388                      DAG.getNode(ISD::TRUNCATE, DL, EltVT, Scalar));
12389 }
12390 
12391 /// Test whether this can be lowered with a single SHUFPS instruction.
12392 ///
12393 /// This is used to disable more specialized lowerings when the shufps lowering
12394 /// will happen to be efficient.
isSingleSHUFPSMask(ArrayRef<int> Mask)12395 static bool isSingleSHUFPSMask(ArrayRef<int> Mask) {
12396   // This routine only handles 128-bit shufps.
12397   assert(Mask.size() == 4 && "Unsupported mask size!");
12398   assert(Mask[0] >= -1 && Mask[0] < 8 && "Out of bound mask element!");
12399   assert(Mask[1] >= -1 && Mask[1] < 8 && "Out of bound mask element!");
12400   assert(Mask[2] >= -1 && Mask[2] < 8 && "Out of bound mask element!");
12401   assert(Mask[3] >= -1 && Mask[3] < 8 && "Out of bound mask element!");
12402 
12403   // To lower with a single SHUFPS we need to have the low half and high half
12404   // each requiring a single input.
12405   if (Mask[0] >= 0 && Mask[1] >= 0 && (Mask[0] < 4) != (Mask[1] < 4))
12406     return false;
12407   if (Mask[2] >= 0 && Mask[3] >= 0 && (Mask[2] < 4) != (Mask[3] < 4))
12408     return false;
12409 
12410   return true;
12411 }
12412 
12413 /// Test whether the specified input (0 or 1) is in-place blended by the
12414 /// given mask.
12415 ///
12416 /// This returns true if the elements from a particular input are already in the
12417 /// slot required by the given mask and require no permutation.
isShuffleMaskInputInPlace(int Input,ArrayRef<int> Mask)12418 static bool isShuffleMaskInputInPlace(int Input, ArrayRef<int> Mask) {
12419   assert((Input == 0 || Input == 1) && "Only two inputs to shuffles.");
12420   int Size = Mask.size();
12421   for (int i = 0; i < Size; ++i)
12422     if (Mask[i] >= 0 && Mask[i] / Size == Input && Mask[i] % Size != i)
12423       return false;
12424 
12425   return true;
12426 }
12427 
12428 /// If we are extracting two 128-bit halves of a vector and shuffling the
12429 /// result, match that to a 256-bit AVX2 vperm* instruction to avoid a
12430 /// multi-shuffle lowering.
lowerShuffleOfExtractsAsVperm(const SDLoc & DL,SDValue N0,SDValue N1,ArrayRef<int> Mask,SelectionDAG & DAG)12431 static SDValue lowerShuffleOfExtractsAsVperm(const SDLoc &DL, SDValue N0,
12432                                              SDValue N1, ArrayRef<int> Mask,
12433                                              SelectionDAG &DAG) {
12434   MVT VT = N0.getSimpleValueType();
12435   assert((VT.is128BitVector() &&
12436           (VT.getScalarSizeInBits() == 32 || VT.getScalarSizeInBits() == 64)) &&
12437          "VPERM* family of shuffles requires 32-bit or 64-bit elements");
12438 
12439   // Check that both sources are extracts of the same source vector.
12440   if (N0.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
12441       N1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
12442       N0.getOperand(0) != N1.getOperand(0) ||
12443       !N0.hasOneUse() || !N1.hasOneUse())
12444     return SDValue();
12445 
12446   SDValue WideVec = N0.getOperand(0);
12447   MVT WideVT = WideVec.getSimpleValueType();
12448   if (!WideVT.is256BitVector())
12449     return SDValue();
12450 
12451   // Match extracts of each half of the wide source vector. Commute the shuffle
12452   // if the extract of the low half is N1.
12453   unsigned NumElts = VT.getVectorNumElements();
12454   SmallVector<int, 4> NewMask(Mask);
12455   const APInt &ExtIndex0 = N0.getConstantOperandAPInt(1);
12456   const APInt &ExtIndex1 = N1.getConstantOperandAPInt(1);
12457   if (ExtIndex1 == 0 && ExtIndex0 == NumElts)
12458     ShuffleVectorSDNode::commuteMask(NewMask);
12459   else if (ExtIndex0 != 0 || ExtIndex1 != NumElts)
12460     return SDValue();
12461 
12462   // Final bailout: if the mask is simple, we are better off using an extract
12463   // and a simple narrow shuffle. Prefer extract+unpack(h/l)ps to vpermps
12464   // because that avoids a constant load from memory.
12465   if (NumElts == 4 &&
12466       (isSingleSHUFPSMask(NewMask) || is128BitUnpackShuffleMask(NewMask, DAG)))
12467     return SDValue();
12468 
12469   // Extend the shuffle mask with undef elements.
12470   NewMask.append(NumElts, -1);
12471 
12472   // shuf (extract X, 0), (extract X, 4), M --> extract (shuf X, undef, M'), 0
12473   SDValue Shuf = DAG.getVectorShuffle(WideVT, DL, WideVec, DAG.getUNDEF(WideVT),
12474                                       NewMask);
12475   // This is free: ymm -> xmm.
12476   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuf,
12477                      DAG.getIntPtrConstant(0, DL));
12478 }
12479 
12480 /// Try to lower broadcast of a single element.
12481 ///
12482 /// For convenience, this code also bundles all of the subtarget feature set
12483 /// filtering. While a little annoying to re-dispatch on type here, there isn't
12484 /// a convenient way to factor it out.
lowerShuffleAsBroadcast(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)12485 static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
12486                                        SDValue V2, ArrayRef<int> Mask,
12487                                        const X86Subtarget &Subtarget,
12488                                        SelectionDAG &DAG) {
12489   MVT EltVT = VT.getVectorElementType();
12490   if (!((Subtarget.hasSSE3() && VT == MVT::v2f64) ||
12491         (Subtarget.hasAVX() && (EltVT == MVT::f64 || EltVT == MVT::f32)) ||
12492         (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16))))
12493     return SDValue();
12494 
12495   // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
12496   // we can only broadcast from a register with AVX2.
12497   unsigned NumEltBits = VT.getScalarSizeInBits();
12498   unsigned Opcode = (VT == MVT::v2f64 && !Subtarget.hasAVX2())
12499                         ? X86ISD::MOVDDUP
12500                         : X86ISD::VBROADCAST;
12501   bool BroadcastFromReg = (Opcode == X86ISD::MOVDDUP) || Subtarget.hasAVX2();
12502 
12503   // Check that the mask is a broadcast.
12504   int BroadcastIdx = getSplatIndex(Mask);
12505   if (BroadcastIdx < 0)
12506     return SDValue();
12507   assert(BroadcastIdx < (int)Mask.size() && "We only expect to be called with "
12508                                             "a sorted mask where the broadcast "
12509                                             "comes from V1.");
12510 
12511   // Go up the chain of (vector) values to find a scalar load that we can
12512   // combine with the broadcast.
12513   // TODO: Combine this logic with findEltLoadSrc() used by
12514   //       EltsFromConsecutiveLoads().
12515   int BitOffset = BroadcastIdx * NumEltBits;
12516   SDValue V = V1;
12517   for (;;) {
12518     switch (V.getOpcode()) {
12519     case ISD::BITCAST: {
12520       V = V.getOperand(0);
12521       continue;
12522     }
12523     case ISD::CONCAT_VECTORS: {
12524       int OpBitWidth = V.getOperand(0).getValueSizeInBits();
12525       int OpIdx = BitOffset / OpBitWidth;
12526       V = V.getOperand(OpIdx);
12527       BitOffset %= OpBitWidth;
12528       continue;
12529     }
12530     case ISD::EXTRACT_SUBVECTOR: {
12531       // The extraction index adds to the existing offset.
12532       unsigned EltBitWidth = V.getScalarValueSizeInBits();
12533       unsigned Idx = V.getConstantOperandVal(1);
12534       unsigned BeginOffset = Idx * EltBitWidth;
12535       BitOffset += BeginOffset;
12536       V = V.getOperand(0);
12537       continue;
12538     }
12539     case ISD::INSERT_SUBVECTOR: {
12540       SDValue VOuter = V.getOperand(0), VInner = V.getOperand(1);
12541       int EltBitWidth = VOuter.getScalarValueSizeInBits();
12542       int Idx = (int)V.getConstantOperandVal(2);
12543       int NumSubElts = (int)VInner.getSimpleValueType().getVectorNumElements();
12544       int BeginOffset = Idx * EltBitWidth;
12545       int EndOffset = BeginOffset + NumSubElts * EltBitWidth;
12546       if (BeginOffset <= BitOffset && BitOffset < EndOffset) {
12547         BitOffset -= BeginOffset;
12548         V = VInner;
12549       } else {
12550         V = VOuter;
12551       }
12552       continue;
12553     }
12554     }
12555     break;
12556   }
12557   assert((BitOffset % NumEltBits) == 0 && "Illegal bit-offset");
12558   BroadcastIdx = BitOffset / NumEltBits;
12559 
12560   // Do we need to bitcast the source to retrieve the original broadcast index?
12561   bool BitCastSrc = V.getScalarValueSizeInBits() != NumEltBits;
12562 
12563   // Check if this is a broadcast of a scalar. We special case lowering
12564   // for scalars so that we can more effectively fold with loads.
12565   // If the original value has a larger element type than the shuffle, the
12566   // broadcast element is in essence truncated. Make that explicit to ease
12567   // folding.
12568   if (BitCastSrc && VT.isInteger())
12569     if (SDValue TruncBroadcast = lowerShuffleAsTruncBroadcast(
12570             DL, VT, V, BroadcastIdx, Subtarget, DAG))
12571       return TruncBroadcast;
12572 
12573   // Also check the simpler case, where we can directly reuse the scalar.
12574   if (!BitCastSrc &&
12575       ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) ||
12576        (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0))) {
12577     V = V.getOperand(BroadcastIdx);
12578 
12579     // If we can't broadcast from a register, check that the input is a load.
12580     if (!BroadcastFromReg && !isShuffleFoldableLoad(V))
12581       return SDValue();
12582   } else if (ISD::isNormalLoad(V.getNode()) &&
12583              cast<LoadSDNode>(V)->isSimple()) {
12584     // We do not check for one-use of the vector load because a broadcast load
12585     // is expected to be a win for code size, register pressure, and possibly
12586     // uops even if the original vector load is not eliminated.
12587 
12588     // Reduce the vector load and shuffle to a broadcasted scalar load.
12589     LoadSDNode *Ld = cast<LoadSDNode>(V);
12590     SDValue BaseAddr = Ld->getOperand(1);
12591     MVT SVT = VT.getScalarType();
12592     unsigned Offset = BroadcastIdx * SVT.getStoreSize();
12593     assert((int)(Offset * 8) == BitOffset && "Unexpected bit-offset");
12594     SDValue NewAddr =
12595         DAG.getMemBasePlusOffset(BaseAddr, TypeSize::getFixed(Offset), DL);
12596 
12597     // Directly form VBROADCAST_LOAD if we're using VBROADCAST opcode rather
12598     // than MOVDDUP.
12599     // FIXME: Should we add VBROADCAST_LOAD isel patterns for pre-AVX?
12600     if (Opcode == X86ISD::VBROADCAST) {
12601       SDVTList Tys = DAG.getVTList(VT, MVT::Other);
12602       SDValue Ops[] = {Ld->getChain(), NewAddr};
12603       V = DAG.getMemIntrinsicNode(
12604           X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, SVT,
12605           DAG.getMachineFunction().getMachineMemOperand(
12606               Ld->getMemOperand(), Offset, SVT.getStoreSize()));
12607       DAG.makeEquivalentMemoryOrdering(Ld, V);
12608       return DAG.getBitcast(VT, V);
12609     }
12610     assert(SVT == MVT::f64 && "Unexpected VT!");
12611     V = DAG.getLoad(SVT, DL, Ld->getChain(), NewAddr,
12612                     DAG.getMachineFunction().getMachineMemOperand(
12613                         Ld->getMemOperand(), Offset, SVT.getStoreSize()));
12614     DAG.makeEquivalentMemoryOrdering(Ld, V);
12615   } else if (!BroadcastFromReg) {
12616     // We can't broadcast from a vector register.
12617     return SDValue();
12618   } else if (BitOffset != 0) {
12619     // We can only broadcast from the zero-element of a vector register,
12620     // but it can be advantageous to broadcast from the zero-element of a
12621     // subvector.
12622     if (!VT.is256BitVector() && !VT.is512BitVector())
12623       return SDValue();
12624 
12625     // VPERMQ/VPERMPD can perform the cross-lane shuffle directly.
12626     if (VT == MVT::v4f64 || VT == MVT::v4i64)
12627       return SDValue();
12628 
12629     // Only broadcast the zero-element of a 128-bit subvector.
12630     if ((BitOffset % 128) != 0)
12631       return SDValue();
12632 
12633     assert((BitOffset % V.getScalarValueSizeInBits()) == 0 &&
12634            "Unexpected bit-offset");
12635     assert((V.getValueSizeInBits() == 256 || V.getValueSizeInBits() == 512) &&
12636            "Unexpected vector size");
12637     unsigned ExtractIdx = BitOffset / V.getScalarValueSizeInBits();
12638     V = extract128BitVector(V, ExtractIdx, DAG, DL);
12639   }
12640 
12641   // On AVX we can use VBROADCAST directly for scalar sources.
12642   if (Opcode == X86ISD::MOVDDUP && !V.getValueType().isVector()) {
12643     V = DAG.getBitcast(MVT::f64, V);
12644     if (Subtarget.hasAVX()) {
12645       V = DAG.getNode(X86ISD::VBROADCAST, DL, MVT::v2f64, V);
12646       return DAG.getBitcast(VT, V);
12647     }
12648     V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f64, V);
12649   }
12650 
12651   // If this is a scalar, do the broadcast on this type and bitcast.
12652   if (!V.getValueType().isVector()) {
12653     assert(V.getScalarValueSizeInBits() == NumEltBits &&
12654            "Unexpected scalar size");
12655     MVT BroadcastVT = MVT::getVectorVT(V.getSimpleValueType(),
12656                                        VT.getVectorNumElements());
12657     return DAG.getBitcast(VT, DAG.getNode(Opcode, DL, BroadcastVT, V));
12658   }
12659 
12660   // We only support broadcasting from 128-bit vectors to minimize the
12661   // number of patterns we need to deal with in isel. So extract down to
12662   // 128-bits, removing as many bitcasts as possible.
12663   if (V.getValueSizeInBits() > 128)
12664     V = extract128BitVector(peekThroughBitcasts(V), 0, DAG, DL);
12665 
12666   // Otherwise cast V to a vector with the same element type as VT, but
12667   // possibly narrower than VT. Then perform the broadcast.
12668   unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
12669   MVT CastVT = MVT::getVectorVT(VT.getVectorElementType(), NumSrcElts);
12670   return DAG.getNode(Opcode, DL, VT, DAG.getBitcast(CastVT, V));
12671 }
12672 
12673 // Check for whether we can use INSERTPS to perform the shuffle. We only use
12674 // INSERTPS when the V1 elements are already in the correct locations
12675 // because otherwise we can just always use two SHUFPS instructions which
12676 // are much smaller to encode than a SHUFPS and an INSERTPS. We can also
12677 // perform INSERTPS if a single V1 element is out of place and all V2
12678 // elements are zeroable.
matchShuffleAsInsertPS(SDValue & V1,SDValue & V2,unsigned & InsertPSMask,const APInt & Zeroable,ArrayRef<int> Mask,SelectionDAG & DAG)12679 static bool matchShuffleAsInsertPS(SDValue &V1, SDValue &V2,
12680                                    unsigned &InsertPSMask,
12681                                    const APInt &Zeroable,
12682                                    ArrayRef<int> Mask, SelectionDAG &DAG) {
12683   assert(V1.getSimpleValueType().is128BitVector() && "Bad operand type!");
12684   assert(V2.getSimpleValueType().is128BitVector() && "Bad operand type!");
12685   assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!");
12686 
12687   // Attempt to match INSERTPS with one element from VA or VB being
12688   // inserted into VA (or undef). If successful, V1, V2 and InsertPSMask
12689   // are updated.
12690   auto matchAsInsertPS = [&](SDValue VA, SDValue VB,
12691                              ArrayRef<int> CandidateMask) {
12692     unsigned ZMask = 0;
12693     int VADstIndex = -1;
12694     int VBDstIndex = -1;
12695     bool VAUsedInPlace = false;
12696 
12697     for (int i = 0; i < 4; ++i) {
12698       // Synthesize a zero mask from the zeroable elements (includes undefs).
12699       if (Zeroable[i]) {
12700         ZMask |= 1 << i;
12701         continue;
12702       }
12703 
12704       // Flag if we use any VA inputs in place.
12705       if (i == CandidateMask[i]) {
12706         VAUsedInPlace = true;
12707         continue;
12708       }
12709 
12710       // We can only insert a single non-zeroable element.
12711       if (VADstIndex >= 0 || VBDstIndex >= 0)
12712         return false;
12713 
12714       if (CandidateMask[i] < 4) {
12715         // VA input out of place for insertion.
12716         VADstIndex = i;
12717       } else {
12718         // VB input for insertion.
12719         VBDstIndex = i;
12720       }
12721     }
12722 
12723     // Don't bother if we have no (non-zeroable) element for insertion.
12724     if (VADstIndex < 0 && VBDstIndex < 0)
12725       return false;
12726 
12727     // Determine element insertion src/dst indices. The src index is from the
12728     // start of the inserted vector, not the start of the concatenated vector.
12729     unsigned VBSrcIndex = 0;
12730     if (VADstIndex >= 0) {
12731       // If we have a VA input out of place, we use VA as the V2 element
12732       // insertion and don't use the original V2 at all.
12733       VBSrcIndex = CandidateMask[VADstIndex];
12734       VBDstIndex = VADstIndex;
12735       VB = VA;
12736     } else {
12737       VBSrcIndex = CandidateMask[VBDstIndex] - 4;
12738     }
12739 
12740     // If no V1 inputs are used in place, then the result is created only from
12741     // the zero mask and the V2 insertion - so remove V1 dependency.
12742     if (!VAUsedInPlace)
12743       VA = DAG.getUNDEF(MVT::v4f32);
12744 
12745     // Update V1, V2 and InsertPSMask accordingly.
12746     V1 = VA;
12747     V2 = VB;
12748 
12749     // Insert the V2 element into the desired position.
12750     InsertPSMask = VBSrcIndex << 6 | VBDstIndex << 4 | ZMask;
12751     assert((InsertPSMask & ~0xFFu) == 0 && "Invalid mask!");
12752     return true;
12753   };
12754 
12755   if (matchAsInsertPS(V1, V2, Mask))
12756     return true;
12757 
12758   // Commute and try again.
12759   SmallVector<int, 4> CommutedMask(Mask);
12760   ShuffleVectorSDNode::commuteMask(CommutedMask);
12761   if (matchAsInsertPS(V2, V1, CommutedMask))
12762     return true;
12763 
12764   return false;
12765 }
12766 
lowerShuffleAsInsertPS(const SDLoc & DL,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,SelectionDAG & DAG)12767 static SDValue lowerShuffleAsInsertPS(const SDLoc &DL, SDValue V1, SDValue V2,
12768                                       ArrayRef<int> Mask, const APInt &Zeroable,
12769                                       SelectionDAG &DAG) {
12770   assert(V1.getSimpleValueType() == MVT::v4f32 && "Bad operand type!");
12771   assert(V2.getSimpleValueType() == MVT::v4f32 && "Bad operand type!");
12772 
12773   // Attempt to match the insertps pattern.
12774   unsigned InsertPSMask = 0;
12775   if (!matchShuffleAsInsertPS(V1, V2, InsertPSMask, Zeroable, Mask, DAG))
12776     return SDValue();
12777 
12778   // Insert the V2 element into the desired position.
12779   return DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32, V1, V2,
12780                      DAG.getTargetConstant(InsertPSMask, DL, MVT::i8));
12781 }
12782 
12783 /// Handle lowering of 2-lane 64-bit floating point shuffles.
12784 ///
12785 /// This is the basis function for the 2-lane 64-bit shuffles as we have full
12786 /// support for floating point shuffles but not integer shuffles. These
12787 /// instructions will incur a domain crossing penalty on some chips though so
12788 /// it is better to avoid lowering through this for integer vectors where
12789 /// possible.
lowerV2F64Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)12790 static SDValue lowerV2F64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
12791                                  const APInt &Zeroable, SDValue V1, SDValue V2,
12792                                  const X86Subtarget &Subtarget,
12793                                  SelectionDAG &DAG) {
12794   assert(V1.getSimpleValueType() == MVT::v2f64 && "Bad operand type!");
12795   assert(V2.getSimpleValueType() == MVT::v2f64 && "Bad operand type!");
12796   assert(Mask.size() == 2 && "Unexpected mask size for v2 shuffle!");
12797 
12798   if (V2.isUndef()) {
12799     // Check for being able to broadcast a single element.
12800     if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v2f64, V1, V2,
12801                                                     Mask, Subtarget, DAG))
12802       return Broadcast;
12803 
12804     // Straight shuffle of a single input vector. Simulate this by using the
12805     // single input as both of the "inputs" to this instruction..
12806     unsigned SHUFPDMask = (Mask[0] == 1) | ((Mask[1] == 1) << 1);
12807 
12808     if (Subtarget.hasAVX()) {
12809       // If we have AVX, we can use VPERMILPS which will allow folding a load
12810       // into the shuffle.
12811       return DAG.getNode(X86ISD::VPERMILPI, DL, MVT::v2f64, V1,
12812                          DAG.getTargetConstant(SHUFPDMask, DL, MVT::i8));
12813     }
12814 
12815     return DAG.getNode(
12816         X86ISD::SHUFP, DL, MVT::v2f64,
12817         Mask[0] == SM_SentinelUndef ? DAG.getUNDEF(MVT::v2f64) : V1,
12818         Mask[1] == SM_SentinelUndef ? DAG.getUNDEF(MVT::v2f64) : V1,
12819         DAG.getTargetConstant(SHUFPDMask, DL, MVT::i8));
12820   }
12821   assert(Mask[0] >= 0 && "No undef lanes in multi-input v2 shuffles!");
12822   assert(Mask[1] >= 0 && "No undef lanes in multi-input v2 shuffles!");
12823   assert(Mask[0] < 2 && "We sort V1 to be the first input.");
12824   assert(Mask[1] >= 2 && "We sort V2 to be the second input.");
12825 
12826   if (Subtarget.hasAVX2())
12827     if (SDValue Extract = lowerShuffleOfExtractsAsVperm(DL, V1, V2, Mask, DAG))
12828       return Extract;
12829 
12830   // When loading a scalar and then shuffling it into a vector we can often do
12831   // the insertion cheaply.
12832   if (SDValue Insertion = lowerShuffleAsElementInsertion(
12833           DL, MVT::v2f64, V1, V2, Mask, Zeroable, Subtarget, DAG))
12834     return Insertion;
12835   // Try inverting the insertion since for v2 masks it is easy to do and we
12836   // can't reliably sort the mask one way or the other.
12837   int InverseMask[2] = {Mask[0] < 0 ? -1 : (Mask[0] ^ 2),
12838                         Mask[1] < 0 ? -1 : (Mask[1] ^ 2)};
12839   if (SDValue Insertion = lowerShuffleAsElementInsertion(
12840           DL, MVT::v2f64, V2, V1, InverseMask, Zeroable, Subtarget, DAG))
12841     return Insertion;
12842 
12843   // Try to use one of the special instruction patterns to handle two common
12844   // blend patterns if a zero-blend above didn't work.
12845   if (isShuffleEquivalent(Mask, {0, 3}, V1, V2) ||
12846       isShuffleEquivalent(Mask, {1, 3}, V1, V2))
12847     if (SDValue V1S = getScalarValueForVectorElement(V1, Mask[0], DAG))
12848       // We can either use a special instruction to load over the low double or
12849       // to move just the low double.
12850       return DAG.getNode(
12851           X86ISD::MOVSD, DL, MVT::v2f64, V2,
12852           DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f64, V1S));
12853 
12854   if (Subtarget.hasSSE41())
12855     if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v2f64, V1, V2, Mask,
12856                                             Zeroable, Subtarget, DAG))
12857       return Blend;
12858 
12859   // Use dedicated unpack instructions for masks that match their pattern.
12860   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v2f64, Mask, V1, V2, DAG))
12861     return V;
12862 
12863   unsigned SHUFPDMask = (Mask[0] == 1) | (((Mask[1] - 2) == 1) << 1);
12864   return DAG.getNode(X86ISD::SHUFP, DL, MVT::v2f64, V1, V2,
12865                      DAG.getTargetConstant(SHUFPDMask, DL, MVT::i8));
12866 }
12867 
12868 /// Handle lowering of 2-lane 64-bit integer shuffles.
12869 ///
12870 /// Tries to lower a 2-lane 64-bit shuffle using shuffle operations provided by
12871 /// the integer unit to minimize domain crossing penalties. However, for blends
12872 /// it falls back to the floating point shuffle operation with appropriate bit
12873 /// casting.
lowerV2I64Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)12874 static SDValue lowerV2I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
12875                                  const APInt &Zeroable, SDValue V1, SDValue V2,
12876                                  const X86Subtarget &Subtarget,
12877                                  SelectionDAG &DAG) {
12878   assert(V1.getSimpleValueType() == MVT::v2i64 && "Bad operand type!");
12879   assert(V2.getSimpleValueType() == MVT::v2i64 && "Bad operand type!");
12880   assert(Mask.size() == 2 && "Unexpected mask size for v2 shuffle!");
12881 
12882   if (V2.isUndef()) {
12883     // Check for being able to broadcast a single element.
12884     if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v2i64, V1, V2,
12885                                                     Mask, Subtarget, DAG))
12886       return Broadcast;
12887 
12888     // Straight shuffle of a single input vector. For everything from SSE2
12889     // onward this has a single fast instruction with no scary immediates.
12890     // We have to map the mask as it is actually a v4i32 shuffle instruction.
12891     V1 = DAG.getBitcast(MVT::v4i32, V1);
12892     int WidenedMask[4] = {Mask[0] < 0 ? -1 : (Mask[0] * 2),
12893                           Mask[0] < 0 ? -1 : ((Mask[0] * 2) + 1),
12894                           Mask[1] < 0 ? -1 : (Mask[1] * 2),
12895                           Mask[1] < 0 ? -1 : ((Mask[1] * 2) + 1)};
12896     return DAG.getBitcast(
12897         MVT::v2i64,
12898         DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32, V1,
12899                     getV4X86ShuffleImm8ForMask(WidenedMask, DL, DAG)));
12900   }
12901   assert(Mask[0] != -1 && "No undef lanes in multi-input v2 shuffles!");
12902   assert(Mask[1] != -1 && "No undef lanes in multi-input v2 shuffles!");
12903   assert(Mask[0] < 2 && "We sort V1 to be the first input.");
12904   assert(Mask[1] >= 2 && "We sort V2 to be the second input.");
12905 
12906   if (Subtarget.hasAVX2())
12907     if (SDValue Extract = lowerShuffleOfExtractsAsVperm(DL, V1, V2, Mask, DAG))
12908       return Extract;
12909 
12910   // Try to use shift instructions.
12911   if (SDValue Shift =
12912           lowerShuffleAsShift(DL, MVT::v2i64, V1, V2, Mask, Zeroable, Subtarget,
12913                               DAG, /*BitwiseOnly*/ false))
12914     return Shift;
12915 
12916   // When loading a scalar and then shuffling it into a vector we can often do
12917   // the insertion cheaply.
12918   if (SDValue Insertion = lowerShuffleAsElementInsertion(
12919           DL, MVT::v2i64, V1, V2, Mask, Zeroable, Subtarget, DAG))
12920     return Insertion;
12921   // Try inverting the insertion since for v2 masks it is easy to do and we
12922   // can't reliably sort the mask one way or the other.
12923   int InverseMask[2] = {Mask[0] ^ 2, Mask[1] ^ 2};
12924   if (SDValue Insertion = lowerShuffleAsElementInsertion(
12925           DL, MVT::v2i64, V2, V1, InverseMask, Zeroable, Subtarget, DAG))
12926     return Insertion;
12927 
12928   // We have different paths for blend lowering, but they all must use the
12929   // *exact* same predicate.
12930   bool IsBlendSupported = Subtarget.hasSSE41();
12931   if (IsBlendSupported)
12932     if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v2i64, V1, V2, Mask,
12933                                             Zeroable, Subtarget, DAG))
12934       return Blend;
12935 
12936   // Use dedicated unpack instructions for masks that match their pattern.
12937   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v2i64, Mask, V1, V2, DAG))
12938     return V;
12939 
12940   // Try to use byte rotation instructions.
12941   // Its more profitable for pre-SSSE3 to use shuffles/unpacks.
12942   if (Subtarget.hasSSSE3()) {
12943     if (Subtarget.hasVLX())
12944       if (SDValue Rotate = lowerShuffleAsVALIGN(DL, MVT::v2i64, V1, V2, Mask,
12945                                                 Zeroable, Subtarget, DAG))
12946         return Rotate;
12947 
12948     if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v2i64, V1, V2, Mask,
12949                                                   Subtarget, DAG))
12950       return Rotate;
12951   }
12952 
12953   // If we have direct support for blends, we should lower by decomposing into
12954   // a permute. That will be faster than the domain cross.
12955   if (IsBlendSupported)
12956     return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v2i64, V1, V2, Mask,
12957                                                 Subtarget, DAG);
12958 
12959   // We implement this with SHUFPD which is pretty lame because it will likely
12960   // incur 2 cycles of stall for integer vectors on Nehalem and older chips.
12961   // However, all the alternatives are still more cycles and newer chips don't
12962   // have this problem. It would be really nice if x86 had better shuffles here.
12963   V1 = DAG.getBitcast(MVT::v2f64, V1);
12964   V2 = DAG.getBitcast(MVT::v2f64, V2);
12965   return DAG.getBitcast(MVT::v2i64,
12966                         DAG.getVectorShuffle(MVT::v2f64, DL, V1, V2, Mask));
12967 }
12968 
12969 /// Lower a vector shuffle using the SHUFPS instruction.
12970 ///
12971 /// This is a helper routine dedicated to lowering vector shuffles using SHUFPS.
12972 /// It makes no assumptions about whether this is the *best* lowering, it simply
12973 /// uses it.
lowerShuffleWithSHUFPS(const SDLoc & DL,MVT VT,ArrayRef<int> Mask,SDValue V1,SDValue V2,SelectionDAG & DAG)12974 static SDValue lowerShuffleWithSHUFPS(const SDLoc &DL, MVT VT,
12975                                       ArrayRef<int> Mask, SDValue V1,
12976                                       SDValue V2, SelectionDAG &DAG) {
12977   SDValue LowV = V1, HighV = V2;
12978   SmallVector<int, 4> NewMask(Mask);
12979   int NumV2Elements = count_if(Mask, [](int M) { return M >= 4; });
12980 
12981   if (NumV2Elements == 1) {
12982     int V2Index = find_if(Mask, [](int M) { return M >= 4; }) - Mask.begin();
12983 
12984     // Compute the index adjacent to V2Index and in the same half by toggling
12985     // the low bit.
12986     int V2AdjIndex = V2Index ^ 1;
12987 
12988     if (Mask[V2AdjIndex] < 0) {
12989       // Handles all the cases where we have a single V2 element and an undef.
12990       // This will only ever happen in the high lanes because we commute the
12991       // vector otherwise.
12992       if (V2Index < 2)
12993         std::swap(LowV, HighV);
12994       NewMask[V2Index] -= 4;
12995     } else {
12996       // Handle the case where the V2 element ends up adjacent to a V1 element.
12997       // To make this work, blend them together as the first step.
12998       int V1Index = V2AdjIndex;
12999       int BlendMask[4] = {Mask[V2Index] - 4, 0, Mask[V1Index], 0};
13000       V2 = DAG.getNode(X86ISD::SHUFP, DL, VT, V2, V1,
13001                        getV4X86ShuffleImm8ForMask(BlendMask, DL, DAG));
13002 
13003       // Now proceed to reconstruct the final blend as we have the necessary
13004       // high or low half formed.
13005       if (V2Index < 2) {
13006         LowV = V2;
13007         HighV = V1;
13008       } else {
13009         HighV = V2;
13010       }
13011       NewMask[V1Index] = 2; // We put the V1 element in V2[2].
13012       NewMask[V2Index] = 0; // We shifted the V2 element into V2[0].
13013     }
13014   } else if (NumV2Elements == 2) {
13015     if (Mask[0] < 4 && Mask[1] < 4) {
13016       // Handle the easy case where we have V1 in the low lanes and V2 in the
13017       // high lanes.
13018       NewMask[2] -= 4;
13019       NewMask[3] -= 4;
13020     } else if (Mask[2] < 4 && Mask[3] < 4) {
13021       // We also handle the reversed case because this utility may get called
13022       // when we detect a SHUFPS pattern but can't easily commute the shuffle to
13023       // arrange things in the right direction.
13024       NewMask[0] -= 4;
13025       NewMask[1] -= 4;
13026       HighV = V1;
13027       LowV = V2;
13028     } else {
13029       // We have a mixture of V1 and V2 in both low and high lanes. Rather than
13030       // trying to place elements directly, just blend them and set up the final
13031       // shuffle to place them.
13032 
13033       // The first two blend mask elements are for V1, the second two are for
13034       // V2.
13035       int BlendMask[4] = {Mask[0] < 4 ? Mask[0] : Mask[1],
13036                           Mask[2] < 4 ? Mask[2] : Mask[3],
13037                           (Mask[0] >= 4 ? Mask[0] : Mask[1]) - 4,
13038                           (Mask[2] >= 4 ? Mask[2] : Mask[3]) - 4};
13039       V1 = DAG.getNode(X86ISD::SHUFP, DL, VT, V1, V2,
13040                        getV4X86ShuffleImm8ForMask(BlendMask, DL, DAG));
13041 
13042       // Now we do a normal shuffle of V1 by giving V1 as both operands to
13043       // a blend.
13044       LowV = HighV = V1;
13045       NewMask[0] = Mask[0] < 4 ? 0 : 2;
13046       NewMask[1] = Mask[0] < 4 ? 2 : 0;
13047       NewMask[2] = Mask[2] < 4 ? 1 : 3;
13048       NewMask[3] = Mask[2] < 4 ? 3 : 1;
13049     }
13050   } else if (NumV2Elements == 3) {
13051     // Ideally canonicalizeShuffleMaskWithCommute should have caught this, but
13052     // we can get here due to other paths (e.g repeated mask matching) that we
13053     // don't want to do another round of lowerVECTOR_SHUFFLE.
13054     ShuffleVectorSDNode::commuteMask(NewMask);
13055     return lowerShuffleWithSHUFPS(DL, VT, NewMask, V2, V1, DAG);
13056   }
13057   return DAG.getNode(X86ISD::SHUFP, DL, VT, LowV, HighV,
13058                      getV4X86ShuffleImm8ForMask(NewMask, DL, DAG));
13059 }
13060 
13061 /// Lower 4-lane 32-bit floating point shuffles.
13062 ///
13063 /// Uses instructions exclusively from the floating point unit to minimize
13064 /// domain crossing penalties, as these are sufficient to implement all v4f32
13065 /// shuffles.
lowerV4F32Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)13066 static SDValue lowerV4F32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
13067                                  const APInt &Zeroable, SDValue V1, SDValue V2,
13068                                  const X86Subtarget &Subtarget,
13069                                  SelectionDAG &DAG) {
13070   assert(V1.getSimpleValueType() == MVT::v4f32 && "Bad operand type!");
13071   assert(V2.getSimpleValueType() == MVT::v4f32 && "Bad operand type!");
13072   assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!");
13073 
13074   if (Subtarget.hasSSE41())
13075     if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v4f32, V1, V2, Mask,
13076                                             Zeroable, Subtarget, DAG))
13077       return Blend;
13078 
13079   int NumV2Elements = count_if(Mask, [](int M) { return M >= 4; });
13080 
13081   if (NumV2Elements == 0) {
13082     // Check for being able to broadcast a single element.
13083     if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v4f32, V1, V2,
13084                                                     Mask, Subtarget, DAG))
13085       return Broadcast;
13086 
13087     // Use even/odd duplicate instructions for masks that match their pattern.
13088     if (Subtarget.hasSSE3()) {
13089       if (isShuffleEquivalent(Mask, {0, 0, 2, 2}, V1, V2))
13090         return DAG.getNode(X86ISD::MOVSLDUP, DL, MVT::v4f32, V1);
13091       if (isShuffleEquivalent(Mask, {1, 1, 3, 3}, V1, V2))
13092         return DAG.getNode(X86ISD::MOVSHDUP, DL, MVT::v4f32, V1);
13093     }
13094 
13095     if (Subtarget.hasAVX()) {
13096       // If we have AVX, we can use VPERMILPS which will allow folding a load
13097       // into the shuffle.
13098       return DAG.getNode(X86ISD::VPERMILPI, DL, MVT::v4f32, V1,
13099                          getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
13100     }
13101 
13102     // Use MOVLHPS/MOVHLPS to simulate unary shuffles. These are only valid
13103     // in SSE1 because otherwise they are widened to v2f64 and never get here.
13104     if (!Subtarget.hasSSE2()) {
13105       if (isShuffleEquivalent(Mask, {0, 1, 0, 1}, V1, V2))
13106         return DAG.getNode(X86ISD::MOVLHPS, DL, MVT::v4f32, V1, V1);
13107       if (isShuffleEquivalent(Mask, {2, 3, 2, 3}, V1, V2))
13108         return DAG.getNode(X86ISD::MOVHLPS, DL, MVT::v4f32, V1, V1);
13109     }
13110 
13111     // Otherwise, use a straight shuffle of a single input vector. We pass the
13112     // input vector to both operands to simulate this with a SHUFPS.
13113     return DAG.getNode(X86ISD::SHUFP, DL, MVT::v4f32, V1, V1,
13114                        getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
13115   }
13116 
13117   if (Subtarget.hasSSE2())
13118     if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(
13119             DL, MVT::v4i32, V1, V2, Mask, Zeroable, Subtarget, DAG)) {
13120       ZExt = DAG.getBitcast(MVT::v4f32, ZExt);
13121       return ZExt;
13122     }
13123 
13124   if (Subtarget.hasAVX2())
13125     if (SDValue Extract = lowerShuffleOfExtractsAsVperm(DL, V1, V2, Mask, DAG))
13126       return Extract;
13127 
13128   // There are special ways we can lower some single-element blends. However, we
13129   // have custom ways we can lower more complex single-element blends below that
13130   // we defer to if both this and BLENDPS fail to match, so restrict this to
13131   // when the V2 input is targeting element 0 of the mask -- that is the fast
13132   // case here.
13133   if (NumV2Elements == 1 && Mask[0] >= 4)
13134     if (SDValue V = lowerShuffleAsElementInsertion(
13135             DL, MVT::v4f32, V1, V2, Mask, Zeroable, Subtarget, DAG))
13136       return V;
13137 
13138   if (Subtarget.hasSSE41()) {
13139     // Use INSERTPS if we can complete the shuffle efficiently.
13140     if (SDValue V = lowerShuffleAsInsertPS(DL, V1, V2, Mask, Zeroable, DAG))
13141       return V;
13142 
13143     if (!isSingleSHUFPSMask(Mask))
13144       if (SDValue BlendPerm = lowerShuffleAsBlendAndPermute(DL, MVT::v4f32, V1,
13145                                                             V2, Mask, DAG))
13146         return BlendPerm;
13147   }
13148 
13149   // Use low/high mov instructions. These are only valid in SSE1 because
13150   // otherwise they are widened to v2f64 and never get here.
13151   if (!Subtarget.hasSSE2()) {
13152     if (isShuffleEquivalent(Mask, {0, 1, 4, 5}, V1, V2))
13153       return DAG.getNode(X86ISD::MOVLHPS, DL, MVT::v4f32, V1, V2);
13154     if (isShuffleEquivalent(Mask, {2, 3, 6, 7}, V1, V2))
13155       return DAG.getNode(X86ISD::MOVHLPS, DL, MVT::v4f32, V2, V1);
13156   }
13157 
13158   // Use dedicated unpack instructions for masks that match their pattern.
13159   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v4f32, Mask, V1, V2, DAG))
13160     return V;
13161 
13162   // Otherwise fall back to a SHUFPS lowering strategy.
13163   return lowerShuffleWithSHUFPS(DL, MVT::v4f32, Mask, V1, V2, DAG);
13164 }
13165 
13166 /// Lower 4-lane i32 vector shuffles.
13167 ///
13168 /// We try to handle these with integer-domain shuffles where we can, but for
13169 /// blends we use the floating point domain blend instructions.
lowerV4I32Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)13170 static SDValue lowerV4I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
13171                                  const APInt &Zeroable, SDValue V1, SDValue V2,
13172                                  const X86Subtarget &Subtarget,
13173                                  SelectionDAG &DAG) {
13174   assert(V1.getSimpleValueType() == MVT::v4i32 && "Bad operand type!");
13175   assert(V2.getSimpleValueType() == MVT::v4i32 && "Bad operand type!");
13176   assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!");
13177 
13178   // Whenever we can lower this as a zext, that instruction is strictly faster
13179   // than any alternative. It also allows us to fold memory operands into the
13180   // shuffle in many cases.
13181   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(DL, MVT::v4i32, V1, V2, Mask,
13182                                                    Zeroable, Subtarget, DAG))
13183     return ZExt;
13184 
13185   int NumV2Elements = count_if(Mask, [](int M) { return M >= 4; });
13186 
13187   // Try to use shift instructions if fast.
13188   if (Subtarget.preferLowerShuffleAsShift()) {
13189     if (SDValue Shift =
13190             lowerShuffleAsShift(DL, MVT::v4i32, V1, V2, Mask, Zeroable,
13191                                 Subtarget, DAG, /*BitwiseOnly*/ true))
13192       return Shift;
13193     if (NumV2Elements == 0)
13194       if (SDValue Rotate =
13195               lowerShuffleAsBitRotate(DL, MVT::v4i32, V1, Mask, Subtarget, DAG))
13196         return Rotate;
13197   }
13198 
13199   if (NumV2Elements == 0) {
13200     // Try to use broadcast unless the mask only has one non-undef element.
13201     if (count_if(Mask, [](int M) { return M >= 0 && M < 4; }) > 1) {
13202       if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v4i32, V1, V2,
13203                                                       Mask, Subtarget, DAG))
13204         return Broadcast;
13205     }
13206 
13207     // Straight shuffle of a single input vector. For everything from SSE2
13208     // onward this has a single fast instruction with no scary immediates.
13209     // We coerce the shuffle pattern to be compatible with UNPCK instructions
13210     // but we aren't actually going to use the UNPCK instruction because doing
13211     // so prevents folding a load into this instruction or making a copy.
13212     const int UnpackLoMask[] = {0, 0, 1, 1};
13213     const int UnpackHiMask[] = {2, 2, 3, 3};
13214     if (isShuffleEquivalent(Mask, {0, 0, 1, 1}, V1, V2))
13215       Mask = UnpackLoMask;
13216     else if (isShuffleEquivalent(Mask, {2, 2, 3, 3}, V1, V2))
13217       Mask = UnpackHiMask;
13218 
13219     return DAG.getNode(X86ISD::PSHUFD, DL, MVT::v4i32, V1,
13220                        getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
13221   }
13222 
13223   if (Subtarget.hasAVX2())
13224     if (SDValue Extract = lowerShuffleOfExtractsAsVperm(DL, V1, V2, Mask, DAG))
13225       return Extract;
13226 
13227   // Try to use shift instructions.
13228   if (SDValue Shift =
13229           lowerShuffleAsShift(DL, MVT::v4i32, V1, V2, Mask, Zeroable, Subtarget,
13230                               DAG, /*BitwiseOnly*/ false))
13231     return Shift;
13232 
13233   // There are special ways we can lower some single-element blends.
13234   if (NumV2Elements == 1)
13235     if (SDValue V = lowerShuffleAsElementInsertion(
13236             DL, MVT::v4i32, V1, V2, Mask, Zeroable, Subtarget, DAG))
13237       return V;
13238 
13239   // We have different paths for blend lowering, but they all must use the
13240   // *exact* same predicate.
13241   bool IsBlendSupported = Subtarget.hasSSE41();
13242   if (IsBlendSupported)
13243     if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v4i32, V1, V2, Mask,
13244                                             Zeroable, Subtarget, DAG))
13245       return Blend;
13246 
13247   if (SDValue Masked = lowerShuffleAsBitMask(DL, MVT::v4i32, V1, V2, Mask,
13248                                              Zeroable, Subtarget, DAG))
13249     return Masked;
13250 
13251   // Use dedicated unpack instructions for masks that match their pattern.
13252   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v4i32, Mask, V1, V2, DAG))
13253     return V;
13254 
13255   // Try to use byte rotation instructions.
13256   // Its more profitable for pre-SSSE3 to use shuffles/unpacks.
13257   if (Subtarget.hasSSSE3()) {
13258     if (Subtarget.hasVLX())
13259       if (SDValue Rotate = lowerShuffleAsVALIGN(DL, MVT::v4i32, V1, V2, Mask,
13260                                                 Zeroable, Subtarget, DAG))
13261         return Rotate;
13262 
13263     if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v4i32, V1, V2, Mask,
13264                                                   Subtarget, DAG))
13265       return Rotate;
13266   }
13267 
13268   // Assume that a single SHUFPS is faster than an alternative sequence of
13269   // multiple instructions (even if the CPU has a domain penalty).
13270   // If some CPU is harmed by the domain switch, we can fix it in a later pass.
13271   if (!isSingleSHUFPSMask(Mask)) {
13272     // If we have direct support for blends, we should lower by decomposing into
13273     // a permute. That will be faster than the domain cross.
13274     if (IsBlendSupported)
13275       return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v4i32, V1, V2, Mask,
13276                                                   Subtarget, DAG);
13277 
13278     // Try to lower by permuting the inputs into an unpack instruction.
13279     if (SDValue Unpack = lowerShuffleAsPermuteAndUnpack(DL, MVT::v4i32, V1, V2,
13280                                                         Mask, Subtarget, DAG))
13281       return Unpack;
13282   }
13283 
13284   // We implement this with SHUFPS because it can blend from two vectors.
13285   // Because we're going to eventually use SHUFPS, we use SHUFPS even to build
13286   // up the inputs, bypassing domain shift penalties that we would incur if we
13287   // directly used PSHUFD on Nehalem and older. For newer chips, this isn't
13288   // relevant.
13289   SDValue CastV1 = DAG.getBitcast(MVT::v4f32, V1);
13290   SDValue CastV2 = DAG.getBitcast(MVT::v4f32, V2);
13291   SDValue ShufPS = DAG.getVectorShuffle(MVT::v4f32, DL, CastV1, CastV2, Mask);
13292   return DAG.getBitcast(MVT::v4i32, ShufPS);
13293 }
13294 
13295 /// Lowering of single-input v8i16 shuffles is the cornerstone of SSE2
13296 /// shuffle lowering, and the most complex part.
13297 ///
13298 /// The lowering strategy is to try to form pairs of input lanes which are
13299 /// targeted at the same half of the final vector, and then use a dword shuffle
13300 /// to place them onto the right half, and finally unpack the paired lanes into
13301 /// their final position.
13302 ///
13303 /// The exact breakdown of how to form these dword pairs and align them on the
13304 /// correct sides is really tricky. See the comments within the function for
13305 /// more of the details.
13306 ///
13307 /// This code also handles repeated 128-bit lanes of v8i16 shuffles, but each
13308 /// lane must shuffle the *exact* same way. In fact, you must pass a v8 Mask to
13309 /// this routine for it to work correctly. To shuffle a 256-bit or 512-bit i16
13310 /// vector, form the analogous 128-bit 8-element Mask.
lowerV8I16GeneralSingleInputShuffle(const SDLoc & DL,MVT VT,SDValue V,MutableArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)13311 static SDValue lowerV8I16GeneralSingleInputShuffle(
13312     const SDLoc &DL, MVT VT, SDValue V, MutableArrayRef<int> Mask,
13313     const X86Subtarget &Subtarget, SelectionDAG &DAG) {
13314   assert(VT.getVectorElementType() == MVT::i16 && "Bad input type!");
13315   MVT PSHUFDVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2);
13316 
13317   assert(Mask.size() == 8 && "Shuffle mask length doesn't match!");
13318   MutableArrayRef<int> LoMask = Mask.slice(0, 4);
13319   MutableArrayRef<int> HiMask = Mask.slice(4, 4);
13320 
13321   // Attempt to directly match PSHUFLW or PSHUFHW.
13322   if (isUndefOrInRange(LoMask, 0, 4) &&
13323       isSequentialOrUndefInRange(HiMask, 0, 4, 4)) {
13324     return DAG.getNode(X86ISD::PSHUFLW, DL, VT, V,
13325                        getV4X86ShuffleImm8ForMask(LoMask, DL, DAG));
13326   }
13327   if (isUndefOrInRange(HiMask, 4, 8) &&
13328       isSequentialOrUndefInRange(LoMask, 0, 4, 0)) {
13329     for (int i = 0; i != 4; ++i)
13330       HiMask[i] = (HiMask[i] < 0 ? HiMask[i] : (HiMask[i] - 4));
13331     return DAG.getNode(X86ISD::PSHUFHW, DL, VT, V,
13332                        getV4X86ShuffleImm8ForMask(HiMask, DL, DAG));
13333   }
13334 
13335   SmallVector<int, 4> LoInputs;
13336   copy_if(LoMask, std::back_inserter(LoInputs), [](int M) { return M >= 0; });
13337   array_pod_sort(LoInputs.begin(), LoInputs.end());
13338   LoInputs.erase(llvm::unique(LoInputs), LoInputs.end());
13339   SmallVector<int, 4> HiInputs;
13340   copy_if(HiMask, std::back_inserter(HiInputs), [](int M) { return M >= 0; });
13341   array_pod_sort(HiInputs.begin(), HiInputs.end());
13342   HiInputs.erase(llvm::unique(HiInputs), HiInputs.end());
13343   int NumLToL = llvm::lower_bound(LoInputs, 4) - LoInputs.begin();
13344   int NumHToL = LoInputs.size() - NumLToL;
13345   int NumLToH = llvm::lower_bound(HiInputs, 4) - HiInputs.begin();
13346   int NumHToH = HiInputs.size() - NumLToH;
13347   MutableArrayRef<int> LToLInputs(LoInputs.data(), NumLToL);
13348   MutableArrayRef<int> LToHInputs(HiInputs.data(), NumLToH);
13349   MutableArrayRef<int> HToLInputs(LoInputs.data() + NumLToL, NumHToL);
13350   MutableArrayRef<int> HToHInputs(HiInputs.data() + NumLToH, NumHToH);
13351 
13352   // If we are shuffling values from one half - check how many different DWORD
13353   // pairs we need to create. If only 1 or 2 then we can perform this as a
13354   // PSHUFLW/PSHUFHW + PSHUFD instead of the PSHUFD+PSHUFLW+PSHUFHW chain below.
13355   auto ShuffleDWordPairs = [&](ArrayRef<int> PSHUFHalfMask,
13356                                ArrayRef<int> PSHUFDMask, unsigned ShufWOp) {
13357     V = DAG.getNode(ShufWOp, DL, VT, V,
13358                     getV4X86ShuffleImm8ForMask(PSHUFHalfMask, DL, DAG));
13359     V = DAG.getBitcast(PSHUFDVT, V);
13360     V = DAG.getNode(X86ISD::PSHUFD, DL, PSHUFDVT, V,
13361                     getV4X86ShuffleImm8ForMask(PSHUFDMask, DL, DAG));
13362     return DAG.getBitcast(VT, V);
13363   };
13364 
13365   if ((NumHToL + NumHToH) == 0 || (NumLToL + NumLToH) == 0) {
13366     int PSHUFDMask[4] = { -1, -1, -1, -1 };
13367     SmallVector<std::pair<int, int>, 4> DWordPairs;
13368     int DOffset = ((NumHToL + NumHToH) == 0 ? 0 : 2);
13369 
13370     // Collect the different DWORD pairs.
13371     for (int DWord = 0; DWord != 4; ++DWord) {
13372       int M0 = Mask[2 * DWord + 0];
13373       int M1 = Mask[2 * DWord + 1];
13374       M0 = (M0 >= 0 ? M0 % 4 : M0);
13375       M1 = (M1 >= 0 ? M1 % 4 : M1);
13376       if (M0 < 0 && M1 < 0)
13377         continue;
13378 
13379       bool Match = false;
13380       for (int j = 0, e = DWordPairs.size(); j < e; ++j) {
13381         auto &DWordPair = DWordPairs[j];
13382         if ((M0 < 0 || isUndefOrEqual(DWordPair.first, M0)) &&
13383             (M1 < 0 || isUndefOrEqual(DWordPair.second, M1))) {
13384           DWordPair.first = (M0 >= 0 ? M0 : DWordPair.first);
13385           DWordPair.second = (M1 >= 0 ? M1 : DWordPair.second);
13386           PSHUFDMask[DWord] = DOffset + j;
13387           Match = true;
13388           break;
13389         }
13390       }
13391       if (!Match) {
13392         PSHUFDMask[DWord] = DOffset + DWordPairs.size();
13393         DWordPairs.push_back(std::make_pair(M0, M1));
13394       }
13395     }
13396 
13397     if (DWordPairs.size() <= 2) {
13398       DWordPairs.resize(2, std::make_pair(-1, -1));
13399       int PSHUFHalfMask[4] = {DWordPairs[0].first, DWordPairs[0].second,
13400                               DWordPairs[1].first, DWordPairs[1].second};
13401       if ((NumHToL + NumHToH) == 0)
13402         return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFLW);
13403       if ((NumLToL + NumLToH) == 0)
13404         return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFHW);
13405     }
13406   }
13407 
13408   // Simplify the 1-into-3 and 3-into-1 cases with a single pshufd. For all
13409   // such inputs we can swap two of the dwords across the half mark and end up
13410   // with <=2 inputs to each half in each half. Once there, we can fall through
13411   // to the generic code below. For example:
13412   //
13413   // Input: [a, b, c, d, e, f, g, h] -PSHUFD[0,2,1,3]-> [a, b, e, f, c, d, g, h]
13414   // Mask:  [0, 1, 2, 7, 4, 5, 6, 3] -----------------> [0, 1, 4, 7, 2, 3, 6, 5]
13415   //
13416   // However in some very rare cases we have a 1-into-3 or 3-into-1 on one half
13417   // and an existing 2-into-2 on the other half. In this case we may have to
13418   // pre-shuffle the 2-into-2 half to avoid turning it into a 3-into-1 or
13419   // 1-into-3 which could cause us to cycle endlessly fixing each side in turn.
13420   // Fortunately, we don't have to handle anything but a 2-into-2 pattern
13421   // because any other situation (including a 3-into-1 or 1-into-3 in the other
13422   // half than the one we target for fixing) will be fixed when we re-enter this
13423   // path. We will also combine away any sequence of PSHUFD instructions that
13424   // result into a single instruction. Here is an example of the tricky case:
13425   //
13426   // Input: [a, b, c, d, e, f, g, h] -PSHUFD[0,2,1,3]-> [a, b, e, f, c, d, g, h]
13427   // Mask:  [3, 7, 1, 0, 2, 7, 3, 5] -THIS-IS-BAD!!!!-> [5, 7, 1, 0, 4, 7, 5, 3]
13428   //
13429   // This now has a 1-into-3 in the high half! Instead, we do two shuffles:
13430   //
13431   // Input: [a, b, c, d, e, f, g, h] PSHUFHW[0,2,1,3]-> [a, b, c, d, e, g, f, h]
13432   // Mask:  [3, 7, 1, 0, 2, 7, 3, 5] -----------------> [3, 7, 1, 0, 2, 7, 3, 6]
13433   //
13434   // Input: [a, b, c, d, e, g, f, h] -PSHUFD[0,2,1,3]-> [a, b, e, g, c, d, f, h]
13435   // Mask:  [3, 7, 1, 0, 2, 7, 3, 6] -----------------> [5, 7, 1, 0, 4, 7, 5, 6]
13436   //
13437   // The result is fine to be handled by the generic logic.
13438   auto balanceSides = [&](ArrayRef<int> AToAInputs, ArrayRef<int> BToAInputs,
13439                           ArrayRef<int> BToBInputs, ArrayRef<int> AToBInputs,
13440                           int AOffset, int BOffset) {
13441     assert((AToAInputs.size() == 3 || AToAInputs.size() == 1) &&
13442            "Must call this with A having 3 or 1 inputs from the A half.");
13443     assert((BToAInputs.size() == 1 || BToAInputs.size() == 3) &&
13444            "Must call this with B having 1 or 3 inputs from the B half.");
13445     assert(AToAInputs.size() + BToAInputs.size() == 4 &&
13446            "Must call this with either 3:1 or 1:3 inputs (summing to 4).");
13447 
13448     bool ThreeAInputs = AToAInputs.size() == 3;
13449 
13450     // Compute the index of dword with only one word among the three inputs in
13451     // a half by taking the sum of the half with three inputs and subtracting
13452     // the sum of the actual three inputs. The difference is the remaining
13453     // slot.
13454     int ADWord = 0, BDWord = 0;
13455     int &TripleDWord = ThreeAInputs ? ADWord : BDWord;
13456     int &OneInputDWord = ThreeAInputs ? BDWord : ADWord;
13457     int TripleInputOffset = ThreeAInputs ? AOffset : BOffset;
13458     ArrayRef<int> TripleInputs = ThreeAInputs ? AToAInputs : BToAInputs;
13459     int OneInput = ThreeAInputs ? BToAInputs[0] : AToAInputs[0];
13460     int TripleInputSum = 0 + 1 + 2 + 3 + (4 * TripleInputOffset);
13461     int TripleNonInputIdx =
13462         TripleInputSum - std::accumulate(TripleInputs.begin(), TripleInputs.end(), 0);
13463     TripleDWord = TripleNonInputIdx / 2;
13464 
13465     // We use xor with one to compute the adjacent DWord to whichever one the
13466     // OneInput is in.
13467     OneInputDWord = (OneInput / 2) ^ 1;
13468 
13469     // Check for one tricky case: We're fixing a 3<-1 or a 1<-3 shuffle for AToA
13470     // and BToA inputs. If there is also such a problem with the BToB and AToB
13471     // inputs, we don't try to fix it necessarily -- we'll recurse and see it in
13472     // the next pass. However, if we have a 2<-2 in the BToB and AToB inputs, it
13473     // is essential that we don't *create* a 3<-1 as then we might oscillate.
13474     if (BToBInputs.size() == 2 && AToBInputs.size() == 2) {
13475       // Compute how many inputs will be flipped by swapping these DWords. We
13476       // need
13477       // to balance this to ensure we don't form a 3-1 shuffle in the other
13478       // half.
13479       int NumFlippedAToBInputs = llvm::count(AToBInputs, 2 * ADWord) +
13480                                  llvm::count(AToBInputs, 2 * ADWord + 1);
13481       int NumFlippedBToBInputs = llvm::count(BToBInputs, 2 * BDWord) +
13482                                  llvm::count(BToBInputs, 2 * BDWord + 1);
13483       if ((NumFlippedAToBInputs == 1 &&
13484            (NumFlippedBToBInputs == 0 || NumFlippedBToBInputs == 2)) ||
13485           (NumFlippedBToBInputs == 1 &&
13486            (NumFlippedAToBInputs == 0 || NumFlippedAToBInputs == 2))) {
13487         // We choose whether to fix the A half or B half based on whether that
13488         // half has zero flipped inputs. At zero, we may not be able to fix it
13489         // with that half. We also bias towards fixing the B half because that
13490         // will more commonly be the high half, and we have to bias one way.
13491         auto FixFlippedInputs = [&V, &DL, &Mask, &DAG](int PinnedIdx, int DWord,
13492                                                        ArrayRef<int> Inputs) {
13493           int FixIdx = PinnedIdx ^ 1; // The adjacent slot to the pinned slot.
13494           bool IsFixIdxInput = is_contained(Inputs, PinnedIdx ^ 1);
13495           // Determine whether the free index is in the flipped dword or the
13496           // unflipped dword based on where the pinned index is. We use this bit
13497           // in an xor to conditionally select the adjacent dword.
13498           int FixFreeIdx = 2 * (DWord ^ (PinnedIdx / 2 == DWord));
13499           bool IsFixFreeIdxInput = is_contained(Inputs, FixFreeIdx);
13500           if (IsFixIdxInput == IsFixFreeIdxInput)
13501             FixFreeIdx += 1;
13502           IsFixFreeIdxInput = is_contained(Inputs, FixFreeIdx);
13503           assert(IsFixIdxInput != IsFixFreeIdxInput &&
13504                  "We need to be changing the number of flipped inputs!");
13505           int PSHUFHalfMask[] = {0, 1, 2, 3};
13506           std::swap(PSHUFHalfMask[FixFreeIdx % 4], PSHUFHalfMask[FixIdx % 4]);
13507           V = DAG.getNode(
13508               FixIdx < 4 ? X86ISD::PSHUFLW : X86ISD::PSHUFHW, DL,
13509               MVT::getVectorVT(MVT::i16, V.getValueSizeInBits() / 16), V,
13510               getV4X86ShuffleImm8ForMask(PSHUFHalfMask, DL, DAG));
13511 
13512           for (int &M : Mask)
13513             if (M >= 0 && M == FixIdx)
13514               M = FixFreeIdx;
13515             else if (M >= 0 && M == FixFreeIdx)
13516               M = FixIdx;
13517         };
13518         if (NumFlippedBToBInputs != 0) {
13519           int BPinnedIdx =
13520               BToAInputs.size() == 3 ? TripleNonInputIdx : OneInput;
13521           FixFlippedInputs(BPinnedIdx, BDWord, BToBInputs);
13522         } else {
13523           assert(NumFlippedAToBInputs != 0 && "Impossible given predicates!");
13524           int APinnedIdx = ThreeAInputs ? TripleNonInputIdx : OneInput;
13525           FixFlippedInputs(APinnedIdx, ADWord, AToBInputs);
13526         }
13527       }
13528     }
13529 
13530     int PSHUFDMask[] = {0, 1, 2, 3};
13531     PSHUFDMask[ADWord] = BDWord;
13532     PSHUFDMask[BDWord] = ADWord;
13533     V = DAG.getBitcast(
13534         VT,
13535         DAG.getNode(X86ISD::PSHUFD, DL, PSHUFDVT, DAG.getBitcast(PSHUFDVT, V),
13536                     getV4X86ShuffleImm8ForMask(PSHUFDMask, DL, DAG)));
13537 
13538     // Adjust the mask to match the new locations of A and B.
13539     for (int &M : Mask)
13540       if (M >= 0 && M/2 == ADWord)
13541         M = 2 * BDWord + M % 2;
13542       else if (M >= 0 && M/2 == BDWord)
13543         M = 2 * ADWord + M % 2;
13544 
13545     // Recurse back into this routine to re-compute state now that this isn't
13546     // a 3 and 1 problem.
13547     return lowerV8I16GeneralSingleInputShuffle(DL, VT, V, Mask, Subtarget, DAG);
13548   };
13549   if ((NumLToL == 3 && NumHToL == 1) || (NumLToL == 1 && NumHToL == 3))
13550     return balanceSides(LToLInputs, HToLInputs, HToHInputs, LToHInputs, 0, 4);
13551   if ((NumHToH == 3 && NumLToH == 1) || (NumHToH == 1 && NumLToH == 3))
13552     return balanceSides(HToHInputs, LToHInputs, LToLInputs, HToLInputs, 4, 0);
13553 
13554   // At this point there are at most two inputs to the low and high halves from
13555   // each half. That means the inputs can always be grouped into dwords and
13556   // those dwords can then be moved to the correct half with a dword shuffle.
13557   // We use at most one low and one high word shuffle to collect these paired
13558   // inputs into dwords, and finally a dword shuffle to place them.
13559   int PSHUFLMask[4] = {-1, -1, -1, -1};
13560   int PSHUFHMask[4] = {-1, -1, -1, -1};
13561   int PSHUFDMask[4] = {-1, -1, -1, -1};
13562 
13563   // First fix the masks for all the inputs that are staying in their
13564   // original halves. This will then dictate the targets of the cross-half
13565   // shuffles.
13566   auto fixInPlaceInputs =
13567       [&PSHUFDMask](ArrayRef<int> InPlaceInputs, ArrayRef<int> IncomingInputs,
13568                     MutableArrayRef<int> SourceHalfMask,
13569                     MutableArrayRef<int> HalfMask, int HalfOffset) {
13570     if (InPlaceInputs.empty())
13571       return;
13572     if (InPlaceInputs.size() == 1) {
13573       SourceHalfMask[InPlaceInputs[0] - HalfOffset] =
13574           InPlaceInputs[0] - HalfOffset;
13575       PSHUFDMask[InPlaceInputs[0] / 2] = InPlaceInputs[0] / 2;
13576       return;
13577     }
13578     if (IncomingInputs.empty()) {
13579       // Just fix all of the in place inputs.
13580       for (int Input : InPlaceInputs) {
13581         SourceHalfMask[Input - HalfOffset] = Input - HalfOffset;
13582         PSHUFDMask[Input / 2] = Input / 2;
13583       }
13584       return;
13585     }
13586 
13587     assert(InPlaceInputs.size() == 2 && "Cannot handle 3 or 4 inputs!");
13588     SourceHalfMask[InPlaceInputs[0] - HalfOffset] =
13589         InPlaceInputs[0] - HalfOffset;
13590     // Put the second input next to the first so that they are packed into
13591     // a dword. We find the adjacent index by toggling the low bit.
13592     int AdjIndex = InPlaceInputs[0] ^ 1;
13593     SourceHalfMask[AdjIndex - HalfOffset] = InPlaceInputs[1] - HalfOffset;
13594     std::replace(HalfMask.begin(), HalfMask.end(), InPlaceInputs[1], AdjIndex);
13595     PSHUFDMask[AdjIndex / 2] = AdjIndex / 2;
13596   };
13597   fixInPlaceInputs(LToLInputs, HToLInputs, PSHUFLMask, LoMask, 0);
13598   fixInPlaceInputs(HToHInputs, LToHInputs, PSHUFHMask, HiMask, 4);
13599 
13600   // Now gather the cross-half inputs and place them into a free dword of
13601   // their target half.
13602   // FIXME: This operation could almost certainly be simplified dramatically to
13603   // look more like the 3-1 fixing operation.
13604   auto moveInputsToRightHalf = [&PSHUFDMask](
13605       MutableArrayRef<int> IncomingInputs, ArrayRef<int> ExistingInputs,
13606       MutableArrayRef<int> SourceHalfMask, MutableArrayRef<int> HalfMask,
13607       MutableArrayRef<int> FinalSourceHalfMask, int SourceOffset,
13608       int DestOffset) {
13609     auto isWordClobbered = [](ArrayRef<int> SourceHalfMask, int Word) {
13610       return SourceHalfMask[Word] >= 0 && SourceHalfMask[Word] != Word;
13611     };
13612     auto isDWordClobbered = [&isWordClobbered](ArrayRef<int> SourceHalfMask,
13613                                                int Word) {
13614       int LowWord = Word & ~1;
13615       int HighWord = Word | 1;
13616       return isWordClobbered(SourceHalfMask, LowWord) ||
13617              isWordClobbered(SourceHalfMask, HighWord);
13618     };
13619 
13620     if (IncomingInputs.empty())
13621       return;
13622 
13623     if (ExistingInputs.empty()) {
13624       // Map any dwords with inputs from them into the right half.
13625       for (int Input : IncomingInputs) {
13626         // If the source half mask maps over the inputs, turn those into
13627         // swaps and use the swapped lane.
13628         if (isWordClobbered(SourceHalfMask, Input - SourceOffset)) {
13629           if (SourceHalfMask[SourceHalfMask[Input - SourceOffset]] < 0) {
13630             SourceHalfMask[SourceHalfMask[Input - SourceOffset]] =
13631                 Input - SourceOffset;
13632             // We have to swap the uses in our half mask in one sweep.
13633             for (int &M : HalfMask)
13634               if (M == SourceHalfMask[Input - SourceOffset] + SourceOffset)
13635                 M = Input;
13636               else if (M == Input)
13637                 M = SourceHalfMask[Input - SourceOffset] + SourceOffset;
13638           } else {
13639             assert(SourceHalfMask[SourceHalfMask[Input - SourceOffset]] ==
13640                        Input - SourceOffset &&
13641                    "Previous placement doesn't match!");
13642           }
13643           // Note that this correctly re-maps both when we do a swap and when
13644           // we observe the other side of the swap above. We rely on that to
13645           // avoid swapping the members of the input list directly.
13646           Input = SourceHalfMask[Input - SourceOffset] + SourceOffset;
13647         }
13648 
13649         // Map the input's dword into the correct half.
13650         if (PSHUFDMask[(Input - SourceOffset + DestOffset) / 2] < 0)
13651           PSHUFDMask[(Input - SourceOffset + DestOffset) / 2] = Input / 2;
13652         else
13653           assert(PSHUFDMask[(Input - SourceOffset + DestOffset) / 2] ==
13654                      Input / 2 &&
13655                  "Previous placement doesn't match!");
13656       }
13657 
13658       // And just directly shift any other-half mask elements to be same-half
13659       // as we will have mirrored the dword containing the element into the
13660       // same position within that half.
13661       for (int &M : HalfMask)
13662         if (M >= SourceOffset && M < SourceOffset + 4) {
13663           M = M - SourceOffset + DestOffset;
13664           assert(M >= 0 && "This should never wrap below zero!");
13665         }
13666       return;
13667     }
13668 
13669     // Ensure we have the input in a viable dword of its current half. This
13670     // is particularly tricky because the original position may be clobbered
13671     // by inputs being moved and *staying* in that half.
13672     if (IncomingInputs.size() == 1) {
13673       if (isWordClobbered(SourceHalfMask, IncomingInputs[0] - SourceOffset)) {
13674         int InputFixed = find(SourceHalfMask, -1) - std::begin(SourceHalfMask) +
13675                          SourceOffset;
13676         SourceHalfMask[InputFixed - SourceOffset] =
13677             IncomingInputs[0] - SourceOffset;
13678         std::replace(HalfMask.begin(), HalfMask.end(), IncomingInputs[0],
13679                      InputFixed);
13680         IncomingInputs[0] = InputFixed;
13681       }
13682     } else if (IncomingInputs.size() == 2) {
13683       if (IncomingInputs[0] / 2 != IncomingInputs[1] / 2 ||
13684           isDWordClobbered(SourceHalfMask, IncomingInputs[0] - SourceOffset)) {
13685         // We have two non-adjacent or clobbered inputs we need to extract from
13686         // the source half. To do this, we need to map them into some adjacent
13687         // dword slot in the source mask.
13688         int InputsFixed[2] = {IncomingInputs[0] - SourceOffset,
13689                               IncomingInputs[1] - SourceOffset};
13690 
13691         // If there is a free slot in the source half mask adjacent to one of
13692         // the inputs, place the other input in it. We use (Index XOR 1) to
13693         // compute an adjacent index.
13694         if (!isWordClobbered(SourceHalfMask, InputsFixed[0]) &&
13695             SourceHalfMask[InputsFixed[0] ^ 1] < 0) {
13696           SourceHalfMask[InputsFixed[0]] = InputsFixed[0];
13697           SourceHalfMask[InputsFixed[0] ^ 1] = InputsFixed[1];
13698           InputsFixed[1] = InputsFixed[0] ^ 1;
13699         } else if (!isWordClobbered(SourceHalfMask, InputsFixed[1]) &&
13700                    SourceHalfMask[InputsFixed[1] ^ 1] < 0) {
13701           SourceHalfMask[InputsFixed[1]] = InputsFixed[1];
13702           SourceHalfMask[InputsFixed[1] ^ 1] = InputsFixed[0];
13703           InputsFixed[0] = InputsFixed[1] ^ 1;
13704         } else if (SourceHalfMask[2 * ((InputsFixed[0] / 2) ^ 1)] < 0 &&
13705                    SourceHalfMask[2 * ((InputsFixed[0] / 2) ^ 1) + 1] < 0) {
13706           // The two inputs are in the same DWord but it is clobbered and the
13707           // adjacent DWord isn't used at all. Move both inputs to the free
13708           // slot.
13709           SourceHalfMask[2 * ((InputsFixed[0] / 2) ^ 1)] = InputsFixed[0];
13710           SourceHalfMask[2 * ((InputsFixed[0] / 2) ^ 1) + 1] = InputsFixed[1];
13711           InputsFixed[0] = 2 * ((InputsFixed[0] / 2) ^ 1);
13712           InputsFixed[1] = 2 * ((InputsFixed[0] / 2) ^ 1) + 1;
13713         } else {
13714           // The only way we hit this point is if there is no clobbering
13715           // (because there are no off-half inputs to this half) and there is no
13716           // free slot adjacent to one of the inputs. In this case, we have to
13717           // swap an input with a non-input.
13718           for (int i = 0; i < 4; ++i)
13719             assert((SourceHalfMask[i] < 0 || SourceHalfMask[i] == i) &&
13720                    "We can't handle any clobbers here!");
13721           assert(InputsFixed[1] != (InputsFixed[0] ^ 1) &&
13722                  "Cannot have adjacent inputs here!");
13723 
13724           SourceHalfMask[InputsFixed[0] ^ 1] = InputsFixed[1];
13725           SourceHalfMask[InputsFixed[1]] = InputsFixed[0] ^ 1;
13726 
13727           // We also have to update the final source mask in this case because
13728           // it may need to undo the above swap.
13729           for (int &M : FinalSourceHalfMask)
13730             if (M == (InputsFixed[0] ^ 1) + SourceOffset)
13731               M = InputsFixed[1] + SourceOffset;
13732             else if (M == InputsFixed[1] + SourceOffset)
13733               M = (InputsFixed[0] ^ 1) + SourceOffset;
13734 
13735           InputsFixed[1] = InputsFixed[0] ^ 1;
13736         }
13737 
13738         // Point everything at the fixed inputs.
13739         for (int &M : HalfMask)
13740           if (M == IncomingInputs[0])
13741             M = InputsFixed[0] + SourceOffset;
13742           else if (M == IncomingInputs[1])
13743             M = InputsFixed[1] + SourceOffset;
13744 
13745         IncomingInputs[0] = InputsFixed[0] + SourceOffset;
13746         IncomingInputs[1] = InputsFixed[1] + SourceOffset;
13747       }
13748     } else {
13749       llvm_unreachable("Unhandled input size!");
13750     }
13751 
13752     // Now hoist the DWord down to the right half.
13753     int FreeDWord = (PSHUFDMask[DestOffset / 2] < 0 ? 0 : 1) + DestOffset / 2;
13754     assert(PSHUFDMask[FreeDWord] < 0 && "DWord not free");
13755     PSHUFDMask[FreeDWord] = IncomingInputs[0] / 2;
13756     for (int &M : HalfMask)
13757       for (int Input : IncomingInputs)
13758         if (M == Input)
13759           M = FreeDWord * 2 + Input % 2;
13760   };
13761   moveInputsToRightHalf(HToLInputs, LToLInputs, PSHUFHMask, LoMask, HiMask,
13762                         /*SourceOffset*/ 4, /*DestOffset*/ 0);
13763   moveInputsToRightHalf(LToHInputs, HToHInputs, PSHUFLMask, HiMask, LoMask,
13764                         /*SourceOffset*/ 0, /*DestOffset*/ 4);
13765 
13766   // Now enact all the shuffles we've computed to move the inputs into their
13767   // target half.
13768   if (!isNoopShuffleMask(PSHUFLMask))
13769     V = DAG.getNode(X86ISD::PSHUFLW, DL, VT, V,
13770                     getV4X86ShuffleImm8ForMask(PSHUFLMask, DL, DAG));
13771   if (!isNoopShuffleMask(PSHUFHMask))
13772     V = DAG.getNode(X86ISD::PSHUFHW, DL, VT, V,
13773                     getV4X86ShuffleImm8ForMask(PSHUFHMask, DL, DAG));
13774   if (!isNoopShuffleMask(PSHUFDMask))
13775     V = DAG.getBitcast(
13776         VT,
13777         DAG.getNode(X86ISD::PSHUFD, DL, PSHUFDVT, DAG.getBitcast(PSHUFDVT, V),
13778                     getV4X86ShuffleImm8ForMask(PSHUFDMask, DL, DAG)));
13779 
13780   // At this point, each half should contain all its inputs, and we can then
13781   // just shuffle them into their final position.
13782   assert(count_if(LoMask, [](int M) { return M >= 4; }) == 0 &&
13783          "Failed to lift all the high half inputs to the low mask!");
13784   assert(count_if(HiMask, [](int M) { return M >= 0 && M < 4; }) == 0 &&
13785          "Failed to lift all the low half inputs to the high mask!");
13786 
13787   // Do a half shuffle for the low mask.
13788   if (!isNoopShuffleMask(LoMask))
13789     V = DAG.getNode(X86ISD::PSHUFLW, DL, VT, V,
13790                     getV4X86ShuffleImm8ForMask(LoMask, DL, DAG));
13791 
13792   // Do a half shuffle with the high mask after shifting its values down.
13793   for (int &M : HiMask)
13794     if (M >= 0)
13795       M -= 4;
13796   if (!isNoopShuffleMask(HiMask))
13797     V = DAG.getNode(X86ISD::PSHUFHW, DL, VT, V,
13798                     getV4X86ShuffleImm8ForMask(HiMask, DL, DAG));
13799 
13800   return V;
13801 }
13802 
13803 /// Helper to form a PSHUFB-based shuffle+blend, opportunistically avoiding the
13804 /// blend if only one input is used.
lowerShuffleAsBlendOfPSHUFBs(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,SelectionDAG & DAG,bool & V1InUse,bool & V2InUse)13805 static SDValue lowerShuffleAsBlendOfPSHUFBs(
13806     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
13807     const APInt &Zeroable, SelectionDAG &DAG, bool &V1InUse, bool &V2InUse) {
13808   assert(!is128BitLaneCrossingShuffleMask(VT, Mask) &&
13809          "Lane crossing shuffle masks not supported");
13810 
13811   int NumBytes = VT.getSizeInBits() / 8;
13812   int Size = Mask.size();
13813   int Scale = NumBytes / Size;
13814 
13815   SmallVector<SDValue, 64> V1Mask(NumBytes, DAG.getUNDEF(MVT::i8));
13816   SmallVector<SDValue, 64> V2Mask(NumBytes, DAG.getUNDEF(MVT::i8));
13817   V1InUse = false;
13818   V2InUse = false;
13819 
13820   for (int i = 0; i < NumBytes; ++i) {
13821     int M = Mask[i / Scale];
13822     if (M < 0)
13823       continue;
13824 
13825     const int ZeroMask = 0x80;
13826     int V1Idx = M < Size ? M * Scale + i % Scale : ZeroMask;
13827     int V2Idx = M < Size ? ZeroMask : (M - Size) * Scale + i % Scale;
13828     if (Zeroable[i / Scale])
13829       V1Idx = V2Idx = ZeroMask;
13830 
13831     V1Mask[i] = DAG.getConstant(V1Idx, DL, MVT::i8);
13832     V2Mask[i] = DAG.getConstant(V2Idx, DL, MVT::i8);
13833     V1InUse |= (ZeroMask != V1Idx);
13834     V2InUse |= (ZeroMask != V2Idx);
13835   }
13836 
13837   MVT ShufVT = MVT::getVectorVT(MVT::i8, NumBytes);
13838   if (V1InUse)
13839     V1 = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT, DAG.getBitcast(ShufVT, V1),
13840                      DAG.getBuildVector(ShufVT, DL, V1Mask));
13841   if (V2InUse)
13842     V2 = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT, DAG.getBitcast(ShufVT, V2),
13843                      DAG.getBuildVector(ShufVT, DL, V2Mask));
13844 
13845   // If we need shuffled inputs from both, blend the two.
13846   SDValue V;
13847   if (V1InUse && V2InUse)
13848     V = DAG.getNode(ISD::OR, DL, ShufVT, V1, V2);
13849   else
13850     V = V1InUse ? V1 : V2;
13851 
13852   // Cast the result back to the correct type.
13853   return DAG.getBitcast(VT, V);
13854 }
13855 
13856 /// Generic lowering of 8-lane i16 shuffles.
13857 ///
13858 /// This handles both single-input shuffles and combined shuffle/blends with
13859 /// two inputs. The single input shuffles are immediately delegated to
13860 /// a dedicated lowering routine.
13861 ///
13862 /// The blends are lowered in one of three fundamental ways. If there are few
13863 /// enough inputs, it delegates to a basic UNPCK-based strategy. If the shuffle
13864 /// of the input is significantly cheaper when lowered as an interleaving of
13865 /// the two inputs, try to interleave them. Otherwise, blend the low and high
13866 /// halves of the inputs separately (making them have relatively few inputs)
13867 /// and then concatenate them.
lowerV8I16Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)13868 static SDValue lowerV8I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
13869                                  const APInt &Zeroable, SDValue V1, SDValue V2,
13870                                  const X86Subtarget &Subtarget,
13871                                  SelectionDAG &DAG) {
13872   assert(V1.getSimpleValueType() == MVT::v8i16 && "Bad operand type!");
13873   assert(V2.getSimpleValueType() == MVT::v8i16 && "Bad operand type!");
13874   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
13875 
13876   // Whenever we can lower this as a zext, that instruction is strictly faster
13877   // than any alternative.
13878   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(DL, MVT::v8i16, V1, V2, Mask,
13879                                                    Zeroable, Subtarget, DAG))
13880     return ZExt;
13881 
13882   // Try to use lower using a truncation.
13883   if (SDValue V = lowerShuffleWithVPMOV(DL, MVT::v8i16, V1, V2, Mask, Zeroable,
13884                                         Subtarget, DAG))
13885     return V;
13886 
13887   int NumV2Inputs = count_if(Mask, [](int M) { return M >= 8; });
13888 
13889   if (NumV2Inputs == 0) {
13890     // Try to use shift instructions.
13891     if (SDValue Shift =
13892             lowerShuffleAsShift(DL, MVT::v8i16, V1, V1, Mask, Zeroable,
13893                                 Subtarget, DAG, /*BitwiseOnly*/ false))
13894       return Shift;
13895 
13896     // Check for being able to broadcast a single element.
13897     if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8i16, V1, V2,
13898                                                     Mask, Subtarget, DAG))
13899       return Broadcast;
13900 
13901     // Try to use bit rotation instructions.
13902     if (SDValue Rotate = lowerShuffleAsBitRotate(DL, MVT::v8i16, V1, Mask,
13903                                                  Subtarget, DAG))
13904       return Rotate;
13905 
13906     // Use dedicated unpack instructions for masks that match their pattern.
13907     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v8i16, Mask, V1, V2, DAG))
13908       return V;
13909 
13910     // Use dedicated pack instructions for masks that match their pattern.
13911     if (SDValue V = lowerShuffleWithPACK(DL, MVT::v8i16, Mask, V1, V2, DAG,
13912                                          Subtarget))
13913       return V;
13914 
13915     // Try to use byte rotation instructions.
13916     if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v8i16, V1, V1, Mask,
13917                                                   Subtarget, DAG))
13918       return Rotate;
13919 
13920     // Make a copy of the mask so it can be modified.
13921     SmallVector<int, 8> MutableMask(Mask);
13922     return lowerV8I16GeneralSingleInputShuffle(DL, MVT::v8i16, V1, MutableMask,
13923                                                Subtarget, DAG);
13924   }
13925 
13926   assert(llvm::any_of(Mask, [](int M) { return M >= 0 && M < 8; }) &&
13927          "All single-input shuffles should be canonicalized to be V1-input "
13928          "shuffles.");
13929 
13930   // Try to use shift instructions.
13931   if (SDValue Shift =
13932           lowerShuffleAsShift(DL, MVT::v8i16, V1, V2, Mask, Zeroable, Subtarget,
13933                               DAG, /*BitwiseOnly*/ false))
13934     return Shift;
13935 
13936   // See if we can use SSE4A Extraction / Insertion.
13937   if (Subtarget.hasSSE4A())
13938     if (SDValue V = lowerShuffleWithSSE4A(DL, MVT::v8i16, V1, V2, Mask,
13939                                           Zeroable, DAG))
13940       return V;
13941 
13942   // There are special ways we can lower some single-element blends.
13943   if (NumV2Inputs == 1)
13944     if (SDValue V = lowerShuffleAsElementInsertion(
13945             DL, MVT::v8i16, V1, V2, Mask, Zeroable, Subtarget, DAG))
13946       return V;
13947 
13948   // We have different paths for blend lowering, but they all must use the
13949   // *exact* same predicate.
13950   bool IsBlendSupported = Subtarget.hasSSE41();
13951   if (IsBlendSupported)
13952     if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v8i16, V1, V2, Mask,
13953                                             Zeroable, Subtarget, DAG))
13954       return Blend;
13955 
13956   if (SDValue Masked = lowerShuffleAsBitMask(DL, MVT::v8i16, V1, V2, Mask,
13957                                              Zeroable, Subtarget, DAG))
13958     return Masked;
13959 
13960   // Use dedicated unpack instructions for masks that match their pattern.
13961   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v8i16, Mask, V1, V2, DAG))
13962     return V;
13963 
13964   // Use dedicated pack instructions for masks that match their pattern.
13965   if (SDValue V = lowerShuffleWithPACK(DL, MVT::v8i16, Mask, V1, V2, DAG,
13966                                        Subtarget))
13967     return V;
13968 
13969   // Try to use lower using a truncation.
13970   if (SDValue V = lowerShuffleAsVTRUNC(DL, MVT::v8i16, V1, V2, Mask, Zeroable,
13971                                        Subtarget, DAG))
13972     return V;
13973 
13974   // Try to use byte rotation instructions.
13975   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v8i16, V1, V2, Mask,
13976                                                 Subtarget, DAG))
13977     return Rotate;
13978 
13979   if (SDValue BitBlend =
13980           lowerShuffleAsBitBlend(DL, MVT::v8i16, V1, V2, Mask, DAG))
13981     return BitBlend;
13982 
13983   // Try to use byte shift instructions to mask.
13984   if (SDValue V = lowerShuffleAsByteShiftMask(DL, MVT::v8i16, V1, V2, Mask,
13985                                               Zeroable, Subtarget, DAG))
13986     return V;
13987 
13988   // Attempt to lower using compaction, SSE41 is necessary for PACKUSDW.
13989   int NumEvenDrops = canLowerByDroppingElements(Mask, true, false);
13990   if ((NumEvenDrops == 1 || (NumEvenDrops == 2 && Subtarget.hasSSE41())) &&
13991       !Subtarget.hasVLX()) {
13992     // Check if this is part of a 256-bit vector truncation.
13993     unsigned PackOpc = 0;
13994     if (NumEvenDrops == 2 && Subtarget.hasAVX2() &&
13995         peekThroughBitcasts(V1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
13996         peekThroughBitcasts(V2).getOpcode() == ISD::EXTRACT_SUBVECTOR) {
13997       SDValue V1V2 = concatSubVectors(V1, V2, DAG, DL);
13998       V1V2 = DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1V2,
13999                          getZeroVector(MVT::v16i16, Subtarget, DAG, DL),
14000                          DAG.getTargetConstant(0xEE, DL, MVT::i8));
14001       V1V2 = DAG.getBitcast(MVT::v8i32, V1V2);
14002       V1 = extract128BitVector(V1V2, 0, DAG, DL);
14003       V2 = extract128BitVector(V1V2, 4, DAG, DL);
14004       PackOpc = X86ISD::PACKUS;
14005     } else if (Subtarget.hasSSE41()) {
14006       SmallVector<SDValue, 4> DWordClearOps(4,
14007                                             DAG.getConstant(0, DL, MVT::i32));
14008       for (unsigned i = 0; i != 4; i += 1 << (NumEvenDrops - 1))
14009         DWordClearOps[i] = DAG.getConstant(0xFFFF, DL, MVT::i32);
14010       SDValue DWordClearMask =
14011           DAG.getBuildVector(MVT::v4i32, DL, DWordClearOps);
14012       V1 = DAG.getNode(ISD::AND, DL, MVT::v4i32, DAG.getBitcast(MVT::v4i32, V1),
14013                        DWordClearMask);
14014       V2 = DAG.getNode(ISD::AND, DL, MVT::v4i32, DAG.getBitcast(MVT::v4i32, V2),
14015                        DWordClearMask);
14016       PackOpc = X86ISD::PACKUS;
14017     } else if (!Subtarget.hasSSSE3()) {
14018       SDValue ShAmt = DAG.getTargetConstant(16, DL, MVT::i8);
14019       V1 = DAG.getBitcast(MVT::v4i32, V1);
14020       V2 = DAG.getBitcast(MVT::v4i32, V2);
14021       V1 = DAG.getNode(X86ISD::VSHLI, DL, MVT::v4i32, V1, ShAmt);
14022       V2 = DAG.getNode(X86ISD::VSHLI, DL, MVT::v4i32, V2, ShAmt);
14023       V1 = DAG.getNode(X86ISD::VSRAI, DL, MVT::v4i32, V1, ShAmt);
14024       V2 = DAG.getNode(X86ISD::VSRAI, DL, MVT::v4i32, V2, ShAmt);
14025       PackOpc = X86ISD::PACKSS;
14026     }
14027     if (PackOpc) {
14028       // Now pack things back together.
14029       SDValue Result = DAG.getNode(PackOpc, DL, MVT::v8i16, V1, V2);
14030       if (NumEvenDrops == 2) {
14031         Result = DAG.getBitcast(MVT::v4i32, Result);
14032         Result = DAG.getNode(PackOpc, DL, MVT::v8i16, Result, Result);
14033       }
14034       return Result;
14035     }
14036   }
14037 
14038   // When compacting odd (upper) elements, use PACKSS pre-SSE41.
14039   int NumOddDrops = canLowerByDroppingElements(Mask, false, false);
14040   if (NumOddDrops == 1) {
14041     bool HasSSE41 = Subtarget.hasSSE41();
14042     V1 = DAG.getNode(HasSSE41 ? X86ISD::VSRLI : X86ISD::VSRAI, DL, MVT::v4i32,
14043                      DAG.getBitcast(MVT::v4i32, V1),
14044                      DAG.getTargetConstant(16, DL, MVT::i8));
14045     V2 = DAG.getNode(HasSSE41 ? X86ISD::VSRLI : X86ISD::VSRAI, DL, MVT::v4i32,
14046                      DAG.getBitcast(MVT::v4i32, V2),
14047                      DAG.getTargetConstant(16, DL, MVT::i8));
14048     return DAG.getNode(HasSSE41 ? X86ISD::PACKUS : X86ISD::PACKSS, DL,
14049                        MVT::v8i16, V1, V2);
14050   }
14051 
14052   // Try to lower by permuting the inputs into an unpack instruction.
14053   if (SDValue Unpack = lowerShuffleAsPermuteAndUnpack(DL, MVT::v8i16, V1, V2,
14054                                                       Mask, Subtarget, DAG))
14055     return Unpack;
14056 
14057   // If we can't directly blend but can use PSHUFB, that will be better as it
14058   // can both shuffle and set up the inefficient blend.
14059   if (!IsBlendSupported && Subtarget.hasSSSE3()) {
14060     bool V1InUse, V2InUse;
14061     return lowerShuffleAsBlendOfPSHUFBs(DL, MVT::v8i16, V1, V2, Mask,
14062                                         Zeroable, DAG, V1InUse, V2InUse);
14063   }
14064 
14065   // We can always bit-blend if we have to so the fallback strategy is to
14066   // decompose into single-input permutes and blends/unpacks.
14067   return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v8i16, V1, V2,
14068                                               Mask, Subtarget, DAG);
14069 }
14070 
14071 /// Lower 8-lane 16-bit floating point shuffles.
lowerV8F16Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)14072 static SDValue lowerV8F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
14073                                  const APInt &Zeroable, SDValue V1, SDValue V2,
14074                                  const X86Subtarget &Subtarget,
14075                                  SelectionDAG &DAG) {
14076   assert(V1.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
14077   assert(V2.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
14078   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
14079   int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
14080 
14081   if (Subtarget.hasFP16()) {
14082     if (NumV2Elements == 0) {
14083       // Check for being able to broadcast a single element.
14084       if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
14085                                                       Mask, Subtarget, DAG))
14086         return Broadcast;
14087     }
14088     if (NumV2Elements == 1 && Mask[0] >= 8)
14089       if (SDValue V = lowerShuffleAsElementInsertion(
14090               DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
14091         return V;
14092   }
14093 
14094   V1 = DAG.getBitcast(MVT::v8i16, V1);
14095   V2 = DAG.getBitcast(MVT::v8i16, V2);
14096   return DAG.getBitcast(MVT::v8f16,
14097                         DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
14098 }
14099 
14100 // Lowers unary/binary shuffle as VPERMV/VPERMV3, for non-VLX targets,
14101 // sub-512-bit shuffles are padded to 512-bits for the shuffle and then
14102 // the active subvector is extracted.
lowerShuffleWithPERMV(const SDLoc & DL,MVT VT,ArrayRef<int> Mask,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)14103 static SDValue lowerShuffleWithPERMV(const SDLoc &DL, MVT VT,
14104                                      ArrayRef<int> Mask, SDValue V1, SDValue V2,
14105                                      const X86Subtarget &Subtarget,
14106                                      SelectionDAG &DAG) {
14107   MVT MaskVT = VT.changeTypeToInteger();
14108   SDValue MaskNode;
14109   MVT ShuffleVT = VT;
14110   if (!VT.is512BitVector() && !Subtarget.hasVLX()) {
14111     V1 = widenSubVector(V1, false, Subtarget, DAG, DL, 512);
14112     V2 = widenSubVector(V2, false, Subtarget, DAG, DL, 512);
14113     ShuffleVT = V1.getSimpleValueType();
14114 
14115     // Adjust mask to correct indices for the second input.
14116     int NumElts = VT.getVectorNumElements();
14117     unsigned Scale = 512 / VT.getSizeInBits();
14118     SmallVector<int, 32> AdjustedMask(Mask);
14119     for (int &M : AdjustedMask)
14120       if (NumElts <= M)
14121         M += (Scale - 1) * NumElts;
14122     MaskNode = getConstVector(AdjustedMask, MaskVT, DAG, DL, true);
14123     MaskNode = widenSubVector(MaskNode, false, Subtarget, DAG, DL, 512);
14124   } else {
14125     MaskNode = getConstVector(Mask, MaskVT, DAG, DL, true);
14126   }
14127 
14128   SDValue Result;
14129   if (V2.isUndef())
14130     Result = DAG.getNode(X86ISD::VPERMV, DL, ShuffleVT, MaskNode, V1);
14131   else
14132     Result = DAG.getNode(X86ISD::VPERMV3, DL, ShuffleVT, V1, MaskNode, V2);
14133 
14134   if (VT != ShuffleVT)
14135     Result = extractSubVector(Result, 0, DAG, DL, VT.getSizeInBits());
14136 
14137   return Result;
14138 }
14139 
14140 /// Generic lowering of v16i8 shuffles.
14141 ///
14142 /// This is a hybrid strategy to lower v16i8 vectors. It first attempts to
14143 /// detect any complexity reducing interleaving. If that doesn't help, it uses
14144 /// UNPCK to spread the i8 elements across two i16-element vectors, and uses
14145 /// the existing lowering for v8i16 blends on each half, finally PACK-ing them
14146 /// back together.
lowerV16I8Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)14147 static SDValue lowerV16I8Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
14148                                  const APInt &Zeroable, SDValue V1, SDValue V2,
14149                                  const X86Subtarget &Subtarget,
14150                                  SelectionDAG &DAG) {
14151   assert(V1.getSimpleValueType() == MVT::v16i8 && "Bad operand type!");
14152   assert(V2.getSimpleValueType() == MVT::v16i8 && "Bad operand type!");
14153   assert(Mask.size() == 16 && "Unexpected mask size for v16 shuffle!");
14154 
14155   // Try to use shift instructions.
14156   if (SDValue Shift =
14157           lowerShuffleAsShift(DL, MVT::v16i8, V1, V2, Mask, Zeroable, Subtarget,
14158                               DAG, /*BitwiseOnly*/ false))
14159     return Shift;
14160 
14161   // Try to use byte rotation instructions.
14162   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v16i8, V1, V2, Mask,
14163                                                 Subtarget, DAG))
14164     return Rotate;
14165 
14166   // Use dedicated pack instructions for masks that match their pattern.
14167   if (SDValue V = lowerShuffleWithPACK(DL, MVT::v16i8, Mask, V1, V2, DAG,
14168                                        Subtarget))
14169     return V;
14170 
14171   // Try to use a zext lowering.
14172   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(DL, MVT::v16i8, V1, V2, Mask,
14173                                                    Zeroable, Subtarget, DAG))
14174     return ZExt;
14175 
14176   // Try to use lower using a truncation.
14177   if (SDValue V = lowerShuffleWithVPMOV(DL, MVT::v16i8, V1, V2, Mask, Zeroable,
14178                                         Subtarget, DAG))
14179     return V;
14180 
14181   if (SDValue V = lowerShuffleAsVTRUNC(DL, MVT::v16i8, V1, V2, Mask, Zeroable,
14182                                        Subtarget, DAG))
14183     return V;
14184 
14185   // See if we can use SSE4A Extraction / Insertion.
14186   if (Subtarget.hasSSE4A())
14187     if (SDValue V = lowerShuffleWithSSE4A(DL, MVT::v16i8, V1, V2, Mask,
14188                                           Zeroable, DAG))
14189       return V;
14190 
14191   int NumV2Elements = count_if(Mask, [](int M) { return M >= 16; });
14192 
14193   // For single-input shuffles, there are some nicer lowering tricks we can use.
14194   if (NumV2Elements == 0) {
14195     // Check for being able to broadcast a single element.
14196     if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v16i8, V1, V2,
14197                                                     Mask, Subtarget, DAG))
14198       return Broadcast;
14199 
14200     // Try to use bit rotation instructions.
14201     if (SDValue Rotate = lowerShuffleAsBitRotate(DL, MVT::v16i8, V1, Mask,
14202                                                  Subtarget, DAG))
14203       return Rotate;
14204 
14205     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i8, Mask, V1, V2, DAG))
14206       return V;
14207 
14208     // Check whether we can widen this to an i16 shuffle by duplicating bytes.
14209     // Notably, this handles splat and partial-splat shuffles more efficiently.
14210     // However, it only makes sense if the pre-duplication shuffle simplifies
14211     // things significantly. Currently, this means we need to be able to
14212     // express the pre-duplication shuffle as an i16 shuffle.
14213     //
14214     // FIXME: We should check for other patterns which can be widened into an
14215     // i16 shuffle as well.
14216     auto canWidenViaDuplication = [](ArrayRef<int> Mask) {
14217       for (int i = 0; i < 16; i += 2)
14218         if (Mask[i] >= 0 && Mask[i + 1] >= 0 && Mask[i] != Mask[i + 1])
14219           return false;
14220 
14221       return true;
14222     };
14223     auto tryToWidenViaDuplication = [&]() -> SDValue {
14224       if (!canWidenViaDuplication(Mask))
14225         return SDValue();
14226       SmallVector<int, 4> LoInputs;
14227       copy_if(Mask, std::back_inserter(LoInputs),
14228               [](int M) { return M >= 0 && M < 8; });
14229       array_pod_sort(LoInputs.begin(), LoInputs.end());
14230       LoInputs.erase(llvm::unique(LoInputs), LoInputs.end());
14231       SmallVector<int, 4> HiInputs;
14232       copy_if(Mask, std::back_inserter(HiInputs), [](int M) { return M >= 8; });
14233       array_pod_sort(HiInputs.begin(), HiInputs.end());
14234       HiInputs.erase(llvm::unique(HiInputs), HiInputs.end());
14235 
14236       bool TargetLo = LoInputs.size() >= HiInputs.size();
14237       ArrayRef<int> InPlaceInputs = TargetLo ? LoInputs : HiInputs;
14238       ArrayRef<int> MovingInputs = TargetLo ? HiInputs : LoInputs;
14239 
14240       int PreDupI16Shuffle[] = {-1, -1, -1, -1, -1, -1, -1, -1};
14241       SmallDenseMap<int, int, 8> LaneMap;
14242       for (int I : InPlaceInputs) {
14243         PreDupI16Shuffle[I/2] = I/2;
14244         LaneMap[I] = I;
14245       }
14246       int j = TargetLo ? 0 : 4, je = j + 4;
14247       for (int i = 0, ie = MovingInputs.size(); i < ie; ++i) {
14248         // Check if j is already a shuffle of this input. This happens when
14249         // there are two adjacent bytes after we move the low one.
14250         if (PreDupI16Shuffle[j] != MovingInputs[i] / 2) {
14251           // If we haven't yet mapped the input, search for a slot into which
14252           // we can map it.
14253           while (j < je && PreDupI16Shuffle[j] >= 0)
14254             ++j;
14255 
14256           if (j == je)
14257             // We can't place the inputs into a single half with a simple i16 shuffle, so bail.
14258             return SDValue();
14259 
14260           // Map this input with the i16 shuffle.
14261           PreDupI16Shuffle[j] = MovingInputs[i] / 2;
14262         }
14263 
14264         // Update the lane map based on the mapping we ended up with.
14265         LaneMap[MovingInputs[i]] = 2 * j + MovingInputs[i] % 2;
14266       }
14267       V1 = DAG.getBitcast(
14268           MVT::v16i8,
14269           DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
14270                                DAG.getUNDEF(MVT::v8i16), PreDupI16Shuffle));
14271 
14272       // Unpack the bytes to form the i16s that will be shuffled into place.
14273       bool EvenInUse = false, OddInUse = false;
14274       for (int i = 0; i < 16; i += 2) {
14275         EvenInUse |= (Mask[i + 0] >= 0);
14276         OddInUse |= (Mask[i + 1] >= 0);
14277         if (EvenInUse && OddInUse)
14278           break;
14279       }
14280       V1 = DAG.getNode(TargetLo ? X86ISD::UNPCKL : X86ISD::UNPCKH, DL,
14281                        MVT::v16i8, EvenInUse ? V1 : DAG.getUNDEF(MVT::v16i8),
14282                        OddInUse ? V1 : DAG.getUNDEF(MVT::v16i8));
14283 
14284       int PostDupI16Shuffle[8] = {-1, -1, -1, -1, -1, -1, -1, -1};
14285       for (int i = 0; i < 16; ++i)
14286         if (Mask[i] >= 0) {
14287           int MappedMask = LaneMap[Mask[i]] - (TargetLo ? 0 : 8);
14288           assert(MappedMask < 8 && "Invalid v8 shuffle mask!");
14289           if (PostDupI16Shuffle[i / 2] < 0)
14290             PostDupI16Shuffle[i / 2] = MappedMask;
14291           else
14292             assert(PostDupI16Shuffle[i / 2] == MappedMask &&
14293                    "Conflicting entries in the original shuffle!");
14294         }
14295       return DAG.getBitcast(
14296           MVT::v16i8,
14297           DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
14298                                DAG.getUNDEF(MVT::v8i16), PostDupI16Shuffle));
14299     };
14300     if (SDValue V = tryToWidenViaDuplication())
14301       return V;
14302   }
14303 
14304   if (SDValue Masked = lowerShuffleAsBitMask(DL, MVT::v16i8, V1, V2, Mask,
14305                                              Zeroable, Subtarget, DAG))
14306     return Masked;
14307 
14308   // Use dedicated unpack instructions for masks that match their pattern.
14309   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i8, Mask, V1, V2, DAG))
14310     return V;
14311 
14312   // Try to use byte shift instructions to mask.
14313   if (SDValue V = lowerShuffleAsByteShiftMask(DL, MVT::v16i8, V1, V2, Mask,
14314                                               Zeroable, Subtarget, DAG))
14315     return V;
14316 
14317   // Check for compaction patterns.
14318   bool IsSingleInput = V2.isUndef();
14319   int NumEvenDrops = canLowerByDroppingElements(Mask, true, IsSingleInput);
14320 
14321   // Check for SSSE3 which lets us lower all v16i8 shuffles much more directly
14322   // with PSHUFB. It is important to do this before we attempt to generate any
14323   // blends but after all of the single-input lowerings. If the single input
14324   // lowerings can find an instruction sequence that is faster than a PSHUFB, we
14325   // want to preserve that and we can DAG combine any longer sequences into
14326   // a PSHUFB in the end. But once we start blending from multiple inputs,
14327   // the complexity of DAG combining bad patterns back into PSHUFB is too high,
14328   // and there are *very* few patterns that would actually be faster than the
14329   // PSHUFB approach because of its ability to zero lanes.
14330   //
14331   // If the mask is a binary compaction, we can more efficiently perform this
14332   // as a PACKUS(AND(),AND()) - which is quicker than UNPACK(PSHUFB(),PSHUFB()).
14333   //
14334   // FIXME: The only exceptions to the above are blends which are exact
14335   // interleavings with direct instructions supporting them. We currently don't
14336   // handle those well here.
14337   if (Subtarget.hasSSSE3() && (IsSingleInput || NumEvenDrops != 1)) {
14338     bool V1InUse = false;
14339     bool V2InUse = false;
14340 
14341     SDValue PSHUFB = lowerShuffleAsBlendOfPSHUFBs(
14342         DL, MVT::v16i8, V1, V2, Mask, Zeroable, DAG, V1InUse, V2InUse);
14343 
14344     // If both V1 and V2 are in use and we can use a direct blend or an unpack,
14345     // do so. This avoids using them to handle blends-with-zero which is
14346     // important as a single pshufb is significantly faster for that.
14347     if (V1InUse && V2InUse) {
14348       if (Subtarget.hasSSE41())
14349         if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v16i8, V1, V2, Mask,
14350                                                 Zeroable, Subtarget, DAG))
14351           return Blend;
14352 
14353       // We can use an unpack to do the blending rather than an or in some
14354       // cases. Even though the or may be (very minorly) more efficient, we
14355       // preference this lowering because there are common cases where part of
14356       // the complexity of the shuffles goes away when we do the final blend as
14357       // an unpack.
14358       // FIXME: It might be worth trying to detect if the unpack-feeding
14359       // shuffles will both be pshufb, in which case we shouldn't bother with
14360       // this.
14361       if (SDValue Unpack = lowerShuffleAsPermuteAndUnpack(
14362               DL, MVT::v16i8, V1, V2, Mask, Subtarget, DAG))
14363         return Unpack;
14364 
14365       // AVX512VBMI can lower to VPERMB (non-VLX will pad to v64i8).
14366       if (Subtarget.hasVBMI())
14367         return lowerShuffleWithPERMV(DL, MVT::v16i8, Mask, V1, V2, Subtarget,
14368                                      DAG);
14369 
14370       // If we have XOP we can use one VPPERM instead of multiple PSHUFBs.
14371       if (Subtarget.hasXOP()) {
14372         SDValue MaskNode = getConstVector(Mask, MVT::v16i8, DAG, DL, true);
14373         return DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, V1, V2, MaskNode);
14374       }
14375 
14376       // Use PALIGNR+Permute if possible - permute might become PSHUFB but the
14377       // PALIGNR will be cheaper than the second PSHUFB+OR.
14378       if (SDValue V = lowerShuffleAsByteRotateAndPermute(
14379               DL, MVT::v16i8, V1, V2, Mask, Subtarget, DAG))
14380         return V;
14381     }
14382 
14383     return PSHUFB;
14384   }
14385 
14386   // There are special ways we can lower some single-element blends.
14387   if (NumV2Elements == 1)
14388     if (SDValue V = lowerShuffleAsElementInsertion(
14389             DL, MVT::v16i8, V1, V2, Mask, Zeroable, Subtarget, DAG))
14390       return V;
14391 
14392   if (SDValue Blend = lowerShuffleAsBitBlend(DL, MVT::v16i8, V1, V2, Mask, DAG))
14393     return Blend;
14394 
14395   // Check whether a compaction lowering can be done. This handles shuffles
14396   // which take every Nth element for some even N. See the helper function for
14397   // details.
14398   //
14399   // We special case these as they can be particularly efficiently handled with
14400   // the PACKUSB instruction on x86 and they show up in common patterns of
14401   // rearranging bytes to truncate wide elements.
14402   if (NumEvenDrops) {
14403     // NumEvenDrops is the power of two stride of the elements. Another way of
14404     // thinking about it is that we need to drop the even elements this many
14405     // times to get the original input.
14406 
14407     // First we need to zero all the dropped bytes.
14408     assert(NumEvenDrops <= 3 &&
14409            "No support for dropping even elements more than 3 times.");
14410     SmallVector<SDValue, 8> WordClearOps(8, DAG.getConstant(0, DL, MVT::i16));
14411     for (unsigned i = 0; i != 8; i += 1 << (NumEvenDrops - 1))
14412       WordClearOps[i] = DAG.getConstant(0xFF, DL, MVT::i16);
14413     SDValue WordClearMask = DAG.getBuildVector(MVT::v8i16, DL, WordClearOps);
14414     V1 = DAG.getNode(ISD::AND, DL, MVT::v8i16, DAG.getBitcast(MVT::v8i16, V1),
14415                      WordClearMask);
14416     if (!IsSingleInput)
14417       V2 = DAG.getNode(ISD::AND, DL, MVT::v8i16, DAG.getBitcast(MVT::v8i16, V2),
14418                        WordClearMask);
14419 
14420     // Now pack things back together.
14421     SDValue Result = DAG.getNode(X86ISD::PACKUS, DL, MVT::v16i8, V1,
14422                                  IsSingleInput ? V1 : V2);
14423     for (int i = 1; i < NumEvenDrops; ++i) {
14424       Result = DAG.getBitcast(MVT::v8i16, Result);
14425       Result = DAG.getNode(X86ISD::PACKUS, DL, MVT::v16i8, Result, Result);
14426     }
14427     return Result;
14428   }
14429 
14430   int NumOddDrops = canLowerByDroppingElements(Mask, false, IsSingleInput);
14431   if (NumOddDrops == 1) {
14432     V1 = DAG.getNode(X86ISD::VSRLI, DL, MVT::v8i16,
14433                      DAG.getBitcast(MVT::v8i16, V1),
14434                      DAG.getTargetConstant(8, DL, MVT::i8));
14435     if (!IsSingleInput)
14436       V2 = DAG.getNode(X86ISD::VSRLI, DL, MVT::v8i16,
14437                        DAG.getBitcast(MVT::v8i16, V2),
14438                        DAG.getTargetConstant(8, DL, MVT::i8));
14439     return DAG.getNode(X86ISD::PACKUS, DL, MVT::v16i8, V1,
14440                        IsSingleInput ? V1 : V2);
14441   }
14442 
14443   // Handle multi-input cases by blending/unpacking single-input shuffles.
14444   if (NumV2Elements > 0)
14445     return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v16i8, V1, V2, Mask,
14446                                                 Subtarget, DAG);
14447 
14448   // The fallback path for single-input shuffles widens this into two v8i16
14449   // vectors with unpacks, shuffles those, and then pulls them back together
14450   // with a pack.
14451   SDValue V = V1;
14452 
14453   std::array<int, 8> LoBlendMask = {{-1, -1, -1, -1, -1, -1, -1, -1}};
14454   std::array<int, 8> HiBlendMask = {{-1, -1, -1, -1, -1, -1, -1, -1}};
14455   for (int i = 0; i < 16; ++i)
14456     if (Mask[i] >= 0)
14457       (i < 8 ? LoBlendMask[i] : HiBlendMask[i % 8]) = Mask[i];
14458 
14459   SDValue VLoHalf, VHiHalf;
14460   // Check if any of the odd lanes in the v16i8 are used. If not, we can mask
14461   // them out and avoid using UNPCK{L,H} to extract the elements of V as
14462   // i16s.
14463   if (none_of(LoBlendMask, [](int M) { return M >= 0 && M % 2 == 1; }) &&
14464       none_of(HiBlendMask, [](int M) { return M >= 0 && M % 2 == 1; })) {
14465     // Use a mask to drop the high bytes.
14466     VLoHalf = DAG.getBitcast(MVT::v8i16, V);
14467     VLoHalf = DAG.getNode(ISD::AND, DL, MVT::v8i16, VLoHalf,
14468                           DAG.getConstant(0x00FF, DL, MVT::v8i16));
14469 
14470     // This will be a single vector shuffle instead of a blend so nuke VHiHalf.
14471     VHiHalf = DAG.getUNDEF(MVT::v8i16);
14472 
14473     // Squash the masks to point directly into VLoHalf.
14474     for (int &M : LoBlendMask)
14475       if (M >= 0)
14476         M /= 2;
14477     for (int &M : HiBlendMask)
14478       if (M >= 0)
14479         M /= 2;
14480   } else {
14481     // Otherwise just unpack the low half of V into VLoHalf and the high half into
14482     // VHiHalf so that we can blend them as i16s.
14483     SDValue Zero = getZeroVector(MVT::v16i8, Subtarget, DAG, DL);
14484 
14485     VLoHalf = DAG.getBitcast(
14486         MVT::v8i16, DAG.getNode(X86ISD::UNPCKL, DL, MVT::v16i8, V, Zero));
14487     VHiHalf = DAG.getBitcast(
14488         MVT::v8i16, DAG.getNode(X86ISD::UNPCKH, DL, MVT::v16i8, V, Zero));
14489   }
14490 
14491   SDValue LoV = DAG.getVectorShuffle(MVT::v8i16, DL, VLoHalf, VHiHalf, LoBlendMask);
14492   SDValue HiV = DAG.getVectorShuffle(MVT::v8i16, DL, VLoHalf, VHiHalf, HiBlendMask);
14493 
14494   return DAG.getNode(X86ISD::PACKUS, DL, MVT::v16i8, LoV, HiV);
14495 }
14496 
14497 /// Dispatching routine to lower various 128-bit x86 vector shuffles.
14498 ///
14499 /// This routine breaks down the specific type of 128-bit shuffle and
14500 /// dispatches to the lowering routines accordingly.
lower128BitShuffle(const SDLoc & DL,ArrayRef<int> Mask,MVT VT,SDValue V1,SDValue V2,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)14501 static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
14502                                   MVT VT, SDValue V1, SDValue V2,
14503                                   const APInt &Zeroable,
14504                                   const X86Subtarget &Subtarget,
14505                                   SelectionDAG &DAG) {
14506   if (VT == MVT::v8bf16) {
14507     V1 = DAG.getBitcast(MVT::v8i16, V1);
14508     V2 = DAG.getBitcast(MVT::v8i16, V2);
14509     return DAG.getBitcast(VT,
14510                           DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
14511   }
14512 
14513   switch (VT.SimpleTy) {
14514   case MVT::v2i64:
14515     return lowerV2I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
14516   case MVT::v2f64:
14517     return lowerV2F64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
14518   case MVT::v4i32:
14519     return lowerV4I32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
14520   case MVT::v4f32:
14521     return lowerV4F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
14522   case MVT::v8i16:
14523     return lowerV8I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
14524   case MVT::v8f16:
14525     return lowerV8F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
14526   case MVT::v16i8:
14527     return lowerV16I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
14528 
14529   default:
14530     llvm_unreachable("Unimplemented!");
14531   }
14532 }
14533 
14534 /// Generic routine to split vector shuffle into half-sized shuffles.
14535 ///
14536 /// This routine just extracts two subvectors, shuffles them independently, and
14537 /// then concatenates them back together. This should work effectively with all
14538 /// AVX vector shuffle types.
splitAndLowerShuffle(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG,bool SimpleOnly)14539 static SDValue splitAndLowerShuffle(const SDLoc &DL, MVT VT, SDValue V1,
14540                                     SDValue V2, ArrayRef<int> Mask,
14541                                     SelectionDAG &DAG, bool SimpleOnly) {
14542   assert(VT.getSizeInBits() >= 256 &&
14543          "Only for 256-bit or wider vector shuffles!");
14544   assert(V1.getSimpleValueType() == VT && "Bad operand type!");
14545   assert(V2.getSimpleValueType() == VT && "Bad operand type!");
14546 
14547   ArrayRef<int> LoMask = Mask.slice(0, Mask.size() / 2);
14548   ArrayRef<int> HiMask = Mask.slice(Mask.size() / 2);
14549 
14550   int NumElements = VT.getVectorNumElements();
14551   int SplitNumElements = NumElements / 2;
14552   MVT ScalarVT = VT.getVectorElementType();
14553   MVT SplitVT = MVT::getVectorVT(ScalarVT, SplitNumElements);
14554 
14555   // Use splitVector/extractSubVector so that split build-vectors just build two
14556   // narrower build vectors. This helps shuffling with splats and zeros.
14557   auto SplitVector = [&](SDValue V) {
14558     SDValue LoV, HiV;
14559     std::tie(LoV, HiV) = splitVector(peekThroughBitcasts(V), DAG, DL);
14560     return std::make_pair(DAG.getBitcast(SplitVT, LoV),
14561                           DAG.getBitcast(SplitVT, HiV));
14562   };
14563 
14564   SDValue LoV1, HiV1, LoV2, HiV2;
14565   std::tie(LoV1, HiV1) = SplitVector(V1);
14566   std::tie(LoV2, HiV2) = SplitVector(V2);
14567 
14568   // Now create two 4-way blends of these half-width vectors.
14569   auto GetHalfBlendPiecesReq = [&](const ArrayRef<int> &HalfMask, bool &UseLoV1,
14570                                    bool &UseHiV1, bool &UseLoV2,
14571                                    bool &UseHiV2) {
14572     UseLoV1 = UseHiV1 = UseLoV2 = UseHiV2 = false;
14573     for (int i = 0; i < SplitNumElements; ++i) {
14574       int M = HalfMask[i];
14575       if (M >= NumElements) {
14576         if (M >= NumElements + SplitNumElements)
14577           UseHiV2 = true;
14578         else
14579           UseLoV2 = true;
14580       } else if (M >= 0) {
14581         if (M >= SplitNumElements)
14582           UseHiV1 = true;
14583         else
14584           UseLoV1 = true;
14585       }
14586     }
14587   };
14588 
14589   auto CheckHalfBlendUsable = [&](const ArrayRef<int> &HalfMask) -> bool {
14590     if (!SimpleOnly)
14591       return true;
14592 
14593     bool UseLoV1, UseHiV1, UseLoV2, UseHiV2;
14594     GetHalfBlendPiecesReq(HalfMask, UseLoV1, UseHiV1, UseLoV2, UseHiV2);
14595 
14596     return !(UseHiV1 || UseHiV2);
14597   };
14598 
14599   auto HalfBlend = [&](ArrayRef<int> HalfMask) {
14600     SmallVector<int, 32> V1BlendMask((unsigned)SplitNumElements, -1);
14601     SmallVector<int, 32> V2BlendMask((unsigned)SplitNumElements, -1);
14602     SmallVector<int, 32> BlendMask((unsigned)SplitNumElements, -1);
14603     for (int i = 0; i < SplitNumElements; ++i) {
14604       int M = HalfMask[i];
14605       if (M >= NumElements) {
14606         V2BlendMask[i] = M - NumElements;
14607         BlendMask[i] = SplitNumElements + i;
14608       } else if (M >= 0) {
14609         V1BlendMask[i] = M;
14610         BlendMask[i] = i;
14611       }
14612     }
14613 
14614     bool UseLoV1, UseHiV1, UseLoV2, UseHiV2;
14615     GetHalfBlendPiecesReq(HalfMask, UseLoV1, UseHiV1, UseLoV2, UseHiV2);
14616 
14617     // Because the lowering happens after all combining takes place, we need to
14618     // manually combine these blend masks as much as possible so that we create
14619     // a minimal number of high-level vector shuffle nodes.
14620     assert((!SimpleOnly || (!UseHiV1 && !UseHiV2)) && "Shuffle isn't simple");
14621 
14622     // First try just blending the halves of V1 or V2.
14623     if (!UseLoV1 && !UseHiV1 && !UseLoV2 && !UseHiV2)
14624       return DAG.getUNDEF(SplitVT);
14625     if (!UseLoV2 && !UseHiV2)
14626       return DAG.getVectorShuffle(SplitVT, DL, LoV1, HiV1, V1BlendMask);
14627     if (!UseLoV1 && !UseHiV1)
14628       return DAG.getVectorShuffle(SplitVT, DL, LoV2, HiV2, V2BlendMask);
14629 
14630     SDValue V1Blend, V2Blend;
14631     if (UseLoV1 && UseHiV1) {
14632       V1Blend = DAG.getVectorShuffle(SplitVT, DL, LoV1, HiV1, V1BlendMask);
14633     } else {
14634       // We only use half of V1 so map the usage down into the final blend mask.
14635       V1Blend = UseLoV1 ? LoV1 : HiV1;
14636       for (int i = 0; i < SplitNumElements; ++i)
14637         if (BlendMask[i] >= 0 && BlendMask[i] < SplitNumElements)
14638           BlendMask[i] = V1BlendMask[i] - (UseLoV1 ? 0 : SplitNumElements);
14639     }
14640     if (UseLoV2 && UseHiV2) {
14641       V2Blend = DAG.getVectorShuffle(SplitVT, DL, LoV2, HiV2, V2BlendMask);
14642     } else {
14643       // We only use half of V2 so map the usage down into the final blend mask.
14644       V2Blend = UseLoV2 ? LoV2 : HiV2;
14645       for (int i = 0; i < SplitNumElements; ++i)
14646         if (BlendMask[i] >= SplitNumElements)
14647           BlendMask[i] = V2BlendMask[i] + (UseLoV2 ? SplitNumElements : 0);
14648     }
14649     return DAG.getVectorShuffle(SplitVT, DL, V1Blend, V2Blend, BlendMask);
14650   };
14651 
14652   if (!CheckHalfBlendUsable(LoMask) || !CheckHalfBlendUsable(HiMask))
14653     return SDValue();
14654 
14655   SDValue Lo = HalfBlend(LoMask);
14656   SDValue Hi = HalfBlend(HiMask);
14657   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
14658 }
14659 
14660 /// Either split a vector in halves or decompose the shuffles and the
14661 /// blend/unpack.
14662 ///
14663 /// This is provided as a good fallback for many lowerings of non-single-input
14664 /// shuffles with more than one 128-bit lane. In those cases, we want to select
14665 /// between splitting the shuffle into 128-bit components and stitching those
14666 /// back together vs. extracting the single-input shuffles and blending those
14667 /// results.
lowerShuffleAsSplitOrBlend(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)14668 static SDValue lowerShuffleAsSplitOrBlend(const SDLoc &DL, MVT VT, SDValue V1,
14669                                           SDValue V2, ArrayRef<int> Mask,
14670                                           const X86Subtarget &Subtarget,
14671                                           SelectionDAG &DAG) {
14672   assert(!V2.isUndef() && "This routine must not be used to lower single-input "
14673          "shuffles as it could then recurse on itself.");
14674   int Size = Mask.size();
14675 
14676   // If this can be modeled as a broadcast of two elements followed by a blend,
14677   // prefer that lowering. This is especially important because broadcasts can
14678   // often fold with memory operands.
14679   auto DoBothBroadcast = [&] {
14680     int V1BroadcastIdx = -1, V2BroadcastIdx = -1;
14681     for (int M : Mask)
14682       if (M >= Size) {
14683         if (V2BroadcastIdx < 0)
14684           V2BroadcastIdx = M - Size;
14685         else if (M - Size != V2BroadcastIdx)
14686           return false;
14687       } else if (M >= 0) {
14688         if (V1BroadcastIdx < 0)
14689           V1BroadcastIdx = M;
14690         else if (M != V1BroadcastIdx)
14691           return false;
14692       }
14693     return true;
14694   };
14695   if (DoBothBroadcast())
14696     return lowerShuffleAsDecomposedShuffleMerge(DL, VT, V1, V2, Mask, Subtarget,
14697                                                 DAG);
14698 
14699   // If the inputs all stem from a single 128-bit lane of each input, then we
14700   // split them rather than blending because the split will decompose to
14701   // unusually few instructions.
14702   int LaneCount = VT.getSizeInBits() / 128;
14703   int LaneSize = Size / LaneCount;
14704   SmallBitVector LaneInputs[2];
14705   LaneInputs[0].resize(LaneCount, false);
14706   LaneInputs[1].resize(LaneCount, false);
14707   for (int i = 0; i < Size; ++i)
14708     if (Mask[i] >= 0)
14709       LaneInputs[Mask[i] / Size][(Mask[i] % Size) / LaneSize] = true;
14710   if (LaneInputs[0].count() <= 1 && LaneInputs[1].count() <= 1)
14711     return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG,
14712                                 /*SimpleOnly*/ false);
14713 
14714   // Otherwise, just fall back to decomposed shuffles and a blend/unpack. This
14715   // requires that the decomposed single-input shuffles don't end up here.
14716   return lowerShuffleAsDecomposedShuffleMerge(DL, VT, V1, V2, Mask, Subtarget,
14717                                               DAG);
14718 }
14719 
14720 // Lower as SHUFPD(VPERM2F128(V1, V2), VPERM2F128(V1, V2)).
14721 // TODO: Extend to support v8f32 (+ 512-bit shuffles).
lowerShuffleAsLanePermuteAndSHUFP(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG)14722 static SDValue lowerShuffleAsLanePermuteAndSHUFP(const SDLoc &DL, MVT VT,
14723                                                  SDValue V1, SDValue V2,
14724                                                  ArrayRef<int> Mask,
14725                                                  SelectionDAG &DAG) {
14726   assert(VT == MVT::v4f64 && "Only for v4f64 shuffles");
14727 
14728   int LHSMask[4] = {-1, -1, -1, -1};
14729   int RHSMask[4] = {-1, -1, -1, -1};
14730   unsigned SHUFPMask = 0;
14731 
14732   // As SHUFPD uses a single LHS/RHS element per lane, we can always
14733   // perform the shuffle once the lanes have been shuffled in place.
14734   for (int i = 0; i != 4; ++i) {
14735     int M = Mask[i];
14736     if (M < 0)
14737       continue;
14738     int LaneBase = i & ~1;
14739     auto &LaneMask = (i & 1) ? RHSMask : LHSMask;
14740     LaneMask[LaneBase + (M & 1)] = M;
14741     SHUFPMask |= (M & 1) << i;
14742   }
14743 
14744   SDValue LHS = DAG.getVectorShuffle(VT, DL, V1, V2, LHSMask);
14745   SDValue RHS = DAG.getVectorShuffle(VT, DL, V1, V2, RHSMask);
14746   return DAG.getNode(X86ISD::SHUFP, DL, VT, LHS, RHS,
14747                      DAG.getTargetConstant(SHUFPMask, DL, MVT::i8));
14748 }
14749 
14750 /// Lower a vector shuffle crossing multiple 128-bit lanes as
14751 /// a lane permutation followed by a per-lane permutation.
14752 ///
14753 /// This is mainly for cases where we can have non-repeating permutes
14754 /// in each lane.
14755 ///
14756 /// TODO: This is very similar to lowerShuffleAsLanePermuteAndRepeatedMask,
14757 /// we should investigate merging them.
lowerShuffleAsLanePermuteAndPermute(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG,const X86Subtarget & Subtarget)14758 static SDValue lowerShuffleAsLanePermuteAndPermute(
14759     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
14760     SelectionDAG &DAG, const X86Subtarget &Subtarget) {
14761   int NumElts = VT.getVectorNumElements();
14762   int NumLanes = VT.getSizeInBits() / 128;
14763   int NumEltsPerLane = NumElts / NumLanes;
14764   bool CanUseSublanes = Subtarget.hasAVX2() && V2.isUndef();
14765 
14766   /// Attempts to find a sublane permute with the given size
14767   /// that gets all elements into their target lanes.
14768   ///
14769   /// If successful, fills CrossLaneMask and InLaneMask and returns true.
14770   /// If unsuccessful, returns false and may overwrite InLaneMask.
14771   auto getSublanePermute = [&](int NumSublanes) -> SDValue {
14772     int NumSublanesPerLane = NumSublanes / NumLanes;
14773     int NumEltsPerSublane = NumElts / NumSublanes;
14774 
14775     SmallVector<int, 16> CrossLaneMask;
14776     SmallVector<int, 16> InLaneMask(NumElts, SM_SentinelUndef);
14777     // CrossLaneMask but one entry == one sublane.
14778     SmallVector<int, 16> CrossLaneMaskLarge(NumSublanes, SM_SentinelUndef);
14779 
14780     for (int i = 0; i != NumElts; ++i) {
14781       int M = Mask[i];
14782       if (M < 0)
14783         continue;
14784 
14785       int SrcSublane = M / NumEltsPerSublane;
14786       int DstLane = i / NumEltsPerLane;
14787 
14788       // We only need to get the elements into the right lane, not sublane.
14789       // So search all sublanes that make up the destination lane.
14790       bool Found = false;
14791       int DstSubStart = DstLane * NumSublanesPerLane;
14792       int DstSubEnd = DstSubStart + NumSublanesPerLane;
14793       for (int DstSublane = DstSubStart; DstSublane < DstSubEnd; ++DstSublane) {
14794         if (!isUndefOrEqual(CrossLaneMaskLarge[DstSublane], SrcSublane))
14795           continue;
14796 
14797         Found = true;
14798         CrossLaneMaskLarge[DstSublane] = SrcSublane;
14799         int DstSublaneOffset = DstSublane * NumEltsPerSublane;
14800         InLaneMask[i] = DstSublaneOffset + M % NumEltsPerSublane;
14801         break;
14802       }
14803       if (!Found)
14804         return SDValue();
14805     }
14806 
14807     // Fill CrossLaneMask using CrossLaneMaskLarge.
14808     narrowShuffleMaskElts(NumEltsPerSublane, CrossLaneMaskLarge, CrossLaneMask);
14809 
14810     if (!CanUseSublanes) {
14811       // If we're only shuffling a single lowest lane and the rest are identity
14812       // then don't bother.
14813       // TODO - isShuffleMaskInputInPlace could be extended to something like
14814       // this.
14815       int NumIdentityLanes = 0;
14816       bool OnlyShuffleLowestLane = true;
14817       for (int i = 0; i != NumLanes; ++i) {
14818         int LaneOffset = i * NumEltsPerLane;
14819         if (isSequentialOrUndefInRange(InLaneMask, LaneOffset, NumEltsPerLane,
14820                                        i * NumEltsPerLane))
14821           NumIdentityLanes++;
14822         else if (CrossLaneMask[LaneOffset] != 0)
14823           OnlyShuffleLowestLane = false;
14824       }
14825       if (OnlyShuffleLowestLane && NumIdentityLanes == (NumLanes - 1))
14826         return SDValue();
14827     }
14828 
14829     // Avoid returning the same shuffle operation. For example,
14830     // t7: v16i16 = vector_shuffle<8,9,10,11,4,5,6,7,0,1,2,3,12,13,14,15> t5,
14831     //                             undef:v16i16
14832     if (CrossLaneMask == Mask || InLaneMask == Mask)
14833       return SDValue();
14834 
14835     SDValue CrossLane = DAG.getVectorShuffle(VT, DL, V1, V2, CrossLaneMask);
14836     return DAG.getVectorShuffle(VT, DL, CrossLane, DAG.getUNDEF(VT),
14837                                 InLaneMask);
14838   };
14839 
14840   // First attempt a solution with full lanes.
14841   if (SDValue V = getSublanePermute(/*NumSublanes=*/NumLanes))
14842     return V;
14843 
14844   // The rest of the solutions use sublanes.
14845   if (!CanUseSublanes)
14846     return SDValue();
14847 
14848   // Then attempt a solution with 64-bit sublanes (vpermq).
14849   if (SDValue V = getSublanePermute(/*NumSublanes=*/NumLanes * 2))
14850     return V;
14851 
14852   // If that doesn't work and we have fast variable cross-lane shuffle,
14853   // attempt 32-bit sublanes (vpermd).
14854   if (!Subtarget.hasFastVariableCrossLaneShuffle())
14855     return SDValue();
14856 
14857   return getSublanePermute(/*NumSublanes=*/NumLanes * 4);
14858 }
14859 
14860 /// Helper to get compute inlane shuffle mask for a complete shuffle mask.
computeInLaneShuffleMask(const ArrayRef<int> & Mask,int LaneSize,SmallVector<int> & InLaneMask)14861 static void computeInLaneShuffleMask(const ArrayRef<int> &Mask, int LaneSize,
14862                                      SmallVector<int> &InLaneMask) {
14863   int Size = Mask.size();
14864   InLaneMask.assign(Mask.begin(), Mask.end());
14865   for (int i = 0; i < Size; ++i) {
14866     int &M = InLaneMask[i];
14867     if (M < 0)
14868       continue;
14869     if (((M % Size) / LaneSize) != (i / LaneSize))
14870       M = (M % LaneSize) + ((i / LaneSize) * LaneSize) + Size;
14871   }
14872 }
14873 
14874 /// Lower a vector shuffle crossing multiple 128-bit lanes by shuffling one
14875 /// source with a lane permutation.
14876 ///
14877 /// This lowering strategy results in four instructions in the worst case for a
14878 /// single-input cross lane shuffle which is lower than any other fully general
14879 /// cross-lane shuffle strategy I'm aware of. Special cases for each particular
14880 /// shuffle pattern should be handled prior to trying this lowering.
lowerShuffleAsLanePermuteAndShuffle(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG,const X86Subtarget & Subtarget)14881 static SDValue lowerShuffleAsLanePermuteAndShuffle(
14882     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
14883     SelectionDAG &DAG, const X86Subtarget &Subtarget) {
14884   // FIXME: This should probably be generalized for 512-bit vectors as well.
14885   assert(VT.is256BitVector() && "Only for 256-bit vector shuffles!");
14886   int Size = Mask.size();
14887   int LaneSize = Size / 2;
14888 
14889   // Fold to SHUFPD(VPERM2F128(V1, V2), VPERM2F128(V1, V2)).
14890   // Only do this if the elements aren't all from the lower lane,
14891   // otherwise we're (probably) better off doing a split.
14892   if (VT == MVT::v4f64 &&
14893       !all_of(Mask, [LaneSize](int M) { return M < LaneSize; }))
14894     return lowerShuffleAsLanePermuteAndSHUFP(DL, VT, V1, V2, Mask, DAG);
14895 
14896   // If there are only inputs from one 128-bit lane, splitting will in fact be
14897   // less expensive. The flags track whether the given lane contains an element
14898   // that crosses to another lane.
14899   bool AllLanes;
14900   if (!Subtarget.hasAVX2()) {
14901     bool LaneCrossing[2] = {false, false};
14902     for (int i = 0; i < Size; ++i)
14903       if (Mask[i] >= 0 && ((Mask[i] % Size) / LaneSize) != (i / LaneSize))
14904         LaneCrossing[(Mask[i] % Size) / LaneSize] = true;
14905     AllLanes = LaneCrossing[0] && LaneCrossing[1];
14906   } else {
14907     bool LaneUsed[2] = {false, false};
14908     for (int i = 0; i < Size; ++i)
14909       if (Mask[i] >= 0)
14910         LaneUsed[(Mask[i] % Size) / LaneSize] = true;
14911     AllLanes = LaneUsed[0] && LaneUsed[1];
14912   }
14913 
14914   // TODO - we could support shuffling V2 in the Flipped input.
14915   assert(V2.isUndef() &&
14916          "This last part of this routine only works on single input shuffles");
14917 
14918   SmallVector<int> InLaneMask;
14919   computeInLaneShuffleMask(Mask, Mask.size() / 2, InLaneMask);
14920 
14921   assert(!is128BitLaneCrossingShuffleMask(VT, InLaneMask) &&
14922          "In-lane shuffle mask expected");
14923 
14924   // If we're not using both lanes in each lane and the inlane mask is not
14925   // repeating, then we're better off splitting.
14926   if (!AllLanes && !is128BitLaneRepeatedShuffleMask(VT, InLaneMask))
14927     return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG,
14928                                 /*SimpleOnly*/ false);
14929 
14930   // Flip the lanes, and shuffle the results which should now be in-lane.
14931   MVT PVT = VT.isFloatingPoint() ? MVT::v4f64 : MVT::v4i64;
14932   SDValue Flipped = DAG.getBitcast(PVT, V1);
14933   Flipped =
14934       DAG.getVectorShuffle(PVT, DL, Flipped, DAG.getUNDEF(PVT), {2, 3, 0, 1});
14935   Flipped = DAG.getBitcast(VT, Flipped);
14936   return DAG.getVectorShuffle(VT, DL, V1, Flipped, InLaneMask);
14937 }
14938 
14939 /// Handle lowering 2-lane 128-bit shuffles.
lowerV2X128Shuffle(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)14940 static SDValue lowerV2X128Shuffle(const SDLoc &DL, MVT VT, SDValue V1,
14941                                   SDValue V2, ArrayRef<int> Mask,
14942                                   const APInt &Zeroable,
14943                                   const X86Subtarget &Subtarget,
14944                                   SelectionDAG &DAG) {
14945   if (V2.isUndef()) {
14946     // Attempt to match VBROADCAST*128 subvector broadcast load.
14947     bool SplatLo = isShuffleEquivalent(Mask, {0, 1, 0, 1}, V1);
14948     bool SplatHi = isShuffleEquivalent(Mask, {2, 3, 2, 3}, V1);
14949     if ((SplatLo || SplatHi) && !Subtarget.hasAVX512() && V1.hasOneUse() &&
14950         X86::mayFoldLoad(peekThroughOneUseBitcasts(V1), Subtarget)) {
14951       MVT MemVT = VT.getHalfNumVectorElementsVT();
14952       unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
14953       auto *Ld = cast<LoadSDNode>(peekThroughOneUseBitcasts(V1));
14954       if (SDValue BcstLd = getBROADCAST_LOAD(X86ISD::SUBV_BROADCAST_LOAD, DL,
14955                                              VT, MemVT, Ld, Ofs, DAG))
14956         return BcstLd;
14957     }
14958 
14959     // With AVX2, use VPERMQ/VPERMPD for unary shuffles to allow memory folding.
14960     if (Subtarget.hasAVX2())
14961       return SDValue();
14962   }
14963 
14964   bool V2IsZero = !V2.isUndef() && ISD::isBuildVectorAllZeros(V2.getNode());
14965 
14966   SmallVector<int, 4> WidenedMask;
14967   if (!canWidenShuffleElements(Mask, Zeroable, V2IsZero, WidenedMask))
14968     return SDValue();
14969 
14970   bool IsLowZero = (Zeroable & 0x3) == 0x3;
14971   bool IsHighZero = (Zeroable & 0xc) == 0xc;
14972 
14973   // Try to use an insert into a zero vector.
14974   if (WidenedMask[0] == 0 && IsHighZero) {
14975     MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2);
14976     SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1,
14977                               DAG.getIntPtrConstant(0, DL));
14978     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
14979                        getZeroVector(VT, Subtarget, DAG, DL), LoV,
14980                        DAG.getIntPtrConstant(0, DL));
14981   }
14982 
14983   // TODO: If minimizing size and one of the inputs is a zero vector and the
14984   // the zero vector has only one use, we could use a VPERM2X128 to save the
14985   // instruction bytes needed to explicitly generate the zero vector.
14986 
14987   // Blends are faster and handle all the non-lane-crossing cases.
14988   if (SDValue Blend = lowerShuffleAsBlend(DL, VT, V1, V2, Mask, Zeroable,
14989                                           Subtarget, DAG))
14990     return Blend;
14991 
14992   // If either input operand is a zero vector, use VPERM2X128 because its mask
14993   // allows us to replace the zero input with an implicit zero.
14994   if (!IsLowZero && !IsHighZero) {
14995     // Check for patterns which can be matched with a single insert of a 128-bit
14996     // subvector.
14997     bool OnlyUsesV1 = isShuffleEquivalent(Mask, {0, 1, 0, 1}, V1, V2);
14998     if (OnlyUsesV1 || isShuffleEquivalent(Mask, {0, 1, 4, 5}, V1, V2)) {
14999 
15000       // With AVX1, use vperm2f128 (below) to allow load folding. Otherwise,
15001       // this will likely become vinsertf128 which can't fold a 256-bit memop.
15002       if (!isa<LoadSDNode>(peekThroughBitcasts(V1))) {
15003         MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2);
15004         SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT,
15005                                      OnlyUsesV1 ? V1 : V2,
15006                                      DAG.getIntPtrConstant(0, DL));
15007         return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, V1, SubVec,
15008                            DAG.getIntPtrConstant(2, DL));
15009       }
15010     }
15011 
15012     // Try to use SHUF128 if possible.
15013     if (Subtarget.hasVLX()) {
15014       if (WidenedMask[0] < 2 && WidenedMask[1] >= 2) {
15015         unsigned PermMask = ((WidenedMask[0] % 2) << 0) |
15016                             ((WidenedMask[1] % 2) << 1);
15017         return DAG.getNode(X86ISD::SHUF128, DL, VT, V1, V2,
15018                            DAG.getTargetConstant(PermMask, DL, MVT::i8));
15019       }
15020     }
15021   }
15022 
15023   // Otherwise form a 128-bit permutation. After accounting for undefs,
15024   // convert the 64-bit shuffle mask selection values into 128-bit
15025   // selection bits by dividing the indexes by 2 and shifting into positions
15026   // defined by a vperm2*128 instruction's immediate control byte.
15027 
15028   // The immediate permute control byte looks like this:
15029   //    [1:0] - select 128 bits from sources for low half of destination
15030   //    [2]   - ignore
15031   //    [3]   - zero low half of destination
15032   //    [5:4] - select 128 bits from sources for high half of destination
15033   //    [6]   - ignore
15034   //    [7]   - zero high half of destination
15035 
15036   assert((WidenedMask[0] >= 0 || IsLowZero) &&
15037          (WidenedMask[1] >= 0 || IsHighZero) && "Undef half?");
15038 
15039   unsigned PermMask = 0;
15040   PermMask |= IsLowZero  ? 0x08 : (WidenedMask[0] << 0);
15041   PermMask |= IsHighZero ? 0x80 : (WidenedMask[1] << 4);
15042 
15043   // Check the immediate mask and replace unused sources with undef.
15044   if ((PermMask & 0x0a) != 0x00 && (PermMask & 0xa0) != 0x00)
15045     V1 = DAG.getUNDEF(VT);
15046   if ((PermMask & 0x0a) != 0x02 && (PermMask & 0xa0) != 0x20)
15047     V2 = DAG.getUNDEF(VT);
15048 
15049   return DAG.getNode(X86ISD::VPERM2X128, DL, VT, V1, V2,
15050                      DAG.getTargetConstant(PermMask, DL, MVT::i8));
15051 }
15052 
15053 /// Lower a vector shuffle by first fixing the 128-bit lanes and then
15054 /// shuffling each lane.
15055 ///
15056 /// This attempts to create a repeated lane shuffle where each lane uses one
15057 /// or two of the lanes of the inputs. The lanes of the input vectors are
15058 /// shuffled in one or two independent shuffles to get the lanes into the
15059 /// position needed by the final shuffle.
lowerShuffleAsLanePermuteAndRepeatedMask(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)15060 static SDValue lowerShuffleAsLanePermuteAndRepeatedMask(
15061     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
15062     const X86Subtarget &Subtarget, SelectionDAG &DAG) {
15063   assert(!V2.isUndef() && "This is only useful with multiple inputs.");
15064 
15065   if (is128BitLaneRepeatedShuffleMask(VT, Mask))
15066     return SDValue();
15067 
15068   int NumElts = Mask.size();
15069   int NumLanes = VT.getSizeInBits() / 128;
15070   int NumLaneElts = 128 / VT.getScalarSizeInBits();
15071   SmallVector<int, 16> RepeatMask(NumLaneElts, -1);
15072   SmallVector<std::array<int, 2>, 2> LaneSrcs(NumLanes, {{-1, -1}});
15073 
15074   // First pass will try to fill in the RepeatMask from lanes that need two
15075   // sources.
15076   for (int Lane = 0; Lane != NumLanes; ++Lane) {
15077     int Srcs[2] = {-1, -1};
15078     SmallVector<int, 16> InLaneMask(NumLaneElts, -1);
15079     for (int i = 0; i != NumLaneElts; ++i) {
15080       int M = Mask[(Lane * NumLaneElts) + i];
15081       if (M < 0)
15082         continue;
15083       // Determine which of the possible input lanes (NumLanes from each source)
15084       // this element comes from. Assign that as one of the sources for this
15085       // lane. We can assign up to 2 sources for this lane. If we run out
15086       // sources we can't do anything.
15087       int LaneSrc = M / NumLaneElts;
15088       int Src;
15089       if (Srcs[0] < 0 || Srcs[0] == LaneSrc)
15090         Src = 0;
15091       else if (Srcs[1] < 0 || Srcs[1] == LaneSrc)
15092         Src = 1;
15093       else
15094         return SDValue();
15095 
15096       Srcs[Src] = LaneSrc;
15097       InLaneMask[i] = (M % NumLaneElts) + Src * NumElts;
15098     }
15099 
15100     // If this lane has two sources, see if it fits with the repeat mask so far.
15101     if (Srcs[1] < 0)
15102       continue;
15103 
15104     LaneSrcs[Lane][0] = Srcs[0];
15105     LaneSrcs[Lane][1] = Srcs[1];
15106 
15107     auto MatchMasks = [](ArrayRef<int> M1, ArrayRef<int> M2) {
15108       assert(M1.size() == M2.size() && "Unexpected mask size");
15109       for (int i = 0, e = M1.size(); i != e; ++i)
15110         if (M1[i] >= 0 && M2[i] >= 0 && M1[i] != M2[i])
15111           return false;
15112       return true;
15113     };
15114 
15115     auto MergeMasks = [](ArrayRef<int> Mask, MutableArrayRef<int> MergedMask) {
15116       assert(Mask.size() == MergedMask.size() && "Unexpected mask size");
15117       for (int i = 0, e = MergedMask.size(); i != e; ++i) {
15118         int M = Mask[i];
15119         if (M < 0)
15120           continue;
15121         assert((MergedMask[i] < 0 || MergedMask[i] == M) &&
15122                "Unexpected mask element");
15123         MergedMask[i] = M;
15124       }
15125     };
15126 
15127     if (MatchMasks(InLaneMask, RepeatMask)) {
15128       // Merge this lane mask into the final repeat mask.
15129       MergeMasks(InLaneMask, RepeatMask);
15130       continue;
15131     }
15132 
15133     // Didn't find a match. Swap the operands and try again.
15134     std::swap(LaneSrcs[Lane][0], LaneSrcs[Lane][1]);
15135     ShuffleVectorSDNode::commuteMask(InLaneMask);
15136 
15137     if (MatchMasks(InLaneMask, RepeatMask)) {
15138       // Merge this lane mask into the final repeat mask.
15139       MergeMasks(InLaneMask, RepeatMask);
15140       continue;
15141     }
15142 
15143     // Couldn't find a match with the operands in either order.
15144     return SDValue();
15145   }
15146 
15147   // Now handle any lanes with only one source.
15148   for (int Lane = 0; Lane != NumLanes; ++Lane) {
15149     // If this lane has already been processed, skip it.
15150     if (LaneSrcs[Lane][0] >= 0)
15151       continue;
15152 
15153     for (int i = 0; i != NumLaneElts; ++i) {
15154       int M = Mask[(Lane * NumLaneElts) + i];
15155       if (M < 0)
15156         continue;
15157 
15158       // If RepeatMask isn't defined yet we can define it ourself.
15159       if (RepeatMask[i] < 0)
15160         RepeatMask[i] = M % NumLaneElts;
15161 
15162       if (RepeatMask[i] < NumElts) {
15163         if (RepeatMask[i] != M % NumLaneElts)
15164           return SDValue();
15165         LaneSrcs[Lane][0] = M / NumLaneElts;
15166       } else {
15167         if (RepeatMask[i] != ((M % NumLaneElts) + NumElts))
15168           return SDValue();
15169         LaneSrcs[Lane][1] = M / NumLaneElts;
15170       }
15171     }
15172 
15173     if (LaneSrcs[Lane][0] < 0 && LaneSrcs[Lane][1] < 0)
15174       return SDValue();
15175   }
15176 
15177   SmallVector<int, 16> NewMask(NumElts, -1);
15178   for (int Lane = 0; Lane != NumLanes; ++Lane) {
15179     int Src = LaneSrcs[Lane][0];
15180     for (int i = 0; i != NumLaneElts; ++i) {
15181       int M = -1;
15182       if (Src >= 0)
15183         M = Src * NumLaneElts + i;
15184       NewMask[Lane * NumLaneElts + i] = M;
15185     }
15186   }
15187   SDValue NewV1 = DAG.getVectorShuffle(VT, DL, V1, V2, NewMask);
15188   // Ensure we didn't get back the shuffle we started with.
15189   // FIXME: This is a hack to make up for some splat handling code in
15190   // getVectorShuffle.
15191   if (isa<ShuffleVectorSDNode>(NewV1) &&
15192       cast<ShuffleVectorSDNode>(NewV1)->getMask() == Mask)
15193     return SDValue();
15194 
15195   for (int Lane = 0; Lane != NumLanes; ++Lane) {
15196     int Src = LaneSrcs[Lane][1];
15197     for (int i = 0; i != NumLaneElts; ++i) {
15198       int M = -1;
15199       if (Src >= 0)
15200         M = Src * NumLaneElts + i;
15201       NewMask[Lane * NumLaneElts + i] = M;
15202     }
15203   }
15204   SDValue NewV2 = DAG.getVectorShuffle(VT, DL, V1, V2, NewMask);
15205   // Ensure we didn't get back the shuffle we started with.
15206   // FIXME: This is a hack to make up for some splat handling code in
15207   // getVectorShuffle.
15208   if (isa<ShuffleVectorSDNode>(NewV2) &&
15209       cast<ShuffleVectorSDNode>(NewV2)->getMask() == Mask)
15210     return SDValue();
15211 
15212   for (int i = 0; i != NumElts; ++i) {
15213     if (Mask[i] < 0) {
15214       NewMask[i] = -1;
15215       continue;
15216     }
15217     NewMask[i] = RepeatMask[i % NumLaneElts];
15218     if (NewMask[i] < 0)
15219       continue;
15220 
15221     NewMask[i] += (i / NumLaneElts) * NumLaneElts;
15222   }
15223   return DAG.getVectorShuffle(VT, DL, NewV1, NewV2, NewMask);
15224 }
15225 
15226 /// If the input shuffle mask results in a vector that is undefined in all upper
15227 /// or lower half elements and that mask accesses only 2 halves of the
15228 /// shuffle's operands, return true. A mask of half the width with mask indexes
15229 /// adjusted to access the extracted halves of the original shuffle operands is
15230 /// returned in HalfMask. HalfIdx1 and HalfIdx2 return whether the upper or
15231 /// lower half of each input operand is accessed.
15232 static bool
getHalfShuffleMask(ArrayRef<int> Mask,MutableArrayRef<int> HalfMask,int & HalfIdx1,int & HalfIdx2)15233 getHalfShuffleMask(ArrayRef<int> Mask, MutableArrayRef<int> HalfMask,
15234                    int &HalfIdx1, int &HalfIdx2) {
15235   assert((Mask.size() == HalfMask.size() * 2) &&
15236          "Expected input mask to be twice as long as output");
15237 
15238   // Exactly one half of the result must be undef to allow narrowing.
15239   bool UndefLower = isUndefLowerHalf(Mask);
15240   bool UndefUpper = isUndefUpperHalf(Mask);
15241   if (UndefLower == UndefUpper)
15242     return false;
15243 
15244   unsigned HalfNumElts = HalfMask.size();
15245   unsigned MaskIndexOffset = UndefLower ? HalfNumElts : 0;
15246   HalfIdx1 = -1;
15247   HalfIdx2 = -1;
15248   for (unsigned i = 0; i != HalfNumElts; ++i) {
15249     int M = Mask[i + MaskIndexOffset];
15250     if (M < 0) {
15251       HalfMask[i] = M;
15252       continue;
15253     }
15254 
15255     // Determine which of the 4 half vectors this element is from.
15256     // i.e. 0 = Lower V1, 1 = Upper V1, 2 = Lower V2, 3 = Upper V2.
15257     int HalfIdx = M / HalfNumElts;
15258 
15259     // Determine the element index into its half vector source.
15260     int HalfElt = M % HalfNumElts;
15261 
15262     // We can shuffle with up to 2 half vectors, set the new 'half'
15263     // shuffle mask accordingly.
15264     if (HalfIdx1 < 0 || HalfIdx1 == HalfIdx) {
15265       HalfMask[i] = HalfElt;
15266       HalfIdx1 = HalfIdx;
15267       continue;
15268     }
15269     if (HalfIdx2 < 0 || HalfIdx2 == HalfIdx) {
15270       HalfMask[i] = HalfElt + HalfNumElts;
15271       HalfIdx2 = HalfIdx;
15272       continue;
15273     }
15274 
15275     // Too many half vectors referenced.
15276     return false;
15277   }
15278 
15279   return true;
15280 }
15281 
15282 /// Given the output values from getHalfShuffleMask(), create a half width
15283 /// shuffle of extracted vectors followed by an insert back to full width.
getShuffleHalfVectors(const SDLoc & DL,SDValue V1,SDValue V2,ArrayRef<int> HalfMask,int HalfIdx1,int HalfIdx2,bool UndefLower,SelectionDAG & DAG,bool UseConcat=false)15284 static SDValue getShuffleHalfVectors(const SDLoc &DL, SDValue V1, SDValue V2,
15285                                      ArrayRef<int> HalfMask, int HalfIdx1,
15286                                      int HalfIdx2, bool UndefLower,
15287                                      SelectionDAG &DAG, bool UseConcat = false) {
15288   assert(V1.getValueType() == V2.getValueType() && "Different sized vectors?");
15289   assert(V1.getValueType().isSimple() && "Expecting only simple types");
15290 
15291   MVT VT = V1.getSimpleValueType();
15292   MVT HalfVT = VT.getHalfNumVectorElementsVT();
15293   unsigned HalfNumElts = HalfVT.getVectorNumElements();
15294 
15295   auto getHalfVector = [&](int HalfIdx) {
15296     if (HalfIdx < 0)
15297       return DAG.getUNDEF(HalfVT);
15298     SDValue V = (HalfIdx < 2 ? V1 : V2);
15299     HalfIdx = (HalfIdx % 2) * HalfNumElts;
15300     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, V,
15301                        DAG.getIntPtrConstant(HalfIdx, DL));
15302   };
15303 
15304   // ins undef, (shuf (ext V1, HalfIdx1), (ext V2, HalfIdx2), HalfMask), Offset
15305   SDValue Half1 = getHalfVector(HalfIdx1);
15306   SDValue Half2 = getHalfVector(HalfIdx2);
15307   SDValue V = DAG.getVectorShuffle(HalfVT, DL, Half1, Half2, HalfMask);
15308   if (UseConcat) {
15309     SDValue Op0 = V;
15310     SDValue Op1 = DAG.getUNDEF(HalfVT);
15311     if (UndefLower)
15312       std::swap(Op0, Op1);
15313     return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Op0, Op1);
15314   }
15315 
15316   unsigned Offset = UndefLower ? HalfNumElts : 0;
15317   return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), V,
15318                      DAG.getIntPtrConstant(Offset, DL));
15319 }
15320 
15321 /// Lower shuffles where an entire half of a 256 or 512-bit vector is UNDEF.
15322 /// This allows for fast cases such as subvector extraction/insertion
15323 /// or shuffling smaller vector types which can lower more efficiently.
lowerShuffleWithUndefHalf(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)15324 static SDValue lowerShuffleWithUndefHalf(const SDLoc &DL, MVT VT, SDValue V1,
15325                                          SDValue V2, ArrayRef<int> Mask,
15326                                          const X86Subtarget &Subtarget,
15327                                          SelectionDAG &DAG) {
15328   assert((VT.is256BitVector() || VT.is512BitVector()) &&
15329          "Expected 256-bit or 512-bit vector");
15330 
15331   bool UndefLower = isUndefLowerHalf(Mask);
15332   if (!UndefLower && !isUndefUpperHalf(Mask))
15333     return SDValue();
15334 
15335   assert((!UndefLower || !isUndefUpperHalf(Mask)) &&
15336          "Completely undef shuffle mask should have been simplified already");
15337 
15338   // Upper half is undef and lower half is whole upper subvector.
15339   // e.g. vector_shuffle <4, 5, 6, 7, u, u, u, u> or <2, 3, u, u>
15340   MVT HalfVT = VT.getHalfNumVectorElementsVT();
15341   unsigned HalfNumElts = HalfVT.getVectorNumElements();
15342   if (!UndefLower &&
15343       isSequentialOrUndefInRange(Mask, 0, HalfNumElts, HalfNumElts)) {
15344     SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, V1,
15345                              DAG.getIntPtrConstant(HalfNumElts, DL));
15346     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), Hi,
15347                        DAG.getIntPtrConstant(0, DL));
15348   }
15349 
15350   // Lower half is undef and upper half is whole lower subvector.
15351   // e.g. vector_shuffle <u, u, u, u, 0, 1, 2, 3> or <u, u, 0, 1>
15352   if (UndefLower &&
15353       isSequentialOrUndefInRange(Mask, HalfNumElts, HalfNumElts, 0)) {
15354     SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, V1,
15355                              DAG.getIntPtrConstant(0, DL));
15356     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), Hi,
15357                        DAG.getIntPtrConstant(HalfNumElts, DL));
15358   }
15359 
15360   int HalfIdx1, HalfIdx2;
15361   SmallVector<int, 8> HalfMask(HalfNumElts);
15362   if (!getHalfShuffleMask(Mask, HalfMask, HalfIdx1, HalfIdx2))
15363     return SDValue();
15364 
15365   assert(HalfMask.size() == HalfNumElts && "Unexpected shuffle mask length");
15366 
15367   // Only shuffle the halves of the inputs when useful.
15368   unsigned NumLowerHalves =
15369       (HalfIdx1 == 0 || HalfIdx1 == 2) + (HalfIdx2 == 0 || HalfIdx2 == 2);
15370   unsigned NumUpperHalves =
15371       (HalfIdx1 == 1 || HalfIdx1 == 3) + (HalfIdx2 == 1 || HalfIdx2 == 3);
15372   assert(NumLowerHalves + NumUpperHalves <= 2 && "Only 1 or 2 halves allowed");
15373 
15374   // Determine the larger pattern of undef/halves, then decide if it's worth
15375   // splitting the shuffle based on subtarget capabilities and types.
15376   unsigned EltWidth = VT.getVectorElementType().getSizeInBits();
15377   if (!UndefLower) {
15378     // XXXXuuuu: no insert is needed.
15379     // Always extract lowers when setting lower - these are all free subreg ops.
15380     if (NumUpperHalves == 0)
15381       return getShuffleHalfVectors(DL, V1, V2, HalfMask, HalfIdx1, HalfIdx2,
15382                                    UndefLower, DAG);
15383 
15384     if (NumUpperHalves == 1) {
15385       // AVX2 has efficient 32/64-bit element cross-lane shuffles.
15386       if (Subtarget.hasAVX2()) {
15387         // extract128 + vunpckhps/vshufps, is better than vblend + vpermps.
15388         if (EltWidth == 32 && NumLowerHalves && HalfVT.is128BitVector() &&
15389             !is128BitUnpackShuffleMask(HalfMask, DAG) &&
15390             (!isSingleSHUFPSMask(HalfMask) ||
15391              Subtarget.hasFastVariableCrossLaneShuffle()))
15392           return SDValue();
15393         // If this is a unary shuffle (assume that the 2nd operand is
15394         // canonicalized to undef), then we can use vpermpd. Otherwise, we
15395         // are better off extracting the upper half of 1 operand and using a
15396         // narrow shuffle.
15397         if (EltWidth == 64 && V2.isUndef())
15398           return SDValue();
15399       }
15400       // AVX512 has efficient cross-lane shuffles for all legal 512-bit types.
15401       if (Subtarget.hasAVX512() && VT.is512BitVector())
15402         return SDValue();
15403       // Extract + narrow shuffle is better than the wide alternative.
15404       return getShuffleHalfVectors(DL, V1, V2, HalfMask, HalfIdx1, HalfIdx2,
15405                                    UndefLower, DAG);
15406     }
15407 
15408     // Don't extract both uppers, instead shuffle and then extract.
15409     assert(NumUpperHalves == 2 && "Half vector count went wrong");
15410     return SDValue();
15411   }
15412 
15413   // UndefLower - uuuuXXXX: an insert to high half is required if we split this.
15414   if (NumUpperHalves == 0) {
15415     // AVX2 has efficient 64-bit element cross-lane shuffles.
15416     // TODO: Refine to account for unary shuffle, splat, and other masks?
15417     if (Subtarget.hasAVX2() && EltWidth == 64)
15418       return SDValue();
15419     // AVX512 has efficient cross-lane shuffles for all legal 512-bit types.
15420     if (Subtarget.hasAVX512() && VT.is512BitVector())
15421       return SDValue();
15422     // Narrow shuffle + insert is better than the wide alternative.
15423     return getShuffleHalfVectors(DL, V1, V2, HalfMask, HalfIdx1, HalfIdx2,
15424                                  UndefLower, DAG);
15425   }
15426 
15427   // NumUpperHalves != 0: don't bother with extract, shuffle, and then insert.
15428   return SDValue();
15429 }
15430 
15431 /// Handle case where shuffle sources are coming from the same 128-bit lane and
15432 /// every lane can be represented as the same repeating mask - allowing us to
15433 /// shuffle the sources with the repeating shuffle and then permute the result
15434 /// to the destination lanes.
lowerShuffleAsRepeatedMaskAndLanePermute(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const X86Subtarget & Subtarget,SelectionDAG & DAG)15435 static SDValue lowerShuffleAsRepeatedMaskAndLanePermute(
15436     const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask,
15437     const X86Subtarget &Subtarget, SelectionDAG &DAG) {
15438   int NumElts = VT.getVectorNumElements();
15439   int NumLanes = VT.getSizeInBits() / 128;
15440   int NumLaneElts = NumElts / NumLanes;
15441 
15442   // On AVX2 we may be able to just shuffle the lowest elements and then
15443   // broadcast the result.
15444   if (Subtarget.hasAVX2()) {
15445     for (unsigned BroadcastSize : {16, 32, 64}) {
15446       if (BroadcastSize <= VT.getScalarSizeInBits())
15447         continue;
15448       int NumBroadcastElts = BroadcastSize / VT.getScalarSizeInBits();
15449 
15450       // Attempt to match a repeating pattern every NumBroadcastElts,
15451       // accounting for UNDEFs but only references the lowest 128-bit
15452       // lane of the inputs.
15453       auto FindRepeatingBroadcastMask = [&](SmallVectorImpl<int> &RepeatMask) {
15454         for (int i = 0; i != NumElts; i += NumBroadcastElts)
15455           for (int j = 0; j != NumBroadcastElts; ++j) {
15456             int M = Mask[i + j];
15457             if (M < 0)
15458               continue;
15459             int &R = RepeatMask[j];
15460             if (0 != ((M % NumElts) / NumLaneElts))
15461               return false;
15462             if (0 <= R && R != M)
15463               return false;
15464             R = M;
15465           }
15466         return true;
15467       };
15468 
15469       SmallVector<int, 8> RepeatMask((unsigned)NumElts, -1);
15470       if (!FindRepeatingBroadcastMask(RepeatMask))
15471         continue;
15472 
15473       // Shuffle the (lowest) repeated elements in place for broadcast.
15474       SDValue RepeatShuf = DAG.getVectorShuffle(VT, DL, V1, V2, RepeatMask);
15475 
15476       // Shuffle the actual broadcast.
15477       SmallVector<int, 8> BroadcastMask((unsigned)NumElts, -1);
15478       for (int i = 0; i != NumElts; i += NumBroadcastElts)
15479         for (int j = 0; j != NumBroadcastElts; ++j)
15480           BroadcastMask[i + j] = j;
15481 
15482       // Avoid returning the same shuffle operation. For example,
15483       // v8i32 = vector_shuffle<0,1,0,1,0,1,0,1> t5, undef:v8i32
15484       if (BroadcastMask == Mask)
15485         return SDValue();
15486 
15487       return DAG.getVectorShuffle(VT, DL, RepeatShuf, DAG.getUNDEF(VT),
15488                                   BroadcastMask);
15489     }
15490   }
15491 
15492   // Bail if the shuffle mask doesn't cross 128-bit lanes.
15493   if (!is128BitLaneCrossingShuffleMask(VT, Mask))
15494     return SDValue();
15495 
15496   // Bail if we already have a repeated lane shuffle mask.
15497   if (is128BitLaneRepeatedShuffleMask(VT, Mask))
15498     return SDValue();
15499 
15500   // Helper to look for repeated mask in each split sublane, and that those
15501   // sublanes can then be permuted into place.
15502   auto ShuffleSubLanes = [&](int SubLaneScale) {
15503     int NumSubLanes = NumLanes * SubLaneScale;
15504     int NumSubLaneElts = NumLaneElts / SubLaneScale;
15505 
15506     // Check that all the sources are coming from the same lane and see if we
15507     // can form a repeating shuffle mask (local to each sub-lane). At the same
15508     // time, determine the source sub-lane for each destination sub-lane.
15509     int TopSrcSubLane = -1;
15510     SmallVector<int, 8> Dst2SrcSubLanes((unsigned)NumSubLanes, -1);
15511     SmallVector<SmallVector<int, 8>> RepeatedSubLaneMasks(
15512         SubLaneScale,
15513         SmallVector<int, 8>((unsigned)NumSubLaneElts, SM_SentinelUndef));
15514 
15515     for (int DstSubLane = 0; DstSubLane != NumSubLanes; ++DstSubLane) {
15516       // Extract the sub-lane mask, check that it all comes from the same lane
15517       // and normalize the mask entries to come from the first lane.
15518       int SrcLane = -1;
15519       SmallVector<int, 8> SubLaneMask((unsigned)NumSubLaneElts, -1);
15520       for (int Elt = 0; Elt != NumSubLaneElts; ++Elt) {
15521         int M = Mask[(DstSubLane * NumSubLaneElts) + Elt];
15522         if (M < 0)
15523           continue;
15524         int Lane = (M % NumElts) / NumLaneElts;
15525         if ((0 <= SrcLane) && (SrcLane != Lane))
15526           return SDValue();
15527         SrcLane = Lane;
15528         int LocalM = (M % NumLaneElts) + (M < NumElts ? 0 : NumElts);
15529         SubLaneMask[Elt] = LocalM;
15530       }
15531 
15532       // Whole sub-lane is UNDEF.
15533       if (SrcLane < 0)
15534         continue;
15535 
15536       // Attempt to match against the candidate repeated sub-lane masks.
15537       for (int SubLane = 0; SubLane != SubLaneScale; ++SubLane) {
15538         auto MatchMasks = [NumSubLaneElts](ArrayRef<int> M1, ArrayRef<int> M2) {
15539           for (int i = 0; i != NumSubLaneElts; ++i) {
15540             if (M1[i] < 0 || M2[i] < 0)
15541               continue;
15542             if (M1[i] != M2[i])
15543               return false;
15544           }
15545           return true;
15546         };
15547 
15548         auto &RepeatedSubLaneMask = RepeatedSubLaneMasks[SubLane];
15549         if (!MatchMasks(SubLaneMask, RepeatedSubLaneMask))
15550           continue;
15551 
15552         // Merge the sub-lane mask into the matching repeated sub-lane mask.
15553         for (int i = 0; i != NumSubLaneElts; ++i) {
15554           int M = SubLaneMask[i];
15555           if (M < 0)
15556             continue;
15557           assert((RepeatedSubLaneMask[i] < 0 || RepeatedSubLaneMask[i] == M) &&
15558                  "Unexpected mask element");
15559           RepeatedSubLaneMask[i] = M;
15560         }
15561 
15562         // Track the top most source sub-lane - by setting the remaining to
15563         // UNDEF we can greatly simplify shuffle matching.
15564         int SrcSubLane = (SrcLane * SubLaneScale) + SubLane;
15565         TopSrcSubLane = std::max(TopSrcSubLane, SrcSubLane);
15566         Dst2SrcSubLanes[DstSubLane] = SrcSubLane;
15567         break;
15568       }
15569 
15570       // Bail if we failed to find a matching repeated sub-lane mask.
15571       if (Dst2SrcSubLanes[DstSubLane] < 0)
15572         return SDValue();
15573     }
15574     assert(0 <= TopSrcSubLane && TopSrcSubLane < NumSubLanes &&
15575            "Unexpected source lane");
15576 
15577     // Create a repeating shuffle mask for the entire vector.
15578     SmallVector<int, 8> RepeatedMask((unsigned)NumElts, -1);
15579     for (int SubLane = 0; SubLane <= TopSrcSubLane; ++SubLane) {
15580       int Lane = SubLane / SubLaneScale;
15581       auto &RepeatedSubLaneMask = RepeatedSubLaneMasks[SubLane % SubLaneScale];
15582       for (int Elt = 0; Elt != NumSubLaneElts; ++Elt) {
15583         int M = RepeatedSubLaneMask[Elt];
15584         if (M < 0)
15585           continue;
15586         int Idx = (SubLane * NumSubLaneElts) + Elt;
15587         RepeatedMask[Idx] = M + (Lane * NumLaneElts);
15588       }
15589     }
15590 
15591     // Shuffle each source sub-lane to its destination.
15592     SmallVector<int, 8> SubLaneMask((unsigned)NumElts, -1);
15593     for (int i = 0; i != NumElts; i += NumSubLaneElts) {
15594       int SrcSubLane = Dst2SrcSubLanes[i / NumSubLaneElts];
15595       if (SrcSubLane < 0)
15596         continue;
15597       for (int j = 0; j != NumSubLaneElts; ++j)
15598         SubLaneMask[i + j] = j + (SrcSubLane * NumSubLaneElts);
15599     }
15600 
15601     // Avoid returning the same shuffle operation.
15602     // v8i32 = vector_shuffle<0,1,4,5,2,3,6,7> t5, undef:v8i32
15603     if (RepeatedMask == Mask || SubLaneMask == Mask)
15604       return SDValue();
15605 
15606     SDValue RepeatedShuffle =
15607         DAG.getVectorShuffle(VT, DL, V1, V2, RepeatedMask);
15608 
15609     return DAG.getVectorShuffle(VT, DL, RepeatedShuffle, DAG.getUNDEF(VT),
15610                                 SubLaneMask);
15611   };
15612 
15613   // On AVX2 targets we can permute 256-bit vectors as 64-bit sub-lanes
15614   // (with PERMQ/PERMPD). On AVX2/AVX512BW targets, permuting 32-bit sub-lanes,
15615   // even with a variable shuffle, can be worth it for v32i8/v64i8 vectors.
15616   // Otherwise we can only permute whole 128-bit lanes.
15617   int MinSubLaneScale = 1, MaxSubLaneScale = 1;
15618   if (Subtarget.hasAVX2() && VT.is256BitVector()) {
15619     bool OnlyLowestElts = isUndefOrInRange(Mask, 0, NumLaneElts);
15620     MinSubLaneScale = 2;
15621     MaxSubLaneScale =
15622         (!OnlyLowestElts && V2.isUndef() && VT == MVT::v32i8) ? 4 : 2;
15623   }
15624   if (Subtarget.hasBWI() && VT == MVT::v64i8)
15625     MinSubLaneScale = MaxSubLaneScale = 4;
15626 
15627   for (int Scale = MinSubLaneScale; Scale <= MaxSubLaneScale; Scale *= 2)
15628     if (SDValue Shuffle = ShuffleSubLanes(Scale))
15629       return Shuffle;
15630 
15631   return SDValue();
15632 }
15633 
matchShuffleWithSHUFPD(MVT VT,SDValue & V1,SDValue & V2,bool & ForceV1Zero,bool & ForceV2Zero,unsigned & ShuffleImm,ArrayRef<int> Mask,const APInt & Zeroable)15634 static bool matchShuffleWithSHUFPD(MVT VT, SDValue &V1, SDValue &V2,
15635                                    bool &ForceV1Zero, bool &ForceV2Zero,
15636                                    unsigned &ShuffleImm, ArrayRef<int> Mask,
15637                                    const APInt &Zeroable) {
15638   int NumElts = VT.getVectorNumElements();
15639   assert(VT.getScalarSizeInBits() == 64 &&
15640          (NumElts == 2 || NumElts == 4 || NumElts == 8) &&
15641          "Unexpected data type for VSHUFPD");
15642   assert(isUndefOrZeroOrInRange(Mask, 0, 2 * NumElts) &&
15643          "Illegal shuffle mask");
15644 
15645   bool ZeroLane[2] = { true, true };
15646   for (int i = 0; i < NumElts; ++i)
15647     ZeroLane[i & 1] &= Zeroable[i];
15648 
15649   // Mask for V8F64: 0/1,  8/9,  2/3,  10/11, 4/5, ..
15650   // Mask for V4F64; 0/1,  4/5,  2/3,  6/7..
15651   ShuffleImm = 0;
15652   bool ShufpdMask = true;
15653   bool CommutableMask = true;
15654   for (int i = 0; i < NumElts; ++i) {
15655     if (Mask[i] == SM_SentinelUndef || ZeroLane[i & 1])
15656       continue;
15657     if (Mask[i] < 0)
15658       return false;
15659     int Val = (i & 6) + NumElts * (i & 1);
15660     int CommutVal = (i & 0xe) + NumElts * ((i & 1) ^ 1);
15661     if (Mask[i] < Val || Mask[i] > Val + 1)
15662       ShufpdMask = false;
15663     if (Mask[i] < CommutVal || Mask[i] > CommutVal + 1)
15664       CommutableMask = false;
15665     ShuffleImm |= (Mask[i] % 2) << i;
15666   }
15667 
15668   if (!ShufpdMask && !CommutableMask)
15669     return false;
15670 
15671   if (!ShufpdMask && CommutableMask)
15672     std::swap(V1, V2);
15673 
15674   ForceV1Zero = ZeroLane[0];
15675   ForceV2Zero = ZeroLane[1];
15676   return true;
15677 }
15678 
lowerShuffleWithSHUFPD(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)15679 static SDValue lowerShuffleWithSHUFPD(const SDLoc &DL, MVT VT, SDValue V1,
15680                                       SDValue V2, ArrayRef<int> Mask,
15681                                       const APInt &Zeroable,
15682                                       const X86Subtarget &Subtarget,
15683                                       SelectionDAG &DAG) {
15684   assert((VT == MVT::v2f64 || VT == MVT::v4f64 || VT == MVT::v8f64) &&
15685          "Unexpected data type for VSHUFPD");
15686 
15687   unsigned Immediate = 0;
15688   bool ForceV1Zero = false, ForceV2Zero = false;
15689   if (!matchShuffleWithSHUFPD(VT, V1, V2, ForceV1Zero, ForceV2Zero, Immediate,
15690                               Mask, Zeroable))
15691     return SDValue();
15692 
15693   // Create a REAL zero vector - ISD::isBuildVectorAllZeros allows UNDEFs.
15694   if (ForceV1Zero)
15695     V1 = getZeroVector(VT, Subtarget, DAG, DL);
15696   if (ForceV2Zero)
15697     V2 = getZeroVector(VT, Subtarget, DAG, DL);
15698 
15699   return DAG.getNode(X86ISD::SHUFP, DL, VT, V1, V2,
15700                      DAG.getTargetConstant(Immediate, DL, MVT::i8));
15701 }
15702 
15703 // Look for {0, 8, 16, 24, 32, 40, 48, 56 } in the first 8 elements. Followed
15704 // by zeroable elements in the remaining 24 elements. Turn this into two
15705 // vmovqb instructions shuffled together.
lowerShuffleAsVTRUNCAndUnpack(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,const APInt & Zeroable,SelectionDAG & DAG)15706 static SDValue lowerShuffleAsVTRUNCAndUnpack(const SDLoc &DL, MVT VT,
15707                                              SDValue V1, SDValue V2,
15708                                              ArrayRef<int> Mask,
15709                                              const APInt &Zeroable,
15710                                              SelectionDAG &DAG) {
15711   assert(VT == MVT::v32i8 && "Unexpected type!");
15712 
15713   // The first 8 indices should be every 8th element.
15714   if (!isSequentialOrUndefInRange(Mask, 0, 8, 0, 8))
15715     return SDValue();
15716 
15717   // Remaining elements need to be zeroable.
15718   if (Zeroable.countl_one() < (Mask.size() - 8))
15719     return SDValue();
15720 
15721   V1 = DAG.getBitcast(MVT::v4i64, V1);
15722   V2 = DAG.getBitcast(MVT::v4i64, V2);
15723 
15724   V1 = DAG.getNode(X86ISD::VTRUNC, DL, MVT::v16i8, V1);
15725   V2 = DAG.getNode(X86ISD::VTRUNC, DL, MVT::v16i8, V2);
15726 
15727   // The VTRUNCs will put 0s in the upper 12 bytes. Use them to put zeroes in
15728   // the upper bits of the result using an unpckldq.
15729   SDValue Unpack = DAG.getVectorShuffle(MVT::v16i8, DL, V1, V2,
15730                                         { 0, 1, 2, 3, 16, 17, 18, 19,
15731                                           4, 5, 6, 7, 20, 21, 22, 23 });
15732   // Insert the unpckldq into a zero vector to widen to v32i8.
15733   return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v32i8,
15734                      DAG.getConstant(0, DL, MVT::v32i8), Unpack,
15735                      DAG.getIntPtrConstant(0, DL));
15736 }
15737 
15738 // a = shuffle v1, v2, mask1    ; interleaving lower lanes of v1 and v2
15739 // b = shuffle v1, v2, mask2    ; interleaving higher lanes of v1 and v2
15740 //     =>
15741 // ul = unpckl v1, v2
15742 // uh = unpckh v1, v2
15743 // a = vperm ul, uh
15744 // b = vperm ul, uh
15745 //
15746 // Pattern-match interleave(256b v1, 256b v2) -> 512b v3 and lower it into unpck
15747 // and permute. We cannot directly match v3 because it is split into two
15748 // 256-bit vectors in earlier isel stages. Therefore, this function matches a
15749 // pair of 256-bit shuffles and makes sure the masks are consecutive.
15750 //
15751 // Once unpck and permute nodes are created, the permute corresponding to this
15752 // shuffle is returned, while the other permute replaces the other half of the
15753 // shuffle in the selection dag.
lowerShufflePairAsUNPCKAndPermute(const SDLoc & DL,MVT VT,SDValue V1,SDValue V2,ArrayRef<int> Mask,SelectionDAG & DAG)15754 static SDValue lowerShufflePairAsUNPCKAndPermute(const SDLoc &DL, MVT VT,
15755                                                  SDValue V1, SDValue V2,
15756                                                  ArrayRef<int> Mask,
15757                                                  SelectionDAG &DAG) {
15758   if (VT != MVT::v8f32 && VT != MVT::v8i32 && VT != MVT::v16i16 &&
15759       VT != MVT::v32i8)
15760     return SDValue();
15761   // <B0, B1, B0+1, B1+1, ..., >
15762   auto IsInterleavingPattern = [&](ArrayRef<int> Mask, unsigned Begin0,
15763                                    unsigned Begin1) {
15764     size_t Size = Mask.size();
15765     assert(Size % 2 == 0 && "Expected even mask size");
15766     for (unsigned I = 0; I < Size; I += 2) {
15767       if (Mask[I] != (int)(Begin0 + I / 2) ||
15768           Mask[I + 1] != (int)(Begin1 + I / 2))
15769         return false;
15770     }
15771     return true;
15772   };
15773   // Check which half is this shuffle node
15774   int NumElts = VT.getVectorNumElements();
15775   size_t FirstQtr = NumElts / 2;
15776   size_t ThirdQtr = NumElts + NumElts / 2;
15777   bool IsFirstHalf = IsInterleavingPattern(Mask, 0, NumElts);
15778   bool IsSecondHalf = IsInterleavingPattern(Mask, FirstQtr, ThirdQtr);
15779   if (!IsFirstHalf && !IsSecondHalf)
15780     return SDValue();
15781 
15782   // Find the intersection between shuffle users of V1 and V2.
15783   SmallVector<SDNode *, 2> Shuffles;
15784   for (SDNode *User : V1->uses())
15785     if (User->getOpcode() == ISD::VECTOR_SHUFFLE && User->getOperand(0) == V1 &&
15786         User->getOperand(1) == V2)
15787       Shuffles.push_back(User);
15788   // Limit user size to two for now.
15789   if (Shuffles.size() != 2)
15790     return SDValue();
15791   // Find out which half of the 512-bit shuffles is each smaller shuffle
15792   auto *SVN1 = cast<ShuffleVectorSDNode>(Shuffles[0]);
15793   auto *SVN2 = cast<ShuffleVectorSDNode>(Shuffles[1]);
15794   SDNode *FirstHalf;
15795   SDNode *SecondHalf;
15796   if (IsInterleavingPattern(SVN1->getMask(), 0, NumElts) &&
15797       IsInterleavingPattern(SVN2->getMask(), FirstQtr, ThirdQtr)) {
15798     FirstHalf = Shuffles[0];
15799     SecondHalf = Shuffles[1];
15800   } else if (IsInterleavingPattern(SVN1->getMask(), FirstQtr, ThirdQtr) &&
15801              IsInterleavingPattern(SVN2->getMask(), 0, NumElts)) {
15802     FirstHalf = Shuffles[1];
15803     SecondHalf = Shuffles[0];
15804   } else {
15805     return SDValue();
15806   }
15807   // Lower into unpck and perm. Return the perm of this shuffle and replace
15808   // the other.
15809   SDValue Unpckl = DAG.getNode(X86ISD::UNPCKL, DL, VT, V1, V2);
15810   SDValue Unpckh = DAG.getNode(X86ISD::UNPCKH, DL, VT, V1, V2);
15811   SDValue Perm1 = DAG.getNode(X86ISD::VPERM2X128, DL, VT, Unpckl, Unpckh,
15812                               DAG.getTargetConstant(0x20, DL, MVT::i8));
15813   SDValue Perm2 = DAG.getNode(X86ISD::VPERM2X128, DL, VT, Unpckl, Unpckh,
15814                               DAG.getTargetConstant(0x31, DL, MVT::i8));
15815   if (IsFirstHalf) {
15816     DAG.ReplaceAllUsesWith(SecondHalf, &Perm2);
15817     return Perm1;
15818   }
15819   DAG.ReplaceAllUsesWith(FirstHalf, &Perm1);
15820   return Perm2;
15821 }
15822 
15823 /// Handle lowering of 4-lane 64-bit floating point shuffles.
15824 ///
15825 /// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2
15826 /// isn't available.
lowerV4F64Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)15827 static SDValue lowerV4F64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
15828                                  const APInt &Zeroable, SDValue V1, SDValue V2,
15829                                  const X86Subtarget &Subtarget,
15830                                  SelectionDAG &DAG) {
15831   assert(V1.getSimpleValueType() == MVT::v4f64 && "Bad operand type!");
15832   assert(V2.getSimpleValueType() == MVT::v4f64 && "Bad operand type!");
15833   assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!");
15834 
15835   if (SDValue V = lowerV2X128Shuffle(DL, MVT::v4f64, V1, V2, Mask, Zeroable,
15836                                      Subtarget, DAG))
15837     return V;
15838 
15839   if (V2.isUndef()) {
15840     // Check for being able to broadcast a single element.
15841     if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v4f64, V1, V2,
15842                                                     Mask, Subtarget, DAG))
15843       return Broadcast;
15844 
15845     // Use low duplicate instructions for masks that match their pattern.
15846     if (isShuffleEquivalent(Mask, {0, 0, 2, 2}, V1, V2))
15847       return DAG.getNode(X86ISD::MOVDDUP, DL, MVT::v4f64, V1);
15848 
15849     if (!is128BitLaneCrossingShuffleMask(MVT::v4f64, Mask)) {
15850       // Non-half-crossing single input shuffles can be lowered with an
15851       // interleaved permutation.
15852       unsigned VPERMILPMask = (Mask[0] == 1) | ((Mask[1] == 1) << 1) |
15853                               ((Mask[2] == 3) << 2) | ((Mask[3] == 3) << 3);
15854       return DAG.getNode(X86ISD::VPERMILPI, DL, MVT::v4f64, V1,
15855                          DAG.getTargetConstant(VPERMILPMask, DL, MVT::i8));
15856     }
15857 
15858     // With AVX2 we have direct support for this permutation.
15859     if (Subtarget.hasAVX2())
15860       return DAG.getNode(X86ISD::VPERMI, DL, MVT::v4f64, V1,
15861                          getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
15862 
15863     // Try to create an in-lane repeating shuffle mask and then shuffle the
15864     // results into the target lanes.
15865     if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
15866             DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG))
15867       return V;
15868 
15869     // Try to permute the lanes and then use a per-lane permute.
15870     if (SDValue V = lowerShuffleAsLanePermuteAndPermute(DL, MVT::v4f64, V1, V2,
15871                                                         Mask, DAG, Subtarget))
15872       return V;
15873 
15874     // Otherwise, fall back.
15875     return lowerShuffleAsLanePermuteAndShuffle(DL, MVT::v4f64, V1, V2, Mask,
15876                                                DAG, Subtarget);
15877   }
15878 
15879   // Use dedicated unpack instructions for masks that match their pattern.
15880   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v4f64, Mask, V1, V2, DAG))
15881     return V;
15882 
15883   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v4f64, V1, V2, Mask,
15884                                           Zeroable, Subtarget, DAG))
15885     return Blend;
15886 
15887   // Check if the blend happens to exactly fit that of SHUFPD.
15888   if (SDValue Op = lowerShuffleWithSHUFPD(DL, MVT::v4f64, V1, V2, Mask,
15889                                           Zeroable, Subtarget, DAG))
15890     return Op;
15891 
15892   bool V1IsInPlace = isShuffleMaskInputInPlace(0, Mask);
15893   bool V2IsInPlace = isShuffleMaskInputInPlace(1, Mask);
15894 
15895   // If we have lane crossing shuffles AND they don't all come from the lower
15896   // lane elements, lower to SHUFPD(VPERM2F128(V1, V2), VPERM2F128(V1, V2)).
15897   // TODO: Handle BUILD_VECTOR sources which getVectorShuffle currently
15898   // canonicalize to a blend of splat which isn't necessary for this combine.
15899   if (is128BitLaneCrossingShuffleMask(MVT::v4f64, Mask) &&
15900       !all_of(Mask, [](int M) { return M < 2 || (4 <= M && M < 6); }) &&
15901       (V1.getOpcode() != ISD::BUILD_VECTOR) &&
15902       (V2.getOpcode() != ISD::BUILD_VECTOR))
15903     return lowerShuffleAsLanePermuteAndSHUFP(DL, MVT::v4f64, V1, V2, Mask, DAG);
15904 
15905   // If we have one input in place, then we can permute the other input and
15906   // blend the result.
15907   if (V1IsInPlace || V2IsInPlace)
15908     return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v4f64, V1, V2, Mask,
15909                                                 Subtarget, DAG);
15910 
15911   // Try to create an in-lane repeating shuffle mask and then shuffle the
15912   // results into the target lanes.
15913   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
15914           DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG))
15915     return V;
15916 
15917   // Try to simplify this by merging 128-bit lanes to enable a lane-based
15918   // shuffle. However, if we have AVX2 and either inputs are already in place,
15919   // we will be able to shuffle even across lanes the other input in a single
15920   // instruction so skip this pattern.
15921   if (!(Subtarget.hasAVX2() && (V1IsInPlace || V2IsInPlace)))
15922     if (SDValue V = lowerShuffleAsLanePermuteAndRepeatedMask(
15923             DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG))
15924       return V;
15925 
15926   // If we have VLX support, we can use VEXPAND.
15927   if (Subtarget.hasVLX())
15928     if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v4f64, Zeroable, Mask, V1, V2,
15929                                          DAG, Subtarget))
15930       return V;
15931 
15932   // If we have AVX2 then we always want to lower with a blend because an v4 we
15933   // can fully permute the elements.
15934   if (Subtarget.hasAVX2())
15935     return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v4f64, V1, V2, Mask,
15936                                                 Subtarget, DAG);
15937 
15938   // Otherwise fall back on generic lowering.
15939   return lowerShuffleAsSplitOrBlend(DL, MVT::v4f64, V1, V2, Mask,
15940                                     Subtarget, DAG);
15941 }
15942 
15943 /// Handle lowering of 4-lane 64-bit integer shuffles.
15944 ///
15945 /// This routine is only called when we have AVX2 and thus a reasonable
15946 /// instruction set for v4i64 shuffling..
lowerV4I64Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)15947 static SDValue lowerV4I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
15948                                  const APInt &Zeroable, SDValue V1, SDValue V2,
15949                                  const X86Subtarget &Subtarget,
15950                                  SelectionDAG &DAG) {
15951   assert(V1.getSimpleValueType() == MVT::v4i64 && "Bad operand type!");
15952   assert(V2.getSimpleValueType() == MVT::v4i64 && "Bad operand type!");
15953   assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!");
15954   assert(Subtarget.hasAVX2() && "We can only lower v4i64 with AVX2!");
15955 
15956   if (SDValue V = lowerV2X128Shuffle(DL, MVT::v4i64, V1, V2, Mask, Zeroable,
15957                                      Subtarget, DAG))
15958     return V;
15959 
15960   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v4i64, V1, V2, Mask,
15961                                           Zeroable, Subtarget, DAG))
15962     return Blend;
15963 
15964   // Check for being able to broadcast a single element.
15965   if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v4i64, V1, V2, Mask,
15966                                                   Subtarget, DAG))
15967     return Broadcast;
15968 
15969   // Try to use shift instructions if fast.
15970   if (Subtarget.preferLowerShuffleAsShift())
15971     if (SDValue Shift =
15972             lowerShuffleAsShift(DL, MVT::v4i64, V1, V2, Mask, Zeroable,
15973                                 Subtarget, DAG, /*BitwiseOnly*/ true))
15974       return Shift;
15975 
15976   if (V2.isUndef()) {
15977     // When the shuffle is mirrored between the 128-bit lanes of the unit, we
15978     // can use lower latency instructions that will operate on both lanes.
15979     SmallVector<int, 2> RepeatedMask;
15980     if (is128BitLaneRepeatedShuffleMask(MVT::v4i64, Mask, RepeatedMask)) {
15981       SmallVector<int, 4> PSHUFDMask;
15982       narrowShuffleMaskElts(2, RepeatedMask, PSHUFDMask);
15983       return DAG.getBitcast(
15984           MVT::v4i64,
15985           DAG.getNode(X86ISD::PSHUFD, DL, MVT::v8i32,
15986                       DAG.getBitcast(MVT::v8i32, V1),
15987                       getV4X86ShuffleImm8ForMask(PSHUFDMask, DL, DAG)));
15988     }
15989 
15990     // AVX2 provides a direct instruction for permuting a single input across
15991     // lanes.
15992     return DAG.getNode(X86ISD::VPERMI, DL, MVT::v4i64, V1,
15993                        getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
15994   }
15995 
15996   // Try to use shift instructions.
15997   if (SDValue Shift =
15998           lowerShuffleAsShift(DL, MVT::v4i64, V1, V2, Mask, Zeroable, Subtarget,
15999                               DAG, /*BitwiseOnly*/ false))
16000     return Shift;
16001 
16002   // If we have VLX support, we can use VALIGN or VEXPAND.
16003   if (Subtarget.hasVLX()) {
16004     if (SDValue Rotate = lowerShuffleAsVALIGN(DL, MVT::v4i64, V1, V2, Mask,
16005                                               Zeroable, Subtarget, DAG))
16006       return Rotate;
16007 
16008     if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v4i64, Zeroable, Mask, V1, V2,
16009                                          DAG, Subtarget))
16010       return V;
16011   }
16012 
16013   // Try to use PALIGNR.
16014   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v4i64, V1, V2, Mask,
16015                                                 Subtarget, DAG))
16016     return Rotate;
16017 
16018   // Use dedicated unpack instructions for masks that match their pattern.
16019   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v4i64, Mask, V1, V2, DAG))
16020     return V;
16021 
16022   bool V1IsInPlace = isShuffleMaskInputInPlace(0, Mask);
16023   bool V2IsInPlace = isShuffleMaskInputInPlace(1, Mask);
16024 
16025   // If we have one input in place, then we can permute the other input and
16026   // blend the result.
16027   if (V1IsInPlace || V2IsInPlace)
16028     return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v4i64, V1, V2, Mask,
16029                                                 Subtarget, DAG);
16030 
16031   // Try to create an in-lane repeating shuffle mask and then shuffle the
16032   // results into the target lanes.
16033   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
16034           DL, MVT::v4i64, V1, V2, Mask, Subtarget, DAG))
16035     return V;
16036 
16037   // Try to lower to PERMQ(BLENDD(V1,V2)).
16038   if (SDValue V =
16039           lowerShuffleAsBlendAndPermute(DL, MVT::v4i64, V1, V2, Mask, DAG))
16040     return V;
16041 
16042   // Try to simplify this by merging 128-bit lanes to enable a lane-based
16043   // shuffle. However, if we have AVX2 and either inputs are already in place,
16044   // we will be able to shuffle even across lanes the other input in a single
16045   // instruction so skip this pattern.
16046   if (!V1IsInPlace && !V2IsInPlace)
16047     if (SDValue Result = lowerShuffleAsLanePermuteAndRepeatedMask(
16048             DL, MVT::v4i64, V1, V2, Mask, Subtarget, DAG))
16049       return Result;
16050 
16051   // Otherwise fall back on generic blend lowering.
16052   return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v4i64, V1, V2, Mask,
16053                                               Subtarget, DAG);
16054 }
16055 
16056 /// Handle lowering of 8-lane 32-bit floating point shuffles.
16057 ///
16058 /// Also ends up handling lowering of 8-lane 32-bit integer shuffles when AVX2
16059 /// isn't available.
lowerV8F32Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16060 static SDValue lowerV8F32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16061                                  const APInt &Zeroable, SDValue V1, SDValue V2,
16062                                  const X86Subtarget &Subtarget,
16063                                  SelectionDAG &DAG) {
16064   assert(V1.getSimpleValueType() == MVT::v8f32 && "Bad operand type!");
16065   assert(V2.getSimpleValueType() == MVT::v8f32 && "Bad operand type!");
16066   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
16067 
16068   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v8f32, V1, V2, Mask,
16069                                           Zeroable, Subtarget, DAG))
16070     return Blend;
16071 
16072   // Check for being able to broadcast a single element.
16073   if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f32, V1, V2, Mask,
16074                                                   Subtarget, DAG))
16075     return Broadcast;
16076 
16077   if (!Subtarget.hasAVX2()) {
16078     SmallVector<int> InLaneMask;
16079     computeInLaneShuffleMask(Mask, Mask.size() / 2, InLaneMask);
16080 
16081     if (!is128BitLaneRepeatedShuffleMask(MVT::v8f32, InLaneMask))
16082       if (SDValue R = splitAndLowerShuffle(DL, MVT::v8f32, V1, V2, Mask, DAG,
16083                                            /*SimpleOnly*/ true))
16084         return R;
16085   }
16086   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(DL, MVT::v8i32, V1, V2, Mask,
16087                                                    Zeroable, Subtarget, DAG))
16088     return DAG.getBitcast(MVT::v8f32, ZExt);
16089 
16090   // If the shuffle mask is repeated in each 128-bit lane, we have many more
16091   // options to efficiently lower the shuffle.
16092   SmallVector<int, 4> RepeatedMask;
16093   if (is128BitLaneRepeatedShuffleMask(MVT::v8f32, Mask, RepeatedMask)) {
16094     assert(RepeatedMask.size() == 4 &&
16095            "Repeated masks must be half the mask width!");
16096 
16097     // Use even/odd duplicate instructions for masks that match their pattern.
16098     if (isShuffleEquivalent(RepeatedMask, {0, 0, 2, 2}, V1, V2))
16099       return DAG.getNode(X86ISD::MOVSLDUP, DL, MVT::v8f32, V1);
16100     if (isShuffleEquivalent(RepeatedMask, {1, 1, 3, 3}, V1, V2))
16101       return DAG.getNode(X86ISD::MOVSHDUP, DL, MVT::v8f32, V1);
16102 
16103     if (V2.isUndef())
16104       return DAG.getNode(X86ISD::VPERMILPI, DL, MVT::v8f32, V1,
16105                          getV4X86ShuffleImm8ForMask(RepeatedMask, DL, DAG));
16106 
16107     // Use dedicated unpack instructions for masks that match their pattern.
16108     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v8f32, Mask, V1, V2, DAG))
16109       return V;
16110 
16111     // Otherwise, fall back to a SHUFPS sequence. Here it is important that we
16112     // have already handled any direct blends.
16113     return lowerShuffleWithSHUFPS(DL, MVT::v8f32, RepeatedMask, V1, V2, DAG);
16114   }
16115 
16116   // Try to create an in-lane repeating shuffle mask and then shuffle the
16117   // results into the target lanes.
16118   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
16119           DL, MVT::v8f32, V1, V2, Mask, Subtarget, DAG))
16120     return V;
16121 
16122   // If we have a single input shuffle with different shuffle patterns in the
16123   // two 128-bit lanes use the variable mask to VPERMILPS.
16124   if (V2.isUndef()) {
16125     if (!is128BitLaneCrossingShuffleMask(MVT::v8f32, Mask)) {
16126       SDValue VPermMask = getConstVector(Mask, MVT::v8i32, DAG, DL, true);
16127       return DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, V1, VPermMask);
16128     }
16129     if (Subtarget.hasAVX2()) {
16130       SDValue VPermMask = getConstVector(Mask, MVT::v8i32, DAG, DL, true);
16131       return DAG.getNode(X86ISD::VPERMV, DL, MVT::v8f32, VPermMask, V1);
16132     }
16133     // Otherwise, fall back.
16134     return lowerShuffleAsLanePermuteAndShuffle(DL, MVT::v8f32, V1, V2, Mask,
16135                                                DAG, Subtarget);
16136   }
16137 
16138   // Try to simplify this by merging 128-bit lanes to enable a lane-based
16139   // shuffle.
16140   if (SDValue Result = lowerShuffleAsLanePermuteAndRepeatedMask(
16141           DL, MVT::v8f32, V1, V2, Mask, Subtarget, DAG))
16142     return Result;
16143 
16144   // If we have VLX support, we can use VEXPAND.
16145   if (Subtarget.hasVLX())
16146     if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v8f32, Zeroable, Mask, V1, V2,
16147                                          DAG, Subtarget))
16148       return V;
16149 
16150   // Try to match an interleave of two v8f32s and lower them as unpck and
16151   // permutes using ymms. This needs to go before we try to split the vectors.
16152   //
16153   // TODO: Expand this to AVX1. Currently v8i32 is casted to v8f32 and hits
16154   // this path inadvertently.
16155   if (Subtarget.hasAVX2() && !Subtarget.hasAVX512())
16156     if (SDValue V = lowerShufflePairAsUNPCKAndPermute(DL, MVT::v8f32, V1, V2,
16157                                                       Mask, DAG))
16158       return V;
16159 
16160   // For non-AVX512 if the Mask is of 16bit elements in lane then try to split
16161   // since after split we get a more efficient code using vpunpcklwd and
16162   // vpunpckhwd instrs than vblend.
16163   if (!Subtarget.hasAVX512() && isUnpackWdShuffleMask(Mask, MVT::v8f32, DAG))
16164     return lowerShuffleAsSplitOrBlend(DL, MVT::v8f32, V1, V2, Mask, Subtarget,
16165                                       DAG);
16166 
16167   // If we have AVX2 then we always want to lower with a blend because at v8 we
16168   // can fully permute the elements.
16169   if (Subtarget.hasAVX2())
16170     return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v8f32, V1, V2, Mask,
16171                                                 Subtarget, DAG);
16172 
16173   // Otherwise fall back on generic lowering.
16174   return lowerShuffleAsSplitOrBlend(DL, MVT::v8f32, V1, V2, Mask,
16175                                     Subtarget, DAG);
16176 }
16177 
16178 /// Handle lowering of 8-lane 32-bit integer shuffles.
16179 ///
16180 /// This routine is only called when we have AVX2 and thus a reasonable
16181 /// instruction set for v8i32 shuffling..
lowerV8I32Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16182 static SDValue lowerV8I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16183                                  const APInt &Zeroable, SDValue V1, SDValue V2,
16184                                  const X86Subtarget &Subtarget,
16185                                  SelectionDAG &DAG) {
16186   assert(V1.getSimpleValueType() == MVT::v8i32 && "Bad operand type!");
16187   assert(V2.getSimpleValueType() == MVT::v8i32 && "Bad operand type!");
16188   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
16189   assert(Subtarget.hasAVX2() && "We can only lower v8i32 with AVX2!");
16190 
16191   int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
16192 
16193   // Whenever we can lower this as a zext, that instruction is strictly faster
16194   // than any alternative. It also allows us to fold memory operands into the
16195   // shuffle in many cases.
16196   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(DL, MVT::v8i32, V1, V2, Mask,
16197                                                    Zeroable, Subtarget, DAG))
16198     return ZExt;
16199 
16200   // Try to match an interleave of two v8i32s and lower them as unpck and
16201   // permutes using ymms. This needs to go before we try to split the vectors.
16202   if (!Subtarget.hasAVX512())
16203     if (SDValue V = lowerShufflePairAsUNPCKAndPermute(DL, MVT::v8i32, V1, V2,
16204                                                       Mask, DAG))
16205       return V;
16206 
16207   // For non-AVX512 if the Mask is of 16bit elements in lane then try to split
16208   // since after split we get a more efficient code than vblend by using
16209   // vpunpcklwd and vpunpckhwd instrs.
16210   if (isUnpackWdShuffleMask(Mask, MVT::v8i32, DAG) && !V2.isUndef() &&
16211       !Subtarget.hasAVX512())
16212     return lowerShuffleAsSplitOrBlend(DL, MVT::v8i32, V1, V2, Mask, Subtarget,
16213                                       DAG);
16214 
16215   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v8i32, V1, V2, Mask,
16216                                           Zeroable, Subtarget, DAG))
16217     return Blend;
16218 
16219   // Check for being able to broadcast a single element.
16220   if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8i32, V1, V2, Mask,
16221                                                   Subtarget, DAG))
16222     return Broadcast;
16223 
16224   // Try to use shift instructions if fast.
16225   if (Subtarget.preferLowerShuffleAsShift()) {
16226     if (SDValue Shift =
16227             lowerShuffleAsShift(DL, MVT::v8i32, V1, V2, Mask, Zeroable,
16228                                 Subtarget, DAG, /*BitwiseOnly*/ true))
16229       return Shift;
16230     if (NumV2Elements == 0)
16231       if (SDValue Rotate =
16232               lowerShuffleAsBitRotate(DL, MVT::v8i32, V1, Mask, Subtarget, DAG))
16233         return Rotate;
16234   }
16235 
16236   // If the shuffle mask is repeated in each 128-bit lane we can use more
16237   // efficient instructions that mirror the shuffles across the two 128-bit
16238   // lanes.
16239   SmallVector<int, 4> RepeatedMask;
16240   bool Is128BitLaneRepeatedShuffle =
16241       is128BitLaneRepeatedShuffleMask(MVT::v8i32, Mask, RepeatedMask);
16242   if (Is128BitLaneRepeatedShuffle) {
16243     assert(RepeatedMask.size() == 4 && "Unexpected repeated mask size!");
16244     if (V2.isUndef())
16245       return DAG.getNode(X86ISD::PSHUFD, DL, MVT::v8i32, V1,
16246                          getV4X86ShuffleImm8ForMask(RepeatedMask, DL, DAG));
16247 
16248     // Use dedicated unpack instructions for masks that match their pattern.
16249     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v8i32, Mask, V1, V2, DAG))
16250       return V;
16251   }
16252 
16253   // Try to use shift instructions.
16254   if (SDValue Shift =
16255           lowerShuffleAsShift(DL, MVT::v8i32, V1, V2, Mask, Zeroable, Subtarget,
16256                               DAG, /*BitwiseOnly*/ false))
16257     return Shift;
16258 
16259   if (!Subtarget.preferLowerShuffleAsShift() && NumV2Elements == 0)
16260     if (SDValue Rotate =
16261             lowerShuffleAsBitRotate(DL, MVT::v8i32, V1, Mask, Subtarget, DAG))
16262       return Rotate;
16263 
16264   // If we have VLX support, we can use VALIGN or EXPAND.
16265   if (Subtarget.hasVLX()) {
16266     if (SDValue Rotate = lowerShuffleAsVALIGN(DL, MVT::v8i32, V1, V2, Mask,
16267                                               Zeroable, Subtarget, DAG))
16268       return Rotate;
16269 
16270     if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v8i32, Zeroable, Mask, V1, V2,
16271                                          DAG, Subtarget))
16272       return V;
16273   }
16274 
16275   // Try to use byte rotation instructions.
16276   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v8i32, V1, V2, Mask,
16277                                                 Subtarget, DAG))
16278     return Rotate;
16279 
16280   // Try to create an in-lane repeating shuffle mask and then shuffle the
16281   // results into the target lanes.
16282   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
16283           DL, MVT::v8i32, V1, V2, Mask, Subtarget, DAG))
16284     return V;
16285 
16286   if (V2.isUndef()) {
16287     // Try to produce a fixed cross-128-bit lane permute followed by unpack
16288     // because that should be faster than the variable permute alternatives.
16289     if (SDValue V = lowerShuffleWithUNPCK256(DL, MVT::v8i32, Mask, V1, V2, DAG))
16290       return V;
16291 
16292     // If the shuffle patterns aren't repeated but it's a single input, directly
16293     // generate a cross-lane VPERMD instruction.
16294     SDValue VPermMask = getConstVector(Mask, MVT::v8i32, DAG, DL, true);
16295     return DAG.getNode(X86ISD::VPERMV, DL, MVT::v8i32, VPermMask, V1);
16296   }
16297 
16298   // Assume that a single SHUFPS is faster than an alternative sequence of
16299   // multiple instructions (even if the CPU has a domain penalty).
16300   // If some CPU is harmed by the domain switch, we can fix it in a later pass.
16301   if (Is128BitLaneRepeatedShuffle && isSingleSHUFPSMask(RepeatedMask)) {
16302     SDValue CastV1 = DAG.getBitcast(MVT::v8f32, V1);
16303     SDValue CastV2 = DAG.getBitcast(MVT::v8f32, V2);
16304     SDValue ShufPS = lowerShuffleWithSHUFPS(DL, MVT::v8f32, RepeatedMask,
16305                                             CastV1, CastV2, DAG);
16306     return DAG.getBitcast(MVT::v8i32, ShufPS);
16307   }
16308 
16309   // Try to simplify this by merging 128-bit lanes to enable a lane-based
16310   // shuffle.
16311   if (SDValue Result = lowerShuffleAsLanePermuteAndRepeatedMask(
16312           DL, MVT::v8i32, V1, V2, Mask, Subtarget, DAG))
16313     return Result;
16314 
16315   // Otherwise fall back on generic blend lowering.
16316   return lowerShuffleAsDecomposedShuffleMerge(DL, MVT::v8i32, V1, V2, Mask,
16317                                               Subtarget, DAG);
16318 }
16319 
16320 /// Handle lowering of 16-lane 16-bit integer shuffles.
16321 ///
16322 /// This routine is only called when we have AVX2 and thus a reasonable
16323 /// instruction set for v16i16 shuffling..
lowerV16I16Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16324 static SDValue lowerV16I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16325                                   const APInt &Zeroable, SDValue V1, SDValue V2,
16326                                   const X86Subtarget &Subtarget,
16327                                   SelectionDAG &DAG) {
16328   assert(V1.getSimpleValueType() == MVT::v16i16 && "Bad operand type!");
16329   assert(V2.getSimpleValueType() == MVT::v16i16 && "Bad operand type!");
16330   assert(Mask.size() == 16 && "Unexpected mask size for v16 shuffle!");
16331   assert(Subtarget.hasAVX2() && "We can only lower v16i16 with AVX2!");
16332 
16333   // Whenever we can lower this as a zext, that instruction is strictly faster
16334   // than any alternative. It also allows us to fold memory operands into the
16335   // shuffle in many cases.
16336   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(
16337           DL, MVT::v16i16, V1, V2, Mask, Zeroable, Subtarget, DAG))
16338     return ZExt;
16339 
16340   // Check for being able to broadcast a single element.
16341   if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v16i16, V1, V2, Mask,
16342                                                   Subtarget, DAG))
16343     return Broadcast;
16344 
16345   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v16i16, V1, V2, Mask,
16346                                           Zeroable, Subtarget, DAG))
16347     return Blend;
16348 
16349   // Use dedicated unpack instructions for masks that match their pattern.
16350   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i16, Mask, V1, V2, DAG))
16351     return V;
16352 
16353   // Use dedicated pack instructions for masks that match their pattern.
16354   if (SDValue V = lowerShuffleWithPACK(DL, MVT::v16i16, Mask, V1, V2, DAG,
16355                                        Subtarget))
16356     return V;
16357 
16358   // Try to use lower using a truncation.
16359   if (SDValue V = lowerShuffleAsVTRUNC(DL, MVT::v16i16, V1, V2, Mask, Zeroable,
16360                                        Subtarget, DAG))
16361     return V;
16362 
16363   // Try to use shift instructions.
16364   if (SDValue Shift =
16365           lowerShuffleAsShift(DL, MVT::v16i16, V1, V2, Mask, Zeroable,
16366                               Subtarget, DAG, /*BitwiseOnly*/ false))
16367     return Shift;
16368 
16369   // Try to use byte rotation instructions.
16370   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v16i16, V1, V2, Mask,
16371                                                 Subtarget, DAG))
16372     return Rotate;
16373 
16374   // Try to create an in-lane repeating shuffle mask and then shuffle the
16375   // results into the target lanes.
16376   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
16377           DL, MVT::v16i16, V1, V2, Mask, Subtarget, DAG))
16378     return V;
16379 
16380   if (V2.isUndef()) {
16381     // Try to use bit rotation instructions.
16382     if (SDValue Rotate =
16383             lowerShuffleAsBitRotate(DL, MVT::v16i16, V1, Mask, Subtarget, DAG))
16384       return Rotate;
16385 
16386     // Try to produce a fixed cross-128-bit lane permute followed by unpack
16387     // because that should be faster than the variable permute alternatives.
16388     if (SDValue V = lowerShuffleWithUNPCK256(DL, MVT::v16i16, Mask, V1, V2, DAG))
16389       return V;
16390 
16391     // There are no generalized cross-lane shuffle operations available on i16
16392     // element types.
16393     if (is128BitLaneCrossingShuffleMask(MVT::v16i16, Mask)) {
16394       if (SDValue V = lowerShuffleAsLanePermuteAndPermute(
16395               DL, MVT::v16i16, V1, V2, Mask, DAG, Subtarget))
16396         return V;
16397 
16398       return lowerShuffleAsLanePermuteAndShuffle(DL, MVT::v16i16, V1, V2, Mask,
16399                                                  DAG, Subtarget);
16400     }
16401 
16402     SmallVector<int, 8> RepeatedMask;
16403     if (is128BitLaneRepeatedShuffleMask(MVT::v16i16, Mask, RepeatedMask)) {
16404       // As this is a single-input shuffle, the repeated mask should be
16405       // a strictly valid v8i16 mask that we can pass through to the v8i16
16406       // lowering to handle even the v16 case.
16407       return lowerV8I16GeneralSingleInputShuffle(
16408           DL, MVT::v16i16, V1, RepeatedMask, Subtarget, DAG);
16409     }
16410   }
16411 
16412   if (SDValue PSHUFB = lowerShuffleWithPSHUFB(DL, MVT::v16i16, Mask, V1, V2,
16413                                               Zeroable, Subtarget, DAG))
16414     return PSHUFB;
16415 
16416   // AVX512BW can lower to VPERMW (non-VLX will pad to v32i16).
16417   if (Subtarget.hasBWI())
16418     return lowerShuffleWithPERMV(DL, MVT::v16i16, Mask, V1, V2, Subtarget, DAG);
16419 
16420   // Try to simplify this by merging 128-bit lanes to enable a lane-based
16421   // shuffle.
16422   if (SDValue Result = lowerShuffleAsLanePermuteAndRepeatedMask(
16423           DL, MVT::v16i16, V1, V2, Mask, Subtarget, DAG))
16424     return Result;
16425 
16426   // Try to permute the lanes and then use a per-lane permute.
16427   if (SDValue V = lowerShuffleAsLanePermuteAndPermute(
16428           DL, MVT::v16i16, V1, V2, Mask, DAG, Subtarget))
16429     return V;
16430 
16431   // Try to match an interleave of two v16i16s and lower them as unpck and
16432   // permutes using ymms.
16433   if (!Subtarget.hasAVX512())
16434     if (SDValue V = lowerShufflePairAsUNPCKAndPermute(DL, MVT::v16i16, V1, V2,
16435                                                       Mask, DAG))
16436       return V;
16437 
16438   // Otherwise fall back on generic lowering.
16439   return lowerShuffleAsSplitOrBlend(DL, MVT::v16i16, V1, V2, Mask,
16440                                     Subtarget, DAG);
16441 }
16442 
16443 /// Handle lowering of 32-lane 8-bit integer shuffles.
16444 ///
16445 /// This routine is only called when we have AVX2 and thus a reasonable
16446 /// instruction set for v32i8 shuffling..
lowerV32I8Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16447 static SDValue lowerV32I8Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16448                                  const APInt &Zeroable, SDValue V1, SDValue V2,
16449                                  const X86Subtarget &Subtarget,
16450                                  SelectionDAG &DAG) {
16451   assert(V1.getSimpleValueType() == MVT::v32i8 && "Bad operand type!");
16452   assert(V2.getSimpleValueType() == MVT::v32i8 && "Bad operand type!");
16453   assert(Mask.size() == 32 && "Unexpected mask size for v32 shuffle!");
16454   assert(Subtarget.hasAVX2() && "We can only lower v32i8 with AVX2!");
16455 
16456   // Whenever we can lower this as a zext, that instruction is strictly faster
16457   // than any alternative. It also allows us to fold memory operands into the
16458   // shuffle in many cases.
16459   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(DL, MVT::v32i8, V1, V2, Mask,
16460                                                    Zeroable, Subtarget, DAG))
16461     return ZExt;
16462 
16463   // Check for being able to broadcast a single element.
16464   if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v32i8, V1, V2, Mask,
16465                                                   Subtarget, DAG))
16466     return Broadcast;
16467 
16468   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v32i8, V1, V2, Mask,
16469                                           Zeroable, Subtarget, DAG))
16470     return Blend;
16471 
16472   // Use dedicated unpack instructions for masks that match their pattern.
16473   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v32i8, Mask, V1, V2, DAG))
16474     return V;
16475 
16476   // Use dedicated pack instructions for masks that match their pattern.
16477   if (SDValue V = lowerShuffleWithPACK(DL, MVT::v32i8, Mask, V1, V2, DAG,
16478                                        Subtarget))
16479     return V;
16480 
16481   // Try to use lower using a truncation.
16482   if (SDValue V = lowerShuffleAsVTRUNC(DL, MVT::v32i8, V1, V2, Mask, Zeroable,
16483                                        Subtarget, DAG))
16484     return V;
16485 
16486   // Try to use shift instructions.
16487   if (SDValue Shift =
16488           lowerShuffleAsShift(DL, MVT::v32i8, V1, V2, Mask, Zeroable, Subtarget,
16489                               DAG, /*BitwiseOnly*/ false))
16490     return Shift;
16491 
16492   // Try to use byte rotation instructions.
16493   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v32i8, V1, V2, Mask,
16494                                                 Subtarget, DAG))
16495     return Rotate;
16496 
16497   // Try to use bit rotation instructions.
16498   if (V2.isUndef())
16499     if (SDValue Rotate =
16500             lowerShuffleAsBitRotate(DL, MVT::v32i8, V1, Mask, Subtarget, DAG))
16501       return Rotate;
16502 
16503   // Try to create an in-lane repeating shuffle mask and then shuffle the
16504   // results into the target lanes.
16505   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
16506           DL, MVT::v32i8, V1, V2, Mask, Subtarget, DAG))
16507     return V;
16508 
16509   // There are no generalized cross-lane shuffle operations available on i8
16510   // element types.
16511   if (V2.isUndef() && is128BitLaneCrossingShuffleMask(MVT::v32i8, Mask)) {
16512     // Try to produce a fixed cross-128-bit lane permute followed by unpack
16513     // because that should be faster than the variable permute alternatives.
16514     if (SDValue V = lowerShuffleWithUNPCK256(DL, MVT::v32i8, Mask, V1, V2, DAG))
16515       return V;
16516 
16517     if (SDValue V = lowerShuffleAsLanePermuteAndPermute(
16518             DL, MVT::v32i8, V1, V2, Mask, DAG, Subtarget))
16519       return V;
16520 
16521     return lowerShuffleAsLanePermuteAndShuffle(DL, MVT::v32i8, V1, V2, Mask,
16522                                                DAG, Subtarget);
16523   }
16524 
16525   if (SDValue PSHUFB = lowerShuffleWithPSHUFB(DL, MVT::v32i8, Mask, V1, V2,
16526                                               Zeroable, Subtarget, DAG))
16527     return PSHUFB;
16528 
16529   // AVX512VBMI can lower to VPERMB (non-VLX will pad to v64i8).
16530   if (Subtarget.hasVBMI())
16531     return lowerShuffleWithPERMV(DL, MVT::v32i8, Mask, V1, V2, Subtarget, DAG);
16532 
16533   // Try to simplify this by merging 128-bit lanes to enable a lane-based
16534   // shuffle.
16535   if (SDValue Result = lowerShuffleAsLanePermuteAndRepeatedMask(
16536           DL, MVT::v32i8, V1, V2, Mask, Subtarget, DAG))
16537     return Result;
16538 
16539   // Try to permute the lanes and then use a per-lane permute.
16540   if (SDValue V = lowerShuffleAsLanePermuteAndPermute(
16541           DL, MVT::v32i8, V1, V2, Mask, DAG, Subtarget))
16542     return V;
16543 
16544   // Look for {0, 8, 16, 24, 32, 40, 48, 56 } in the first 8 elements. Followed
16545   // by zeroable elements in the remaining 24 elements. Turn this into two
16546   // vmovqb instructions shuffled together.
16547   if (Subtarget.hasVLX())
16548     if (SDValue V = lowerShuffleAsVTRUNCAndUnpack(DL, MVT::v32i8, V1, V2,
16549                                                   Mask, Zeroable, DAG))
16550       return V;
16551 
16552   // Try to match an interleave of two v32i8s and lower them as unpck and
16553   // permutes using ymms.
16554   if (!Subtarget.hasAVX512())
16555     if (SDValue V = lowerShufflePairAsUNPCKAndPermute(DL, MVT::v32i8, V1, V2,
16556                                                       Mask, DAG))
16557       return V;
16558 
16559   // Otherwise fall back on generic lowering.
16560   return lowerShuffleAsSplitOrBlend(DL, MVT::v32i8, V1, V2, Mask,
16561                                     Subtarget, DAG);
16562 }
16563 
16564 /// High-level routine to lower various 256-bit x86 vector shuffles.
16565 ///
16566 /// This routine either breaks down the specific type of a 256-bit x86 vector
16567 /// shuffle or splits it into two 128-bit shuffles and fuses the results back
16568 /// together based on the available instructions.
lower256BitShuffle(const SDLoc & DL,ArrayRef<int> Mask,MVT VT,SDValue V1,SDValue V2,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)16569 static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
16570                                   SDValue V1, SDValue V2, const APInt &Zeroable,
16571                                   const X86Subtarget &Subtarget,
16572                                   SelectionDAG &DAG) {
16573   // If we have a single input to the zero element, insert that into V1 if we
16574   // can do so cheaply.
16575   int NumElts = VT.getVectorNumElements();
16576   int NumV2Elements = count_if(Mask, [NumElts](int M) { return M >= NumElts; });
16577 
16578   if (NumV2Elements == 1 && Mask[0] >= NumElts)
16579     if (SDValue Insertion = lowerShuffleAsElementInsertion(
16580             DL, VT, V1, V2, Mask, Zeroable, Subtarget, DAG))
16581       return Insertion;
16582 
16583   // Handle special cases where the lower or upper half is UNDEF.
16584   if (SDValue V =
16585           lowerShuffleWithUndefHalf(DL, VT, V1, V2, Mask, Subtarget, DAG))
16586     return V;
16587 
16588   // There is a really nice hard cut-over between AVX1 and AVX2 that means we
16589   // can check for those subtargets here and avoid much of the subtarget
16590   // querying in the per-vector-type lowering routines. With AVX1 we have
16591   // essentially *zero* ability to manipulate a 256-bit vector with integer
16592   // types. Since we'll use floating point types there eventually, just
16593   // immediately cast everything to a float and operate entirely in that domain.
16594   if (VT.isInteger() && !Subtarget.hasAVX2()) {
16595     int ElementBits = VT.getScalarSizeInBits();
16596     if (ElementBits < 32) {
16597       // No floating point type available, if we can't use the bit operations
16598       // for masking/blending then decompose into 128-bit vectors.
16599       if (SDValue V = lowerShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable,
16600                                             Subtarget, DAG))
16601         return V;
16602       if (SDValue V = lowerShuffleAsBitBlend(DL, VT, V1, V2, Mask, DAG))
16603         return V;
16604       return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG, /*SimpleOnly*/ false);
16605     }
16606 
16607     MVT FpVT = MVT::getVectorVT(MVT::getFloatingPointVT(ElementBits),
16608                                 VT.getVectorNumElements());
16609     V1 = DAG.getBitcast(FpVT, V1);
16610     V2 = DAG.getBitcast(FpVT, V2);
16611     return DAG.getBitcast(VT, DAG.getVectorShuffle(FpVT, DL, V1, V2, Mask));
16612   }
16613 
16614   if (VT == MVT::v16f16 || VT == MVT::v16bf16) {
16615     V1 = DAG.getBitcast(MVT::v16i16, V1);
16616     V2 = DAG.getBitcast(MVT::v16i16, V2);
16617     return DAG.getBitcast(VT,
16618                           DAG.getVectorShuffle(MVT::v16i16, DL, V1, V2, Mask));
16619   }
16620 
16621   switch (VT.SimpleTy) {
16622   case MVT::v4f64:
16623     return lowerV4F64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
16624   case MVT::v4i64:
16625     return lowerV4I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
16626   case MVT::v8f32:
16627     return lowerV8F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
16628   case MVT::v8i32:
16629     return lowerV8I32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
16630   case MVT::v16i16:
16631     return lowerV16I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
16632   case MVT::v32i8:
16633     return lowerV32I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
16634 
16635   default:
16636     llvm_unreachable("Not a valid 256-bit x86 vector type!");
16637   }
16638 }
16639 
16640 /// Try to lower a vector shuffle as a 128-bit shuffles.
lowerV4X128Shuffle(const SDLoc & DL,MVT VT,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16641 static SDValue lowerV4X128Shuffle(const SDLoc &DL, MVT VT, ArrayRef<int> Mask,
16642                                   const APInt &Zeroable, SDValue V1, SDValue V2,
16643                                   const X86Subtarget &Subtarget,
16644                                   SelectionDAG &DAG) {
16645   assert(VT.getScalarSizeInBits() == 64 &&
16646          "Unexpected element type size for 128bit shuffle.");
16647 
16648   // To handle 256 bit vector requires VLX and most probably
16649   // function lowerV2X128VectorShuffle() is better solution.
16650   assert(VT.is512BitVector() && "Unexpected vector size for 512bit shuffle.");
16651 
16652   // TODO - use Zeroable like we do for lowerV2X128VectorShuffle?
16653   SmallVector<int, 4> Widened128Mask;
16654   if (!canWidenShuffleElements(Mask, Widened128Mask))
16655     return SDValue();
16656   assert(Widened128Mask.size() == 4 && "Shuffle widening mismatch");
16657 
16658   // Try to use an insert into a zero vector.
16659   if (Widened128Mask[0] == 0 && (Zeroable & 0xf0) == 0xf0 &&
16660       (Widened128Mask[1] == 1 || (Zeroable & 0x0c) == 0x0c)) {
16661     unsigned NumElts = ((Zeroable & 0x0c) == 0x0c) ? 2 : 4;
16662     MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), NumElts);
16663     SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1,
16664                               DAG.getIntPtrConstant(0, DL));
16665     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
16666                        getZeroVector(VT, Subtarget, DAG, DL), LoV,
16667                        DAG.getIntPtrConstant(0, DL));
16668   }
16669 
16670   // Check for patterns which can be matched with a single insert of a 256-bit
16671   // subvector.
16672   bool OnlyUsesV1 = isShuffleEquivalent(Mask, {0, 1, 2, 3, 0, 1, 2, 3}, V1, V2);
16673   if (OnlyUsesV1 ||
16674       isShuffleEquivalent(Mask, {0, 1, 2, 3, 8, 9, 10, 11}, V1, V2)) {
16675     MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 4);
16676     SDValue SubVec =
16677         DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, OnlyUsesV1 ? V1 : V2,
16678                     DAG.getIntPtrConstant(0, DL));
16679     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, V1, SubVec,
16680                        DAG.getIntPtrConstant(4, DL));
16681   }
16682 
16683   // See if this is an insertion of the lower 128-bits of V2 into V1.
16684   bool IsInsert = true;
16685   int V2Index = -1;
16686   for (int i = 0; i < 4; ++i) {
16687     assert(Widened128Mask[i] >= -1 && "Illegal shuffle sentinel value");
16688     if (Widened128Mask[i] < 0)
16689       continue;
16690 
16691     // Make sure all V1 subvectors are in place.
16692     if (Widened128Mask[i] < 4) {
16693       if (Widened128Mask[i] != i) {
16694         IsInsert = false;
16695         break;
16696       }
16697     } else {
16698       // Make sure we only have a single V2 index and its the lowest 128-bits.
16699       if (V2Index >= 0 || Widened128Mask[i] != 4) {
16700         IsInsert = false;
16701         break;
16702       }
16703       V2Index = i;
16704     }
16705   }
16706   if (IsInsert && V2Index >= 0) {
16707     MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2);
16708     SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V2,
16709                                  DAG.getIntPtrConstant(0, DL));
16710     return insert128BitVector(V1, Subvec, V2Index * 2, DAG, DL);
16711   }
16712 
16713   // See if we can widen to a 256-bit lane shuffle, we're going to lose 128-lane
16714   // UNDEF info by lowering to X86ISD::SHUF128 anyway, so by widening where
16715   // possible we at least ensure the lanes stay sequential to help later
16716   // combines.
16717   SmallVector<int, 2> Widened256Mask;
16718   if (canWidenShuffleElements(Widened128Mask, Widened256Mask)) {
16719     Widened128Mask.clear();
16720     narrowShuffleMaskElts(2, Widened256Mask, Widened128Mask);
16721   }
16722 
16723   // Try to lower to vshuf64x2/vshuf32x4.
16724   SDValue Ops[2] = {DAG.getUNDEF(VT), DAG.getUNDEF(VT)};
16725   int PermMask[4] = {-1, -1, -1, -1};
16726   // Ensure elements came from the same Op.
16727   for (int i = 0; i < 4; ++i) {
16728     assert(Widened128Mask[i] >= -1 && "Illegal shuffle sentinel value");
16729     if (Widened128Mask[i] < 0)
16730       continue;
16731 
16732     SDValue Op = Widened128Mask[i] >= 4 ? V2 : V1;
16733     unsigned OpIndex = i / 2;
16734     if (Ops[OpIndex].isUndef())
16735       Ops[OpIndex] = Op;
16736     else if (Ops[OpIndex] != Op)
16737       return SDValue();
16738 
16739     PermMask[i] = Widened128Mask[i] % 4;
16740   }
16741 
16742   return DAG.getNode(X86ISD::SHUF128, DL, VT, Ops[0], Ops[1],
16743                      getV4X86ShuffleImm8ForMask(PermMask, DL, DAG));
16744 }
16745 
16746 /// Handle lowering of 8-lane 64-bit floating point shuffles.
lowerV8F64Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16747 static SDValue lowerV8F64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16748                                  const APInt &Zeroable, SDValue V1, SDValue V2,
16749                                  const X86Subtarget &Subtarget,
16750                                  SelectionDAG &DAG) {
16751   assert(V1.getSimpleValueType() == MVT::v8f64 && "Bad operand type!");
16752   assert(V2.getSimpleValueType() == MVT::v8f64 && "Bad operand type!");
16753   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
16754 
16755   if (V2.isUndef()) {
16756     // Use low duplicate instructions for masks that match their pattern.
16757     if (isShuffleEquivalent(Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1, V2))
16758       return DAG.getNode(X86ISD::MOVDDUP, DL, MVT::v8f64, V1);
16759 
16760     if (!is128BitLaneCrossingShuffleMask(MVT::v8f64, Mask)) {
16761       // Non-half-crossing single input shuffles can be lowered with an
16762       // interleaved permutation.
16763       unsigned VPERMILPMask = (Mask[0] == 1) | ((Mask[1] == 1) << 1) |
16764                               ((Mask[2] == 3) << 2) | ((Mask[3] == 3) << 3) |
16765                               ((Mask[4] == 5) << 4) | ((Mask[5] == 5) << 5) |
16766                               ((Mask[6] == 7) << 6) | ((Mask[7] == 7) << 7);
16767       return DAG.getNode(X86ISD::VPERMILPI, DL, MVT::v8f64, V1,
16768                          DAG.getTargetConstant(VPERMILPMask, DL, MVT::i8));
16769     }
16770 
16771     SmallVector<int, 4> RepeatedMask;
16772     if (is256BitLaneRepeatedShuffleMask(MVT::v8f64, Mask, RepeatedMask))
16773       return DAG.getNode(X86ISD::VPERMI, DL, MVT::v8f64, V1,
16774                          getV4X86ShuffleImm8ForMask(RepeatedMask, DL, DAG));
16775   }
16776 
16777   if (SDValue Shuf128 = lowerV4X128Shuffle(DL, MVT::v8f64, Mask, Zeroable, V1,
16778                                            V2, Subtarget, DAG))
16779     return Shuf128;
16780 
16781   if (SDValue Unpck = lowerShuffleWithUNPCK(DL, MVT::v8f64, Mask, V1, V2, DAG))
16782     return Unpck;
16783 
16784   // Check if the blend happens to exactly fit that of SHUFPD.
16785   if (SDValue Op = lowerShuffleWithSHUFPD(DL, MVT::v8f64, V1, V2, Mask,
16786                                           Zeroable, Subtarget, DAG))
16787     return Op;
16788 
16789   if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v8f64, Zeroable, Mask, V1, V2,
16790                                        DAG, Subtarget))
16791     return V;
16792 
16793   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v8f64, V1, V2, Mask,
16794                                           Zeroable, Subtarget, DAG))
16795     return Blend;
16796 
16797   return lowerShuffleWithPERMV(DL, MVT::v8f64, Mask, V1, V2, Subtarget, DAG);
16798 }
16799 
16800 /// Handle lowering of 16-lane 32-bit floating point shuffles.
lowerV16F32Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16801 static SDValue lowerV16F32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16802                                   const APInt &Zeroable, SDValue V1, SDValue V2,
16803                                   const X86Subtarget &Subtarget,
16804                                   SelectionDAG &DAG) {
16805   assert(V1.getSimpleValueType() == MVT::v16f32 && "Bad operand type!");
16806   assert(V2.getSimpleValueType() == MVT::v16f32 && "Bad operand type!");
16807   assert(Mask.size() == 16 && "Unexpected mask size for v16 shuffle!");
16808 
16809   // If the shuffle mask is repeated in each 128-bit lane, we have many more
16810   // options to efficiently lower the shuffle.
16811   SmallVector<int, 4> RepeatedMask;
16812   if (is128BitLaneRepeatedShuffleMask(MVT::v16f32, Mask, RepeatedMask)) {
16813     assert(RepeatedMask.size() == 4 && "Unexpected repeated mask size!");
16814 
16815     // Use even/odd duplicate instructions for masks that match their pattern.
16816     if (isShuffleEquivalent(RepeatedMask, {0, 0, 2, 2}, V1, V2))
16817       return DAG.getNode(X86ISD::MOVSLDUP, DL, MVT::v16f32, V1);
16818     if (isShuffleEquivalent(RepeatedMask, {1, 1, 3, 3}, V1, V2))
16819       return DAG.getNode(X86ISD::MOVSHDUP, DL, MVT::v16f32, V1);
16820 
16821     if (V2.isUndef())
16822       return DAG.getNode(X86ISD::VPERMILPI, DL, MVT::v16f32, V1,
16823                          getV4X86ShuffleImm8ForMask(RepeatedMask, DL, DAG));
16824 
16825     // Use dedicated unpack instructions for masks that match their pattern.
16826     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16f32, Mask, V1, V2, DAG))
16827       return V;
16828 
16829     if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v16f32, V1, V2, Mask,
16830                                             Zeroable, Subtarget, DAG))
16831       return Blend;
16832 
16833     // Otherwise, fall back to a SHUFPS sequence.
16834     return lowerShuffleWithSHUFPS(DL, MVT::v16f32, RepeatedMask, V1, V2, DAG);
16835   }
16836 
16837   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v16f32, V1, V2, Mask,
16838                                           Zeroable, Subtarget, DAG))
16839     return Blend;
16840 
16841   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(
16842           DL, MVT::v16i32, V1, V2, Mask, Zeroable, Subtarget, DAG))
16843     return DAG.getBitcast(MVT::v16f32, ZExt);
16844 
16845   // Try to create an in-lane repeating shuffle mask and then shuffle the
16846   // results into the target lanes.
16847   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
16848           DL, MVT::v16f32, V1, V2, Mask, Subtarget, DAG))
16849     return V;
16850 
16851   // If we have a single input shuffle with different shuffle patterns in the
16852   // 128-bit lanes and don't lane cross, use variable mask VPERMILPS.
16853   if (V2.isUndef() &&
16854       !is128BitLaneCrossingShuffleMask(MVT::v16f32, Mask)) {
16855     SDValue VPermMask = getConstVector(Mask, MVT::v16i32, DAG, DL, true);
16856     return DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v16f32, V1, VPermMask);
16857   }
16858 
16859   // If we have AVX512F support, we can use VEXPAND.
16860   if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v16f32, Zeroable, Mask,
16861                                              V1, V2, DAG, Subtarget))
16862     return V;
16863 
16864   return lowerShuffleWithPERMV(DL, MVT::v16f32, Mask, V1, V2, Subtarget, DAG);
16865 }
16866 
16867 /// Handle lowering of 8-lane 64-bit integer shuffles.
lowerV8I64Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16868 static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16869                                  const APInt &Zeroable, SDValue V1, SDValue V2,
16870                                  const X86Subtarget &Subtarget,
16871                                  SelectionDAG &DAG) {
16872   assert(V1.getSimpleValueType() == MVT::v8i64 && "Bad operand type!");
16873   assert(V2.getSimpleValueType() == MVT::v8i64 && "Bad operand type!");
16874   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
16875 
16876   // Try to use shift instructions if fast.
16877   if (Subtarget.preferLowerShuffleAsShift())
16878     if (SDValue Shift =
16879             lowerShuffleAsShift(DL, MVT::v8i64, V1, V2, Mask, Zeroable,
16880                                 Subtarget, DAG, /*BitwiseOnly*/ true))
16881       return Shift;
16882 
16883   if (V2.isUndef()) {
16884     // When the shuffle is mirrored between the 128-bit lanes of the unit, we
16885     // can use lower latency instructions that will operate on all four
16886     // 128-bit lanes.
16887     SmallVector<int, 2> Repeated128Mask;
16888     if (is128BitLaneRepeatedShuffleMask(MVT::v8i64, Mask, Repeated128Mask)) {
16889       SmallVector<int, 4> PSHUFDMask;
16890       narrowShuffleMaskElts(2, Repeated128Mask, PSHUFDMask);
16891       return DAG.getBitcast(
16892           MVT::v8i64,
16893           DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
16894                       DAG.getBitcast(MVT::v16i32, V1),
16895                       getV4X86ShuffleImm8ForMask(PSHUFDMask, DL, DAG)));
16896     }
16897 
16898     SmallVector<int, 4> Repeated256Mask;
16899     if (is256BitLaneRepeatedShuffleMask(MVT::v8i64, Mask, Repeated256Mask))
16900       return DAG.getNode(X86ISD::VPERMI, DL, MVT::v8i64, V1,
16901                          getV4X86ShuffleImm8ForMask(Repeated256Mask, DL, DAG));
16902   }
16903 
16904   if (SDValue Shuf128 = lowerV4X128Shuffle(DL, MVT::v8i64, Mask, Zeroable, V1,
16905                                            V2, Subtarget, DAG))
16906     return Shuf128;
16907 
16908   // Try to use shift instructions.
16909   if (SDValue Shift =
16910           lowerShuffleAsShift(DL, MVT::v8i64, V1, V2, Mask, Zeroable, Subtarget,
16911                               DAG, /*BitwiseOnly*/ false))
16912     return Shift;
16913 
16914   // Try to use VALIGN.
16915   if (SDValue Rotate = lowerShuffleAsVALIGN(DL, MVT::v8i64, V1, V2, Mask,
16916                                             Zeroable, Subtarget, DAG))
16917     return Rotate;
16918 
16919   // Try to use PALIGNR.
16920   if (Subtarget.hasBWI())
16921     if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v8i64, V1, V2, Mask,
16922                                                   Subtarget, DAG))
16923       return Rotate;
16924 
16925   if (SDValue Unpck = lowerShuffleWithUNPCK(DL, MVT::v8i64, Mask, V1, V2, DAG))
16926     return Unpck;
16927 
16928   // If we have AVX512F support, we can use VEXPAND.
16929   if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v8i64, Zeroable, Mask, V1, V2,
16930                                        DAG, Subtarget))
16931     return V;
16932 
16933   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v8i64, V1, V2, Mask,
16934                                           Zeroable, Subtarget, DAG))
16935     return Blend;
16936 
16937   return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
16938 }
16939 
16940 /// Handle lowering of 16-lane 32-bit integer shuffles.
lowerV16I32Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)16941 static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
16942                                   const APInt &Zeroable, SDValue V1, SDValue V2,
16943                                   const X86Subtarget &Subtarget,
16944                                   SelectionDAG &DAG) {
16945   assert(V1.getSimpleValueType() == MVT::v16i32 && "Bad operand type!");
16946   assert(V2.getSimpleValueType() == MVT::v16i32 && "Bad operand type!");
16947   assert(Mask.size() == 16 && "Unexpected mask size for v16 shuffle!");
16948 
16949   int NumV2Elements = count_if(Mask, [](int M) { return M >= 16; });
16950 
16951   // Whenever we can lower this as a zext, that instruction is strictly faster
16952   // than any alternative. It also allows us to fold memory operands into the
16953   // shuffle in many cases.
16954   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(
16955           DL, MVT::v16i32, V1, V2, Mask, Zeroable, Subtarget, DAG))
16956     return ZExt;
16957 
16958   // Try to use shift instructions if fast.
16959   if (Subtarget.preferLowerShuffleAsShift()) {
16960     if (SDValue Shift =
16961             lowerShuffleAsShift(DL, MVT::v16i32, V1, V2, Mask, Zeroable,
16962                                 Subtarget, DAG, /*BitwiseOnly*/ true))
16963       return Shift;
16964     if (NumV2Elements == 0)
16965       if (SDValue Rotate = lowerShuffleAsBitRotate(DL, MVT::v16i32, V1, Mask,
16966                                                    Subtarget, DAG))
16967         return Rotate;
16968   }
16969 
16970   // If the shuffle mask is repeated in each 128-bit lane we can use more
16971   // efficient instructions that mirror the shuffles across the four 128-bit
16972   // lanes.
16973   SmallVector<int, 4> RepeatedMask;
16974   bool Is128BitLaneRepeatedShuffle =
16975       is128BitLaneRepeatedShuffleMask(MVT::v16i32, Mask, RepeatedMask);
16976   if (Is128BitLaneRepeatedShuffle) {
16977     assert(RepeatedMask.size() == 4 && "Unexpected repeated mask size!");
16978     if (V2.isUndef())
16979       return DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32, V1,
16980                          getV4X86ShuffleImm8ForMask(RepeatedMask, DL, DAG));
16981 
16982     // Use dedicated unpack instructions for masks that match their pattern.
16983     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, Mask, V1, V2, DAG))
16984       return V;
16985   }
16986 
16987   // Try to use shift instructions.
16988   if (SDValue Shift =
16989           lowerShuffleAsShift(DL, MVT::v16i32, V1, V2, Mask, Zeroable,
16990                               Subtarget, DAG, /*BitwiseOnly*/ false))
16991     return Shift;
16992 
16993   if (!Subtarget.preferLowerShuffleAsShift() && NumV2Elements != 0)
16994     if (SDValue Rotate =
16995             lowerShuffleAsBitRotate(DL, MVT::v16i32, V1, Mask, Subtarget, DAG))
16996       return Rotate;
16997 
16998   // Try to use VALIGN.
16999   if (SDValue Rotate = lowerShuffleAsVALIGN(DL, MVT::v16i32, V1, V2, Mask,
17000                                             Zeroable, Subtarget, DAG))
17001     return Rotate;
17002 
17003   // Try to use byte rotation instructions.
17004   if (Subtarget.hasBWI())
17005     if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v16i32, V1, V2, Mask,
17006                                                   Subtarget, DAG))
17007       return Rotate;
17008 
17009   // Assume that a single SHUFPS is faster than using a permv shuffle.
17010   // If some CPU is harmed by the domain switch, we can fix it in a later pass.
17011   if (Is128BitLaneRepeatedShuffle && isSingleSHUFPSMask(RepeatedMask)) {
17012     SDValue CastV1 = DAG.getBitcast(MVT::v16f32, V1);
17013     SDValue CastV2 = DAG.getBitcast(MVT::v16f32, V2);
17014     SDValue ShufPS = lowerShuffleWithSHUFPS(DL, MVT::v16f32, RepeatedMask,
17015                                             CastV1, CastV2, DAG);
17016     return DAG.getBitcast(MVT::v16i32, ShufPS);
17017   }
17018 
17019   // Try to create an in-lane repeating shuffle mask and then shuffle the
17020   // results into the target lanes.
17021   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
17022           DL, MVT::v16i32, V1, V2, Mask, Subtarget, DAG))
17023     return V;
17024 
17025   // If we have AVX512F support, we can use VEXPAND.
17026   if (SDValue V = lowerShuffleToEXPAND(DL, MVT::v16i32, Zeroable, Mask, V1, V2,
17027                                        DAG, Subtarget))
17028     return V;
17029 
17030   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v16i32, V1, V2, Mask,
17031                                           Zeroable, Subtarget, DAG))
17032     return Blend;
17033 
17034   return lowerShuffleWithPERMV(DL, MVT::v16i32, Mask, V1, V2, Subtarget, DAG);
17035 }
17036 
17037 /// Handle lowering of 32-lane 16-bit integer shuffles.
lowerV32I16Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)17038 static SDValue lowerV32I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
17039                                   const APInt &Zeroable, SDValue V1, SDValue V2,
17040                                   const X86Subtarget &Subtarget,
17041                                   SelectionDAG &DAG) {
17042   assert(V1.getSimpleValueType() == MVT::v32i16 && "Bad operand type!");
17043   assert(V2.getSimpleValueType() == MVT::v32i16 && "Bad operand type!");
17044   assert(Mask.size() == 32 && "Unexpected mask size for v32 shuffle!");
17045   assert(Subtarget.hasBWI() && "We can only lower v32i16 with AVX-512-BWI!");
17046 
17047   // Whenever we can lower this as a zext, that instruction is strictly faster
17048   // than any alternative. It also allows us to fold memory operands into the
17049   // shuffle in many cases.
17050   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(
17051           DL, MVT::v32i16, V1, V2, Mask, Zeroable, Subtarget, DAG))
17052     return ZExt;
17053 
17054   // Use dedicated unpack instructions for masks that match their pattern.
17055   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v32i16, Mask, V1, V2, DAG))
17056     return V;
17057 
17058   // Use dedicated pack instructions for masks that match their pattern.
17059   if (SDValue V =
17060           lowerShuffleWithPACK(DL, MVT::v32i16, Mask, V1, V2, DAG, Subtarget))
17061     return V;
17062 
17063   // Try to use shift instructions.
17064   if (SDValue Shift =
17065           lowerShuffleAsShift(DL, MVT::v32i16, V1, V2, Mask, Zeroable,
17066                               Subtarget, DAG, /*BitwiseOnly*/ false))
17067     return Shift;
17068 
17069   // Try to use byte rotation instructions.
17070   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v32i16, V1, V2, Mask,
17071                                                 Subtarget, DAG))
17072     return Rotate;
17073 
17074   if (V2.isUndef()) {
17075     // Try to use bit rotation instructions.
17076     if (SDValue Rotate =
17077             lowerShuffleAsBitRotate(DL, MVT::v32i16, V1, Mask, Subtarget, DAG))
17078       return Rotate;
17079 
17080     SmallVector<int, 8> RepeatedMask;
17081     if (is128BitLaneRepeatedShuffleMask(MVT::v32i16, Mask, RepeatedMask)) {
17082       // As this is a single-input shuffle, the repeated mask should be
17083       // a strictly valid v8i16 mask that we can pass through to the v8i16
17084       // lowering to handle even the v32 case.
17085       return lowerV8I16GeneralSingleInputShuffle(DL, MVT::v32i16, V1,
17086                                                  RepeatedMask, Subtarget, DAG);
17087     }
17088   }
17089 
17090   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v32i16, V1, V2, Mask,
17091                                           Zeroable, Subtarget, DAG))
17092     return Blend;
17093 
17094   if (SDValue PSHUFB = lowerShuffleWithPSHUFB(DL, MVT::v32i16, Mask, V1, V2,
17095                                               Zeroable, Subtarget, DAG))
17096     return PSHUFB;
17097 
17098   return lowerShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, Subtarget, DAG);
17099 }
17100 
17101 /// Handle lowering of 64-lane 8-bit integer shuffles.
lowerV64I8Shuffle(const SDLoc & DL,ArrayRef<int> Mask,const APInt & Zeroable,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)17102 static SDValue lowerV64I8Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
17103                                  const APInt &Zeroable, SDValue V1, SDValue V2,
17104                                  const X86Subtarget &Subtarget,
17105                                  SelectionDAG &DAG) {
17106   assert(V1.getSimpleValueType() == MVT::v64i8 && "Bad operand type!");
17107   assert(V2.getSimpleValueType() == MVT::v64i8 && "Bad operand type!");
17108   assert(Mask.size() == 64 && "Unexpected mask size for v64 shuffle!");
17109   assert(Subtarget.hasBWI() && "We can only lower v64i8 with AVX-512-BWI!");
17110 
17111   // Whenever we can lower this as a zext, that instruction is strictly faster
17112   // than any alternative. It also allows us to fold memory operands into the
17113   // shuffle in many cases.
17114   if (SDValue ZExt = lowerShuffleAsZeroOrAnyExtend(
17115           DL, MVT::v64i8, V1, V2, Mask, Zeroable, Subtarget, DAG))
17116     return ZExt;
17117 
17118   // Use dedicated unpack instructions for masks that match their pattern.
17119   if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v64i8, Mask, V1, V2, DAG))
17120     return V;
17121 
17122   // Use dedicated pack instructions for masks that match their pattern.
17123   if (SDValue V = lowerShuffleWithPACK(DL, MVT::v64i8, Mask, V1, V2, DAG,
17124                                        Subtarget))
17125     return V;
17126 
17127   // Try to use shift instructions.
17128   if (SDValue Shift =
17129           lowerShuffleAsShift(DL, MVT::v64i8, V1, V2, Mask, Zeroable, Subtarget,
17130                               DAG, /*BitwiseOnly*/ false))
17131     return Shift;
17132 
17133   // Try to use byte rotation instructions.
17134   if (SDValue Rotate = lowerShuffleAsByteRotate(DL, MVT::v64i8, V1, V2, Mask,
17135                                                 Subtarget, DAG))
17136     return Rotate;
17137 
17138   // Try to use bit rotation instructions.
17139   if (V2.isUndef())
17140     if (SDValue Rotate =
17141             lowerShuffleAsBitRotate(DL, MVT::v64i8, V1, Mask, Subtarget, DAG))
17142       return Rotate;
17143 
17144   // Lower as AND if possible.
17145   if (SDValue Masked = lowerShuffleAsBitMask(DL, MVT::v64i8, V1, V2, Mask,
17146                                              Zeroable, Subtarget, DAG))
17147     return Masked;
17148 
17149   if (SDValue PSHUFB = lowerShuffleWithPSHUFB(DL, MVT::v64i8, Mask, V1, V2,
17150                                               Zeroable, Subtarget, DAG))
17151     return PSHUFB;
17152 
17153   // Try to create an in-lane repeating shuffle mask and then shuffle the
17154   // results into the target lanes.
17155   if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute(
17156           DL, MVT::v64i8, V1, V2, Mask, Subtarget, DAG))
17157     return V;
17158 
17159   if (SDValue Result = lowerShuffleAsLanePermuteAndPermute(
17160           DL, MVT::v64i8, V1, V2, Mask, DAG, Subtarget))
17161     return Result;
17162 
17163   if (SDValue Blend = lowerShuffleAsBlend(DL, MVT::v64i8, V1, V2, Mask,
17164                                           Zeroable, Subtarget, DAG))
17165     return Blend;
17166 
17167   if (!is128BitLaneCrossingShuffleMask(MVT::v64i8, Mask)) {
17168     // Use PALIGNR+Permute if possible - permute might become PSHUFB but the
17169     // PALIGNR will be cheaper than the second PSHUFB+OR.
17170     if (SDValue V = lowerShuffleAsByteRotateAndPermute(DL, MVT::v64i8, V1, V2,
17171                                                        Mask, Subtarget, DAG))
17172       return V;
17173 
17174     // If we can't directly blend but can use PSHUFB, that will be better as it
17175     // can both shuffle and set up the inefficient blend.
17176     bool V1InUse, V2InUse;
17177     return lowerShuffleAsBlendOfPSHUFBs(DL, MVT::v64i8, V1, V2, Mask, Zeroable,
17178                                         DAG, V1InUse, V2InUse);
17179   }
17180 
17181   // Try to simplify this by merging 128-bit lanes to enable a lane-based
17182   // shuffle.
17183   if (!V2.isUndef())
17184     if (SDValue Result = lowerShuffleAsLanePermuteAndRepeatedMask(
17185             DL, MVT::v64i8, V1, V2, Mask, Subtarget, DAG))
17186       return Result;
17187 
17188   // VBMI can use VPERMV/VPERMV3 byte shuffles.
17189   if (Subtarget.hasVBMI())
17190     return lowerShuffleWithPERMV(DL, MVT::v64i8, Mask, V1, V2, Subtarget, DAG);
17191 
17192   return splitAndLowerShuffle(DL, MVT::v64i8, V1, V2, Mask, DAG, /*SimpleOnly*/ false);
17193 }
17194 
17195 /// High-level routine to lower various 512-bit x86 vector shuffles.
17196 ///
17197 /// This routine either breaks down the specific type of a 512-bit x86 vector
17198 /// shuffle or splits it into two 256-bit shuffles and fuses the results back
17199 /// together based on the available instructions.
lower512BitShuffle(const SDLoc & DL,ArrayRef<int> Mask,MVT VT,SDValue V1,SDValue V2,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)17200 static SDValue lower512BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
17201                                   MVT VT, SDValue V1, SDValue V2,
17202                                   const APInt &Zeroable,
17203                                   const X86Subtarget &Subtarget,
17204                                   SelectionDAG &DAG) {
17205   assert(Subtarget.hasAVX512() &&
17206          "Cannot lower 512-bit vectors w/ basic ISA!");
17207 
17208   // If we have a single input to the zero element, insert that into V1 if we
17209   // can do so cheaply.
17210   int NumElts = Mask.size();
17211   int NumV2Elements = count_if(Mask, [NumElts](int M) { return M >= NumElts; });
17212 
17213   if (NumV2Elements == 1 && Mask[0] >= NumElts)
17214     if (SDValue Insertion = lowerShuffleAsElementInsertion(
17215             DL, VT, V1, V2, Mask, Zeroable, Subtarget, DAG))
17216       return Insertion;
17217 
17218   // Handle special cases where the lower or upper half is UNDEF.
17219   if (SDValue V =
17220           lowerShuffleWithUndefHalf(DL, VT, V1, V2, Mask, Subtarget, DAG))
17221     return V;
17222 
17223   // Check for being able to broadcast a single element.
17224   if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, VT, V1, V2, Mask,
17225                                                   Subtarget, DAG))
17226     return Broadcast;
17227 
17228   if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.hasBWI()) {
17229     // Try using bit ops for masking and blending before falling back to
17230     // splitting.
17231     if (SDValue V = lowerShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable,
17232                                           Subtarget, DAG))
17233       return V;
17234     if (SDValue V = lowerShuffleAsBitBlend(DL, VT, V1, V2, Mask, DAG))
17235       return V;
17236 
17237     return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG, /*SimpleOnly*/ false);
17238   }
17239 
17240   if (VT == MVT::v32f16 || VT == MVT::v32bf16) {
17241     if (!Subtarget.hasBWI())
17242       return splitAndLowerShuffle(DL, VT, V1, V2, Mask, DAG,
17243                                   /*SimpleOnly*/ false);
17244 
17245     V1 = DAG.getBitcast(MVT::v32i16, V1);
17246     V2 = DAG.getBitcast(MVT::v32i16, V2);
17247     return DAG.getBitcast(VT,
17248                           DAG.getVectorShuffle(MVT::v32i16, DL, V1, V2, Mask));
17249   }
17250 
17251   // Dispatch to each element type for lowering. If we don't have support for
17252   // specific element type shuffles at 512 bits, immediately split them and
17253   // lower them. Each lowering routine of a given type is allowed to assume that
17254   // the requisite ISA extensions for that element type are available.
17255   switch (VT.SimpleTy) {
17256   case MVT::v8f64:
17257     return lowerV8F64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
17258   case MVT::v16f32:
17259     return lowerV16F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
17260   case MVT::v8i64:
17261     return lowerV8I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
17262   case MVT::v16i32:
17263     return lowerV16I32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
17264   case MVT::v32i16:
17265     return lowerV32I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
17266   case MVT::v64i8:
17267     return lowerV64I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
17268 
17269   default:
17270     llvm_unreachable("Not a valid 512-bit x86 vector type!");
17271   }
17272 }
17273 
lower1BitShuffleAsKSHIFTR(const SDLoc & DL,ArrayRef<int> Mask,MVT VT,SDValue V1,SDValue V2,const X86Subtarget & Subtarget,SelectionDAG & DAG)17274 static SDValue lower1BitShuffleAsKSHIFTR(const SDLoc &DL, ArrayRef<int> Mask,
17275                                          MVT VT, SDValue V1, SDValue V2,
17276                                          const X86Subtarget &Subtarget,
17277                                          SelectionDAG &DAG) {
17278   // Shuffle should be unary.
17279   if (!V2.isUndef())
17280     return SDValue();
17281 
17282   int ShiftAmt = -1;
17283   int NumElts = Mask.size();
17284   for (int i = 0; i != NumElts; ++i) {
17285     int M = Mask[i];
17286     assert((M == SM_SentinelUndef || (0 <= M && M < NumElts)) &&
17287            "Unexpected mask index.");
17288     if (M < 0)
17289       continue;
17290 
17291     // The first non-undef element determines our shift amount.
17292     if (ShiftAmt < 0) {
17293       ShiftAmt = M - i;
17294       // Need to be shifting right.
17295       if (ShiftAmt <= 0)
17296         return SDValue();
17297     }
17298     // All non-undef elements must shift by the same amount.
17299     if (ShiftAmt != M - i)
17300       return SDValue();
17301   }
17302   assert(ShiftAmt >= 0 && "All undef?");
17303 
17304   // Great we found a shift right.
17305   SDValue Res = widenMaskVector(V1, false, Subtarget, DAG, DL);
17306   Res = DAG.getNode(X86ISD::KSHIFTR, DL, Res.getValueType(), Res,
17307                     DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
17308   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
17309                      DAG.getIntPtrConstant(0, DL));
17310 }
17311 
17312 // Determine if this shuffle can be implemented with a KSHIFT instruction.
17313 // Returns the shift amount if possible or -1 if not. This is a simplified
17314 // version of matchShuffleAsShift.
match1BitShuffleAsKSHIFT(unsigned & Opcode,ArrayRef<int> Mask,int MaskOffset,const APInt & Zeroable)17315 static int match1BitShuffleAsKSHIFT(unsigned &Opcode, ArrayRef<int> Mask,
17316                                     int MaskOffset, const APInt &Zeroable) {
17317   int Size = Mask.size();
17318 
17319   auto CheckZeros = [&](int Shift, bool Left) {
17320     for (int j = 0; j < Shift; ++j)
17321       if (!Zeroable[j + (Left ? 0 : (Size - Shift))])
17322         return false;
17323 
17324     return true;
17325   };
17326 
17327   auto MatchShift = [&](int Shift, bool Left) {
17328     unsigned Pos = Left ? Shift : 0;
17329     unsigned Low = Left ? 0 : Shift;
17330     unsigned Len = Size - Shift;
17331     return isSequentialOrUndefInRange(Mask, Pos, Len, Low + MaskOffset);
17332   };
17333 
17334   for (int Shift = 1; Shift != Size; ++Shift)
17335     for (bool Left : {true, false})
17336       if (CheckZeros(Shift, Left) && MatchShift(Shift, Left)) {
17337         Opcode = Left ? X86ISD::KSHIFTL : X86ISD::KSHIFTR;
17338         return Shift;
17339       }
17340 
17341   return -1;
17342 }
17343 
17344 
17345 // Lower vXi1 vector shuffles.
17346 // There is no a dedicated instruction on AVX-512 that shuffles the masks.
17347 // The only way to shuffle bits is to sign-extend the mask vector to SIMD
17348 // vector, shuffle and then truncate it back.
lower1BitShuffle(const SDLoc & DL,ArrayRef<int> Mask,MVT VT,SDValue V1,SDValue V2,const APInt & Zeroable,const X86Subtarget & Subtarget,SelectionDAG & DAG)17349 static SDValue lower1BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
17350                                 MVT VT, SDValue V1, SDValue V2,
17351                                 const APInt &Zeroable,
17352                                 const X86Subtarget &Subtarget,
17353                                 SelectionDAG &DAG) {
17354   assert(Subtarget.hasAVX512() &&
17355          "Cannot lower 512-bit vectors w/o basic ISA!");
17356 
17357   int NumElts = Mask.size();
17358   int NumV2Elements = count_if(Mask, [NumElts](int M) { return M >= NumElts; });
17359 
17360   // Try to recognize shuffles that are just padding a subvector with zeros.
17361   int SubvecElts = 0;
17362   int Src = -1;
17363   for (int i = 0; i != NumElts; ++i) {
17364     if (Mask[i] >= 0) {
17365       // Grab the source from the first valid mask. All subsequent elements need
17366       // to use this same source.
17367       if (Src < 0)
17368         Src = Mask[i] / NumElts;
17369       if (Src != (Mask[i] / NumElts) || (Mask[i] % NumElts) != i)
17370         break;
17371     }
17372 
17373     ++SubvecElts;
17374   }
17375   assert(SubvecElts != NumElts && "Identity shuffle?");
17376 
17377   // Clip to a power 2.
17378   SubvecElts = llvm::bit_floor<uint32_t>(SubvecElts);
17379 
17380   // Make sure the number of zeroable bits in the top at least covers the bits
17381   // not covered by the subvector.
17382   if ((int)Zeroable.countl_one() >= (NumElts - SubvecElts)) {
17383     assert(Src >= 0 && "Expected a source!");
17384     MVT ExtractVT = MVT::getVectorVT(MVT::i1, SubvecElts);
17385     SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT,
17386                                   Src == 0 ? V1 : V2,
17387                                   DAG.getIntPtrConstant(0, DL));
17388     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
17389                        DAG.getConstant(0, DL, VT),
17390                        Extract, DAG.getIntPtrConstant(0, DL));
17391   }
17392 
17393   // Try a simple shift right with undef elements. Later we'll try with zeros.
17394   if (SDValue Shift = lower1BitShuffleAsKSHIFTR(DL, Mask, VT, V1, V2, Subtarget,
17395                                                 DAG))
17396     return Shift;
17397 
17398   // Try to match KSHIFTs.
17399   unsigned Offset = 0;
17400   for (SDValue V : { V1, V2 }) {
17401     unsigned Opcode;
17402     int ShiftAmt = match1BitShuffleAsKSHIFT(Opcode, Mask, Offset, Zeroable);
17403     if (ShiftAmt >= 0) {
17404       SDValue Res = widenMaskVector(V, false, Subtarget, DAG, DL);
17405       MVT WideVT = Res.getSimpleValueType();
17406       // Widened right shifts need two shifts to ensure we shift in zeroes.
17407       if (Opcode == X86ISD::KSHIFTR && WideVT != VT) {
17408         int WideElts = WideVT.getVectorNumElements();
17409         // Shift left to put the original vector in the MSBs of the new size.
17410         Res = DAG.getNode(X86ISD::KSHIFTL, DL, WideVT, Res,
17411                           DAG.getTargetConstant(WideElts - NumElts, DL, MVT::i8));
17412         // Increase the shift amount to account for the left shift.
17413         ShiftAmt += WideElts - NumElts;
17414       }
17415 
17416       Res = DAG.getNode(Opcode, DL, WideVT, Res,
17417                         DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
17418       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
17419                          DAG.getIntPtrConstant(0, DL));
17420     }
17421     Offset += NumElts; // Increment for next iteration.
17422   }
17423 
17424   // If we're performing an unary shuffle on a SETCC result, try to shuffle the
17425   // ops instead.
17426   // TODO: What other unary shuffles would benefit from this?
17427   if (NumV2Elements == 0 && V1.getOpcode() == ISD::SETCC && V1->hasOneUse()) {
17428     SDValue Op0 = V1.getOperand(0);
17429     SDValue Op1 = V1.getOperand(1);
17430     ISD::CondCode CC = cast<CondCodeSDNode>(V1.getOperand(2))->get();
17431     EVT OpVT = Op0.getValueType();
17432     if (OpVT.getScalarSizeInBits() >= 32 || isBroadcastShuffleMask(Mask))
17433       return DAG.getSetCC(
17434           DL, VT, DAG.getVectorShuffle(OpVT, DL, Op0, DAG.getUNDEF(OpVT), Mask),
17435           DAG.getVectorShuffle(OpVT, DL, Op1, DAG.getUNDEF(OpVT), Mask), CC);
17436   }
17437 
17438   MVT ExtVT;
17439   switch (VT.SimpleTy) {
17440   default:
17441     llvm_unreachable("Expected a vector of i1 elements");
17442   case MVT::v2i1:
17443     ExtVT = MVT::v2i64;
17444     break;
17445   case MVT::v4i1:
17446     ExtVT = MVT::v4i32;
17447     break;
17448   case MVT::v8i1:
17449     // Take 512-bit type, more shuffles on KNL. If we have VLX use a 256-bit
17450     // shuffle.
17451     ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64;
17452     break;
17453   case MVT::v16i1:
17454     // Take 512-bit type, unless we are avoiding 512-bit types and have the
17455     // 256-bit operation available.
17456     ExtVT = Subtarget.canExtendTo512DQ() ? MVT::v16i32 : MVT::v16i16;
17457     break;
17458   case MVT::v32i1:
17459     // Take 512-bit type, unless we are avoiding 512-bit types and have the
17460     // 256-bit operation available.
17461     assert(Subtarget.hasBWI() && "Expected AVX512BW support");
17462     ExtVT = Subtarget.canExtendTo512BW() ? MVT::v32i16 : MVT::v32i8;
17463     break;
17464   case MVT::v64i1:
17465     // Fall back to scalarization. FIXME: We can do better if the shuffle
17466     // can be partitioned cleanly.
17467     if (!Subtarget.useBWIRegs())
17468       return SDValue();
17469     ExtVT = MVT::v64i8;
17470     break;
17471   }
17472 
17473   V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V1);
17474   V2 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V2);
17475 
17476   SDValue Shuffle = DAG.getVectorShuffle(ExtVT, DL, V1, V2, Mask);
17477   // i1 was sign extended we can use X86ISD::CVT2MASK.
17478   int NumElems = VT.getVectorNumElements();
17479   if ((Subtarget.hasBWI() && (NumElems >= 32)) ||
17480       (Subtarget.hasDQI() && (NumElems < 32)))
17481     return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, ExtVT),
17482                        Shuffle, ISD::SETGT);
17483 
17484   return DAG.getNode(ISD::TRUNCATE, DL, VT, Shuffle);
17485 }
17486 
17487 /// Helper function that returns true if the shuffle mask should be
17488 /// commuted to improve canonicalization.
canonicalizeShuffleMaskWithCommute(ArrayRef<int> Mask)17489 static bool canonicalizeShuffleMaskWithCommute(ArrayRef<int> Mask) {
17490   int NumElements = Mask.size();
17491 
17492   int NumV1Elements = 0, NumV2Elements = 0;
17493   for (int M : Mask)
17494     if (M < 0)
17495       continue;
17496     else if (M < NumElements)
17497       ++NumV1Elements;
17498     else
17499       ++NumV2Elements;
17500 
17501   // Commute the shuffle as needed such that more elements come from V1 than
17502   // V2. This allows us to match the shuffle pattern strictly on how many
17503   // elements come from V1 without handling the symmetric cases.
17504   if (NumV2Elements > NumV1Elements)
17505     return true;
17506 
17507   assert(NumV1Elements > 0 && "No V1 indices");
17508 
17509   if (NumV2Elements == 0)
17510     return false;
17511 
17512   // When the number of V1 and V2 elements are the same, try to minimize the
17513   // number of uses of V2 in the low half of the vector. When that is tied,
17514   // ensure that the sum of indices for V1 is equal to or lower than the sum
17515   // indices for V2. When those are equal, try to ensure that the number of odd
17516   // indices for V1 is lower than the number of odd indices for V2.
17517   if (NumV1Elements == NumV2Elements) {
17518     int LowV1Elements = 0, LowV2Elements = 0;
17519     for (int M : Mask.slice(0, NumElements / 2))
17520       if (M >= NumElements)
17521         ++LowV2Elements;
17522       else if (M >= 0)
17523         ++LowV1Elements;
17524     if (LowV2Elements > LowV1Elements)
17525       return true;
17526     if (LowV2Elements == LowV1Elements) {
17527       int SumV1Indices = 0, SumV2Indices = 0;
17528       for (int i = 0, Size = Mask.size(); i < Size; ++i)
17529         if (Mask[i] >= NumElements)
17530           SumV2Indices += i;
17531         else if (Mask[i] >= 0)
17532           SumV1Indices += i;
17533       if (SumV2Indices < SumV1Indices)
17534         return true;
17535       if (SumV2Indices == SumV1Indices) {
17536         int NumV1OddIndices = 0, NumV2OddIndices = 0;
17537         for (int i = 0, Size = Mask.size(); i < Size; ++i)
17538           if (Mask[i] >= NumElements)
17539             NumV2OddIndices += i % 2;
17540           else if (Mask[i] >= 0)
17541             NumV1OddIndices += i % 2;
17542         if (NumV2OddIndices < NumV1OddIndices)
17543           return true;
17544       }
17545     }
17546   }
17547 
17548   return false;
17549 }
17550 
canCombineAsMaskOperation(SDValue V,const X86Subtarget & Subtarget)17551 static bool canCombineAsMaskOperation(SDValue V,
17552                                       const X86Subtarget &Subtarget) {
17553   if (!Subtarget.hasAVX512())
17554     return false;
17555 
17556   if (!V.getValueType().isSimple())
17557     return false;
17558 
17559   MVT VT = V.getSimpleValueType().getScalarType();
17560   if ((VT == MVT::i16 || VT == MVT::i8) && !Subtarget.hasBWI())
17561     return false;
17562 
17563   // If vec width < 512, widen i8/i16 even with BWI as blendd/blendps/blendpd
17564   // are preferable to blendw/blendvb/masked-mov.
17565   if ((VT == MVT::i16 || VT == MVT::i8) &&
17566       V.getSimpleValueType().getSizeInBits() < 512)
17567     return false;
17568 
17569   auto HasMaskOperation = [&](SDValue V) {
17570     // TODO: Currently we only check limited opcode. We probably extend
17571     // it to all binary operation by checking TLI.isBinOp().
17572     switch (V->getOpcode()) {
17573     default:
17574       return false;
17575     case ISD::ADD:
17576     case ISD::SUB:
17577     case ISD::AND:
17578     case ISD::XOR:
17579     case ISD::OR:
17580     case ISD::SMAX:
17581     case ISD::SMIN:
17582     case ISD::UMAX:
17583     case ISD::UMIN:
17584     case ISD::ABS:
17585     case ISD::SHL:
17586     case ISD::SRL:
17587     case ISD::SRA:
17588     case ISD::MUL:
17589       break;
17590     }
17591     if (!V->hasOneUse())
17592       return false;
17593 
17594     return true;
17595   };
17596 
17597   if (HasMaskOperation(V))
17598     return true;
17599 
17600   return false;
17601 }
17602 
17603 // Forward declaration.
17604 static SDValue canonicalizeShuffleMaskWithHorizOp(
17605     MutableArrayRef<SDValue> Ops, MutableArrayRef<int> Mask,
17606     unsigned RootSizeInBits, const SDLoc &DL, SelectionDAG &DAG,
17607     const X86Subtarget &Subtarget);
17608 
17609     /// Top-level lowering for x86 vector shuffles.
17610 ///
17611 /// This handles decomposition, canonicalization, and lowering of all x86
17612 /// vector shuffles. Most of the specific lowering strategies are encapsulated
17613 /// above in helper routines. The canonicalization attempts to widen shuffles
17614 /// to involve fewer lanes of wider elements, consolidate symmetric patterns
17615 /// s.t. only one of the two inputs needs to be tested, etc.
lowerVECTOR_SHUFFLE(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)17616 static SDValue lowerVECTOR_SHUFFLE(SDValue Op, const X86Subtarget &Subtarget,
17617                                    SelectionDAG &DAG) {
17618   ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(Op);
17619   ArrayRef<int> OrigMask = SVOp->getMask();
17620   SDValue V1 = Op.getOperand(0);
17621   SDValue V2 = Op.getOperand(1);
17622   MVT VT = Op.getSimpleValueType();
17623   int NumElements = VT.getVectorNumElements();
17624   SDLoc DL(Op);
17625   bool Is1BitVector = (VT.getVectorElementType() == MVT::i1);
17626 
17627   assert((VT.getSizeInBits() != 64 || Is1BitVector) &&
17628          "Can't lower MMX shuffles");
17629 
17630   bool V1IsUndef = V1.isUndef();
17631   bool V2IsUndef = V2.isUndef();
17632   if (V1IsUndef && V2IsUndef)
17633     return DAG.getUNDEF(VT);
17634 
17635   // When we create a shuffle node we put the UNDEF node to second operand,
17636   // but in some cases the first operand may be transformed to UNDEF.
17637   // In this case we should just commute the node.
17638   if (V1IsUndef)
17639     return DAG.getCommutedVectorShuffle(*SVOp);
17640 
17641   // Check for non-undef masks pointing at an undef vector and make the masks
17642   // undef as well. This makes it easier to match the shuffle based solely on
17643   // the mask.
17644   if (V2IsUndef &&
17645       any_of(OrigMask, [NumElements](int M) { return M >= NumElements; })) {
17646     SmallVector<int, 8> NewMask(OrigMask);
17647     for (int &M : NewMask)
17648       if (M >= NumElements)
17649         M = -1;
17650     return DAG.getVectorShuffle(VT, DL, V1, V2, NewMask);
17651   }
17652 
17653   // Check for illegal shuffle mask element index values.
17654   int MaskUpperLimit = OrigMask.size() * (V2IsUndef ? 1 : 2);
17655   (void)MaskUpperLimit;
17656   assert(llvm::all_of(OrigMask,
17657                       [&](int M) { return -1 <= M && M < MaskUpperLimit; }) &&
17658          "Out of bounds shuffle index");
17659 
17660   // We actually see shuffles that are entirely re-arrangements of a set of
17661   // zero inputs. This mostly happens while decomposing complex shuffles into
17662   // simple ones. Directly lower these as a buildvector of zeros.
17663   APInt KnownUndef, KnownZero;
17664   computeZeroableShuffleElements(OrigMask, V1, V2, KnownUndef, KnownZero);
17665 
17666   APInt Zeroable = KnownUndef | KnownZero;
17667   if (Zeroable.isAllOnes())
17668     return getZeroVector(VT, Subtarget, DAG, DL);
17669 
17670   bool V2IsZero = !V2IsUndef && ISD::isBuildVectorAllZeros(V2.getNode());
17671 
17672   // Try to collapse shuffles into using a vector type with fewer elements but
17673   // wider element types. We cap this to not form integers or floating point
17674   // elements wider than 64 bits. It does not seem beneficial to form i128
17675   // integers to handle flipping the low and high halves of AVX 256-bit vectors.
17676   SmallVector<int, 16> WidenedMask;
17677   if (VT.getScalarSizeInBits() < 64 && !Is1BitVector &&
17678       !canCombineAsMaskOperation(V1, Subtarget) &&
17679       !canCombineAsMaskOperation(V2, Subtarget) &&
17680       canWidenShuffleElements(OrigMask, Zeroable, V2IsZero, WidenedMask)) {
17681     // Shuffle mask widening should not interfere with a broadcast opportunity
17682     // by obfuscating the operands with bitcasts.
17683     // TODO: Avoid lowering directly from this top-level function: make this
17684     // a query (canLowerAsBroadcast) and defer lowering to the type-based calls.
17685     if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, VT, V1, V2, OrigMask,
17686                                                     Subtarget, DAG))
17687       return Broadcast;
17688 
17689     MVT NewEltVT = VT.isFloatingPoint()
17690                        ? MVT::getFloatingPointVT(VT.getScalarSizeInBits() * 2)
17691                        : MVT::getIntegerVT(VT.getScalarSizeInBits() * 2);
17692     int NewNumElts = NumElements / 2;
17693     MVT NewVT = MVT::getVectorVT(NewEltVT, NewNumElts);
17694     // Make sure that the new vector type is legal. For example, v2f64 isn't
17695     // legal on SSE1.
17696     if (DAG.getTargetLoweringInfo().isTypeLegal(NewVT)) {
17697       if (V2IsZero) {
17698         // Modify the new Mask to take all zeros from the all-zero vector.
17699         // Choose indices that are blend-friendly.
17700         bool UsedZeroVector = false;
17701         assert(is_contained(WidenedMask, SM_SentinelZero) &&
17702                "V2's non-undef elements are used?!");
17703         for (int i = 0; i != NewNumElts; ++i)
17704           if (WidenedMask[i] == SM_SentinelZero) {
17705             WidenedMask[i] = i + NewNumElts;
17706             UsedZeroVector = true;
17707           }
17708         // Ensure all elements of V2 are zero - isBuildVectorAllZeros permits
17709         // some elements to be undef.
17710         if (UsedZeroVector)
17711           V2 = getZeroVector(NewVT, Subtarget, DAG, DL);
17712       }
17713       V1 = DAG.getBitcast(NewVT, V1);
17714       V2 = DAG.getBitcast(NewVT, V2);
17715       return DAG.getBitcast(
17716           VT, DAG.getVectorShuffle(NewVT, DL, V1, V2, WidenedMask));
17717     }
17718   }
17719 
17720   SmallVector<SDValue> Ops = {V1, V2};
17721   SmallVector<int> Mask(OrigMask);
17722 
17723   // Canonicalize the shuffle with any horizontal ops inputs.
17724   // NOTE: This may update Ops and Mask.
17725   if (SDValue HOp = canonicalizeShuffleMaskWithHorizOp(
17726           Ops, Mask, VT.getSizeInBits(), DL, DAG, Subtarget))
17727     return DAG.getBitcast(VT, HOp);
17728 
17729   V1 = DAG.getBitcast(VT, Ops[0]);
17730   V2 = DAG.getBitcast(VT, Ops[1]);
17731   assert(NumElements == (int)Mask.size() &&
17732          "canonicalizeShuffleMaskWithHorizOp "
17733          "shouldn't alter the shuffle mask size");
17734 
17735   // Commute the shuffle if it will improve canonicalization.
17736   if (canonicalizeShuffleMaskWithCommute(Mask)) {
17737     ShuffleVectorSDNode::commuteMask(Mask);
17738     std::swap(V1, V2);
17739   }
17740 
17741   // For each vector width, delegate to a specialized lowering routine.
17742   if (VT.is128BitVector())
17743     return lower128BitShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, DAG);
17744 
17745   if (VT.is256BitVector())
17746     return lower256BitShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, DAG);
17747 
17748   if (VT.is512BitVector())
17749     return lower512BitShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, DAG);
17750 
17751   if (Is1BitVector)
17752     return lower1BitShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, DAG);
17753 
17754   llvm_unreachable("Unimplemented!");
17755 }
17756 
17757 /// Try to lower a VSELECT instruction to a vector shuffle.
lowerVSELECTtoVectorShuffle(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)17758 static SDValue lowerVSELECTtoVectorShuffle(SDValue Op,
17759                                            const X86Subtarget &Subtarget,
17760                                            SelectionDAG &DAG) {
17761   SDValue Cond = Op.getOperand(0);
17762   SDValue LHS = Op.getOperand(1);
17763   SDValue RHS = Op.getOperand(2);
17764   MVT VT = Op.getSimpleValueType();
17765 
17766   // Only non-legal VSELECTs reach this lowering, convert those into generic
17767   // shuffles and re-use the shuffle lowering path for blends.
17768   if (ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) {
17769     SmallVector<int, 32> Mask;
17770     if (createShuffleMaskFromVSELECT(Mask, Cond))
17771       return DAG.getVectorShuffle(VT, SDLoc(Op), LHS, RHS, Mask);
17772   }
17773 
17774   return SDValue();
17775 }
17776 
LowerVSELECT(SDValue Op,SelectionDAG & DAG) const17777 SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const {
17778   SDValue Cond = Op.getOperand(0);
17779   SDValue LHS = Op.getOperand(1);
17780   SDValue RHS = Op.getOperand(2);
17781 
17782   SDLoc dl(Op);
17783   MVT VT = Op.getSimpleValueType();
17784   if (isSoftF16(VT, Subtarget)) {
17785     MVT NVT = VT.changeVectorElementTypeToInteger();
17786     return DAG.getBitcast(VT, DAG.getNode(ISD::VSELECT, dl, NVT, Cond,
17787                                           DAG.getBitcast(NVT, LHS),
17788                                           DAG.getBitcast(NVT, RHS)));
17789   }
17790 
17791   // A vselect where all conditions and data are constants can be optimized into
17792   // a single vector load by SelectionDAGLegalize::ExpandBUILD_VECTOR().
17793   if (ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()) &&
17794       ISD::isBuildVectorOfConstantSDNodes(LHS.getNode()) &&
17795       ISD::isBuildVectorOfConstantSDNodes(RHS.getNode()))
17796     return SDValue();
17797 
17798   // Try to lower this to a blend-style vector shuffle. This can handle all
17799   // constant condition cases.
17800   if (SDValue BlendOp = lowerVSELECTtoVectorShuffle(Op, Subtarget, DAG))
17801     return BlendOp;
17802 
17803   // If this VSELECT has a vector if i1 as a mask, it will be directly matched
17804   // with patterns on the mask registers on AVX-512.
17805   MVT CondVT = Cond.getSimpleValueType();
17806   unsigned CondEltSize = Cond.getScalarValueSizeInBits();
17807   if (CondEltSize == 1)
17808     return Op;
17809 
17810   // Variable blends are only legal from SSE4.1 onward.
17811   if (!Subtarget.hasSSE41())
17812     return SDValue();
17813 
17814   unsigned EltSize = VT.getScalarSizeInBits();
17815   unsigned NumElts = VT.getVectorNumElements();
17816 
17817   // Expand v32i16/v64i8 without BWI.
17818   if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.hasBWI())
17819     return SDValue();
17820 
17821   // If the VSELECT is on a 512-bit type, we have to convert a non-i1 condition
17822   // into an i1 condition so that we can use the mask-based 512-bit blend
17823   // instructions.
17824   if (VT.getSizeInBits() == 512) {
17825     // Build a mask by testing the condition against zero.
17826     MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
17827     SDValue Mask = DAG.getSetCC(dl, MaskVT, Cond,
17828                                 DAG.getConstant(0, dl, CondVT),
17829                                 ISD::SETNE);
17830     // Now return a new VSELECT using the mask.
17831     return DAG.getSelect(dl, VT, Mask, LHS, RHS);
17832   }
17833 
17834   // SEXT/TRUNC cases where the mask doesn't match the destination size.
17835   if (CondEltSize != EltSize) {
17836     // If we don't have a sign splat, rely on the expansion.
17837     if (CondEltSize != DAG.ComputeNumSignBits(Cond))
17838       return SDValue();
17839 
17840     MVT NewCondSVT = MVT::getIntegerVT(EltSize);
17841     MVT NewCondVT = MVT::getVectorVT(NewCondSVT, NumElts);
17842     Cond = DAG.getSExtOrTrunc(Cond, dl, NewCondVT);
17843     return DAG.getNode(ISD::VSELECT, dl, VT, Cond, LHS, RHS);
17844   }
17845 
17846   // v16i16/v32i8 selects without AVX2, if the condition and another operand
17847   // are free to split, then better to split before expanding the
17848   // select. Don't bother with XOP as it has the fast VPCMOV instruction.
17849   // TODO: This is very similar to narrowVectorSelect.
17850   // TODO: Add Load splitting to isFreeToSplitVector ?
17851   if (EltSize < 32 && VT.is256BitVector() && !Subtarget.hasAVX2() &&
17852       !Subtarget.hasXOP()) {
17853     bool FreeCond = isFreeToSplitVector(Cond.getNode(), DAG);
17854     bool FreeLHS = isFreeToSplitVector(LHS.getNode(), DAG) ||
17855                    (ISD::isNormalLoad(LHS.getNode()) && LHS.hasOneUse());
17856     bool FreeRHS = isFreeToSplitVector(RHS.getNode(), DAG) ||
17857                    (ISD::isNormalLoad(RHS.getNode()) && RHS.hasOneUse());
17858     if (FreeCond && (FreeLHS || FreeRHS))
17859       return splitVectorOp(Op, DAG, dl);
17860   }
17861 
17862   // Only some types will be legal on some subtargets. If we can emit a legal
17863   // VSELECT-matching blend, return Op, and but if we need to expand, return
17864   // a null value.
17865   switch (VT.SimpleTy) {
17866   default:
17867     // Most of the vector types have blends past SSE4.1.
17868     return Op;
17869 
17870   case MVT::v32i8:
17871     // The byte blends for AVX vectors were introduced only in AVX2.
17872     if (Subtarget.hasAVX2())
17873       return Op;
17874 
17875     return SDValue();
17876 
17877   case MVT::v8i16:
17878   case MVT::v16i16: {
17879     // Bitcast everything to the vXi8 type and use a vXi8 vselect.
17880     MVT CastVT = MVT::getVectorVT(MVT::i8, NumElts * 2);
17881     Cond = DAG.getBitcast(CastVT, Cond);
17882     LHS = DAG.getBitcast(CastVT, LHS);
17883     RHS = DAG.getBitcast(CastVT, RHS);
17884     SDValue Select = DAG.getNode(ISD::VSELECT, dl, CastVT, Cond, LHS, RHS);
17885     return DAG.getBitcast(VT, Select);
17886   }
17887   }
17888 }
17889 
LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op,SelectionDAG & DAG)17890 static SDValue LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op, SelectionDAG &DAG) {
17891   MVT VT = Op.getSimpleValueType();
17892   SDValue Vec = Op.getOperand(0);
17893   SDValue Idx = Op.getOperand(1);
17894   assert(isa<ConstantSDNode>(Idx) && "Constant index expected");
17895   SDLoc dl(Op);
17896 
17897   if (!Vec.getSimpleValueType().is128BitVector())
17898     return SDValue();
17899 
17900   if (VT.getSizeInBits() == 8) {
17901     // If IdxVal is 0, it's cheaper to do a move instead of a pextrb, unless
17902     // we're going to zero extend the register or fold the store.
17903     if (llvm::isNullConstant(Idx) && !X86::mayFoldIntoZeroExtend(Op) &&
17904         !X86::mayFoldIntoStore(Op))
17905       return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
17906                          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32,
17907                                      DAG.getBitcast(MVT::v4i32, Vec), Idx));
17908 
17909     unsigned IdxVal = Idx->getAsZExtVal();
17910     SDValue Extract = DAG.getNode(X86ISD::PEXTRB, dl, MVT::i32, Vec,
17911                                   DAG.getTargetConstant(IdxVal, dl, MVT::i8));
17912     return DAG.getNode(ISD::TRUNCATE, dl, VT, Extract);
17913   }
17914 
17915   if (VT == MVT::f32) {
17916     // EXTRACTPS outputs to a GPR32 register which will require a movd to copy
17917     // the result back to FR32 register. It's only worth matching if the
17918     // result has a single use which is a store or a bitcast to i32.  And in
17919     // the case of a store, it's not worth it if the index is a constant 0,
17920     // because a MOVSSmr can be used instead, which is smaller and faster.
17921     if (!Op.hasOneUse())
17922       return SDValue();
17923     SDNode *User = *Op.getNode()->use_begin();
17924     if ((User->getOpcode() != ISD::STORE || isNullConstant(Idx)) &&
17925         (User->getOpcode() != ISD::BITCAST ||
17926          User->getValueType(0) != MVT::i32))
17927       return SDValue();
17928     SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32,
17929                                   DAG.getBitcast(MVT::v4i32, Vec), Idx);
17930     return DAG.getBitcast(MVT::f32, Extract);
17931   }
17932 
17933   if (VT == MVT::i32 || VT == MVT::i64)
17934       return Op;
17935 
17936   return SDValue();
17937 }
17938 
17939 /// Extract one bit from mask vector, like v16i1 or v8i1.
17940 /// AVX-512 feature.
ExtractBitFromMaskVector(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)17941 static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
17942                                         const X86Subtarget &Subtarget) {
17943   SDValue Vec = Op.getOperand(0);
17944   SDLoc dl(Vec);
17945   MVT VecVT = Vec.getSimpleValueType();
17946   SDValue Idx = Op.getOperand(1);
17947   auto* IdxC = dyn_cast<ConstantSDNode>(Idx);
17948   MVT EltVT = Op.getSimpleValueType();
17949 
17950   assert((VecVT.getVectorNumElements() <= 16 || Subtarget.hasBWI()) &&
17951          "Unexpected vector type in ExtractBitFromMaskVector");
17952 
17953   // variable index can't be handled in mask registers,
17954   // extend vector to VR512/128
17955   if (!IdxC) {
17956     unsigned NumElts = VecVT.getVectorNumElements();
17957     // Extending v8i1/v16i1 to 512-bit get better performance on KNL
17958     // than extending to 128/256bit.
17959     if (NumElts == 1) {
17960       Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
17961       MVT IntVT = MVT::getIntegerVT(Vec.getValueType().getVectorNumElements());
17962       return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, DAG.getBitcast(IntVT, Vec));
17963     }
17964     MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8;
17965     MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);
17966     SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec);
17967     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ExtEltVT, Ext, Idx);
17968     return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
17969   }
17970 
17971   unsigned IdxVal = IdxC->getZExtValue();
17972   if (IdxVal == 0) // the operation is legal
17973     return Op;
17974 
17975   // Extend to natively supported kshift.
17976   Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
17977 
17978   // Use kshiftr instruction to move to the lower element.
17979   Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
17980                     DAG.getTargetConstant(IdxVal, dl, MVT::i8));
17981 
17982   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec,
17983                      DAG.getIntPtrConstant(0, dl));
17984 }
17985 
17986 // Helper to find all the extracted elements from a vector.
getExtractedDemandedElts(SDNode * N)17987 static APInt getExtractedDemandedElts(SDNode *N) {
17988   MVT VT = N->getSimpleValueType(0);
17989   unsigned NumElts = VT.getVectorNumElements();
17990   APInt DemandedElts = APInt::getZero(NumElts);
17991   for (SDNode *User : N->uses()) {
17992     switch (User->getOpcode()) {
17993     case X86ISD::PEXTRB:
17994     case X86ISD::PEXTRW:
17995     case ISD::EXTRACT_VECTOR_ELT:
17996       if (!isa<ConstantSDNode>(User->getOperand(1))) {
17997         DemandedElts.setAllBits();
17998         return DemandedElts;
17999       }
18000       DemandedElts.setBit(User->getConstantOperandVal(1));
18001       break;
18002     case ISD::BITCAST: {
18003       if (!User->getValueType(0).isSimple() ||
18004           !User->getValueType(0).isVector()) {
18005         DemandedElts.setAllBits();
18006         return DemandedElts;
18007       }
18008       APInt DemandedSrcElts = getExtractedDemandedElts(User);
18009       DemandedElts |= APIntOps::ScaleBitMask(DemandedSrcElts, NumElts);
18010       break;
18011     }
18012     default:
18013       DemandedElts.setAllBits();
18014       return DemandedElts;
18015     }
18016   }
18017   return DemandedElts;
18018 }
18019 
18020 SDValue
LowerEXTRACT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const18021 X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
18022                                            SelectionDAG &DAG) const {
18023   SDLoc dl(Op);
18024   SDValue Vec = Op.getOperand(0);
18025   MVT VecVT = Vec.getSimpleValueType();
18026   SDValue Idx = Op.getOperand(1);
18027   auto* IdxC = dyn_cast<ConstantSDNode>(Idx);
18028 
18029   if (VecVT.getVectorElementType() == MVT::i1)
18030     return ExtractBitFromMaskVector(Op, DAG, Subtarget);
18031 
18032   if (!IdxC) {
18033     // Its more profitable to go through memory (1 cycles throughput)
18034     // than using VMOVD + VPERMV/PSHUFB sequence (2/3 cycles throughput)
18035     // IACA tool was used to get performance estimation
18036     // (https://software.intel.com/en-us/articles/intel-architecture-code-analyzer)
18037     //
18038     // example : extractelement <16 x i8> %a, i32 %i
18039     //
18040     // Block Throughput: 3.00 Cycles
18041     // Throughput Bottleneck: Port5
18042     //
18043     // | Num Of |   Ports pressure in cycles  |    |
18044     // |  Uops  |  0  - DV  |  5  |  6  |  7  |    |
18045     // ---------------------------------------------
18046     // |   1    |           | 1.0 |     |     | CP | vmovd xmm1, edi
18047     // |   1    |           | 1.0 |     |     | CP | vpshufb xmm0, xmm0, xmm1
18048     // |   2    | 1.0       | 1.0 |     |     | CP | vpextrb eax, xmm0, 0x0
18049     // Total Num Of Uops: 4
18050     //
18051     //
18052     // Block Throughput: 1.00 Cycles
18053     // Throughput Bottleneck: PORT2_AGU, PORT3_AGU, Port4
18054     //
18055     // |    |  Ports pressure in cycles   |  |
18056     // |Uops| 1 | 2 - D  |3 -  D  | 4 | 5 |  |
18057     // ---------------------------------------------------------
18058     // |2^  |   | 0.5    | 0.5    |1.0|   |CP| vmovaps xmmword ptr [rsp-0x18], xmm0
18059     // |1   |0.5|        |        |   |0.5|  | lea rax, ptr [rsp-0x18]
18060     // |1   |   |0.5, 0.5|0.5, 0.5|   |   |CP| mov al, byte ptr [rdi+rax*1]
18061     // Total Num Of Uops: 4
18062 
18063     return SDValue();
18064   }
18065 
18066   unsigned IdxVal = IdxC->getZExtValue();
18067 
18068   // If this is a 256-bit vector result, first extract the 128-bit vector and
18069   // then extract the element from the 128-bit vector.
18070   if (VecVT.is256BitVector() || VecVT.is512BitVector()) {
18071     // Get the 128-bit vector.
18072     Vec = extract128BitVector(Vec, IdxVal, DAG, dl);
18073     MVT EltVT = VecVT.getVectorElementType();
18074 
18075     unsigned ElemsPerChunk = 128 / EltVT.getSizeInBits();
18076     assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
18077 
18078     // Find IdxVal modulo ElemsPerChunk. Since ElemsPerChunk is a power of 2
18079     // this can be done with a mask.
18080     IdxVal &= ElemsPerChunk - 1;
18081     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec,
18082                        DAG.getIntPtrConstant(IdxVal, dl));
18083   }
18084 
18085   assert(VecVT.is128BitVector() && "Unexpected vector length");
18086 
18087   MVT VT = Op.getSimpleValueType();
18088 
18089   if (VT == MVT::i16) {
18090     // If IdxVal is 0, it's cheaper to do a move instead of a pextrw, unless
18091     // we're going to zero extend the register or fold the store (SSE41 only).
18092     if (IdxVal == 0 && !X86::mayFoldIntoZeroExtend(Op) &&
18093         !(Subtarget.hasSSE41() && X86::mayFoldIntoStore(Op))) {
18094       if (Subtarget.hasFP16())
18095         return Op;
18096 
18097       return DAG.getNode(ISD::TRUNCATE, dl, MVT::i16,
18098                          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32,
18099                                      DAG.getBitcast(MVT::v4i32, Vec), Idx));
18100     }
18101 
18102     SDValue Extract = DAG.getNode(X86ISD::PEXTRW, dl, MVT::i32, Vec,
18103                                   DAG.getTargetConstant(IdxVal, dl, MVT::i8));
18104     return DAG.getNode(ISD::TRUNCATE, dl, VT, Extract);
18105   }
18106 
18107   if (Subtarget.hasSSE41())
18108     if (SDValue Res = LowerEXTRACT_VECTOR_ELT_SSE4(Op, DAG))
18109       return Res;
18110 
18111   // Only extract a single element from a v16i8 source - determine the common
18112   // DWORD/WORD that all extractions share, and extract the sub-byte.
18113   // TODO: Add QWORD MOVQ extraction?
18114   if (VT == MVT::i8) {
18115     APInt DemandedElts = getExtractedDemandedElts(Vec.getNode());
18116     assert(DemandedElts.getBitWidth() == 16 && "Vector width mismatch");
18117 
18118     // Extract either the lowest i32 or any i16, and extract the sub-byte.
18119     int DWordIdx = IdxVal / 4;
18120     if (DWordIdx == 0 && DemandedElts == (DemandedElts & 15)) {
18121       SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32,
18122                                 DAG.getBitcast(MVT::v4i32, Vec),
18123                                 DAG.getIntPtrConstant(DWordIdx, dl));
18124       int ShiftVal = (IdxVal % 4) * 8;
18125       if (ShiftVal != 0)
18126         Res = DAG.getNode(ISD::SRL, dl, MVT::i32, Res,
18127                           DAG.getConstant(ShiftVal, dl, MVT::i8));
18128       return DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
18129     }
18130 
18131     int WordIdx = IdxVal / 2;
18132     if (DemandedElts == (DemandedElts & (3 << (WordIdx * 2)))) {
18133       SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i16,
18134                                 DAG.getBitcast(MVT::v8i16, Vec),
18135                                 DAG.getIntPtrConstant(WordIdx, dl));
18136       int ShiftVal = (IdxVal % 2) * 8;
18137       if (ShiftVal != 0)
18138         Res = DAG.getNode(ISD::SRL, dl, MVT::i16, Res,
18139                           DAG.getConstant(ShiftVal, dl, MVT::i8));
18140       return DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
18141     }
18142   }
18143 
18144   if (VT == MVT::f16 || VT.getSizeInBits() == 32) {
18145     if (IdxVal == 0)
18146       return Op;
18147 
18148     // Shuffle the element to the lowest element, then movss or movsh.
18149     SmallVector<int, 8> Mask(VecVT.getVectorNumElements(), -1);
18150     Mask[0] = static_cast<int>(IdxVal);
18151     Vec = DAG.getVectorShuffle(VecVT, dl, Vec, DAG.getUNDEF(VecVT), Mask);
18152     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Vec,
18153                        DAG.getIntPtrConstant(0, dl));
18154   }
18155 
18156   if (VT.getSizeInBits() == 64) {
18157     // FIXME: .td only matches this for <2 x f64>, not <2 x i64> on 32b
18158     // FIXME: seems like this should be unnecessary if mov{h,l}pd were taught
18159     //        to match extract_elt for f64.
18160     if (IdxVal == 0)
18161       return Op;
18162 
18163     // UNPCKHPD the element to the lowest double word, then movsd.
18164     // Note if the lower 64 bits of the result of the UNPCKHPD is then stored
18165     // to a f64mem, the whole operation is folded into a single MOVHPDmr.
18166     int Mask[2] = { 1, -1 };
18167     Vec = DAG.getVectorShuffle(VecVT, dl, Vec, DAG.getUNDEF(VecVT), Mask);
18168     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Vec,
18169                        DAG.getIntPtrConstant(0, dl));
18170   }
18171 
18172   return SDValue();
18173 }
18174 
18175 /// Insert one bit to mask vector, like v16i1 or v8i1.
18176 /// AVX-512 feature.
InsertBitToMaskVector(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)18177 static SDValue InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG,
18178                                      const X86Subtarget &Subtarget) {
18179   SDLoc dl(Op);
18180   SDValue Vec = Op.getOperand(0);
18181   SDValue Elt = Op.getOperand(1);
18182   SDValue Idx = Op.getOperand(2);
18183   MVT VecVT = Vec.getSimpleValueType();
18184 
18185   if (!isa<ConstantSDNode>(Idx)) {
18186     // Non constant index. Extend source and destination,
18187     // insert element and then truncate the result.
18188     unsigned NumElts = VecVT.getVectorNumElements();
18189     MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8;
18190     MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);
18191     SDValue ExtOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtVecVT,
18192       DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec),
18193       DAG.getNode(ISD::SIGN_EXTEND, dl, ExtEltVT, Elt), Idx);
18194     return DAG.getNode(ISD::TRUNCATE, dl, VecVT, ExtOp);
18195   }
18196 
18197   // Copy into a k-register, extract to v1i1 and insert_subvector.
18198   SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i1, Elt);
18199   return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT, Vec, EltInVec, Idx);
18200 }
18201 
LowerINSERT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const18202 SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
18203                                                   SelectionDAG &DAG) const {
18204   MVT VT = Op.getSimpleValueType();
18205   MVT EltVT = VT.getVectorElementType();
18206   unsigned NumElts = VT.getVectorNumElements();
18207   unsigned EltSizeInBits = EltVT.getScalarSizeInBits();
18208 
18209   if (EltVT == MVT::i1)
18210     return InsertBitToMaskVector(Op, DAG, Subtarget);
18211 
18212   SDLoc dl(Op);
18213   SDValue N0 = Op.getOperand(0);
18214   SDValue N1 = Op.getOperand(1);
18215   SDValue N2 = Op.getOperand(2);
18216   auto *N2C = dyn_cast<ConstantSDNode>(N2);
18217 
18218   if (EltVT == MVT::bf16) {
18219     MVT IVT = VT.changeVectorElementTypeToInteger();
18220     SDValue Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, IVT,
18221                               DAG.getBitcast(IVT, N0),
18222                               DAG.getBitcast(MVT::i16, N1), N2);
18223     return DAG.getBitcast(VT, Res);
18224   }
18225 
18226   if (!N2C) {
18227     // Variable insertion indices, usually we're better off spilling to stack,
18228     // but AVX512 can use a variable compare+select by comparing against all
18229     // possible vector indices, and FP insertion has less gpr->simd traffic.
18230     if (!(Subtarget.hasBWI() ||
18231           (Subtarget.hasAVX512() && EltSizeInBits >= 32) ||
18232           (Subtarget.hasSSE41() && (EltVT == MVT::f32 || EltVT == MVT::f64))))
18233       return SDValue();
18234 
18235     MVT IdxSVT = MVT::getIntegerVT(EltSizeInBits);
18236     MVT IdxVT = MVT::getVectorVT(IdxSVT, NumElts);
18237     if (!isTypeLegal(IdxSVT) || !isTypeLegal(IdxVT))
18238       return SDValue();
18239 
18240     SDValue IdxExt = DAG.getZExtOrTrunc(N2, dl, IdxSVT);
18241     SDValue IdxSplat = DAG.getSplatBuildVector(IdxVT, dl, IdxExt);
18242     SDValue EltSplat = DAG.getSplatBuildVector(VT, dl, N1);
18243 
18244     SmallVector<SDValue, 16> RawIndices;
18245     for (unsigned I = 0; I != NumElts; ++I)
18246       RawIndices.push_back(DAG.getConstant(I, dl, IdxSVT));
18247     SDValue Indices = DAG.getBuildVector(IdxVT, dl, RawIndices);
18248 
18249     // inselt N0, N1, N2 --> select (SplatN2 == {0,1,2...}) ? SplatN1 : N0.
18250     return DAG.getSelectCC(dl, IdxSplat, Indices, EltSplat, N0,
18251                            ISD::CondCode::SETEQ);
18252   }
18253 
18254   if (N2C->getAPIntValue().uge(NumElts))
18255     return SDValue();
18256   uint64_t IdxVal = N2C->getZExtValue();
18257 
18258   bool IsZeroElt = X86::isZeroNode(N1);
18259   bool IsAllOnesElt = VT.isInteger() && llvm::isAllOnesConstant(N1);
18260 
18261   if (IsZeroElt || IsAllOnesElt) {
18262     // Lower insertion of v16i8/v32i8/v64i16 -1 elts as an 'OR' blend.
18263     // We don't deal with i8 0 since it appears to be handled elsewhere.
18264     if (IsAllOnesElt &&
18265         ((VT == MVT::v16i8 && !Subtarget.hasSSE41()) ||
18266          ((VT == MVT::v32i8 || VT == MVT::v16i16) && !Subtarget.hasInt256()))) {
18267       SDValue ZeroCst = DAG.getConstant(0, dl, VT.getScalarType());
18268       SDValue OnesCst = DAG.getAllOnesConstant(dl, VT.getScalarType());
18269       SmallVector<SDValue, 8> CstVectorElts(NumElts, ZeroCst);
18270       CstVectorElts[IdxVal] = OnesCst;
18271       SDValue CstVector = DAG.getBuildVector(VT, dl, CstVectorElts);
18272       return DAG.getNode(ISD::OR, dl, VT, N0, CstVector);
18273     }
18274     // See if we can do this more efficiently with a blend shuffle with a
18275     // rematerializable vector.
18276     if (Subtarget.hasSSE41() &&
18277         (EltSizeInBits >= 16 || (IsZeroElt && !VT.is128BitVector()))) {
18278       SmallVector<int, 8> BlendMask;
18279       for (unsigned i = 0; i != NumElts; ++i)
18280         BlendMask.push_back(i == IdxVal ? i + NumElts : i);
18281       SDValue CstVector = IsZeroElt ? getZeroVector(VT, Subtarget, DAG, dl)
18282                                     : getOnesVector(VT, DAG, dl);
18283       return DAG.getVectorShuffle(VT, dl, N0, CstVector, BlendMask);
18284     }
18285   }
18286 
18287   // If the vector is wider than 128 bits, extract the 128-bit subvector, insert
18288   // into that, and then insert the subvector back into the result.
18289   if (VT.is256BitVector() || VT.is512BitVector()) {
18290     // With a 256-bit vector, we can insert into the zero element efficiently
18291     // using a blend if we have AVX or AVX2 and the right data type.
18292     if (VT.is256BitVector() && IdxVal == 0) {
18293       // TODO: It is worthwhile to cast integer to floating point and back
18294       // and incur a domain crossing penalty if that's what we'll end up
18295       // doing anyway after extracting to a 128-bit vector.
18296       if ((Subtarget.hasAVX() && (EltVT == MVT::f64 || EltVT == MVT::f32)) ||
18297           (Subtarget.hasAVX2() && (EltVT == MVT::i32 || EltVT == MVT::i64))) {
18298         SDValue N1Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, N1);
18299         return DAG.getNode(X86ISD::BLENDI, dl, VT, N0, N1Vec,
18300                            DAG.getTargetConstant(1, dl, MVT::i8));
18301       }
18302     }
18303 
18304     unsigned NumEltsIn128 = 128 / EltSizeInBits;
18305     assert(isPowerOf2_32(NumEltsIn128) &&
18306            "Vectors will always have power-of-two number of elements.");
18307 
18308     // If we are not inserting into the low 128-bit vector chunk,
18309     // then prefer the broadcast+blend sequence.
18310     // FIXME: relax the profitability check iff all N1 uses are insertions.
18311     if (IdxVal >= NumEltsIn128 &&
18312         ((Subtarget.hasAVX2() && EltSizeInBits != 8) ||
18313          (Subtarget.hasAVX() && (EltSizeInBits >= 32) &&
18314           X86::mayFoldLoad(N1, Subtarget)))) {
18315       SDValue N1SplatVec = DAG.getSplatBuildVector(VT, dl, N1);
18316       SmallVector<int, 8> BlendMask;
18317       for (unsigned i = 0; i != NumElts; ++i)
18318         BlendMask.push_back(i == IdxVal ? i + NumElts : i);
18319       return DAG.getVectorShuffle(VT, dl, N0, N1SplatVec, BlendMask);
18320     }
18321 
18322     // Get the desired 128-bit vector chunk.
18323     SDValue V = extract128BitVector(N0, IdxVal, DAG, dl);
18324 
18325     // Insert the element into the desired chunk.
18326     // Since NumEltsIn128 is a power of 2 we can use mask instead of modulo.
18327     unsigned IdxIn128 = IdxVal & (NumEltsIn128 - 1);
18328 
18329     V = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, V.getValueType(), V, N1,
18330                     DAG.getIntPtrConstant(IdxIn128, dl));
18331 
18332     // Insert the changed part back into the bigger vector
18333     return insert128BitVector(N0, V, IdxVal, DAG, dl);
18334   }
18335   assert(VT.is128BitVector() && "Only 128-bit vector types should be left!");
18336 
18337   // This will be just movw/movd/movq/movsh/movss/movsd.
18338   if (IdxVal == 0 && ISD::isBuildVectorAllZeros(N0.getNode())) {
18339     if (EltVT == MVT::i32 || EltVT == MVT::f32 || EltVT == MVT::f64 ||
18340         EltVT == MVT::f16 || EltVT == MVT::i64) {
18341       N1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, N1);
18342       return getShuffleVectorZeroOrUndef(N1, 0, true, Subtarget, DAG);
18343     }
18344 
18345     // We can't directly insert an i8 or i16 into a vector, so zero extend
18346     // it to i32 first.
18347     if (EltVT == MVT::i16 || EltVT == MVT::i8) {
18348       N1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, N1);
18349       MVT ShufVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32);
18350       N1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, ShufVT, N1);
18351       N1 = getShuffleVectorZeroOrUndef(N1, 0, true, Subtarget, DAG);
18352       return DAG.getBitcast(VT, N1);
18353     }
18354   }
18355 
18356   // Transform it so it match pinsr{b,w} which expects a GR32 as its second
18357   // argument. SSE41 required for pinsrb.
18358   if (VT == MVT::v8i16 || (VT == MVT::v16i8 && Subtarget.hasSSE41())) {
18359     unsigned Opc;
18360     if (VT == MVT::v8i16) {
18361       assert(Subtarget.hasSSE2() && "SSE2 required for PINSRW");
18362       Opc = X86ISD::PINSRW;
18363     } else {
18364       assert(VT == MVT::v16i8 && "PINSRB requires v16i8 vector");
18365       assert(Subtarget.hasSSE41() && "SSE41 required for PINSRB");
18366       Opc = X86ISD::PINSRB;
18367     }
18368 
18369     assert(N1.getValueType() != MVT::i32 && "Unexpected VT");
18370     N1 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, N1);
18371     N2 = DAG.getTargetConstant(IdxVal, dl, MVT::i8);
18372     return DAG.getNode(Opc, dl, VT, N0, N1, N2);
18373   }
18374 
18375   if (Subtarget.hasSSE41()) {
18376     if (EltVT == MVT::f32) {
18377       // Bits [7:6] of the constant are the source select. This will always be
18378       //   zero here. The DAG Combiner may combine an extract_elt index into
18379       //   these bits. For example (insert (extract, 3), 2) could be matched by
18380       //   putting the '3' into bits [7:6] of X86ISD::INSERTPS.
18381       // Bits [5:4] of the constant are the destination select. This is the
18382       //   value of the incoming immediate.
18383       // Bits [3:0] of the constant are the zero mask. The DAG Combiner may
18384       //   combine either bitwise AND or insert of float 0.0 to set these bits.
18385 
18386       bool MinSize = DAG.getMachineFunction().getFunction().hasMinSize();
18387       if (IdxVal == 0 && (!MinSize || !X86::mayFoldLoad(N1, Subtarget))) {
18388         // If this is an insertion of 32-bits into the low 32-bits of
18389         // a vector, we prefer to generate a blend with immediate rather
18390         // than an insertps. Blends are simpler operations in hardware and so
18391         // will always have equal or better performance than insertps.
18392         // But if optimizing for size and there's a load folding opportunity,
18393         // generate insertps because blendps does not have a 32-bit memory
18394         // operand form.
18395         N1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4f32, N1);
18396         return DAG.getNode(X86ISD::BLENDI, dl, VT, N0, N1,
18397                            DAG.getTargetConstant(1, dl, MVT::i8));
18398       }
18399       // Create this as a scalar to vector..
18400       N1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4f32, N1);
18401       return DAG.getNode(X86ISD::INSERTPS, dl, VT, N0, N1,
18402                          DAG.getTargetConstant(IdxVal << 4, dl, MVT::i8));
18403     }
18404 
18405     // PINSR* works with constant index.
18406     if (EltVT == MVT::i32 || EltVT == MVT::i64)
18407       return Op;
18408   }
18409 
18410   return SDValue();
18411 }
18412 
LowerSCALAR_TO_VECTOR(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)18413 static SDValue LowerSCALAR_TO_VECTOR(SDValue Op, const X86Subtarget &Subtarget,
18414                                      SelectionDAG &DAG) {
18415   SDLoc dl(Op);
18416   MVT OpVT = Op.getSimpleValueType();
18417 
18418   // It's always cheaper to replace a xor+movd with xorps and simplifies further
18419   // combines.
18420   if (X86::isZeroNode(Op.getOperand(0)))
18421     return getZeroVector(OpVT, Subtarget, DAG, dl);
18422 
18423   // If this is a 256-bit vector result, first insert into a 128-bit
18424   // vector and then insert into the 256-bit vector.
18425   if (!OpVT.is128BitVector()) {
18426     // Insert into a 128-bit vector.
18427     unsigned SizeFactor = OpVT.getSizeInBits() / 128;
18428     MVT VT128 = MVT::getVectorVT(OpVT.getVectorElementType(),
18429                                  OpVT.getVectorNumElements() / SizeFactor);
18430 
18431     Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT128, Op.getOperand(0));
18432 
18433     // Insert the 128-bit vector.
18434     return insert128BitVector(DAG.getUNDEF(OpVT), Op, 0, DAG, dl);
18435   }
18436   assert(OpVT.is128BitVector() && OpVT.isInteger() && OpVT != MVT::v2i64 &&
18437          "Expected an SSE type!");
18438 
18439   // Pass through a v4i32 or V8i16 SCALAR_TO_VECTOR as that's what we use in
18440   // tblgen.
18441   if (OpVT == MVT::v4i32 || (OpVT == MVT::v8i16 && Subtarget.hasFP16()))
18442     return Op;
18443 
18444   SDValue AnyExt = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Op.getOperand(0));
18445   return DAG.getBitcast(
18446       OpVT, DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, AnyExt));
18447 }
18448 
18449 // Lower a node with an INSERT_SUBVECTOR opcode.  This may result in a
18450 // simple superregister reference or explicit instructions to insert
18451 // the upper bits of a vector.
LowerINSERT_SUBVECTOR(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)18452 static SDValue LowerINSERT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget,
18453                                      SelectionDAG &DAG) {
18454   assert(Op.getSimpleValueType().getVectorElementType() == MVT::i1);
18455 
18456   return insert1BitVector(Op, DAG, Subtarget);
18457 }
18458 
LowerEXTRACT_SUBVECTOR(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)18459 static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget,
18460                                       SelectionDAG &DAG) {
18461   assert(Op.getSimpleValueType().getVectorElementType() == MVT::i1 &&
18462          "Only vXi1 extract_subvectors need custom lowering");
18463 
18464   SDLoc dl(Op);
18465   SDValue Vec = Op.getOperand(0);
18466   uint64_t IdxVal = Op.getConstantOperandVal(1);
18467 
18468   if (IdxVal == 0) // the operation is legal
18469     return Op;
18470 
18471   // Extend to natively supported kshift.
18472   Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
18473 
18474   // Shift to the LSB.
18475   Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
18476                     DAG.getTargetConstant(IdxVal, dl, MVT::i8));
18477 
18478   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec,
18479                      DAG.getIntPtrConstant(0, dl));
18480 }
18481 
18482 // Returns the appropriate wrapper opcode for a global reference.
getGlobalWrapperKind(const GlobalValue * GV,const unsigned char OpFlags) const18483 unsigned X86TargetLowering::getGlobalWrapperKind(
18484     const GlobalValue *GV, const unsigned char OpFlags) const {
18485   // References to absolute symbols are never PC-relative.
18486   if (GV && GV->isAbsoluteSymbolRef())
18487     return X86ISD::Wrapper;
18488 
18489   // The following OpFlags under RIP-rel PIC use RIP.
18490   if (Subtarget.isPICStyleRIPRel() &&
18491       (OpFlags == X86II::MO_NO_FLAG || OpFlags == X86II::MO_COFFSTUB ||
18492        OpFlags == X86II::MO_DLLIMPORT))
18493     return X86ISD::WrapperRIP;
18494 
18495   // GOTPCREL references must always use RIP.
18496   if (OpFlags == X86II::MO_GOTPCREL || OpFlags == X86II::MO_GOTPCREL_NORELAX)
18497     return X86ISD::WrapperRIP;
18498 
18499   return X86ISD::Wrapper;
18500 }
18501 
18502 // ConstantPool, JumpTable, GlobalAddress, and ExternalSymbol are lowered as
18503 // their target counterpart wrapped in the X86ISD::Wrapper node. Suppose N is
18504 // one of the above mentioned nodes. It has to be wrapped because otherwise
18505 // Select(N) returns N. So the raw TargetGlobalAddress nodes, etc. can only
18506 // be used to form addressing mode. These wrapped nodes will be selected
18507 // into MOV32ri.
18508 SDValue
LowerConstantPool(SDValue Op,SelectionDAG & DAG) const18509 X86TargetLowering::LowerConstantPool(SDValue Op, SelectionDAG &DAG) const {
18510   ConstantPoolSDNode *CP = cast<ConstantPoolSDNode>(Op);
18511 
18512   // In PIC mode (unless we're in RIPRel PIC mode) we add an offset to the
18513   // global base reg.
18514   unsigned char OpFlag = Subtarget.classifyLocalReference(nullptr);
18515 
18516   auto PtrVT = getPointerTy(DAG.getDataLayout());
18517   SDValue Result = DAG.getTargetConstantPool(
18518       CP->getConstVal(), PtrVT, CP->getAlign(), CP->getOffset(), OpFlag);
18519   SDLoc DL(CP);
18520   Result =
18521       DAG.getNode(getGlobalWrapperKind(nullptr, OpFlag), DL, PtrVT, Result);
18522   // With PIC, the address is actually $g + Offset.
18523   if (OpFlag) {
18524     Result =
18525         DAG.getNode(ISD::ADD, DL, PtrVT,
18526                     DAG.getNode(X86ISD::GlobalBaseReg, SDLoc(), PtrVT), Result);
18527   }
18528 
18529   return Result;
18530 }
18531 
LowerJumpTable(SDValue Op,SelectionDAG & DAG) const18532 SDValue X86TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
18533   JumpTableSDNode *JT = cast<JumpTableSDNode>(Op);
18534 
18535   // In PIC mode (unless we're in RIPRel PIC mode) we add an offset to the
18536   // global base reg.
18537   unsigned char OpFlag = Subtarget.classifyLocalReference(nullptr);
18538 
18539   auto PtrVT = getPointerTy(DAG.getDataLayout());
18540   SDValue Result = DAG.getTargetJumpTable(JT->getIndex(), PtrVT, OpFlag);
18541   SDLoc DL(JT);
18542   Result =
18543       DAG.getNode(getGlobalWrapperKind(nullptr, OpFlag), DL, PtrVT, Result);
18544 
18545   // With PIC, the address is actually $g + Offset.
18546   if (OpFlag)
18547     Result =
18548         DAG.getNode(ISD::ADD, DL, PtrVT,
18549                     DAG.getNode(X86ISD::GlobalBaseReg, SDLoc(), PtrVT), Result);
18550 
18551   return Result;
18552 }
18553 
LowerExternalSymbol(SDValue Op,SelectionDAG & DAG) const18554 SDValue X86TargetLowering::LowerExternalSymbol(SDValue Op,
18555                                                SelectionDAG &DAG) const {
18556   return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
18557 }
18558 
18559 SDValue
LowerBlockAddress(SDValue Op,SelectionDAG & DAG) const18560 X86TargetLowering::LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const {
18561   // Create the TargetBlockAddressAddress node.
18562   unsigned char OpFlags =
18563     Subtarget.classifyBlockAddressReference();
18564   const BlockAddress *BA = cast<BlockAddressSDNode>(Op)->getBlockAddress();
18565   int64_t Offset = cast<BlockAddressSDNode>(Op)->getOffset();
18566   SDLoc dl(Op);
18567   auto PtrVT = getPointerTy(DAG.getDataLayout());
18568   SDValue Result = DAG.getTargetBlockAddress(BA, PtrVT, Offset, OpFlags);
18569   Result =
18570       DAG.getNode(getGlobalWrapperKind(nullptr, OpFlags), dl, PtrVT, Result);
18571 
18572   // With PIC, the address is actually $g + Offset.
18573   if (isGlobalRelativeToPICBase(OpFlags)) {
18574     Result = DAG.getNode(ISD::ADD, dl, PtrVT,
18575                          DAG.getNode(X86ISD::GlobalBaseReg, dl, PtrVT), Result);
18576   }
18577 
18578   return Result;
18579 }
18580 
18581 /// Creates target global address or external symbol nodes for calls or
18582 /// other uses.
LowerGlobalOrExternal(SDValue Op,SelectionDAG & DAG,bool ForCall) const18583 SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
18584                                                  bool ForCall) const {
18585   // Unpack the global address or external symbol.
18586   SDLoc dl(Op);
18587   const GlobalValue *GV = nullptr;
18588   int64_t Offset = 0;
18589   const char *ExternalSym = nullptr;
18590   if (const auto *G = dyn_cast<GlobalAddressSDNode>(Op)) {
18591     GV = G->getGlobal();
18592     Offset = G->getOffset();
18593   } else {
18594     const auto *ES = cast<ExternalSymbolSDNode>(Op);
18595     ExternalSym = ES->getSymbol();
18596   }
18597 
18598   // Calculate some flags for address lowering.
18599   const Module &Mod = *DAG.getMachineFunction().getFunction().getParent();
18600   unsigned char OpFlags;
18601   if (ForCall)
18602     OpFlags = Subtarget.classifyGlobalFunctionReference(GV, Mod);
18603   else
18604     OpFlags = Subtarget.classifyGlobalReference(GV, Mod);
18605   bool HasPICReg = isGlobalRelativeToPICBase(OpFlags);
18606   bool NeedsLoad = isGlobalStubReference(OpFlags);
18607 
18608   CodeModel::Model M = DAG.getTarget().getCodeModel();
18609   auto PtrVT = getPointerTy(DAG.getDataLayout());
18610   SDValue Result;
18611 
18612   if (GV) {
18613     // Create a target global address if this is a global. If possible, fold the
18614     // offset into the global address reference. Otherwise, ADD it on later.
18615     // Suppress the folding if Offset is negative: movl foo-1, %eax is not
18616     // allowed because if the address of foo is 0, the ELF R_X86_64_32
18617     // relocation will compute to a negative value, which is invalid.
18618     int64_t GlobalOffset = 0;
18619     if (OpFlags == X86II::MO_NO_FLAG && Offset >= 0 &&
18620         X86::isOffsetSuitableForCodeModel(Offset, M, true)) {
18621       std::swap(GlobalOffset, Offset);
18622     }
18623     Result = DAG.getTargetGlobalAddress(GV, dl, PtrVT, GlobalOffset, OpFlags);
18624   } else {
18625     // If this is not a global address, this must be an external symbol.
18626     Result = DAG.getTargetExternalSymbol(ExternalSym, PtrVT, OpFlags);
18627   }
18628 
18629   // If this is a direct call, avoid the wrapper if we don't need to do any
18630   // loads or adds. This allows SDAG ISel to match direct calls.
18631   if (ForCall && !NeedsLoad && !HasPICReg && Offset == 0)
18632     return Result;
18633 
18634   Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result);
18635 
18636   // With PIC, the address is actually $g + Offset.
18637   if (HasPICReg) {
18638     Result = DAG.getNode(ISD::ADD, dl, PtrVT,
18639                          DAG.getNode(X86ISD::GlobalBaseReg, dl, PtrVT), Result);
18640   }
18641 
18642   // For globals that require a load from a stub to get the address, emit the
18643   // load.
18644   if (NeedsLoad)
18645     Result = DAG.getLoad(PtrVT, dl, DAG.getEntryNode(), Result,
18646                          MachinePointerInfo::getGOT(DAG.getMachineFunction()));
18647 
18648   // If there was a non-zero offset that we didn't fold, create an explicit
18649   // addition for it.
18650   if (Offset != 0)
18651     Result = DAG.getNode(ISD::ADD, dl, PtrVT, Result,
18652                          DAG.getConstant(Offset, dl, PtrVT));
18653 
18654   return Result;
18655 }
18656 
18657 SDValue
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const18658 X86TargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
18659   return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
18660 }
18661 
18662 static SDValue
GetTLSADDR(SelectionDAG & DAG,SDValue Chain,GlobalAddressSDNode * GA,SDValue * InGlue,const EVT PtrVT,unsigned ReturnReg,unsigned char OperandFlags,bool LocalDynamic=false)18663 GetTLSADDR(SelectionDAG &DAG, SDValue Chain, GlobalAddressSDNode *GA,
18664            SDValue *InGlue, const EVT PtrVT, unsigned ReturnReg,
18665            unsigned char OperandFlags, bool LocalDynamic = false) {
18666   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
18667   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
18668   SDLoc dl(GA);
18669   SDValue TGA;
18670   bool UseTLSDESC = DAG.getTarget().useTLSDESC();
18671   if (LocalDynamic && UseTLSDESC) {
18672     TGA = DAG.getTargetExternalSymbol("_TLS_MODULE_BASE_", PtrVT, OperandFlags);
18673     auto UI = TGA->use_begin();
18674     // Reuse existing GetTLSADDR node if we can find it.
18675     if (UI != TGA->use_end())
18676       return SDValue(*UI->use_begin()->use_begin(), 0);
18677   } else {
18678     TGA = DAG.getTargetGlobalAddress(GA->getGlobal(), dl, GA->getValueType(0),
18679                                      GA->getOffset(), OperandFlags);
18680   }
18681 
18682   X86ISD::NodeType CallType = UseTLSDESC     ? X86ISD::TLSDESC
18683                               : LocalDynamic ? X86ISD::TLSBASEADDR
18684                                              : X86ISD::TLSADDR;
18685 
18686   if (InGlue) {
18687     SDValue Ops[] = { Chain,  TGA, *InGlue };
18688     Chain = DAG.getNode(CallType, dl, NodeTys, Ops);
18689   } else {
18690     SDValue Ops[]  = { Chain, TGA };
18691     Chain = DAG.getNode(CallType, dl, NodeTys, Ops);
18692   }
18693 
18694   // TLSADDR will be codegen'ed as call. Inform MFI that function has calls.
18695   MFI.setAdjustsStack(true);
18696   MFI.setHasCalls(true);
18697 
18698   SDValue Glue = Chain.getValue(1);
18699   SDValue Ret = DAG.getCopyFromReg(Chain, dl, ReturnReg, PtrVT, Glue);
18700 
18701   if (!UseTLSDESC)
18702     return Ret;
18703 
18704   const X86Subtarget &Subtarget = DAG.getSubtarget<X86Subtarget>();
18705   unsigned Seg = Subtarget.is64Bit() ? X86AS::FS : X86AS::GS;
18706 
18707   Value *Ptr = Constant::getNullValue(PointerType::get(*DAG.getContext(), Seg));
18708   SDValue Offset =
18709       DAG.getLoad(PtrVT, dl, DAG.getEntryNode(), DAG.getIntPtrConstant(0, dl),
18710                   MachinePointerInfo(Ptr));
18711   return DAG.getNode(ISD::ADD, dl, PtrVT, Ret, Offset);
18712 }
18713 
18714 // Lower ISD::GlobalTLSAddress using the "general dynamic" model, 32 bit
18715 static SDValue
LowerToTLSGeneralDynamicModel32(GlobalAddressSDNode * GA,SelectionDAG & DAG,const EVT PtrVT)18716 LowerToTLSGeneralDynamicModel32(GlobalAddressSDNode *GA, SelectionDAG &DAG,
18717                                 const EVT PtrVT) {
18718   SDValue InGlue;
18719   SDLoc dl(GA);  // ? function entry point might be better
18720   SDValue Chain = DAG.getCopyToReg(DAG.getEntryNode(), dl, X86::EBX,
18721                                    DAG.getNode(X86ISD::GlobalBaseReg,
18722                                                SDLoc(), PtrVT), InGlue);
18723   InGlue = Chain.getValue(1);
18724 
18725   return GetTLSADDR(DAG, Chain, GA, &InGlue, PtrVT, X86::EAX, X86II::MO_TLSGD);
18726 }
18727 
18728 // Lower ISD::GlobalTLSAddress using the "general dynamic" model, 64 bit LP64
18729 static SDValue
LowerToTLSGeneralDynamicModel64(GlobalAddressSDNode * GA,SelectionDAG & DAG,const EVT PtrVT)18730 LowerToTLSGeneralDynamicModel64(GlobalAddressSDNode *GA, SelectionDAG &DAG,
18731                                 const EVT PtrVT) {
18732   return GetTLSADDR(DAG, DAG.getEntryNode(), GA, nullptr, PtrVT,
18733                     X86::RAX, X86II::MO_TLSGD);
18734 }
18735 
18736 // Lower ISD::GlobalTLSAddress using the "general dynamic" model, 64 bit ILP32
18737 static SDValue
LowerToTLSGeneralDynamicModelX32(GlobalAddressSDNode * GA,SelectionDAG & DAG,const EVT PtrVT)18738 LowerToTLSGeneralDynamicModelX32(GlobalAddressSDNode *GA, SelectionDAG &DAG,
18739                                  const EVT PtrVT) {
18740   return GetTLSADDR(DAG, DAG.getEntryNode(), GA, nullptr, PtrVT,
18741                     X86::EAX, X86II::MO_TLSGD);
18742 }
18743 
LowerToTLSLocalDynamicModel(GlobalAddressSDNode * GA,SelectionDAG & DAG,const EVT PtrVT,bool Is64Bit,bool Is64BitLP64)18744 static SDValue LowerToTLSLocalDynamicModel(GlobalAddressSDNode *GA,
18745                                            SelectionDAG &DAG, const EVT PtrVT,
18746                                            bool Is64Bit, bool Is64BitLP64) {
18747   SDLoc dl(GA);
18748 
18749   // Get the start address of the TLS block for this module.
18750   X86MachineFunctionInfo *MFI = DAG.getMachineFunction()
18751       .getInfo<X86MachineFunctionInfo>();
18752   MFI->incNumLocalDynamicTLSAccesses();
18753 
18754   SDValue Base;
18755   if (Is64Bit) {
18756     unsigned ReturnReg = Is64BitLP64 ? X86::RAX : X86::EAX;
18757     Base = GetTLSADDR(DAG, DAG.getEntryNode(), GA, nullptr, PtrVT, ReturnReg,
18758                       X86II::MO_TLSLD, /*LocalDynamic=*/true);
18759   } else {
18760     SDValue InGlue;
18761     SDValue Chain = DAG.getCopyToReg(DAG.getEntryNode(), dl, X86::EBX,
18762         DAG.getNode(X86ISD::GlobalBaseReg, SDLoc(), PtrVT), InGlue);
18763     InGlue = Chain.getValue(1);
18764     Base = GetTLSADDR(DAG, Chain, GA, &InGlue, PtrVT, X86::EAX,
18765                       X86II::MO_TLSLDM, /*LocalDynamic=*/true);
18766   }
18767 
18768   // Note: the CleanupLocalDynamicTLSPass will remove redundant computations
18769   // of Base.
18770 
18771   // Build x@dtpoff.
18772   unsigned char OperandFlags = X86II::MO_DTPOFF;
18773   unsigned WrapperKind = X86ISD::Wrapper;
18774   SDValue TGA = DAG.getTargetGlobalAddress(GA->getGlobal(), dl,
18775                                            GA->getValueType(0),
18776                                            GA->getOffset(), OperandFlags);
18777   SDValue Offset = DAG.getNode(WrapperKind, dl, PtrVT, TGA);
18778 
18779   // Add x@dtpoff with the base.
18780   return DAG.getNode(ISD::ADD, dl, PtrVT, Offset, Base);
18781 }
18782 
18783 // Lower ISD::GlobalTLSAddress using the "initial exec" or "local exec" model.
LowerToTLSExecModel(GlobalAddressSDNode * GA,SelectionDAG & DAG,const EVT PtrVT,TLSModel::Model model,bool is64Bit,bool isPIC)18784 static SDValue LowerToTLSExecModel(GlobalAddressSDNode *GA, SelectionDAG &DAG,
18785                                    const EVT PtrVT, TLSModel::Model model,
18786                                    bool is64Bit, bool isPIC) {
18787   SDLoc dl(GA);
18788 
18789   // Get the Thread Pointer, which is %gs:0 (32-bit) or %fs:0 (64-bit).
18790   Value *Ptr = Constant::getNullValue(
18791       PointerType::get(*DAG.getContext(), is64Bit ? 257 : 256));
18792 
18793   SDValue ThreadPointer =
18794       DAG.getLoad(PtrVT, dl, DAG.getEntryNode(), DAG.getIntPtrConstant(0, dl),
18795                   MachinePointerInfo(Ptr));
18796 
18797   unsigned char OperandFlags = 0;
18798   // Most TLS accesses are not RIP relative, even on x86-64.  One exception is
18799   // initialexec.
18800   unsigned WrapperKind = X86ISD::Wrapper;
18801   if (model == TLSModel::LocalExec) {
18802     OperandFlags = is64Bit ? X86II::MO_TPOFF : X86II::MO_NTPOFF;
18803   } else if (model == TLSModel::InitialExec) {
18804     if (is64Bit) {
18805       OperandFlags = X86II::MO_GOTTPOFF;
18806       WrapperKind = X86ISD::WrapperRIP;
18807     } else {
18808       OperandFlags = isPIC ? X86II::MO_GOTNTPOFF : X86II::MO_INDNTPOFF;
18809     }
18810   } else {
18811     llvm_unreachable("Unexpected model");
18812   }
18813 
18814   // emit "addl x@ntpoff,%eax" (local exec)
18815   // or "addl x@indntpoff,%eax" (initial exec)
18816   // or "addl x@gotntpoff(%ebx) ,%eax" (initial exec, 32-bit pic)
18817   SDValue TGA =
18818       DAG.getTargetGlobalAddress(GA->getGlobal(), dl, GA->getValueType(0),
18819                                  GA->getOffset(), OperandFlags);
18820   SDValue Offset = DAG.getNode(WrapperKind, dl, PtrVT, TGA);
18821 
18822   if (model == TLSModel::InitialExec) {
18823     if (isPIC && !is64Bit) {
18824       Offset = DAG.getNode(ISD::ADD, dl, PtrVT,
18825                            DAG.getNode(X86ISD::GlobalBaseReg, SDLoc(), PtrVT),
18826                            Offset);
18827     }
18828 
18829     Offset = DAG.getLoad(PtrVT, dl, DAG.getEntryNode(), Offset,
18830                          MachinePointerInfo::getGOT(DAG.getMachineFunction()));
18831   }
18832 
18833   // The address of the thread local variable is the add of the thread
18834   // pointer with the offset of the variable.
18835   return DAG.getNode(ISD::ADD, dl, PtrVT, ThreadPointer, Offset);
18836 }
18837 
18838 SDValue
LowerGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const18839 X86TargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const {
18840 
18841   GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
18842 
18843   if (DAG.getTarget().useEmulatedTLS())
18844     return LowerToTLSEmulatedModel(GA, DAG);
18845 
18846   const GlobalValue *GV = GA->getGlobal();
18847   auto PtrVT = getPointerTy(DAG.getDataLayout());
18848   bool PositionIndependent = isPositionIndependent();
18849 
18850   if (Subtarget.isTargetELF()) {
18851     TLSModel::Model model = DAG.getTarget().getTLSModel(GV);
18852     switch (model) {
18853       case TLSModel::GeneralDynamic:
18854         if (Subtarget.is64Bit()) {
18855           if (Subtarget.isTarget64BitLP64())
18856             return LowerToTLSGeneralDynamicModel64(GA, DAG, PtrVT);
18857           return LowerToTLSGeneralDynamicModelX32(GA, DAG, PtrVT);
18858         }
18859         return LowerToTLSGeneralDynamicModel32(GA, DAG, PtrVT);
18860       case TLSModel::LocalDynamic:
18861         return LowerToTLSLocalDynamicModel(GA, DAG, PtrVT, Subtarget.is64Bit(),
18862                                            Subtarget.isTarget64BitLP64());
18863       case TLSModel::InitialExec:
18864       case TLSModel::LocalExec:
18865         return LowerToTLSExecModel(GA, DAG, PtrVT, model, Subtarget.is64Bit(),
18866                                    PositionIndependent);
18867     }
18868     llvm_unreachable("Unknown TLS model.");
18869   }
18870 
18871   if (Subtarget.isTargetDarwin()) {
18872     // Darwin only has one model of TLS.  Lower to that.
18873     unsigned char OpFlag = 0;
18874     unsigned WrapperKind = 0;
18875 
18876     // In PIC mode (unless we're in RIPRel PIC mode) we add an offset to the
18877     // global base reg.
18878     bool PIC32 = PositionIndependent && !Subtarget.is64Bit();
18879     if (PIC32) {
18880       OpFlag = X86II::MO_TLVP_PIC_BASE;
18881       WrapperKind = X86ISD::Wrapper;
18882     } else {
18883       OpFlag = X86II::MO_TLVP;
18884       WrapperKind = X86ISD::WrapperRIP;
18885     }
18886     SDLoc DL(Op);
18887     SDValue Result = DAG.getTargetGlobalAddress(GA->getGlobal(), DL,
18888                                                 GA->getValueType(0),
18889                                                 GA->getOffset(), OpFlag);
18890     SDValue Offset = DAG.getNode(WrapperKind, DL, PtrVT, Result);
18891 
18892     // With PIC32, the address is actually $g + Offset.
18893     if (PIC32)
18894       Offset = DAG.getNode(ISD::ADD, DL, PtrVT,
18895                            DAG.getNode(X86ISD::GlobalBaseReg, SDLoc(), PtrVT),
18896                            Offset);
18897 
18898     // Lowering the machine isd will make sure everything is in the right
18899     // location.
18900     SDValue Chain = DAG.getEntryNode();
18901     SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
18902     Chain = DAG.getCALLSEQ_START(Chain, 0, 0, DL);
18903     SDValue Args[] = { Chain, Offset };
18904     Chain = DAG.getNode(X86ISD::TLSCALL, DL, NodeTys, Args);
18905     Chain = DAG.getCALLSEQ_END(Chain, 0, 0, Chain.getValue(1), DL);
18906 
18907     // TLSCALL will be codegen'ed as call. Inform MFI that function has calls.
18908     MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
18909     MFI.setAdjustsStack(true);
18910 
18911     // And our return value (tls address) is in the standard call return value
18912     // location.
18913     unsigned Reg = Subtarget.is64Bit() ? X86::RAX : X86::EAX;
18914     return DAG.getCopyFromReg(Chain, DL, Reg, PtrVT, Chain.getValue(1));
18915   }
18916 
18917   if (Subtarget.isOSWindows()) {
18918     // Just use the implicit TLS architecture
18919     // Need to generate something similar to:
18920     //   mov     rdx, qword [gs:abs 58H]; Load pointer to ThreadLocalStorage
18921     //                                  ; from TEB
18922     //   mov     ecx, dword [rel _tls_index]: Load index (from C runtime)
18923     //   mov     rcx, qword [rdx+rcx*8]
18924     //   mov     eax, .tls$:tlsvar
18925     //   [rax+rcx] contains the address
18926     // Windows 64bit: gs:0x58
18927     // Windows 32bit: fs:__tls_array
18928 
18929     SDLoc dl(GA);
18930     SDValue Chain = DAG.getEntryNode();
18931 
18932     // Get the Thread Pointer, which is %fs:__tls_array (32-bit) or
18933     // %gs:0x58 (64-bit). On MinGW, __tls_array is not available, so directly
18934     // use its literal value of 0x2C.
18935     Value *Ptr = Constant::getNullValue(
18936         Subtarget.is64Bit() ? PointerType::get(*DAG.getContext(), 256)
18937                             : PointerType::get(*DAG.getContext(), 257));
18938 
18939     SDValue TlsArray = Subtarget.is64Bit()
18940                            ? DAG.getIntPtrConstant(0x58, dl)
18941                            : (Subtarget.isTargetWindowsGNU()
18942                                   ? DAG.getIntPtrConstant(0x2C, dl)
18943                                   : DAG.getExternalSymbol("_tls_array", PtrVT));
18944 
18945     SDValue ThreadPointer =
18946         DAG.getLoad(PtrVT, dl, Chain, TlsArray, MachinePointerInfo(Ptr));
18947 
18948     SDValue res;
18949     if (GV->getThreadLocalMode() == GlobalVariable::LocalExecTLSModel) {
18950       res = ThreadPointer;
18951     } else {
18952       // Load the _tls_index variable
18953       SDValue IDX = DAG.getExternalSymbol("_tls_index", PtrVT);
18954       if (Subtarget.is64Bit())
18955         IDX = DAG.getExtLoad(ISD::ZEXTLOAD, dl, PtrVT, Chain, IDX,
18956                              MachinePointerInfo(), MVT::i32);
18957       else
18958         IDX = DAG.getLoad(PtrVT, dl, Chain, IDX, MachinePointerInfo());
18959 
18960       const DataLayout &DL = DAG.getDataLayout();
18961       SDValue Scale =
18962           DAG.getConstant(Log2_64_Ceil(DL.getPointerSize()), dl, MVT::i8);
18963       IDX = DAG.getNode(ISD::SHL, dl, PtrVT, IDX, Scale);
18964 
18965       res = DAG.getNode(ISD::ADD, dl, PtrVT, ThreadPointer, IDX);
18966     }
18967 
18968     res = DAG.getLoad(PtrVT, dl, Chain, res, MachinePointerInfo());
18969 
18970     // Get the offset of start of .tls section
18971     SDValue TGA = DAG.getTargetGlobalAddress(GA->getGlobal(), dl,
18972                                              GA->getValueType(0),
18973                                              GA->getOffset(), X86II::MO_SECREL);
18974     SDValue Offset = DAG.getNode(X86ISD::Wrapper, dl, PtrVT, TGA);
18975 
18976     // The address of the thread local variable is the add of the thread
18977     // pointer with the offset of the variable.
18978     return DAG.getNode(ISD::ADD, dl, PtrVT, res, Offset);
18979   }
18980 
18981   llvm_unreachable("TLS not implemented for this target.");
18982 }
18983 
addressingModeSupportsTLS(const GlobalValue & GV) const18984 bool X86TargetLowering::addressingModeSupportsTLS(const GlobalValue &GV) const {
18985   if (Subtarget.is64Bit() && Subtarget.isTargetELF()) {
18986     const TargetMachine &TM = getTargetMachine();
18987     TLSModel::Model Model = TM.getTLSModel(&GV);
18988     switch (Model) {
18989     case TLSModel::LocalExec:
18990     case TLSModel::InitialExec:
18991       // We can include the %fs segment register in addressing modes.
18992       return true;
18993     case TLSModel::LocalDynamic:
18994     case TLSModel::GeneralDynamic:
18995       // These models do not result in %fs relative addresses unless
18996       // TLS descriptior are used.
18997       //
18998       // Even in the case of TLS descriptors we currently have no way to model
18999       // the difference between %fs access and the computations needed for the
19000       // offset and returning `true` for TLS-desc currently duplicates both
19001       // which is detrimental :-/
19002       return false;
19003     }
19004   }
19005   return false;
19006 }
19007 
19008 /// Lower SRA_PARTS and friends, which return two i32 values
19009 /// and take a 2 x i32 value to shift plus a shift amount.
19010 /// TODO: Can this be moved to general expansion code?
LowerShiftParts(SDValue Op,SelectionDAG & DAG)19011 static SDValue LowerShiftParts(SDValue Op, SelectionDAG &DAG) {
19012   SDValue Lo, Hi;
19013   DAG.getTargetLoweringInfo().expandShiftParts(Op.getNode(), Lo, Hi, DAG);
19014   return DAG.getMergeValues({Lo, Hi}, SDLoc(Op));
19015 }
19016 
19017 // Try to use a packed vector operation to handle i64 on 32-bit targets when
19018 // AVX512DQ is enabled.
LowerI64IntToFP_AVX512DQ(SDValue Op,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)19019 static SDValue LowerI64IntToFP_AVX512DQ(SDValue Op, const SDLoc &dl,
19020                                         SelectionDAG &DAG,
19021                                         const X86Subtarget &Subtarget) {
19022   assert((Op.getOpcode() == ISD::SINT_TO_FP ||
19023           Op.getOpcode() == ISD::STRICT_SINT_TO_FP ||
19024           Op.getOpcode() == ISD::STRICT_UINT_TO_FP ||
19025           Op.getOpcode() == ISD::UINT_TO_FP) &&
19026          "Unexpected opcode!");
19027   bool IsStrict = Op->isStrictFPOpcode();
19028   unsigned OpNo = IsStrict ? 1 : 0;
19029   SDValue Src = Op.getOperand(OpNo);
19030   MVT SrcVT = Src.getSimpleValueType();
19031   MVT VT = Op.getSimpleValueType();
19032 
19033    if (!Subtarget.hasDQI() || SrcVT != MVT::i64 || Subtarget.is64Bit() ||
19034        (VT != MVT::f32 && VT != MVT::f64))
19035     return SDValue();
19036 
19037   // Pack the i64 into a vector, do the operation and extract.
19038 
19039   // Using 256-bit to ensure result is 128-bits for f32 case.
19040   unsigned NumElts = Subtarget.hasVLX() ? 4 : 8;
19041   MVT VecInVT = MVT::getVectorVT(MVT::i64, NumElts);
19042   MVT VecVT = MVT::getVectorVT(VT, NumElts);
19043 
19044   SDValue InVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecInVT, Src);
19045   if (IsStrict) {
19046     SDValue CvtVec = DAG.getNode(Op.getOpcode(), dl, {VecVT, MVT::Other},
19047                                  {Op.getOperand(0), InVec});
19048     SDValue Chain = CvtVec.getValue(1);
19049     SDValue Value = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, CvtVec,
19050                                 DAG.getIntPtrConstant(0, dl));
19051     return DAG.getMergeValues({Value, Chain}, dl);
19052   }
19053 
19054   SDValue CvtVec = DAG.getNode(Op.getOpcode(), dl, VecVT, InVec);
19055 
19056   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, CvtVec,
19057                      DAG.getIntPtrConstant(0, dl));
19058 }
19059 
19060 // Try to use a packed vector operation to handle i64 on 32-bit targets.
LowerI64IntToFP16(SDValue Op,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)19061 static SDValue LowerI64IntToFP16(SDValue Op, const SDLoc &dl, SelectionDAG &DAG,
19062                                  const X86Subtarget &Subtarget) {
19063   assert((Op.getOpcode() == ISD::SINT_TO_FP ||
19064           Op.getOpcode() == ISD::STRICT_SINT_TO_FP ||
19065           Op.getOpcode() == ISD::STRICT_UINT_TO_FP ||
19066           Op.getOpcode() == ISD::UINT_TO_FP) &&
19067          "Unexpected opcode!");
19068   bool IsStrict = Op->isStrictFPOpcode();
19069   SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
19070   MVT SrcVT = Src.getSimpleValueType();
19071   MVT VT = Op.getSimpleValueType();
19072 
19073   if (SrcVT != MVT::i64 || Subtarget.is64Bit() || VT != MVT::f16)
19074     return SDValue();
19075 
19076   // Pack the i64 into a vector, do the operation and extract.
19077 
19078   assert(Subtarget.hasFP16() && "Expected FP16");
19079 
19080   SDValue InVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2i64, Src);
19081   if (IsStrict) {
19082     SDValue CvtVec = DAG.getNode(Op.getOpcode(), dl, {MVT::v2f16, MVT::Other},
19083                                  {Op.getOperand(0), InVec});
19084     SDValue Chain = CvtVec.getValue(1);
19085     SDValue Value = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, CvtVec,
19086                                 DAG.getIntPtrConstant(0, dl));
19087     return DAG.getMergeValues({Value, Chain}, dl);
19088   }
19089 
19090   SDValue CvtVec = DAG.getNode(Op.getOpcode(), dl, MVT::v2f16, InVec);
19091 
19092   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, CvtVec,
19093                      DAG.getIntPtrConstant(0, dl));
19094 }
19095 
useVectorCast(unsigned Opcode,MVT FromVT,MVT ToVT,const X86Subtarget & Subtarget)19096 static bool useVectorCast(unsigned Opcode, MVT FromVT, MVT ToVT,
19097                           const X86Subtarget &Subtarget) {
19098   switch (Opcode) {
19099     case ISD::SINT_TO_FP:
19100       // TODO: Handle wider types with AVX/AVX512.
19101       if (!Subtarget.hasSSE2() || FromVT != MVT::v4i32)
19102         return false;
19103       // CVTDQ2PS or (V)CVTDQ2PD
19104       return ToVT == MVT::v4f32 || (Subtarget.hasAVX() && ToVT == MVT::v4f64);
19105 
19106     case ISD::UINT_TO_FP:
19107       // TODO: Handle wider types and i64 elements.
19108       if (!Subtarget.hasAVX512() || FromVT != MVT::v4i32)
19109         return false;
19110       // VCVTUDQ2PS or VCVTUDQ2PD
19111       return ToVT == MVT::v4f32 || ToVT == MVT::v4f64;
19112 
19113     default:
19114       return false;
19115   }
19116 }
19117 
19118 /// Given a scalar cast operation that is extracted from a vector, try to
19119 /// vectorize the cast op followed by extraction. This will avoid an expensive
19120 /// round-trip between XMM and GPR.
vectorizeExtractedCast(SDValue Cast,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)19121 static SDValue vectorizeExtractedCast(SDValue Cast, const SDLoc &DL,
19122                                       SelectionDAG &DAG,
19123                                       const X86Subtarget &Subtarget) {
19124   // TODO: This could be enhanced to handle smaller integer types by peeking
19125   // through an extend.
19126   SDValue Extract = Cast.getOperand(0);
19127   MVT DestVT = Cast.getSimpleValueType();
19128   if (Extract.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
19129       !isa<ConstantSDNode>(Extract.getOperand(1)))
19130     return SDValue();
19131 
19132   // See if we have a 128-bit vector cast op for this type of cast.
19133   SDValue VecOp = Extract.getOperand(0);
19134   MVT FromVT = VecOp.getSimpleValueType();
19135   unsigned NumEltsInXMM = 128 / FromVT.getScalarSizeInBits();
19136   MVT Vec128VT = MVT::getVectorVT(FromVT.getScalarType(), NumEltsInXMM);
19137   MVT ToVT = MVT::getVectorVT(DestVT, NumEltsInXMM);
19138   if (!useVectorCast(Cast.getOpcode(), Vec128VT, ToVT, Subtarget))
19139     return SDValue();
19140 
19141   // If we are extracting from a non-zero element, first shuffle the source
19142   // vector to allow extracting from element zero.
19143   if (!isNullConstant(Extract.getOperand(1))) {
19144     SmallVector<int, 16> Mask(FromVT.getVectorNumElements(), -1);
19145     Mask[0] = Extract.getConstantOperandVal(1);
19146     VecOp = DAG.getVectorShuffle(FromVT, DL, VecOp, DAG.getUNDEF(FromVT), Mask);
19147   }
19148   // If the source vector is wider than 128-bits, extract the low part. Do not
19149   // create an unnecessarily wide vector cast op.
19150   if (FromVT != Vec128VT)
19151     VecOp = extract128BitVector(VecOp, 0, DAG, DL);
19152 
19153   // cast (extelt V, 0) --> extelt (cast (extract_subv V)), 0
19154   // cast (extelt V, C) --> extelt (cast (extract_subv (shuffle V, [C...]))), 0
19155   SDValue VCast = DAG.getNode(Cast.getOpcode(), DL, ToVT, VecOp);
19156   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, DestVT, VCast,
19157                      DAG.getIntPtrConstant(0, DL));
19158 }
19159 
19160 /// Given a scalar cast to FP with a cast to integer operand (almost an ftrunc),
19161 /// try to vectorize the cast ops. This will avoid an expensive round-trip
19162 /// between XMM and GPR.
lowerFPToIntToFP(SDValue CastToFP,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)19163 static SDValue lowerFPToIntToFP(SDValue CastToFP, const SDLoc &DL,
19164                                 SelectionDAG &DAG,
19165                                 const X86Subtarget &Subtarget) {
19166   // TODO: Allow FP_TO_UINT.
19167   SDValue CastToInt = CastToFP.getOperand(0);
19168   MVT VT = CastToFP.getSimpleValueType();
19169   if (CastToInt.getOpcode() != ISD::FP_TO_SINT || VT.isVector())
19170     return SDValue();
19171 
19172   MVT IntVT = CastToInt.getSimpleValueType();
19173   SDValue X = CastToInt.getOperand(0);
19174   MVT SrcVT = X.getSimpleValueType();
19175   if (SrcVT != MVT::f32 && SrcVT != MVT::f64)
19176     return SDValue();
19177 
19178   // See if we have 128-bit vector cast instructions for this type of cast.
19179   // We need cvttps2dq/cvttpd2dq and cvtdq2ps/cvtdq2pd.
19180   if (!Subtarget.hasSSE2() || (VT != MVT::f32 && VT != MVT::f64) ||
19181       IntVT != MVT::i32)
19182     return SDValue();
19183 
19184   unsigned SrcSize = SrcVT.getSizeInBits();
19185   unsigned IntSize = IntVT.getSizeInBits();
19186   unsigned VTSize = VT.getSizeInBits();
19187   MVT VecSrcVT = MVT::getVectorVT(SrcVT, 128 / SrcSize);
19188   MVT VecIntVT = MVT::getVectorVT(IntVT, 128 / IntSize);
19189   MVT VecVT = MVT::getVectorVT(VT, 128 / VTSize);
19190 
19191   // We need target-specific opcodes if this is v2f64 -> v4i32 -> v2f64.
19192   unsigned ToIntOpcode =
19193       SrcSize != IntSize ? X86ISD::CVTTP2SI : (unsigned)ISD::FP_TO_SINT;
19194   unsigned ToFPOpcode =
19195       IntSize != VTSize ? X86ISD::CVTSI2P : (unsigned)ISD::SINT_TO_FP;
19196 
19197   // sint_to_fp (fp_to_sint X) --> extelt (sint_to_fp (fp_to_sint (s2v X))), 0
19198   //
19199   // We are not defining the high elements (for example, zero them) because
19200   // that could nullify any performance advantage that we hoped to gain from
19201   // this vector op hack. We do not expect any adverse effects (like denorm
19202   // penalties) with cast ops.
19203   SDValue ZeroIdx = DAG.getIntPtrConstant(0, DL);
19204   SDValue VecX = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecSrcVT, X);
19205   SDValue VCastToInt = DAG.getNode(ToIntOpcode, DL, VecIntVT, VecX);
19206   SDValue VCastToFP = DAG.getNode(ToFPOpcode, DL, VecVT, VCastToInt);
19207   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VCastToFP, ZeroIdx);
19208 }
19209 
lowerINT_TO_FP_vXi64(SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)19210 static SDValue lowerINT_TO_FP_vXi64(SDValue Op, const SDLoc &DL,
19211                                     SelectionDAG &DAG,
19212                                     const X86Subtarget &Subtarget) {
19213   bool IsStrict = Op->isStrictFPOpcode();
19214   MVT VT = Op->getSimpleValueType(0);
19215   SDValue Src = Op->getOperand(IsStrict ? 1 : 0);
19216 
19217   if (Subtarget.hasDQI()) {
19218     assert(!Subtarget.hasVLX() && "Unexpected features");
19219 
19220     assert((Src.getSimpleValueType() == MVT::v2i64 ||
19221             Src.getSimpleValueType() == MVT::v4i64) &&
19222            "Unsupported custom type");
19223 
19224     // With AVX512DQ, but not VLX we need to widen to get a 512-bit result type.
19225     assert((VT == MVT::v4f32 || VT == MVT::v2f64 || VT == MVT::v4f64) &&
19226            "Unexpected VT!");
19227     MVT WideVT = VT == MVT::v4f32 ? MVT::v8f32 : MVT::v8f64;
19228 
19229     // Need to concat with zero vector for strict fp to avoid spurious
19230     // exceptions.
19231     SDValue Tmp = IsStrict ? DAG.getConstant(0, DL, MVT::v8i64)
19232                            : DAG.getUNDEF(MVT::v8i64);
19233     Src = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v8i64, Tmp, Src,
19234                       DAG.getIntPtrConstant(0, DL));
19235     SDValue Res, Chain;
19236     if (IsStrict) {
19237       Res = DAG.getNode(Op.getOpcode(), DL, {WideVT, MVT::Other},
19238                         {Op->getOperand(0), Src});
19239       Chain = Res.getValue(1);
19240     } else {
19241       Res = DAG.getNode(Op.getOpcode(), DL, WideVT, Src);
19242     }
19243 
19244     Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
19245                       DAG.getIntPtrConstant(0, DL));
19246 
19247     if (IsStrict)
19248       return DAG.getMergeValues({Res, Chain}, DL);
19249     return Res;
19250   }
19251 
19252   bool IsSigned = Op->getOpcode() == ISD::SINT_TO_FP ||
19253                   Op->getOpcode() == ISD::STRICT_SINT_TO_FP;
19254   if (VT != MVT::v4f32 || IsSigned)
19255     return SDValue();
19256 
19257   SDValue Zero = DAG.getConstant(0, DL, MVT::v4i64);
19258   SDValue One  = DAG.getConstant(1, DL, MVT::v4i64);
19259   SDValue Sign = DAG.getNode(ISD::OR, DL, MVT::v4i64,
19260                              DAG.getNode(ISD::SRL, DL, MVT::v4i64, Src, One),
19261                              DAG.getNode(ISD::AND, DL, MVT::v4i64, Src, One));
19262   SDValue IsNeg = DAG.getSetCC(DL, MVT::v4i64, Src, Zero, ISD::SETLT);
19263   SDValue SignSrc = DAG.getSelect(DL, MVT::v4i64, IsNeg, Sign, Src);
19264   SmallVector<SDValue, 4> SignCvts(4);
19265   SmallVector<SDValue, 4> Chains(4);
19266   for (int i = 0; i != 4; ++i) {
19267     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, SignSrc,
19268                               DAG.getIntPtrConstant(i, DL));
19269     if (IsStrict) {
19270       SignCvts[i] =
19271           DAG.getNode(ISD::STRICT_SINT_TO_FP, DL, {MVT::f32, MVT::Other},
19272                       {Op.getOperand(0), Elt});
19273       Chains[i] = SignCvts[i].getValue(1);
19274     } else {
19275       SignCvts[i] = DAG.getNode(ISD::SINT_TO_FP, DL, MVT::f32, Elt);
19276     }
19277   }
19278   SDValue SignCvt = DAG.getBuildVector(VT, DL, SignCvts);
19279 
19280   SDValue Slow, Chain;
19281   if (IsStrict) {
19282     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
19283     Slow = DAG.getNode(ISD::STRICT_FADD, DL, {MVT::v4f32, MVT::Other},
19284                        {Chain, SignCvt, SignCvt});
19285     Chain = Slow.getValue(1);
19286   } else {
19287     Slow = DAG.getNode(ISD::FADD, DL, MVT::v4f32, SignCvt, SignCvt);
19288   }
19289 
19290   IsNeg = DAG.getNode(ISD::TRUNCATE, DL, MVT::v4i32, IsNeg);
19291   SDValue Cvt = DAG.getSelect(DL, MVT::v4f32, IsNeg, Slow, SignCvt);
19292 
19293   if (IsStrict)
19294     return DAG.getMergeValues({Cvt, Chain}, DL);
19295 
19296   return Cvt;
19297 }
19298 
promoteXINT_TO_FP(SDValue Op,const SDLoc & dl,SelectionDAG & DAG)19299 static SDValue promoteXINT_TO_FP(SDValue Op, const SDLoc &dl,
19300                                  SelectionDAG &DAG) {
19301   bool IsStrict = Op->isStrictFPOpcode();
19302   SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
19303   SDValue Chain = IsStrict ? Op->getOperand(0) : DAG.getEntryNode();
19304   MVT VT = Op.getSimpleValueType();
19305   MVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
19306 
19307   SDValue Rnd = DAG.getIntPtrConstant(0, dl);
19308   if (IsStrict)
19309     return DAG.getNode(
19310         ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
19311         {Chain,
19312          DAG.getNode(Op.getOpcode(), dl, {NVT, MVT::Other}, {Chain, Src}),
19313          Rnd});
19314   return DAG.getNode(ISD::FP_ROUND, dl, VT,
19315                      DAG.getNode(Op.getOpcode(), dl, NVT, Src), Rnd);
19316 }
19317 
isLegalConversion(MVT VT,bool IsSigned,const X86Subtarget & Subtarget)19318 static bool isLegalConversion(MVT VT, bool IsSigned,
19319                               const X86Subtarget &Subtarget) {
19320   if (VT == MVT::v4i32 && Subtarget.hasSSE2() && IsSigned)
19321     return true;
19322   if (VT == MVT::v8i32 && Subtarget.hasAVX() && IsSigned)
19323     return true;
19324   if (Subtarget.hasVLX() && (VT == MVT::v4i32 || VT == MVT::v8i32))
19325     return true;
19326   if (Subtarget.useAVX512Regs()) {
19327     if (VT == MVT::v16i32)
19328       return true;
19329     if (VT == MVT::v8i64 && Subtarget.hasDQI())
19330       return true;
19331   }
19332   if (Subtarget.hasDQI() && Subtarget.hasVLX() &&
19333       (VT == MVT::v2i64 || VT == MVT::v4i64))
19334     return true;
19335   return false;
19336 }
19337 
LowerSINT_TO_FP(SDValue Op,SelectionDAG & DAG) const19338 SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op,
19339                                            SelectionDAG &DAG) const {
19340   bool IsStrict = Op->isStrictFPOpcode();
19341   unsigned OpNo = IsStrict ? 1 : 0;
19342   SDValue Src = Op.getOperand(OpNo);
19343   SDValue Chain = IsStrict ? Op->getOperand(0) : DAG.getEntryNode();
19344   MVT SrcVT = Src.getSimpleValueType();
19345   MVT VT = Op.getSimpleValueType();
19346   SDLoc dl(Op);
19347 
19348   if (isSoftF16(VT, Subtarget))
19349     return promoteXINT_TO_FP(Op, dl, DAG);
19350   else if (isLegalConversion(SrcVT, true, Subtarget))
19351     return Op;
19352 
19353   if (Subtarget.isTargetWin64() && SrcVT == MVT::i128)
19354     return LowerWin64_INT128_TO_FP(Op, DAG);
19355 
19356   if (SDValue Extract = vectorizeExtractedCast(Op, dl, DAG, Subtarget))
19357     return Extract;
19358 
19359   if (SDValue R = lowerFPToIntToFP(Op, dl, DAG, Subtarget))
19360     return R;
19361 
19362   if (SrcVT.isVector()) {
19363     if (SrcVT == MVT::v2i32 && VT == MVT::v2f64) {
19364       // Note: Since v2f64 is a legal type. We don't need to zero extend the
19365       // source for strict FP.
19366       if (IsStrict)
19367         return DAG.getNode(
19368             X86ISD::STRICT_CVTSI2P, dl, {VT, MVT::Other},
19369             {Chain, DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src,
19370                                 DAG.getUNDEF(SrcVT))});
19371       return DAG.getNode(X86ISD::CVTSI2P, dl, VT,
19372                          DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src,
19373                                      DAG.getUNDEF(SrcVT)));
19374     }
19375     if (SrcVT == MVT::v2i64 || SrcVT == MVT::v4i64)
19376       return lowerINT_TO_FP_vXi64(Op, dl, DAG, Subtarget);
19377 
19378     return SDValue();
19379   }
19380 
19381   assert(SrcVT <= MVT::i64 && SrcVT >= MVT::i16 &&
19382          "Unknown SINT_TO_FP to lower!");
19383 
19384   bool UseSSEReg = isScalarFPTypeInSSEReg(VT);
19385 
19386   // These are really Legal; return the operand so the caller accepts it as
19387   // Legal.
19388   if (SrcVT == MVT::i32 && UseSSEReg)
19389     return Op;
19390   if (SrcVT == MVT::i64 && UseSSEReg && Subtarget.is64Bit())
19391     return Op;
19392 
19393   if (SDValue V = LowerI64IntToFP_AVX512DQ(Op, dl, DAG, Subtarget))
19394     return V;
19395   if (SDValue V = LowerI64IntToFP16(Op, dl, DAG, Subtarget))
19396     return V;
19397 
19398   // SSE doesn't have an i16 conversion so we need to promote.
19399   if (SrcVT == MVT::i16 && (UseSSEReg || VT == MVT::f128)) {
19400     SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i32, Src);
19401     if (IsStrict)
19402       return DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {VT, MVT::Other},
19403                          {Chain, Ext});
19404 
19405     return DAG.getNode(ISD::SINT_TO_FP, dl, VT, Ext);
19406   }
19407 
19408   if (VT == MVT::f128 || !Subtarget.hasX87())
19409     return SDValue();
19410 
19411   SDValue ValueToStore = Src;
19412   if (SrcVT == MVT::i64 && Subtarget.hasSSE2() && !Subtarget.is64Bit())
19413     // Bitcasting to f64 here allows us to do a single 64-bit store from
19414     // an SSE register, avoiding the store forwarding penalty that would come
19415     // with two 32-bit stores.
19416     ValueToStore = DAG.getBitcast(MVT::f64, ValueToStore);
19417 
19418   unsigned Size = SrcVT.getStoreSize();
19419   Align Alignment(Size);
19420   MachineFunction &MF = DAG.getMachineFunction();
19421   auto PtrVT = getPointerTy(MF.getDataLayout());
19422   int SSFI = MF.getFrameInfo().CreateStackObject(Size, Alignment, false);
19423   MachinePointerInfo MPI =
19424       MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SSFI);
19425   SDValue StackSlot = DAG.getFrameIndex(SSFI, PtrVT);
19426   Chain = DAG.getStore(Chain, dl, ValueToStore, StackSlot, MPI, Alignment);
19427   std::pair<SDValue, SDValue> Tmp =
19428       BuildFILD(VT, SrcVT, dl, Chain, StackSlot, MPI, Alignment, DAG);
19429 
19430   if (IsStrict)
19431     return DAG.getMergeValues({Tmp.first, Tmp.second}, dl);
19432 
19433   return Tmp.first;
19434 }
19435 
BuildFILD(EVT DstVT,EVT SrcVT,const SDLoc & DL,SDValue Chain,SDValue Pointer,MachinePointerInfo PtrInfo,Align Alignment,SelectionDAG & DAG) const19436 std::pair<SDValue, SDValue> X86TargetLowering::BuildFILD(
19437     EVT DstVT, EVT SrcVT, const SDLoc &DL, SDValue Chain, SDValue Pointer,
19438     MachinePointerInfo PtrInfo, Align Alignment, SelectionDAG &DAG) const {
19439   // Build the FILD
19440   SDVTList Tys;
19441   bool useSSE = isScalarFPTypeInSSEReg(DstVT);
19442   if (useSSE)
19443     Tys = DAG.getVTList(MVT::f80, MVT::Other);
19444   else
19445     Tys = DAG.getVTList(DstVT, MVT::Other);
19446 
19447   SDValue FILDOps[] = {Chain, Pointer};
19448   SDValue Result =
19449       DAG.getMemIntrinsicNode(X86ISD::FILD, DL, Tys, FILDOps, SrcVT, PtrInfo,
19450                               Alignment, MachineMemOperand::MOLoad);
19451   Chain = Result.getValue(1);
19452 
19453   if (useSSE) {
19454     MachineFunction &MF = DAG.getMachineFunction();
19455     unsigned SSFISize = DstVT.getStoreSize();
19456     int SSFI =
19457         MF.getFrameInfo().CreateStackObject(SSFISize, Align(SSFISize), false);
19458     auto PtrVT = getPointerTy(MF.getDataLayout());
19459     SDValue StackSlot = DAG.getFrameIndex(SSFI, PtrVT);
19460     Tys = DAG.getVTList(MVT::Other);
19461     SDValue FSTOps[] = {Chain, Result, StackSlot};
19462     MachineMemOperand *StoreMMO = DAG.getMachineFunction().getMachineMemOperand(
19463         MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SSFI),
19464         MachineMemOperand::MOStore, SSFISize, Align(SSFISize));
19465 
19466     Chain =
19467         DAG.getMemIntrinsicNode(X86ISD::FST, DL, Tys, FSTOps, DstVT, StoreMMO);
19468     Result = DAG.getLoad(
19469         DstVT, DL, Chain, StackSlot,
19470         MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SSFI));
19471     Chain = Result.getValue(1);
19472   }
19473 
19474   return { Result, Chain };
19475 }
19476 
19477 /// Horizontal vector math instructions may be slower than normal math with
19478 /// shuffles. Limit horizontal op codegen based on size/speed trade-offs, uarch
19479 /// implementation, and likely shuffle complexity of the alternate sequence.
shouldUseHorizontalOp(bool IsSingleSource,SelectionDAG & DAG,const X86Subtarget & Subtarget)19480 static bool shouldUseHorizontalOp(bool IsSingleSource, SelectionDAG &DAG,
19481                                   const X86Subtarget &Subtarget) {
19482   bool IsOptimizingSize = DAG.shouldOptForSize();
19483   bool HasFastHOps = Subtarget.hasFastHorizontalOps();
19484   return !IsSingleSource || IsOptimizingSize || HasFastHOps;
19485 }
19486 
19487 /// 64-bit unsigned integer to double expansion.
LowerUINT_TO_FP_i64(SDValue Op,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)19488 static SDValue LowerUINT_TO_FP_i64(SDValue Op, const SDLoc &dl,
19489                                    SelectionDAG &DAG,
19490                                    const X86Subtarget &Subtarget) {
19491   // We can't use this algorithm for strict fp. It produces -0.0 instead of +0.0
19492   // when converting 0 when rounding toward negative infinity. Caller will
19493   // fall back to Expand for when i64 or is legal or use FILD in 32-bit mode.
19494   assert(!Op->isStrictFPOpcode() && "Expected non-strict uint_to_fp!");
19495   // This algorithm is not obvious. Here it is what we're trying to output:
19496   /*
19497      movq       %rax,  %xmm0
19498      punpckldq  (c0),  %xmm0  // c0: (uint4){ 0x43300000U, 0x45300000U, 0U, 0U }
19499      subpd      (c1),  %xmm0  // c1: (double2){ 0x1.0p52, 0x1.0p52 * 0x1.0p32 }
19500      #ifdef __SSE3__
19501        haddpd   %xmm0, %xmm0
19502      #else
19503        pshufd   $0x4e, %xmm0, %xmm1
19504        addpd    %xmm1, %xmm0
19505      #endif
19506   */
19507 
19508   LLVMContext *Context = DAG.getContext();
19509 
19510   // Build some magic constants.
19511   static const uint32_t CV0[] = { 0x43300000, 0x45300000, 0, 0 };
19512   Constant *C0 = ConstantDataVector::get(*Context, CV0);
19513   auto PtrVT = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
19514   SDValue CPIdx0 = DAG.getConstantPool(C0, PtrVT, Align(16));
19515 
19516   SmallVector<Constant*,2> CV1;
19517   CV1.push_back(
19518     ConstantFP::get(*Context, APFloat(APFloat::IEEEdouble(),
19519                                       APInt(64, 0x4330000000000000ULL))));
19520   CV1.push_back(
19521     ConstantFP::get(*Context, APFloat(APFloat::IEEEdouble(),
19522                                       APInt(64, 0x4530000000000000ULL))));
19523   Constant *C1 = ConstantVector::get(CV1);
19524   SDValue CPIdx1 = DAG.getConstantPool(C1, PtrVT, Align(16));
19525 
19526   // Load the 64-bit value into an XMM register.
19527   SDValue XR1 =
19528       DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2i64, Op.getOperand(0));
19529   SDValue CLod0 = DAG.getLoad(
19530       MVT::v4i32, dl, DAG.getEntryNode(), CPIdx0,
19531       MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), Align(16));
19532   SDValue Unpck1 =
19533       getUnpackl(DAG, dl, MVT::v4i32, DAG.getBitcast(MVT::v4i32, XR1), CLod0);
19534 
19535   SDValue CLod1 = DAG.getLoad(
19536       MVT::v2f64, dl, CLod0.getValue(1), CPIdx1,
19537       MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), Align(16));
19538   SDValue XR2F = DAG.getBitcast(MVT::v2f64, Unpck1);
19539   // TODO: Are there any fast-math-flags to propagate here?
19540   SDValue Sub = DAG.getNode(ISD::FSUB, dl, MVT::v2f64, XR2F, CLod1);
19541   SDValue Result;
19542 
19543   if (Subtarget.hasSSE3() &&
19544       shouldUseHorizontalOp(true, DAG, Subtarget)) {
19545     Result = DAG.getNode(X86ISD::FHADD, dl, MVT::v2f64, Sub, Sub);
19546   } else {
19547     SDValue Shuffle = DAG.getVectorShuffle(MVT::v2f64, dl, Sub, Sub, {1,-1});
19548     Result = DAG.getNode(ISD::FADD, dl, MVT::v2f64, Shuffle, Sub);
19549   }
19550   Result = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64, Result,
19551                        DAG.getIntPtrConstant(0, dl));
19552   return Result;
19553 }
19554 
19555 /// 32-bit unsigned integer to float expansion.
LowerUINT_TO_FP_i32(SDValue Op,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)19556 static SDValue LowerUINT_TO_FP_i32(SDValue Op, const SDLoc &dl,
19557                                    SelectionDAG &DAG,
19558                                    const X86Subtarget &Subtarget) {
19559   unsigned OpNo = Op.getNode()->isStrictFPOpcode() ? 1 : 0;
19560   // FP constant to bias correct the final result.
19561   SDValue Bias = DAG.getConstantFP(
19562       llvm::bit_cast<double>(0x4330000000000000ULL), dl, MVT::f64);
19563 
19564   // Load the 32-bit value into an XMM register.
19565   SDValue Load =
19566       DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, Op.getOperand(OpNo));
19567 
19568   // Zero out the upper parts of the register.
19569   Load = getShuffleVectorZeroOrUndef(Load, 0, true, Subtarget, DAG);
19570 
19571   // Or the load with the bias.
19572   SDValue Or = DAG.getNode(
19573       ISD::OR, dl, MVT::v2i64,
19574       DAG.getBitcast(MVT::v2i64, Load),
19575       DAG.getBitcast(MVT::v2i64,
19576                      DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2f64, Bias)));
19577   Or =
19578       DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64,
19579                   DAG.getBitcast(MVT::v2f64, Or), DAG.getIntPtrConstant(0, dl));
19580 
19581   if (Op.getNode()->isStrictFPOpcode()) {
19582     // Subtract the bias.
19583     // TODO: Are there any fast-math-flags to propagate here?
19584     SDValue Chain = Op.getOperand(0);
19585     SDValue Sub = DAG.getNode(ISD::STRICT_FSUB, dl, {MVT::f64, MVT::Other},
19586                               {Chain, Or, Bias});
19587 
19588     if (Op.getValueType() == Sub.getValueType())
19589       return Sub;
19590 
19591     // Handle final rounding.
19592     std::pair<SDValue, SDValue> ResultPair = DAG.getStrictFPExtendOrRound(
19593         Sub, Sub.getValue(1), dl, Op.getSimpleValueType());
19594 
19595     return DAG.getMergeValues({ResultPair.first, ResultPair.second}, dl);
19596   }
19597 
19598   // Subtract the bias.
19599   // TODO: Are there any fast-math-flags to propagate here?
19600   SDValue Sub = DAG.getNode(ISD::FSUB, dl, MVT::f64, Or, Bias);
19601 
19602   // Handle final rounding.
19603   return DAG.getFPExtendOrRound(Sub, dl, Op.getSimpleValueType());
19604 }
19605 
lowerUINT_TO_FP_v2i32(SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)19606 static SDValue lowerUINT_TO_FP_v2i32(SDValue Op, const SDLoc &DL,
19607                                      SelectionDAG &DAG,
19608                                      const X86Subtarget &Subtarget) {
19609   if (Op.getSimpleValueType() != MVT::v2f64)
19610     return SDValue();
19611 
19612   bool IsStrict = Op->isStrictFPOpcode();
19613 
19614   SDValue N0 = Op.getOperand(IsStrict ? 1 : 0);
19615   assert(N0.getSimpleValueType() == MVT::v2i32 && "Unexpected input type");
19616 
19617   if (Subtarget.hasAVX512()) {
19618     if (!Subtarget.hasVLX()) {
19619       // Let generic type legalization widen this.
19620       if (!IsStrict)
19621         return SDValue();
19622       // Otherwise pad the integer input with 0s and widen the operation.
19623       N0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0,
19624                        DAG.getConstant(0, DL, MVT::v2i32));
19625       SDValue Res = DAG.getNode(Op->getOpcode(), DL, {MVT::v4f64, MVT::Other},
19626                                 {Op.getOperand(0), N0});
19627       SDValue Chain = Res.getValue(1);
19628       Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2f64, Res,
19629                         DAG.getIntPtrConstant(0, DL));
19630       return DAG.getMergeValues({Res, Chain}, DL);
19631     }
19632 
19633     // Legalize to v4i32 type.
19634     N0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0,
19635                      DAG.getUNDEF(MVT::v2i32));
19636     if (IsStrict)
19637       return DAG.getNode(X86ISD::STRICT_CVTUI2P, DL, {MVT::v2f64, MVT::Other},
19638                          {Op.getOperand(0), N0});
19639     return DAG.getNode(X86ISD::CVTUI2P, DL, MVT::v2f64, N0);
19640   }
19641 
19642   // Zero extend to 2i64, OR with the floating point representation of 2^52.
19643   // This gives us the floating point equivalent of 2^52 + the i32 integer
19644   // since double has 52-bits of mantissa. Then subtract 2^52 in floating
19645   // point leaving just our i32 integers in double format.
19646   SDValue ZExtIn = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v2i64, N0);
19647   SDValue VBias = DAG.getConstantFP(
19648       llvm::bit_cast<double>(0x4330000000000000ULL), DL, MVT::v2f64);
19649   SDValue Or = DAG.getNode(ISD::OR, DL, MVT::v2i64, ZExtIn,
19650                            DAG.getBitcast(MVT::v2i64, VBias));
19651   Or = DAG.getBitcast(MVT::v2f64, Or);
19652 
19653   if (IsStrict)
19654     return DAG.getNode(ISD::STRICT_FSUB, DL, {MVT::v2f64, MVT::Other},
19655                        {Op.getOperand(0), Or, VBias});
19656   return DAG.getNode(ISD::FSUB, DL, MVT::v2f64, Or, VBias);
19657 }
19658 
lowerUINT_TO_FP_vXi32(SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)19659 static SDValue lowerUINT_TO_FP_vXi32(SDValue Op, const SDLoc &DL,
19660                                      SelectionDAG &DAG,
19661                                      const X86Subtarget &Subtarget) {
19662   bool IsStrict = Op->isStrictFPOpcode();
19663   SDValue V = Op->getOperand(IsStrict ? 1 : 0);
19664   MVT VecIntVT = V.getSimpleValueType();
19665   assert((VecIntVT == MVT::v4i32 || VecIntVT == MVT::v8i32) &&
19666          "Unsupported custom type");
19667 
19668   if (Subtarget.hasAVX512()) {
19669     // With AVX512, but not VLX we need to widen to get a 512-bit result type.
19670     assert(!Subtarget.hasVLX() && "Unexpected features");
19671     MVT VT = Op->getSimpleValueType(0);
19672 
19673     // v8i32->v8f64 is legal with AVX512 so just return it.
19674     if (VT == MVT::v8f64)
19675       return Op;
19676 
19677     assert((VT == MVT::v4f32 || VT == MVT::v8f32 || VT == MVT::v4f64) &&
19678            "Unexpected VT!");
19679     MVT WideVT = VT == MVT::v4f64 ? MVT::v8f64 : MVT::v16f32;
19680     MVT WideIntVT = VT == MVT::v4f64 ? MVT::v8i32 : MVT::v16i32;
19681     // Need to concat with zero vector for strict fp to avoid spurious
19682     // exceptions.
19683     SDValue Tmp =
19684         IsStrict ? DAG.getConstant(0, DL, WideIntVT) : DAG.getUNDEF(WideIntVT);
19685     V = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideIntVT, Tmp, V,
19686                     DAG.getIntPtrConstant(0, DL));
19687     SDValue Res, Chain;
19688     if (IsStrict) {
19689       Res = DAG.getNode(ISD::STRICT_UINT_TO_FP, DL, {WideVT, MVT::Other},
19690                         {Op->getOperand(0), V});
19691       Chain = Res.getValue(1);
19692     } else {
19693       Res = DAG.getNode(ISD::UINT_TO_FP, DL, WideVT, V);
19694     }
19695 
19696     Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
19697                       DAG.getIntPtrConstant(0, DL));
19698 
19699     if (IsStrict)
19700       return DAG.getMergeValues({Res, Chain}, DL);
19701     return Res;
19702   }
19703 
19704   if (Subtarget.hasAVX() && VecIntVT == MVT::v4i32 &&
19705       Op->getSimpleValueType(0) == MVT::v4f64) {
19706     SDValue ZExtIn = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v4i64, V);
19707     Constant *Bias = ConstantFP::get(
19708         *DAG.getContext(),
19709         APFloat(APFloat::IEEEdouble(), APInt(64, 0x4330000000000000ULL)));
19710     auto PtrVT = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
19711     SDValue CPIdx = DAG.getConstantPool(Bias, PtrVT, Align(8));
19712     SDVTList Tys = DAG.getVTList(MVT::v4f64, MVT::Other);
19713     SDValue Ops[] = {DAG.getEntryNode(), CPIdx};
19714     SDValue VBias = DAG.getMemIntrinsicNode(
19715         X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, MVT::f64,
19716         MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), Align(8),
19717         MachineMemOperand::MOLoad);
19718 
19719     SDValue Or = DAG.getNode(ISD::OR, DL, MVT::v4i64, ZExtIn,
19720                              DAG.getBitcast(MVT::v4i64, VBias));
19721     Or = DAG.getBitcast(MVT::v4f64, Or);
19722 
19723     if (IsStrict)
19724       return DAG.getNode(ISD::STRICT_FSUB, DL, {MVT::v4f64, MVT::Other},
19725                          {Op.getOperand(0), Or, VBias});
19726     return DAG.getNode(ISD::FSUB, DL, MVT::v4f64, Or, VBias);
19727   }
19728 
19729   // The algorithm is the following:
19730   // #ifdef __SSE4_1__
19731   //     uint4 lo = _mm_blend_epi16( v, (uint4) 0x4b000000, 0xaa);
19732   //     uint4 hi = _mm_blend_epi16( _mm_srli_epi32(v,16),
19733   //                                 (uint4) 0x53000000, 0xaa);
19734   // #else
19735   //     uint4 lo = (v & (uint4) 0xffff) | (uint4) 0x4b000000;
19736   //     uint4 hi = (v >> 16) | (uint4) 0x53000000;
19737   // #endif
19738   //     float4 fhi = (float4) hi - (0x1.0p39f + 0x1.0p23f);
19739   //     return (float4) lo + fhi;
19740 
19741   bool Is128 = VecIntVT == MVT::v4i32;
19742   MVT VecFloatVT = Is128 ? MVT::v4f32 : MVT::v8f32;
19743   // If we convert to something else than the supported type, e.g., to v4f64,
19744   // abort early.
19745   if (VecFloatVT != Op->getSimpleValueType(0))
19746     return SDValue();
19747 
19748   // In the #idef/#else code, we have in common:
19749   // - The vector of constants:
19750   // -- 0x4b000000
19751   // -- 0x53000000
19752   // - A shift:
19753   // -- v >> 16
19754 
19755   // Create the splat vector for 0x4b000000.
19756   SDValue VecCstLow = DAG.getConstant(0x4b000000, DL, VecIntVT);
19757   // Create the splat vector for 0x53000000.
19758   SDValue VecCstHigh = DAG.getConstant(0x53000000, DL, VecIntVT);
19759 
19760   // Create the right shift.
19761   SDValue VecCstShift = DAG.getConstant(16, DL, VecIntVT);
19762   SDValue HighShift = DAG.getNode(ISD::SRL, DL, VecIntVT, V, VecCstShift);
19763 
19764   SDValue Low, High;
19765   if (Subtarget.hasSSE41()) {
19766     MVT VecI16VT = Is128 ? MVT::v8i16 : MVT::v16i16;
19767     //     uint4 lo = _mm_blend_epi16( v, (uint4) 0x4b000000, 0xaa);
19768     SDValue VecCstLowBitcast = DAG.getBitcast(VecI16VT, VecCstLow);
19769     SDValue VecBitcast = DAG.getBitcast(VecI16VT, V);
19770     // Low will be bitcasted right away, so do not bother bitcasting back to its
19771     // original type.
19772     Low = DAG.getNode(X86ISD::BLENDI, DL, VecI16VT, VecBitcast,
19773                       VecCstLowBitcast, DAG.getTargetConstant(0xaa, DL, MVT::i8));
19774     //     uint4 hi = _mm_blend_epi16( _mm_srli_epi32(v,16),
19775     //                                 (uint4) 0x53000000, 0xaa);
19776     SDValue VecCstHighBitcast = DAG.getBitcast(VecI16VT, VecCstHigh);
19777     SDValue VecShiftBitcast = DAG.getBitcast(VecI16VT, HighShift);
19778     // High will be bitcasted right away, so do not bother bitcasting back to
19779     // its original type.
19780     High = DAG.getNode(X86ISD::BLENDI, DL, VecI16VT, VecShiftBitcast,
19781                        VecCstHighBitcast, DAG.getTargetConstant(0xaa, DL, MVT::i8));
19782   } else {
19783     SDValue VecCstMask = DAG.getConstant(0xffff, DL, VecIntVT);
19784     //     uint4 lo = (v & (uint4) 0xffff) | (uint4) 0x4b000000;
19785     SDValue LowAnd = DAG.getNode(ISD::AND, DL, VecIntVT, V, VecCstMask);
19786     Low = DAG.getNode(ISD::OR, DL, VecIntVT, LowAnd, VecCstLow);
19787 
19788     //     uint4 hi = (v >> 16) | (uint4) 0x53000000;
19789     High = DAG.getNode(ISD::OR, DL, VecIntVT, HighShift, VecCstHigh);
19790   }
19791 
19792   // Create the vector constant for (0x1.0p39f + 0x1.0p23f).
19793   SDValue VecCstFSub = DAG.getConstantFP(
19794       APFloat(APFloat::IEEEsingle(), APInt(32, 0x53000080)), DL, VecFloatVT);
19795 
19796   //     float4 fhi = (float4) hi - (0x1.0p39f + 0x1.0p23f);
19797   // NOTE: By using fsub of a positive constant instead of fadd of a negative
19798   // constant, we avoid reassociation in MachineCombiner when unsafe-fp-math is
19799   // enabled. See PR24512.
19800   SDValue HighBitcast = DAG.getBitcast(VecFloatVT, High);
19801   // TODO: Are there any fast-math-flags to propagate here?
19802   //     (float4) lo;
19803   SDValue LowBitcast = DAG.getBitcast(VecFloatVT, Low);
19804   //     return (float4) lo + fhi;
19805   if (IsStrict) {
19806     SDValue FHigh = DAG.getNode(ISD::STRICT_FSUB, DL, {VecFloatVT, MVT::Other},
19807                                 {Op.getOperand(0), HighBitcast, VecCstFSub});
19808     return DAG.getNode(ISD::STRICT_FADD, DL, {VecFloatVT, MVT::Other},
19809                        {FHigh.getValue(1), LowBitcast, FHigh});
19810   }
19811 
19812   SDValue FHigh =
19813       DAG.getNode(ISD::FSUB, DL, VecFloatVT, HighBitcast, VecCstFSub);
19814   return DAG.getNode(ISD::FADD, DL, VecFloatVT, LowBitcast, FHigh);
19815 }
19816 
lowerUINT_TO_FP_vec(SDValue Op,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)19817 static SDValue lowerUINT_TO_FP_vec(SDValue Op, const SDLoc &dl, SelectionDAG &DAG,
19818                                    const X86Subtarget &Subtarget) {
19819   unsigned OpNo = Op.getNode()->isStrictFPOpcode() ? 1 : 0;
19820   SDValue N0 = Op.getOperand(OpNo);
19821   MVT SrcVT = N0.getSimpleValueType();
19822 
19823   switch (SrcVT.SimpleTy) {
19824   default:
19825     llvm_unreachable("Custom UINT_TO_FP is not supported!");
19826   case MVT::v2i32:
19827     return lowerUINT_TO_FP_v2i32(Op, dl, DAG, Subtarget);
19828   case MVT::v4i32:
19829   case MVT::v8i32:
19830     return lowerUINT_TO_FP_vXi32(Op, dl, DAG, Subtarget);
19831   case MVT::v2i64:
19832   case MVT::v4i64:
19833     return lowerINT_TO_FP_vXi64(Op, dl, DAG, Subtarget);
19834   }
19835 }
19836 
LowerUINT_TO_FP(SDValue Op,SelectionDAG & DAG) const19837 SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op,
19838                                            SelectionDAG &DAG) const {
19839   bool IsStrict = Op->isStrictFPOpcode();
19840   unsigned OpNo = IsStrict ? 1 : 0;
19841   SDValue Src = Op.getOperand(OpNo);
19842   SDLoc dl(Op);
19843   auto PtrVT = getPointerTy(DAG.getDataLayout());
19844   MVT SrcVT = Src.getSimpleValueType();
19845   MVT DstVT = Op->getSimpleValueType(0);
19846   SDValue Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
19847 
19848   // Bail out when we don't have native conversion instructions.
19849   if (DstVT == MVT::f128)
19850     return SDValue();
19851 
19852   if (isSoftF16(DstVT, Subtarget))
19853     return promoteXINT_TO_FP(Op, dl, DAG);
19854   else if (isLegalConversion(SrcVT, false, Subtarget))
19855     return Op;
19856 
19857   if (DstVT.isVector())
19858     return lowerUINT_TO_FP_vec(Op, dl, DAG, Subtarget);
19859 
19860   if (Subtarget.isTargetWin64() && SrcVT == MVT::i128)
19861     return LowerWin64_INT128_TO_FP(Op, DAG);
19862 
19863   if (SDValue Extract = vectorizeExtractedCast(Op, dl, DAG, Subtarget))
19864     return Extract;
19865 
19866   if (Subtarget.hasAVX512() && isScalarFPTypeInSSEReg(DstVT) &&
19867       (SrcVT == MVT::i32 || (SrcVT == MVT::i64 && Subtarget.is64Bit()))) {
19868     // Conversions from unsigned i32 to f32/f64 are legal,
19869     // using VCVTUSI2SS/SD.  Same for i64 in 64-bit mode.
19870     return Op;
19871   }
19872 
19873   // Promote i32 to i64 and use a signed conversion on 64-bit targets.
19874   if (SrcVT == MVT::i32 && Subtarget.is64Bit()) {
19875     Src = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, Src);
19876     if (IsStrict)
19877       return DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {DstVT, MVT::Other},
19878                          {Chain, Src});
19879     return DAG.getNode(ISD::SINT_TO_FP, dl, DstVT, Src);
19880   }
19881 
19882   if (SDValue V = LowerI64IntToFP_AVX512DQ(Op, dl, DAG, Subtarget))
19883     return V;
19884   if (SDValue V = LowerI64IntToFP16(Op, dl, DAG, Subtarget))
19885     return V;
19886 
19887   // The transform for i64->f64 isn't correct for 0 when rounding to negative
19888   // infinity. It produces -0.0, so disable under strictfp.
19889   if (SrcVT == MVT::i64 && DstVT == MVT::f64 && Subtarget.hasSSE2() &&
19890       !IsStrict)
19891     return LowerUINT_TO_FP_i64(Op, dl, DAG, Subtarget);
19892   // The transform for i32->f64/f32 isn't correct for 0 when rounding to
19893   // negative infinity. So disable under strictfp. Using FILD instead.
19894   if (SrcVT == MVT::i32 && Subtarget.hasSSE2() && DstVT != MVT::f80 &&
19895       !IsStrict)
19896     return LowerUINT_TO_FP_i32(Op, dl, DAG, Subtarget);
19897   if (Subtarget.is64Bit() && SrcVT == MVT::i64 &&
19898       (DstVT == MVT::f32 || DstVT == MVT::f64))
19899     return SDValue();
19900 
19901   // Make a 64-bit buffer, and use it to build an FILD.
19902   SDValue StackSlot = DAG.CreateStackTemporary(MVT::i64, 8);
19903   int SSFI = cast<FrameIndexSDNode>(StackSlot)->getIndex();
19904   Align SlotAlign(8);
19905   MachinePointerInfo MPI =
19906       MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SSFI);
19907   if (SrcVT == MVT::i32) {
19908     SDValue OffsetSlot =
19909         DAG.getMemBasePlusOffset(StackSlot, TypeSize::getFixed(4), dl);
19910     SDValue Store1 = DAG.getStore(Chain, dl, Src, StackSlot, MPI, SlotAlign);
19911     SDValue Store2 = DAG.getStore(Store1, dl, DAG.getConstant(0, dl, MVT::i32),
19912                                   OffsetSlot, MPI.getWithOffset(4), SlotAlign);
19913     std::pair<SDValue, SDValue> Tmp =
19914         BuildFILD(DstVT, MVT::i64, dl, Store2, StackSlot, MPI, SlotAlign, DAG);
19915     if (IsStrict)
19916       return DAG.getMergeValues({Tmp.first, Tmp.second}, dl);
19917 
19918     return Tmp.first;
19919   }
19920 
19921   assert(SrcVT == MVT::i64 && "Unexpected type in UINT_TO_FP");
19922   SDValue ValueToStore = Src;
19923   if (isScalarFPTypeInSSEReg(Op.getValueType()) && !Subtarget.is64Bit()) {
19924     // Bitcasting to f64 here allows us to do a single 64-bit store from
19925     // an SSE register, avoiding the store forwarding penalty that would come
19926     // with two 32-bit stores.
19927     ValueToStore = DAG.getBitcast(MVT::f64, ValueToStore);
19928   }
19929   SDValue Store =
19930       DAG.getStore(Chain, dl, ValueToStore, StackSlot, MPI, SlotAlign);
19931   // For i64 source, we need to add the appropriate power of 2 if the input
19932   // was negative. We must be careful to do the computation in x87 extended
19933   // precision, not in SSE.
19934   SDVTList Tys = DAG.getVTList(MVT::f80, MVT::Other);
19935   SDValue Ops[] = {Store, StackSlot};
19936   SDValue Fild =
19937       DAG.getMemIntrinsicNode(X86ISD::FILD, dl, Tys, Ops, MVT::i64, MPI,
19938                               SlotAlign, MachineMemOperand::MOLoad);
19939   Chain = Fild.getValue(1);
19940 
19941   // Check whether the sign bit is set.
19942   SDValue SignSet = DAG.getSetCC(
19943       dl, getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
19944       Op.getOperand(OpNo), DAG.getConstant(0, dl, MVT::i64), ISD::SETLT);
19945 
19946   // Build a 64 bit pair (FF, 0) in the constant pool, with FF in the hi bits.
19947   APInt FF(64, 0x5F80000000000000ULL);
19948   SDValue FudgePtr =
19949       DAG.getConstantPool(ConstantInt::get(*DAG.getContext(), FF), PtrVT);
19950   Align CPAlignment = cast<ConstantPoolSDNode>(FudgePtr)->getAlign();
19951 
19952   // Get a pointer to FF if the sign bit was set, or to 0 otherwise.
19953   SDValue Zero = DAG.getIntPtrConstant(0, dl);
19954   SDValue Four = DAG.getIntPtrConstant(4, dl);
19955   SDValue Offset = DAG.getSelect(dl, Zero.getValueType(), SignSet, Four, Zero);
19956   FudgePtr = DAG.getNode(ISD::ADD, dl, PtrVT, FudgePtr, Offset);
19957 
19958   // Load the value out, extending it from f32 to f80.
19959   SDValue Fudge = DAG.getExtLoad(
19960       ISD::EXTLOAD, dl, MVT::f80, Chain, FudgePtr,
19961       MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), MVT::f32,
19962       CPAlignment);
19963   Chain = Fudge.getValue(1);
19964   // Extend everything to 80 bits to force it to be done on x87.
19965   // TODO: Are there any fast-math-flags to propagate here?
19966   if (IsStrict) {
19967     unsigned Opc = ISD::STRICT_FADD;
19968     // Windows needs the precision control changed to 80bits around this add.
19969     if (Subtarget.isOSWindows() && DstVT == MVT::f32)
19970       Opc = X86ISD::STRICT_FP80_ADD;
19971 
19972     SDValue Add =
19973         DAG.getNode(Opc, dl, {MVT::f80, MVT::Other}, {Chain, Fild, Fudge});
19974     // STRICT_FP_ROUND can't handle equal types.
19975     if (DstVT == MVT::f80)
19976       return Add;
19977     return DAG.getNode(ISD::STRICT_FP_ROUND, dl, {DstVT, MVT::Other},
19978                        {Add.getValue(1), Add, DAG.getIntPtrConstant(0, dl)});
19979   }
19980   unsigned Opc = ISD::FADD;
19981   // Windows needs the precision control changed to 80bits around this add.
19982   if (Subtarget.isOSWindows() && DstVT == MVT::f32)
19983     Opc = X86ISD::FP80_ADD;
19984 
19985   SDValue Add = DAG.getNode(Opc, dl, MVT::f80, Fild, Fudge);
19986   return DAG.getNode(ISD::FP_ROUND, dl, DstVT, Add,
19987                      DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
19988 }
19989 
19990 // If the given FP_TO_SINT (IsSigned) or FP_TO_UINT (!IsSigned) operation
19991 // is legal, or has an fp128 or f16 source (which needs to be promoted to f32),
19992 // just return an SDValue().
19993 // Otherwise it is assumed to be a conversion from one of f32, f64 or f80
19994 // to i16, i32 or i64, and we lower it to a legal sequence and return the
19995 // result.
FP_TO_INTHelper(SDValue Op,SelectionDAG & DAG,bool IsSigned,SDValue & Chain) const19996 SDValue X86TargetLowering::FP_TO_INTHelper(SDValue Op, SelectionDAG &DAG,
19997                                            bool IsSigned,
19998                                            SDValue &Chain) const {
19999   bool IsStrict = Op->isStrictFPOpcode();
20000   SDLoc DL(Op);
20001 
20002   EVT DstTy = Op.getValueType();
20003   SDValue Value = Op.getOperand(IsStrict ? 1 : 0);
20004   EVT TheVT = Value.getValueType();
20005   auto PtrVT = getPointerTy(DAG.getDataLayout());
20006 
20007   if (TheVT != MVT::f32 && TheVT != MVT::f64 && TheVT != MVT::f80) {
20008     // f16 must be promoted before using the lowering in this routine.
20009     // fp128 does not use this lowering.
20010     return SDValue();
20011   }
20012 
20013   // If using FIST to compute an unsigned i64, we'll need some fixup
20014   // to handle values above the maximum signed i64.  A FIST is always
20015   // used for the 32-bit subtarget, but also for f80 on a 64-bit target.
20016   bool UnsignedFixup = !IsSigned && DstTy == MVT::i64;
20017 
20018   // FIXME: This does not generate an invalid exception if the input does not
20019   // fit in i32. PR44019
20020   if (!IsSigned && DstTy != MVT::i64) {
20021     // Replace the fp-to-uint32 operation with an fp-to-sint64 FIST.
20022     // The low 32 bits of the fist result will have the correct uint32 result.
20023     assert(DstTy == MVT::i32 && "Unexpected FP_TO_UINT");
20024     DstTy = MVT::i64;
20025   }
20026 
20027   assert(DstTy.getSimpleVT() <= MVT::i64 &&
20028          DstTy.getSimpleVT() >= MVT::i16 &&
20029          "Unknown FP_TO_INT to lower!");
20030 
20031   // We lower FP->int64 into FISTP64 followed by a load from a temporary
20032   // stack slot.
20033   MachineFunction &MF = DAG.getMachineFunction();
20034   unsigned MemSize = DstTy.getStoreSize();
20035   int SSFI =
20036       MF.getFrameInfo().CreateStackObject(MemSize, Align(MemSize), false);
20037   SDValue StackSlot = DAG.getFrameIndex(SSFI, PtrVT);
20038 
20039   Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
20040 
20041   SDValue Adjust; // 0x0 or 0x80000000, for result sign bit adjustment.
20042 
20043   if (UnsignedFixup) {
20044     //
20045     // Conversion to unsigned i64 is implemented with a select,
20046     // depending on whether the source value fits in the range
20047     // of a signed i64.  Let Thresh be the FP equivalent of
20048     // 0x8000000000000000ULL.
20049     //
20050     //  Adjust = (Value >= Thresh) ? 0x80000000 : 0;
20051     //  FltOfs = (Value >= Thresh) ? 0x80000000 : 0;
20052     //  FistSrc = (Value - FltOfs);
20053     //  Fist-to-mem64 FistSrc
20054     //  Add 0 or 0x800...0ULL to the 64-bit result, which is equivalent
20055     //  to XOR'ing the high 32 bits with Adjust.
20056     //
20057     // Being a power of 2, Thresh is exactly representable in all FP formats.
20058     // For X87 we'd like to use the smallest FP type for this constant, but
20059     // for DAG type consistency we have to match the FP operand type.
20060 
20061     APFloat Thresh(APFloat::IEEEsingle(), APInt(32, 0x5f000000));
20062     LLVM_ATTRIBUTE_UNUSED APFloat::opStatus Status = APFloat::opOK;
20063     bool LosesInfo = false;
20064     if (TheVT == MVT::f64)
20065       // The rounding mode is irrelevant as the conversion should be exact.
20066       Status = Thresh.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
20067                               &LosesInfo);
20068     else if (TheVT == MVT::f80)
20069       Status = Thresh.convert(APFloat::x87DoubleExtended(),
20070                               APFloat::rmNearestTiesToEven, &LosesInfo);
20071 
20072     assert(Status == APFloat::opOK && !LosesInfo &&
20073            "FP conversion should have been exact");
20074 
20075     SDValue ThreshVal = DAG.getConstantFP(Thresh, DL, TheVT);
20076 
20077     EVT ResVT = getSetCCResultType(DAG.getDataLayout(),
20078                                    *DAG.getContext(), TheVT);
20079     SDValue Cmp;
20080     if (IsStrict) {
20081       Cmp = DAG.getSetCC(DL, ResVT, Value, ThreshVal, ISD::SETGE, Chain,
20082                          /*IsSignaling*/ true);
20083       Chain = Cmp.getValue(1);
20084     } else {
20085       Cmp = DAG.getSetCC(DL, ResVT, Value, ThreshVal, ISD::SETGE);
20086     }
20087 
20088     // Our preferred lowering of
20089     //
20090     // (Value >= Thresh) ? 0x8000000000000000ULL : 0
20091     //
20092     // is
20093     //
20094     // (Value >= Thresh) << 63
20095     //
20096     // but since we can get here after LegalOperations, DAGCombine might do the
20097     // wrong thing if we create a select. So, directly create the preferred
20098     // version.
20099     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Cmp);
20100     SDValue Const63 = DAG.getConstant(63, DL, MVT::i8);
20101     Adjust = DAG.getNode(ISD::SHL, DL, MVT::i64, Zext, Const63);
20102 
20103     SDValue FltOfs = DAG.getSelect(DL, TheVT, Cmp, ThreshVal,
20104                                    DAG.getConstantFP(0.0, DL, TheVT));
20105 
20106     if (IsStrict) {
20107       Value = DAG.getNode(ISD::STRICT_FSUB, DL, { TheVT, MVT::Other},
20108                           { Chain, Value, FltOfs });
20109       Chain = Value.getValue(1);
20110     } else
20111       Value = DAG.getNode(ISD::FSUB, DL, TheVT, Value, FltOfs);
20112   }
20113 
20114   MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, SSFI);
20115 
20116   // FIXME This causes a redundant load/store if the SSE-class value is already
20117   // in memory, such as if it is on the callstack.
20118   if (isScalarFPTypeInSSEReg(TheVT)) {
20119     assert(DstTy == MVT::i64 && "Invalid FP_TO_SINT to lower!");
20120     Chain = DAG.getStore(Chain, DL, Value, StackSlot, MPI);
20121     SDVTList Tys = DAG.getVTList(MVT::f80, MVT::Other);
20122     SDValue Ops[] = { Chain, StackSlot };
20123 
20124     unsigned FLDSize = TheVT.getStoreSize();
20125     assert(FLDSize <= MemSize && "Stack slot not big enough");
20126     MachineMemOperand *MMO = MF.getMachineMemOperand(
20127         MPI, MachineMemOperand::MOLoad, FLDSize, Align(FLDSize));
20128     Value = DAG.getMemIntrinsicNode(X86ISD::FLD, DL, Tys, Ops, TheVT, MMO);
20129     Chain = Value.getValue(1);
20130   }
20131 
20132   // Build the FP_TO_INT*_IN_MEM
20133   MachineMemOperand *MMO = MF.getMachineMemOperand(
20134       MPI, MachineMemOperand::MOStore, MemSize, Align(MemSize));
20135   SDValue Ops[] = { Chain, Value, StackSlot };
20136   SDValue FIST = DAG.getMemIntrinsicNode(X86ISD::FP_TO_INT_IN_MEM, DL,
20137                                          DAG.getVTList(MVT::Other),
20138                                          Ops, DstTy, MMO);
20139 
20140   SDValue Res = DAG.getLoad(Op.getValueType(), DL, FIST, StackSlot, MPI);
20141   Chain = Res.getValue(1);
20142 
20143   // If we need an unsigned fixup, XOR the result with adjust.
20144   if (UnsignedFixup)
20145     Res = DAG.getNode(ISD::XOR, DL, MVT::i64, Res, Adjust);
20146 
20147   return Res;
20148 }
20149 
LowerAVXExtend(SDValue Op,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)20150 static SDValue LowerAVXExtend(SDValue Op, const SDLoc &dl, SelectionDAG &DAG,
20151                               const X86Subtarget &Subtarget) {
20152   MVT VT = Op.getSimpleValueType();
20153   SDValue In = Op.getOperand(0);
20154   MVT InVT = In.getSimpleValueType();
20155   unsigned Opc = Op.getOpcode();
20156 
20157   assert(VT.isVector() && InVT.isVector() && "Expected vector type");
20158   assert((Opc == ISD::ANY_EXTEND || Opc == ISD::ZERO_EXTEND) &&
20159          "Unexpected extension opcode");
20160   assert(VT.getVectorNumElements() == InVT.getVectorNumElements() &&
20161          "Expected same number of elements");
20162   assert((VT.getVectorElementType() == MVT::i16 ||
20163           VT.getVectorElementType() == MVT::i32 ||
20164           VT.getVectorElementType() == MVT::i64) &&
20165          "Unexpected element type");
20166   assert((InVT.getVectorElementType() == MVT::i8 ||
20167           InVT.getVectorElementType() == MVT::i16 ||
20168           InVT.getVectorElementType() == MVT::i32) &&
20169          "Unexpected element type");
20170 
20171   unsigned ExtendInVecOpc = DAG.getOpcode_EXTEND_VECTOR_INREG(Opc);
20172 
20173   if (VT == MVT::v32i16 && !Subtarget.hasBWI()) {
20174     assert(InVT == MVT::v32i8 && "Unexpected VT!");
20175     return splitVectorIntUnary(Op, DAG, dl);
20176   }
20177 
20178   if (Subtarget.hasInt256())
20179     return Op;
20180 
20181   // Optimize vectors in AVX mode:
20182   //
20183   //   v8i16 -> v8i32
20184   //   Use vpmovzwd for 4 lower elements  v8i16 -> v4i32.
20185   //   Use vpunpckhwd for 4 upper elements  v8i16 -> v4i32.
20186   //   Concat upper and lower parts.
20187   //
20188   //   v4i32 -> v4i64
20189   //   Use vpmovzdq for 4 lower elements  v4i32 -> v2i64.
20190   //   Use vpunpckhdq for 4 upper elements  v4i32 -> v2i64.
20191   //   Concat upper and lower parts.
20192   //
20193   MVT HalfVT = VT.getHalfNumVectorElementsVT();
20194   SDValue OpLo = DAG.getNode(ExtendInVecOpc, dl, HalfVT, In);
20195 
20196   // Short-circuit if we can determine that each 128-bit half is the same value.
20197   // Otherwise, this is difficult to match and optimize.
20198   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(In))
20199     if (hasIdenticalHalvesShuffleMask(Shuf->getMask()))
20200       return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpLo);
20201 
20202   SDValue ZeroVec = DAG.getConstant(0, dl, InVT);
20203   SDValue Undef = DAG.getUNDEF(InVT);
20204   bool NeedZero = Opc == ISD::ZERO_EXTEND;
20205   SDValue OpHi = getUnpackh(DAG, dl, InVT, In, NeedZero ? ZeroVec : Undef);
20206   OpHi = DAG.getBitcast(HalfVT, OpHi);
20207 
20208   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi);
20209 }
20210 
20211 // Helper to split and extend a v16i1 mask to v16i8 or v16i16.
SplitAndExtendv16i1(unsigned ExtOpc,MVT VT,SDValue In,const SDLoc & dl,SelectionDAG & DAG)20212 static SDValue SplitAndExtendv16i1(unsigned ExtOpc, MVT VT, SDValue In,
20213                                    const SDLoc &dl, SelectionDAG &DAG) {
20214   assert((VT == MVT::v16i8 || VT == MVT::v16i16) && "Unexpected VT.");
20215   SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In,
20216                            DAG.getIntPtrConstant(0, dl));
20217   SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In,
20218                            DAG.getIntPtrConstant(8, dl));
20219   Lo = DAG.getNode(ExtOpc, dl, MVT::v8i16, Lo);
20220   Hi = DAG.getNode(ExtOpc, dl, MVT::v8i16, Hi);
20221   SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i16, Lo, Hi);
20222   return DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
20223 }
20224 
LowerZERO_EXTEND_Mask(SDValue Op,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)20225 static SDValue LowerZERO_EXTEND_Mask(SDValue Op, const SDLoc &DL,
20226                                      const X86Subtarget &Subtarget,
20227                                      SelectionDAG &DAG) {
20228   MVT VT = Op->getSimpleValueType(0);
20229   SDValue In = Op->getOperand(0);
20230   MVT InVT = In.getSimpleValueType();
20231   assert(InVT.getVectorElementType() == MVT::i1 && "Unexpected input type!");
20232   unsigned NumElts = VT.getVectorNumElements();
20233 
20234   // For all vectors, but vXi8 we can just emit a sign_extend and a shift. This
20235   // avoids a constant pool load.
20236   if (VT.getVectorElementType() != MVT::i8) {
20237     SDValue Extend = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, In);
20238     return DAG.getNode(ISD::SRL, DL, VT, Extend,
20239                        DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT));
20240   }
20241 
20242   // Extend VT if BWI is not supported.
20243   MVT ExtVT = VT;
20244   if (!Subtarget.hasBWI()) {
20245     // If v16i32 is to be avoided, we'll need to split and concatenate.
20246     if (NumElts == 16 && !Subtarget.canExtendTo512DQ())
20247       return SplitAndExtendv16i1(ISD::ZERO_EXTEND, VT, In, DL, DAG);
20248 
20249     ExtVT = MVT::getVectorVT(MVT::i32, NumElts);
20250   }
20251 
20252   // Widen to 512-bits if VLX is not supported.
20253   MVT WideVT = ExtVT;
20254   if (!ExtVT.is512BitVector() && !Subtarget.hasVLX()) {
20255     NumElts *= 512 / ExtVT.getSizeInBits();
20256     InVT = MVT::getVectorVT(MVT::i1, NumElts);
20257     In = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT, DAG.getUNDEF(InVT),
20258                      In, DAG.getIntPtrConstant(0, DL));
20259     WideVT = MVT::getVectorVT(ExtVT.getVectorElementType(),
20260                               NumElts);
20261   }
20262 
20263   SDValue One = DAG.getConstant(1, DL, WideVT);
20264   SDValue Zero = DAG.getConstant(0, DL, WideVT);
20265 
20266   SDValue SelectedVal = DAG.getSelect(DL, WideVT, In, One, Zero);
20267 
20268   // Truncate if we had to extend above.
20269   if (VT != ExtVT) {
20270     WideVT = MVT::getVectorVT(MVT::i8, NumElts);
20271     SelectedVal = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SelectedVal);
20272   }
20273 
20274   // Extract back to 128/256-bit if we widened.
20275   if (WideVT != VT)
20276     SelectedVal = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SelectedVal,
20277                               DAG.getIntPtrConstant(0, DL));
20278 
20279   return SelectedVal;
20280 }
20281 
LowerZERO_EXTEND(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)20282 static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
20283                                 SelectionDAG &DAG) {
20284   SDValue In = Op.getOperand(0);
20285   MVT SVT = In.getSimpleValueType();
20286   SDLoc DL(Op);
20287 
20288   if (SVT.getVectorElementType() == MVT::i1)
20289     return LowerZERO_EXTEND_Mask(Op, DL, Subtarget, DAG);
20290 
20291   assert(Subtarget.hasAVX() && "Expected AVX support");
20292   return LowerAVXExtend(Op, DL, DAG, Subtarget);
20293 }
20294 
20295 /// Helper to recursively truncate vector elements in half with PACKSS/PACKUS.
20296 /// It makes use of the fact that vectors with enough leading sign/zero bits
20297 /// prevent the PACKSS/PACKUS from saturating the results.
20298 /// AVX2 (Int256) sub-targets require extra shuffling as the PACK*S operates
20299 /// within each 128-bit lane.
truncateVectorWithPACK(unsigned Opcode,EVT DstVT,SDValue In,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)20300 static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In,
20301                                       const SDLoc &DL, SelectionDAG &DAG,
20302                                       const X86Subtarget &Subtarget) {
20303   assert((Opcode == X86ISD::PACKSS || Opcode == X86ISD::PACKUS) &&
20304          "Unexpected PACK opcode");
20305   assert(DstVT.isVector() && "VT not a vector?");
20306 
20307   // Requires SSE2 for PACKSS (SSE41 PACKUSDW is handled below).
20308   if (!Subtarget.hasSSE2())
20309     return SDValue();
20310 
20311   EVT SrcVT = In.getValueType();
20312 
20313   // No truncation required, we might get here due to recursive calls.
20314   if (SrcVT == DstVT)
20315     return In;
20316 
20317   unsigned NumElems = SrcVT.getVectorNumElements();
20318   if (NumElems < 2 || !isPowerOf2_32(NumElems) )
20319     return SDValue();
20320 
20321   unsigned DstSizeInBits = DstVT.getSizeInBits();
20322   unsigned SrcSizeInBits = SrcVT.getSizeInBits();
20323   assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
20324   assert(SrcSizeInBits > DstSizeInBits && "Illegal truncation");
20325 
20326   LLVMContext &Ctx = *DAG.getContext();
20327   EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
20328   EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
20329 
20330   // Pack to the largest type possible:
20331   // vXi64/vXi32 -> PACK*SDW and vXi16 -> PACK*SWB.
20332   EVT InVT = MVT::i16, OutVT = MVT::i8;
20333   if (SrcVT.getScalarSizeInBits() > 16 &&
20334       (Opcode == X86ISD::PACKSS || Subtarget.hasSSE41())) {
20335     InVT = MVT::i32;
20336     OutVT = MVT::i16;
20337   }
20338 
20339   // Sub-128-bit truncation - widen to 128-bit src and pack in the lower half.
20340   // On pre-AVX512, pack the src in both halves to help value tracking.
20341   if (SrcSizeInBits <= 128) {
20342     InVT = EVT::getVectorVT(Ctx, InVT, 128 / InVT.getSizeInBits());
20343     OutVT = EVT::getVectorVT(Ctx, OutVT, 128 / OutVT.getSizeInBits());
20344     In = widenSubVector(In, false, Subtarget, DAG, DL, 128);
20345     SDValue LHS = DAG.getBitcast(InVT, In);
20346     SDValue RHS = Subtarget.hasAVX512() ? DAG.getUNDEF(InVT) : LHS;
20347     SDValue Res = DAG.getNode(Opcode, DL, OutVT, LHS, RHS);
20348     Res = extractSubVector(Res, 0, DAG, DL, SrcSizeInBits / 2);
20349     Res = DAG.getBitcast(PackedVT, Res);
20350     return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget);
20351   }
20352 
20353   // Split lower/upper subvectors.
20354   SDValue Lo, Hi;
20355   std::tie(Lo, Hi) = splitVector(In, DAG, DL);
20356 
20357   // If Hi is undef, then don't bother packing it and widen the result instead.
20358   if (Hi.isUndef()) {
20359     EVT DstHalfVT = DstVT.getHalfNumVectorElementsVT(Ctx);
20360     if (SDValue Res =
20361             truncateVectorWithPACK(Opcode, DstHalfVT, Lo, DL, DAG, Subtarget))
20362       return widenSubVector(Res, false, Subtarget, DAG, DL, DstSizeInBits);
20363   }
20364 
20365   unsigned SubSizeInBits = SrcSizeInBits / 2;
20366   InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits());
20367   OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits());
20368 
20369   // 256bit -> 128bit truncate - PACK lower/upper 128-bit subvectors.
20370   if (SrcVT.is256BitVector() && DstVT.is128BitVector()) {
20371     Lo = DAG.getBitcast(InVT, Lo);
20372     Hi = DAG.getBitcast(InVT, Hi);
20373     SDValue Res = DAG.getNode(Opcode, DL, OutVT, Lo, Hi);
20374     return DAG.getBitcast(DstVT, Res);
20375   }
20376 
20377   // AVX2: 512bit -> 256bit truncate - PACK lower/upper 256-bit subvectors.
20378   // AVX2: 512bit -> 128bit truncate - PACK(PACK, PACK).
20379   if (SrcVT.is512BitVector() && Subtarget.hasInt256()) {
20380     Lo = DAG.getBitcast(InVT, Lo);
20381     Hi = DAG.getBitcast(InVT, Hi);
20382     SDValue Res = DAG.getNode(Opcode, DL, OutVT, Lo, Hi);
20383 
20384     // 256-bit PACK(ARG0, ARG1) leaves us with ((LO0,LO1),(HI0,HI1)),
20385     // so we need to shuffle to get ((LO0,HI0),(LO1,HI1)).
20386     // Scale shuffle mask to avoid bitcasts and help ComputeNumSignBits.
20387     SmallVector<int, 64> Mask;
20388     int Scale = 64 / OutVT.getScalarSizeInBits();
20389     narrowShuffleMaskElts(Scale, { 0, 2, 1, 3 }, Mask);
20390     Res = DAG.getVectorShuffle(OutVT, DL, Res, Res, Mask);
20391 
20392     if (DstVT.is256BitVector())
20393       return DAG.getBitcast(DstVT, Res);
20394 
20395     // If 512bit -> 128bit truncate another stage.
20396     Res = DAG.getBitcast(PackedVT, Res);
20397     return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget);
20398   }
20399 
20400   // Recursively pack lower/upper subvectors, concat result and pack again.
20401   assert(SrcSizeInBits >= 256 && "Expected 256-bit vector or greater");
20402 
20403   if (PackedVT.is128BitVector()) {
20404     // Avoid CONCAT_VECTORS on sub-128bit nodes as these can fail after
20405     // type legalization.
20406     SDValue Res =
20407         truncateVectorWithPACK(Opcode, PackedVT, In, DL, DAG, Subtarget);
20408     return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget);
20409   }
20410 
20411   EVT HalfPackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2);
20412   Lo = truncateVectorWithPACK(Opcode, HalfPackedVT, Lo, DL, DAG, Subtarget);
20413   Hi = truncateVectorWithPACK(Opcode, HalfPackedVT, Hi, DL, DAG, Subtarget);
20414   SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi);
20415   return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget);
20416 }
20417 
20418 /// Truncate using inreg zero extension (AND mask) and X86ISD::PACKUS.
20419 /// e.g. trunc <8 x i32> X to <8 x i16> -->
20420 /// MaskX = X & 0xffff (clear high bits to prevent saturation)
20421 /// packus (extract_subv MaskX, 0), (extract_subv MaskX, 1)
truncateVectorWithPACKUS(EVT DstVT,SDValue In,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)20422 static SDValue truncateVectorWithPACKUS(EVT DstVT, SDValue In, const SDLoc &DL,
20423                                         const X86Subtarget &Subtarget,
20424                                         SelectionDAG &DAG) {
20425   In = DAG.getZeroExtendInReg(In, DL, DstVT);
20426   return truncateVectorWithPACK(X86ISD::PACKUS, DstVT, In, DL, DAG, Subtarget);
20427 }
20428 
20429 /// Truncate using inreg sign extension and X86ISD::PACKSS.
truncateVectorWithPACKSS(EVT DstVT,SDValue In,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)20430 static SDValue truncateVectorWithPACKSS(EVT DstVT, SDValue In, const SDLoc &DL,
20431                                         const X86Subtarget &Subtarget,
20432                                         SelectionDAG &DAG) {
20433   EVT SrcVT = In.getValueType();
20434   In = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, SrcVT, In,
20435                    DAG.getValueType(DstVT));
20436   return truncateVectorWithPACK(X86ISD::PACKSS, DstVT, In, DL, DAG, Subtarget);
20437 }
20438 
20439 /// Helper to determine if \p In truncated to \p DstVT has the necessary
20440 /// signbits / leading zero bits to be truncated with PACKSS / PACKUS,
20441 /// possibly by converting a SRL node to SRA for sign extension.
matchTruncateWithPACK(unsigned & PackOpcode,EVT DstVT,SDValue In,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)20442 static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
20443                                      SDValue In, const SDLoc &DL,
20444                                      SelectionDAG &DAG,
20445                                      const X86Subtarget &Subtarget) {
20446   // Requires SSE2.
20447   if (!Subtarget.hasSSE2())
20448     return SDValue();
20449 
20450   EVT SrcVT = In.getValueType();
20451   EVT DstSVT = DstVT.getVectorElementType();
20452   EVT SrcSVT = SrcVT.getVectorElementType();
20453   unsigned NumDstEltBits = DstSVT.getSizeInBits();
20454   unsigned NumSrcEltBits = SrcSVT.getSizeInBits();
20455 
20456   // Check we have a truncation suited for PACKSS/PACKUS.
20457   if (!((SrcSVT == MVT::i16 || SrcSVT == MVT::i32 || SrcSVT == MVT::i64) &&
20458         (DstSVT == MVT::i8 || DstSVT == MVT::i16 || DstSVT == MVT::i32)))
20459     return SDValue();
20460 
20461   assert(NumSrcEltBits > NumDstEltBits && "Bad truncation");
20462   unsigned NumStages = Log2_32(NumSrcEltBits / NumDstEltBits);
20463 
20464   // Truncation from 128-bit to vXi32 can be better handled with PSHUFD.
20465   // Truncation to sub-64-bit vXi16 can be better handled with PSHUFD/PSHUFLW.
20466   // Truncation from v2i64 to v2i8 can be better handled with PSHUFB.
20467   if ((DstSVT == MVT::i32 && SrcVT.getSizeInBits() <= 128) ||
20468       (DstSVT == MVT::i16 && SrcVT.getSizeInBits() <= (64 * NumStages)) ||
20469       (DstVT == MVT::v2i8 && SrcVT == MVT::v2i64 && Subtarget.hasSSSE3()))
20470     return SDValue();
20471 
20472   // Prefer to lower v4i64 -> v4i32 as a shuffle unless we can cheaply
20473   // split this for packing.
20474   if (SrcVT == MVT::v4i64 && DstVT == MVT::v4i32 &&
20475       !isFreeToSplitVector(In.getNode(), DAG) &&
20476       (!Subtarget.hasAVX() || DAG.ComputeNumSignBits(In) != 64))
20477     return SDValue();
20478 
20479   // Don't truncate AVX512 targets as multiple PACK nodes stages.
20480   if (Subtarget.hasAVX512() && NumStages > 1)
20481     return SDValue();
20482 
20483   unsigned NumPackedSignBits = std::min<unsigned>(NumDstEltBits, 16);
20484   unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
20485 
20486   // Truncate with PACKUS if we are truncating a vector with leading zero
20487   // bits that extend all the way to the packed/truncated value.
20488   // e.g. Masks, zext_in_reg, etc.
20489   // Pre-SSE41 we can only use PACKUSWB.
20490   KnownBits Known = DAG.computeKnownBits(In);
20491   if ((NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros()) {
20492     PackOpcode = X86ISD::PACKUS;
20493     return In;
20494   }
20495 
20496   // Truncate with PACKSS if we are truncating a vector with sign-bits
20497   // that extend all the way to the packed/truncated value.
20498   // e.g. Comparison result, sext_in_reg, etc.
20499   unsigned NumSignBits = DAG.ComputeNumSignBits(In);
20500 
20501   // Don't use PACKSS for vXi64 -> vXi32 truncations unless we're dealing with
20502   // a sign splat (or AVX512 VPSRAQ support). ComputeNumSignBits struggles to
20503   // see through BITCASTs later on and combines/simplifications can't then use
20504   // it.
20505   if (DstSVT == MVT::i32 && NumSignBits != NumSrcEltBits &&
20506       !Subtarget.hasAVX512())
20507     return SDValue();
20508 
20509   unsigned MinSignBits = NumSrcEltBits - NumPackedSignBits;
20510   if (MinSignBits < NumSignBits) {
20511     PackOpcode = X86ISD::PACKSS;
20512     return In;
20513   }
20514 
20515   // If we have a srl that only generates signbits that we will discard in
20516   // the truncation then we can use PACKSS by converting the srl to a sra.
20517   // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
20518   if (In.getOpcode() == ISD::SRL && In->hasOneUse())
20519     if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
20520       if (*ShAmt == MinSignBits) {
20521         PackOpcode = X86ISD::PACKSS;
20522         return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
20523       }
20524     }
20525 
20526   return SDValue();
20527 }
20528 
20529 /// This function lowers a vector truncation of 'extended sign-bits' or
20530 /// 'extended zero-bits' values.
20531 /// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations.
LowerTruncateVecPackWithSignBits(MVT DstVT,SDValue In,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)20532 static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
20533                                                 const SDLoc &DL,
20534                                                 const X86Subtarget &Subtarget,
20535                                                 SelectionDAG &DAG) {
20536   MVT SrcVT = In.getSimpleValueType();
20537   MVT DstSVT = DstVT.getVectorElementType();
20538   MVT SrcSVT = SrcVT.getVectorElementType();
20539   if (!((SrcSVT == MVT::i16 || SrcSVT == MVT::i32 || SrcSVT == MVT::i64) &&
20540         (DstSVT == MVT::i8 || DstSVT == MVT::i16 || DstSVT == MVT::i32)))
20541     return SDValue();
20542 
20543   // If the upper half of the source is undef, then attempt to split and
20544   // only truncate the lower half.
20545   if (DstVT.getSizeInBits() >= 128) {
20546     SmallVector<SDValue> LowerOps;
20547     if (SDValue Lo = isUpperSubvectorUndef(In, DL, DAG)) {
20548       MVT DstHalfVT = DstVT.getHalfNumVectorElementsVT();
20549       if (SDValue Res = LowerTruncateVecPackWithSignBits(DstHalfVT, Lo, DL,
20550                                                          Subtarget, DAG))
20551         return widenSubVector(Res, false, Subtarget, DAG, DL,
20552                               DstVT.getSizeInBits());
20553     }
20554   }
20555 
20556   unsigned PackOpcode;
20557   if (SDValue Src =
20558           matchTruncateWithPACK(PackOpcode, DstVT, In, DL, DAG, Subtarget))
20559     return truncateVectorWithPACK(PackOpcode, DstVT, Src, DL, DAG, Subtarget);
20560 
20561   return SDValue();
20562 }
20563 
20564 /// This function lowers a vector truncation from vXi32/vXi64 to vXi8/vXi16 into
20565 /// X86ISD::PACKUS/X86ISD::PACKSS operations.
LowerTruncateVecPack(MVT DstVT,SDValue In,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)20566 static SDValue LowerTruncateVecPack(MVT DstVT, SDValue In, const SDLoc &DL,
20567                                     const X86Subtarget &Subtarget,
20568                                     SelectionDAG &DAG) {
20569   MVT SrcVT = In.getSimpleValueType();
20570   MVT DstSVT = DstVT.getVectorElementType();
20571   MVT SrcSVT = SrcVT.getVectorElementType();
20572   unsigned NumElems = DstVT.getVectorNumElements();
20573   if (!((SrcSVT == MVT::i16 || SrcSVT == MVT::i32 || SrcSVT == MVT::i64) &&
20574         (DstSVT == MVT::i8 || DstSVT == MVT::i16) && isPowerOf2_32(NumElems) &&
20575         NumElems >= 8))
20576     return SDValue();
20577 
20578   // SSSE3's pshufb results in less instructions in the cases below.
20579   if (Subtarget.hasSSSE3() && NumElems == 8) {
20580     if (SrcSVT == MVT::i16)
20581       return SDValue();
20582     if (SrcSVT == MVT::i32 && (DstSVT == MVT::i8 || !Subtarget.hasSSE41()))
20583       return SDValue();
20584   }
20585 
20586   // If the upper half of the source is undef, then attempt to split and
20587   // only truncate the lower half.
20588   if (DstVT.getSizeInBits() >= 128) {
20589     SmallVector<SDValue> LowerOps;
20590     if (SDValue Lo = isUpperSubvectorUndef(In, DL, DAG)) {
20591       MVT DstHalfVT = DstVT.getHalfNumVectorElementsVT();
20592       if (SDValue Res = LowerTruncateVecPack(DstHalfVT, Lo, DL, Subtarget, DAG))
20593         return widenSubVector(Res, false, Subtarget, DAG, DL,
20594                               DstVT.getSizeInBits());
20595     }
20596   }
20597 
20598   // SSE2 provides PACKUS for only 2 x v8i16 -> v16i8 and SSE4.1 provides PACKUS
20599   // for 2 x v4i32 -> v8i16. For SSSE3 and below, we need to use PACKSS to
20600   // truncate 2 x v4i32 to v8i16.
20601   if (Subtarget.hasSSE41() || DstSVT == MVT::i8)
20602     return truncateVectorWithPACKUS(DstVT, In, DL, Subtarget, DAG);
20603 
20604   if (SrcSVT == MVT::i16 || SrcSVT == MVT::i32)
20605     return truncateVectorWithPACKSS(DstVT, In, DL, Subtarget, DAG);
20606 
20607   // Special case vXi64 -> vXi16, shuffle to vXi32 and then use PACKSS.
20608   if (DstSVT == MVT::i16 && SrcSVT == MVT::i64) {
20609     MVT TruncVT = MVT::getVectorVT(MVT::i32, NumElems);
20610     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, In);
20611     return truncateVectorWithPACKSS(DstVT, Trunc, DL, Subtarget, DAG);
20612   }
20613 
20614   return SDValue();
20615 }
20616 
LowerTruncateVecI1(SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)20617 static SDValue LowerTruncateVecI1(SDValue Op, const SDLoc &DL,
20618                                   SelectionDAG &DAG,
20619                                   const X86Subtarget &Subtarget) {
20620   MVT VT = Op.getSimpleValueType();
20621   SDValue In = Op.getOperand(0);
20622   MVT InVT = In.getSimpleValueType();
20623   assert(VT.getVectorElementType() == MVT::i1 && "Unexpected vector type.");
20624 
20625   // Shift LSB to MSB and use VPMOVB/W2M or TESTD/Q.
20626   unsigned ShiftInx = InVT.getScalarSizeInBits() - 1;
20627   if (InVT.getScalarSizeInBits() <= 16) {
20628     if (Subtarget.hasBWI()) {
20629       // legal, will go to VPMOVB2M, VPMOVW2M
20630       if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) {
20631         // We need to shift to get the lsb into sign position.
20632         // Shift packed bytes not supported natively, bitcast to word
20633         MVT ExtVT = MVT::getVectorVT(MVT::i16, InVT.getSizeInBits()/16);
20634         In = DAG.getNode(ISD::SHL, DL, ExtVT,
20635                          DAG.getBitcast(ExtVT, In),
20636                          DAG.getConstant(ShiftInx, DL, ExtVT));
20637         In = DAG.getBitcast(InVT, In);
20638       }
20639       return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT),
20640                           In, ISD::SETGT);
20641     }
20642     // Use TESTD/Q, extended vector to packed dword/qword.
20643     assert((InVT.is256BitVector() || InVT.is128BitVector()) &&
20644            "Unexpected vector type.");
20645     unsigned NumElts = InVT.getVectorNumElements();
20646     assert((NumElts == 8 || NumElts == 16) && "Unexpected number of elements");
20647     // We need to change to a wider element type that we have support for.
20648     // For 8 element vectors this is easy, we either extend to v8i32 or v8i64.
20649     // For 16 element vectors we extend to v16i32 unless we are explicitly
20650     // trying to avoid 512-bit vectors. If we are avoiding 512-bit vectors
20651     // we need to split into two 8 element vectors which we can extend to v8i32,
20652     // truncate and concat the results. There's an additional complication if
20653     // the original type is v16i8. In that case we can't split the v16i8
20654     // directly, so we need to shuffle high elements to low and use
20655     // sign_extend_vector_inreg.
20656     if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) {
20657       SDValue Lo, Hi;
20658       if (InVT == MVT::v16i8) {
20659         Lo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, DL, MVT::v8i32, In);
20660         Hi = DAG.getVectorShuffle(
20661             InVT, DL, In, In,
20662             {8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1});
20663         Hi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, DL, MVT::v8i32, Hi);
20664       } else {
20665         assert(InVT == MVT::v16i16 && "Unexpected VT!");
20666         Lo = extract128BitVector(In, 0, DAG, DL);
20667         Hi = extract128BitVector(In, 8, DAG, DL);
20668       }
20669       // We're split now, just emit two truncates and a concat. The two
20670       // truncates will trigger legalization to come back to this function.
20671       Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Lo);
20672       Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Hi);
20673       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
20674     }
20675     // We either have 8 elements or we're allowed to use 512-bit vectors.
20676     // If we have VLX, we want to use the narrowest vector that can get the
20677     // job done so we use vXi32.
20678     MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts);
20679     MVT ExtVT = MVT::getVectorVT(EltVT, NumElts);
20680     In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In);
20681     InVT = ExtVT;
20682     ShiftInx = InVT.getScalarSizeInBits() - 1;
20683   }
20684 
20685   if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) {
20686     // We need to shift to get the lsb into sign position.
20687     In = DAG.getNode(ISD::SHL, DL, InVT, In,
20688                      DAG.getConstant(ShiftInx, DL, InVT));
20689   }
20690   // If we have DQI, emit a pattern that will be iseled as vpmovq2m/vpmovd2m.
20691   if (Subtarget.hasDQI())
20692     return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, InVT), In, ISD::SETGT);
20693   return DAG.getSetCC(DL, VT, In, DAG.getConstant(0, DL, InVT), ISD::SETNE);
20694 }
20695 
LowerTRUNCATE(SDValue Op,SelectionDAG & DAG) const20696 SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
20697   SDLoc DL(Op);
20698   MVT VT = Op.getSimpleValueType();
20699   SDValue In = Op.getOperand(0);
20700   MVT InVT = In.getSimpleValueType();
20701   assert(VT.getVectorNumElements() == InVT.getVectorNumElements() &&
20702          "Invalid TRUNCATE operation");
20703 
20704   // If we're called by the type legalizer, handle a few cases.
20705   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20706   if (!TLI.isTypeLegal(VT) || !TLI.isTypeLegal(InVT)) {
20707     if ((InVT == MVT::v8i64 || InVT == MVT::v16i32 || InVT == MVT::v16i64) &&
20708         VT.is128BitVector() && Subtarget.hasAVX512()) {
20709       assert((InVT == MVT::v16i64 || Subtarget.hasVLX()) &&
20710              "Unexpected subtarget!");
20711       // The default behavior is to truncate one step, concatenate, and then
20712       // truncate the remainder. We'd rather produce two 64-bit results and
20713       // concatenate those.
20714       SDValue Lo, Hi;
20715       std::tie(Lo, Hi) = DAG.SplitVector(In, DL);
20716 
20717       EVT LoVT, HiVT;
20718       std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
20719 
20720       Lo = DAG.getNode(ISD::TRUNCATE, DL, LoVT, Lo);
20721       Hi = DAG.getNode(ISD::TRUNCATE, DL, HiVT, Hi);
20722       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
20723     }
20724 
20725     // Pre-AVX512 (or prefer-256bit) see if we can make use of PACKSS/PACKUS.
20726     if (!Subtarget.hasAVX512() ||
20727         (InVT.is512BitVector() && VT.is256BitVector()))
20728       if (SDValue SignPack =
20729               LowerTruncateVecPackWithSignBits(VT, In, DL, Subtarget, DAG))
20730         return SignPack;
20731 
20732     // Pre-AVX512 see if we can make use of PACKSS/PACKUS.
20733     if (!Subtarget.hasAVX512())
20734       return LowerTruncateVecPack(VT, In, DL, Subtarget, DAG);
20735 
20736     // Otherwise let default legalization handle it.
20737     return SDValue();
20738   }
20739 
20740   if (VT.getVectorElementType() == MVT::i1)
20741     return LowerTruncateVecI1(Op, DL, DAG, Subtarget);
20742 
20743   // Attempt to truncate with PACKUS/PACKSS even on AVX512 if we'd have to
20744   // concat from subvectors to use VPTRUNC etc.
20745   if (!Subtarget.hasAVX512() || isFreeToSplitVector(In.getNode(), DAG))
20746     if (SDValue SignPack =
20747             LowerTruncateVecPackWithSignBits(VT, In, DL, Subtarget, DAG))
20748       return SignPack;
20749 
20750   // vpmovqb/w/d, vpmovdb/w, vpmovwb
20751   if (Subtarget.hasAVX512()) {
20752     if (InVT == MVT::v32i16 && !Subtarget.hasBWI()) {
20753       assert(VT == MVT::v32i8 && "Unexpected VT!");
20754       return splitVectorIntUnary(Op, DAG, DL);
20755     }
20756 
20757     // word to byte only under BWI. Otherwise we have to promoted to v16i32
20758     // and then truncate that. But we should only do that if we haven't been
20759     // asked to avoid 512-bit vectors. The actual promotion to v16i32 will be
20760     // handled by isel patterns.
20761     if (InVT != MVT::v16i16 || Subtarget.hasBWI() ||
20762         Subtarget.canExtendTo512DQ())
20763       return Op;
20764   }
20765 
20766   // Handle truncation of V256 to V128 using shuffles.
20767   assert(VT.is128BitVector() && InVT.is256BitVector() && "Unexpected types!");
20768 
20769   if ((VT == MVT::v4i32) && (InVT == MVT::v4i64)) {
20770     // On AVX2, v4i64 -> v4i32 becomes VPERMD.
20771     if (Subtarget.hasInt256()) {
20772       static const int ShufMask[] = {0, 2, 4, 6, -1, -1, -1, -1};
20773       In = DAG.getBitcast(MVT::v8i32, In);
20774       In = DAG.getVectorShuffle(MVT::v8i32, DL, In, In, ShufMask);
20775       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, In,
20776                          DAG.getIntPtrConstant(0, DL));
20777     }
20778 
20779     SDValue OpLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i64, In,
20780                                DAG.getIntPtrConstant(0, DL));
20781     SDValue OpHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i64, In,
20782                                DAG.getIntPtrConstant(2, DL));
20783     static const int ShufMask[] = {0, 2, 4, 6};
20784     return DAG.getVectorShuffle(VT, DL, DAG.getBitcast(MVT::v4i32, OpLo),
20785                                 DAG.getBitcast(MVT::v4i32, OpHi), ShufMask);
20786   }
20787 
20788   if ((VT == MVT::v8i16) && (InVT == MVT::v8i32)) {
20789     // On AVX2, v8i32 -> v8i16 becomes PSHUFB.
20790     if (Subtarget.hasInt256()) {
20791       // The PSHUFB mask:
20792       static const int ShufMask1[] = { 0,  1,  4,  5,  8,  9, 12, 13,
20793                                       -1, -1, -1, -1, -1, -1, -1, -1,
20794                                       16, 17, 20, 21, 24, 25, 28, 29,
20795                                       -1, -1, -1, -1, -1, -1, -1, -1 };
20796       In = DAG.getBitcast(MVT::v32i8, In);
20797       In = DAG.getVectorShuffle(MVT::v32i8, DL, In, In, ShufMask1);
20798       In = DAG.getBitcast(MVT::v4i64, In);
20799 
20800       static const int ShufMask2[] = {0, 2, -1, -1};
20801       In = DAG.getVectorShuffle(MVT::v4i64, DL, In, In, ShufMask2);
20802       In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i64, In,
20803                        DAG.getIntPtrConstant(0, DL));
20804       return DAG.getBitcast(MVT::v8i16, In);
20805     }
20806 
20807     return Subtarget.hasSSE41()
20808                ? truncateVectorWithPACKUS(VT, In, DL, Subtarget, DAG)
20809                : truncateVectorWithPACKSS(VT, In, DL, Subtarget, DAG);
20810   }
20811 
20812   if (VT == MVT::v16i8 && InVT == MVT::v16i16)
20813     return truncateVectorWithPACKUS(VT, In, DL, Subtarget, DAG);
20814 
20815   llvm_unreachable("All 256->128 cases should have been handled above!");
20816 }
20817 
20818 // We can leverage the specific way the "cvttps2dq/cvttpd2dq" instruction
20819 // behaves on out of range inputs to generate optimized conversions.
expandFP_TO_UINT_SSE(MVT VT,SDValue Src,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)20820 static SDValue expandFP_TO_UINT_SSE(MVT VT, SDValue Src, const SDLoc &dl,
20821                                     SelectionDAG &DAG,
20822                                     const X86Subtarget &Subtarget) {
20823   MVT SrcVT = Src.getSimpleValueType();
20824   unsigned DstBits = VT.getScalarSizeInBits();
20825   assert(DstBits == 32 && "expandFP_TO_UINT_SSE - only vXi32 supported");
20826 
20827   // Calculate the converted result for values in the range 0 to
20828   // 2^31-1 ("Small") and from 2^31 to 2^32-1 ("Big").
20829   SDValue Small = DAG.getNode(X86ISD::CVTTP2SI, dl, VT, Src);
20830   SDValue Big =
20831       DAG.getNode(X86ISD::CVTTP2SI, dl, VT,
20832                   DAG.getNode(ISD::FSUB, dl, SrcVT, Src,
20833                               DAG.getConstantFP(2147483648.0f, dl, SrcVT)));
20834 
20835   // The "CVTTP2SI" instruction conveniently sets the sign bit if
20836   // and only if the value was out of range. So we can use that
20837   // as our indicator that we rather use "Big" instead of "Small".
20838   //
20839   // Use "Small" if "IsOverflown" has all bits cleared
20840   // and "0x80000000 | Big" if all bits in "IsOverflown" are set.
20841 
20842   // AVX1 can't use the signsplat masking for 256-bit vectors - we have to
20843   // use the slightly slower blendv select instead.
20844   if (VT == MVT::v8i32 && !Subtarget.hasAVX2()) {
20845     SDValue Overflow = DAG.getNode(ISD::OR, dl, VT, Small, Big);
20846     return DAG.getNode(X86ISD::BLENDV, dl, VT, Small, Overflow, Small);
20847   }
20848 
20849   SDValue IsOverflown =
20850       DAG.getNode(X86ISD::VSRAI, dl, VT, Small,
20851                   DAG.getTargetConstant(DstBits - 1, dl, MVT::i8));
20852   return DAG.getNode(ISD::OR, dl, VT, Small,
20853                      DAG.getNode(ISD::AND, dl, VT, Big, IsOverflown));
20854 }
20855 
LowerFP_TO_INT(SDValue Op,SelectionDAG & DAG) const20856 SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
20857   bool IsStrict = Op->isStrictFPOpcode();
20858   bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT ||
20859                   Op.getOpcode() == ISD::STRICT_FP_TO_SINT;
20860   MVT VT = Op->getSimpleValueType(0);
20861   SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
20862   SDValue Chain = IsStrict ? Op->getOperand(0) : SDValue();
20863   MVT SrcVT = Src.getSimpleValueType();
20864   SDLoc dl(Op);
20865 
20866   SDValue Res;
20867   if (isSoftF16(SrcVT, Subtarget)) {
20868     MVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
20869     if (IsStrict)
20870       return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
20871                          {Chain, DAG.getNode(ISD::STRICT_FP_EXTEND, dl,
20872                                              {NVT, MVT::Other}, {Chain, Src})});
20873     return DAG.getNode(Op.getOpcode(), dl, VT,
20874                        DAG.getNode(ISD::FP_EXTEND, dl, NVT, Src));
20875   } else if (isTypeLegal(SrcVT) && isLegalConversion(VT, IsSigned, Subtarget)) {
20876     return Op;
20877   }
20878 
20879   if (VT.isVector()) {
20880     if (VT == MVT::v2i1 && SrcVT == MVT::v2f64) {
20881       MVT ResVT = MVT::v4i32;
20882       MVT TruncVT = MVT::v4i1;
20883       unsigned Opc;
20884       if (IsStrict)
20885         Opc = IsSigned ? X86ISD::STRICT_CVTTP2SI : X86ISD::STRICT_CVTTP2UI;
20886       else
20887         Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
20888 
20889       if (!IsSigned && !Subtarget.hasVLX()) {
20890         assert(Subtarget.useAVX512Regs() && "Unexpected features!");
20891         // Widen to 512-bits.
20892         ResVT = MVT::v8i32;
20893         TruncVT = MVT::v8i1;
20894         Opc = Op.getOpcode();
20895         // Need to concat with zero vector for strict fp to avoid spurious
20896         // exceptions.
20897         // TODO: Should we just do this for non-strict as well?
20898         SDValue Tmp = IsStrict ? DAG.getConstantFP(0.0, dl, MVT::v8f64)
20899                                : DAG.getUNDEF(MVT::v8f64);
20900         Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8f64, Tmp, Src,
20901                           DAG.getIntPtrConstant(0, dl));
20902       }
20903       if (IsStrict) {
20904         Res = DAG.getNode(Opc, dl, {ResVT, MVT::Other}, {Chain, Src});
20905         Chain = Res.getValue(1);
20906       } else {
20907         Res = DAG.getNode(Opc, dl, ResVT, Src);
20908       }
20909 
20910       Res = DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Res);
20911       Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i1, Res,
20912                         DAG.getIntPtrConstant(0, dl));
20913       if (IsStrict)
20914         return DAG.getMergeValues({Res, Chain}, dl);
20915       return Res;
20916     }
20917 
20918     if (Subtarget.hasFP16() && SrcVT.getVectorElementType() == MVT::f16) {
20919       if (VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16)
20920         return Op;
20921 
20922       MVT ResVT = VT;
20923       MVT EleVT = VT.getVectorElementType();
20924       if (EleVT != MVT::i64)
20925         ResVT = EleVT == MVT::i32 ? MVT::v4i32 : MVT::v8i16;
20926 
20927       if (SrcVT != MVT::v8f16) {
20928         SDValue Tmp =
20929             IsStrict ? DAG.getConstantFP(0.0, dl, SrcVT) : DAG.getUNDEF(SrcVT);
20930         SmallVector<SDValue, 4> Ops(SrcVT == MVT::v2f16 ? 4 : 2, Tmp);
20931         Ops[0] = Src;
20932         Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f16, Ops);
20933       }
20934 
20935       if (IsStrict) {
20936         Res = DAG.getNode(IsSigned ? X86ISD::STRICT_CVTTP2SI
20937                                    : X86ISD::STRICT_CVTTP2UI,
20938                           dl, {ResVT, MVT::Other}, {Chain, Src});
20939         Chain = Res.getValue(1);
20940       } else {
20941         Res = DAG.getNode(IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI, dl,
20942                           ResVT, Src);
20943       }
20944 
20945       // TODO: Need to add exception check code for strict FP.
20946       if (EleVT.getSizeInBits() < 16) {
20947         ResVT = MVT::getVectorVT(EleVT, 8);
20948         Res = DAG.getNode(ISD::TRUNCATE, dl, ResVT, Res);
20949       }
20950 
20951       if (ResVT != VT)
20952         Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, Res,
20953                           DAG.getIntPtrConstant(0, dl));
20954 
20955       if (IsStrict)
20956         return DAG.getMergeValues({Res, Chain}, dl);
20957       return Res;
20958     }
20959 
20960     // v8f32/v16f32/v8f64->v8i16/v16i16 need to widen first.
20961     if (VT.getVectorElementType() == MVT::i16) {
20962       assert((SrcVT.getVectorElementType() == MVT::f32 ||
20963               SrcVT.getVectorElementType() == MVT::f64) &&
20964              "Expected f32/f64 vector!");
20965       MVT NVT = VT.changeVectorElementType(MVT::i32);
20966       if (IsStrict) {
20967         Res = DAG.getNode(IsSigned ? ISD::STRICT_FP_TO_SINT
20968                                    : ISD::STRICT_FP_TO_UINT,
20969                           dl, {NVT, MVT::Other}, {Chain, Src});
20970         Chain = Res.getValue(1);
20971       } else {
20972         Res = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl,
20973                           NVT, Src);
20974       }
20975 
20976       // TODO: Need to add exception check code for strict FP.
20977       Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
20978 
20979       if (IsStrict)
20980         return DAG.getMergeValues({Res, Chain}, dl);
20981       return Res;
20982     }
20983 
20984     // v8f64->v8i32 is legal, but we need v8i32 to be custom for v8f32.
20985     if (VT == MVT::v8i32 && SrcVT == MVT::v8f64) {
20986       assert(!IsSigned && "Expected unsigned conversion!");
20987       assert(Subtarget.useAVX512Regs() && "Requires avx512f");
20988       return Op;
20989     }
20990 
20991     // Widen vXi32 fp_to_uint with avx512f to 512-bit source.
20992     if ((VT == MVT::v4i32 || VT == MVT::v8i32) &&
20993         (SrcVT == MVT::v4f64 || SrcVT == MVT::v4f32 || SrcVT == MVT::v8f32) &&
20994         Subtarget.useAVX512Regs()) {
20995       assert(!IsSigned && "Expected unsigned conversion!");
20996       assert(!Subtarget.hasVLX() && "Unexpected features!");
20997       MVT WideVT = SrcVT == MVT::v4f64 ? MVT::v8f64 : MVT::v16f32;
20998       MVT ResVT = SrcVT == MVT::v4f64 ? MVT::v8i32 : MVT::v16i32;
20999       // Need to concat with zero vector for strict fp to avoid spurious
21000       // exceptions.
21001       // TODO: Should we just do this for non-strict as well?
21002       SDValue Tmp =
21003           IsStrict ? DAG.getConstantFP(0.0, dl, WideVT) : DAG.getUNDEF(WideVT);
21004       Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVT, Tmp, Src,
21005                         DAG.getIntPtrConstant(0, dl));
21006 
21007       if (IsStrict) {
21008         Res = DAG.getNode(ISD::STRICT_FP_TO_UINT, dl, {ResVT, MVT::Other},
21009                           {Chain, Src});
21010         Chain = Res.getValue(1);
21011       } else {
21012         Res = DAG.getNode(ISD::FP_TO_UINT, dl, ResVT, Src);
21013       }
21014 
21015       Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, Res,
21016                         DAG.getIntPtrConstant(0, dl));
21017 
21018       if (IsStrict)
21019         return DAG.getMergeValues({Res, Chain}, dl);
21020       return Res;
21021     }
21022 
21023     // Widen vXi64 fp_to_uint/fp_to_sint with avx512dq to 512-bit source.
21024     if ((VT == MVT::v2i64 || VT == MVT::v4i64) &&
21025         (SrcVT == MVT::v2f64 || SrcVT == MVT::v4f64 || SrcVT == MVT::v4f32) &&
21026         Subtarget.useAVX512Regs() && Subtarget.hasDQI()) {
21027       assert(!Subtarget.hasVLX() && "Unexpected features!");
21028       MVT WideVT = SrcVT == MVT::v4f32 ? MVT::v8f32 : MVT::v8f64;
21029       // Need to concat with zero vector for strict fp to avoid spurious
21030       // exceptions.
21031       // TODO: Should we just do this for non-strict as well?
21032       SDValue Tmp =
21033           IsStrict ? DAG.getConstantFP(0.0, dl, WideVT) : DAG.getUNDEF(WideVT);
21034       Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVT, Tmp, Src,
21035                         DAG.getIntPtrConstant(0, dl));
21036 
21037       if (IsStrict) {
21038         Res = DAG.getNode(Op.getOpcode(), dl, {MVT::v8i64, MVT::Other},
21039                           {Chain, Src});
21040         Chain = Res.getValue(1);
21041       } else {
21042         Res = DAG.getNode(Op.getOpcode(), dl, MVT::v8i64, Src);
21043       }
21044 
21045       Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, Res,
21046                         DAG.getIntPtrConstant(0, dl));
21047 
21048       if (IsStrict)
21049         return DAG.getMergeValues({Res, Chain}, dl);
21050       return Res;
21051     }
21052 
21053     if (VT == MVT::v2i64 && SrcVT == MVT::v2f32) {
21054       if (!Subtarget.hasVLX()) {
21055         // Non-strict nodes without VLX can we widened to v4f32->v4i64 by type
21056         // legalizer and then widened again by vector op legalization.
21057         if (!IsStrict)
21058           return SDValue();
21059 
21060         SDValue Zero = DAG.getConstantFP(0.0, dl, MVT::v2f32);
21061         SDValue Tmp = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f32,
21062                                   {Src, Zero, Zero, Zero});
21063         Tmp = DAG.getNode(Op.getOpcode(), dl, {MVT::v8i64, MVT::Other},
21064                           {Chain, Tmp});
21065         SDValue Chain = Tmp.getValue(1);
21066         Tmp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i64, Tmp,
21067                           DAG.getIntPtrConstant(0, dl));
21068         return DAG.getMergeValues({Tmp, Chain}, dl);
21069       }
21070 
21071       assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL");
21072       SDValue Tmp = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src,
21073                                 DAG.getUNDEF(MVT::v2f32));
21074       if (IsStrict) {
21075         unsigned Opc = IsSigned ? X86ISD::STRICT_CVTTP2SI
21076                                 : X86ISD::STRICT_CVTTP2UI;
21077         return DAG.getNode(Opc, dl, {VT, MVT::Other}, {Op->getOperand(0), Tmp});
21078       }
21079       unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
21080       return DAG.getNode(Opc, dl, VT, Tmp);
21081     }
21082 
21083     // Generate optimized instructions for pre AVX512 unsigned conversions from
21084     // vXf32 to vXi32.
21085     if ((VT == MVT::v4i32 && SrcVT == MVT::v4f32) ||
21086         (VT == MVT::v4i32 && SrcVT == MVT::v4f64) ||
21087         (VT == MVT::v8i32 && SrcVT == MVT::v8f32)) {
21088       assert(!IsSigned && "Expected unsigned conversion!");
21089       return expandFP_TO_UINT_SSE(VT, Src, dl, DAG, Subtarget);
21090     }
21091 
21092     return SDValue();
21093   }
21094 
21095   assert(!VT.isVector());
21096 
21097   bool UseSSEReg = isScalarFPTypeInSSEReg(SrcVT);
21098 
21099   if (!IsSigned && UseSSEReg) {
21100     // Conversions from f32/f64 with AVX512 should be legal.
21101     if (Subtarget.hasAVX512())
21102       return Op;
21103 
21104     // We can leverage the specific way the "cvttss2si/cvttsd2si" instruction
21105     // behaves on out of range inputs to generate optimized conversions.
21106     if (!IsStrict && ((VT == MVT::i32 && !Subtarget.is64Bit()) ||
21107                       (VT == MVT::i64 && Subtarget.is64Bit()))) {
21108       unsigned DstBits = VT.getScalarSizeInBits();
21109       APInt UIntLimit = APInt::getSignMask(DstBits);
21110       SDValue FloatOffset = DAG.getNode(ISD::UINT_TO_FP, dl, SrcVT,
21111                                         DAG.getConstant(UIntLimit, dl, VT));
21112       MVT SrcVecVT = MVT::getVectorVT(SrcVT, 128 / SrcVT.getScalarSizeInBits());
21113 
21114       // Calculate the converted result for values in the range:
21115       // (i32) 0 to 2^31-1 ("Small") and from 2^31 to 2^32-1 ("Big").
21116       // (i64) 0 to 2^63-1 ("Small") and from 2^63 to 2^64-1 ("Big").
21117       SDValue Small =
21118           DAG.getNode(X86ISD::CVTTS2SI, dl, VT,
21119                       DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, SrcVecVT, Src));
21120       SDValue Big = DAG.getNode(
21121           X86ISD::CVTTS2SI, dl, VT,
21122           DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, SrcVecVT,
21123                       DAG.getNode(ISD::FSUB, dl, SrcVT, Src, FloatOffset)));
21124 
21125       // The "CVTTS2SI" instruction conveniently sets the sign bit if
21126       // and only if the value was out of range. So we can use that
21127       // as our indicator that we rather use "Big" instead of "Small".
21128       //
21129       // Use "Small" if "IsOverflown" has all bits cleared
21130       // and "0x80000000 | Big" if all bits in "IsOverflown" are set.
21131       SDValue IsOverflown = DAG.getNode(
21132           ISD::SRA, dl, VT, Small, DAG.getConstant(DstBits - 1, dl, MVT::i8));
21133       return DAG.getNode(ISD::OR, dl, VT, Small,
21134                          DAG.getNode(ISD::AND, dl, VT, Big, IsOverflown));
21135     }
21136 
21137     // Use default expansion for i64.
21138     if (VT == MVT::i64)
21139       return SDValue();
21140 
21141     assert(VT == MVT::i32 && "Unexpected VT!");
21142 
21143     // Promote i32 to i64 and use a signed operation on 64-bit targets.
21144     // FIXME: This does not generate an invalid exception if the input does not
21145     // fit in i32. PR44019
21146     if (Subtarget.is64Bit()) {
21147       if (IsStrict) {
21148         Res = DAG.getNode(ISD::STRICT_FP_TO_SINT, dl, {MVT::i64, MVT::Other},
21149                           {Chain, Src});
21150         Chain = Res.getValue(1);
21151       } else
21152         Res = DAG.getNode(ISD::FP_TO_SINT, dl, MVT::i64, Src);
21153 
21154       Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
21155       if (IsStrict)
21156         return DAG.getMergeValues({Res, Chain}, dl);
21157       return Res;
21158     }
21159 
21160     // Use default expansion for SSE1/2 targets without SSE3. With SSE3 we can
21161     // use fisttp which will be handled later.
21162     if (!Subtarget.hasSSE3())
21163       return SDValue();
21164   }
21165 
21166   // Promote i16 to i32 if we can use a SSE operation or the type is f128.
21167   // FIXME: This does not generate an invalid exception if the input does not
21168   // fit in i16. PR44019
21169   if (VT == MVT::i16 && (UseSSEReg || SrcVT == MVT::f128)) {
21170     assert(IsSigned && "Expected i16 FP_TO_UINT to have been promoted!");
21171     if (IsStrict) {
21172       Res = DAG.getNode(ISD::STRICT_FP_TO_SINT, dl, {MVT::i32, MVT::Other},
21173                         {Chain, Src});
21174       Chain = Res.getValue(1);
21175     } else
21176       Res = DAG.getNode(ISD::FP_TO_SINT, dl, MVT::i32, Src);
21177 
21178     Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
21179     if (IsStrict)
21180       return DAG.getMergeValues({Res, Chain}, dl);
21181     return Res;
21182   }
21183 
21184   // If this is a FP_TO_SINT using SSEReg we're done.
21185   if (UseSSEReg && IsSigned)
21186     return Op;
21187 
21188   // fp128 needs to use a libcall.
21189   if (SrcVT == MVT::f128) {
21190     RTLIB::Libcall LC;
21191     if (IsSigned)
21192       LC = RTLIB::getFPTOSINT(SrcVT, VT);
21193     else
21194       LC = RTLIB::getFPTOUINT(SrcVT, VT);
21195 
21196     MakeLibCallOptions CallOptions;
21197     std::pair<SDValue, SDValue> Tmp =
21198         makeLibCall(DAG, LC, VT, Src, CallOptions, dl, Chain);
21199 
21200     if (IsStrict)
21201       return DAG.getMergeValues({ Tmp.first, Tmp.second }, dl);
21202 
21203     return Tmp.first;
21204   }
21205 
21206   // Fall back to X87.
21207   if (SDValue V = FP_TO_INTHelper(Op, DAG, IsSigned, Chain)) {
21208     if (IsStrict)
21209       return DAG.getMergeValues({V, Chain}, dl);
21210     return V;
21211   }
21212 
21213   llvm_unreachable("Expected FP_TO_INTHelper to handle all remaining cases.");
21214 }
21215 
LowerLRINT_LLRINT(SDValue Op,SelectionDAG & DAG) const21216 SDValue X86TargetLowering::LowerLRINT_LLRINT(SDValue Op,
21217                                              SelectionDAG &DAG) const {
21218   SDValue Src = Op.getOperand(0);
21219   EVT DstVT = Op.getSimpleValueType();
21220   MVT SrcVT = Src.getSimpleValueType();
21221 
21222   if (SrcVT.isVector())
21223     return DstVT.getScalarType() == MVT::i32 ? Op : SDValue();
21224 
21225   if (SrcVT == MVT::f16)
21226     return SDValue();
21227 
21228   // If the source is in an SSE register, the node is Legal.
21229   if (isScalarFPTypeInSSEReg(SrcVT))
21230     return Op;
21231 
21232   return LRINT_LLRINTHelper(Op.getNode(), DAG);
21233 }
21234 
LRINT_LLRINTHelper(SDNode * N,SelectionDAG & DAG) const21235 SDValue X86TargetLowering::LRINT_LLRINTHelper(SDNode *N,
21236                                               SelectionDAG &DAG) const {
21237   EVT DstVT = N->getValueType(0);
21238   SDValue Src = N->getOperand(0);
21239   EVT SrcVT = Src.getValueType();
21240 
21241   if (SrcVT != MVT::f32 && SrcVT != MVT::f64 && SrcVT != MVT::f80) {
21242     // f16 must be promoted before using the lowering in this routine.
21243     // fp128 does not use this lowering.
21244     return SDValue();
21245   }
21246 
21247   SDLoc DL(N);
21248   SDValue Chain = DAG.getEntryNode();
21249 
21250   bool UseSSE = isScalarFPTypeInSSEReg(SrcVT);
21251 
21252   // If we're converting from SSE, the stack slot needs to hold both types.
21253   // Otherwise it only needs to hold the DstVT.
21254   EVT OtherVT = UseSSE ? SrcVT : DstVT;
21255   SDValue StackPtr = DAG.CreateStackTemporary(DstVT, OtherVT);
21256   int SPFI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
21257   MachinePointerInfo MPI =
21258       MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SPFI);
21259 
21260   if (UseSSE) {
21261     assert(DstVT == MVT::i64 && "Invalid LRINT/LLRINT to lower!");
21262     Chain = DAG.getStore(Chain, DL, Src, StackPtr, MPI);
21263     SDVTList Tys = DAG.getVTList(MVT::f80, MVT::Other);
21264     SDValue Ops[] = { Chain, StackPtr };
21265 
21266     Src = DAG.getMemIntrinsicNode(X86ISD::FLD, DL, Tys, Ops, SrcVT, MPI,
21267                                   /*Align*/ std::nullopt,
21268                                   MachineMemOperand::MOLoad);
21269     Chain = Src.getValue(1);
21270   }
21271 
21272   SDValue StoreOps[] = { Chain, Src, StackPtr };
21273   Chain = DAG.getMemIntrinsicNode(X86ISD::FIST, DL, DAG.getVTList(MVT::Other),
21274                                   StoreOps, DstVT, MPI, /*Align*/ std::nullopt,
21275                                   MachineMemOperand::MOStore);
21276 
21277   return DAG.getLoad(DstVT, DL, Chain, StackPtr, MPI);
21278 }
21279 
21280 SDValue
LowerFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG) const21281 X86TargetLowering::LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const {
21282   // This is based on the TargetLowering::expandFP_TO_INT_SAT implementation,
21283   // but making use of X86 specifics to produce better instruction sequences.
21284   SDNode *Node = Op.getNode();
21285   bool IsSigned = Node->getOpcode() == ISD::FP_TO_SINT_SAT;
21286   unsigned FpToIntOpcode = IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT;
21287   SDLoc dl(SDValue(Node, 0));
21288   SDValue Src = Node->getOperand(0);
21289 
21290   // There are three types involved here: SrcVT is the source floating point
21291   // type, DstVT is the type of the result, and TmpVT is the result of the
21292   // intermediate FP_TO_*INT operation we'll use (which may be a promotion of
21293   // DstVT).
21294   EVT SrcVT = Src.getValueType();
21295   EVT DstVT = Node->getValueType(0);
21296   EVT TmpVT = DstVT;
21297 
21298   // This code is only for floats and doubles. Fall back to generic code for
21299   // anything else.
21300   if (!isScalarFPTypeInSSEReg(SrcVT) || isSoftF16(SrcVT, Subtarget))
21301     return SDValue();
21302 
21303   EVT SatVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
21304   unsigned SatWidth = SatVT.getScalarSizeInBits();
21305   unsigned DstWidth = DstVT.getScalarSizeInBits();
21306   unsigned TmpWidth = TmpVT.getScalarSizeInBits();
21307   assert(SatWidth <= DstWidth && SatWidth <= TmpWidth &&
21308          "Expected saturation width smaller than result width");
21309 
21310   // Promote result of FP_TO_*INT to at least 32 bits.
21311   if (TmpWidth < 32) {
21312     TmpVT = MVT::i32;
21313     TmpWidth = 32;
21314   }
21315 
21316   // Promote conversions to unsigned 32-bit to 64-bit, because it will allow
21317   // us to use a native signed conversion instead.
21318   if (SatWidth == 32 && !IsSigned && Subtarget.is64Bit()) {
21319     TmpVT = MVT::i64;
21320     TmpWidth = 64;
21321   }
21322 
21323   // If the saturation width is smaller than the size of the temporary result,
21324   // we can always use signed conversion, which is native.
21325   if (SatWidth < TmpWidth)
21326     FpToIntOpcode = ISD::FP_TO_SINT;
21327 
21328   // Determine minimum and maximum integer values and their corresponding
21329   // floating-point values.
21330   APInt MinInt, MaxInt;
21331   if (IsSigned) {
21332     MinInt = APInt::getSignedMinValue(SatWidth).sext(DstWidth);
21333     MaxInt = APInt::getSignedMaxValue(SatWidth).sext(DstWidth);
21334   } else {
21335     MinInt = APInt::getMinValue(SatWidth).zext(DstWidth);
21336     MaxInt = APInt::getMaxValue(SatWidth).zext(DstWidth);
21337   }
21338 
21339   APFloat MinFloat(DAG.EVTToAPFloatSemantics(SrcVT));
21340   APFloat MaxFloat(DAG.EVTToAPFloatSemantics(SrcVT));
21341 
21342   APFloat::opStatus MinStatus = MinFloat.convertFromAPInt(
21343     MinInt, IsSigned, APFloat::rmTowardZero);
21344   APFloat::opStatus MaxStatus = MaxFloat.convertFromAPInt(
21345     MaxInt, IsSigned, APFloat::rmTowardZero);
21346   bool AreExactFloatBounds = !(MinStatus & APFloat::opStatus::opInexact)
21347                           && !(MaxStatus & APFloat::opStatus::opInexact);
21348 
21349   SDValue MinFloatNode = DAG.getConstantFP(MinFloat, dl, SrcVT);
21350   SDValue MaxFloatNode = DAG.getConstantFP(MaxFloat, dl, SrcVT);
21351 
21352   // If the integer bounds are exactly representable as floats, emit a
21353   // min+max+fptoi sequence. Otherwise use comparisons and selects.
21354   if (AreExactFloatBounds) {
21355     if (DstVT != TmpVT) {
21356       // Clamp by MinFloat from below. If Src is NaN, propagate NaN.
21357       SDValue MinClamped = DAG.getNode(
21358         X86ISD::FMAX, dl, SrcVT, MinFloatNode, Src);
21359       // Clamp by MaxFloat from above. If Src is NaN, propagate NaN.
21360       SDValue BothClamped = DAG.getNode(
21361         X86ISD::FMIN, dl, SrcVT, MaxFloatNode, MinClamped);
21362       // Convert clamped value to integer.
21363       SDValue FpToInt = DAG.getNode(FpToIntOpcode, dl, TmpVT, BothClamped);
21364 
21365       // NaN will become INDVAL, with the top bit set and the rest zero.
21366       // Truncation will discard the top bit, resulting in zero.
21367       return DAG.getNode(ISD::TRUNCATE, dl, DstVT, FpToInt);
21368     }
21369 
21370     // Clamp by MinFloat from below. If Src is NaN, the result is MinFloat.
21371     SDValue MinClamped = DAG.getNode(
21372       X86ISD::FMAX, dl, SrcVT, Src, MinFloatNode);
21373     // Clamp by MaxFloat from above. NaN cannot occur.
21374     SDValue BothClamped = DAG.getNode(
21375       X86ISD::FMINC, dl, SrcVT, MinClamped, MaxFloatNode);
21376     // Convert clamped value to integer.
21377     SDValue FpToInt = DAG.getNode(FpToIntOpcode, dl, DstVT, BothClamped);
21378 
21379     if (!IsSigned) {
21380       // In the unsigned case we're done, because we mapped NaN to MinFloat,
21381       // which is zero.
21382       return FpToInt;
21383     }
21384 
21385     // Otherwise, select zero if Src is NaN.
21386     SDValue ZeroInt = DAG.getConstant(0, dl, DstVT);
21387     return DAG.getSelectCC(
21388       dl, Src, Src, ZeroInt, FpToInt, ISD::CondCode::SETUO);
21389   }
21390 
21391   SDValue MinIntNode = DAG.getConstant(MinInt, dl, DstVT);
21392   SDValue MaxIntNode = DAG.getConstant(MaxInt, dl, DstVT);
21393 
21394   // Result of direct conversion, which may be selected away.
21395   SDValue FpToInt = DAG.getNode(FpToIntOpcode, dl, TmpVT, Src);
21396 
21397   if (DstVT != TmpVT) {
21398     // NaN will become INDVAL, with the top bit set and the rest zero.
21399     // Truncation will discard the top bit, resulting in zero.
21400     FpToInt = DAG.getNode(ISD::TRUNCATE, dl, DstVT, FpToInt);
21401   }
21402 
21403   SDValue Select = FpToInt;
21404   // For signed conversions where we saturate to the same size as the
21405   // result type of the fptoi instructions, INDVAL coincides with integer
21406   // minimum, so we don't need to explicitly check it.
21407   if (!IsSigned || SatWidth != TmpVT.getScalarSizeInBits()) {
21408     // If Src ULT MinFloat, select MinInt. In particular, this also selects
21409     // MinInt if Src is NaN.
21410     Select = DAG.getSelectCC(
21411       dl, Src, MinFloatNode, MinIntNode, Select, ISD::CondCode::SETULT);
21412   }
21413 
21414   // If Src OGT MaxFloat, select MaxInt.
21415   Select = DAG.getSelectCC(
21416     dl, Src, MaxFloatNode, MaxIntNode, Select, ISD::CondCode::SETOGT);
21417 
21418   // In the unsigned case we are done, because we mapped NaN to MinInt, which
21419   // is already zero. The promoted case was already handled above.
21420   if (!IsSigned || DstVT != TmpVT) {
21421     return Select;
21422   }
21423 
21424   // Otherwise, select 0 if Src is NaN.
21425   SDValue ZeroInt = DAG.getConstant(0, dl, DstVT);
21426   return DAG.getSelectCC(
21427     dl, Src, Src, ZeroInt, Select, ISD::CondCode::SETUO);
21428 }
21429 
LowerFP_EXTEND(SDValue Op,SelectionDAG & DAG) const21430 SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
21431   bool IsStrict = Op->isStrictFPOpcode();
21432 
21433   SDLoc DL(Op);
21434   MVT VT = Op.getSimpleValueType();
21435   SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
21436   SDValue In = Op.getOperand(IsStrict ? 1 : 0);
21437   MVT SVT = In.getSimpleValueType();
21438 
21439   // Let f16->f80 get lowered to a libcall, except for darwin, where we should
21440   // lower it to an fp_extend via f32 (as only f16<>f32 libcalls are available)
21441   if (VT == MVT::f128 || (SVT == MVT::f16 && VT == MVT::f80 &&
21442                           !Subtarget.getTargetTriple().isOSDarwin()))
21443     return SDValue();
21444 
21445   if ((SVT == MVT::v8f16 && Subtarget.hasF16C()) ||
21446       (SVT == MVT::v16f16 && Subtarget.useAVX512Regs()))
21447     return Op;
21448 
21449   if (SVT == MVT::f16) {
21450     if (Subtarget.hasFP16())
21451       return Op;
21452 
21453     if (VT != MVT::f32) {
21454       if (IsStrict)
21455         return DAG.getNode(
21456             ISD::STRICT_FP_EXTEND, DL, {VT, MVT::Other},
21457             {Chain, DAG.getNode(ISD::STRICT_FP_EXTEND, DL,
21458                                 {MVT::f32, MVT::Other}, {Chain, In})});
21459 
21460       return DAG.getNode(ISD::FP_EXTEND, DL, VT,
21461                          DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, In));
21462     }
21463 
21464     if (!Subtarget.hasF16C()) {
21465       if (!Subtarget.getTargetTriple().isOSDarwin())
21466         return SDValue();
21467 
21468       assert(VT == MVT::f32 && SVT == MVT::f16 && "unexpected extend libcall");
21469 
21470       // Need a libcall, but ABI for f16 is soft-float on MacOS.
21471       TargetLowering::CallLoweringInfo CLI(DAG);
21472       Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
21473 
21474       In = DAG.getBitcast(MVT::i16, In);
21475       TargetLowering::ArgListTy Args;
21476       TargetLowering::ArgListEntry Entry;
21477       Entry.Node = In;
21478       Entry.Ty = EVT(MVT::i16).getTypeForEVT(*DAG.getContext());
21479       Entry.IsSExt = false;
21480       Entry.IsZExt = true;
21481       Args.push_back(Entry);
21482 
21483       SDValue Callee = DAG.getExternalSymbol(
21484           getLibcallName(RTLIB::FPEXT_F16_F32),
21485           getPointerTy(DAG.getDataLayout()));
21486       CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
21487           CallingConv::C, EVT(VT).getTypeForEVT(*DAG.getContext()), Callee,
21488           std::move(Args));
21489 
21490       SDValue Res;
21491       std::tie(Res,Chain) = LowerCallTo(CLI);
21492       if (IsStrict)
21493         Res = DAG.getMergeValues({Res, Chain}, DL);
21494 
21495       return Res;
21496     }
21497 
21498     In = DAG.getBitcast(MVT::i16, In);
21499     In = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v8i16,
21500                      getZeroVector(MVT::v8i16, Subtarget, DAG, DL), In,
21501                      DAG.getIntPtrConstant(0, DL));
21502     SDValue Res;
21503     if (IsStrict) {
21504       Res = DAG.getNode(X86ISD::STRICT_CVTPH2PS, DL, {MVT::v4f32, MVT::Other},
21505                         {Chain, In});
21506       Chain = Res.getValue(1);
21507     } else {
21508       Res = DAG.getNode(X86ISD::CVTPH2PS, DL, MVT::v4f32, In,
21509                         DAG.getTargetConstant(4, DL, MVT::i32));
21510     }
21511     Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Res,
21512                       DAG.getIntPtrConstant(0, DL));
21513     if (IsStrict)
21514       return DAG.getMergeValues({Res, Chain}, DL);
21515     return Res;
21516   }
21517 
21518   if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
21519     return Op;
21520 
21521   if (SVT.getVectorElementType() == MVT::f16) {
21522     if (Subtarget.hasFP16() && isTypeLegal(SVT))
21523       return Op;
21524     assert(Subtarget.hasF16C() && "Unexpected features!");
21525     if (SVT == MVT::v2f16)
21526       In = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f16, In,
21527                        DAG.getUNDEF(MVT::v2f16));
21528     SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8f16, In,
21529                               DAG.getUNDEF(MVT::v4f16));
21530     if (IsStrict)
21531       return DAG.getNode(X86ISD::STRICT_VFPEXT, DL, {VT, MVT::Other},
21532                          {Op->getOperand(0), Res});
21533     return DAG.getNode(X86ISD::VFPEXT, DL, VT, Res);
21534   } else if (VT == MVT::v4f64 || VT == MVT::v8f64) {
21535     return Op;
21536   }
21537 
21538   assert(SVT == MVT::v2f32 && "Only customize MVT::v2f32 type legalization!");
21539 
21540   SDValue Res =
21541       DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, In, DAG.getUNDEF(SVT));
21542   if (IsStrict)
21543     return DAG.getNode(X86ISD::STRICT_VFPEXT, DL, {VT, MVT::Other},
21544                        {Op->getOperand(0), Res});
21545   return DAG.getNode(X86ISD::VFPEXT, DL, VT, Res);
21546 }
21547 
LowerFP_ROUND(SDValue Op,SelectionDAG & DAG) const21548 SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
21549   bool IsStrict = Op->isStrictFPOpcode();
21550 
21551   SDLoc DL(Op);
21552   SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
21553   SDValue In = Op.getOperand(IsStrict ? 1 : 0);
21554   MVT VT = Op.getSimpleValueType();
21555   MVT SVT = In.getSimpleValueType();
21556 
21557   if (SVT == MVT::f128 || (VT == MVT::f16 && SVT == MVT::f80))
21558     return SDValue();
21559 
21560   if (VT == MVT::f16 && (SVT == MVT::f64 || SVT == MVT::f32) &&
21561       !Subtarget.hasFP16() && (SVT == MVT::f64 || !Subtarget.hasF16C())) {
21562     if (!Subtarget.getTargetTriple().isOSDarwin())
21563       return SDValue();
21564 
21565     // We need a libcall but the ABI for f16 libcalls on MacOS is soft.
21566     TargetLowering::CallLoweringInfo CLI(DAG);
21567     Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
21568 
21569     TargetLowering::ArgListTy Args;
21570     TargetLowering::ArgListEntry Entry;
21571     Entry.Node = In;
21572     Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
21573     Entry.IsSExt = false;
21574     Entry.IsZExt = true;
21575     Args.push_back(Entry);
21576 
21577     SDValue Callee = DAG.getExternalSymbol(
21578         getLibcallName(SVT == MVT::f64 ? RTLIB::FPROUND_F64_F16
21579                                        : RTLIB::FPROUND_F32_F16),
21580         getPointerTy(DAG.getDataLayout()));
21581     CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
21582         CallingConv::C, EVT(MVT::i16).getTypeForEVT(*DAG.getContext()), Callee,
21583         std::move(Args));
21584 
21585     SDValue Res;
21586     std::tie(Res, Chain) = LowerCallTo(CLI);
21587 
21588     Res = DAG.getBitcast(MVT::f16, Res);
21589 
21590     if (IsStrict)
21591       Res = DAG.getMergeValues({Res, Chain}, DL);
21592 
21593     return Res;
21594   }
21595 
21596   if (VT.getScalarType() == MVT::bf16) {
21597     if (SVT.getScalarType() == MVT::f32 &&
21598         ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
21599          Subtarget.hasAVXNECONVERT()))
21600       return Op;
21601     return SDValue();
21602   }
21603 
21604   if (VT.getScalarType() == MVT::f16 && !Subtarget.hasFP16()) {
21605     if (!Subtarget.hasF16C() || SVT.getScalarType() != MVT::f32)
21606       return SDValue();
21607 
21608     if (VT.isVector())
21609       return Op;
21610 
21611     SDValue Res;
21612     SDValue Rnd = DAG.getTargetConstant(X86::STATIC_ROUNDING::CUR_DIRECTION, DL,
21613                                         MVT::i32);
21614     if (IsStrict) {
21615       Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v4f32,
21616                         DAG.getConstantFP(0, DL, MVT::v4f32), In,
21617                         DAG.getIntPtrConstant(0, DL));
21618       Res = DAG.getNode(X86ISD::STRICT_CVTPS2PH, DL, {MVT::v8i16, MVT::Other},
21619                         {Chain, Res, Rnd});
21620       Chain = Res.getValue(1);
21621     } else {
21622       // FIXME: Should we use zeros for upper elements for non-strict?
21623       Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, In);
21624       Res = DAG.getNode(X86ISD::CVTPS2PH, DL, MVT::v8i16, Res, Rnd);
21625     }
21626 
21627     Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i16, Res,
21628                       DAG.getIntPtrConstant(0, DL));
21629     Res = DAG.getBitcast(MVT::f16, Res);
21630 
21631     if (IsStrict)
21632       return DAG.getMergeValues({Res, Chain}, DL);
21633 
21634     return Res;
21635   }
21636 
21637   return Op;
21638 }
21639 
LowerFP16_TO_FP(SDValue Op,SelectionDAG & DAG)21640 static SDValue LowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) {
21641   bool IsStrict = Op->isStrictFPOpcode();
21642   SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
21643   assert(Src.getValueType() == MVT::i16 && Op.getValueType() == MVT::f32 &&
21644          "Unexpected VT!");
21645 
21646   SDLoc dl(Op);
21647   SDValue Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v8i16,
21648                             DAG.getConstant(0, dl, MVT::v8i16), Src,
21649                             DAG.getIntPtrConstant(0, dl));
21650 
21651   SDValue Chain;
21652   if (IsStrict) {
21653     Res = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {MVT::v4f32, MVT::Other},
21654                       {Op.getOperand(0), Res});
21655     Chain = Res.getValue(1);
21656   } else {
21657     Res = DAG.getNode(X86ISD::CVTPH2PS, dl, MVT::v4f32, Res);
21658   }
21659 
21660   Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f32, Res,
21661                     DAG.getIntPtrConstant(0, dl));
21662 
21663   if (IsStrict)
21664     return DAG.getMergeValues({Res, Chain}, dl);
21665 
21666   return Res;
21667 }
21668 
LowerFP_TO_FP16(SDValue Op,SelectionDAG & DAG)21669 static SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) {
21670   bool IsStrict = Op->isStrictFPOpcode();
21671   SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
21672   assert(Src.getValueType() == MVT::f32 && Op.getValueType() == MVT::i16 &&
21673          "Unexpected VT!");
21674 
21675   SDLoc dl(Op);
21676   SDValue Res, Chain;
21677   if (IsStrict) {
21678     Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v4f32,
21679                       DAG.getConstantFP(0, dl, MVT::v4f32), Src,
21680                       DAG.getIntPtrConstant(0, dl));
21681     Res = DAG.getNode(
21682         X86ISD::STRICT_CVTPS2PH, dl, {MVT::v8i16, MVT::Other},
21683         {Op.getOperand(0), Res, DAG.getTargetConstant(4, dl, MVT::i32)});
21684     Chain = Res.getValue(1);
21685   } else {
21686     // FIXME: Should we use zeros for upper elements for non-strict?
21687     Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4f32, Src);
21688     Res = DAG.getNode(X86ISD::CVTPS2PH, dl, MVT::v8i16, Res,
21689                       DAG.getTargetConstant(4, dl, MVT::i32));
21690   }
21691 
21692   Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i16, Res,
21693                     DAG.getIntPtrConstant(0, dl));
21694 
21695   if (IsStrict)
21696     return DAG.getMergeValues({Res, Chain}, dl);
21697 
21698   return Res;
21699 }
21700 
LowerFP_TO_BF16(SDValue Op,SelectionDAG & DAG) const21701 SDValue X86TargetLowering::LowerFP_TO_BF16(SDValue Op,
21702                                            SelectionDAG &DAG) const {
21703   SDLoc DL(Op);
21704 
21705   MVT SVT = Op.getOperand(0).getSimpleValueType();
21706   if (SVT == MVT::f32 && ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
21707                           Subtarget.hasAVXNECONVERT())) {
21708     SDValue Res;
21709     Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, Op.getOperand(0));
21710     Res = DAG.getNode(X86ISD::CVTNEPS2BF16, DL, MVT::v8bf16, Res);
21711     Res = DAG.getBitcast(MVT::v8i16, Res);
21712     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i16, Res,
21713                        DAG.getIntPtrConstant(0, DL));
21714   }
21715 
21716   MakeLibCallOptions CallOptions;
21717   RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, MVT::bf16);
21718   SDValue Res =
21719       makeLibCall(DAG, LC, MVT::f16, Op.getOperand(0), CallOptions, DL).first;
21720   return DAG.getBitcast(MVT::i16, Res);
21721 }
21722 
21723 /// Depending on uarch and/or optimizing for size, we might prefer to use a
21724 /// vector operation in place of the typical scalar operation.
lowerAddSubToHorizontalOp(SDValue Op,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)21725 static SDValue lowerAddSubToHorizontalOp(SDValue Op, const SDLoc &DL,
21726                                          SelectionDAG &DAG,
21727                                          const X86Subtarget &Subtarget) {
21728   // If both operands have other uses, this is probably not profitable.
21729   SDValue LHS = Op.getOperand(0);
21730   SDValue RHS = Op.getOperand(1);
21731   if (!LHS.hasOneUse() && !RHS.hasOneUse())
21732     return Op;
21733 
21734   // FP horizontal add/sub were added with SSE3. Integer with SSSE3.
21735   bool IsFP = Op.getSimpleValueType().isFloatingPoint();
21736   if (IsFP && !Subtarget.hasSSE3())
21737     return Op;
21738   if (!IsFP && !Subtarget.hasSSSE3())
21739     return Op;
21740 
21741   // Extract from a common vector.
21742   if (LHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21743       RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21744       LHS.getOperand(0) != RHS.getOperand(0) ||
21745       !isa<ConstantSDNode>(LHS.getOperand(1)) ||
21746       !isa<ConstantSDNode>(RHS.getOperand(1)) ||
21747       !shouldUseHorizontalOp(true, DAG, Subtarget))
21748     return Op;
21749 
21750   // Allow commuted 'hadd' ops.
21751   // TODO: Allow commuted (f)sub by negating the result of (F)HSUB?
21752   unsigned HOpcode;
21753   switch (Op.getOpcode()) {
21754   // clang-format off
21755   case ISD::ADD: HOpcode = X86ISD::HADD; break;
21756   case ISD::SUB: HOpcode = X86ISD::HSUB; break;
21757   case ISD::FADD: HOpcode = X86ISD::FHADD; break;
21758   case ISD::FSUB: HOpcode = X86ISD::FHSUB; break;
21759   default:
21760     llvm_unreachable("Trying to lower unsupported opcode to horizontal op");
21761   // clang-format on
21762   }
21763   unsigned LExtIndex = LHS.getConstantOperandVal(1);
21764   unsigned RExtIndex = RHS.getConstantOperandVal(1);
21765   if ((LExtIndex & 1) == 1 && (RExtIndex & 1) == 0 &&
21766       (HOpcode == X86ISD::HADD || HOpcode == X86ISD::FHADD))
21767     std::swap(LExtIndex, RExtIndex);
21768 
21769   if ((LExtIndex & 1) != 0 || RExtIndex != (LExtIndex + 1))
21770     return Op;
21771 
21772   SDValue X = LHS.getOperand(0);
21773   EVT VecVT = X.getValueType();
21774   unsigned BitWidth = VecVT.getSizeInBits();
21775   unsigned NumLanes = BitWidth / 128;
21776   unsigned NumEltsPerLane = VecVT.getVectorNumElements() / NumLanes;
21777   assert((BitWidth == 128 || BitWidth == 256 || BitWidth == 512) &&
21778          "Not expecting illegal vector widths here");
21779 
21780   // Creating a 256-bit horizontal op would be wasteful, and there is no 512-bit
21781   // equivalent, so extract the 256/512-bit source op to 128-bit if we can.
21782   if (BitWidth == 256 || BitWidth == 512) {
21783     unsigned LaneIdx = LExtIndex / NumEltsPerLane;
21784     X = extract128BitVector(X, LaneIdx * NumEltsPerLane, DAG, DL);
21785     LExtIndex %= NumEltsPerLane;
21786   }
21787 
21788   // add (extractelt (X, 0), extractelt (X, 1)) --> extractelt (hadd X, X), 0
21789   // add (extractelt (X, 1), extractelt (X, 0)) --> extractelt (hadd X, X), 0
21790   // add (extractelt (X, 2), extractelt (X, 3)) --> extractelt (hadd X, X), 1
21791   // sub (extractelt (X, 0), extractelt (X, 1)) --> extractelt (hsub X, X), 0
21792   SDValue HOp = DAG.getNode(HOpcode, DL, X.getValueType(), X, X);
21793   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getSimpleValueType(), HOp,
21794                      DAG.getIntPtrConstant(LExtIndex / 2, DL));
21795 }
21796 
21797 /// Depending on uarch and/or optimizing for size, we might prefer to use a
21798 /// vector operation in place of the typical scalar operation.
lowerFaddFsub(SDValue Op,SelectionDAG & DAG) const21799 SDValue X86TargetLowering::lowerFaddFsub(SDValue Op, SelectionDAG &DAG) const {
21800   assert((Op.getValueType() == MVT::f32 || Op.getValueType() == MVT::f64) &&
21801          "Only expecting float/double");
21802   return lowerAddSubToHorizontalOp(Op, SDLoc(Op), DAG, Subtarget);
21803 }
21804 
21805 /// ISD::FROUND is defined to round to nearest with ties rounding away from 0.
21806 /// This mode isn't supported in hardware on X86. But as long as we aren't
21807 /// compiling with trapping math, we can emulate this with
21808 /// trunc(X + copysign(nextafter(0.5, 0.0), X)).
LowerFROUND(SDValue Op,SelectionDAG & DAG)21809 static SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) {
21810   SDValue N0 = Op.getOperand(0);
21811   SDLoc dl(Op);
21812   MVT VT = Op.getSimpleValueType();
21813 
21814   // N0 += copysign(nextafter(0.5, 0.0), N0)
21815   const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(VT);
21816   bool Ignored;
21817   APFloat Point5Pred = APFloat(0.5f);
21818   Point5Pred.convert(Sem, APFloat::rmNearestTiesToEven, &Ignored);
21819   Point5Pred.next(/*nextDown*/true);
21820 
21821   SDValue Adder = DAG.getNode(ISD::FCOPYSIGN, dl, VT,
21822                               DAG.getConstantFP(Point5Pred, dl, VT), N0);
21823   N0 = DAG.getNode(ISD::FADD, dl, VT, N0, Adder);
21824 
21825   // Truncate the result to remove fraction.
21826   return DAG.getNode(ISD::FTRUNC, dl, VT, N0);
21827 }
21828 
21829 /// The only differences between FABS and FNEG are the mask and the logic op.
21830 /// FNEG also has a folding opportunity for FNEG(FABS(x)).
LowerFABSorFNEG(SDValue Op,SelectionDAG & DAG)21831 static SDValue LowerFABSorFNEG(SDValue Op, SelectionDAG &DAG) {
21832   assert((Op.getOpcode() == ISD::FABS || Op.getOpcode() == ISD::FNEG) &&
21833          "Wrong opcode for lowering FABS or FNEG.");
21834 
21835   bool IsFABS = (Op.getOpcode() == ISD::FABS);
21836 
21837   // If this is a FABS and it has an FNEG user, bail out to fold the combination
21838   // into an FNABS. We'll lower the FABS after that if it is still in use.
21839   if (IsFABS)
21840     for (SDNode *User : Op->uses())
21841       if (User->getOpcode() == ISD::FNEG)
21842         return Op;
21843 
21844   SDLoc dl(Op);
21845   MVT VT = Op.getSimpleValueType();
21846 
21847   bool IsF128 = (VT == MVT::f128);
21848   assert(VT.isFloatingPoint() && VT != MVT::f80 &&
21849          DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
21850          "Unexpected type in LowerFABSorFNEG");
21851 
21852   // FIXME: Use function attribute "OptimizeForSize" and/or CodeGenOptLevel to
21853   // decide if we should generate a 16-byte constant mask when we only need 4 or
21854   // 8 bytes for the scalar case.
21855 
21856   // There are no scalar bitwise logical SSE/AVX instructions, so we
21857   // generate a 16-byte vector constant and logic op even for the scalar case.
21858   // Using a 16-byte mask allows folding the load of the mask with
21859   // the logic op, so it can save (~4 bytes) on code size.
21860   bool IsFakeVector = !VT.isVector() && !IsF128;
21861   MVT LogicVT = VT;
21862   if (IsFakeVector)
21863     LogicVT = (VT == MVT::f64)   ? MVT::v2f64
21864               : (VT == MVT::f32) ? MVT::v4f32
21865                                  : MVT::v8f16;
21866 
21867   unsigned EltBits = VT.getScalarSizeInBits();
21868   // For FABS, mask is 0x7f...; for FNEG, mask is 0x80...
21869   APInt MaskElt = IsFABS ? APInt::getSignedMaxValue(EltBits) :
21870                            APInt::getSignMask(EltBits);
21871   const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(VT);
21872   SDValue Mask = DAG.getConstantFP(APFloat(Sem, MaskElt), dl, LogicVT);
21873 
21874   SDValue Op0 = Op.getOperand(0);
21875   bool IsFNABS = !IsFABS && (Op0.getOpcode() == ISD::FABS);
21876   unsigned LogicOp = IsFABS  ? X86ISD::FAND :
21877                      IsFNABS ? X86ISD::FOR  :
21878                                X86ISD::FXOR;
21879   SDValue Operand = IsFNABS ? Op0.getOperand(0) : Op0;
21880 
21881   if (VT.isVector() || IsF128)
21882     return DAG.getNode(LogicOp, dl, LogicVT, Operand, Mask);
21883 
21884   // For the scalar case extend to a 128-bit vector, perform the logic op,
21885   // and extract the scalar result back out.
21886   Operand = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LogicVT, Operand);
21887   SDValue LogicNode = DAG.getNode(LogicOp, dl, LogicVT, Operand, Mask);
21888   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, LogicNode,
21889                      DAG.getIntPtrConstant(0, dl));
21890 }
21891 
LowerFCOPYSIGN(SDValue Op,SelectionDAG & DAG)21892 static SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) {
21893   SDValue Mag = Op.getOperand(0);
21894   SDValue Sign = Op.getOperand(1);
21895   SDLoc dl(Op);
21896 
21897   // If the sign operand is smaller, extend it first.
21898   MVT VT = Op.getSimpleValueType();
21899   if (Sign.getSimpleValueType().bitsLT(VT))
21900     Sign = DAG.getNode(ISD::FP_EXTEND, dl, VT, Sign);
21901 
21902   // And if it is bigger, shrink it first.
21903   if (Sign.getSimpleValueType().bitsGT(VT))
21904     Sign = DAG.getNode(ISD::FP_ROUND, dl, VT, Sign,
21905                        DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
21906 
21907   // At this point the operands and the result should have the same
21908   // type, and that won't be f80 since that is not custom lowered.
21909   bool IsF128 = (VT == MVT::f128);
21910   assert(VT.isFloatingPoint() && VT != MVT::f80 &&
21911          DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
21912          "Unexpected type in LowerFCOPYSIGN");
21913 
21914   const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(VT);
21915 
21916   // Perform all scalar logic operations as 16-byte vectors because there are no
21917   // scalar FP logic instructions in SSE.
21918   // TODO: This isn't necessary. If we used scalar types, we might avoid some
21919   // unnecessary splats, but we might miss load folding opportunities. Should
21920   // this decision be based on OptimizeForSize?
21921   bool IsFakeVector = !VT.isVector() && !IsF128;
21922   MVT LogicVT = VT;
21923   if (IsFakeVector)
21924     LogicVT = (VT == MVT::f64)   ? MVT::v2f64
21925               : (VT == MVT::f32) ? MVT::v4f32
21926                                  : MVT::v8f16;
21927 
21928   // The mask constants are automatically splatted for vector types.
21929   unsigned EltSizeInBits = VT.getScalarSizeInBits();
21930   SDValue SignMask = DAG.getConstantFP(
21931       APFloat(Sem, APInt::getSignMask(EltSizeInBits)), dl, LogicVT);
21932   SDValue MagMask = DAG.getConstantFP(
21933       APFloat(Sem, APInt::getSignedMaxValue(EltSizeInBits)), dl, LogicVT);
21934 
21935   // First, clear all bits but the sign bit from the second operand (sign).
21936   if (IsFakeVector)
21937     Sign = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LogicVT, Sign);
21938   SDValue SignBit = DAG.getNode(X86ISD::FAND, dl, LogicVT, Sign, SignMask);
21939 
21940   // Next, clear the sign bit from the first operand (magnitude).
21941   // TODO: If we had general constant folding for FP logic ops, this check
21942   // wouldn't be necessary.
21943   SDValue MagBits;
21944   if (ConstantFPSDNode *Op0CN = isConstOrConstSplatFP(Mag)) {
21945     APFloat APF = Op0CN->getValueAPF();
21946     APF.clearSign();
21947     MagBits = DAG.getConstantFP(APF, dl, LogicVT);
21948   } else {
21949     // If the magnitude operand wasn't a constant, we need to AND out the sign.
21950     if (IsFakeVector)
21951       Mag = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LogicVT, Mag);
21952     MagBits = DAG.getNode(X86ISD::FAND, dl, LogicVT, Mag, MagMask);
21953   }
21954 
21955   // OR the magnitude value with the sign bit.
21956   SDValue Or = DAG.getNode(X86ISD::FOR, dl, LogicVT, MagBits, SignBit);
21957   return !IsFakeVector ? Or : DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Or,
21958                                           DAG.getIntPtrConstant(0, dl));
21959 }
21960 
LowerFGETSIGN(SDValue Op,SelectionDAG & DAG)21961 static SDValue LowerFGETSIGN(SDValue Op, SelectionDAG &DAG) {
21962   SDValue N0 = Op.getOperand(0);
21963   SDLoc dl(Op);
21964   MVT VT = Op.getSimpleValueType();
21965 
21966   MVT OpVT = N0.getSimpleValueType();
21967   assert((OpVT == MVT::f32 || OpVT == MVT::f64) &&
21968          "Unexpected type for FGETSIGN");
21969 
21970   // Lower ISD::FGETSIGN to (AND (X86ISD::MOVMSK ...) 1).
21971   MVT VecVT = (OpVT == MVT::f32 ? MVT::v4f32 : MVT::v2f64);
21972   SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, N0);
21973   Res = DAG.getNode(X86ISD::MOVMSK, dl, MVT::i32, Res);
21974   Res = DAG.getZExtOrTrunc(Res, dl, VT);
21975   Res = DAG.getNode(ISD::AND, dl, VT, Res, DAG.getConstant(1, dl, VT));
21976   return Res;
21977 }
21978 
21979 /// Helper for attempting to create a X86ISD::BT node.
getBT(SDValue Src,SDValue BitNo,const SDLoc & DL,SelectionDAG & DAG)21980 static SDValue getBT(SDValue Src, SDValue BitNo, const SDLoc &DL, SelectionDAG &DAG) {
21981   // If Src is i8, promote it to i32 with any_extend.  There is no i8 BT
21982   // instruction.  Since the shift amount is in-range-or-undefined, we know
21983   // that doing a bittest on the i32 value is ok.  We extend to i32 because
21984   // the encoding for the i16 version is larger than the i32 version.
21985   // Also promote i16 to i32 for performance / code size reason.
21986   if (Src.getValueType().getScalarSizeInBits() < 32)
21987     Src = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Src);
21988 
21989   // No legal type found, give up.
21990   if (!DAG.getTargetLoweringInfo().isTypeLegal(Src.getValueType()))
21991     return SDValue();
21992 
21993   // See if we can use the 32-bit instruction instead of the 64-bit one for a
21994   // shorter encoding. Since the former takes the modulo 32 of BitNo and the
21995   // latter takes the modulo 64, this is only valid if the 5th bit of BitNo is
21996   // known to be zero.
21997   if (Src.getValueType() == MVT::i64 &&
21998       DAG.MaskedValueIsZero(BitNo, APInt(BitNo.getValueSizeInBits(), 32)))
21999     Src = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Src);
22000 
22001   // If the operand types disagree, extend the shift amount to match.  Since
22002   // BT ignores high bits (like shifts) we can use anyextend.
22003   if (Src.getValueType() != BitNo.getValueType()) {
22004     // Peek through a mask/modulo operation.
22005     // TODO: DAGCombine fails to do this as it just checks isTruncateFree, but
22006     // we probably need a better IsDesirableToPromoteOp to handle this as well.
22007     if (BitNo.getOpcode() == ISD::AND && BitNo->hasOneUse())
22008       BitNo = DAG.getNode(ISD::AND, DL, Src.getValueType(),
22009                           DAG.getNode(ISD::ANY_EXTEND, DL, Src.getValueType(),
22010                                       BitNo.getOperand(0)),
22011                           DAG.getNode(ISD::ANY_EXTEND, DL, Src.getValueType(),
22012                                       BitNo.getOperand(1)));
22013     else
22014       BitNo = DAG.getNode(ISD::ANY_EXTEND, DL, Src.getValueType(), BitNo);
22015   }
22016 
22017   return DAG.getNode(X86ISD::BT, DL, MVT::i32, Src, BitNo);
22018 }
22019 
22020 /// Helper for creating a X86ISD::SETCC node.
getSETCC(X86::CondCode Cond,SDValue EFLAGS,const SDLoc & dl,SelectionDAG & DAG)22021 static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
22022                         SelectionDAG &DAG) {
22023   return DAG.getNode(X86ISD::SETCC, dl, MVT::i8,
22024                      DAG.getTargetConstant(Cond, dl, MVT::i8), EFLAGS);
22025 }
22026 
22027 /// Recursive helper for combineVectorSizedSetCCEquality() to see if we have a
22028 /// recognizable memcmp expansion.
isOrXorXorTree(SDValue X,bool Root=true)22029 static bool isOrXorXorTree(SDValue X, bool Root = true) {
22030   if (X.getOpcode() == ISD::OR)
22031     return isOrXorXorTree(X.getOperand(0), false) &&
22032            isOrXorXorTree(X.getOperand(1), false);
22033   if (Root)
22034     return false;
22035   return X.getOpcode() == ISD::XOR;
22036 }
22037 
22038 /// Recursive helper for combineVectorSizedSetCCEquality() to emit the memcmp
22039 /// expansion.
22040 template <typename F>
emitOrXorXorTree(SDValue X,const SDLoc & DL,SelectionDAG & DAG,EVT VecVT,EVT CmpVT,bool HasPT,F SToV)22041 static SDValue emitOrXorXorTree(SDValue X, const SDLoc &DL, SelectionDAG &DAG,
22042                                 EVT VecVT, EVT CmpVT, bool HasPT, F SToV) {
22043   SDValue Op0 = X.getOperand(0);
22044   SDValue Op1 = X.getOperand(1);
22045   if (X.getOpcode() == ISD::OR) {
22046     SDValue A = emitOrXorXorTree(Op0, DL, DAG, VecVT, CmpVT, HasPT, SToV);
22047     SDValue B = emitOrXorXorTree(Op1, DL, DAG, VecVT, CmpVT, HasPT, SToV);
22048     if (VecVT != CmpVT)
22049       return DAG.getNode(ISD::OR, DL, CmpVT, A, B);
22050     if (HasPT)
22051       return DAG.getNode(ISD::OR, DL, VecVT, A, B);
22052     return DAG.getNode(ISD::AND, DL, CmpVT, A, B);
22053   }
22054   if (X.getOpcode() == ISD::XOR) {
22055     SDValue A = SToV(Op0);
22056     SDValue B = SToV(Op1);
22057     if (VecVT != CmpVT)
22058       return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETNE);
22059     if (HasPT)
22060       return DAG.getNode(ISD::XOR, DL, VecVT, A, B);
22061     return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETEQ);
22062   }
22063   llvm_unreachable("Impossible");
22064 }
22065 
22066 /// Try to map a 128-bit or larger integer comparison to vector instructions
22067 /// before type legalization splits it up into chunks.
combineVectorSizedSetCCEquality(EVT VT,SDValue X,SDValue Y,ISD::CondCode CC,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)22068 static SDValue combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y,
22069                                                ISD::CondCode CC,
22070                                                const SDLoc &DL,
22071                                                SelectionDAG &DAG,
22072                                                const X86Subtarget &Subtarget) {
22073   assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate");
22074 
22075   // We're looking for an oversized integer equality comparison.
22076   EVT OpVT = X.getValueType();
22077   unsigned OpSize = OpVT.getSizeInBits();
22078   if (!OpVT.isScalarInteger() || OpSize < 128)
22079     return SDValue();
22080 
22081   // Ignore a comparison with zero because that gets special treatment in
22082   // EmitTest(). But make an exception for the special case of a pair of
22083   // logically-combined vector-sized operands compared to zero. This pattern may
22084   // be generated by the memcmp expansion pass with oversized integer compares
22085   // (see PR33325).
22086   bool IsOrXorXorTreeCCZero = isNullConstant(Y) && isOrXorXorTree(X);
22087   if (isNullConstant(Y) && !IsOrXorXorTreeCCZero)
22088     return SDValue();
22089 
22090   // Don't perform this combine if constructing the vector will be expensive.
22091   auto IsVectorBitCastCheap = [](SDValue X) {
22092     X = peekThroughBitcasts(X);
22093     return isa<ConstantSDNode>(X) || X.getValueType().isVector() ||
22094            X.getOpcode() == ISD::LOAD;
22095   };
22096   if ((!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y)) &&
22097       !IsOrXorXorTreeCCZero)
22098     return SDValue();
22099 
22100   // Use XOR (plus OR) and PTEST after SSE4.1 for 128/256-bit operands.
22101   // Use PCMPNEQ (plus OR) and KORTEST for 512-bit operands.
22102   // Otherwise use PCMPEQ (plus AND) and mask testing.
22103   bool NoImplicitFloatOps =
22104       DAG.getMachineFunction().getFunction().hasFnAttribute(
22105           Attribute::NoImplicitFloat);
22106   if (!Subtarget.useSoftFloat() && !NoImplicitFloatOps &&
22107       ((OpSize == 128 && Subtarget.hasSSE2()) ||
22108        (OpSize == 256 && Subtarget.hasAVX()) ||
22109        (OpSize == 512 && Subtarget.useAVX512Regs()))) {
22110     bool HasPT = Subtarget.hasSSE41();
22111 
22112     // PTEST and MOVMSK are slow on Knights Landing and Knights Mill and widened
22113     // vector registers are essentially free. (Technically, widening registers
22114     // prevents load folding, but the tradeoff is worth it.)
22115     bool PreferKOT = Subtarget.preferMaskRegisters();
22116     bool NeedZExt = PreferKOT && !Subtarget.hasVLX() && OpSize != 512;
22117 
22118     EVT VecVT = MVT::v16i8;
22119     EVT CmpVT = PreferKOT ? MVT::v16i1 : VecVT;
22120     if (OpSize == 256) {
22121       VecVT = MVT::v32i8;
22122       CmpVT = PreferKOT ? MVT::v32i1 : VecVT;
22123     }
22124     EVT CastVT = VecVT;
22125     bool NeedsAVX512FCast = false;
22126     if (OpSize == 512 || NeedZExt) {
22127       if (Subtarget.hasBWI()) {
22128         VecVT = MVT::v64i8;
22129         CmpVT = MVT::v64i1;
22130         if (OpSize == 512)
22131           CastVT = VecVT;
22132       } else {
22133         VecVT = MVT::v16i32;
22134         CmpVT = MVT::v16i1;
22135         CastVT = OpSize == 512   ? VecVT
22136                  : OpSize == 256 ? MVT::v8i32
22137                                  : MVT::v4i32;
22138         NeedsAVX512FCast = true;
22139       }
22140     }
22141 
22142     auto ScalarToVector = [&](SDValue X) -> SDValue {
22143       bool TmpZext = false;
22144       EVT TmpCastVT = CastVT;
22145       if (X.getOpcode() == ISD::ZERO_EXTEND) {
22146         SDValue OrigX = X.getOperand(0);
22147         unsigned OrigSize = OrigX.getScalarValueSizeInBits();
22148         if (OrigSize < OpSize) {
22149           if (OrigSize == 128) {
22150             TmpCastVT = NeedsAVX512FCast ? MVT::v4i32 : MVT::v16i8;
22151             X = OrigX;
22152             TmpZext = true;
22153           } else if (OrigSize == 256) {
22154             TmpCastVT = NeedsAVX512FCast ? MVT::v8i32 : MVT::v32i8;
22155             X = OrigX;
22156             TmpZext = true;
22157           }
22158         }
22159       }
22160       X = DAG.getBitcast(TmpCastVT, X);
22161       if (!NeedZExt && !TmpZext)
22162         return X;
22163       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT,
22164                          DAG.getConstant(0, DL, VecVT), X,
22165                          DAG.getVectorIdxConstant(0, DL));
22166     };
22167 
22168     SDValue Cmp;
22169     if (IsOrXorXorTreeCCZero) {
22170       // This is a bitwise-combined equality comparison of 2 pairs of vectors:
22171       // setcc i128 (or (xor A, B), (xor C, D)), 0, eq|ne
22172       // Use 2 vector equality compares and 'and' the results before doing a
22173       // MOVMSK.
22174       Cmp = emitOrXorXorTree(X, DL, DAG, VecVT, CmpVT, HasPT, ScalarToVector);
22175     } else {
22176       SDValue VecX = ScalarToVector(X);
22177       SDValue VecY = ScalarToVector(Y);
22178       if (VecVT != CmpVT) {
22179         Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE);
22180       } else if (HasPT) {
22181         Cmp = DAG.getNode(ISD::XOR, DL, VecVT, VecX, VecY);
22182       } else {
22183         Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETEQ);
22184       }
22185     }
22186     // AVX512 should emit a setcc that will lower to kortest.
22187     if (VecVT != CmpVT) {
22188       EVT KRegVT = CmpVT == MVT::v64i1   ? MVT::i64
22189                    : CmpVT == MVT::v32i1 ? MVT::i32
22190                                          : MVT::i16;
22191       return DAG.getSetCC(DL, VT, DAG.getBitcast(KRegVT, Cmp),
22192                           DAG.getConstant(0, DL, KRegVT), CC);
22193     }
22194     if (HasPT) {
22195       SDValue BCCmp =
22196           DAG.getBitcast(OpSize == 256 ? MVT::v4i64 : MVT::v2i64, Cmp);
22197       SDValue PT = DAG.getNode(X86ISD::PTEST, DL, MVT::i32, BCCmp, BCCmp);
22198       X86::CondCode X86CC = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE;
22199       SDValue X86SetCC = getSETCC(X86CC, PT, DL, DAG);
22200       return DAG.getNode(ISD::TRUNCATE, DL, VT, X86SetCC.getValue(0));
22201     }
22202     // If all bytes match (bitmask is 0x(FFFF)FFFF), that's equality.
22203     // setcc i128 X, Y, eq --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, eq
22204     // setcc i128 X, Y, ne --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, ne
22205     assert(Cmp.getValueType() == MVT::v16i8 &&
22206            "Non 128-bit vector on pre-SSE41 target");
22207     SDValue MovMsk = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Cmp);
22208     SDValue FFFFs = DAG.getConstant(0xFFFF, DL, MVT::i32);
22209     return DAG.getSetCC(DL, VT, MovMsk, FFFFs, CC);
22210   }
22211 
22212   return SDValue();
22213 }
22214 
22215 /// Helper for matching BINOP(EXTRACTELT(X,0),BINOP(EXTRACTELT(X,1),...))
22216 /// style scalarized (associative) reduction patterns. Partial reductions
22217 /// are supported when the pointer SrcMask is non-null.
22218 /// TODO - move this to SelectionDAG?
matchScalarReduction(SDValue Op,ISD::NodeType BinOp,SmallVectorImpl<SDValue> & SrcOps,SmallVectorImpl<APInt> * SrcMask=nullptr)22219 static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
22220                                  SmallVectorImpl<SDValue> &SrcOps,
22221                                  SmallVectorImpl<APInt> *SrcMask = nullptr) {
22222   SmallVector<SDValue, 8> Opnds;
22223   DenseMap<SDValue, APInt> SrcOpMap;
22224   EVT VT = MVT::Other;
22225 
22226   // Recognize a special case where a vector is casted into wide integer to
22227   // test all 0s.
22228   assert(Op.getOpcode() == unsigned(BinOp) &&
22229          "Unexpected bit reduction opcode");
22230   Opnds.push_back(Op.getOperand(0));
22231   Opnds.push_back(Op.getOperand(1));
22232 
22233   for (unsigned Slot = 0, e = Opnds.size(); Slot < e; ++Slot) {
22234     SmallVectorImpl<SDValue>::const_iterator I = Opnds.begin() + Slot;
22235     // BFS traverse all BinOp operands.
22236     if (I->getOpcode() == unsigned(BinOp)) {
22237       Opnds.push_back(I->getOperand(0));
22238       Opnds.push_back(I->getOperand(1));
22239       // Re-evaluate the number of nodes to be traversed.
22240       e += 2; // 2 more nodes (LHS and RHS) are pushed.
22241       continue;
22242     }
22243 
22244     // Quit if a non-EXTRACT_VECTOR_ELT
22245     if (I->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
22246       return false;
22247 
22248     // Quit if without a constant index.
22249     auto *Idx = dyn_cast<ConstantSDNode>(I->getOperand(1));
22250     if (!Idx)
22251       return false;
22252 
22253     SDValue Src = I->getOperand(0);
22254     DenseMap<SDValue, APInt>::iterator M = SrcOpMap.find(Src);
22255     if (M == SrcOpMap.end()) {
22256       VT = Src.getValueType();
22257       // Quit if not the same type.
22258       if (!SrcOpMap.empty() && VT != SrcOpMap.begin()->first.getValueType())
22259         return false;
22260       unsigned NumElts = VT.getVectorNumElements();
22261       APInt EltCount = APInt::getZero(NumElts);
22262       M = SrcOpMap.insert(std::make_pair(Src, EltCount)).first;
22263       SrcOps.push_back(Src);
22264     }
22265 
22266     // Quit if element already used.
22267     unsigned CIdx = Idx->getZExtValue();
22268     if (M->second[CIdx])
22269       return false;
22270     M->second.setBit(CIdx);
22271   }
22272 
22273   if (SrcMask) {
22274     // Collect the source partial masks.
22275     for (SDValue &SrcOp : SrcOps)
22276       SrcMask->push_back(SrcOpMap[SrcOp]);
22277   } else {
22278     // Quit if not all elements are used.
22279     for (const auto &I : SrcOpMap)
22280       if (!I.second.isAllOnes())
22281         return false;
22282   }
22283 
22284   return true;
22285 }
22286 
22287 // Helper function for comparing all bits of two vectors.
LowerVectorAllEqual(const SDLoc & DL,SDValue LHS,SDValue RHS,ISD::CondCode CC,const APInt & OriginalMask,const X86Subtarget & Subtarget,SelectionDAG & DAG,X86::CondCode & X86CC)22288 static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
22289                                    ISD::CondCode CC, const APInt &OriginalMask,
22290                                    const X86Subtarget &Subtarget,
22291                                    SelectionDAG &DAG, X86::CondCode &X86CC) {
22292   EVT VT = LHS.getValueType();
22293   unsigned ScalarSize = VT.getScalarSizeInBits();
22294   if (OriginalMask.getBitWidth() != ScalarSize) {
22295     assert(ScalarSize == 1 && "Element Mask vs Vector bitwidth mismatch");
22296     return SDValue();
22297   }
22298 
22299   // Quit if not convertable to legal scalar or 128/256-bit vector.
22300   if (!llvm::has_single_bit<uint32_t>(VT.getSizeInBits()))
22301     return SDValue();
22302 
22303   // FCMP may use ISD::SETNE when nnan - early out if we manage to get here.
22304   if (VT.isFloatingPoint())
22305     return SDValue();
22306 
22307   assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
22308   X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE);
22309 
22310   APInt Mask = OriginalMask;
22311 
22312   auto MaskBits = [&](SDValue Src) {
22313     if (Mask.isAllOnes())
22314       return Src;
22315     EVT SrcVT = Src.getValueType();
22316     SDValue MaskValue = DAG.getConstant(Mask, DL, SrcVT);
22317     return DAG.getNode(ISD::AND, DL, SrcVT, Src, MaskValue);
22318   };
22319 
22320   // For sub-128-bit vector, cast to (legal) integer and compare with zero.
22321   if (VT.getSizeInBits() < 128) {
22322     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
22323     if (!DAG.getTargetLoweringInfo().isTypeLegal(IntVT)) {
22324       if (IntVT != MVT::i64)
22325         return SDValue();
22326       auto SplitLHS = DAG.SplitScalar(DAG.getBitcast(IntVT, MaskBits(LHS)), DL,
22327                                       MVT::i32, MVT::i32);
22328       auto SplitRHS = DAG.SplitScalar(DAG.getBitcast(IntVT, MaskBits(RHS)), DL,
22329                                       MVT::i32, MVT::i32);
22330       SDValue Lo =
22331           DAG.getNode(ISD::XOR, DL, MVT::i32, SplitLHS.first, SplitRHS.first);
22332       SDValue Hi =
22333           DAG.getNode(ISD::XOR, DL, MVT::i32, SplitLHS.second, SplitRHS.second);
22334       return DAG.getNode(X86ISD::CMP, DL, MVT::i32,
22335                          DAG.getNode(ISD::OR, DL, MVT::i32, Lo, Hi),
22336                          DAG.getConstant(0, DL, MVT::i32));
22337     }
22338     return DAG.getNode(X86ISD::CMP, DL, MVT::i32,
22339                        DAG.getBitcast(IntVT, MaskBits(LHS)),
22340                        DAG.getBitcast(IntVT, MaskBits(RHS)));
22341   }
22342 
22343   // Without PTEST, a masked v2i64 or-reduction is not faster than
22344   // scalarization.
22345   bool UseKORTEST = Subtarget.useAVX512Regs();
22346   bool UsePTEST = Subtarget.hasSSE41();
22347   if (!UsePTEST && !Mask.isAllOnes() && ScalarSize > 32)
22348     return SDValue();
22349 
22350   // Split down to 128/256/512-bit vector.
22351   unsigned TestSize = UseKORTEST ? 512 : (Subtarget.hasAVX() ? 256 : 128);
22352 
22353   // If the input vector has vector elements wider than the target test size,
22354   // then cast to <X x i64> so it will safely split.
22355   if (ScalarSize > TestSize) {
22356     if (!Mask.isAllOnes())
22357       return SDValue();
22358     VT = EVT::getVectorVT(*DAG.getContext(), MVT::i64, VT.getSizeInBits() / 64);
22359     LHS = DAG.getBitcast(VT, LHS);
22360     RHS = DAG.getBitcast(VT, RHS);
22361     Mask = APInt::getAllOnes(64);
22362   }
22363 
22364   if (VT.getSizeInBits() > TestSize) {
22365     KnownBits KnownRHS = DAG.computeKnownBits(RHS);
22366     if (KnownRHS.isConstant() && KnownRHS.getConstant() == Mask) {
22367       // If ICMP(AND(LHS,MASK),MASK) - reduce using AND splits.
22368       while (VT.getSizeInBits() > TestSize) {
22369         auto Split = DAG.SplitVector(LHS, DL);
22370         VT = Split.first.getValueType();
22371         LHS = DAG.getNode(ISD::AND, DL, VT, Split.first, Split.second);
22372       }
22373       RHS = DAG.getAllOnesConstant(DL, VT);
22374     } else if (!UsePTEST && !KnownRHS.isZero()) {
22375       // MOVMSK Special Case:
22376       // ALLOF(CMPEQ(X,Y)) -> AND(CMPEQ(X[0],Y[0]),CMPEQ(X[1],Y[1]),....)
22377       MVT SVT = ScalarSize >= 32 ? MVT::i32 : MVT::i8;
22378       VT = MVT::getVectorVT(SVT, VT.getSizeInBits() / SVT.getSizeInBits());
22379       LHS = DAG.getBitcast(VT, MaskBits(LHS));
22380       RHS = DAG.getBitcast(VT, MaskBits(RHS));
22381       EVT BoolVT = VT.changeVectorElementType(MVT::i1);
22382       SDValue V = DAG.getSetCC(DL, BoolVT, LHS, RHS, ISD::SETEQ);
22383       V = DAG.getSExtOrTrunc(V, DL, VT);
22384       while (VT.getSizeInBits() > TestSize) {
22385         auto Split = DAG.SplitVector(V, DL);
22386         VT = Split.first.getValueType();
22387         V = DAG.getNode(ISD::AND, DL, VT, Split.first, Split.second);
22388       }
22389       V = DAG.getNOT(DL, V, VT);
22390       V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V);
22391       return DAG.getNode(X86ISD::CMP, DL, MVT::i32, V,
22392                          DAG.getConstant(0, DL, MVT::i32));
22393     } else {
22394       // Convert to a ICMP_EQ(XOR(LHS,RHS),0) pattern.
22395       SDValue V = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS);
22396       while (VT.getSizeInBits() > TestSize) {
22397         auto Split = DAG.SplitVector(V, DL);
22398         VT = Split.first.getValueType();
22399         V = DAG.getNode(ISD::OR, DL, VT, Split.first, Split.second);
22400       }
22401       LHS = V;
22402       RHS = DAG.getConstant(0, DL, VT);
22403     }
22404   }
22405 
22406   if (UseKORTEST && VT.is512BitVector()) {
22407     MVT TestVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32);
22408     MVT BoolVT = TestVT.changeVectorElementType(MVT::i1);
22409     LHS = DAG.getBitcast(TestVT, MaskBits(LHS));
22410     RHS = DAG.getBitcast(TestVT, MaskBits(RHS));
22411     SDValue V = DAG.getSetCC(DL, BoolVT, LHS, RHS, ISD::SETNE);
22412     return DAG.getNode(X86ISD::KORTEST, DL, MVT::i32, V, V);
22413   }
22414 
22415   if (UsePTEST) {
22416     MVT TestVT = MVT::getVectorVT(MVT::i64, VT.getSizeInBits() / 64);
22417     LHS = DAG.getBitcast(TestVT, MaskBits(LHS));
22418     RHS = DAG.getBitcast(TestVT, MaskBits(RHS));
22419     SDValue V = DAG.getNode(ISD::XOR, DL, TestVT, LHS, RHS);
22420     return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, V, V);
22421   }
22422 
22423   assert(VT.getSizeInBits() == 128 && "Failure to split to 128-bits");
22424   MVT MaskVT = ScalarSize >= 32 ? MVT::v4i32 : MVT::v16i8;
22425   LHS = DAG.getBitcast(MaskVT, MaskBits(LHS));
22426   RHS = DAG.getBitcast(MaskVT, MaskBits(RHS));
22427   SDValue V = DAG.getNode(X86ISD::PCMPEQ, DL, MaskVT, LHS, RHS);
22428   V = DAG.getNOT(DL, V, MaskVT);
22429   V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V);
22430   return DAG.getNode(X86ISD::CMP, DL, MVT::i32, V,
22431                      DAG.getConstant(0, DL, MVT::i32));
22432 }
22433 
22434 // Check whether an AND/OR'd reduction tree is PTEST-able, or if we can fallback
22435 // to CMP(MOVMSK(PCMPEQB(X,Y))).
MatchVectorAllEqualTest(SDValue LHS,SDValue RHS,ISD::CondCode CC,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG,X86::CondCode & X86CC)22436 static SDValue MatchVectorAllEqualTest(SDValue LHS, SDValue RHS,
22437                                        ISD::CondCode CC, const SDLoc &DL,
22438                                        const X86Subtarget &Subtarget,
22439                                        SelectionDAG &DAG,
22440                                        X86::CondCode &X86CC) {
22441   assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
22442 
22443   bool CmpNull = isNullConstant(RHS);
22444   bool CmpAllOnes = isAllOnesConstant(RHS);
22445   if (!CmpNull && !CmpAllOnes)
22446     return SDValue();
22447 
22448   SDValue Op = LHS;
22449   if (!Subtarget.hasSSE2() || !Op->hasOneUse())
22450     return SDValue();
22451 
22452   // Check whether we're masking/truncating an OR-reduction result, in which
22453   // case track the masked bits.
22454   // TODO: Add CmpAllOnes support.
22455   APInt Mask = APInt::getAllOnes(Op.getScalarValueSizeInBits());
22456   if (CmpNull) {
22457     switch (Op.getOpcode()) {
22458     case ISD::TRUNCATE: {
22459       SDValue Src = Op.getOperand(0);
22460       Mask = APInt::getLowBitsSet(Src.getScalarValueSizeInBits(),
22461                                   Op.getScalarValueSizeInBits());
22462       Op = Src;
22463       break;
22464     }
22465     case ISD::AND: {
22466       if (auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
22467         Mask = Cst->getAPIntValue();
22468         Op = Op.getOperand(0);
22469       }
22470       break;
22471     }
22472     }
22473   }
22474 
22475   ISD::NodeType LogicOp = CmpNull ? ISD::OR : ISD::AND;
22476 
22477   // Match icmp(or(extract(X,0),extract(X,1)),0) anyof reduction patterns.
22478   // Match icmp(and(extract(X,0),extract(X,1)),-1) allof reduction patterns.
22479   SmallVector<SDValue, 8> VecIns;
22480   if (Op.getOpcode() == LogicOp && matchScalarReduction(Op, LogicOp, VecIns)) {
22481     EVT VT = VecIns[0].getValueType();
22482     assert(llvm::all_of(VecIns,
22483                         [VT](SDValue V) { return VT == V.getValueType(); }) &&
22484            "Reduction source vector mismatch");
22485 
22486     // Quit if not splittable to scalar/128/256/512-bit vector.
22487     if (!llvm::has_single_bit<uint32_t>(VT.getSizeInBits()))
22488       return SDValue();
22489 
22490     // If more than one full vector is evaluated, AND/OR them first before
22491     // PTEST.
22492     for (unsigned Slot = 0, e = VecIns.size(); e - Slot > 1;
22493          Slot += 2, e += 1) {
22494       // Each iteration will AND/OR 2 nodes and append the result until there is
22495       // only 1 node left, i.e. the final value of all vectors.
22496       SDValue LHS = VecIns[Slot];
22497       SDValue RHS = VecIns[Slot + 1];
22498       VecIns.push_back(DAG.getNode(LogicOp, DL, VT, LHS, RHS));
22499     }
22500 
22501     return LowerVectorAllEqual(DL, VecIns.back(),
22502                                CmpNull ? DAG.getConstant(0, DL, VT)
22503                                        : DAG.getAllOnesConstant(DL, VT),
22504                                CC, Mask, Subtarget, DAG, X86CC);
22505   }
22506 
22507   // Match icmp(reduce_or(X),0) anyof reduction patterns.
22508   // Match icmp(reduce_and(X),-1) allof reduction patterns.
22509   if (Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
22510     ISD::NodeType BinOp;
22511     if (SDValue Match =
22512             DAG.matchBinOpReduction(Op.getNode(), BinOp, {LogicOp})) {
22513       EVT MatchVT = Match.getValueType();
22514       return LowerVectorAllEqual(DL, Match,
22515                                  CmpNull ? DAG.getConstant(0, DL, MatchVT)
22516                                          : DAG.getAllOnesConstant(DL, MatchVT),
22517                                  CC, Mask, Subtarget, DAG, X86CC);
22518     }
22519   }
22520 
22521   if (Mask.isAllOnes()) {
22522     assert(!Op.getValueType().isVector() &&
22523            "Illegal vector type for reduction pattern");
22524     SDValue Src = peekThroughBitcasts(Op);
22525     if (Src.getValueType().isFixedLengthVector() &&
22526         Src.getValueType().getScalarType() == MVT::i1) {
22527       // Match icmp(bitcast(icmp_ne(X,Y)),0) reduction patterns.
22528       // Match icmp(bitcast(icmp_eq(X,Y)),-1) reduction patterns.
22529       if (Src.getOpcode() == ISD::SETCC) {
22530         SDValue LHS = Src.getOperand(0);
22531         SDValue RHS = Src.getOperand(1);
22532         EVT LHSVT = LHS.getValueType();
22533         ISD::CondCode SrcCC = cast<CondCodeSDNode>(Src.getOperand(2))->get();
22534         if (SrcCC == (CmpNull ? ISD::SETNE : ISD::SETEQ) &&
22535             llvm::has_single_bit<uint32_t>(LHSVT.getSizeInBits())) {
22536           APInt SrcMask = APInt::getAllOnes(LHSVT.getScalarSizeInBits());
22537           return LowerVectorAllEqual(DL, LHS, RHS, CC, SrcMask, Subtarget, DAG,
22538                                      X86CC);
22539         }
22540       }
22541       // Match icmp(bitcast(vXi1 trunc(Y)),0) reduction patterns.
22542       // Match icmp(bitcast(vXi1 trunc(Y)),-1) reduction patterns.
22543       // Peek through truncation, mask the LSB and compare against zero/LSB.
22544       if (Src.getOpcode() == ISD::TRUNCATE) {
22545         SDValue Inner = Src.getOperand(0);
22546         EVT InnerVT = Inner.getValueType();
22547         if (llvm::has_single_bit<uint32_t>(InnerVT.getSizeInBits())) {
22548           unsigned BW = InnerVT.getScalarSizeInBits();
22549           APInt SrcMask = APInt(BW, 1);
22550           APInt Cmp = CmpNull ? APInt::getZero(BW) : SrcMask;
22551           return LowerVectorAllEqual(DL, Inner,
22552                                      DAG.getConstant(Cmp, DL, InnerVT), CC,
22553                                      SrcMask, Subtarget, DAG, X86CC);
22554         }
22555       }
22556     }
22557   }
22558 
22559   return SDValue();
22560 }
22561 
22562 /// return true if \c Op has a use that doesn't just read flags.
hasNonFlagsUse(SDValue Op)22563 static bool hasNonFlagsUse(SDValue Op) {
22564   for (SDNode::use_iterator UI = Op->use_begin(), UE = Op->use_end(); UI != UE;
22565        ++UI) {
22566     SDNode *User = *UI;
22567     unsigned UOpNo = UI.getOperandNo();
22568     if (User->getOpcode() == ISD::TRUNCATE && User->hasOneUse()) {
22569       // Look pass truncate.
22570       UOpNo = User->use_begin().getOperandNo();
22571       User = *User->use_begin();
22572     }
22573 
22574     if (User->getOpcode() != ISD::BRCOND && User->getOpcode() != ISD::SETCC &&
22575         !(User->getOpcode() == ISD::SELECT && UOpNo == 0))
22576       return true;
22577   }
22578   return false;
22579 }
22580 
22581 // Transform to an x86-specific ALU node with flags if there is a chance of
22582 // using an RMW op or only the flags are used. Otherwise, leave
22583 // the node alone and emit a 'cmp' or 'test' instruction.
isProfitableToUseFlagOp(SDValue Op)22584 static bool isProfitableToUseFlagOp(SDValue Op) {
22585   for (SDNode *U : Op->uses())
22586     if (U->getOpcode() != ISD::CopyToReg &&
22587         U->getOpcode() != ISD::SETCC &&
22588         U->getOpcode() != ISD::STORE)
22589       return false;
22590 
22591   return true;
22592 }
22593 
22594 /// Emit nodes that will be selected as "test Op0,Op0", or something
22595 /// equivalent.
EmitTest(SDValue Op,unsigned X86CC,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)22596 static SDValue EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl,
22597                         SelectionDAG &DAG, const X86Subtarget &Subtarget) {
22598   // CF and OF aren't always set the way we want. Determine which
22599   // of these we need.
22600   bool NeedCF = false;
22601   bool NeedOF = false;
22602   switch (X86CC) {
22603   default: break;
22604   case X86::COND_A: case X86::COND_AE:
22605   case X86::COND_B: case X86::COND_BE:
22606     NeedCF = true;
22607     break;
22608   case X86::COND_G: case X86::COND_GE:
22609   case X86::COND_L: case X86::COND_LE:
22610   case X86::COND_O: case X86::COND_NO: {
22611     // Check if we really need to set the
22612     // Overflow flag. If NoSignedWrap is present
22613     // that is not actually needed.
22614     switch (Op->getOpcode()) {
22615     case ISD::ADD:
22616     case ISD::SUB:
22617     case ISD::MUL:
22618     case ISD::SHL:
22619       if (Op.getNode()->getFlags().hasNoSignedWrap())
22620         break;
22621       [[fallthrough]];
22622     default:
22623       NeedOF = true;
22624       break;
22625     }
22626     break;
22627   }
22628   }
22629   // See if we can use the EFLAGS value from the operand instead of
22630   // doing a separate TEST. TEST always sets OF and CF to 0, so unless
22631   // we prove that the arithmetic won't overflow, we can't use OF or CF.
22632   if (Op.getResNo() != 0 || NeedOF || NeedCF) {
22633     // Emit a CMP with 0, which is the TEST pattern.
22634     return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
22635                        DAG.getConstant(0, dl, Op.getValueType()));
22636   }
22637   unsigned Opcode = 0;
22638   unsigned NumOperands = 0;
22639 
22640   SDValue ArithOp = Op;
22641 
22642   // NOTICE: In the code below we use ArithOp to hold the arithmetic operation
22643   // which may be the result of a CAST.  We use the variable 'Op', which is the
22644   // non-casted variable when we check for possible users.
22645   switch (ArithOp.getOpcode()) {
22646   case ISD::AND:
22647     // If the primary 'and' result isn't used, don't bother using X86ISD::AND,
22648     // because a TEST instruction will be better.
22649     if (!hasNonFlagsUse(Op))
22650       break;
22651 
22652     [[fallthrough]];
22653   case ISD::ADD:
22654   case ISD::SUB:
22655   case ISD::OR:
22656   case ISD::XOR:
22657     if (!isProfitableToUseFlagOp(Op))
22658       break;
22659 
22660     // Otherwise use a regular EFLAGS-setting instruction.
22661     switch (ArithOp.getOpcode()) {
22662     // clang-format off
22663     default: llvm_unreachable("unexpected operator!");
22664     case ISD::ADD: Opcode = X86ISD::ADD; break;
22665     case ISD::SUB: Opcode = X86ISD::SUB; break;
22666     case ISD::XOR: Opcode = X86ISD::XOR; break;
22667     case ISD::AND: Opcode = X86ISD::AND; break;
22668     case ISD::OR:  Opcode = X86ISD::OR;  break;
22669     // clang-format on
22670     }
22671 
22672     NumOperands = 2;
22673     break;
22674   case X86ISD::ADD:
22675   case X86ISD::SUB:
22676   case X86ISD::OR:
22677   case X86ISD::XOR:
22678   case X86ISD::AND:
22679     return SDValue(Op.getNode(), 1);
22680   case ISD::SSUBO:
22681   case ISD::USUBO: {
22682     // /USUBO/SSUBO will become a X86ISD::SUB and we can use its Z flag.
22683     SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
22684     return DAG.getNode(X86ISD::SUB, dl, VTs, Op->getOperand(0),
22685                        Op->getOperand(1)).getValue(1);
22686   }
22687   default:
22688     break;
22689   }
22690 
22691   if (Opcode == 0) {
22692     // Emit a CMP with 0, which is the TEST pattern.
22693     return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
22694                        DAG.getConstant(0, dl, Op.getValueType()));
22695   }
22696   SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
22697   SmallVector<SDValue, 4> Ops(Op->op_begin(), Op->op_begin() + NumOperands);
22698 
22699   SDValue New = DAG.getNode(Opcode, dl, VTs, Ops);
22700   DAG.ReplaceAllUsesOfValueWith(SDValue(Op.getNode(), 0), New);
22701   return SDValue(New.getNode(), 1);
22702 }
22703 
22704 /// Emit nodes that will be selected as "cmp Op0,Op1", or something
22705 /// equivalent.
EmitCmp(SDValue Op0,SDValue Op1,unsigned X86CC,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget)22706 static SDValue EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC,
22707                        const SDLoc &dl, SelectionDAG &DAG,
22708                        const X86Subtarget &Subtarget) {
22709   if (isNullConstant(Op1))
22710     return EmitTest(Op0, X86CC, dl, DAG, Subtarget);
22711 
22712   EVT CmpVT = Op0.getValueType();
22713 
22714   assert((CmpVT == MVT::i8 || CmpVT == MVT::i16 ||
22715           CmpVT == MVT::i32 || CmpVT == MVT::i64) && "Unexpected VT!");
22716 
22717   // Only promote the compare up to I32 if it is a 16 bit operation
22718   // with an immediate. 16 bit immediates are to be avoided unless the target
22719   // isn't slowed down by length changing prefixes, we're optimizing for
22720   // codesize or the comparison is with a folded load.
22721   if (CmpVT == MVT::i16 && !Subtarget.hasFastImm16() &&
22722       !X86::mayFoldLoad(Op0, Subtarget) && !X86::mayFoldLoad(Op1, Subtarget) &&
22723       !DAG.getMachineFunction().getFunction().hasMinSize()) {
22724     auto *COp0 = dyn_cast<ConstantSDNode>(Op0);
22725     auto *COp1 = dyn_cast<ConstantSDNode>(Op1);
22726     // Don't do this if the immediate can fit in 8-bits.
22727     if ((COp0 && !COp0->getAPIntValue().isSignedIntN(8)) ||
22728         (COp1 && !COp1->getAPIntValue().isSignedIntN(8))) {
22729       unsigned ExtendOp =
22730           isX86CCSigned(X86CC) ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
22731       if (X86CC == X86::COND_E || X86CC == X86::COND_NE) {
22732         // For equality comparisons try to use SIGN_EXTEND if the input was
22733         // truncate from something with enough sign bits.
22734         if (Op0.getOpcode() == ISD::TRUNCATE) {
22735           if (DAG.ComputeMaxSignificantBits(Op0.getOperand(0)) <= 16)
22736             ExtendOp = ISD::SIGN_EXTEND;
22737         } else if (Op1.getOpcode() == ISD::TRUNCATE) {
22738           if (DAG.ComputeMaxSignificantBits(Op1.getOperand(0)) <= 16)
22739             ExtendOp = ISD::SIGN_EXTEND;
22740         }
22741       }
22742 
22743       CmpVT = MVT::i32;
22744       Op0 = DAG.getNode(ExtendOp, dl, CmpVT, Op0);
22745       Op1 = DAG.getNode(ExtendOp, dl, CmpVT, Op1);
22746     }
22747   }
22748 
22749   // Try to shrink i64 compares if the input has enough zero bits.
22750   // TODO: Add sign-bits equivalent for isX86CCSigned(X86CC)?
22751   if (CmpVT == MVT::i64 && !isX86CCSigned(X86CC) &&
22752       Op0.hasOneUse() && // Hacky way to not break CSE opportunities with sub.
22753       DAG.MaskedValueIsZero(Op1, APInt::getHighBitsSet(64, 32)) &&
22754       DAG.MaskedValueIsZero(Op0, APInt::getHighBitsSet(64, 32))) {
22755     CmpVT = MVT::i32;
22756     Op0 = DAG.getNode(ISD::TRUNCATE, dl, CmpVT, Op0);
22757     Op1 = DAG.getNode(ISD::TRUNCATE, dl, CmpVT, Op1);
22758   }
22759 
22760   // 0-x == y --> x+y == 0
22761   // 0-x != y --> x+y != 0
22762   if (Op0.getOpcode() == ISD::SUB && isNullConstant(Op0.getOperand(0)) &&
22763       Op0.hasOneUse() && (X86CC == X86::COND_E || X86CC == X86::COND_NE)) {
22764     SDVTList VTs = DAG.getVTList(CmpVT, MVT::i32);
22765     SDValue Add = DAG.getNode(X86ISD::ADD, dl, VTs, Op0.getOperand(1), Op1);
22766     return Add.getValue(1);
22767   }
22768 
22769   // x == 0-y --> x+y == 0
22770   // x != 0-y --> x+y != 0
22771   if (Op1.getOpcode() == ISD::SUB && isNullConstant(Op1.getOperand(0)) &&
22772       Op1.hasOneUse() && (X86CC == X86::COND_E || X86CC == X86::COND_NE)) {
22773     SDVTList VTs = DAG.getVTList(CmpVT, MVT::i32);
22774     SDValue Add = DAG.getNode(X86ISD::ADD, dl, VTs, Op0, Op1.getOperand(1));
22775     return Add.getValue(1);
22776   }
22777 
22778   // Use SUB instead of CMP to enable CSE between SUB and CMP.
22779   SDVTList VTs = DAG.getVTList(CmpVT, MVT::i32);
22780   SDValue Sub = DAG.getNode(X86ISD::SUB, dl, VTs, Op0, Op1);
22781   return Sub.getValue(1);
22782 }
22783 
isXAndYEqZeroPreferableToXAndYEqY(ISD::CondCode Cond,EVT VT) const22784 bool X86TargetLowering::isXAndYEqZeroPreferableToXAndYEqY(ISD::CondCode Cond,
22785                                                           EVT VT) const {
22786   return !VT.isVector() || Cond != ISD::CondCode::SETEQ;
22787 }
22788 
optimizeFMulOrFDivAsShiftAddBitcast(SDNode * N,SDValue,SDValue IntPow2) const22789 bool X86TargetLowering::optimizeFMulOrFDivAsShiftAddBitcast(
22790     SDNode *N, SDValue, SDValue IntPow2) const {
22791   if (N->getOpcode() == ISD::FDIV)
22792     return true;
22793 
22794   EVT FPVT = N->getValueType(0);
22795   EVT IntVT = IntPow2.getValueType();
22796 
22797   // This indicates a non-free bitcast.
22798   // TODO: This is probably overly conservative as we will need to scale the
22799   // integer vector anyways for the int->fp cast.
22800   if (FPVT.isVector() &&
22801       FPVT.getScalarSizeInBits() != IntVT.getScalarSizeInBits())
22802     return false;
22803 
22804   return true;
22805 }
22806 
22807 /// Check if replacement of SQRT with RSQRT should be disabled.
isFsqrtCheap(SDValue Op,SelectionDAG & DAG) const22808 bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const {
22809   EVT VT = Op.getValueType();
22810 
22811   // We don't need to replace SQRT with RSQRT for half type.
22812   if (VT.getScalarType() == MVT::f16)
22813     return true;
22814 
22815   // We never want to use both SQRT and RSQRT instructions for the same input.
22816   if (DAG.doesNodeExist(X86ISD::FRSQRT, DAG.getVTList(VT), Op))
22817     return false;
22818 
22819   if (VT.isVector())
22820     return Subtarget.hasFastVectorFSQRT();
22821   return Subtarget.hasFastScalarFSQRT();
22822 }
22823 
22824 /// The minimum architected relative accuracy is 2^-12. We need one
22825 /// Newton-Raphson step to have a good float result (24 bits of precision).
getSqrtEstimate(SDValue Op,SelectionDAG & DAG,int Enabled,int & RefinementSteps,bool & UseOneConstNR,bool Reciprocal) const22826 SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
22827                                            SelectionDAG &DAG, int Enabled,
22828                                            int &RefinementSteps,
22829                                            bool &UseOneConstNR,
22830                                            bool Reciprocal) const {
22831   SDLoc DL(Op);
22832   EVT VT = Op.getValueType();
22833 
22834   // SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
22835   // It is likely not profitable to do this for f64 because a double-precision
22836   // rsqrt estimate with refinement on x86 prior to FMA requires at least 16
22837   // instructions: convert to single, rsqrtss, convert back to double, refine
22838   // (3 steps = at least 13 insts). If an 'rsqrtsd' variant was added to the ISA
22839   // along with FMA, this could be a throughput win.
22840   // TODO: SQRT requires SSE2 to prevent the introduction of an illegal v4i32
22841   // after legalize types.
22842   if ((VT == MVT::f32 && Subtarget.hasSSE1()) ||
22843       (VT == MVT::v4f32 && Subtarget.hasSSE1() && Reciprocal) ||
22844       (VT == MVT::v4f32 && Subtarget.hasSSE2() && !Reciprocal) ||
22845       (VT == MVT::v8f32 && Subtarget.hasAVX()) ||
22846       (VT == MVT::v16f32 && Subtarget.useAVX512Regs())) {
22847     if (RefinementSteps == ReciprocalEstimate::Unspecified)
22848       RefinementSteps = 1;
22849 
22850     UseOneConstNR = false;
22851     // There is no FSQRT for 512-bits, but there is RSQRT14.
22852     unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT;
22853     SDValue Estimate = DAG.getNode(Opcode, DL, VT, Op);
22854     if (RefinementSteps == 0 && !Reciprocal)
22855       Estimate = DAG.getNode(ISD::FMUL, DL, VT, Op, Estimate);
22856     return Estimate;
22857   }
22858 
22859   if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
22860       Subtarget.hasFP16()) {
22861     assert(Reciprocal && "Don't replace SQRT with RSQRT for half type");
22862     if (RefinementSteps == ReciprocalEstimate::Unspecified)
22863       RefinementSteps = 0;
22864 
22865     if (VT == MVT::f16) {
22866       SDValue Zero = DAG.getIntPtrConstant(0, DL);
22867       SDValue Undef = DAG.getUNDEF(MVT::v8f16);
22868       Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
22869       Op = DAG.getNode(X86ISD::RSQRT14S, DL, MVT::v8f16, Undef, Op);
22870       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
22871     }
22872 
22873     return DAG.getNode(X86ISD::RSQRT14, DL, VT, Op);
22874   }
22875   return SDValue();
22876 }
22877 
22878 /// The minimum architected relative accuracy is 2^-12. We need one
22879 /// Newton-Raphson step to have a good float result (24 bits of precision).
getRecipEstimate(SDValue Op,SelectionDAG & DAG,int Enabled,int & RefinementSteps) const22880 SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG,
22881                                             int Enabled,
22882                                             int &RefinementSteps) const {
22883   SDLoc DL(Op);
22884   EVT VT = Op.getValueType();
22885 
22886   // SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps.
22887   // It is likely not profitable to do this for f64 because a double-precision
22888   // reciprocal estimate with refinement on x86 prior to FMA requires
22889   // 15 instructions: convert to single, rcpss, convert back to double, refine
22890   // (3 steps = 12 insts). If an 'rcpsd' variant was added to the ISA
22891   // along with FMA, this could be a throughput win.
22892 
22893   if ((VT == MVT::f32 && Subtarget.hasSSE1()) ||
22894       (VT == MVT::v4f32 && Subtarget.hasSSE1()) ||
22895       (VT == MVT::v8f32 && Subtarget.hasAVX()) ||
22896       (VT == MVT::v16f32 && Subtarget.useAVX512Regs())) {
22897     // Enable estimate codegen with 1 refinement step for vector division.
22898     // Scalar division estimates are disabled because they break too much
22899     // real-world code. These defaults are intended to match GCC behavior.
22900     if (VT == MVT::f32 && Enabled == ReciprocalEstimate::Unspecified)
22901       return SDValue();
22902 
22903     if (RefinementSteps == ReciprocalEstimate::Unspecified)
22904       RefinementSteps = 1;
22905 
22906     // There is no FSQRT for 512-bits, but there is RCP14.
22907     unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP;
22908     return DAG.getNode(Opcode, DL, VT, Op);
22909   }
22910 
22911   if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) &&
22912       Subtarget.hasFP16()) {
22913     if (RefinementSteps == ReciprocalEstimate::Unspecified)
22914       RefinementSteps = 0;
22915 
22916     if (VT == MVT::f16) {
22917       SDValue Zero = DAG.getIntPtrConstant(0, DL);
22918       SDValue Undef = DAG.getUNDEF(MVT::v8f16);
22919       Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op);
22920       Op = DAG.getNode(X86ISD::RCP14S, DL, MVT::v8f16, Undef, Op);
22921       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero);
22922     }
22923 
22924     return DAG.getNode(X86ISD::RCP14, DL, VT, Op);
22925   }
22926   return SDValue();
22927 }
22928 
22929 /// If we have at least two divisions that use the same divisor, convert to
22930 /// multiplication by a reciprocal. This may need to be adjusted for a given
22931 /// CPU if a division's cost is not at least twice the cost of a multiplication.
22932 /// This is because we still need one division to calculate the reciprocal and
22933 /// then we need two multiplies by that reciprocal as replacements for the
22934 /// original divisions.
combineRepeatedFPDivisors() const22935 unsigned X86TargetLowering::combineRepeatedFPDivisors() const {
22936   return 2;
22937 }
22938 
22939 SDValue
BuildSDIVPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const22940 X86TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
22941                                  SelectionDAG &DAG,
22942                                  SmallVectorImpl<SDNode *> &Created) const {
22943   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
22944   if (isIntDivCheap(N->getValueType(0), Attr))
22945     return SDValue(N,0); // Lower SDIV as SDIV
22946 
22947   assert((Divisor.isPowerOf2() || Divisor.isNegatedPowerOf2()) &&
22948          "Unexpected divisor!");
22949 
22950   // Only perform this transform if CMOV is supported otherwise the select
22951   // below will become a branch.
22952   if (!Subtarget.canUseCMOV())
22953     return SDValue();
22954 
22955   // fold (sdiv X, pow2)
22956   EVT VT = N->getValueType(0);
22957   // FIXME: Support i8.
22958   if (VT != MVT::i16 && VT != MVT::i32 &&
22959       !(Subtarget.is64Bit() && VT == MVT::i64))
22960     return SDValue();
22961 
22962   // If the divisor is 2 or -2, the default expansion is better.
22963   if (Divisor == 2 ||
22964       Divisor == APInt(Divisor.getBitWidth(), -2, /*isSigned*/ true))
22965     return SDValue();
22966 
22967   return TargetLowering::buildSDIVPow2WithCMov(N, Divisor, DAG, Created);
22968 }
22969 
22970 /// Result of 'and' is compared against zero. Change to a BT node if possible.
22971 /// Returns the BT node and the condition code needed to use it.
LowerAndToBT(SDValue And,ISD::CondCode CC,const SDLoc & dl,SelectionDAG & DAG,X86::CondCode & X86CC)22972 static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, const SDLoc &dl,
22973                             SelectionDAG &DAG, X86::CondCode &X86CC) {
22974   assert(And.getOpcode() == ISD::AND && "Expected AND node!");
22975   SDValue Op0 = And.getOperand(0);
22976   SDValue Op1 = And.getOperand(1);
22977   if (Op0.getOpcode() == ISD::TRUNCATE)
22978     Op0 = Op0.getOperand(0);
22979   if (Op1.getOpcode() == ISD::TRUNCATE)
22980     Op1 = Op1.getOperand(0);
22981 
22982   SDValue Src, BitNo;
22983   if (Op1.getOpcode() == ISD::SHL)
22984     std::swap(Op0, Op1);
22985   if (Op0.getOpcode() == ISD::SHL) {
22986     if (isOneConstant(Op0.getOperand(0))) {
22987       // If we looked past a truncate, check that it's only truncating away
22988       // known zeros.
22989       unsigned BitWidth = Op0.getValueSizeInBits();
22990       unsigned AndBitWidth = And.getValueSizeInBits();
22991       if (BitWidth > AndBitWidth) {
22992         KnownBits Known = DAG.computeKnownBits(Op0);
22993         if (Known.countMinLeadingZeros() < BitWidth - AndBitWidth)
22994           return SDValue();
22995       }
22996       Src = Op1;
22997       BitNo = Op0.getOperand(1);
22998     }
22999   } else if (Op1.getOpcode() == ISD::Constant) {
23000     ConstantSDNode *AndRHS = cast<ConstantSDNode>(Op1);
23001     uint64_t AndRHSVal = AndRHS->getZExtValue();
23002     SDValue AndLHS = Op0;
23003 
23004     if (AndRHSVal == 1 && AndLHS.getOpcode() == ISD::SRL) {
23005       Src = AndLHS.getOperand(0);
23006       BitNo = AndLHS.getOperand(1);
23007     } else {
23008       // Use BT if the immediate can't be encoded in a TEST instruction or we
23009       // are optimizing for size and the immedaite won't fit in a byte.
23010       bool OptForSize = DAG.shouldOptForSize();
23011       if ((!isUInt<32>(AndRHSVal) || (OptForSize && !isUInt<8>(AndRHSVal))) &&
23012           isPowerOf2_64(AndRHSVal)) {
23013         Src = AndLHS;
23014         BitNo = DAG.getConstant(Log2_64_Ceil(AndRHSVal), dl,
23015                                 Src.getValueType());
23016       }
23017     }
23018   }
23019 
23020   // No patterns found, give up.
23021   if (!Src.getNode())
23022     return SDValue();
23023 
23024   // Remove any bit flip.
23025   if (isBitwiseNot(Src)) {
23026     Src = Src.getOperand(0);
23027     CC = CC == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ;
23028   }
23029 
23030   // Attempt to create the X86ISD::BT node.
23031   if (SDValue BT = getBT(Src, BitNo, dl, DAG)) {
23032     X86CC = CC == ISD::SETEQ ? X86::COND_AE : X86::COND_B;
23033     return BT;
23034   }
23035 
23036   return SDValue();
23037 }
23038 
23039 // Check if pre-AVX condcode can be performed by a single FCMP op.
cheapX86FSETCC_SSE(ISD::CondCode SetCCOpcode)23040 static bool cheapX86FSETCC_SSE(ISD::CondCode SetCCOpcode) {
23041   return (SetCCOpcode != ISD::SETONE) && (SetCCOpcode != ISD::SETUEQ);
23042 }
23043 
23044 /// Turns an ISD::CondCode into a value suitable for SSE floating-point mask
23045 /// CMPs.
translateX86FSETCC(ISD::CondCode SetCCOpcode,SDValue & Op0,SDValue & Op1,bool & IsAlwaysSignaling)23046 static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
23047                                    SDValue &Op1, bool &IsAlwaysSignaling) {
23048   unsigned SSECC;
23049   bool Swap = false;
23050 
23051   // SSE Condition code mapping:
23052   //  0 - EQ
23053   //  1 - LT
23054   //  2 - LE
23055   //  3 - UNORD
23056   //  4 - NEQ
23057   //  5 - NLT
23058   //  6 - NLE
23059   //  7 - ORD
23060   switch (SetCCOpcode) {
23061   // clang-format off
23062   default: llvm_unreachable("Unexpected SETCC condition");
23063   case ISD::SETOEQ:
23064   case ISD::SETEQ:  SSECC = 0; break;
23065   case ISD::SETOGT:
23066   case ISD::SETGT:  Swap = true; [[fallthrough]];
23067   case ISD::SETLT:
23068   case ISD::SETOLT: SSECC = 1; break;
23069   case ISD::SETOGE:
23070   case ISD::SETGE:  Swap = true; [[fallthrough]];
23071   case ISD::SETLE:
23072   case ISD::SETOLE: SSECC = 2; break;
23073   case ISD::SETUO:  SSECC = 3; break;
23074   case ISD::SETUNE:
23075   case ISD::SETNE:  SSECC = 4; break;
23076   case ISD::SETULE: Swap = true; [[fallthrough]];
23077   case ISD::SETUGE: SSECC = 5; break;
23078   case ISD::SETULT: Swap = true; [[fallthrough]];
23079   case ISD::SETUGT: SSECC = 6; break;
23080   case ISD::SETO:   SSECC = 7; break;
23081   case ISD::SETUEQ: SSECC = 8; break;
23082   case ISD::SETONE: SSECC = 12; break;
23083   // clang-format on
23084   }
23085   if (Swap)
23086     std::swap(Op0, Op1);
23087 
23088   switch (SetCCOpcode) {
23089   default:
23090     IsAlwaysSignaling = true;
23091     break;
23092   case ISD::SETEQ:
23093   case ISD::SETOEQ:
23094   case ISD::SETUEQ:
23095   case ISD::SETNE:
23096   case ISD::SETONE:
23097   case ISD::SETUNE:
23098   case ISD::SETO:
23099   case ISD::SETUO:
23100     IsAlwaysSignaling = false;
23101     break;
23102   }
23103 
23104   return SSECC;
23105 }
23106 
23107 /// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
23108 /// concatenate the result back.
splitIntVSETCC(EVT VT,SDValue LHS,SDValue RHS,ISD::CondCode Cond,SelectionDAG & DAG,const SDLoc & dl)23109 static SDValue splitIntVSETCC(EVT VT, SDValue LHS, SDValue RHS,
23110                               ISD::CondCode Cond, SelectionDAG &DAG,
23111                               const SDLoc &dl) {
23112   assert(VT.isInteger() && VT == LHS.getValueType() &&
23113          VT == RHS.getValueType() && "Unsupported VTs!");
23114 
23115   SDValue CC = DAG.getCondCode(Cond);
23116 
23117   // Extract the LHS Lo/Hi vectors
23118   SDValue LHS1, LHS2;
23119   std::tie(LHS1, LHS2) = splitVector(LHS, DAG, dl);
23120 
23121   // Extract the RHS Lo/Hi vectors
23122   SDValue RHS1, RHS2;
23123   std::tie(RHS1, RHS2) = splitVector(RHS, DAG, dl);
23124 
23125   // Issue the operation on the smaller types and concatenate the result back
23126   EVT LoVT, HiVT;
23127   std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
23128   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
23129                      DAG.getNode(ISD::SETCC, dl, LoVT, LHS1, RHS1, CC),
23130                      DAG.getNode(ISD::SETCC, dl, HiVT, LHS2, RHS2, CC));
23131 }
23132 
LowerIntVSETCC_AVX512(SDValue Op,const SDLoc & dl,SelectionDAG & DAG)23133 static SDValue LowerIntVSETCC_AVX512(SDValue Op, const SDLoc &dl,
23134                                      SelectionDAG &DAG) {
23135   SDValue Op0 = Op.getOperand(0);
23136   SDValue Op1 = Op.getOperand(1);
23137   SDValue CC = Op.getOperand(2);
23138   MVT VT = Op.getSimpleValueType();
23139   assert(VT.getVectorElementType() == MVT::i1 &&
23140          "Cannot set masked compare for this operation");
23141 
23142   ISD::CondCode SetCCOpcode = cast<CondCodeSDNode>(CC)->get();
23143 
23144   // Prefer SETGT over SETLT.
23145   if (SetCCOpcode == ISD::SETLT) {
23146     SetCCOpcode = ISD::getSetCCSwappedOperands(SetCCOpcode);
23147     std::swap(Op0, Op1);
23148   }
23149 
23150   return DAG.getSetCC(dl, VT, Op0, Op1, SetCCOpcode);
23151 }
23152 
23153 /// Given a buildvector constant, return a new vector constant with each element
23154 /// incremented or decremented. If incrementing or decrementing would result in
23155 /// unsigned overflow or underflow or this is not a simple vector constant,
23156 /// return an empty value.
incDecVectorConstant(SDValue V,SelectionDAG & DAG,bool IsInc,bool NSW)23157 static SDValue incDecVectorConstant(SDValue V, SelectionDAG &DAG, bool IsInc,
23158                                     bool NSW) {
23159   auto *BV = dyn_cast<BuildVectorSDNode>(V.getNode());
23160   if (!BV || !V.getValueType().isSimple())
23161     return SDValue();
23162 
23163   MVT VT = V.getSimpleValueType();
23164   MVT EltVT = VT.getVectorElementType();
23165   unsigned NumElts = VT.getVectorNumElements();
23166   SmallVector<SDValue, 8> NewVecC;
23167   SDLoc DL(V);
23168   for (unsigned i = 0; i < NumElts; ++i) {
23169     auto *Elt = dyn_cast<ConstantSDNode>(BV->getOperand(i));
23170     if (!Elt || Elt->isOpaque() || Elt->getSimpleValueType(0) != EltVT)
23171       return SDValue();
23172 
23173     // Avoid overflow/underflow.
23174     const APInt &EltC = Elt->getAPIntValue();
23175     if ((IsInc && EltC.isMaxValue()) || (!IsInc && EltC.isZero()))
23176       return SDValue();
23177     if (NSW && ((IsInc && EltC.isMaxSignedValue()) ||
23178                 (!IsInc && EltC.isMinSignedValue())))
23179       return SDValue();
23180 
23181     NewVecC.push_back(DAG.getConstant(EltC + (IsInc ? 1 : -1), DL, EltVT));
23182   }
23183 
23184   return DAG.getBuildVector(VT, DL, NewVecC);
23185 }
23186 
23187 /// As another special case, use PSUBUS[BW] when it's profitable. E.g. for
23188 /// Op0 u<= Op1:
23189 ///   t = psubus Op0, Op1
23190 ///   pcmpeq t, <0..0>
LowerVSETCCWithSUBUS(SDValue Op0,SDValue Op1,MVT VT,ISD::CondCode Cond,const SDLoc & dl,const X86Subtarget & Subtarget,SelectionDAG & DAG)23191 static SDValue LowerVSETCCWithSUBUS(SDValue Op0, SDValue Op1, MVT VT,
23192                                     ISD::CondCode Cond, const SDLoc &dl,
23193                                     const X86Subtarget &Subtarget,
23194                                     SelectionDAG &DAG) {
23195   if (!Subtarget.hasSSE2())
23196     return SDValue();
23197 
23198   MVT VET = VT.getVectorElementType();
23199   if (VET != MVT::i8 && VET != MVT::i16)
23200     return SDValue();
23201 
23202   switch (Cond) {
23203   default:
23204     return SDValue();
23205   case ISD::SETULT: {
23206     // If the comparison is against a constant we can turn this into a
23207     // setule.  With psubus, setule does not require a swap.  This is
23208     // beneficial because the constant in the register is no longer
23209     // destructed as the destination so it can be hoisted out of a loop.
23210     // Only do this pre-AVX since vpcmp* is no longer destructive.
23211     if (Subtarget.hasAVX())
23212       return SDValue();
23213     SDValue ULEOp1 =
23214         incDecVectorConstant(Op1, DAG, /*IsInc*/ false, /*NSW*/ false);
23215     if (!ULEOp1)
23216       return SDValue();
23217     Op1 = ULEOp1;
23218     break;
23219   }
23220   case ISD::SETUGT: {
23221     // If the comparison is against a constant, we can turn this into a setuge.
23222     // This is beneficial because materializing a constant 0 for the PCMPEQ is
23223     // probably cheaper than XOR+PCMPGT using 2 different vector constants:
23224     // cmpgt (xor X, SignMaskC) CmpC --> cmpeq (usubsat (CmpC+1), X), 0
23225     SDValue UGEOp1 =
23226         incDecVectorConstant(Op1, DAG, /*IsInc*/ true, /*NSW*/ false);
23227     if (!UGEOp1)
23228       return SDValue();
23229     Op1 = Op0;
23230     Op0 = UGEOp1;
23231     break;
23232   }
23233   // Psubus is better than flip-sign because it requires no inversion.
23234   case ISD::SETUGE:
23235     std::swap(Op0, Op1);
23236     break;
23237   case ISD::SETULE:
23238     break;
23239   }
23240 
23241   SDValue Result = DAG.getNode(ISD::USUBSAT, dl, VT, Op0, Op1);
23242   return DAG.getNode(X86ISD::PCMPEQ, dl, VT, Result,
23243                      DAG.getConstant(0, dl, VT));
23244 }
23245 
LowerVSETCC(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)23246 static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
23247                            SelectionDAG &DAG) {
23248   bool IsStrict = Op.getOpcode() == ISD::STRICT_FSETCC ||
23249                   Op.getOpcode() == ISD::STRICT_FSETCCS;
23250   SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0);
23251   SDValue Op1 = Op.getOperand(IsStrict ? 2 : 1);
23252   SDValue CC = Op.getOperand(IsStrict ? 3 : 2);
23253   MVT VT = Op->getSimpleValueType(0);
23254   ISD::CondCode Cond = cast<CondCodeSDNode>(CC)->get();
23255   bool isFP = Op1.getSimpleValueType().isFloatingPoint();
23256   SDLoc dl(Op);
23257 
23258   if (isFP) {
23259     MVT EltVT = Op0.getSimpleValueType().getVectorElementType();
23260     assert(EltVT == MVT::f16 || EltVT == MVT::f32 || EltVT == MVT::f64);
23261     if (isSoftF16(EltVT, Subtarget))
23262       return SDValue();
23263 
23264     bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
23265     SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
23266 
23267     // If we have a strict compare with a vXi1 result and the input is 128/256
23268     // bits we can't use a masked compare unless we have VLX. If we use a wider
23269     // compare like we do for non-strict, we might trigger spurious exceptions
23270     // from the upper elements. Instead emit a AVX compare and convert to mask.
23271     unsigned Opc;
23272     if (Subtarget.hasAVX512() && VT.getVectorElementType() == MVT::i1 &&
23273         (!IsStrict || Subtarget.hasVLX() ||
23274          Op0.getSimpleValueType().is512BitVector())) {
23275 #ifndef NDEBUG
23276       unsigned Num = VT.getVectorNumElements();
23277       assert(Num <= 16 || (Num == 32 && EltVT == MVT::f16));
23278 #endif
23279       Opc = IsStrict ? X86ISD::STRICT_CMPM : X86ISD::CMPM;
23280     } else {
23281       Opc = IsStrict ? X86ISD::STRICT_CMPP : X86ISD::CMPP;
23282       // The SSE/AVX packed FP comparison nodes are defined with a
23283       // floating-point vector result that matches the operand type. This allows
23284       // them to work with an SSE1 target (integer vector types are not legal).
23285       VT = Op0.getSimpleValueType();
23286     }
23287 
23288     SDValue Cmp;
23289     bool IsAlwaysSignaling;
23290     unsigned SSECC = translateX86FSETCC(Cond, Op0, Op1, IsAlwaysSignaling);
23291     if (!Subtarget.hasAVX()) {
23292       // TODO: We could use following steps to handle a quiet compare with
23293       // signaling encodings.
23294       // 1. Get ordered masks from a quiet ISD::SETO
23295       // 2. Use the masks to mask potential unordered elements in operand A, B
23296       // 3. Get the compare results of masked A, B
23297       // 4. Calculating final result using the mask and result from 3
23298       // But currently, we just fall back to scalar operations.
23299       if (IsStrict && IsAlwaysSignaling && !IsSignaling)
23300         return SDValue();
23301 
23302       // Insert an extra signaling instruction to raise exception.
23303       if (IsStrict && !IsAlwaysSignaling && IsSignaling) {
23304         SDValue SignalCmp = DAG.getNode(
23305             Opc, dl, {VT, MVT::Other},
23306             {Chain, Op0, Op1, DAG.getTargetConstant(1, dl, MVT::i8)}); // LT_OS
23307         // FIXME: It seems we need to update the flags of all new strict nodes.
23308         // Otherwise, mayRaiseFPException in MI will return false due to
23309         // NoFPExcept = false by default. However, I didn't find it in other
23310         // patches.
23311         SignalCmp->setFlags(Op->getFlags());
23312         Chain = SignalCmp.getValue(1);
23313       }
23314 
23315       // In the two cases not handled by SSE compare predicates (SETUEQ/SETONE),
23316       // emit two comparisons and a logic op to tie them together.
23317       if (!cheapX86FSETCC_SSE(Cond)) {
23318         // LLVM predicate is SETUEQ or SETONE.
23319         unsigned CC0, CC1;
23320         unsigned CombineOpc;
23321         if (Cond == ISD::SETUEQ) {
23322           CC0 = 3; // UNORD
23323           CC1 = 0; // EQ
23324           CombineOpc = X86ISD::FOR;
23325         } else {
23326           assert(Cond == ISD::SETONE);
23327           CC0 = 7; // ORD
23328           CC1 = 4; // NEQ
23329           CombineOpc = X86ISD::FAND;
23330         }
23331 
23332         SDValue Cmp0, Cmp1;
23333         if (IsStrict) {
23334           Cmp0 = DAG.getNode(
23335               Opc, dl, {VT, MVT::Other},
23336               {Chain, Op0, Op1, DAG.getTargetConstant(CC0, dl, MVT::i8)});
23337           Cmp1 = DAG.getNode(
23338               Opc, dl, {VT, MVT::Other},
23339               {Chain, Op0, Op1, DAG.getTargetConstant(CC1, dl, MVT::i8)});
23340           Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Cmp0.getValue(1),
23341                               Cmp1.getValue(1));
23342         } else {
23343           Cmp0 = DAG.getNode(
23344               Opc, dl, VT, Op0, Op1, DAG.getTargetConstant(CC0, dl, MVT::i8));
23345           Cmp1 = DAG.getNode(
23346               Opc, dl, VT, Op0, Op1, DAG.getTargetConstant(CC1, dl, MVT::i8));
23347         }
23348         Cmp = DAG.getNode(CombineOpc, dl, VT, Cmp0, Cmp1);
23349       } else {
23350         if (IsStrict) {
23351           Cmp = DAG.getNode(
23352               Opc, dl, {VT, MVT::Other},
23353               {Chain, Op0, Op1, DAG.getTargetConstant(SSECC, dl, MVT::i8)});
23354           Chain = Cmp.getValue(1);
23355         } else
23356           Cmp = DAG.getNode(
23357               Opc, dl, VT, Op0, Op1, DAG.getTargetConstant(SSECC, dl, MVT::i8));
23358       }
23359     } else {
23360       // Handle all other FP comparisons here.
23361       if (IsStrict) {
23362         // Make a flip on already signaling CCs before setting bit 4 of AVX CC.
23363         SSECC |= (IsAlwaysSignaling ^ IsSignaling) << 4;
23364         Cmp = DAG.getNode(
23365             Opc, dl, {VT, MVT::Other},
23366             {Chain, Op0, Op1, DAG.getTargetConstant(SSECC, dl, MVT::i8)});
23367         Chain = Cmp.getValue(1);
23368       } else
23369         Cmp = DAG.getNode(
23370             Opc, dl, VT, Op0, Op1, DAG.getTargetConstant(SSECC, dl, MVT::i8));
23371     }
23372 
23373     if (VT.getFixedSizeInBits() >
23374         Op.getSimpleValueType().getFixedSizeInBits()) {
23375       // We emitted a compare with an XMM/YMM result. Finish converting to a
23376       // mask register using a vptestm.
23377       EVT CastVT = EVT(VT).changeVectorElementTypeToInteger();
23378       Cmp = DAG.getBitcast(CastVT, Cmp);
23379       Cmp = DAG.getSetCC(dl, Op.getSimpleValueType(), Cmp,
23380                          DAG.getConstant(0, dl, CastVT), ISD::SETNE);
23381     } else {
23382       // If this is SSE/AVX CMPP, bitcast the result back to integer to match
23383       // the result type of SETCC. The bitcast is expected to be optimized
23384       // away during combining/isel.
23385       Cmp = DAG.getBitcast(Op.getSimpleValueType(), Cmp);
23386     }
23387 
23388     if (IsStrict)
23389       return DAG.getMergeValues({Cmp, Chain}, dl);
23390 
23391     return Cmp;
23392   }
23393 
23394   assert(!IsStrict && "Strict SETCC only handles FP operands.");
23395 
23396   [[maybe_unused]] MVT VTOp0 = Op0.getSimpleValueType();
23397   assert(VTOp0 == Op1.getSimpleValueType() &&
23398          "Expected operands with same type!");
23399   assert(VT.getVectorNumElements() == VTOp0.getVectorNumElements() &&
23400          "Invalid number of packed elements for source and destination!");
23401 
23402   // The non-AVX512 code below works under the assumption that source and
23403   // destination types are the same.
23404   assert((Subtarget.hasAVX512() || (VT == VTOp0)) &&
23405          "Value types for source and destination must be the same!");
23406 
23407   // The result is boolean, but operands are int/float
23408   if (VT.getVectorElementType() == MVT::i1) {
23409     // In AVX-512 architecture setcc returns mask with i1 elements,
23410     // But there is no compare instruction for i8 and i16 elements in KNL.
23411     assert((VTOp0.getScalarSizeInBits() >= 32 || Subtarget.hasBWI()) &&
23412            "Unexpected operand type");
23413     return LowerIntVSETCC_AVX512(Op, dl, DAG);
23414   }
23415 
23416   // Lower using XOP integer comparisons.
23417   if (VT.is128BitVector() && Subtarget.hasXOP()) {
23418     // Translate compare code to XOP PCOM compare mode.
23419     unsigned CmpMode = 0;
23420     switch (Cond) {
23421     // clang-format off
23422     default: llvm_unreachable("Unexpected SETCC condition");
23423     case ISD::SETULT:
23424     case ISD::SETLT: CmpMode = 0x00; break;
23425     case ISD::SETULE:
23426     case ISD::SETLE: CmpMode = 0x01; break;
23427     case ISD::SETUGT:
23428     case ISD::SETGT: CmpMode = 0x02; break;
23429     case ISD::SETUGE:
23430     case ISD::SETGE: CmpMode = 0x03; break;
23431     case ISD::SETEQ: CmpMode = 0x04; break;
23432     case ISD::SETNE: CmpMode = 0x05; break;
23433     // clang-format on
23434     }
23435 
23436     // Are we comparing unsigned or signed integers?
23437     unsigned Opc =
23438         ISD::isUnsignedIntSetCC(Cond) ? X86ISD::VPCOMU : X86ISD::VPCOM;
23439 
23440     return DAG.getNode(Opc, dl, VT, Op0, Op1,
23441                        DAG.getTargetConstant(CmpMode, dl, MVT::i8));
23442   }
23443 
23444   // (X & Y) != 0 --> (X & Y) == Y iff Y is power-of-2.
23445   // Revert part of the simplifySetCCWithAnd combine, to avoid an invert.
23446   if (Cond == ISD::SETNE && ISD::isBuildVectorAllZeros(Op1.getNode())) {
23447     SDValue BC0 = peekThroughBitcasts(Op0);
23448     if (BC0.getOpcode() == ISD::AND) {
23449       APInt UndefElts;
23450       SmallVector<APInt, 64> EltBits;
23451       if (getTargetConstantBitsFromNode(
23452               BC0.getOperand(1), VT.getScalarSizeInBits(), UndefElts, EltBits,
23453               /*AllowWholeUndefs*/ false, /*AllowPartialUndefs*/ false)) {
23454         if (llvm::all_of(EltBits, [](APInt &V) { return V.isPowerOf2(); })) {
23455           Cond = ISD::SETEQ;
23456           Op1 = DAG.getBitcast(VT, BC0.getOperand(1));
23457         }
23458       }
23459     }
23460   }
23461 
23462   // ICMP_EQ(AND(X,C),C) -> SRA(SHL(X,LOG2(C)),BW-1) iff C is power-of-2.
23463   if (Cond == ISD::SETEQ && Op0.getOpcode() == ISD::AND &&
23464       Op0.getOperand(1) == Op1 && Op0.hasOneUse()) {
23465     ConstantSDNode *C1 = isConstOrConstSplat(Op1);
23466     if (C1 && C1->getAPIntValue().isPowerOf2()) {
23467       unsigned BitWidth = VT.getScalarSizeInBits();
23468       unsigned ShiftAmt = BitWidth - C1->getAPIntValue().logBase2() - 1;
23469 
23470       SDValue Result = Op0.getOperand(0);
23471       Result = DAG.getNode(ISD::SHL, dl, VT, Result,
23472                            DAG.getConstant(ShiftAmt, dl, VT));
23473       Result = DAG.getNode(ISD::SRA, dl, VT, Result,
23474                            DAG.getConstant(BitWidth - 1, dl, VT));
23475       return Result;
23476     }
23477   }
23478 
23479   // Break 256-bit integer vector compare into smaller ones.
23480   if (VT.is256BitVector() && !Subtarget.hasInt256())
23481     return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23482 
23483   // Break 512-bit integer vector compare into smaller ones.
23484   // TODO: Try harder to use VPCMPx + VPMOV2x?
23485   if (VT.is512BitVector())
23486     return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
23487 
23488   // If we have a limit constant, try to form PCMPGT (signed cmp) to avoid
23489   // not-of-PCMPEQ:
23490   // X != INT_MIN --> X >s INT_MIN
23491   // X != INT_MAX --> X <s INT_MAX --> INT_MAX >s X
23492   // +X != 0 --> +X >s 0
23493   APInt ConstValue;
23494   if (Cond == ISD::SETNE &&
23495       ISD::isConstantSplatVector(Op1.getNode(), ConstValue)) {
23496     if (ConstValue.isMinSignedValue())
23497       Cond = ISD::SETGT;
23498     else if (ConstValue.isMaxSignedValue())
23499       Cond = ISD::SETLT;
23500     else if (ConstValue.isZero() && DAG.SignBitIsZero(Op0))
23501       Cond = ISD::SETGT;
23502   }
23503 
23504   // If both operands are known non-negative, then an unsigned compare is the
23505   // same as a signed compare and there's no need to flip signbits.
23506   // TODO: We could check for more general simplifications here since we're
23507   // computing known bits.
23508   bool FlipSigns = ISD::isUnsignedIntSetCC(Cond) &&
23509                    !(DAG.SignBitIsZero(Op0) && DAG.SignBitIsZero(Op1));
23510 
23511   // Special case: Use min/max operations for unsigned compares.
23512   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23513   if (ISD::isUnsignedIntSetCC(Cond) &&
23514       (FlipSigns || ISD::isTrueWhenEqual(Cond)) &&
23515       TLI.isOperationLegal(ISD::UMIN, VT)) {
23516     // If we have a constant operand, increment/decrement it and change the
23517     // condition to avoid an invert.
23518     if (Cond == ISD::SETUGT) {
23519       // X > C --> X >= (C+1) --> X == umax(X, C+1)
23520       if (SDValue UGTOp1 =
23521               incDecVectorConstant(Op1, DAG, /*IsInc*/ true, /*NSW*/ false)) {
23522         Op1 = UGTOp1;
23523         Cond = ISD::SETUGE;
23524       }
23525     }
23526     if (Cond == ISD::SETULT) {
23527       // X < C --> X <= (C-1) --> X == umin(X, C-1)
23528       if (SDValue ULTOp1 =
23529               incDecVectorConstant(Op1, DAG, /*IsInc*/ false, /*NSW*/ false)) {
23530         Op1 = ULTOp1;
23531         Cond = ISD::SETULE;
23532       }
23533     }
23534     bool Invert = false;
23535     unsigned Opc;
23536     switch (Cond) {
23537     // clang-format off
23538     default: llvm_unreachable("Unexpected condition code");
23539     case ISD::SETUGT: Invert = true; [[fallthrough]];
23540     case ISD::SETULE: Opc = ISD::UMIN; break;
23541     case ISD::SETULT: Invert = true; [[fallthrough]];
23542     case ISD::SETUGE: Opc = ISD::UMAX; break;
23543     // clang-format on
23544     }
23545 
23546     SDValue Result = DAG.getNode(Opc, dl, VT, Op0, Op1);
23547     Result = DAG.getNode(X86ISD::PCMPEQ, dl, VT, Op0, Result);
23548 
23549     // If the logical-not of the result is required, perform that now.
23550     if (Invert)
23551       Result = DAG.getNOT(dl, Result, VT);
23552 
23553     return Result;
23554   }
23555 
23556   // Try to use SUBUS and PCMPEQ.
23557   if (FlipSigns)
23558     if (SDValue V =
23559             LowerVSETCCWithSUBUS(Op0, Op1, VT, Cond, dl, Subtarget, DAG))
23560       return V;
23561 
23562   // We are handling one of the integer comparisons here. Since SSE only has
23563   // GT and EQ comparisons for integer, swapping operands and multiple
23564   // operations may be required for some comparisons.
23565   unsigned Opc = (Cond == ISD::SETEQ || Cond == ISD::SETNE) ? X86ISD::PCMPEQ
23566                                                             : X86ISD::PCMPGT;
23567   bool Swap = Cond == ISD::SETLT || Cond == ISD::SETULT ||
23568               Cond == ISD::SETGE || Cond == ISD::SETUGE;
23569   bool Invert = Cond == ISD::SETNE ||
23570                 (Cond != ISD::SETEQ && ISD::isTrueWhenEqual(Cond));
23571 
23572   if (Swap)
23573     std::swap(Op0, Op1);
23574 
23575   // Check that the operation in question is available (most are plain SSE2,
23576   // but PCMPGTQ and PCMPEQQ have different requirements).
23577   if (VT == MVT::v2i64) {
23578     if (Opc == X86ISD::PCMPGT && !Subtarget.hasSSE42()) {
23579       assert(Subtarget.hasSSE2() && "Don't know how to lower!");
23580 
23581       // Special case for sign bit test. We can use a v4i32 PCMPGT and shuffle
23582       // the odd elements over the even elements.
23583       if (!FlipSigns && !Invert && ISD::isBuildVectorAllZeros(Op0.getNode())) {
23584         Op0 = DAG.getConstant(0, dl, MVT::v4i32);
23585         Op1 = DAG.getBitcast(MVT::v4i32, Op1);
23586 
23587         SDValue GT = DAG.getNode(X86ISD::PCMPGT, dl, MVT::v4i32, Op0, Op1);
23588         static const int MaskHi[] = { 1, 1, 3, 3 };
23589         SDValue Result = DAG.getVectorShuffle(MVT::v4i32, dl, GT, GT, MaskHi);
23590 
23591         return DAG.getBitcast(VT, Result);
23592       }
23593 
23594       if (!FlipSigns && !Invert && ISD::isBuildVectorAllOnes(Op1.getNode())) {
23595         Op0 = DAG.getBitcast(MVT::v4i32, Op0);
23596         Op1 = DAG.getConstant(-1, dl, MVT::v4i32);
23597 
23598         SDValue GT = DAG.getNode(X86ISD::PCMPGT, dl, MVT::v4i32, Op0, Op1);
23599         static const int MaskHi[] = { 1, 1, 3, 3 };
23600         SDValue Result = DAG.getVectorShuffle(MVT::v4i32, dl, GT, GT, MaskHi);
23601 
23602         return DAG.getBitcast(VT, Result);
23603       }
23604 
23605       // If the i64 elements are sign-extended enough to be representable as i32
23606       // then we can compare the lower i32 bits and splat.
23607       if (!FlipSigns && !Invert && DAG.ComputeNumSignBits(Op0) > 32 &&
23608           DAG.ComputeNumSignBits(Op1) > 32) {
23609         Op0 = DAG.getBitcast(MVT::v4i32, Op0);
23610         Op1 = DAG.getBitcast(MVT::v4i32, Op1);
23611 
23612         SDValue GT = DAG.getNode(X86ISD::PCMPGT, dl, MVT::v4i32, Op0, Op1);
23613         static const int MaskLo[] = {0, 0, 2, 2};
23614         SDValue Result = DAG.getVectorShuffle(MVT::v4i32, dl, GT, GT, MaskLo);
23615 
23616         return DAG.getBitcast(VT, Result);
23617       }
23618 
23619       // Since SSE has no unsigned integer comparisons, we need to flip the sign
23620       // bits of the inputs before performing those operations. The lower
23621       // compare is always unsigned.
23622       SDValue SB = DAG.getConstant(FlipSigns ? 0x8000000080000000ULL
23623                                              : 0x0000000080000000ULL,
23624                                    dl, MVT::v2i64);
23625 
23626       Op0 = DAG.getNode(ISD::XOR, dl, MVT::v2i64, Op0, SB);
23627       Op1 = DAG.getNode(ISD::XOR, dl, MVT::v2i64, Op1, SB);
23628 
23629       // Cast everything to the right type.
23630       Op0 = DAG.getBitcast(MVT::v4i32, Op0);
23631       Op1 = DAG.getBitcast(MVT::v4i32, Op1);
23632 
23633       // Emulate PCMPGTQ with (hi1 > hi2) | ((hi1 == hi2) & (lo1 > lo2))
23634       SDValue GT = DAG.getNode(X86ISD::PCMPGT, dl, MVT::v4i32, Op0, Op1);
23635       SDValue EQ = DAG.getNode(X86ISD::PCMPEQ, dl, MVT::v4i32, Op0, Op1);
23636 
23637       // Create masks for only the low parts/high parts of the 64 bit integers.
23638       static const int MaskHi[] = { 1, 1, 3, 3 };
23639       static const int MaskLo[] = { 0, 0, 2, 2 };
23640       SDValue EQHi = DAG.getVectorShuffle(MVT::v4i32, dl, EQ, EQ, MaskHi);
23641       SDValue GTLo = DAG.getVectorShuffle(MVT::v4i32, dl, GT, GT, MaskLo);
23642       SDValue GTHi = DAG.getVectorShuffle(MVT::v4i32, dl, GT, GT, MaskHi);
23643 
23644       SDValue Result = DAG.getNode(ISD::AND, dl, MVT::v4i32, EQHi, GTLo);
23645       Result = DAG.getNode(ISD::OR, dl, MVT::v4i32, Result, GTHi);
23646 
23647       if (Invert)
23648         Result = DAG.getNOT(dl, Result, MVT::v4i32);
23649 
23650       return DAG.getBitcast(VT, Result);
23651     }
23652 
23653     if (Opc == X86ISD::PCMPEQ && !Subtarget.hasSSE41()) {
23654       // If pcmpeqq is missing but pcmpeqd is available synthesize pcmpeqq with
23655       // pcmpeqd + pshufd + pand.
23656       assert(Subtarget.hasSSE2() && !FlipSigns && "Don't know how to lower!");
23657 
23658       // First cast everything to the right type.
23659       Op0 = DAG.getBitcast(MVT::v4i32, Op0);
23660       Op1 = DAG.getBitcast(MVT::v4i32, Op1);
23661 
23662       // Do the compare.
23663       SDValue Result = DAG.getNode(Opc, dl, MVT::v4i32, Op0, Op1);
23664 
23665       // Make sure the lower and upper halves are both all-ones.
23666       static const int Mask[] = { 1, 0, 3, 2 };
23667       SDValue Shuf = DAG.getVectorShuffle(MVT::v4i32, dl, Result, Result, Mask);
23668       Result = DAG.getNode(ISD::AND, dl, MVT::v4i32, Result, Shuf);
23669 
23670       if (Invert)
23671         Result = DAG.getNOT(dl, Result, MVT::v4i32);
23672 
23673       return DAG.getBitcast(VT, Result);
23674     }
23675   }
23676 
23677   // Since SSE has no unsigned integer comparisons, we need to flip the sign
23678   // bits of the inputs before performing those operations.
23679   if (FlipSigns) {
23680     MVT EltVT = VT.getVectorElementType();
23681     SDValue SM = DAG.getConstant(APInt::getSignMask(EltVT.getSizeInBits()), dl,
23682                                  VT);
23683     Op0 = DAG.getNode(ISD::XOR, dl, VT, Op0, SM);
23684     Op1 = DAG.getNode(ISD::XOR, dl, VT, Op1, SM);
23685   }
23686 
23687   SDValue Result = DAG.getNode(Opc, dl, VT, Op0, Op1);
23688 
23689   // If the logical-not of the result is required, perform that now.
23690   if (Invert)
23691     Result = DAG.getNOT(dl, Result, VT);
23692 
23693   return Result;
23694 }
23695 
23696 // Try to select this as a KORTEST+SETCC or KTEST+SETCC if possible.
EmitAVX512Test(SDValue Op0,SDValue Op1,ISD::CondCode CC,const SDLoc & dl,SelectionDAG & DAG,const X86Subtarget & Subtarget,SDValue & X86CC)23697 static SDValue EmitAVX512Test(SDValue Op0, SDValue Op1, ISD::CondCode CC,
23698                               const SDLoc &dl, SelectionDAG &DAG,
23699                               const X86Subtarget &Subtarget,
23700                               SDValue &X86CC) {
23701   assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
23702 
23703   // Must be a bitcast from vXi1.
23704   if (Op0.getOpcode() != ISD::BITCAST)
23705     return SDValue();
23706 
23707   Op0 = Op0.getOperand(0);
23708   MVT VT = Op0.getSimpleValueType();
23709   if (!(Subtarget.hasAVX512() && VT == MVT::v16i1) &&
23710       !(Subtarget.hasDQI() && VT == MVT::v8i1) &&
23711       !(Subtarget.hasBWI() && (VT == MVT::v32i1 || VT == MVT::v64i1)))
23712     return SDValue();
23713 
23714   X86::CondCode X86Cond;
23715   if (isNullConstant(Op1)) {
23716     X86Cond = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE;
23717   } else if (isAllOnesConstant(Op1)) {
23718     // C flag is set for all ones.
23719     X86Cond = CC == ISD::SETEQ ? X86::COND_B : X86::COND_AE;
23720   } else
23721     return SDValue();
23722 
23723   // If the input is an AND, we can combine it's operands into the KTEST.
23724   bool KTestable = false;
23725   if (Subtarget.hasDQI() && (VT == MVT::v8i1 || VT == MVT::v16i1))
23726     KTestable = true;
23727   if (Subtarget.hasBWI() && (VT == MVT::v32i1 || VT == MVT::v64i1))
23728     KTestable = true;
23729   if (!isNullConstant(Op1))
23730     KTestable = false;
23731   if (KTestable && Op0.getOpcode() == ISD::AND && Op0.hasOneUse()) {
23732     SDValue LHS = Op0.getOperand(0);
23733     SDValue RHS = Op0.getOperand(1);
23734     X86CC = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
23735     return DAG.getNode(X86ISD::KTEST, dl, MVT::i32, LHS, RHS);
23736   }
23737 
23738   // If the input is an OR, we can combine it's operands into the KORTEST.
23739   SDValue LHS = Op0;
23740   SDValue RHS = Op0;
23741   if (Op0.getOpcode() == ISD::OR && Op0.hasOneUse()) {
23742     LHS = Op0.getOperand(0);
23743     RHS = Op0.getOperand(1);
23744   }
23745 
23746   X86CC = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
23747   return DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS);
23748 }
23749 
23750 /// Emit flags for the given setcc condition and operands. Also returns the
23751 /// corresponding X86 condition code constant in X86CC.
emitFlagsForSetcc(SDValue Op0,SDValue Op1,ISD::CondCode CC,const SDLoc & dl,SelectionDAG & DAG,SDValue & X86CC) const23752 SDValue X86TargetLowering::emitFlagsForSetcc(SDValue Op0, SDValue Op1,
23753                                              ISD::CondCode CC, const SDLoc &dl,
23754                                              SelectionDAG &DAG,
23755                                              SDValue &X86CC) const {
23756   // Equality Combines.
23757   if (CC == ISD::SETEQ || CC == ISD::SETNE) {
23758     X86::CondCode X86CondCode;
23759 
23760     // Optimize to BT if possible.
23761     // Lower (X & (1 << N)) == 0 to BT(X, N).
23762     // Lower ((X >>u N) & 1) != 0 to BT(X, N).
23763     // Lower ((X >>s N) & 1) != 0 to BT(X, N).
23764     if (Op0.getOpcode() == ISD::AND && Op0.hasOneUse() && isNullConstant(Op1)) {
23765       if (SDValue BT = LowerAndToBT(Op0, CC, dl, DAG, X86CondCode)) {
23766         X86CC = DAG.getTargetConstant(X86CondCode, dl, MVT::i8);
23767         return BT;
23768       }
23769     }
23770 
23771     // Try to use PTEST/PMOVMSKB for a tree AND/ORs equality compared with -1/0.
23772     if (SDValue CmpZ = MatchVectorAllEqualTest(Op0, Op1, CC, dl, Subtarget, DAG,
23773                                                X86CondCode)) {
23774       X86CC = DAG.getTargetConstant(X86CondCode, dl, MVT::i8);
23775       return CmpZ;
23776     }
23777 
23778     // Try to lower using KORTEST or KTEST.
23779     if (SDValue Test = EmitAVX512Test(Op0, Op1, CC, dl, DAG, Subtarget, X86CC))
23780       return Test;
23781 
23782     // Look for X == 0, X == 1, X != 0, or X != 1.  We can simplify some forms
23783     // of these.
23784     if (isOneConstant(Op1) || isNullConstant(Op1)) {
23785       // If the input is a setcc, then reuse the input setcc or use a new one
23786       // with the inverted condition.
23787       if (Op0.getOpcode() == X86ISD::SETCC) {
23788         bool Invert = (CC == ISD::SETNE) ^ isNullConstant(Op1);
23789 
23790         X86CC = Op0.getOperand(0);
23791         if (Invert) {
23792           X86CondCode = (X86::CondCode)Op0.getConstantOperandVal(0);
23793           X86CondCode = X86::GetOppositeBranchCondition(X86CondCode);
23794           X86CC = DAG.getTargetConstant(X86CondCode, dl, MVT::i8);
23795         }
23796 
23797         return Op0.getOperand(1);
23798       }
23799     }
23800 
23801     // Look for X == INT_MIN or X != INT_MIN. We can use NEG and test for
23802     // overflow.
23803     if (isMinSignedConstant(Op1)) {
23804       EVT VT = Op0.getValueType();
23805       if (VT == MVT::i32 || VT == MVT::i64 || Op0->hasOneUse()) {
23806         SDVTList CmpVTs = DAG.getVTList(VT, MVT::i32);
23807         X86::CondCode CondCode = CC == ISD::SETEQ ? X86::COND_O : X86::COND_NO;
23808         X86CC = DAG.getTargetConstant(CondCode, dl, MVT::i8);
23809         SDValue Neg = DAG.getNode(X86ISD::SUB, dl, CmpVTs,
23810                                   DAG.getConstant(0, dl, VT), Op0);
23811         return SDValue(Neg.getNode(), 1);
23812       }
23813     }
23814 
23815     // Try to use the carry flag from the add in place of an separate CMP for:
23816     // (seteq (add X, -1), -1). Similar for setne.
23817     if (isAllOnesConstant(Op1) && Op0.getOpcode() == ISD::ADD &&
23818         Op0.getOperand(1) == Op1) {
23819       if (isProfitableToUseFlagOp(Op0)) {
23820         SDVTList VTs = DAG.getVTList(Op0.getValueType(), MVT::i32);
23821 
23822         SDValue New = DAG.getNode(X86ISD::ADD, dl, VTs, Op0.getOperand(0),
23823                                   Op0.getOperand(1));
23824         DAG.ReplaceAllUsesOfValueWith(SDValue(Op0.getNode(), 0), New);
23825         X86CondCode = CC == ISD::SETEQ ? X86::COND_AE : X86::COND_B;
23826         X86CC = DAG.getTargetConstant(X86CondCode, dl, MVT::i8);
23827         return SDValue(New.getNode(), 1);
23828       }
23829     }
23830   }
23831 
23832   X86::CondCode CondCode =
23833       TranslateX86CC(CC, dl, /*IsFP*/ false, Op0, Op1, DAG);
23834   assert(CondCode != X86::COND_INVALID && "Unexpected condition code!");
23835 
23836   SDValue EFLAGS = EmitCmp(Op0, Op1, CondCode, dl, DAG, Subtarget);
23837   X86CC = DAG.getTargetConstant(CondCode, dl, MVT::i8);
23838   return EFLAGS;
23839 }
23840 
LowerSETCC(SDValue Op,SelectionDAG & DAG) const23841 SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
23842 
23843   bool IsStrict = Op.getOpcode() == ISD::STRICT_FSETCC ||
23844                   Op.getOpcode() == ISD::STRICT_FSETCCS;
23845   MVT VT = Op->getSimpleValueType(0);
23846 
23847   if (VT.isVector()) return LowerVSETCC(Op, Subtarget, DAG);
23848 
23849   assert(VT == MVT::i8 && "SetCC type must be 8-bit integer");
23850   SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
23851   SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0);
23852   SDValue Op1 = Op.getOperand(IsStrict ? 2 : 1);
23853   SDLoc dl(Op);
23854   ISD::CondCode CC =
23855       cast<CondCodeSDNode>(Op.getOperand(IsStrict ? 3 : 2))->get();
23856 
23857   if (isSoftF16(Op0.getValueType(), Subtarget))
23858     return SDValue();
23859 
23860   // Handle f128 first, since one possible outcome is a normal integer
23861   // comparison which gets handled by emitFlagsForSetcc.
23862   if (Op0.getValueType() == MVT::f128) {
23863     softenSetCCOperands(DAG, MVT::f128, Op0, Op1, CC, dl, Op0, Op1, Chain,
23864                         Op.getOpcode() == ISD::STRICT_FSETCCS);
23865 
23866     // If softenSetCCOperands returned a scalar, use it.
23867     if (!Op1.getNode()) {
23868       assert(Op0.getValueType() == Op.getValueType() &&
23869              "Unexpected setcc expansion!");
23870       if (IsStrict)
23871         return DAG.getMergeValues({Op0, Chain}, dl);
23872       return Op0;
23873     }
23874   }
23875 
23876   if (Op0.getSimpleValueType().isInteger()) {
23877     // Attempt to canonicalize SGT/UGT -> SGE/UGE compares with constant which
23878     // reduces the number of EFLAGs bit reads (the GE conditions don't read ZF),
23879     // this may translate to less uops depending on uarch implementation. The
23880     // equivalent for SLE/ULE -> SLT/ULT isn't likely to happen as we already
23881     // canonicalize to that CondCode.
23882     // NOTE: Only do this if incrementing the constant doesn't increase the bit
23883     // encoding size - so it must either already be a i8 or i32 immediate, or it
23884     // shrinks down to that. We don't do this for any i64's to avoid additional
23885     // constant materializations.
23886     // TODO: Can we move this to TranslateX86CC to handle jumps/branches too?
23887     if (auto *Op1C = dyn_cast<ConstantSDNode>(Op1)) {
23888       const APInt &Op1Val = Op1C->getAPIntValue();
23889       if (!Op1Val.isZero()) {
23890         // Ensure the constant+1 doesn't overflow.
23891         if ((CC == ISD::CondCode::SETGT && !Op1Val.isMaxSignedValue()) ||
23892             (CC == ISD::CondCode::SETUGT && !Op1Val.isMaxValue())) {
23893           APInt Op1ValPlusOne = Op1Val + 1;
23894           if (Op1ValPlusOne.isSignedIntN(32) &&
23895               (!Op1Val.isSignedIntN(8) || Op1ValPlusOne.isSignedIntN(8))) {
23896             Op1 = DAG.getConstant(Op1ValPlusOne, dl, Op0.getValueType());
23897             CC = CC == ISD::CondCode::SETGT ? ISD::CondCode::SETGE
23898                                             : ISD::CondCode::SETUGE;
23899           }
23900         }
23901       }
23902     }
23903 
23904     SDValue X86CC;
23905     SDValue EFLAGS = emitFlagsForSetcc(Op0, Op1, CC, dl, DAG, X86CC);
23906     SDValue Res = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, X86CC, EFLAGS);
23907     return IsStrict ? DAG.getMergeValues({Res, Chain}, dl) : Res;
23908   }
23909 
23910   // Handle floating point.
23911   X86::CondCode CondCode = TranslateX86CC(CC, dl, /*IsFP*/ true, Op0, Op1, DAG);
23912   if (CondCode == X86::COND_INVALID)
23913     return SDValue();
23914 
23915   SDValue EFLAGS;
23916   if (IsStrict) {
23917     bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
23918     EFLAGS =
23919         DAG.getNode(IsSignaling ? X86ISD::STRICT_FCMPS : X86ISD::STRICT_FCMP,
23920                     dl, {MVT::i32, MVT::Other}, {Chain, Op0, Op1});
23921     Chain = EFLAGS.getValue(1);
23922   } else {
23923     EFLAGS = DAG.getNode(X86ISD::FCMP, dl, MVT::i32, Op0, Op1);
23924   }
23925 
23926   SDValue X86CC = DAG.getTargetConstant(CondCode, dl, MVT::i8);
23927   SDValue Res = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, X86CC, EFLAGS);
23928   return IsStrict ? DAG.getMergeValues({Res, Chain}, dl) : Res;
23929 }
23930 
LowerSETCCCARRY(SDValue Op,SelectionDAG & DAG) const23931 SDValue X86TargetLowering::LowerSETCCCARRY(SDValue Op, SelectionDAG &DAG) const {
23932   SDValue LHS = Op.getOperand(0);
23933   SDValue RHS = Op.getOperand(1);
23934   SDValue Carry = Op.getOperand(2);
23935   SDValue Cond = Op.getOperand(3);
23936   SDLoc DL(Op);
23937 
23938   assert(LHS.getSimpleValueType().isInteger() && "SETCCCARRY is integer only.");
23939   X86::CondCode CC = TranslateIntegerX86CC(cast<CondCodeSDNode>(Cond)->get());
23940 
23941   // Recreate the carry if needed.
23942   EVT CarryVT = Carry.getValueType();
23943   Carry = DAG.getNode(X86ISD::ADD, DL, DAG.getVTList(CarryVT, MVT::i32),
23944                       Carry, DAG.getAllOnesConstant(DL, CarryVT));
23945 
23946   SDVTList VTs = DAG.getVTList(LHS.getValueType(), MVT::i32);
23947   SDValue Cmp = DAG.getNode(X86ISD::SBB, DL, VTs, LHS, RHS, Carry.getValue(1));
23948   return getSETCC(CC, Cmp.getValue(1), DL, DAG);
23949 }
23950 
23951 // This function returns three things: the arithmetic computation itself
23952 // (Value), an EFLAGS result (Overflow), and a condition code (Cond).  The
23953 // flag and the condition code define the case in which the arithmetic
23954 // computation overflows.
23955 static std::pair<SDValue, SDValue>
getX86XALUOOp(X86::CondCode & Cond,SDValue Op,SelectionDAG & DAG)23956 getX86XALUOOp(X86::CondCode &Cond, SDValue Op, SelectionDAG &DAG) {
23957   assert(Op.getResNo() == 0 && "Unexpected result number!");
23958   SDValue Value, Overflow;
23959   SDValue LHS = Op.getOperand(0);
23960   SDValue RHS = Op.getOperand(1);
23961   unsigned BaseOp = 0;
23962   SDLoc DL(Op);
23963   switch (Op.getOpcode()) {
23964   default: llvm_unreachable("Unknown ovf instruction!");
23965   case ISD::SADDO:
23966     BaseOp = X86ISD::ADD;
23967     Cond = X86::COND_O;
23968     break;
23969   case ISD::UADDO:
23970     BaseOp = X86ISD::ADD;
23971     Cond = isOneConstant(RHS) ? X86::COND_E : X86::COND_B;
23972     break;
23973   case ISD::SSUBO:
23974     BaseOp = X86ISD::SUB;
23975     Cond = X86::COND_O;
23976     break;
23977   case ISD::USUBO:
23978     BaseOp = X86ISD::SUB;
23979     Cond = X86::COND_B;
23980     break;
23981   case ISD::SMULO:
23982     BaseOp = X86ISD::SMUL;
23983     Cond = X86::COND_O;
23984     break;
23985   case ISD::UMULO:
23986     BaseOp = X86ISD::UMUL;
23987     Cond = X86::COND_O;
23988     break;
23989   }
23990 
23991   if (BaseOp) {
23992     // Also sets EFLAGS.
23993     SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
23994     Value = DAG.getNode(BaseOp, DL, VTs, LHS, RHS);
23995     Overflow = Value.getValue(1);
23996   }
23997 
23998   return std::make_pair(Value, Overflow);
23999 }
24000 
LowerXALUO(SDValue Op,SelectionDAG & DAG)24001 static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
24002   // Lower the "add/sub/mul with overflow" instruction into a regular ins plus
24003   // a "setcc" instruction that checks the overflow flag. The "brcond" lowering
24004   // looks for this combo and may remove the "setcc" instruction if the "setcc"
24005   // has only one use.
24006   SDLoc DL(Op);
24007   X86::CondCode Cond;
24008   SDValue Value, Overflow;
24009   std::tie(Value, Overflow) = getX86XALUOOp(Cond, Op, DAG);
24010 
24011   SDValue SetCC = getSETCC(Cond, Overflow, DL, DAG);
24012   assert(Op->getValueType(1) == MVT::i8 && "Unexpected VT!");
24013   return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(), Value, SetCC);
24014 }
24015 
24016 /// Return true if opcode is a X86 logical comparison.
isX86LogicalCmp(SDValue Op)24017 static bool isX86LogicalCmp(SDValue Op) {
24018   unsigned Opc = Op.getOpcode();
24019   if (Opc == X86ISD::CMP || Opc == X86ISD::COMI || Opc == X86ISD::UCOMI ||
24020       Opc == X86ISD::FCMP)
24021     return true;
24022   if (Op.getResNo() == 1 &&
24023       (Opc == X86ISD::ADD || Opc == X86ISD::SUB || Opc == X86ISD::ADC ||
24024        Opc == X86ISD::SBB || Opc == X86ISD::SMUL || Opc == X86ISD::UMUL ||
24025        Opc == X86ISD::OR || Opc == X86ISD::XOR || Opc == X86ISD::AND))
24026     return true;
24027 
24028   return false;
24029 }
24030 
isTruncWithZeroHighBitsInput(SDValue V,SelectionDAG & DAG)24031 static bool isTruncWithZeroHighBitsInput(SDValue V, SelectionDAG &DAG) {
24032   if (V.getOpcode() != ISD::TRUNCATE)
24033     return false;
24034 
24035   SDValue VOp0 = V.getOperand(0);
24036   unsigned InBits = VOp0.getValueSizeInBits();
24037   unsigned Bits = V.getValueSizeInBits();
24038   return DAG.MaskedValueIsZero(VOp0, APInt::getHighBitsSet(InBits,InBits-Bits));
24039 }
24040 
LowerSELECT(SDValue Op,SelectionDAG & DAG) const24041 SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
24042   bool AddTest = true;
24043   SDValue Cond  = Op.getOperand(0);
24044   SDValue Op1 = Op.getOperand(1);
24045   SDValue Op2 = Op.getOperand(2);
24046   SDLoc DL(Op);
24047   MVT VT = Op1.getSimpleValueType();
24048   SDValue CC;
24049 
24050   if (isSoftF16(VT, Subtarget)) {
24051     MVT NVT = VT.changeTypeToInteger();
24052     return DAG.getBitcast(VT, DAG.getNode(ISD::SELECT, DL, NVT, Cond,
24053                                           DAG.getBitcast(NVT, Op1),
24054                                           DAG.getBitcast(NVT, Op2)));
24055   }
24056 
24057   // Lower FP selects into a CMP/AND/ANDN/OR sequence when the necessary SSE ops
24058   // are available or VBLENDV if AVX is available.
24059   // Otherwise FP cmovs get lowered into a less efficient branch sequence later.
24060   if (Cond.getOpcode() == ISD::SETCC && isScalarFPTypeInSSEReg(VT) &&
24061       VT == Cond.getOperand(0).getSimpleValueType() && Cond->hasOneUse()) {
24062     SDValue CondOp0 = Cond.getOperand(0), CondOp1 = Cond.getOperand(1);
24063     bool IsAlwaysSignaling;
24064     unsigned SSECC =
24065         translateX86FSETCC(cast<CondCodeSDNode>(Cond.getOperand(2))->get(),
24066                            CondOp0, CondOp1, IsAlwaysSignaling);
24067 
24068     if (Subtarget.hasAVX512()) {
24069       SDValue Cmp =
24070           DAG.getNode(X86ISD::FSETCCM, DL, MVT::v1i1, CondOp0, CondOp1,
24071                       DAG.getTargetConstant(SSECC, DL, MVT::i8));
24072       assert(!VT.isVector() && "Not a scalar type?");
24073       return DAG.getNode(X86ISD::SELECTS, DL, VT, Cmp, Op1, Op2);
24074     }
24075 
24076     if (SSECC < 8 || Subtarget.hasAVX()) {
24077       SDValue Cmp = DAG.getNode(X86ISD::FSETCC, DL, VT, CondOp0, CondOp1,
24078                                 DAG.getTargetConstant(SSECC, DL, MVT::i8));
24079 
24080       // If we have AVX, we can use a variable vector select (VBLENDV) instead
24081       // of 3 logic instructions for size savings and potentially speed.
24082       // Unfortunately, there is no scalar form of VBLENDV.
24083 
24084       // If either operand is a +0.0 constant, don't try this. We can expect to
24085       // optimize away at least one of the logic instructions later in that
24086       // case, so that sequence would be faster than a variable blend.
24087 
24088       // BLENDV was introduced with SSE 4.1, but the 2 register form implicitly
24089       // uses XMM0 as the selection register. That may need just as many
24090       // instructions as the AND/ANDN/OR sequence due to register moves, so
24091       // don't bother.
24092       if (Subtarget.hasAVX() && !isNullFPConstant(Op1) &&
24093           !isNullFPConstant(Op2)) {
24094         // Convert to vectors, do a VSELECT, and convert back to scalar.
24095         // All of the conversions should be optimized away.
24096         MVT VecVT = VT == MVT::f32 ? MVT::v4f32 : MVT::v2f64;
24097         SDValue VOp1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Op1);
24098         SDValue VOp2 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Op2);
24099         SDValue VCmp = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Cmp);
24100 
24101         MVT VCmpVT = VT == MVT::f32 ? MVT::v4i32 : MVT::v2i64;
24102         VCmp = DAG.getBitcast(VCmpVT, VCmp);
24103 
24104         SDValue VSel = DAG.getSelect(DL, VecVT, VCmp, VOp1, VOp2);
24105 
24106         return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
24107                            VSel, DAG.getIntPtrConstant(0, DL));
24108       }
24109       SDValue AndN = DAG.getNode(X86ISD::FANDN, DL, VT, Cmp, Op2);
24110       SDValue And = DAG.getNode(X86ISD::FAND, DL, VT, Cmp, Op1);
24111       return DAG.getNode(X86ISD::FOR, DL, VT, AndN, And);
24112     }
24113   }
24114 
24115   // AVX512 fallback is to lower selects of scalar floats to masked moves.
24116   if (isScalarFPTypeInSSEReg(VT) && Subtarget.hasAVX512()) {
24117     SDValue Cmp = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v1i1, Cond);
24118     return DAG.getNode(X86ISD::SELECTS, DL, VT, Cmp, Op1, Op2);
24119   }
24120 
24121   if (Cond.getOpcode() == ISD::SETCC &&
24122       !isSoftF16(Cond.getOperand(0).getSimpleValueType(), Subtarget)) {
24123     if (SDValue NewCond = LowerSETCC(Cond, DAG)) {
24124       Cond = NewCond;
24125       // If the condition was updated, it's possible that the operands of the
24126       // select were also updated (for example, EmitTest has a RAUW). Refresh
24127       // the local references to the select operands in case they got stale.
24128       Op1 = Op.getOperand(1);
24129       Op2 = Op.getOperand(2);
24130     }
24131   }
24132 
24133   // (select (x == 0), -1, y) -> (sign_bit (x - 1)) | y
24134   // (select (x == 0), y, -1) -> ~(sign_bit (x - 1)) | y
24135   // (select (x != 0), y, -1) -> (sign_bit (x - 1)) | y
24136   // (select (x != 0), -1, y) -> ~(sign_bit (x - 1)) | y
24137   // (select (and (x , 0x1) == 0), y, (z ^ y) ) -> (-(and (x , 0x1)) & z ) ^ y
24138   // (select (and (x , 0x1) == 0), y, (z | y) ) -> (-(and (x , 0x1)) & z ) | y
24139   // (select (x > 0), x, 0) -> (~(x >> (size_in_bits(x)-1))) & x
24140   // (select (x < 0), x, 0) -> ((x >> (size_in_bits(x)-1))) & x
24141   if (Cond.getOpcode() == X86ISD::SETCC &&
24142       Cond.getOperand(1).getOpcode() == X86ISD::CMP &&
24143       isNullConstant(Cond.getOperand(1).getOperand(1))) {
24144     SDValue Cmp = Cond.getOperand(1);
24145     SDValue CmpOp0 = Cmp.getOperand(0);
24146     unsigned CondCode = Cond.getConstantOperandVal(0);
24147 
24148     // Special handling for __builtin_ffs(X) - 1 pattern which looks like
24149     // (select (seteq X, 0), -1, (cttz_zero_undef X)). Disable the special
24150     // handle to keep the CMP with 0. This should be removed by
24151     // optimizeCompareInst by using the flags from the BSR/TZCNT used for the
24152     // cttz_zero_undef.
24153     auto MatchFFSMinus1 = [&](SDValue Op1, SDValue Op2) {
24154       return (Op1.getOpcode() == ISD::CTTZ_ZERO_UNDEF && Op1.hasOneUse() &&
24155               Op1.getOperand(0) == CmpOp0 && isAllOnesConstant(Op2));
24156     };
24157     if (Subtarget.canUseCMOV() && (VT == MVT::i32 || VT == MVT::i64) &&
24158         ((CondCode == X86::COND_NE && MatchFFSMinus1(Op1, Op2)) ||
24159          (CondCode == X86::COND_E && MatchFFSMinus1(Op2, Op1)))) {
24160       // Keep Cmp.
24161     } else if ((isAllOnesConstant(Op1) || isAllOnesConstant(Op2)) &&
24162         (CondCode == X86::COND_E || CondCode == X86::COND_NE)) {
24163       SDValue Y = isAllOnesConstant(Op2) ? Op1 : Op2;
24164       SDVTList CmpVTs = DAG.getVTList(CmpOp0.getValueType(), MVT::i32);
24165 
24166       // 'X - 1' sets the carry flag if X == 0.
24167       // '0 - X' sets the carry flag if X != 0.
24168       // Convert the carry flag to a -1/0 mask with sbb:
24169       // select (X != 0), -1, Y --> 0 - X; or (sbb), Y
24170       // select (X == 0), Y, -1 --> 0 - X; or (sbb), Y
24171       // select (X != 0), Y, -1 --> X - 1; or (sbb), Y
24172       // select (X == 0), -1, Y --> X - 1; or (sbb), Y
24173       SDValue Sub;
24174       if (isAllOnesConstant(Op1) == (CondCode == X86::COND_NE)) {
24175         SDValue Zero = DAG.getConstant(0, DL, CmpOp0.getValueType());
24176         Sub = DAG.getNode(X86ISD::SUB, DL, CmpVTs, Zero, CmpOp0);
24177       } else {
24178         SDValue One = DAG.getConstant(1, DL, CmpOp0.getValueType());
24179         Sub = DAG.getNode(X86ISD::SUB, DL, CmpVTs, CmpOp0, One);
24180       }
24181       SDValue SBB = DAG.getNode(X86ISD::SETCC_CARRY, DL, VT,
24182                                 DAG.getTargetConstant(X86::COND_B, DL, MVT::i8),
24183                                 Sub.getValue(1));
24184       return DAG.getNode(ISD::OR, DL, VT, SBB, Y);
24185     } else if (!Subtarget.canUseCMOV() && CondCode == X86::COND_E &&
24186                CmpOp0.getOpcode() == ISD::AND &&
24187                isOneConstant(CmpOp0.getOperand(1))) {
24188       SDValue Src1, Src2;
24189       // true if Op2 is XOR or OR operator and one of its operands
24190       // is equal to Op1
24191       // ( a , a op b) || ( b , a op b)
24192       auto isOrXorPattern = [&]() {
24193         if ((Op2.getOpcode() == ISD::XOR || Op2.getOpcode() == ISD::OR) &&
24194             (Op2.getOperand(0) == Op1 || Op2.getOperand(1) == Op1)) {
24195           Src1 =
24196               Op2.getOperand(0) == Op1 ? Op2.getOperand(1) : Op2.getOperand(0);
24197           Src2 = Op1;
24198           return true;
24199         }
24200         return false;
24201       };
24202 
24203       if (isOrXorPattern()) {
24204         SDValue Neg;
24205         unsigned int CmpSz = CmpOp0.getSimpleValueType().getSizeInBits();
24206         // we need mask of all zeros or ones with same size of the other
24207         // operands.
24208         if (CmpSz > VT.getSizeInBits())
24209           Neg = DAG.getNode(ISD::TRUNCATE, DL, VT, CmpOp0);
24210         else if (CmpSz < VT.getSizeInBits())
24211           Neg = DAG.getNode(ISD::AND, DL, VT,
24212               DAG.getNode(ISD::ANY_EXTEND, DL, VT, CmpOp0.getOperand(0)),
24213               DAG.getConstant(1, DL, VT));
24214         else
24215           Neg = CmpOp0;
24216         SDValue Mask = DAG.getNegative(Neg, DL, VT); // -(and (x, 0x1))
24217         SDValue And = DAG.getNode(ISD::AND, DL, VT, Mask, Src1); // Mask & z
24218         return DAG.getNode(Op2.getOpcode(), DL, VT, And, Src2);  // And Op y
24219       }
24220     } else if ((VT == MVT::i32 || VT == MVT::i64) && isNullConstant(Op2) &&
24221                Cmp.getNode()->hasOneUse() && (CmpOp0 == Op1) &&
24222                ((CondCode == X86::COND_S) ||                    // smin(x, 0)
24223                 (CondCode == X86::COND_G && hasAndNot(Op1)))) { // smax(x, 0)
24224       // (select (x < 0), x, 0) -> ((x >> (size_in_bits(x)-1))) & x
24225       //
24226       // If the comparison is testing for a positive value, we have to invert
24227       // the sign bit mask, so only do that transform if the target has a
24228       // bitwise 'and not' instruction (the invert is free).
24229       // (select (x > 0), x, 0) -> (~(x >> (size_in_bits(x)-1))) & x
24230       unsigned ShCt = VT.getSizeInBits() - 1;
24231       SDValue ShiftAmt = DAG.getConstant(ShCt, DL, VT);
24232       SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, Op1, ShiftAmt);
24233       if (CondCode == X86::COND_G)
24234         Shift = DAG.getNOT(DL, Shift, VT);
24235       return DAG.getNode(ISD::AND, DL, VT, Shift, Op1);
24236     }
24237   }
24238 
24239   // Look past (and (setcc_carry (cmp ...)), 1).
24240   if (Cond.getOpcode() == ISD::AND &&
24241       Cond.getOperand(0).getOpcode() == X86ISD::SETCC_CARRY &&
24242       isOneConstant(Cond.getOperand(1)))
24243     Cond = Cond.getOperand(0);
24244 
24245   // If condition flag is set by a X86ISD::CMP, then use it as the condition
24246   // setting operand in place of the X86ISD::SETCC.
24247   unsigned CondOpcode = Cond.getOpcode();
24248   if (CondOpcode == X86ISD::SETCC ||
24249       CondOpcode == X86ISD::SETCC_CARRY) {
24250     CC = Cond.getOperand(0);
24251 
24252     SDValue Cmp = Cond.getOperand(1);
24253     bool IllegalFPCMov = false;
24254     if (VT.isFloatingPoint() && !VT.isVector() &&
24255         !isScalarFPTypeInSSEReg(VT) && Subtarget.canUseCMOV())  // FPStack?
24256       IllegalFPCMov = !hasFPCMov(cast<ConstantSDNode>(CC)->getSExtValue());
24257 
24258     if ((isX86LogicalCmp(Cmp) && !IllegalFPCMov) ||
24259         Cmp.getOpcode() == X86ISD::BT) { // FIXME
24260       Cond = Cmp;
24261       AddTest = false;
24262     }
24263   } else if (CondOpcode == ISD::USUBO || CondOpcode == ISD::SSUBO ||
24264              CondOpcode == ISD::UADDO || CondOpcode == ISD::SADDO ||
24265              CondOpcode == ISD::UMULO || CondOpcode == ISD::SMULO) {
24266     SDValue Value;
24267     X86::CondCode X86Cond;
24268     std::tie(Value, Cond) = getX86XALUOOp(X86Cond, Cond.getValue(0), DAG);
24269 
24270     CC = DAG.getTargetConstant(X86Cond, DL, MVT::i8);
24271     AddTest = false;
24272   }
24273 
24274   if (AddTest) {
24275     // Look past the truncate if the high bits are known zero.
24276     if (isTruncWithZeroHighBitsInput(Cond, DAG))
24277       Cond = Cond.getOperand(0);
24278 
24279     // We know the result of AND is compared against zero. Try to match
24280     // it to BT.
24281     if (Cond.getOpcode() == ISD::AND && Cond.hasOneUse()) {
24282       X86::CondCode X86CondCode;
24283       if (SDValue BT = LowerAndToBT(Cond, ISD::SETNE, DL, DAG, X86CondCode)) {
24284         CC = DAG.getTargetConstant(X86CondCode, DL, MVT::i8);
24285         Cond = BT;
24286         AddTest = false;
24287       }
24288     }
24289   }
24290 
24291   if (AddTest) {
24292     CC = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
24293     Cond = EmitTest(Cond, X86::COND_NE, DL, DAG, Subtarget);
24294   }
24295 
24296   // a <  b ? -1 :  0 -> RES = ~setcc_carry
24297   // a <  b ?  0 : -1 -> RES = setcc_carry
24298   // a >= b ? -1 :  0 -> RES = setcc_carry
24299   // a >= b ?  0 : -1 -> RES = ~setcc_carry
24300   if (Cond.getOpcode() == X86ISD::SUB) {
24301     unsigned CondCode = CC->getAsZExtVal();
24302 
24303     if ((CondCode == X86::COND_AE || CondCode == X86::COND_B) &&
24304         (isAllOnesConstant(Op1) || isAllOnesConstant(Op2)) &&
24305         (isNullConstant(Op1) || isNullConstant(Op2))) {
24306       SDValue Res =
24307           DAG.getNode(X86ISD::SETCC_CARRY, DL, Op.getValueType(),
24308                       DAG.getTargetConstant(X86::COND_B, DL, MVT::i8), Cond);
24309       if (isAllOnesConstant(Op1) != (CondCode == X86::COND_B))
24310         return DAG.getNOT(DL, Res, Res.getValueType());
24311       return Res;
24312     }
24313   }
24314 
24315   // X86 doesn't have an i8 cmov. If both operands are the result of a truncate
24316   // widen the cmov and push the truncate through. This avoids introducing a new
24317   // branch during isel and doesn't add any extensions.
24318   if (Op.getValueType() == MVT::i8 &&
24319       Op1.getOpcode() == ISD::TRUNCATE && Op2.getOpcode() == ISD::TRUNCATE) {
24320     SDValue T1 = Op1.getOperand(0), T2 = Op2.getOperand(0);
24321     if (T1.getValueType() == T2.getValueType() &&
24322         // Exclude CopyFromReg to avoid partial register stalls.
24323         T1.getOpcode() != ISD::CopyFromReg && T2.getOpcode()!=ISD::CopyFromReg){
24324       SDValue Cmov = DAG.getNode(X86ISD::CMOV, DL, T1.getValueType(), T2, T1,
24325                                  CC, Cond);
24326       return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Cmov);
24327     }
24328   }
24329 
24330   // Or finally, promote i8 cmovs if we have CMOV,
24331   //                 or i16 cmovs if it won't prevent folding a load.
24332   // FIXME: we should not limit promotion of i8 case to only when the CMOV is
24333   //        legal, but EmitLoweredSelect() can not deal with these extensions
24334   //        being inserted between two CMOV's. (in i16 case too TBN)
24335   //        https://bugs.llvm.org/show_bug.cgi?id=40974
24336   if ((Op.getValueType() == MVT::i8 && Subtarget.canUseCMOV()) ||
24337       (Op.getValueType() == MVT::i16 && !X86::mayFoldLoad(Op1, Subtarget) &&
24338        !X86::mayFoldLoad(Op2, Subtarget))) {
24339     Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
24340     Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
24341     SDValue Ops[] = { Op2, Op1, CC, Cond };
24342     SDValue Cmov = DAG.getNode(X86ISD::CMOV, DL, MVT::i32, Ops);
24343     return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Cmov);
24344   }
24345 
24346   // X86ISD::CMOV means set the result (which is operand 1) to the RHS if
24347   // condition is true.
24348   SDValue Ops[] = { Op2, Op1, CC, Cond };
24349   return DAG.getNode(X86ISD::CMOV, DL, Op.getValueType(), Ops, Op->getFlags());
24350 }
24351 
LowerSIGN_EXTEND_Mask(SDValue Op,const SDLoc & dl,const X86Subtarget & Subtarget,SelectionDAG & DAG)24352 static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, const SDLoc &dl,
24353                                      const X86Subtarget &Subtarget,
24354                                      SelectionDAG &DAG) {
24355   MVT VT = Op->getSimpleValueType(0);
24356   SDValue In = Op->getOperand(0);
24357   MVT InVT = In.getSimpleValueType();
24358   assert(InVT.getVectorElementType() == MVT::i1 && "Unexpected input type!");
24359   MVT VTElt = VT.getVectorElementType();
24360   unsigned NumElts = VT.getVectorNumElements();
24361 
24362   // Extend VT if the scalar type is i8/i16 and BWI is not supported.
24363   MVT ExtVT = VT;
24364   if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) {
24365     // If v16i32 is to be avoided, we'll need to split and concatenate.
24366     if (NumElts == 16 && !Subtarget.canExtendTo512DQ())
24367       return SplitAndExtendv16i1(Op.getOpcode(), VT, In, dl, DAG);
24368 
24369     ExtVT = MVT::getVectorVT(MVT::i32, NumElts);
24370   }
24371 
24372   // Widen to 512-bits if VLX is not supported.
24373   MVT WideVT = ExtVT;
24374   if (!ExtVT.is512BitVector() && !Subtarget.hasVLX()) {
24375     NumElts *= 512 / ExtVT.getSizeInBits();
24376     InVT = MVT::getVectorVT(MVT::i1, NumElts);
24377     In = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, InVT, DAG.getUNDEF(InVT),
24378                      In, DAG.getIntPtrConstant(0, dl));
24379     WideVT = MVT::getVectorVT(ExtVT.getVectorElementType(), NumElts);
24380   }
24381 
24382   SDValue V;
24383   MVT WideEltVT = WideVT.getVectorElementType();
24384   if ((Subtarget.hasDQI() && WideEltVT.getSizeInBits() >= 32) ||
24385       (Subtarget.hasBWI() && WideEltVT.getSizeInBits() <= 16)) {
24386     V = DAG.getNode(Op.getOpcode(), dl, WideVT, In);
24387   } else {
24388     SDValue NegOne = DAG.getConstant(-1, dl, WideVT);
24389     SDValue Zero = DAG.getConstant(0, dl, WideVT);
24390     V = DAG.getSelect(dl, WideVT, In, NegOne, Zero);
24391   }
24392 
24393   // Truncate if we had to extend i16/i8 above.
24394   if (VT != ExtVT) {
24395     WideVT = MVT::getVectorVT(VTElt, NumElts);
24396     V = DAG.getNode(ISD::TRUNCATE, dl, WideVT, V);
24397   }
24398 
24399   // Extract back to 128/256-bit if we widened.
24400   if (WideVT != VT)
24401     V = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, V,
24402                     DAG.getIntPtrConstant(0, dl));
24403 
24404   return V;
24405 }
24406 
LowerANY_EXTEND(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)24407 static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
24408                                SelectionDAG &DAG) {
24409   SDValue In = Op->getOperand(0);
24410   MVT InVT = In.getSimpleValueType();
24411   SDLoc DL(Op);
24412 
24413   if (InVT.getVectorElementType() == MVT::i1)
24414     return LowerSIGN_EXTEND_Mask(Op, DL, Subtarget, DAG);
24415 
24416   assert(Subtarget.hasAVX() && "Expected AVX support");
24417   return LowerAVXExtend(Op, DL, DAG, Subtarget);
24418 }
24419 
24420 // Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG.
24421 // For sign extend this needs to handle all vector sizes and SSE4.1 and
24422 // non-SSE4.1 targets. For zero extend this should only handle inputs of
24423 // MVT::v64i8 when BWI is not supported, but AVX512 is.
LowerEXTEND_VECTOR_INREG(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)24424 static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op,
24425                                         const X86Subtarget &Subtarget,
24426                                         SelectionDAG &DAG) {
24427   SDValue In = Op->getOperand(0);
24428   MVT VT = Op->getSimpleValueType(0);
24429   MVT InVT = In.getSimpleValueType();
24430 
24431   MVT SVT = VT.getVectorElementType();
24432   MVT InSVT = InVT.getVectorElementType();
24433   assert(SVT.getFixedSizeInBits() > InSVT.getFixedSizeInBits());
24434 
24435   if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16)
24436     return SDValue();
24437   if (InSVT != MVT::i32 && InSVT != MVT::i16 && InSVT != MVT::i8)
24438     return SDValue();
24439   if (!(VT.is128BitVector() && Subtarget.hasSSE2()) &&
24440       !(VT.is256BitVector() && Subtarget.hasAVX()) &&
24441       !(VT.is512BitVector() && Subtarget.hasAVX512()))
24442     return SDValue();
24443 
24444   SDLoc dl(Op);
24445   unsigned Opc = Op.getOpcode();
24446   unsigned NumElts = VT.getVectorNumElements();
24447 
24448   // For 256-bit vectors, we only need the lower (128-bit) half of the input.
24449   // For 512-bit vectors, we need 128-bits or 256-bits.
24450   if (InVT.getSizeInBits() > 128) {
24451     // Input needs to be at least the same number of elements as output, and
24452     // at least 128-bits.
24453     int InSize = InSVT.getSizeInBits() * NumElts;
24454     In = extractSubVector(In, 0, DAG, dl, std::max(InSize, 128));
24455     InVT = In.getSimpleValueType();
24456   }
24457 
24458   // SSE41 targets can use the pmov[sz]x* instructions directly for 128-bit results,
24459   // so are legal and shouldn't occur here. AVX2/AVX512 pmovsx* instructions still
24460   // need to be handled here for 256/512-bit results.
24461   if (Subtarget.hasInt256()) {
24462     assert(VT.getSizeInBits() > 128 && "Unexpected 128-bit vector extension");
24463 
24464     if (InVT.getVectorNumElements() != NumElts)
24465       return DAG.getNode(Op.getOpcode(), dl, VT, In);
24466 
24467     // FIXME: Apparently we create inreg operations that could be regular
24468     // extends.
24469     unsigned ExtOpc =
24470         Opc == ISD::SIGN_EXTEND_VECTOR_INREG ? ISD::SIGN_EXTEND
24471                                              : ISD::ZERO_EXTEND;
24472     return DAG.getNode(ExtOpc, dl, VT, In);
24473   }
24474 
24475   // pre-AVX2 256-bit extensions need to be split into 128-bit instructions.
24476   if (Subtarget.hasAVX()) {
24477     assert(VT.is256BitVector() && "256-bit vector expected");
24478     MVT HalfVT = VT.getHalfNumVectorElementsVT();
24479     int HalfNumElts = HalfVT.getVectorNumElements();
24480 
24481     unsigned NumSrcElts = InVT.getVectorNumElements();
24482     SmallVector<int, 16> HiMask(NumSrcElts, SM_SentinelUndef);
24483     for (int i = 0; i != HalfNumElts; ++i)
24484       HiMask[i] = HalfNumElts + i;
24485 
24486     SDValue Lo = DAG.getNode(Opc, dl, HalfVT, In);
24487     SDValue Hi = DAG.getVectorShuffle(InVT, dl, In, DAG.getUNDEF(InVT), HiMask);
24488     Hi = DAG.getNode(Opc, dl, HalfVT, Hi);
24489     return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
24490   }
24491 
24492   // We should only get here for sign extend.
24493   assert(Opc == ISD::SIGN_EXTEND_VECTOR_INREG && "Unexpected opcode!");
24494   assert(VT.is128BitVector() && InVT.is128BitVector() && "Unexpected VTs");
24495   unsigned InNumElts = InVT.getVectorNumElements();
24496 
24497   // If the source elements are already all-signbits, we don't need to extend,
24498   // just splat the elements.
24499   APInt DemandedElts = APInt::getLowBitsSet(InNumElts, NumElts);
24500   if (DAG.ComputeNumSignBits(In, DemandedElts) == InVT.getScalarSizeInBits()) {
24501     unsigned Scale = InNumElts / NumElts;
24502     SmallVector<int, 16> ShuffleMask;
24503     for (unsigned I = 0; I != NumElts; ++I)
24504       ShuffleMask.append(Scale, I);
24505     return DAG.getBitcast(VT,
24506                           DAG.getVectorShuffle(InVT, dl, In, In, ShuffleMask));
24507   }
24508 
24509   // pre-SSE41 targets unpack lower lanes and then sign-extend using SRAI.
24510   SDValue Curr = In;
24511   SDValue SignExt = Curr;
24512 
24513   // As SRAI is only available on i16/i32 types, we expand only up to i32
24514   // and handle i64 separately.
24515   if (InVT != MVT::v4i32) {
24516     MVT DestVT = VT == MVT::v2i64 ? MVT::v4i32 : VT;
24517 
24518     unsigned DestWidth = DestVT.getScalarSizeInBits();
24519     unsigned Scale = DestWidth / InSVT.getSizeInBits();
24520     unsigned DestElts = DestVT.getVectorNumElements();
24521 
24522     // Build a shuffle mask that takes each input element and places it in the
24523     // MSBs of the new element size.
24524     SmallVector<int, 16> Mask(InNumElts, SM_SentinelUndef);
24525     for (unsigned i = 0; i != DestElts; ++i)
24526       Mask[i * Scale + (Scale - 1)] = i;
24527 
24528     Curr = DAG.getVectorShuffle(InVT, dl, In, In, Mask);
24529     Curr = DAG.getBitcast(DestVT, Curr);
24530 
24531     unsigned SignExtShift = DestWidth - InSVT.getSizeInBits();
24532     SignExt = DAG.getNode(X86ISD::VSRAI, dl, DestVT, Curr,
24533                           DAG.getTargetConstant(SignExtShift, dl, MVT::i8));
24534   }
24535 
24536   if (VT == MVT::v2i64) {
24537     assert(Curr.getValueType() == MVT::v4i32 && "Unexpected input VT");
24538     SDValue Zero = DAG.getConstant(0, dl, MVT::v4i32);
24539     SDValue Sign = DAG.getSetCC(dl, MVT::v4i32, Zero, Curr, ISD::SETGT);
24540     SignExt = DAG.getVectorShuffle(MVT::v4i32, dl, SignExt, Sign, {0, 4, 1, 5});
24541     SignExt = DAG.getBitcast(VT, SignExt);
24542   }
24543 
24544   return SignExt;
24545 }
24546 
LowerSIGN_EXTEND(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)24547 static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
24548                                 SelectionDAG &DAG) {
24549   MVT VT = Op->getSimpleValueType(0);
24550   SDValue In = Op->getOperand(0);
24551   MVT InVT = In.getSimpleValueType();
24552   SDLoc dl(Op);
24553 
24554   if (InVT.getVectorElementType() == MVT::i1)
24555     return LowerSIGN_EXTEND_Mask(Op, dl, Subtarget, DAG);
24556 
24557   assert(VT.isVector() && InVT.isVector() && "Expected vector type");
24558   assert(VT.getVectorNumElements() == InVT.getVectorNumElements() &&
24559          "Expected same number of elements");
24560   assert((VT.getVectorElementType() == MVT::i16 ||
24561           VT.getVectorElementType() == MVT::i32 ||
24562           VT.getVectorElementType() == MVT::i64) &&
24563          "Unexpected element type");
24564   assert((InVT.getVectorElementType() == MVT::i8 ||
24565           InVT.getVectorElementType() == MVT::i16 ||
24566           InVT.getVectorElementType() == MVT::i32) &&
24567          "Unexpected element type");
24568 
24569   if (VT == MVT::v32i16 && !Subtarget.hasBWI()) {
24570     assert(InVT == MVT::v32i8 && "Unexpected VT!");
24571     return splitVectorIntUnary(Op, DAG, dl);
24572   }
24573 
24574   if (Subtarget.hasInt256())
24575     return Op;
24576 
24577   // Optimize vectors in AVX mode
24578   // Sign extend  v8i16 to v8i32 and
24579   //              v4i32 to v4i64
24580   //
24581   // Divide input vector into two parts
24582   // for v4i32 the high shuffle mask will be {2, 3, -1, -1}
24583   // use vpmovsx instruction to extend v4i32 -> v2i64; v8i16 -> v4i32
24584   // concat the vectors to original VT
24585   MVT HalfVT = VT.getHalfNumVectorElementsVT();
24586   SDValue OpLo = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, HalfVT, In);
24587 
24588   unsigned NumElems = InVT.getVectorNumElements();
24589   SmallVector<int,8> ShufMask(NumElems, -1);
24590   for (unsigned i = 0; i != NumElems/2; ++i)
24591     ShufMask[i] = i + NumElems/2;
24592 
24593   SDValue OpHi = DAG.getVectorShuffle(InVT, dl, In, In, ShufMask);
24594   OpHi = DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, dl, HalfVT, OpHi);
24595 
24596   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi);
24597 }
24598 
24599 /// Change a vector store into a pair of half-size vector stores.
splitVectorStore(StoreSDNode * Store,SelectionDAG & DAG)24600 static SDValue splitVectorStore(StoreSDNode *Store, SelectionDAG &DAG) {
24601   SDValue StoredVal = Store->getValue();
24602   assert((StoredVal.getValueType().is256BitVector() ||
24603           StoredVal.getValueType().is512BitVector()) &&
24604          "Expecting 256/512-bit op");
24605 
24606   // Splitting volatile memory ops is not allowed unless the operation was not
24607   // legal to begin with. Assume the input store is legal (this transform is
24608   // only used for targets with AVX). Note: It is possible that we have an
24609   // illegal type like v2i128, and so we could allow splitting a volatile store
24610   // in that case if that is important.
24611   if (!Store->isSimple())
24612     return SDValue();
24613 
24614   SDLoc DL(Store);
24615   SDValue Value0, Value1;
24616   std::tie(Value0, Value1) = splitVector(StoredVal, DAG, DL);
24617   unsigned HalfOffset = Value0.getValueType().getStoreSize();
24618   SDValue Ptr0 = Store->getBasePtr();
24619   SDValue Ptr1 =
24620       DAG.getMemBasePlusOffset(Ptr0, TypeSize::getFixed(HalfOffset), DL);
24621   SDValue Ch0 =
24622       DAG.getStore(Store->getChain(), DL, Value0, Ptr0, Store->getPointerInfo(),
24623                    Store->getOriginalAlign(),
24624                    Store->getMemOperand()->getFlags());
24625   SDValue Ch1 = DAG.getStore(Store->getChain(), DL, Value1, Ptr1,
24626                              Store->getPointerInfo().getWithOffset(HalfOffset),
24627                              Store->getOriginalAlign(),
24628                              Store->getMemOperand()->getFlags());
24629   return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Ch0, Ch1);
24630 }
24631 
24632 /// Scalarize a vector store, bitcasting to TargetVT to determine the scalar
24633 /// type.
scalarizeVectorStore(StoreSDNode * Store,MVT StoreVT,SelectionDAG & DAG)24634 static SDValue scalarizeVectorStore(StoreSDNode *Store, MVT StoreVT,
24635                                     SelectionDAG &DAG) {
24636   SDValue StoredVal = Store->getValue();
24637   assert(StoreVT.is128BitVector() &&
24638          StoredVal.getValueType().is128BitVector() && "Expecting 128-bit op");
24639   StoredVal = DAG.getBitcast(StoreVT, StoredVal);
24640 
24641   // Splitting volatile memory ops is not allowed unless the operation was not
24642   // legal to begin with. We are assuming the input op is legal (this transform
24643   // is only used for targets with AVX).
24644   if (!Store->isSimple())
24645     return SDValue();
24646 
24647   MVT StoreSVT = StoreVT.getScalarType();
24648   unsigned NumElems = StoreVT.getVectorNumElements();
24649   unsigned ScalarSize = StoreSVT.getStoreSize();
24650 
24651   SDLoc DL(Store);
24652   SmallVector<SDValue, 4> Stores;
24653   for (unsigned i = 0; i != NumElems; ++i) {
24654     unsigned Offset = i * ScalarSize;
24655     SDValue Ptr = DAG.getMemBasePlusOffset(Store->getBasePtr(),
24656                                            TypeSize::getFixed(Offset), DL);
24657     SDValue Scl = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, StoreSVT, StoredVal,
24658                               DAG.getIntPtrConstant(i, DL));
24659     SDValue Ch = DAG.getStore(Store->getChain(), DL, Scl, Ptr,
24660                               Store->getPointerInfo().getWithOffset(Offset),
24661                               Store->getOriginalAlign(),
24662                               Store->getMemOperand()->getFlags());
24663     Stores.push_back(Ch);
24664   }
24665   return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Stores);
24666 }
24667 
LowerStore(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)24668 static SDValue LowerStore(SDValue Op, const X86Subtarget &Subtarget,
24669                           SelectionDAG &DAG) {
24670   StoreSDNode *St = cast<StoreSDNode>(Op.getNode());
24671   SDLoc dl(St);
24672   SDValue StoredVal = St->getValue();
24673 
24674   // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 stores.
24675   if (StoredVal.getValueType().isVector() &&
24676       StoredVal.getValueType().getVectorElementType() == MVT::i1) {
24677     unsigned NumElts = StoredVal.getValueType().getVectorNumElements();
24678     assert(NumElts <= 8 && "Unexpected VT");
24679     assert(!St->isTruncatingStore() && "Expected non-truncating store");
24680     assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() &&
24681            "Expected AVX512F without AVX512DQI");
24682 
24683     // We must pad with zeros to ensure we store zeroes to any unused bits.
24684     StoredVal = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
24685                             DAG.getUNDEF(MVT::v16i1), StoredVal,
24686                             DAG.getIntPtrConstant(0, dl));
24687     StoredVal = DAG.getBitcast(MVT::i16, StoredVal);
24688     StoredVal = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, StoredVal);
24689     // Make sure we store zeros in the extra bits.
24690     if (NumElts < 8)
24691       StoredVal = DAG.getZeroExtendInReg(
24692           StoredVal, dl, EVT::getIntegerVT(*DAG.getContext(), NumElts));
24693 
24694     return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
24695                         St->getPointerInfo(), St->getOriginalAlign(),
24696                         St->getMemOperand()->getFlags());
24697   }
24698 
24699   if (St->isTruncatingStore())
24700     return SDValue();
24701 
24702   // If this is a 256-bit store of concatenated ops, we are better off splitting
24703   // that store into two 128-bit stores. This avoids spurious use of 256-bit ops
24704   // and each half can execute independently. Some cores would split the op into
24705   // halves anyway, so the concat (vinsertf128) is purely an extra op.
24706   MVT StoreVT = StoredVal.getSimpleValueType();
24707   if (StoreVT.is256BitVector() ||
24708       ((StoreVT == MVT::v32i16 || StoreVT == MVT::v64i8) &&
24709        !Subtarget.hasBWI())) {
24710     if (StoredVal.hasOneUse() && isFreeToSplitVector(StoredVal.getNode(), DAG))
24711       return splitVectorStore(St, DAG);
24712     return SDValue();
24713   }
24714 
24715   if (StoreVT.is32BitVector())
24716     return SDValue();
24717 
24718   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24719   assert(StoreVT.is64BitVector() && "Unexpected VT");
24720   assert(TLI.getTypeAction(*DAG.getContext(), StoreVT) ==
24721              TargetLowering::TypeWidenVector &&
24722          "Unexpected type action!");
24723 
24724   EVT WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), StoreVT);
24725   StoredVal = DAG.getNode(ISD::CONCAT_VECTORS, dl, WideVT, StoredVal,
24726                           DAG.getUNDEF(StoreVT));
24727 
24728   if (Subtarget.hasSSE2()) {
24729     // Widen the vector, cast to a v2x64 type, extract the single 64-bit element
24730     // and store it.
24731     MVT StVT = Subtarget.is64Bit() && StoreVT.isInteger() ? MVT::i64 : MVT::f64;
24732     MVT CastVT = MVT::getVectorVT(StVT, 2);
24733     StoredVal = DAG.getBitcast(CastVT, StoredVal);
24734     StoredVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, StVT, StoredVal,
24735                             DAG.getIntPtrConstant(0, dl));
24736 
24737     return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
24738                         St->getPointerInfo(), St->getOriginalAlign(),
24739                         St->getMemOperand()->getFlags());
24740   }
24741   assert(Subtarget.hasSSE1() && "Expected SSE");
24742   SDVTList Tys = DAG.getVTList(MVT::Other);
24743   SDValue Ops[] = {St->getChain(), StoredVal, St->getBasePtr()};
24744   return DAG.getMemIntrinsicNode(X86ISD::VEXTRACT_STORE, dl, Tys, Ops, MVT::i64,
24745                                  St->getMemOperand());
24746 }
24747 
24748 // Lower vector extended loads using a shuffle. If SSSE3 is not available we
24749 // may emit an illegal shuffle but the expansion is still better than scalar
24750 // code. We generate sext/sext_invec for SEXTLOADs if it's available, otherwise
24751 // we'll emit a shuffle and a arithmetic shift.
24752 // FIXME: Is the expansion actually better than scalar code? It doesn't seem so.
24753 // TODO: It is possible to support ZExt by zeroing the undef values during
24754 // the shuffle phase or after the shuffle.
LowerLoad(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)24755 static SDValue LowerLoad(SDValue Op, const X86Subtarget &Subtarget,
24756                                  SelectionDAG &DAG) {
24757   MVT RegVT = Op.getSimpleValueType();
24758   assert(RegVT.isVector() && "We only custom lower vector loads.");
24759   assert(RegVT.isInteger() &&
24760          "We only custom lower integer vector loads.");
24761 
24762   LoadSDNode *Ld = cast<LoadSDNode>(Op.getNode());
24763   SDLoc dl(Ld);
24764 
24765   // Without AVX512DQ, we need to use a scalar type for v2i1/v4i1/v8i1 loads.
24766   if (RegVT.getVectorElementType() == MVT::i1) {
24767     assert(EVT(RegVT) == Ld->getMemoryVT() && "Expected non-extending load");
24768     assert(RegVT.getVectorNumElements() <= 8 && "Unexpected VT");
24769     assert(Subtarget.hasAVX512() && !Subtarget.hasDQI() &&
24770            "Expected AVX512F without AVX512DQI");
24771 
24772     SDValue NewLd = DAG.getLoad(MVT::i8, dl, Ld->getChain(), Ld->getBasePtr(),
24773                                 Ld->getPointerInfo(), Ld->getOriginalAlign(),
24774                                 Ld->getMemOperand()->getFlags());
24775 
24776     // Replace chain users with the new chain.
24777     assert(NewLd->getNumValues() == 2 && "Loads must carry a chain!");
24778 
24779     SDValue Val = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, NewLd);
24780     Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, RegVT,
24781                       DAG.getBitcast(MVT::v16i1, Val),
24782                       DAG.getIntPtrConstant(0, dl));
24783     return DAG.getMergeValues({Val, NewLd.getValue(1)}, dl);
24784   }
24785 
24786   return SDValue();
24787 }
24788 
24789 /// Return true if node is an ISD::AND or ISD::OR of two X86ISD::SETCC nodes
24790 /// each of which has no other use apart from the AND / OR.
isAndOrOfSetCCs(SDValue Op,unsigned & Opc)24791 static bool isAndOrOfSetCCs(SDValue Op, unsigned &Opc) {
24792   Opc = Op.getOpcode();
24793   if (Opc != ISD::OR && Opc != ISD::AND)
24794     return false;
24795   return (Op.getOperand(0).getOpcode() == X86ISD::SETCC &&
24796           Op.getOperand(0).hasOneUse() &&
24797           Op.getOperand(1).getOpcode() == X86ISD::SETCC &&
24798           Op.getOperand(1).hasOneUse());
24799 }
24800 
LowerBRCOND(SDValue Op,SelectionDAG & DAG) const24801 SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
24802   SDValue Chain = Op.getOperand(0);
24803   SDValue Cond  = Op.getOperand(1);
24804   SDValue Dest  = Op.getOperand(2);
24805   SDLoc dl(Op);
24806 
24807   // Bail out when we don't have native compare instructions.
24808   if (Cond.getOpcode() == ISD::SETCC &&
24809       Cond.getOperand(0).getValueType() != MVT::f128 &&
24810       !isSoftF16(Cond.getOperand(0).getValueType(), Subtarget)) {
24811     SDValue LHS = Cond.getOperand(0);
24812     SDValue RHS = Cond.getOperand(1);
24813     ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
24814 
24815     // Special case for
24816     // setcc([su]{add,sub,mul}o == 0)
24817     // setcc([su]{add,sub,mul}o != 1)
24818     if (ISD::isOverflowIntrOpRes(LHS) &&
24819         (CC == ISD::SETEQ || CC == ISD::SETNE) &&
24820         (isNullConstant(RHS) || isOneConstant(RHS))) {
24821       SDValue Value, Overflow;
24822       X86::CondCode X86Cond;
24823       std::tie(Value, Overflow) = getX86XALUOOp(X86Cond, LHS.getValue(0), DAG);
24824 
24825       if ((CC == ISD::SETEQ) == isNullConstant(RHS))
24826         X86Cond = X86::GetOppositeBranchCondition(X86Cond);
24827 
24828       SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
24829       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
24830                          Overflow);
24831     }
24832 
24833     if (LHS.getSimpleValueType().isInteger()) {
24834       SDValue CCVal;
24835       SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, CC, SDLoc(Cond), DAG, CCVal);
24836       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
24837                          EFLAGS);
24838     }
24839 
24840     if (CC == ISD::SETOEQ) {
24841       // For FCMP_OEQ, we can emit
24842       // two branches instead of an explicit AND instruction with a
24843       // separate test. However, we only do this if this block doesn't
24844       // have a fall-through edge, because this requires an explicit
24845       // jmp when the condition is false.
24846       if (Op.getNode()->hasOneUse()) {
24847         SDNode *User = *Op.getNode()->use_begin();
24848         // Look for an unconditional branch following this conditional branch.
24849         // We need this because we need to reverse the successors in order
24850         // to implement FCMP_OEQ.
24851         if (User->getOpcode() == ISD::BR) {
24852           SDValue FalseBB = User->getOperand(1);
24853           SDNode *NewBR =
24854             DAG.UpdateNodeOperands(User, User->getOperand(0), Dest);
24855           assert(NewBR == User);
24856           (void)NewBR;
24857           Dest = FalseBB;
24858 
24859           SDValue Cmp =
24860               DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
24861           SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
24862           Chain = DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest,
24863                               CCVal, Cmp);
24864           CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
24865           return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
24866                              Cmp);
24867         }
24868       }
24869     } else if (CC == ISD::SETUNE) {
24870       // For FCMP_UNE, we can emit
24871       // two branches instead of an explicit OR instruction with a
24872       // separate test.
24873       SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
24874       SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
24875       Chain =
24876           DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal, Cmp);
24877       CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
24878       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
24879                          Cmp);
24880     } else {
24881       X86::CondCode X86Cond =
24882           TranslateX86CC(CC, dl, /*IsFP*/ true, LHS, RHS, DAG);
24883       SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
24884       SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
24885       return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
24886                          Cmp);
24887     }
24888   }
24889 
24890   if (ISD::isOverflowIntrOpRes(Cond)) {
24891     SDValue Value, Overflow;
24892     X86::CondCode X86Cond;
24893     std::tie(Value, Overflow) = getX86XALUOOp(X86Cond, Cond.getValue(0), DAG);
24894 
24895     SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
24896     return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
24897                        Overflow);
24898   }
24899 
24900   // Look past the truncate if the high bits are known zero.
24901   if (isTruncWithZeroHighBitsInput(Cond, DAG))
24902     Cond = Cond.getOperand(0);
24903 
24904   EVT CondVT = Cond.getValueType();
24905 
24906   // Add an AND with 1 if we don't already have one.
24907   if (!(Cond.getOpcode() == ISD::AND && isOneConstant(Cond.getOperand(1))))
24908     Cond =
24909         DAG.getNode(ISD::AND, dl, CondVT, Cond, DAG.getConstant(1, dl, CondVT));
24910 
24911   SDValue LHS = Cond;
24912   SDValue RHS = DAG.getConstant(0, dl, CondVT);
24913 
24914   SDValue CCVal;
24915   SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, ISD::SETNE, dl, DAG, CCVal);
24916   return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
24917                      EFLAGS);
24918 }
24919 
24920 // Lower dynamic stack allocation to _alloca call for Cygwin/Mingw targets.
24921 // Calls to _alloca are needed to probe the stack when allocating more than 4k
24922 // bytes in one go. Touching the stack at 4K increments is necessary to ensure
24923 // that the guard pages used by the OS virtual memory manager are allocated in
24924 // correct sequence.
24925 SDValue
LowerDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const24926 X86TargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
24927                                            SelectionDAG &DAG) const {
24928   MachineFunction &MF = DAG.getMachineFunction();
24929   bool SplitStack = MF.shouldSplitStack();
24930   bool EmitStackProbeCall = hasStackProbeSymbol(MF);
24931   bool Lower = (Subtarget.isOSWindows() && !Subtarget.isTargetMachO()) ||
24932                SplitStack || EmitStackProbeCall;
24933   SDLoc dl(Op);
24934 
24935   // Get the inputs.
24936   SDNode *Node = Op.getNode();
24937   SDValue Chain = Op.getOperand(0);
24938   SDValue Size  = Op.getOperand(1);
24939   MaybeAlign Alignment(Op.getConstantOperandVal(2));
24940   EVT VT = Node->getValueType(0);
24941 
24942   // Chain the dynamic stack allocation so that it doesn't modify the stack
24943   // pointer when other instructions are using the stack.
24944   Chain = DAG.getCALLSEQ_START(Chain, 0, 0, dl);
24945 
24946   bool Is64Bit = Subtarget.is64Bit();
24947   MVT SPTy = getPointerTy(DAG.getDataLayout());
24948 
24949   SDValue Result;
24950   if (!Lower) {
24951     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24952     Register SPReg = TLI.getStackPointerRegisterToSaveRestore();
24953     assert(SPReg && "Target cannot require DYNAMIC_STACKALLOC expansion and"
24954                     " not tell us which reg is the stack pointer!");
24955 
24956     const TargetFrameLowering &TFI = *Subtarget.getFrameLowering();
24957     const Align StackAlign = TFI.getStackAlign();
24958     if (hasInlineStackProbe(MF)) {
24959       MachineRegisterInfo &MRI = MF.getRegInfo();
24960 
24961       const TargetRegisterClass *AddrRegClass = getRegClassFor(SPTy);
24962       Register Vreg = MRI.createVirtualRegister(AddrRegClass);
24963       Chain = DAG.getCopyToReg(Chain, dl, Vreg, Size);
24964       Result = DAG.getNode(X86ISD::PROBED_ALLOCA, dl, SPTy, Chain,
24965                            DAG.getRegister(Vreg, SPTy));
24966     } else {
24967       SDValue SP = DAG.getCopyFromReg(Chain, dl, SPReg, VT);
24968       Chain = SP.getValue(1);
24969       Result = DAG.getNode(ISD::SUB, dl, VT, SP, Size); // Value
24970     }
24971     if (Alignment && *Alignment > StackAlign)
24972       Result =
24973           DAG.getNode(ISD::AND, dl, VT, Result,
24974                       DAG.getConstant(~(Alignment->value() - 1ULL), dl, VT));
24975     Chain = DAG.getCopyToReg(Chain, dl, SPReg, Result); // Output chain
24976   } else if (SplitStack) {
24977     MachineRegisterInfo &MRI = MF.getRegInfo();
24978 
24979     if (Is64Bit) {
24980       // The 64 bit implementation of segmented stacks needs to clobber both r10
24981       // r11. This makes it impossible to use it along with nested parameters.
24982       const Function &F = MF.getFunction();
24983       for (const auto &A : F.args()) {
24984         if (A.hasNestAttr())
24985           report_fatal_error("Cannot use segmented stacks with functions that "
24986                              "have nested arguments.");
24987       }
24988     }
24989 
24990     const TargetRegisterClass *AddrRegClass = getRegClassFor(SPTy);
24991     Register Vreg = MRI.createVirtualRegister(AddrRegClass);
24992     Chain = DAG.getCopyToReg(Chain, dl, Vreg, Size);
24993     Result = DAG.getNode(X86ISD::SEG_ALLOCA, dl, SPTy, Chain,
24994                                 DAG.getRegister(Vreg, SPTy));
24995   } else {
24996     SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
24997     Chain = DAG.getNode(X86ISD::DYN_ALLOCA, dl, NodeTys, Chain, Size);
24998     MF.getInfo<X86MachineFunctionInfo>()->setHasDynAlloca(true);
24999 
25000     const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
25001     Register SPReg = RegInfo->getStackRegister();
25002     SDValue SP = DAG.getCopyFromReg(Chain, dl, SPReg, SPTy);
25003     Chain = SP.getValue(1);
25004 
25005     if (Alignment) {
25006       SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
25007                        DAG.getConstant(~(Alignment->value() - 1ULL), dl, VT));
25008       Chain = DAG.getCopyToReg(Chain, dl, SPReg, SP);
25009     }
25010 
25011     Result = SP;
25012   }
25013 
25014   Chain = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl);
25015 
25016   SDValue Ops[2] = {Result, Chain};
25017   return DAG.getMergeValues(Ops, dl);
25018 }
25019 
LowerVASTART(SDValue Op,SelectionDAG & DAG) const25020 SDValue X86TargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
25021   MachineFunction &MF = DAG.getMachineFunction();
25022   auto PtrVT = getPointerTy(MF.getDataLayout());
25023   X86MachineFunctionInfo *FuncInfo = MF.getInfo<X86MachineFunctionInfo>();
25024 
25025   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
25026   SDLoc DL(Op);
25027 
25028   if (!Subtarget.is64Bit() ||
25029       Subtarget.isCallingConvWin64(MF.getFunction().getCallingConv())) {
25030     // vastart just stores the address of the VarArgsFrameIndex slot into the
25031     // memory location argument.
25032     SDValue FR = DAG.getFrameIndex(FuncInfo->getVarArgsFrameIndex(), PtrVT);
25033     return DAG.getStore(Op.getOperand(0), DL, FR, Op.getOperand(1),
25034                         MachinePointerInfo(SV));
25035   }
25036 
25037   // __va_list_tag:
25038   //   gp_offset         (0 - 6 * 8)
25039   //   fp_offset         (48 - 48 + 8 * 16)
25040   //   overflow_arg_area (point to parameters coming in memory).
25041   //   reg_save_area
25042   SmallVector<SDValue, 8> MemOps;
25043   SDValue FIN = Op.getOperand(1);
25044   // Store gp_offset
25045   SDValue Store = DAG.getStore(
25046       Op.getOperand(0), DL,
25047       DAG.getConstant(FuncInfo->getVarArgsGPOffset(), DL, MVT::i32), FIN,
25048       MachinePointerInfo(SV));
25049   MemOps.push_back(Store);
25050 
25051   // Store fp_offset
25052   FIN = DAG.getMemBasePlusOffset(FIN, TypeSize::getFixed(4), DL);
25053   Store = DAG.getStore(
25054       Op.getOperand(0), DL,
25055       DAG.getConstant(FuncInfo->getVarArgsFPOffset(), DL, MVT::i32), FIN,
25056       MachinePointerInfo(SV, 4));
25057   MemOps.push_back(Store);
25058 
25059   // Store ptr to overflow_arg_area
25060   FIN = DAG.getNode(ISD::ADD, DL, PtrVT, FIN, DAG.getIntPtrConstant(4, DL));
25061   SDValue OVFIN = DAG.getFrameIndex(FuncInfo->getVarArgsFrameIndex(), PtrVT);
25062   Store =
25063       DAG.getStore(Op.getOperand(0), DL, OVFIN, FIN, MachinePointerInfo(SV, 8));
25064   MemOps.push_back(Store);
25065 
25066   // Store ptr to reg_save_area.
25067   FIN = DAG.getNode(ISD::ADD, DL, PtrVT, FIN, DAG.getIntPtrConstant(
25068       Subtarget.isTarget64BitLP64() ? 8 : 4, DL));
25069   SDValue RSFIN = DAG.getFrameIndex(FuncInfo->getRegSaveFrameIndex(), PtrVT);
25070   Store = DAG.getStore(
25071       Op.getOperand(0), DL, RSFIN, FIN,
25072       MachinePointerInfo(SV, Subtarget.isTarget64BitLP64() ? 16 : 12));
25073   MemOps.push_back(Store);
25074   return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
25075 }
25076 
LowerVAARG(SDValue Op,SelectionDAG & DAG) const25077 SDValue X86TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
25078   assert(Subtarget.is64Bit() &&
25079          "LowerVAARG only handles 64-bit va_arg!");
25080   assert(Op.getNumOperands() == 4);
25081 
25082   MachineFunction &MF = DAG.getMachineFunction();
25083   if (Subtarget.isCallingConvWin64(MF.getFunction().getCallingConv()))
25084     // The Win64 ABI uses char* instead of a structure.
25085     return DAG.expandVAArg(Op.getNode());
25086 
25087   SDValue Chain = Op.getOperand(0);
25088   SDValue SrcPtr = Op.getOperand(1);
25089   const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
25090   unsigned Align = Op.getConstantOperandVal(3);
25091   SDLoc dl(Op);
25092 
25093   EVT ArgVT = Op.getNode()->getValueType(0);
25094   Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
25095   uint32_t ArgSize = DAG.getDataLayout().getTypeAllocSize(ArgTy);
25096   uint8_t ArgMode;
25097 
25098   // Decide which area this value should be read from.
25099   // TODO: Implement the AMD64 ABI in its entirety. This simple
25100   // selection mechanism works only for the basic types.
25101   assert(ArgVT != MVT::f80 && "va_arg for f80 not yet implemented");
25102   if (ArgVT.isFloatingPoint() && ArgSize <= 16 /*bytes*/) {
25103     ArgMode = 2;  // Argument passed in XMM register. Use fp_offset.
25104   } else {
25105     assert(ArgVT.isInteger() && ArgSize <= 32 /*bytes*/ &&
25106            "Unhandled argument type in LowerVAARG");
25107     ArgMode = 1;  // Argument passed in GPR64 register(s). Use gp_offset.
25108   }
25109 
25110   if (ArgMode == 2) {
25111     // Make sure using fp_offset makes sense.
25112     assert(!Subtarget.useSoftFloat() &&
25113            !(MF.getFunction().hasFnAttribute(Attribute::NoImplicitFloat)) &&
25114            Subtarget.hasSSE1());
25115   }
25116 
25117   // Insert VAARG node into the DAG
25118   // VAARG returns two values: Variable Argument Address, Chain
25119   SDValue InstOps[] = {Chain, SrcPtr,
25120                        DAG.getTargetConstant(ArgSize, dl, MVT::i32),
25121                        DAG.getTargetConstant(ArgMode, dl, MVT::i8),
25122                        DAG.getTargetConstant(Align, dl, MVT::i32)};
25123   SDVTList VTs = DAG.getVTList(getPointerTy(DAG.getDataLayout()), MVT::Other);
25124   SDValue VAARG = DAG.getMemIntrinsicNode(
25125       Subtarget.isTarget64BitLP64() ? X86ISD::VAARG_64 : X86ISD::VAARG_X32, dl,
25126       VTs, InstOps, MVT::i64, MachinePointerInfo(SV),
25127       /*Alignment=*/std::nullopt,
25128       MachineMemOperand::MOLoad | MachineMemOperand::MOStore);
25129   Chain = VAARG.getValue(1);
25130 
25131   // Load the next argument and return it
25132   return DAG.getLoad(ArgVT, dl, Chain, VAARG, MachinePointerInfo());
25133 }
25134 
LowerVACOPY(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)25135 static SDValue LowerVACOPY(SDValue Op, const X86Subtarget &Subtarget,
25136                            SelectionDAG &DAG) {
25137   // X86-64 va_list is a struct { i32, i32, i8*, i8* }, except on Windows,
25138   // where a va_list is still an i8*.
25139   assert(Subtarget.is64Bit() && "This code only handles 64-bit va_copy!");
25140   if (Subtarget.isCallingConvWin64(
25141         DAG.getMachineFunction().getFunction().getCallingConv()))
25142     // Probably a Win64 va_copy.
25143     return DAG.expandVACopy(Op.getNode());
25144 
25145   SDValue Chain = Op.getOperand(0);
25146   SDValue DstPtr = Op.getOperand(1);
25147   SDValue SrcPtr = Op.getOperand(2);
25148   const Value *DstSV = cast<SrcValueSDNode>(Op.getOperand(3))->getValue();
25149   const Value *SrcSV = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
25150   SDLoc DL(Op);
25151 
25152   return DAG.getMemcpy(
25153       Chain, DL, DstPtr, SrcPtr,
25154       DAG.getIntPtrConstant(Subtarget.isTarget64BitLP64() ? 24 : 16, DL),
25155       Align(Subtarget.isTarget64BitLP64() ? 8 : 4), /*isVolatile*/ false, false,
25156       /*CI=*/nullptr, std::nullopt, MachinePointerInfo(DstSV),
25157       MachinePointerInfo(SrcSV));
25158 }
25159 
25160 // Helper to get immediate/variable SSE shift opcode from other shift opcodes.
getTargetVShiftUniformOpcode(unsigned Opc,bool IsVariable)25161 static unsigned getTargetVShiftUniformOpcode(unsigned Opc, bool IsVariable) {
25162   switch (Opc) {
25163   case ISD::SHL:
25164   case X86ISD::VSHL:
25165   case X86ISD::VSHLI:
25166     return IsVariable ? X86ISD::VSHL : X86ISD::VSHLI;
25167   case ISD::SRL:
25168   case X86ISD::VSRL:
25169   case X86ISD::VSRLI:
25170     return IsVariable ? X86ISD::VSRL : X86ISD::VSRLI;
25171   case ISD::SRA:
25172   case X86ISD::VSRA:
25173   case X86ISD::VSRAI:
25174     return IsVariable ? X86ISD::VSRA : X86ISD::VSRAI;
25175   }
25176   llvm_unreachable("Unknown target vector shift node");
25177 }
25178 
25179 /// Handle vector element shifts where the shift amount is a constant.
25180 /// Takes immediate version of shift as input.
getTargetVShiftByConstNode(unsigned Opc,const SDLoc & dl,MVT VT,SDValue SrcOp,uint64_t ShiftAmt,SelectionDAG & DAG)25181 static SDValue getTargetVShiftByConstNode(unsigned Opc, const SDLoc &dl, MVT VT,
25182                                           SDValue SrcOp, uint64_t ShiftAmt,
25183                                           SelectionDAG &DAG) {
25184   MVT ElementType = VT.getVectorElementType();
25185 
25186   // Bitcast the source vector to the output type, this is mainly necessary for
25187   // vXi8/vXi64 shifts.
25188   if (VT != SrcOp.getSimpleValueType())
25189     SrcOp = DAG.getBitcast(VT, SrcOp);
25190 
25191   // Fold this packed shift into its first operand if ShiftAmt is 0.
25192   if (ShiftAmt == 0)
25193     return SrcOp;
25194 
25195   // Check for ShiftAmt >= element width
25196   if (ShiftAmt >= ElementType.getSizeInBits()) {
25197     if (Opc == X86ISD::VSRAI)
25198       ShiftAmt = ElementType.getSizeInBits() - 1;
25199     else
25200       return DAG.getConstant(0, dl, VT);
25201   }
25202 
25203   assert((Opc == X86ISD::VSHLI || Opc == X86ISD::VSRLI || Opc == X86ISD::VSRAI)
25204          && "Unknown target vector shift-by-constant node");
25205 
25206   // Fold this packed vector shift into a build vector if SrcOp is a
25207   // vector of Constants or UNDEFs.
25208   if (ISD::isBuildVectorOfConstantSDNodes(SrcOp.getNode())) {
25209     unsigned ShiftOpc;
25210     switch (Opc) {
25211     default: llvm_unreachable("Unknown opcode!");
25212     case X86ISD::VSHLI:
25213       ShiftOpc = ISD::SHL;
25214       break;
25215     case X86ISD::VSRLI:
25216       ShiftOpc = ISD::SRL;
25217       break;
25218     case X86ISD::VSRAI:
25219       ShiftOpc = ISD::SRA;
25220       break;
25221     }
25222 
25223     SDValue Amt = DAG.getConstant(ShiftAmt, dl, VT);
25224     if (SDValue C = DAG.FoldConstantArithmetic(ShiftOpc, dl, VT, {SrcOp, Amt}))
25225       return C;
25226   }
25227 
25228   return DAG.getNode(Opc, dl, VT, SrcOp,
25229                      DAG.getTargetConstant(ShiftAmt, dl, MVT::i8));
25230 }
25231 
25232 /// Handle vector element shifts by a splat shift amount
getTargetVShiftNode(unsigned Opc,const SDLoc & dl,MVT VT,SDValue SrcOp,SDValue ShAmt,int ShAmtIdx,const X86Subtarget & Subtarget,SelectionDAG & DAG)25233 static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT,
25234                                    SDValue SrcOp, SDValue ShAmt, int ShAmtIdx,
25235                                    const X86Subtarget &Subtarget,
25236                                    SelectionDAG &DAG) {
25237   MVT AmtVT = ShAmt.getSimpleValueType();
25238   assert(AmtVT.isVector() && "Vector shift type mismatch");
25239   assert(0 <= ShAmtIdx && ShAmtIdx < (int)AmtVT.getVectorNumElements() &&
25240          "Illegal vector splat index");
25241 
25242   // Move the splat element to the bottom element.
25243   if (ShAmtIdx != 0) {
25244     SmallVector<int> Mask(AmtVT.getVectorNumElements(), -1);
25245     Mask[0] = ShAmtIdx;
25246     ShAmt = DAG.getVectorShuffle(AmtVT, dl, ShAmt, DAG.getUNDEF(AmtVT), Mask);
25247   }
25248 
25249   // Peek through any zext node if we can get back to a 128-bit source.
25250   if (AmtVT.getScalarSizeInBits() == 64 &&
25251       (ShAmt.getOpcode() == ISD::ZERO_EXTEND ||
25252        ShAmt.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) &&
25253       ShAmt.getOperand(0).getValueType().isSimple() &&
25254       ShAmt.getOperand(0).getValueType().is128BitVector()) {
25255     ShAmt = ShAmt.getOperand(0);
25256     AmtVT = ShAmt.getSimpleValueType();
25257   }
25258 
25259   // See if we can mask off the upper elements using the existing source node.
25260   // The shift uses the entire lower 64-bits of the amount vector, so no need to
25261   // do this for vXi64 types.
25262   bool IsMasked = false;
25263   if (AmtVT.getScalarSizeInBits() < 64) {
25264     if (ShAmt.getOpcode() == ISD::BUILD_VECTOR ||
25265         ShAmt.getOpcode() == ISD::SCALAR_TO_VECTOR) {
25266       // If the shift amount has come from a scalar, then zero-extend the scalar
25267       // before moving to the vector.
25268       ShAmt = DAG.getZExtOrTrunc(ShAmt.getOperand(0), dl, MVT::i32);
25269       ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, ShAmt);
25270       ShAmt = DAG.getNode(X86ISD::VZEXT_MOVL, dl, MVT::v4i32, ShAmt);
25271       AmtVT = MVT::v4i32;
25272       IsMasked = true;
25273     } else if (ShAmt.getOpcode() == ISD::AND) {
25274       // See if the shift amount is already masked (e.g. for rotation modulo),
25275       // then we can zero-extend it by setting all the other mask elements to
25276       // zero.
25277       SmallVector<SDValue> MaskElts(
25278           AmtVT.getVectorNumElements(),
25279           DAG.getConstant(0, dl, AmtVT.getScalarType()));
25280       MaskElts[0] = DAG.getAllOnesConstant(dl, AmtVT.getScalarType());
25281       SDValue Mask = DAG.getBuildVector(AmtVT, dl, MaskElts);
25282       if ((Mask = DAG.FoldConstantArithmetic(ISD::AND, dl, AmtVT,
25283                                              {ShAmt.getOperand(1), Mask}))) {
25284         ShAmt = DAG.getNode(ISD::AND, dl, AmtVT, ShAmt.getOperand(0), Mask);
25285         IsMasked = true;
25286       }
25287     }
25288   }
25289 
25290   // Extract if the shift amount vector is larger than 128-bits.
25291   if (AmtVT.getSizeInBits() > 128) {
25292     ShAmt = extract128BitVector(ShAmt, 0, DAG, dl);
25293     AmtVT = ShAmt.getSimpleValueType();
25294   }
25295 
25296   // Zero-extend bottom element to v2i64 vector type, either by extension or
25297   // shuffle masking.
25298   if (!IsMasked && AmtVT.getScalarSizeInBits() < 64) {
25299     if (AmtVT == MVT::v4i32 && (ShAmt.getOpcode() == X86ISD::VBROADCAST ||
25300                                 ShAmt.getOpcode() == X86ISD::VBROADCAST_LOAD)) {
25301       ShAmt = DAG.getNode(X86ISD::VZEXT_MOVL, SDLoc(ShAmt), MVT::v4i32, ShAmt);
25302     } else if (Subtarget.hasSSE41()) {
25303       ShAmt = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(ShAmt),
25304                           MVT::v2i64, ShAmt);
25305     } else {
25306       SDValue ByteShift = DAG.getTargetConstant(
25307           (128 - AmtVT.getScalarSizeInBits()) / 8, SDLoc(ShAmt), MVT::i8);
25308       ShAmt = DAG.getBitcast(MVT::v16i8, ShAmt);
25309       ShAmt = DAG.getNode(X86ISD::VSHLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt,
25310                           ByteShift);
25311       ShAmt = DAG.getNode(X86ISD::VSRLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt,
25312                           ByteShift);
25313     }
25314   }
25315 
25316   // Change opcode to non-immediate version.
25317   Opc = getTargetVShiftUniformOpcode(Opc, true);
25318 
25319   // The return type has to be a 128-bit type with the same element
25320   // type as the input type.
25321   MVT EltVT = VT.getVectorElementType();
25322   MVT ShVT = MVT::getVectorVT(EltVT, 128 / EltVT.getSizeInBits());
25323 
25324   ShAmt = DAG.getBitcast(ShVT, ShAmt);
25325   return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt);
25326 }
25327 
25328 /// Return Mask with the necessary casting or extending
25329 /// for \p Mask according to \p MaskVT when lowering masking intrinsics
getMaskNode(SDValue Mask,MVT MaskVT,const X86Subtarget & Subtarget,SelectionDAG & DAG,const SDLoc & dl)25330 static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
25331                            const X86Subtarget &Subtarget, SelectionDAG &DAG,
25332                            const SDLoc &dl) {
25333 
25334   if (isAllOnesConstant(Mask))
25335     return DAG.getConstant(1, dl, MaskVT);
25336   if (X86::isZeroNode(Mask))
25337     return DAG.getConstant(0, dl, MaskVT);
25338 
25339   assert(MaskVT.bitsLE(Mask.getSimpleValueType()) && "Unexpected mask size!");
25340 
25341   if (Mask.getSimpleValueType() == MVT::i64 && Subtarget.is32Bit()) {
25342     assert(MaskVT == MVT::v64i1 && "Expected v64i1 mask!");
25343     assert(Subtarget.hasBWI() && "Expected AVX512BW target!");
25344     // In case 32bit mode, bitcast i64 is illegal, extend/split it.
25345     SDValue Lo, Hi;
25346     std::tie(Lo, Hi) = DAG.SplitScalar(Mask, dl, MVT::i32, MVT::i32);
25347     Lo = DAG.getBitcast(MVT::v32i1, Lo);
25348     Hi = DAG.getBitcast(MVT::v32i1, Hi);
25349     return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi);
25350   } else {
25351     MVT BitcastVT = MVT::getVectorVT(MVT::i1,
25352                                      Mask.getSimpleValueType().getSizeInBits());
25353     // In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements
25354     // are extracted by EXTRACT_SUBVECTOR.
25355     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT,
25356                        DAG.getBitcast(BitcastVT, Mask),
25357                        DAG.getIntPtrConstant(0, dl));
25358   }
25359 }
25360 
25361 /// Return (and \p Op, \p Mask) for compare instructions or
25362 /// (vselect \p Mask, \p Op, \p PreservedSrc) for others along with the
25363 /// necessary casting or extending for \p Mask when lowering masking intrinsics
getVectorMaskingNode(SDValue Op,SDValue Mask,SDValue PreservedSrc,const X86Subtarget & Subtarget,SelectionDAG & DAG)25364 static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask,
25365                                     SDValue PreservedSrc,
25366                                     const X86Subtarget &Subtarget,
25367                                     SelectionDAG &DAG) {
25368   MVT VT = Op.getSimpleValueType();
25369   MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
25370   unsigned OpcodeSelect = ISD::VSELECT;
25371   SDLoc dl(Op);
25372 
25373   if (isAllOnesConstant(Mask))
25374     return Op;
25375 
25376   SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
25377 
25378   if (PreservedSrc.isUndef())
25379     PreservedSrc = getZeroVector(VT, Subtarget, DAG, dl);
25380   return DAG.getNode(OpcodeSelect, dl, VT, VMask, Op, PreservedSrc);
25381 }
25382 
25383 /// Creates an SDNode for a predicated scalar operation.
25384 /// \returns (X86vselect \p Mask, \p Op, \p PreservedSrc).
25385 /// The mask is coming as MVT::i8 and it should be transformed
25386 /// to MVT::v1i1 while lowering masking intrinsics.
25387 /// The main difference between ScalarMaskingNode and VectorMaskingNode is using
25388 /// "X86select" instead of "vselect". We just can't create the "vselect" node
25389 /// for a scalar instruction.
getScalarMaskingNode(SDValue Op,SDValue Mask,SDValue PreservedSrc,const X86Subtarget & Subtarget,SelectionDAG & DAG)25390 static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask,
25391                                     SDValue PreservedSrc,
25392                                     const X86Subtarget &Subtarget,
25393                                     SelectionDAG &DAG) {
25394 
25395   if (auto *MaskConst = dyn_cast<ConstantSDNode>(Mask))
25396     if (MaskConst->getZExtValue() & 0x1)
25397       return Op;
25398 
25399   MVT VT = Op.getSimpleValueType();
25400   SDLoc dl(Op);
25401 
25402   assert(Mask.getValueType() == MVT::i8 && "Unexpect type");
25403   SDValue IMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i1,
25404                               DAG.getBitcast(MVT::v8i1, Mask),
25405                               DAG.getIntPtrConstant(0, dl));
25406   if (Op.getOpcode() == X86ISD::FSETCCM ||
25407       Op.getOpcode() == X86ISD::FSETCCM_SAE ||
25408       Op.getOpcode() == X86ISD::VFPCLASSS)
25409     return DAG.getNode(ISD::AND, dl, VT, Op, IMask);
25410 
25411   if (PreservedSrc.isUndef())
25412     PreservedSrc = getZeroVector(VT, Subtarget, DAG, dl);
25413   return DAG.getNode(X86ISD::SELECTS, dl, VT, IMask, Op, PreservedSrc);
25414 }
25415 
getSEHRegistrationNodeSize(const Function * Fn)25416 static int getSEHRegistrationNodeSize(const Function *Fn) {
25417   if (!Fn->hasPersonalityFn())
25418     report_fatal_error(
25419         "querying registration node size for function without personality");
25420   // The RegNodeSize is 6 32-bit words for SEH and 4 for C++ EH. See
25421   // WinEHStatePass for the full struct definition.
25422   switch (classifyEHPersonality(Fn->getPersonalityFn())) {
25423   case EHPersonality::MSVC_X86SEH: return 24;
25424   case EHPersonality::MSVC_CXX: return 16;
25425   default: break;
25426   }
25427   report_fatal_error(
25428       "can only recover FP for 32-bit MSVC EH personality functions");
25429 }
25430 
25431 /// When the MSVC runtime transfers control to us, either to an outlined
25432 /// function or when returning to a parent frame after catching an exception, we
25433 /// recover the parent frame pointer by doing arithmetic on the incoming EBP.
25434 /// Here's the math:
25435 ///   RegNodeBase = EntryEBP - RegNodeSize
25436 ///   ParentFP = RegNodeBase - ParentFrameOffset
25437 /// Subtracting RegNodeSize takes us to the offset of the registration node, and
25438 /// subtracting the offset (negative on x86) takes us back to the parent FP.
recoverFramePointer(SelectionDAG & DAG,const Function * Fn,SDValue EntryEBP)25439 static SDValue recoverFramePointer(SelectionDAG &DAG, const Function *Fn,
25440                                    SDValue EntryEBP) {
25441   MachineFunction &MF = DAG.getMachineFunction();
25442   SDLoc dl;
25443 
25444   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25445   MVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
25446 
25447   // It's possible that the parent function no longer has a personality function
25448   // if the exceptional code was optimized away, in which case we just return
25449   // the incoming EBP.
25450   if (!Fn->hasPersonalityFn())
25451     return EntryEBP;
25452 
25453   // Get an MCSymbol that will ultimately resolve to the frame offset of the EH
25454   // registration, or the .set_setframe offset.
25455   MCSymbol *OffsetSym = MF.getContext().getOrCreateParentFrameOffsetSymbol(
25456       GlobalValue::dropLLVMManglingEscape(Fn->getName()));
25457   SDValue OffsetSymVal = DAG.getMCSymbol(OffsetSym, PtrVT);
25458   SDValue ParentFrameOffset =
25459       DAG.getNode(ISD::LOCAL_RECOVER, dl, PtrVT, OffsetSymVal);
25460 
25461   // Return EntryEBP + ParentFrameOffset for x64. This adjusts from RSP after
25462   // prologue to RBP in the parent function.
25463   const X86Subtarget &Subtarget = DAG.getSubtarget<X86Subtarget>();
25464   if (Subtarget.is64Bit())
25465     return DAG.getNode(ISD::ADD, dl, PtrVT, EntryEBP, ParentFrameOffset);
25466 
25467   int RegNodeSize = getSEHRegistrationNodeSize(Fn);
25468   // RegNodeBase = EntryEBP - RegNodeSize
25469   // ParentFP = RegNodeBase - ParentFrameOffset
25470   SDValue RegNodeBase = DAG.getNode(ISD::SUB, dl, PtrVT, EntryEBP,
25471                                     DAG.getConstant(RegNodeSize, dl, PtrVT));
25472   return DAG.getNode(ISD::SUB, dl, PtrVT, RegNodeBase, ParentFrameOffset);
25473 }
25474 
LowerINTRINSIC_WO_CHAIN(SDValue Op,SelectionDAG & DAG) const25475 SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
25476                                                    SelectionDAG &DAG) const {
25477   // Helper to detect if the operand is CUR_DIRECTION rounding mode.
25478   auto isRoundModeCurDirection = [](SDValue Rnd) {
25479     if (auto *C = dyn_cast<ConstantSDNode>(Rnd))
25480       return C->getAPIntValue() == X86::STATIC_ROUNDING::CUR_DIRECTION;
25481 
25482     return false;
25483   };
25484   auto isRoundModeSAE = [](SDValue Rnd) {
25485     if (auto *C = dyn_cast<ConstantSDNode>(Rnd)) {
25486       unsigned RC = C->getZExtValue();
25487       if (RC & X86::STATIC_ROUNDING::NO_EXC) {
25488         // Clear the NO_EXC bit and check remaining bits.
25489         RC ^= X86::STATIC_ROUNDING::NO_EXC;
25490         // As a convenience we allow no other bits or explicitly
25491         // current direction.
25492         return RC == 0 || RC == X86::STATIC_ROUNDING::CUR_DIRECTION;
25493       }
25494     }
25495 
25496     return false;
25497   };
25498   auto isRoundModeSAEToX = [](SDValue Rnd, unsigned &RC) {
25499     if (auto *C = dyn_cast<ConstantSDNode>(Rnd)) {
25500       RC = C->getZExtValue();
25501       if (RC & X86::STATIC_ROUNDING::NO_EXC) {
25502         // Clear the NO_EXC bit and check remaining bits.
25503         RC ^= X86::STATIC_ROUNDING::NO_EXC;
25504         return RC == X86::STATIC_ROUNDING::TO_NEAREST_INT ||
25505                RC == X86::STATIC_ROUNDING::TO_NEG_INF ||
25506                RC == X86::STATIC_ROUNDING::TO_POS_INF ||
25507                RC == X86::STATIC_ROUNDING::TO_ZERO;
25508       }
25509     }
25510 
25511     return false;
25512   };
25513 
25514   SDLoc dl(Op);
25515   unsigned IntNo = Op.getConstantOperandVal(0);
25516   MVT VT = Op.getSimpleValueType();
25517   const IntrinsicData* IntrData = getIntrinsicWithoutChain(IntNo);
25518 
25519   // Propagate flags from original node to transformed node(s).
25520   SelectionDAG::FlagInserter FlagsInserter(DAG, Op->getFlags());
25521 
25522   if (IntrData) {
25523     switch(IntrData->Type) {
25524     case INTR_TYPE_1OP: {
25525       // We specify 2 possible opcodes for intrinsics with rounding modes.
25526       // First, we check if the intrinsic may have non-default rounding mode,
25527       // (IntrData->Opc1 != 0), then we check the rounding mode operand.
25528       unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
25529       if (IntrWithRoundingModeOpcode != 0) {
25530         SDValue Rnd = Op.getOperand(2);
25531         unsigned RC = 0;
25532         if (isRoundModeSAEToX(Rnd, RC))
25533           return DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(),
25534                              Op.getOperand(1),
25535                              DAG.getTargetConstant(RC, dl, MVT::i32));
25536         if (!isRoundModeCurDirection(Rnd))
25537           return SDValue();
25538       }
25539       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
25540                          Op.getOperand(1));
25541     }
25542     case INTR_TYPE_1OP_SAE: {
25543       SDValue Sae = Op.getOperand(2);
25544 
25545       unsigned Opc;
25546       if (isRoundModeCurDirection(Sae))
25547         Opc = IntrData->Opc0;
25548       else if (isRoundModeSAE(Sae))
25549         Opc = IntrData->Opc1;
25550       else
25551         return SDValue();
25552 
25553       return DAG.getNode(Opc, dl, Op.getValueType(), Op.getOperand(1));
25554     }
25555     case INTR_TYPE_2OP: {
25556       SDValue Src2 = Op.getOperand(2);
25557 
25558       // We specify 2 possible opcodes for intrinsics with rounding modes.
25559       // First, we check if the intrinsic may have non-default rounding mode,
25560       // (IntrData->Opc1 != 0), then we check the rounding mode operand.
25561       unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
25562       if (IntrWithRoundingModeOpcode != 0) {
25563         SDValue Rnd = Op.getOperand(3);
25564         unsigned RC = 0;
25565         if (isRoundModeSAEToX(Rnd, RC))
25566           return DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(),
25567                              Op.getOperand(1), Src2,
25568                              DAG.getTargetConstant(RC, dl, MVT::i32));
25569         if (!isRoundModeCurDirection(Rnd))
25570           return SDValue();
25571       }
25572 
25573       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
25574                          Op.getOperand(1), Src2);
25575     }
25576     case INTR_TYPE_2OP_SAE: {
25577       SDValue Sae = Op.getOperand(3);
25578 
25579       unsigned Opc;
25580       if (isRoundModeCurDirection(Sae))
25581         Opc = IntrData->Opc0;
25582       else if (isRoundModeSAE(Sae))
25583         Opc = IntrData->Opc1;
25584       else
25585         return SDValue();
25586 
25587       return DAG.getNode(Opc, dl, Op.getValueType(), Op.getOperand(1),
25588                          Op.getOperand(2));
25589     }
25590     case INTR_TYPE_3OP:
25591     case INTR_TYPE_3OP_IMM8: {
25592       SDValue Src1 = Op.getOperand(1);
25593       SDValue Src2 = Op.getOperand(2);
25594       SDValue Src3 = Op.getOperand(3);
25595 
25596       if (IntrData->Type == INTR_TYPE_3OP_IMM8 &&
25597           Src3.getValueType() != MVT::i8) {
25598         Src3 = DAG.getTargetConstant(Src3->getAsZExtVal() & 0xff, dl, MVT::i8);
25599       }
25600 
25601       // We specify 2 possible opcodes for intrinsics with rounding modes.
25602       // First, we check if the intrinsic may have non-default rounding mode,
25603       // (IntrData->Opc1 != 0), then we check the rounding mode operand.
25604       unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
25605       if (IntrWithRoundingModeOpcode != 0) {
25606         SDValue Rnd = Op.getOperand(4);
25607         unsigned RC = 0;
25608         if (isRoundModeSAEToX(Rnd, RC))
25609           return DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(),
25610                              Src1, Src2, Src3,
25611                              DAG.getTargetConstant(RC, dl, MVT::i32));
25612         if (!isRoundModeCurDirection(Rnd))
25613           return SDValue();
25614       }
25615 
25616       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
25617                          {Src1, Src2, Src3});
25618     }
25619     case INTR_TYPE_4OP_IMM8: {
25620       assert(Op.getOperand(4)->getOpcode() == ISD::TargetConstant);
25621       SDValue Src4 = Op.getOperand(4);
25622       if (Src4.getValueType() != MVT::i8) {
25623         Src4 = DAG.getTargetConstant(Src4->getAsZExtVal() & 0xff, dl, MVT::i8);
25624       }
25625 
25626       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
25627                          Op.getOperand(1), Op.getOperand(2), Op.getOperand(3),
25628                          Src4);
25629     }
25630     case INTR_TYPE_1OP_MASK: {
25631       SDValue Src = Op.getOperand(1);
25632       SDValue PassThru = Op.getOperand(2);
25633       SDValue Mask = Op.getOperand(3);
25634       // We add rounding mode to the Node when
25635       //   - RC Opcode is specified and
25636       //   - RC is not "current direction".
25637       unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
25638       if (IntrWithRoundingModeOpcode != 0) {
25639         SDValue Rnd = Op.getOperand(4);
25640         unsigned RC = 0;
25641         if (isRoundModeSAEToX(Rnd, RC))
25642           return getVectorMaskingNode(
25643               DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(),
25644                           Src, DAG.getTargetConstant(RC, dl, MVT::i32)),
25645               Mask, PassThru, Subtarget, DAG);
25646         if (!isRoundModeCurDirection(Rnd))
25647           return SDValue();
25648       }
25649       return getVectorMaskingNode(
25650           DAG.getNode(IntrData->Opc0, dl, VT, Src), Mask, PassThru,
25651           Subtarget, DAG);
25652     }
25653     case INTR_TYPE_1OP_MASK_SAE: {
25654       SDValue Src = Op.getOperand(1);
25655       SDValue PassThru = Op.getOperand(2);
25656       SDValue Mask = Op.getOperand(3);
25657       SDValue Rnd = Op.getOperand(4);
25658 
25659       unsigned Opc;
25660       if (isRoundModeCurDirection(Rnd))
25661         Opc = IntrData->Opc0;
25662       else if (isRoundModeSAE(Rnd))
25663         Opc = IntrData->Opc1;
25664       else
25665         return SDValue();
25666 
25667       return getVectorMaskingNode(DAG.getNode(Opc, dl, VT, Src), Mask, PassThru,
25668                                   Subtarget, DAG);
25669     }
25670     case INTR_TYPE_SCALAR_MASK: {
25671       SDValue Src1 = Op.getOperand(1);
25672       SDValue Src2 = Op.getOperand(2);
25673       SDValue passThru = Op.getOperand(3);
25674       SDValue Mask = Op.getOperand(4);
25675       unsigned IntrWithRoundingModeOpcode = IntrData->Opc1;
25676       // There are 2 kinds of intrinsics in this group:
25677       // (1) With suppress-all-exceptions (sae) or rounding mode- 6 operands
25678       // (2) With rounding mode and sae - 7 operands.
25679       bool HasRounding = IntrWithRoundingModeOpcode != 0;
25680       if (Op.getNumOperands() == (5U + HasRounding)) {
25681         if (HasRounding) {
25682           SDValue Rnd = Op.getOperand(5);
25683           unsigned RC = 0;
25684           if (isRoundModeSAEToX(Rnd, RC))
25685             return getScalarMaskingNode(
25686                 DAG.getNode(IntrWithRoundingModeOpcode, dl, VT, Src1, Src2,
25687                             DAG.getTargetConstant(RC, dl, MVT::i32)),
25688                 Mask, passThru, Subtarget, DAG);
25689           if (!isRoundModeCurDirection(Rnd))
25690             return SDValue();
25691         }
25692         return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src1,
25693                                                 Src2),
25694                                     Mask, passThru, Subtarget, DAG);
25695       }
25696 
25697       assert(Op.getNumOperands() == (6U + HasRounding) &&
25698              "Unexpected intrinsic form");
25699       SDValue RoundingMode = Op.getOperand(5);
25700       unsigned Opc = IntrData->Opc0;
25701       if (HasRounding) {
25702         SDValue Sae = Op.getOperand(6);
25703         if (isRoundModeSAE(Sae))
25704           Opc = IntrWithRoundingModeOpcode;
25705         else if (!isRoundModeCurDirection(Sae))
25706           return SDValue();
25707       }
25708       return getScalarMaskingNode(DAG.getNode(Opc, dl, VT, Src1,
25709                                               Src2, RoundingMode),
25710                                   Mask, passThru, Subtarget, DAG);
25711     }
25712     case INTR_TYPE_SCALAR_MASK_RND: {
25713       SDValue Src1 = Op.getOperand(1);
25714       SDValue Src2 = Op.getOperand(2);
25715       SDValue passThru = Op.getOperand(3);
25716       SDValue Mask = Op.getOperand(4);
25717       SDValue Rnd = Op.getOperand(5);
25718 
25719       SDValue NewOp;
25720       unsigned RC = 0;
25721       if (isRoundModeCurDirection(Rnd))
25722         NewOp = DAG.getNode(IntrData->Opc0, dl, VT, Src1, Src2);
25723       else if (isRoundModeSAEToX(Rnd, RC))
25724         NewOp = DAG.getNode(IntrData->Opc1, dl, VT, Src1, Src2,
25725                             DAG.getTargetConstant(RC, dl, MVT::i32));
25726       else
25727         return SDValue();
25728 
25729       return getScalarMaskingNode(NewOp, Mask, passThru, Subtarget, DAG);
25730     }
25731     case INTR_TYPE_SCALAR_MASK_SAE: {
25732       SDValue Src1 = Op.getOperand(1);
25733       SDValue Src2 = Op.getOperand(2);
25734       SDValue passThru = Op.getOperand(3);
25735       SDValue Mask = Op.getOperand(4);
25736       SDValue Sae = Op.getOperand(5);
25737       unsigned Opc;
25738       if (isRoundModeCurDirection(Sae))
25739         Opc = IntrData->Opc0;
25740       else if (isRoundModeSAE(Sae))
25741         Opc = IntrData->Opc1;
25742       else
25743         return SDValue();
25744 
25745       return getScalarMaskingNode(DAG.getNode(Opc, dl, VT, Src1, Src2),
25746                                   Mask, passThru, Subtarget, DAG);
25747     }
25748     case INTR_TYPE_2OP_MASK: {
25749       SDValue Src1 = Op.getOperand(1);
25750       SDValue Src2 = Op.getOperand(2);
25751       SDValue PassThru = Op.getOperand(3);
25752       SDValue Mask = Op.getOperand(4);
25753       SDValue NewOp;
25754       if (IntrData->Opc1 != 0) {
25755         SDValue Rnd = Op.getOperand(5);
25756         unsigned RC = 0;
25757         if (isRoundModeSAEToX(Rnd, RC))
25758           NewOp = DAG.getNode(IntrData->Opc1, dl, VT, Src1, Src2,
25759                               DAG.getTargetConstant(RC, dl, MVT::i32));
25760         else if (!isRoundModeCurDirection(Rnd))
25761           return SDValue();
25762       }
25763       if (!NewOp)
25764         NewOp = DAG.getNode(IntrData->Opc0, dl, VT, Src1, Src2);
25765       return getVectorMaskingNode(NewOp, Mask, PassThru, Subtarget, DAG);
25766     }
25767     case INTR_TYPE_2OP_MASK_SAE: {
25768       SDValue Src1 = Op.getOperand(1);
25769       SDValue Src2 = Op.getOperand(2);
25770       SDValue PassThru = Op.getOperand(3);
25771       SDValue Mask = Op.getOperand(4);
25772 
25773       unsigned Opc = IntrData->Opc0;
25774       if (IntrData->Opc1 != 0) {
25775         SDValue Sae = Op.getOperand(5);
25776         if (isRoundModeSAE(Sae))
25777           Opc = IntrData->Opc1;
25778         else if (!isRoundModeCurDirection(Sae))
25779           return SDValue();
25780       }
25781 
25782       return getVectorMaskingNode(DAG.getNode(Opc, dl, VT, Src1, Src2),
25783                                   Mask, PassThru, Subtarget, DAG);
25784     }
25785     case INTR_TYPE_3OP_SCALAR_MASK_SAE: {
25786       SDValue Src1 = Op.getOperand(1);
25787       SDValue Src2 = Op.getOperand(2);
25788       SDValue Src3 = Op.getOperand(3);
25789       SDValue PassThru = Op.getOperand(4);
25790       SDValue Mask = Op.getOperand(5);
25791       SDValue Sae = Op.getOperand(6);
25792       unsigned Opc;
25793       if (isRoundModeCurDirection(Sae))
25794         Opc = IntrData->Opc0;
25795       else if (isRoundModeSAE(Sae))
25796         Opc = IntrData->Opc1;
25797       else
25798         return SDValue();
25799 
25800       return getScalarMaskingNode(DAG.getNode(Opc, dl, VT, Src1, Src2, Src3),
25801                                   Mask, PassThru, Subtarget, DAG);
25802     }
25803     case INTR_TYPE_3OP_MASK_SAE: {
25804       SDValue Src1 = Op.getOperand(1);
25805       SDValue Src2 = Op.getOperand(2);
25806       SDValue Src3 = Op.getOperand(3);
25807       SDValue PassThru = Op.getOperand(4);
25808       SDValue Mask = Op.getOperand(5);
25809 
25810       unsigned Opc = IntrData->Opc0;
25811       if (IntrData->Opc1 != 0) {
25812         SDValue Sae = Op.getOperand(6);
25813         if (isRoundModeSAE(Sae))
25814           Opc = IntrData->Opc1;
25815         else if (!isRoundModeCurDirection(Sae))
25816           return SDValue();
25817       }
25818       return getVectorMaskingNode(DAG.getNode(Opc, dl, VT, Src1, Src2, Src3),
25819                                   Mask, PassThru, Subtarget, DAG);
25820     }
25821     case BLENDV: {
25822       SDValue Src1 = Op.getOperand(1);
25823       SDValue Src2 = Op.getOperand(2);
25824       SDValue Src3 = Op.getOperand(3);
25825 
25826       EVT MaskVT = Src3.getValueType().changeVectorElementTypeToInteger();
25827       Src3 = DAG.getBitcast(MaskVT, Src3);
25828 
25829       // Reverse the operands to match VSELECT order.
25830       return DAG.getNode(IntrData->Opc0, dl, VT, Src3, Src2, Src1);
25831     }
25832     case VPERM_2OP : {
25833       SDValue Src1 = Op.getOperand(1);
25834       SDValue Src2 = Op.getOperand(2);
25835 
25836       // Swap Src1 and Src2 in the node creation
25837       return DAG.getNode(IntrData->Opc0, dl, VT,Src2, Src1);
25838     }
25839     case CFMA_OP_MASKZ:
25840     case CFMA_OP_MASK: {
25841       SDValue Src1 = Op.getOperand(1);
25842       SDValue Src2 = Op.getOperand(2);
25843       SDValue Src3 = Op.getOperand(3);
25844       SDValue Mask = Op.getOperand(4);
25845       MVT VT = Op.getSimpleValueType();
25846 
25847       SDValue PassThru = Src3;
25848       if (IntrData->Type == CFMA_OP_MASKZ)
25849         PassThru = getZeroVector(VT, Subtarget, DAG, dl);
25850 
25851       // We add rounding mode to the Node when
25852       //   - RC Opcode is specified and
25853       //   - RC is not "current direction".
25854       SDValue NewOp;
25855       if (IntrData->Opc1 != 0) {
25856         SDValue Rnd = Op.getOperand(5);
25857         unsigned RC = 0;
25858         if (isRoundModeSAEToX(Rnd, RC))
25859           NewOp = DAG.getNode(IntrData->Opc1, dl, VT, Src1, Src2, Src3,
25860                               DAG.getTargetConstant(RC, dl, MVT::i32));
25861         else if (!isRoundModeCurDirection(Rnd))
25862           return SDValue();
25863       }
25864       if (!NewOp)
25865         NewOp = DAG.getNode(IntrData->Opc0, dl, VT, Src1, Src2, Src3);
25866       return getVectorMaskingNode(NewOp, Mask, PassThru, Subtarget, DAG);
25867     }
25868     case IFMA_OP:
25869       // NOTE: We need to swizzle the operands to pass the multiply operands
25870       // first.
25871       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
25872                          Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
25873     case FPCLASSS: {
25874       SDValue Src1 = Op.getOperand(1);
25875       SDValue Imm = Op.getOperand(2);
25876       SDValue Mask = Op.getOperand(3);
25877       SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MVT::v1i1, Src1, Imm);
25878       SDValue FPclassMask = getScalarMaskingNode(FPclass, Mask, SDValue(),
25879                                                  Subtarget, DAG);
25880       // Need to fill with zeros to ensure the bitcast will produce zeroes
25881       // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
25882       SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
25883                                 DAG.getConstant(0, dl, MVT::v8i1),
25884                                 FPclassMask, DAG.getIntPtrConstant(0, dl));
25885       return DAG.getBitcast(MVT::i8, Ins);
25886     }
25887 
25888     case CMP_MASK_CC: {
25889       MVT MaskVT = Op.getSimpleValueType();
25890       SDValue CC = Op.getOperand(3);
25891       SDValue Mask = Op.getOperand(4);
25892       // We specify 2 possible opcodes for intrinsics with rounding modes.
25893       // First, we check if the intrinsic may have non-default rounding mode,
25894       // (IntrData->Opc1 != 0), then we check the rounding mode operand.
25895       if (IntrData->Opc1 != 0) {
25896         SDValue Sae = Op.getOperand(5);
25897         if (isRoundModeSAE(Sae))
25898           return DAG.getNode(IntrData->Opc1, dl, MaskVT, Op.getOperand(1),
25899                              Op.getOperand(2), CC, Mask, Sae);
25900         if (!isRoundModeCurDirection(Sae))
25901           return SDValue();
25902       }
25903       //default rounding mode
25904       return DAG.getNode(IntrData->Opc0, dl, MaskVT,
25905                          {Op.getOperand(1), Op.getOperand(2), CC, Mask});
25906     }
25907     case CMP_MASK_SCALAR_CC: {
25908       SDValue Src1 = Op.getOperand(1);
25909       SDValue Src2 = Op.getOperand(2);
25910       SDValue CC = Op.getOperand(3);
25911       SDValue Mask = Op.getOperand(4);
25912 
25913       SDValue Cmp;
25914       if (IntrData->Opc1 != 0) {
25915         SDValue Sae = Op.getOperand(5);
25916         if (isRoundModeSAE(Sae))
25917           Cmp = DAG.getNode(IntrData->Opc1, dl, MVT::v1i1, Src1, Src2, CC, Sae);
25918         else if (!isRoundModeCurDirection(Sae))
25919           return SDValue();
25920       }
25921       //default rounding mode
25922       if (!Cmp.getNode())
25923         Cmp = DAG.getNode(IntrData->Opc0, dl, MVT::v1i1, Src1, Src2, CC);
25924 
25925       SDValue CmpMask = getScalarMaskingNode(Cmp, Mask, SDValue(),
25926                                              Subtarget, DAG);
25927       // Need to fill with zeros to ensure the bitcast will produce zeroes
25928       // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
25929       SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
25930                                 DAG.getConstant(0, dl, MVT::v8i1),
25931                                 CmpMask, DAG.getIntPtrConstant(0, dl));
25932       return DAG.getBitcast(MVT::i8, Ins);
25933     }
25934     case COMI: { // Comparison intrinsics
25935       ISD::CondCode CC = (ISD::CondCode)IntrData->Opc1;
25936       SDValue LHS = Op.getOperand(1);
25937       SDValue RHS = Op.getOperand(2);
25938       // Some conditions require the operands to be swapped.
25939       if (CC == ISD::SETLT || CC == ISD::SETLE)
25940         std::swap(LHS, RHS);
25941 
25942       SDValue Comi = DAG.getNode(IntrData->Opc0, dl, MVT::i32, LHS, RHS);
25943       SDValue SetCC;
25944       switch (CC) {
25945       case ISD::SETEQ: { // (ZF = 0 and PF = 0)
25946         SetCC = getSETCC(X86::COND_E, Comi, dl, DAG);
25947         SDValue SetNP = getSETCC(X86::COND_NP, Comi, dl, DAG);
25948         SetCC = DAG.getNode(ISD::AND, dl, MVT::i8, SetCC, SetNP);
25949         break;
25950       }
25951       case ISD::SETNE: { // (ZF = 1 or PF = 1)
25952         SetCC = getSETCC(X86::COND_NE, Comi, dl, DAG);
25953         SDValue SetP = getSETCC(X86::COND_P, Comi, dl, DAG);
25954         SetCC = DAG.getNode(ISD::OR, dl, MVT::i8, SetCC, SetP);
25955         break;
25956       }
25957       case ISD::SETGT: // (CF = 0 and ZF = 0)
25958       case ISD::SETLT: { // Condition opposite to GT. Operands swapped above.
25959         SetCC = getSETCC(X86::COND_A, Comi, dl, DAG);
25960         break;
25961       }
25962       case ISD::SETGE: // CF = 0
25963       case ISD::SETLE: // Condition opposite to GE. Operands swapped above.
25964         SetCC = getSETCC(X86::COND_AE, Comi, dl, DAG);
25965         break;
25966       default:
25967         llvm_unreachable("Unexpected illegal condition!");
25968       }
25969       return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC);
25970     }
25971     case COMI_RM: { // Comparison intrinsics with Sae
25972       SDValue LHS = Op.getOperand(1);
25973       SDValue RHS = Op.getOperand(2);
25974       unsigned CondVal = Op.getConstantOperandVal(3);
25975       SDValue Sae = Op.getOperand(4);
25976 
25977       SDValue FCmp;
25978       if (isRoundModeCurDirection(Sae))
25979         FCmp = DAG.getNode(X86ISD::FSETCCM, dl, MVT::v1i1, LHS, RHS,
25980                            DAG.getTargetConstant(CondVal, dl, MVT::i8));
25981       else if (isRoundModeSAE(Sae))
25982         FCmp = DAG.getNode(X86ISD::FSETCCM_SAE, dl, MVT::v1i1, LHS, RHS,
25983                            DAG.getTargetConstant(CondVal, dl, MVT::i8), Sae);
25984       else
25985         return SDValue();
25986       // Need to fill with zeros to ensure the bitcast will produce zeroes
25987       // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
25988       SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
25989                                 DAG.getConstant(0, dl, MVT::v16i1),
25990                                 FCmp, DAG.getIntPtrConstant(0, dl));
25991       return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32,
25992                          DAG.getBitcast(MVT::i16, Ins));
25993     }
25994     case VSHIFT: {
25995       SDValue SrcOp = Op.getOperand(1);
25996       SDValue ShAmt = Op.getOperand(2);
25997       assert(ShAmt.getValueType() == MVT::i32 &&
25998              "Unexpected VSHIFT amount type");
25999 
26000       // Catch shift-by-constant.
26001       if (auto *CShAmt = dyn_cast<ConstantSDNode>(ShAmt))
26002         return getTargetVShiftByConstNode(IntrData->Opc0, dl,
26003                                           Op.getSimpleValueType(), SrcOp,
26004                                           CShAmt->getZExtValue(), DAG);
26005 
26006       ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, ShAmt);
26007       return getTargetVShiftNode(IntrData->Opc0, dl, Op.getSimpleValueType(),
26008                                  SrcOp, ShAmt, 0, Subtarget, DAG);
26009     }
26010     case COMPRESS_EXPAND_IN_REG: {
26011       SDValue Mask = Op.getOperand(3);
26012       SDValue DataToCompress = Op.getOperand(1);
26013       SDValue PassThru = Op.getOperand(2);
26014       if (ISD::isBuildVectorAllOnes(Mask.getNode())) // return data as is
26015         return Op.getOperand(1);
26016 
26017       // Avoid false dependency.
26018       if (PassThru.isUndef())
26019         PassThru = getZeroVector(VT, Subtarget, DAG, dl);
26020 
26021       return DAG.getNode(IntrData->Opc0, dl, VT, DataToCompress, PassThru,
26022                          Mask);
26023     }
26024     case FIXUPIMM:
26025     case FIXUPIMM_MASKZ: {
26026       SDValue Src1 = Op.getOperand(1);
26027       SDValue Src2 = Op.getOperand(2);
26028       SDValue Src3 = Op.getOperand(3);
26029       SDValue Imm = Op.getOperand(4);
26030       SDValue Mask = Op.getOperand(5);
26031       SDValue Passthru = (IntrData->Type == FIXUPIMM)
26032                              ? Src1
26033                              : getZeroVector(VT, Subtarget, DAG, dl);
26034 
26035       unsigned Opc = IntrData->Opc0;
26036       if (IntrData->Opc1 != 0) {
26037         SDValue Sae = Op.getOperand(6);
26038         if (isRoundModeSAE(Sae))
26039           Opc = IntrData->Opc1;
26040         else if (!isRoundModeCurDirection(Sae))
26041           return SDValue();
26042       }
26043 
26044       SDValue FixupImm = DAG.getNode(Opc, dl, VT, Src1, Src2, Src3, Imm);
26045 
26046       if (Opc == X86ISD::VFIXUPIMM || Opc == X86ISD::VFIXUPIMM_SAE)
26047         return getVectorMaskingNode(FixupImm, Mask, Passthru, Subtarget, DAG);
26048 
26049       return getScalarMaskingNode(FixupImm, Mask, Passthru, Subtarget, DAG);
26050     }
26051     case ROUNDP: {
26052       assert(IntrData->Opc0 == X86ISD::VRNDSCALE && "Unexpected opcode");
26053       // Clear the upper bits of the rounding immediate so that the legacy
26054       // intrinsic can't trigger the scaling behavior of VRNDSCALE.
26055       uint64_t Round = Op.getConstantOperandVal(2);
26056       SDValue RoundingMode = DAG.getTargetConstant(Round & 0xf, dl, MVT::i32);
26057       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
26058                          Op.getOperand(1), RoundingMode);
26059     }
26060     case ROUNDS: {
26061       assert(IntrData->Opc0 == X86ISD::VRNDSCALES && "Unexpected opcode");
26062       // Clear the upper bits of the rounding immediate so that the legacy
26063       // intrinsic can't trigger the scaling behavior of VRNDSCALE.
26064       uint64_t Round = Op.getConstantOperandVal(3);
26065       SDValue RoundingMode = DAG.getTargetConstant(Round & 0xf, dl, MVT::i32);
26066       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
26067                          Op.getOperand(1), Op.getOperand(2), RoundingMode);
26068     }
26069     case BEXTRI: {
26070       assert(IntrData->Opc0 == X86ISD::BEXTRI && "Unexpected opcode");
26071 
26072       uint64_t Imm = Op.getConstantOperandVal(2);
26073       SDValue Control = DAG.getTargetConstant(Imm & 0xffff, dl,
26074                                               Op.getValueType());
26075       return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(),
26076                          Op.getOperand(1), Control);
26077     }
26078     // ADC/SBB
26079     case ADX: {
26080       SDVTList CFVTs = DAG.getVTList(Op->getValueType(0), MVT::i32);
26081       SDVTList VTs = DAG.getVTList(Op.getOperand(2).getValueType(), MVT::i32);
26082 
26083       SDValue Res;
26084       // If the carry in is zero, then we should just use ADD/SUB instead of
26085       // ADC/SBB.
26086       if (isNullConstant(Op.getOperand(1))) {
26087         Res = DAG.getNode(IntrData->Opc1, dl, VTs, Op.getOperand(2),
26088                           Op.getOperand(3));
26089       } else {
26090         SDValue GenCF = DAG.getNode(X86ISD::ADD, dl, CFVTs, Op.getOperand(1),
26091                                     DAG.getConstant(-1, dl, MVT::i8));
26092         Res = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(2),
26093                           Op.getOperand(3), GenCF.getValue(1));
26094       }
26095       SDValue SetCC = getSETCC(X86::COND_B, Res.getValue(1), dl, DAG);
26096       SDValue Results[] = { SetCC, Res };
26097       return DAG.getMergeValues(Results, dl);
26098     }
26099     case CVTPD2PS_MASK:
26100     case CVTPD2DQ_MASK:
26101     case CVTQQ2PS_MASK:
26102     case TRUNCATE_TO_REG: {
26103       SDValue Src = Op.getOperand(1);
26104       SDValue PassThru = Op.getOperand(2);
26105       SDValue Mask = Op.getOperand(3);
26106 
26107       if (isAllOnesConstant(Mask))
26108         return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Src);
26109 
26110       MVT SrcVT = Src.getSimpleValueType();
26111       MVT MaskVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorNumElements());
26112       Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
26113       return DAG.getNode(IntrData->Opc1, dl, Op.getValueType(),
26114                          {Src, PassThru, Mask});
26115     }
26116     case CVTPS2PH_MASK: {
26117       SDValue Src = Op.getOperand(1);
26118       SDValue Rnd = Op.getOperand(2);
26119       SDValue PassThru = Op.getOperand(3);
26120       SDValue Mask = Op.getOperand(4);
26121 
26122       unsigned RC = 0;
26123       unsigned Opc = IntrData->Opc0;
26124       bool SAE = Src.getValueType().is512BitVector() &&
26125                  (isRoundModeSAEToX(Rnd, RC) || isRoundModeSAE(Rnd));
26126       if (SAE) {
26127         Opc = X86ISD::CVTPS2PH_SAE;
26128         Rnd = DAG.getTargetConstant(RC, dl, MVT::i32);
26129       }
26130 
26131       if (isAllOnesConstant(Mask))
26132         return DAG.getNode(Opc, dl, Op.getValueType(), Src, Rnd);
26133 
26134       if (SAE)
26135         Opc = X86ISD::MCVTPS2PH_SAE;
26136       else
26137         Opc = IntrData->Opc1;
26138       MVT SrcVT = Src.getSimpleValueType();
26139       MVT MaskVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorNumElements());
26140       Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
26141       return DAG.getNode(Opc, dl, Op.getValueType(), Src, Rnd, PassThru, Mask);
26142     }
26143     case CVTNEPS2BF16_MASK: {
26144       SDValue Src = Op.getOperand(1);
26145       SDValue PassThru = Op.getOperand(2);
26146       SDValue Mask = Op.getOperand(3);
26147 
26148       if (ISD::isBuildVectorAllOnes(Mask.getNode()))
26149         return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Src);
26150 
26151       // Break false dependency.
26152       if (PassThru.isUndef())
26153         PassThru = DAG.getConstant(0, dl, PassThru.getValueType());
26154 
26155       return DAG.getNode(IntrData->Opc1, dl, Op.getValueType(), Src, PassThru,
26156                          Mask);
26157     }
26158     default:
26159       break;
26160     }
26161   }
26162 
26163   switch (IntNo) {
26164   default: return SDValue();    // Don't custom lower most intrinsics.
26165 
26166   // ptest and testp intrinsics. The intrinsic these come from are designed to
26167   // return an integer value, not just an instruction so lower it to the ptest
26168   // or testp pattern and a setcc for the result.
26169   case Intrinsic::x86_avx512_ktestc_b:
26170   case Intrinsic::x86_avx512_ktestc_w:
26171   case Intrinsic::x86_avx512_ktestc_d:
26172   case Intrinsic::x86_avx512_ktestc_q:
26173   case Intrinsic::x86_avx512_ktestz_b:
26174   case Intrinsic::x86_avx512_ktestz_w:
26175   case Intrinsic::x86_avx512_ktestz_d:
26176   case Intrinsic::x86_avx512_ktestz_q:
26177   case Intrinsic::x86_sse41_ptestz:
26178   case Intrinsic::x86_sse41_ptestc:
26179   case Intrinsic::x86_sse41_ptestnzc:
26180   case Intrinsic::x86_avx_ptestz_256:
26181   case Intrinsic::x86_avx_ptestc_256:
26182   case Intrinsic::x86_avx_ptestnzc_256:
26183   case Intrinsic::x86_avx_vtestz_ps:
26184   case Intrinsic::x86_avx_vtestc_ps:
26185   case Intrinsic::x86_avx_vtestnzc_ps:
26186   case Intrinsic::x86_avx_vtestz_pd:
26187   case Intrinsic::x86_avx_vtestc_pd:
26188   case Intrinsic::x86_avx_vtestnzc_pd:
26189   case Intrinsic::x86_avx_vtestz_ps_256:
26190   case Intrinsic::x86_avx_vtestc_ps_256:
26191   case Intrinsic::x86_avx_vtestnzc_ps_256:
26192   case Intrinsic::x86_avx_vtestz_pd_256:
26193   case Intrinsic::x86_avx_vtestc_pd_256:
26194   case Intrinsic::x86_avx_vtestnzc_pd_256: {
26195     unsigned TestOpc = X86ISD::PTEST;
26196     X86::CondCode X86CC;
26197     switch (IntNo) {
26198     default: llvm_unreachable("Bad fallthrough in Intrinsic lowering.");
26199     case Intrinsic::x86_avx512_ktestc_b:
26200     case Intrinsic::x86_avx512_ktestc_w:
26201     case Intrinsic::x86_avx512_ktestc_d:
26202     case Intrinsic::x86_avx512_ktestc_q:
26203       // CF = 1
26204       TestOpc = X86ISD::KTEST;
26205       X86CC = X86::COND_B;
26206       break;
26207     case Intrinsic::x86_avx512_ktestz_b:
26208     case Intrinsic::x86_avx512_ktestz_w:
26209     case Intrinsic::x86_avx512_ktestz_d:
26210     case Intrinsic::x86_avx512_ktestz_q:
26211       TestOpc = X86ISD::KTEST;
26212       X86CC = X86::COND_E;
26213       break;
26214     case Intrinsic::x86_avx_vtestz_ps:
26215     case Intrinsic::x86_avx_vtestz_pd:
26216     case Intrinsic::x86_avx_vtestz_ps_256:
26217     case Intrinsic::x86_avx_vtestz_pd_256:
26218       TestOpc = X86ISD::TESTP;
26219       [[fallthrough]];
26220     case Intrinsic::x86_sse41_ptestz:
26221     case Intrinsic::x86_avx_ptestz_256:
26222       // ZF = 1
26223       X86CC = X86::COND_E;
26224       break;
26225     case Intrinsic::x86_avx_vtestc_ps:
26226     case Intrinsic::x86_avx_vtestc_pd:
26227     case Intrinsic::x86_avx_vtestc_ps_256:
26228     case Intrinsic::x86_avx_vtestc_pd_256:
26229       TestOpc = X86ISD::TESTP;
26230       [[fallthrough]];
26231     case Intrinsic::x86_sse41_ptestc:
26232     case Intrinsic::x86_avx_ptestc_256:
26233       // CF = 1
26234       X86CC = X86::COND_B;
26235       break;
26236     case Intrinsic::x86_avx_vtestnzc_ps:
26237     case Intrinsic::x86_avx_vtestnzc_pd:
26238     case Intrinsic::x86_avx_vtestnzc_ps_256:
26239     case Intrinsic::x86_avx_vtestnzc_pd_256:
26240       TestOpc = X86ISD::TESTP;
26241       [[fallthrough]];
26242     case Intrinsic::x86_sse41_ptestnzc:
26243     case Intrinsic::x86_avx_ptestnzc_256:
26244       // ZF and CF = 0
26245       X86CC = X86::COND_A;
26246       break;
26247     }
26248 
26249     SDValue LHS = Op.getOperand(1);
26250     SDValue RHS = Op.getOperand(2);
26251     SDValue Test = DAG.getNode(TestOpc, dl, MVT::i32, LHS, RHS);
26252     SDValue SetCC = getSETCC(X86CC, Test, dl, DAG);
26253     return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC);
26254   }
26255 
26256   case Intrinsic::x86_sse42_pcmpistria128:
26257   case Intrinsic::x86_sse42_pcmpestria128:
26258   case Intrinsic::x86_sse42_pcmpistric128:
26259   case Intrinsic::x86_sse42_pcmpestric128:
26260   case Intrinsic::x86_sse42_pcmpistrio128:
26261   case Intrinsic::x86_sse42_pcmpestrio128:
26262   case Intrinsic::x86_sse42_pcmpistris128:
26263   case Intrinsic::x86_sse42_pcmpestris128:
26264   case Intrinsic::x86_sse42_pcmpistriz128:
26265   case Intrinsic::x86_sse42_pcmpestriz128: {
26266     unsigned Opcode;
26267     X86::CondCode X86CC;
26268     switch (IntNo) {
26269     default: llvm_unreachable("Impossible intrinsic");  // Can't reach here.
26270     case Intrinsic::x86_sse42_pcmpistria128:
26271       Opcode = X86ISD::PCMPISTR;
26272       X86CC = X86::COND_A;
26273       break;
26274     case Intrinsic::x86_sse42_pcmpestria128:
26275       Opcode = X86ISD::PCMPESTR;
26276       X86CC = X86::COND_A;
26277       break;
26278     case Intrinsic::x86_sse42_pcmpistric128:
26279       Opcode = X86ISD::PCMPISTR;
26280       X86CC = X86::COND_B;
26281       break;
26282     case Intrinsic::x86_sse42_pcmpestric128:
26283       Opcode = X86ISD::PCMPESTR;
26284       X86CC = X86::COND_B;
26285       break;
26286     case Intrinsic::x86_sse42_pcmpistrio128:
26287       Opcode = X86ISD::PCMPISTR;
26288       X86CC = X86::COND_O;
26289       break;
26290     case Intrinsic::x86_sse42_pcmpestrio128:
26291       Opcode = X86ISD::PCMPESTR;
26292       X86CC = X86::COND_O;
26293       break;
26294     case Intrinsic::x86_sse42_pcmpistris128:
26295       Opcode = X86ISD::PCMPISTR;
26296       X86CC = X86::COND_S;
26297       break;
26298     case Intrinsic::x86_sse42_pcmpestris128:
26299       Opcode = X86ISD::PCMPESTR;
26300       X86CC = X86::COND_S;
26301       break;
26302     case Intrinsic::x86_sse42_pcmpistriz128:
26303       Opcode = X86ISD::PCMPISTR;
26304       X86CC = X86::COND_E;
26305       break;
26306     case Intrinsic::x86_sse42_pcmpestriz128:
26307       Opcode = X86ISD::PCMPESTR;
26308       X86CC = X86::COND_E;
26309       break;
26310     }
26311     SmallVector<SDValue, 5> NewOps(llvm::drop_begin(Op->ops()));
26312     SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32);
26313     SDValue PCMP = DAG.getNode(Opcode, dl, VTs, NewOps).getValue(2);
26314     SDValue SetCC = getSETCC(X86CC, PCMP, dl, DAG);
26315     return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC);
26316   }
26317 
26318   case Intrinsic::x86_sse42_pcmpistri128:
26319   case Intrinsic::x86_sse42_pcmpestri128: {
26320     unsigned Opcode;
26321     if (IntNo == Intrinsic::x86_sse42_pcmpistri128)
26322       Opcode = X86ISD::PCMPISTR;
26323     else
26324       Opcode = X86ISD::PCMPESTR;
26325 
26326     SmallVector<SDValue, 5> NewOps(llvm::drop_begin(Op->ops()));
26327     SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32);
26328     return DAG.getNode(Opcode, dl, VTs, NewOps);
26329   }
26330 
26331   case Intrinsic::x86_sse42_pcmpistrm128:
26332   case Intrinsic::x86_sse42_pcmpestrm128: {
26333     unsigned Opcode;
26334     if (IntNo == Intrinsic::x86_sse42_pcmpistrm128)
26335       Opcode = X86ISD::PCMPISTR;
26336     else
26337       Opcode = X86ISD::PCMPESTR;
26338 
26339     SmallVector<SDValue, 5> NewOps(llvm::drop_begin(Op->ops()));
26340     SDVTList VTs = DAG.getVTList(MVT::i32, MVT::v16i8, MVT::i32);
26341     return DAG.getNode(Opcode, dl, VTs, NewOps).getValue(1);
26342   }
26343 
26344   case Intrinsic::eh_sjlj_lsda: {
26345     MachineFunction &MF = DAG.getMachineFunction();
26346     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26347     MVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
26348     auto &Context = MF.getContext();
26349     MCSymbol *S = Context.getOrCreateSymbol(Twine("GCC_except_table") +
26350                                             Twine(MF.getFunctionNumber()));
26351     return DAG.getNode(getGlobalWrapperKind(nullptr, /*OpFlags=*/0), dl, VT,
26352                        DAG.getMCSymbol(S, PtrVT));
26353   }
26354 
26355   case Intrinsic::x86_seh_lsda: {
26356     // Compute the symbol for the LSDA. We know it'll get emitted later.
26357     MachineFunction &MF = DAG.getMachineFunction();
26358     SDValue Op1 = Op.getOperand(1);
26359     auto *Fn = cast<Function>(cast<GlobalAddressSDNode>(Op1)->getGlobal());
26360     MCSymbol *LSDASym = MF.getContext().getOrCreateLSDASymbol(
26361         GlobalValue::dropLLVMManglingEscape(Fn->getName()));
26362 
26363     // Generate a simple absolute symbol reference. This intrinsic is only
26364     // supported on 32-bit Windows, which isn't PIC.
26365     SDValue Result = DAG.getMCSymbol(LSDASym, VT);
26366     return DAG.getNode(X86ISD::Wrapper, dl, VT, Result);
26367   }
26368 
26369   case Intrinsic::eh_recoverfp: {
26370     SDValue FnOp = Op.getOperand(1);
26371     SDValue IncomingFPOp = Op.getOperand(2);
26372     GlobalAddressSDNode *GSD = dyn_cast<GlobalAddressSDNode>(FnOp);
26373     auto *Fn = dyn_cast_or_null<Function>(GSD ? GSD->getGlobal() : nullptr);
26374     if (!Fn)
26375       report_fatal_error(
26376           "llvm.eh.recoverfp must take a function as the first argument");
26377     return recoverFramePointer(DAG, Fn, IncomingFPOp);
26378   }
26379 
26380   case Intrinsic::localaddress: {
26381     // Returns one of the stack, base, or frame pointer registers, depending on
26382     // which is used to reference local variables.
26383     MachineFunction &MF = DAG.getMachineFunction();
26384     const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
26385     unsigned Reg;
26386     if (RegInfo->hasBasePointer(MF))
26387       Reg = RegInfo->getBaseRegister();
26388     else { // Handles the SP or FP case.
26389       bool CantUseFP = RegInfo->hasStackRealignment(MF);
26390       if (CantUseFP)
26391         Reg = RegInfo->getPtrSizedStackRegister(MF);
26392       else
26393         Reg = RegInfo->getPtrSizedFrameRegister(MF);
26394     }
26395     return DAG.getCopyFromReg(DAG.getEntryNode(), dl, Reg, VT);
26396   }
26397   case Intrinsic::x86_avx512_vp2intersect_q_512:
26398   case Intrinsic::x86_avx512_vp2intersect_q_256:
26399   case Intrinsic::x86_avx512_vp2intersect_q_128:
26400   case Intrinsic::x86_avx512_vp2intersect_d_512:
26401   case Intrinsic::x86_avx512_vp2intersect_d_256:
26402   case Intrinsic::x86_avx512_vp2intersect_d_128: {
26403     MVT MaskVT = Op.getSimpleValueType();
26404 
26405     SDVTList VTs = DAG.getVTList(MVT::Untyped, MVT::Other);
26406     SDLoc DL(Op);
26407 
26408     SDValue Operation =
26409         DAG.getNode(X86ISD::VP2INTERSECT, DL, VTs,
26410                     Op->getOperand(1), Op->getOperand(2));
26411 
26412     SDValue Result0 = DAG.getTargetExtractSubreg(X86::sub_mask_0, DL,
26413                                                  MaskVT, Operation);
26414     SDValue Result1 = DAG.getTargetExtractSubreg(X86::sub_mask_1, DL,
26415                                                  MaskVT, Operation);
26416     return DAG.getMergeValues({Result0, Result1}, DL);
26417   }
26418   case Intrinsic::x86_mmx_pslli_w:
26419   case Intrinsic::x86_mmx_pslli_d:
26420   case Intrinsic::x86_mmx_pslli_q:
26421   case Intrinsic::x86_mmx_psrli_w:
26422   case Intrinsic::x86_mmx_psrli_d:
26423   case Intrinsic::x86_mmx_psrli_q:
26424   case Intrinsic::x86_mmx_psrai_w:
26425   case Intrinsic::x86_mmx_psrai_d: {
26426     SDLoc DL(Op);
26427     SDValue ShAmt = Op.getOperand(2);
26428     // If the argument is a constant, convert it to a target constant.
26429     if (auto *C = dyn_cast<ConstantSDNode>(ShAmt)) {
26430       // Clamp out of bounds shift amounts since they will otherwise be masked
26431       // to 8-bits which may make it no longer out of bounds.
26432       unsigned ShiftAmount = C->getAPIntValue().getLimitedValue(255);
26433       if (ShiftAmount == 0)
26434         return Op.getOperand(1);
26435 
26436       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
26437                          Op.getOperand(0), Op.getOperand(1),
26438                          DAG.getTargetConstant(ShiftAmount, DL, MVT::i32));
26439     }
26440 
26441     unsigned NewIntrinsic;
26442     switch (IntNo) {
26443     default: llvm_unreachable("Impossible intrinsic");  // Can't reach here.
26444     case Intrinsic::x86_mmx_pslli_w:
26445       NewIntrinsic = Intrinsic::x86_mmx_psll_w;
26446       break;
26447     case Intrinsic::x86_mmx_pslli_d:
26448       NewIntrinsic = Intrinsic::x86_mmx_psll_d;
26449       break;
26450     case Intrinsic::x86_mmx_pslli_q:
26451       NewIntrinsic = Intrinsic::x86_mmx_psll_q;
26452       break;
26453     case Intrinsic::x86_mmx_psrli_w:
26454       NewIntrinsic = Intrinsic::x86_mmx_psrl_w;
26455       break;
26456     case Intrinsic::x86_mmx_psrli_d:
26457       NewIntrinsic = Intrinsic::x86_mmx_psrl_d;
26458       break;
26459     case Intrinsic::x86_mmx_psrli_q:
26460       NewIntrinsic = Intrinsic::x86_mmx_psrl_q;
26461       break;
26462     case Intrinsic::x86_mmx_psrai_w:
26463       NewIntrinsic = Intrinsic::x86_mmx_psra_w;
26464       break;
26465     case Intrinsic::x86_mmx_psrai_d:
26466       NewIntrinsic = Intrinsic::x86_mmx_psra_d;
26467       break;
26468     }
26469 
26470     // The vector shift intrinsics with scalars uses 32b shift amounts but
26471     // the sse2/mmx shift instructions reads 64 bits. Copy the 32 bits to an
26472     // MMX register.
26473     ShAmt = DAG.getNode(X86ISD::MMX_MOVW2D, DL, MVT::x86mmx, ShAmt);
26474     return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
26475                        DAG.getTargetConstant(NewIntrinsic, DL,
26476                                              getPointerTy(DAG.getDataLayout())),
26477                        Op.getOperand(1), ShAmt);
26478   }
26479   case Intrinsic::thread_pointer: {
26480     if (Subtarget.isTargetELF()) {
26481       SDLoc dl(Op);
26482       EVT PtrVT = getPointerTy(DAG.getDataLayout());
26483       // Get the Thread Pointer, which is %gs:0 (32-bit) or %fs:0 (64-bit).
26484       Value *Ptr = Constant::getNullValue(PointerType::get(
26485           *DAG.getContext(), Subtarget.is64Bit() ? X86AS::FS : X86AS::GS));
26486       return DAG.getLoad(PtrVT, dl, DAG.getEntryNode(),
26487                          DAG.getIntPtrConstant(0, dl), MachinePointerInfo(Ptr));
26488     }
26489     report_fatal_error(
26490         "Target OS doesn't support __builtin_thread_pointer() yet.");
26491   }
26492   }
26493 }
26494 
getAVX2GatherNode(unsigned Opc,SDValue Op,SelectionDAG & DAG,SDValue Src,SDValue Mask,SDValue Base,SDValue Index,SDValue ScaleOp,SDValue Chain,const X86Subtarget & Subtarget)26495 static SDValue getAVX2GatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
26496                                  SDValue Src, SDValue Mask, SDValue Base,
26497                                  SDValue Index, SDValue ScaleOp, SDValue Chain,
26498                                  const X86Subtarget &Subtarget) {
26499   SDLoc dl(Op);
26500   auto *C = dyn_cast<ConstantSDNode>(ScaleOp);
26501   // Scale must be constant.
26502   if (!C)
26503     return SDValue();
26504   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26505   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl,
26506                                         TLI.getPointerTy(DAG.getDataLayout()));
26507   EVT MaskVT = Mask.getValueType().changeVectorElementTypeToInteger();
26508   SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Other);
26509   // If source is undef or we know it won't be used, use a zero vector
26510   // to break register dependency.
26511   // TODO: use undef instead and let BreakFalseDeps deal with it?
26512   if (Src.isUndef() || ISD::isBuildVectorAllOnes(Mask.getNode()))
26513     Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl);
26514 
26515   // Cast mask to an integer type.
26516   Mask = DAG.getBitcast(MaskVT, Mask);
26517 
26518   MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op);
26519 
26520   SDValue Ops[] = {Chain, Src, Mask, Base, Index, Scale };
26521   SDValue Res =
26522       DAG.getMemIntrinsicNode(X86ISD::MGATHER, dl, VTs, Ops,
26523                               MemIntr->getMemoryVT(), MemIntr->getMemOperand());
26524   return DAG.getMergeValues({Res, Res.getValue(1)}, dl);
26525 }
26526 
getGatherNode(SDValue Op,SelectionDAG & DAG,SDValue Src,SDValue Mask,SDValue Base,SDValue Index,SDValue ScaleOp,SDValue Chain,const X86Subtarget & Subtarget)26527 static SDValue getGatherNode(SDValue Op, SelectionDAG &DAG,
26528                              SDValue Src, SDValue Mask, SDValue Base,
26529                              SDValue Index, SDValue ScaleOp, SDValue Chain,
26530                              const X86Subtarget &Subtarget) {
26531   MVT VT = Op.getSimpleValueType();
26532   SDLoc dl(Op);
26533   auto *C = dyn_cast<ConstantSDNode>(ScaleOp);
26534   // Scale must be constant.
26535   if (!C)
26536     return SDValue();
26537   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26538   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl,
26539                                         TLI.getPointerTy(DAG.getDataLayout()));
26540   unsigned MinElts = std::min(Index.getSimpleValueType().getVectorNumElements(),
26541                               VT.getVectorNumElements());
26542   MVT MaskVT = MVT::getVectorVT(MVT::i1, MinElts);
26543 
26544   // We support two versions of the gather intrinsics. One with scalar mask and
26545   // one with vXi1 mask. Convert scalar to vXi1 if necessary.
26546   if (Mask.getValueType() != MaskVT)
26547     Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
26548 
26549   SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Other);
26550   // If source is undef or we know it won't be used, use a zero vector
26551   // to break register dependency.
26552   // TODO: use undef instead and let BreakFalseDeps deal with it?
26553   if (Src.isUndef() || ISD::isBuildVectorAllOnes(Mask.getNode()))
26554     Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl);
26555 
26556   MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op);
26557 
26558   SDValue Ops[] = {Chain, Src, Mask, Base, Index, Scale };
26559   SDValue Res =
26560       DAG.getMemIntrinsicNode(X86ISD::MGATHER, dl, VTs, Ops,
26561                               MemIntr->getMemoryVT(), MemIntr->getMemOperand());
26562   return DAG.getMergeValues({Res, Res.getValue(1)}, dl);
26563 }
26564 
getScatterNode(unsigned Opc,SDValue Op,SelectionDAG & DAG,SDValue Src,SDValue Mask,SDValue Base,SDValue Index,SDValue ScaleOp,SDValue Chain,const X86Subtarget & Subtarget)26565 static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
26566                                SDValue Src, SDValue Mask, SDValue Base,
26567                                SDValue Index, SDValue ScaleOp, SDValue Chain,
26568                                const X86Subtarget &Subtarget) {
26569   SDLoc dl(Op);
26570   auto *C = dyn_cast<ConstantSDNode>(ScaleOp);
26571   // Scale must be constant.
26572   if (!C)
26573     return SDValue();
26574   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26575   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl,
26576                                         TLI.getPointerTy(DAG.getDataLayout()));
26577   unsigned MinElts = std::min(Index.getSimpleValueType().getVectorNumElements(),
26578                               Src.getSimpleValueType().getVectorNumElements());
26579   MVT MaskVT = MVT::getVectorVT(MVT::i1, MinElts);
26580 
26581   // We support two versions of the scatter intrinsics. One with scalar mask and
26582   // one with vXi1 mask. Convert scalar to vXi1 if necessary.
26583   if (Mask.getValueType() != MaskVT)
26584     Mask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
26585 
26586   MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op);
26587 
26588   SDVTList VTs = DAG.getVTList(MVT::Other);
26589   SDValue Ops[] = {Chain, Src, Mask, Base, Index, Scale};
26590   SDValue Res =
26591       DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
26592                               MemIntr->getMemoryVT(), MemIntr->getMemOperand());
26593   return Res;
26594 }
26595 
getPrefetchNode(unsigned Opc,SDValue Op,SelectionDAG & DAG,SDValue Mask,SDValue Base,SDValue Index,SDValue ScaleOp,SDValue Chain,const X86Subtarget & Subtarget)26596 static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
26597                                SDValue Mask, SDValue Base, SDValue Index,
26598                                SDValue ScaleOp, SDValue Chain,
26599                                const X86Subtarget &Subtarget) {
26600   SDLoc dl(Op);
26601   auto *C = dyn_cast<ConstantSDNode>(ScaleOp);
26602   // Scale must be constant.
26603   if (!C)
26604     return SDValue();
26605   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26606   SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl,
26607                                         TLI.getPointerTy(DAG.getDataLayout()));
26608   SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32);
26609   SDValue Segment = DAG.getRegister(0, MVT::i32);
26610   MVT MaskVT =
26611     MVT::getVectorVT(MVT::i1, Index.getSimpleValueType().getVectorNumElements());
26612   SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
26613   SDValue Ops[] = {VMask, Base, Scale, Index, Disp, Segment, Chain};
26614   SDNode *Res = DAG.getMachineNode(Opc, dl, MVT::Other, Ops);
26615   return SDValue(Res, 0);
26616 }
26617 
26618 /// Handles the lowering of builtin intrinsics with chain that return their
26619 /// value into registers EDX:EAX.
26620 /// If operand ScrReg is a valid register identifier, then operand 2 of N is
26621 /// copied to SrcReg. The assumption is that SrcReg is an implicit input to
26622 /// TargetOpcode.
26623 /// Returns a Glue value which can be used to add extra copy-from-reg if the
26624 /// expanded intrinsics implicitly defines extra registers (i.e. not just
26625 /// EDX:EAX).
expandIntrinsicWChainHelper(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,unsigned TargetOpcode,unsigned SrcReg,const X86Subtarget & Subtarget,SmallVectorImpl<SDValue> & Results)26626 static SDValue expandIntrinsicWChainHelper(SDNode *N, const SDLoc &DL,
26627                                         SelectionDAG &DAG,
26628                                         unsigned TargetOpcode,
26629                                         unsigned SrcReg,
26630                                         const X86Subtarget &Subtarget,
26631                                         SmallVectorImpl<SDValue> &Results) {
26632   SDValue Chain = N->getOperand(0);
26633   SDValue Glue;
26634 
26635   if (SrcReg) {
26636     assert(N->getNumOperands() == 3 && "Unexpected number of operands!");
26637     Chain = DAG.getCopyToReg(Chain, DL, SrcReg, N->getOperand(2), Glue);
26638     Glue = Chain.getValue(1);
26639   }
26640 
26641   SDVTList Tys = DAG.getVTList(MVT::Other, MVT::Glue);
26642   SDValue N1Ops[] = {Chain, Glue};
26643   SDNode *N1 = DAG.getMachineNode(
26644       TargetOpcode, DL, Tys, ArrayRef<SDValue>(N1Ops, Glue.getNode() ? 2 : 1));
26645   Chain = SDValue(N1, 0);
26646 
26647   // Reads the content of XCR and returns it in registers EDX:EAX.
26648   SDValue LO, HI;
26649   if (Subtarget.is64Bit()) {
26650     LO = DAG.getCopyFromReg(Chain, DL, X86::RAX, MVT::i64, SDValue(N1, 1));
26651     HI = DAG.getCopyFromReg(LO.getValue(1), DL, X86::RDX, MVT::i64,
26652                             LO.getValue(2));
26653   } else {
26654     LO = DAG.getCopyFromReg(Chain, DL, X86::EAX, MVT::i32, SDValue(N1, 1));
26655     HI = DAG.getCopyFromReg(LO.getValue(1), DL, X86::EDX, MVT::i32,
26656                             LO.getValue(2));
26657   }
26658   Chain = HI.getValue(1);
26659   Glue = HI.getValue(2);
26660 
26661   if (Subtarget.is64Bit()) {
26662     // Merge the two 32-bit values into a 64-bit one.
26663     SDValue Tmp = DAG.getNode(ISD::SHL, DL, MVT::i64, HI,
26664                               DAG.getConstant(32, DL, MVT::i8));
26665     Results.push_back(DAG.getNode(ISD::OR, DL, MVT::i64, LO, Tmp));
26666     Results.push_back(Chain);
26667     return Glue;
26668   }
26669 
26670   // Use a buildpair to merge the two 32-bit values into a 64-bit one.
26671   SDValue Ops[] = { LO, HI };
26672   SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, Ops);
26673   Results.push_back(Pair);
26674   Results.push_back(Chain);
26675   return Glue;
26676 }
26677 
26678 /// Handles the lowering of builtin intrinsics that read the time stamp counter
26679 /// (x86_rdtsc and x86_rdtscp). This function is also used to custom lower
26680 /// READCYCLECOUNTER nodes.
getReadTimeStampCounter(SDNode * N,const SDLoc & DL,unsigned Opcode,SelectionDAG & DAG,const X86Subtarget & Subtarget,SmallVectorImpl<SDValue> & Results)26681 static void getReadTimeStampCounter(SDNode *N, const SDLoc &DL, unsigned Opcode,
26682                                     SelectionDAG &DAG,
26683                                     const X86Subtarget &Subtarget,
26684                                     SmallVectorImpl<SDValue> &Results) {
26685   // The processor's time-stamp counter (a 64-bit MSR) is stored into the
26686   // EDX:EAX registers. EDX is loaded with the high-order 32 bits of the MSR
26687   // and the EAX register is loaded with the low-order 32 bits.
26688   SDValue Glue = expandIntrinsicWChainHelper(N, DL, DAG, Opcode,
26689                                              /* NoRegister */0, Subtarget,
26690                                              Results);
26691   if (Opcode != X86::RDTSCP)
26692     return;
26693 
26694   SDValue Chain = Results[1];
26695   // Instruction RDTSCP loads the IA32:TSC_AUX_MSR (address C000_0103H) into
26696   // the ECX register. Add 'ecx' explicitly to the chain.
26697   SDValue ecx = DAG.getCopyFromReg(Chain, DL, X86::ECX, MVT::i32, Glue);
26698   Results[1] = ecx;
26699   Results.push_back(ecx.getValue(1));
26700 }
26701 
LowerREADCYCLECOUNTER(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)26702 static SDValue LowerREADCYCLECOUNTER(SDValue Op, const X86Subtarget &Subtarget,
26703                                      SelectionDAG &DAG) {
26704   SmallVector<SDValue, 3> Results;
26705   SDLoc DL(Op);
26706   getReadTimeStampCounter(Op.getNode(), DL, X86::RDTSC, DAG, Subtarget,
26707                           Results);
26708   return DAG.getMergeValues(Results, DL);
26709 }
26710 
MarkEHRegistrationNode(SDValue Op,SelectionDAG & DAG)26711 static SDValue MarkEHRegistrationNode(SDValue Op, SelectionDAG &DAG) {
26712   MachineFunction &MF = DAG.getMachineFunction();
26713   SDValue Chain = Op.getOperand(0);
26714   SDValue RegNode = Op.getOperand(2);
26715   WinEHFuncInfo *EHInfo = MF.getWinEHFuncInfo();
26716   if (!EHInfo)
26717     report_fatal_error("EH registrations only live in functions using WinEH");
26718 
26719   // Cast the operand to an alloca, and remember the frame index.
26720   auto *FINode = dyn_cast<FrameIndexSDNode>(RegNode);
26721   if (!FINode)
26722     report_fatal_error("llvm.x86.seh.ehregnode expects a static alloca");
26723   EHInfo->EHRegNodeFrameIndex = FINode->getIndex();
26724 
26725   // Return the chain operand without making any DAG nodes.
26726   return Chain;
26727 }
26728 
MarkEHGuard(SDValue Op,SelectionDAG & DAG)26729 static SDValue MarkEHGuard(SDValue Op, SelectionDAG &DAG) {
26730   MachineFunction &MF = DAG.getMachineFunction();
26731   SDValue Chain = Op.getOperand(0);
26732   SDValue EHGuard = Op.getOperand(2);
26733   WinEHFuncInfo *EHInfo = MF.getWinEHFuncInfo();
26734   if (!EHInfo)
26735     report_fatal_error("EHGuard only live in functions using WinEH");
26736 
26737   // Cast the operand to an alloca, and remember the frame index.
26738   auto *FINode = dyn_cast<FrameIndexSDNode>(EHGuard);
26739   if (!FINode)
26740     report_fatal_error("llvm.x86.seh.ehguard expects a static alloca");
26741   EHInfo->EHGuardFrameIndex = FINode->getIndex();
26742 
26743   // Return the chain operand without making any DAG nodes.
26744   return Chain;
26745 }
26746 
26747 /// Emit Truncating Store with signed or unsigned saturation.
26748 static SDValue
EmitTruncSStore(bool SignedSat,SDValue Chain,const SDLoc & DL,SDValue Val,SDValue Ptr,EVT MemVT,MachineMemOperand * MMO,SelectionDAG & DAG)26749 EmitTruncSStore(bool SignedSat, SDValue Chain, const SDLoc &DL, SDValue Val,
26750                 SDValue Ptr, EVT MemVT, MachineMemOperand *MMO,
26751                 SelectionDAG &DAG) {
26752   SDVTList VTs = DAG.getVTList(MVT::Other);
26753   SDValue Undef = DAG.getUNDEF(Ptr.getValueType());
26754   SDValue Ops[] = { Chain, Val, Ptr, Undef };
26755   unsigned Opc = SignedSat ? X86ISD::VTRUNCSTORES : X86ISD::VTRUNCSTOREUS;
26756   return DAG.getMemIntrinsicNode(Opc, DL, VTs, Ops, MemVT, MMO);
26757 }
26758 
26759 /// Emit Masked Truncating Store with signed or unsigned saturation.
EmitMaskedTruncSStore(bool SignedSat,SDValue Chain,const SDLoc & DL,SDValue Val,SDValue Ptr,SDValue Mask,EVT MemVT,MachineMemOperand * MMO,SelectionDAG & DAG)26760 static SDValue EmitMaskedTruncSStore(bool SignedSat, SDValue Chain,
26761                                      const SDLoc &DL,
26762                       SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT,
26763                       MachineMemOperand *MMO, SelectionDAG &DAG) {
26764   SDVTList VTs = DAG.getVTList(MVT::Other);
26765   SDValue Ops[] = { Chain, Val, Ptr, Mask };
26766   unsigned Opc = SignedSat ? X86ISD::VMTRUNCSTORES : X86ISD::VMTRUNCSTOREUS;
26767   return DAG.getMemIntrinsicNode(Opc, DL, VTs, Ops, MemVT, MMO);
26768 }
26769 
isExtendedSwiftAsyncFrameSupported(const X86Subtarget & Subtarget,const MachineFunction & MF)26770 bool X86::isExtendedSwiftAsyncFrameSupported(const X86Subtarget &Subtarget,
26771                                              const MachineFunction &MF) {
26772   if (!Subtarget.is64Bit())
26773     return false;
26774   // 64-bit targets support extended Swift async frame setup,
26775   // except for targets that use the windows 64 prologue.
26776   return !MF.getTarget().getMCAsmInfo()->usesWindowsCFI();
26777 }
26778 
LowerINTRINSIC_W_CHAIN(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)26779 static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
26780                                       SelectionDAG &DAG) {
26781   unsigned IntNo = Op.getConstantOperandVal(1);
26782   const IntrinsicData *IntrData = getIntrinsicWithChain(IntNo);
26783   if (!IntrData) {
26784     switch (IntNo) {
26785 
26786     case Intrinsic::swift_async_context_addr: {
26787       SDLoc dl(Op);
26788       auto &MF = DAG.getMachineFunction();
26789       auto *X86FI = MF.getInfo<X86MachineFunctionInfo>();
26790       if (X86::isExtendedSwiftAsyncFrameSupported(Subtarget, MF)) {
26791         MF.getFrameInfo().setFrameAddressIsTaken(true);
26792         X86FI->setHasSwiftAsyncContext(true);
26793         SDValue Chain = Op->getOperand(0);
26794         SDValue CopyRBP = DAG.getCopyFromReg(Chain, dl, X86::RBP, MVT::i64);
26795         SDValue Result =
26796             SDValue(DAG.getMachineNode(X86::SUB64ri32, dl, MVT::i64, CopyRBP,
26797                                        DAG.getTargetConstant(8, dl, MVT::i32)),
26798                     0);
26799         // Return { result, chain }.
26800         return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Result,
26801                            CopyRBP.getValue(1));
26802       } else {
26803         // No special extended frame, create or reuse an existing stack slot.
26804         int PtrSize = Subtarget.is64Bit() ? 8 : 4;
26805         if (!X86FI->getSwiftAsyncContextFrameIdx())
26806           X86FI->setSwiftAsyncContextFrameIdx(
26807               MF.getFrameInfo().CreateStackObject(PtrSize, Align(PtrSize),
26808                                                   false));
26809         SDValue Result =
26810             DAG.getFrameIndex(*X86FI->getSwiftAsyncContextFrameIdx(),
26811                               PtrSize == 8 ? MVT::i64 : MVT::i32);
26812         // Return { result, chain }.
26813         return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Result,
26814                            Op->getOperand(0));
26815       }
26816     }
26817 
26818     case llvm::Intrinsic::x86_seh_ehregnode:
26819       return MarkEHRegistrationNode(Op, DAG);
26820     case llvm::Intrinsic::x86_seh_ehguard:
26821       return MarkEHGuard(Op, DAG);
26822     case llvm::Intrinsic::x86_rdpkru: {
26823       SDLoc dl(Op);
26824       SDVTList VTs = DAG.getVTList(MVT::i32, MVT::Other);
26825       // Create a RDPKRU node and pass 0 to the ECX parameter.
26826       return DAG.getNode(X86ISD::RDPKRU, dl, VTs, Op.getOperand(0),
26827                          DAG.getConstant(0, dl, MVT::i32));
26828     }
26829     case llvm::Intrinsic::x86_wrpkru: {
26830       SDLoc dl(Op);
26831       // Create a WRPKRU node, pass the input to the EAX parameter,  and pass 0
26832       // to the EDX and ECX parameters.
26833       return DAG.getNode(X86ISD::WRPKRU, dl, MVT::Other,
26834                          Op.getOperand(0), Op.getOperand(2),
26835                          DAG.getConstant(0, dl, MVT::i32),
26836                          DAG.getConstant(0, dl, MVT::i32));
26837     }
26838     case llvm::Intrinsic::asan_check_memaccess: {
26839       // Mark this as adjustsStack because it will be lowered to a call.
26840       DAG.getMachineFunction().getFrameInfo().setAdjustsStack(true);
26841       // Don't do anything here, we will expand these intrinsics out later.
26842       return Op;
26843     }
26844     case llvm::Intrinsic::x86_flags_read_u32:
26845     case llvm::Intrinsic::x86_flags_read_u64:
26846     case llvm::Intrinsic::x86_flags_write_u32:
26847     case llvm::Intrinsic::x86_flags_write_u64: {
26848       // We need a frame pointer because this will get lowered to a PUSH/POP
26849       // sequence.
26850       MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
26851       MFI.setHasCopyImplyingStackAdjustment(true);
26852       // Don't do anything here, we will expand these intrinsics out later
26853       // during FinalizeISel in EmitInstrWithCustomInserter.
26854       return Op;
26855     }
26856     case Intrinsic::x86_lwpins32:
26857     case Intrinsic::x86_lwpins64:
26858     case Intrinsic::x86_umwait:
26859     case Intrinsic::x86_tpause: {
26860       SDLoc dl(Op);
26861       SDValue Chain = Op->getOperand(0);
26862       SDVTList VTs = DAG.getVTList(MVT::i32, MVT::Other);
26863       unsigned Opcode;
26864 
26865       switch (IntNo) {
26866       default: llvm_unreachable("Impossible intrinsic");
26867       case Intrinsic::x86_umwait:
26868         Opcode = X86ISD::UMWAIT;
26869         break;
26870       case Intrinsic::x86_tpause:
26871         Opcode = X86ISD::TPAUSE;
26872         break;
26873       case Intrinsic::x86_lwpins32:
26874       case Intrinsic::x86_lwpins64:
26875         Opcode = X86ISD::LWPINS;
26876         break;
26877       }
26878 
26879       SDValue Operation =
26880           DAG.getNode(Opcode, dl, VTs, Chain, Op->getOperand(2),
26881                       Op->getOperand(3), Op->getOperand(4));
26882       SDValue SetCC = getSETCC(X86::COND_B, Operation.getValue(0), dl, DAG);
26883       return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,
26884                          Operation.getValue(1));
26885     }
26886     case Intrinsic::x86_enqcmd:
26887     case Intrinsic::x86_enqcmds: {
26888       SDLoc dl(Op);
26889       SDValue Chain = Op.getOperand(0);
26890       SDVTList VTs = DAG.getVTList(MVT::i32, MVT::Other);
26891       unsigned Opcode;
26892       switch (IntNo) {
26893       default: llvm_unreachable("Impossible intrinsic!");
26894       case Intrinsic::x86_enqcmd:
26895         Opcode = X86ISD::ENQCMD;
26896         break;
26897       case Intrinsic::x86_enqcmds:
26898         Opcode = X86ISD::ENQCMDS;
26899         break;
26900       }
26901       SDValue Operation = DAG.getNode(Opcode, dl, VTs, Chain, Op.getOperand(2),
26902                                       Op.getOperand(3));
26903       SDValue SetCC = getSETCC(X86::COND_E, Operation.getValue(0), dl, DAG);
26904       return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,
26905                          Operation.getValue(1));
26906     }
26907     case Intrinsic::x86_aesenc128kl:
26908     case Intrinsic::x86_aesdec128kl:
26909     case Intrinsic::x86_aesenc256kl:
26910     case Intrinsic::x86_aesdec256kl: {
26911       SDLoc DL(Op);
26912       SDVTList VTs = DAG.getVTList(MVT::v2i64, MVT::i32, MVT::Other);
26913       SDValue Chain = Op.getOperand(0);
26914       unsigned Opcode;
26915 
26916       switch (IntNo) {
26917       default: llvm_unreachable("Impossible intrinsic");
26918       case Intrinsic::x86_aesenc128kl:
26919         Opcode = X86ISD::AESENC128KL;
26920         break;
26921       case Intrinsic::x86_aesdec128kl:
26922         Opcode = X86ISD::AESDEC128KL;
26923         break;
26924       case Intrinsic::x86_aesenc256kl:
26925         Opcode = X86ISD::AESENC256KL;
26926         break;
26927       case Intrinsic::x86_aesdec256kl:
26928         Opcode = X86ISD::AESDEC256KL;
26929         break;
26930       }
26931 
26932       MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op);
26933       MachineMemOperand *MMO = MemIntr->getMemOperand();
26934       EVT MemVT = MemIntr->getMemoryVT();
26935       SDValue Operation = DAG.getMemIntrinsicNode(
26936           Opcode, DL, VTs, {Chain, Op.getOperand(2), Op.getOperand(3)}, MemVT,
26937           MMO);
26938       SDValue ZF = getSETCC(X86::COND_E, Operation.getValue(1), DL, DAG);
26939 
26940       return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(),
26941                          {ZF, Operation.getValue(0), Operation.getValue(2)});
26942     }
26943     case Intrinsic::x86_aesencwide128kl:
26944     case Intrinsic::x86_aesdecwide128kl:
26945     case Intrinsic::x86_aesencwide256kl:
26946     case Intrinsic::x86_aesdecwide256kl: {
26947       SDLoc DL(Op);
26948       SDVTList VTs = DAG.getVTList(
26949           {MVT::i32, MVT::v2i64, MVT::v2i64, MVT::v2i64, MVT::v2i64, MVT::v2i64,
26950            MVT::v2i64, MVT::v2i64, MVT::v2i64, MVT::Other});
26951       SDValue Chain = Op.getOperand(0);
26952       unsigned Opcode;
26953 
26954       switch (IntNo) {
26955       default: llvm_unreachable("Impossible intrinsic");
26956       case Intrinsic::x86_aesencwide128kl:
26957         Opcode = X86ISD::AESENCWIDE128KL;
26958         break;
26959       case Intrinsic::x86_aesdecwide128kl:
26960         Opcode = X86ISD::AESDECWIDE128KL;
26961         break;
26962       case Intrinsic::x86_aesencwide256kl:
26963         Opcode = X86ISD::AESENCWIDE256KL;
26964         break;
26965       case Intrinsic::x86_aesdecwide256kl:
26966         Opcode = X86ISD::AESDECWIDE256KL;
26967         break;
26968       }
26969 
26970       MemIntrinsicSDNode *MemIntr = cast<MemIntrinsicSDNode>(Op);
26971       MachineMemOperand *MMO = MemIntr->getMemOperand();
26972       EVT MemVT = MemIntr->getMemoryVT();
26973       SDValue Operation = DAG.getMemIntrinsicNode(
26974           Opcode, DL, VTs,
26975           {Chain, Op.getOperand(2), Op.getOperand(3), Op.getOperand(4),
26976            Op.getOperand(5), Op.getOperand(6), Op.getOperand(7),
26977            Op.getOperand(8), Op.getOperand(9), Op.getOperand(10)},
26978           MemVT, MMO);
26979       SDValue ZF = getSETCC(X86::COND_E, Operation.getValue(0), DL, DAG);
26980 
26981       return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(),
26982                          {ZF, Operation.getValue(1), Operation.getValue(2),
26983                           Operation.getValue(3), Operation.getValue(4),
26984                           Operation.getValue(5), Operation.getValue(6),
26985                           Operation.getValue(7), Operation.getValue(8),
26986                           Operation.getValue(9)});
26987     }
26988     case Intrinsic::x86_testui: {
26989       SDLoc dl(Op);
26990       SDValue Chain = Op.getOperand(0);
26991       SDVTList VTs = DAG.getVTList(MVT::i32, MVT::Other);
26992       SDValue Operation = DAG.getNode(X86ISD::TESTUI, dl, VTs, Chain);
26993       SDValue SetCC = getSETCC(X86::COND_B, Operation.getValue(0), dl, DAG);
26994       return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,
26995                          Operation.getValue(1));
26996     }
26997     case Intrinsic::x86_atomic_bts_rm:
26998     case Intrinsic::x86_atomic_btc_rm:
26999     case Intrinsic::x86_atomic_btr_rm: {
27000       SDLoc DL(Op);
27001       MVT VT = Op.getSimpleValueType();
27002       SDValue Chain = Op.getOperand(0);
27003       SDValue Op1 = Op.getOperand(2);
27004       SDValue Op2 = Op.getOperand(3);
27005       unsigned Opc = IntNo == Intrinsic::x86_atomic_bts_rm   ? X86ISD::LBTS_RM
27006                      : IntNo == Intrinsic::x86_atomic_btc_rm ? X86ISD::LBTC_RM
27007                                                              : X86ISD::LBTR_RM;
27008       MachineMemOperand *MMO = cast<MemIntrinsicSDNode>(Op)->getMemOperand();
27009       SDValue Res =
27010           DAG.getMemIntrinsicNode(Opc, DL, DAG.getVTList(MVT::i32, MVT::Other),
27011                                   {Chain, Op1, Op2}, VT, MMO);
27012       Chain = Res.getValue(1);
27013       Res = DAG.getZExtOrTrunc(getSETCC(X86::COND_B, Res, DL, DAG), DL, VT);
27014       return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(), Res, Chain);
27015     }
27016     case Intrinsic::x86_atomic_bts:
27017     case Intrinsic::x86_atomic_btc:
27018     case Intrinsic::x86_atomic_btr: {
27019       SDLoc DL(Op);
27020       MVT VT = Op.getSimpleValueType();
27021       SDValue Chain = Op.getOperand(0);
27022       SDValue Op1 = Op.getOperand(2);
27023       SDValue Op2 = Op.getOperand(3);
27024       unsigned Opc = IntNo == Intrinsic::x86_atomic_bts   ? X86ISD::LBTS
27025                      : IntNo == Intrinsic::x86_atomic_btc ? X86ISD::LBTC
27026                                                           : X86ISD::LBTR;
27027       SDValue Size = DAG.getConstant(VT.getScalarSizeInBits(), DL, MVT::i32);
27028       MachineMemOperand *MMO = cast<MemIntrinsicSDNode>(Op)->getMemOperand();
27029       SDValue Res =
27030           DAG.getMemIntrinsicNode(Opc, DL, DAG.getVTList(MVT::i32, MVT::Other),
27031                                   {Chain, Op1, Op2, Size}, VT, MMO);
27032       Chain = Res.getValue(1);
27033       Res = DAG.getZExtOrTrunc(getSETCC(X86::COND_B, Res, DL, DAG), DL, VT);
27034       unsigned Imm = Op2->getAsZExtVal();
27035       if (Imm)
27036         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
27037                           DAG.getShiftAmountConstant(Imm, VT, DL));
27038       return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(), Res, Chain);
27039     }
27040     case Intrinsic::x86_cmpccxadd32:
27041     case Intrinsic::x86_cmpccxadd64: {
27042       SDLoc DL(Op);
27043       SDValue Chain = Op.getOperand(0);
27044       SDValue Addr = Op.getOperand(2);
27045       SDValue Src1 = Op.getOperand(3);
27046       SDValue Src2 = Op.getOperand(4);
27047       SDValue CC = Op.getOperand(5);
27048       MachineMemOperand *MMO = cast<MemIntrinsicSDNode>(Op)->getMemOperand();
27049       SDValue Operation = DAG.getMemIntrinsicNode(
27050           X86ISD::CMPCCXADD, DL, Op->getVTList(), {Chain, Addr, Src1, Src2, CC},
27051           MVT::i32, MMO);
27052       return Operation;
27053     }
27054     case Intrinsic::x86_aadd32:
27055     case Intrinsic::x86_aadd64:
27056     case Intrinsic::x86_aand32:
27057     case Intrinsic::x86_aand64:
27058     case Intrinsic::x86_aor32:
27059     case Intrinsic::x86_aor64:
27060     case Intrinsic::x86_axor32:
27061     case Intrinsic::x86_axor64: {
27062       SDLoc DL(Op);
27063       SDValue Chain = Op.getOperand(0);
27064       SDValue Op1 = Op.getOperand(2);
27065       SDValue Op2 = Op.getOperand(3);
27066       MVT VT = Op2.getSimpleValueType();
27067       unsigned Opc = 0;
27068       switch (IntNo) {
27069       default:
27070         llvm_unreachable("Unknown Intrinsic");
27071       case Intrinsic::x86_aadd32:
27072       case Intrinsic::x86_aadd64:
27073         Opc = X86ISD::AADD;
27074         break;
27075       case Intrinsic::x86_aand32:
27076       case Intrinsic::x86_aand64:
27077         Opc = X86ISD::AAND;
27078         break;
27079       case Intrinsic::x86_aor32:
27080       case Intrinsic::x86_aor64:
27081         Opc = X86ISD::AOR;
27082         break;
27083       case Intrinsic::x86_axor32:
27084       case Intrinsic::x86_axor64:
27085         Opc = X86ISD::AXOR;
27086         break;
27087       }
27088       MachineMemOperand *MMO = cast<MemSDNode>(Op)->getMemOperand();
27089       return DAG.getMemIntrinsicNode(Opc, DL, Op->getVTList(),
27090                                      {Chain, Op1, Op2}, VT, MMO);
27091     }
27092     case Intrinsic::x86_atomic_add_cc:
27093     case Intrinsic::x86_atomic_sub_cc:
27094     case Intrinsic::x86_atomic_or_cc:
27095     case Intrinsic::x86_atomic_and_cc:
27096     case Intrinsic::x86_atomic_xor_cc: {
27097       SDLoc DL(Op);
27098       SDValue Chain = Op.getOperand(0);
27099       SDValue Op1 = Op.getOperand(2);
27100       SDValue Op2 = Op.getOperand(3);
27101       X86::CondCode CC = (X86::CondCode)Op.getConstantOperandVal(4);
27102       MVT VT = Op2.getSimpleValueType();
27103       unsigned Opc = 0;
27104       switch (IntNo) {
27105       default:
27106         llvm_unreachable("Unknown Intrinsic");
27107       case Intrinsic::x86_atomic_add_cc:
27108         Opc = X86ISD::LADD;
27109         break;
27110       case Intrinsic::x86_atomic_sub_cc:
27111         Opc = X86ISD::LSUB;
27112         break;
27113       case Intrinsic::x86_atomic_or_cc:
27114         Opc = X86ISD::LOR;
27115         break;
27116       case Intrinsic::x86_atomic_and_cc:
27117         Opc = X86ISD::LAND;
27118         break;
27119       case Intrinsic::x86_atomic_xor_cc:
27120         Opc = X86ISD::LXOR;
27121         break;
27122       }
27123       MachineMemOperand *MMO = cast<MemIntrinsicSDNode>(Op)->getMemOperand();
27124       SDValue LockArith =
27125           DAG.getMemIntrinsicNode(Opc, DL, DAG.getVTList(MVT::i32, MVT::Other),
27126                                   {Chain, Op1, Op2}, VT, MMO);
27127       Chain = LockArith.getValue(1);
27128       return DAG.getMergeValues({getSETCC(CC, LockArith, DL, DAG), Chain}, DL);
27129     }
27130     }
27131     return SDValue();
27132   }
27133 
27134   SDLoc dl(Op);
27135   switch(IntrData->Type) {
27136   default: llvm_unreachable("Unknown Intrinsic Type");
27137   case RDSEED:
27138   case RDRAND: {
27139     // Emit the node with the right value type.
27140     SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::i32, MVT::Other);
27141     SDValue Result = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(0));
27142 
27143     // If the value returned by RDRAND/RDSEED was valid (CF=1), return 1.
27144     // Otherwise return the value from Rand, which is always 0, casted to i32.
27145     SDValue Ops[] = {DAG.getZExtOrTrunc(Result, dl, Op->getValueType(1)),
27146                      DAG.getConstant(1, dl, Op->getValueType(1)),
27147                      DAG.getTargetConstant(X86::COND_B, dl, MVT::i8),
27148                      SDValue(Result.getNode(), 1)};
27149     SDValue isValid = DAG.getNode(X86ISD::CMOV, dl, Op->getValueType(1), Ops);
27150 
27151     // Return { result, isValid, chain }.
27152     return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Result, isValid,
27153                        SDValue(Result.getNode(), 2));
27154   }
27155   case GATHER_AVX2: {
27156     SDValue Chain = Op.getOperand(0);
27157     SDValue Src   = Op.getOperand(2);
27158     SDValue Base  = Op.getOperand(3);
27159     SDValue Index = Op.getOperand(4);
27160     SDValue Mask  = Op.getOperand(5);
27161     SDValue Scale = Op.getOperand(6);
27162     return getAVX2GatherNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index,
27163                              Scale, Chain, Subtarget);
27164   }
27165   case GATHER: {
27166   //gather(v1, mask, index, base, scale);
27167     SDValue Chain = Op.getOperand(0);
27168     SDValue Src   = Op.getOperand(2);
27169     SDValue Base  = Op.getOperand(3);
27170     SDValue Index = Op.getOperand(4);
27171     SDValue Mask  = Op.getOperand(5);
27172     SDValue Scale = Op.getOperand(6);
27173     return getGatherNode(Op, DAG, Src, Mask, Base, Index, Scale,
27174                          Chain, Subtarget);
27175   }
27176   case SCATTER: {
27177   //scatter(base, mask, index, v1, scale);
27178     SDValue Chain = Op.getOperand(0);
27179     SDValue Base  = Op.getOperand(2);
27180     SDValue Mask  = Op.getOperand(3);
27181     SDValue Index = Op.getOperand(4);
27182     SDValue Src   = Op.getOperand(5);
27183     SDValue Scale = Op.getOperand(6);
27184     return getScatterNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index,
27185                           Scale, Chain, Subtarget);
27186   }
27187   case PREFETCH: {
27188     const APInt &HintVal = Op.getConstantOperandAPInt(6);
27189     assert((HintVal == 2 || HintVal == 3) &&
27190            "Wrong prefetch hint in intrinsic: should be 2 or 3");
27191     unsigned Opcode = (HintVal == 2 ? IntrData->Opc1 : IntrData->Opc0);
27192     SDValue Chain = Op.getOperand(0);
27193     SDValue Mask  = Op.getOperand(2);
27194     SDValue Index = Op.getOperand(3);
27195     SDValue Base  = Op.getOperand(4);
27196     SDValue Scale = Op.getOperand(5);
27197     return getPrefetchNode(Opcode, Op, DAG, Mask, Base, Index, Scale, Chain,
27198                            Subtarget);
27199   }
27200   // Read Time Stamp Counter (RDTSC) and Processor ID (RDTSCP).
27201   case RDTSC: {
27202     SmallVector<SDValue, 2> Results;
27203     getReadTimeStampCounter(Op.getNode(), dl, IntrData->Opc0, DAG, Subtarget,
27204                             Results);
27205     return DAG.getMergeValues(Results, dl);
27206   }
27207   // Read Performance Monitoring Counters.
27208   case RDPMC:
27209   // Read Processor Register.
27210   case RDPRU:
27211   // GetExtended Control Register.
27212   case XGETBV: {
27213     SmallVector<SDValue, 2> Results;
27214 
27215     // RDPMC uses ECX to select the index of the performance counter to read.
27216     // RDPRU uses ECX to select the processor register to read.
27217     // XGETBV uses ECX to select the index of the XCR register to return.
27218     // The result is stored into registers EDX:EAX.
27219     expandIntrinsicWChainHelper(Op.getNode(), dl, DAG, IntrData->Opc0, X86::ECX,
27220                                 Subtarget, Results);
27221     return DAG.getMergeValues(Results, dl);
27222   }
27223   // XTEST intrinsics.
27224   case XTEST: {
27225     SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::Other);
27226     SDValue InTrans = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(0));
27227 
27228     SDValue SetCC = getSETCC(X86::COND_NE, InTrans, dl, DAG);
27229     SDValue Ret = DAG.getNode(ISD::ZERO_EXTEND, dl, Op->getValueType(0), SetCC);
27230     return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(),
27231                        Ret, SDValue(InTrans.getNode(), 1));
27232   }
27233   case TRUNCATE_TO_MEM_VI8:
27234   case TRUNCATE_TO_MEM_VI16:
27235   case TRUNCATE_TO_MEM_VI32: {
27236     SDValue Mask = Op.getOperand(4);
27237     SDValue DataToTruncate = Op.getOperand(3);
27238     SDValue Addr = Op.getOperand(2);
27239     SDValue Chain = Op.getOperand(0);
27240 
27241     MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op);
27242     assert(MemIntr && "Expected MemIntrinsicSDNode!");
27243 
27244     EVT MemVT  = MemIntr->getMemoryVT();
27245 
27246     uint16_t TruncationOp = IntrData->Opc0;
27247     switch (TruncationOp) {
27248     case X86ISD::VTRUNC: {
27249       if (isAllOnesConstant(Mask)) // return just a truncate store
27250         return DAG.getTruncStore(Chain, dl, DataToTruncate, Addr, MemVT,
27251                                  MemIntr->getMemOperand());
27252 
27253       MVT MaskVT = MVT::getVectorVT(MVT::i1, MemVT.getVectorNumElements());
27254       SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
27255       SDValue Offset = DAG.getUNDEF(VMask.getValueType());
27256 
27257       return DAG.getMaskedStore(Chain, dl, DataToTruncate, Addr, Offset, VMask,
27258                                 MemVT, MemIntr->getMemOperand(), ISD::UNINDEXED,
27259                                 true /* truncating */);
27260     }
27261     case X86ISD::VTRUNCUS:
27262     case X86ISD::VTRUNCS: {
27263       bool IsSigned = (TruncationOp == X86ISD::VTRUNCS);
27264       if (isAllOnesConstant(Mask))
27265         return EmitTruncSStore(IsSigned, Chain, dl, DataToTruncate, Addr, MemVT,
27266                                MemIntr->getMemOperand(), DAG);
27267 
27268       MVT MaskVT = MVT::getVectorVT(MVT::i1, MemVT.getVectorNumElements());
27269       SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
27270 
27271       return EmitMaskedTruncSStore(IsSigned, Chain, dl, DataToTruncate, Addr,
27272                                    VMask, MemVT, MemIntr->getMemOperand(), DAG);
27273     }
27274     default:
27275       llvm_unreachable("Unsupported truncstore intrinsic");
27276     }
27277   }
27278   }
27279 }
27280 
LowerRETURNADDR(SDValue Op,SelectionDAG & DAG) const27281 SDValue X86TargetLowering::LowerRETURNADDR(SDValue Op,
27282                                            SelectionDAG &DAG) const {
27283   MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
27284   MFI.setReturnAddressIsTaken(true);
27285 
27286   if (verifyReturnAddressArgumentIsConstant(Op, DAG))
27287     return SDValue();
27288 
27289   unsigned Depth = Op.getConstantOperandVal(0);
27290   SDLoc dl(Op);
27291   EVT PtrVT = getPointerTy(DAG.getDataLayout());
27292 
27293   if (Depth > 0) {
27294     SDValue FrameAddr = LowerFRAMEADDR(Op, DAG);
27295     const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
27296     SDValue Offset = DAG.getConstant(RegInfo->getSlotSize(), dl, PtrVT);
27297     return DAG.getLoad(PtrVT, dl, DAG.getEntryNode(),
27298                        DAG.getNode(ISD::ADD, dl, PtrVT, FrameAddr, Offset),
27299                        MachinePointerInfo());
27300   }
27301 
27302   // Just load the return address.
27303   SDValue RetAddrFI = getReturnAddressFrameIndex(DAG);
27304   return DAG.getLoad(PtrVT, dl, DAG.getEntryNode(), RetAddrFI,
27305                      MachinePointerInfo());
27306 }
27307 
LowerADDROFRETURNADDR(SDValue Op,SelectionDAG & DAG) const27308 SDValue X86TargetLowering::LowerADDROFRETURNADDR(SDValue Op,
27309                                                  SelectionDAG &DAG) const {
27310   DAG.getMachineFunction().getFrameInfo().setReturnAddressIsTaken(true);
27311   return getReturnAddressFrameIndex(DAG);
27312 }
27313 
LowerFRAMEADDR(SDValue Op,SelectionDAG & DAG) const27314 SDValue X86TargetLowering::LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const {
27315   MachineFunction &MF = DAG.getMachineFunction();
27316   MachineFrameInfo &MFI = MF.getFrameInfo();
27317   X86MachineFunctionInfo *FuncInfo = MF.getInfo<X86MachineFunctionInfo>();
27318   const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
27319   EVT VT = Op.getValueType();
27320 
27321   MFI.setFrameAddressIsTaken(true);
27322 
27323   if (MF.getTarget().getMCAsmInfo()->usesWindowsCFI()) {
27324     // Depth > 0 makes no sense on targets which use Windows unwind codes.  It
27325     // is not possible to crawl up the stack without looking at the unwind codes
27326     // simultaneously.
27327     int FrameAddrIndex = FuncInfo->getFAIndex();
27328     if (!FrameAddrIndex) {
27329       // Set up a frame object for the return address.
27330       unsigned SlotSize = RegInfo->getSlotSize();
27331       FrameAddrIndex = MF.getFrameInfo().CreateFixedObject(
27332           SlotSize, /*SPOffset=*/0, /*IsImmutable=*/false);
27333       FuncInfo->setFAIndex(FrameAddrIndex);
27334     }
27335     return DAG.getFrameIndex(FrameAddrIndex, VT);
27336   }
27337 
27338   unsigned FrameReg =
27339       RegInfo->getPtrSizedFrameRegister(DAG.getMachineFunction());
27340   SDLoc dl(Op);  // FIXME probably not meaningful
27341   unsigned Depth = Op.getConstantOperandVal(0);
27342   assert(((FrameReg == X86::RBP && VT == MVT::i64) ||
27343           (FrameReg == X86::EBP && VT == MVT::i32)) &&
27344          "Invalid Frame Register!");
27345   SDValue FrameAddr = DAG.getCopyFromReg(DAG.getEntryNode(), dl, FrameReg, VT);
27346   while (Depth--)
27347     FrameAddr = DAG.getLoad(VT, dl, DAG.getEntryNode(), FrameAddr,
27348                             MachinePointerInfo());
27349   return FrameAddr;
27350 }
27351 
27352 // FIXME? Maybe this could be a TableGen attribute on some registers and
27353 // this table could be generated automatically from RegInfo.
getRegisterByName(const char * RegName,LLT VT,const MachineFunction & MF) const27354 Register X86TargetLowering::getRegisterByName(const char* RegName, LLT VT,
27355                                               const MachineFunction &MF) const {
27356   const TargetFrameLowering &TFI = *Subtarget.getFrameLowering();
27357 
27358   Register Reg = StringSwitch<unsigned>(RegName)
27359                      .Case("esp", X86::ESP)
27360                      .Case("rsp", X86::RSP)
27361                      .Case("ebp", X86::EBP)
27362                      .Case("rbp", X86::RBP)
27363                      .Case("r14", X86::R14)
27364                      .Case("r15", X86::R15)
27365                      .Default(0);
27366 
27367   if (Reg == X86::EBP || Reg == X86::RBP) {
27368     if (!TFI.hasFP(MF))
27369       report_fatal_error("register " + StringRef(RegName) +
27370                          " is allocatable: function has no frame pointer");
27371 #ifndef NDEBUG
27372     else {
27373       const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
27374       Register FrameReg = RegInfo->getPtrSizedFrameRegister(MF);
27375       assert((FrameReg == X86::EBP || FrameReg == X86::RBP) &&
27376              "Invalid Frame Register!");
27377     }
27378 #endif
27379   }
27380 
27381   if (Reg)
27382     return Reg;
27383 
27384   report_fatal_error("Invalid register name global variable");
27385 }
27386 
LowerFRAME_TO_ARGS_OFFSET(SDValue Op,SelectionDAG & DAG) const27387 SDValue X86TargetLowering::LowerFRAME_TO_ARGS_OFFSET(SDValue Op,
27388                                                      SelectionDAG &DAG) const {
27389   const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
27390   return DAG.getIntPtrConstant(2 * RegInfo->getSlotSize(), SDLoc(Op));
27391 }
27392 
getExceptionPointerRegister(const Constant * PersonalityFn) const27393 Register X86TargetLowering::getExceptionPointerRegister(
27394     const Constant *PersonalityFn) const {
27395   if (classifyEHPersonality(PersonalityFn) == EHPersonality::CoreCLR)
27396     return Subtarget.isTarget64BitLP64() ? X86::RDX : X86::EDX;
27397 
27398   return Subtarget.isTarget64BitLP64() ? X86::RAX : X86::EAX;
27399 }
27400 
getExceptionSelectorRegister(const Constant * PersonalityFn) const27401 Register X86TargetLowering::getExceptionSelectorRegister(
27402     const Constant *PersonalityFn) const {
27403   // Funclet personalities don't use selectors (the runtime does the selection).
27404   if (isFuncletEHPersonality(classifyEHPersonality(PersonalityFn)))
27405     return X86::NoRegister;
27406   return Subtarget.isTarget64BitLP64() ? X86::RDX : X86::EDX;
27407 }
27408 
needsFixedCatchObjects() const27409 bool X86TargetLowering::needsFixedCatchObjects() const {
27410   return Subtarget.isTargetWin64();
27411 }
27412 
LowerEH_RETURN(SDValue Op,SelectionDAG & DAG) const27413 SDValue X86TargetLowering::LowerEH_RETURN(SDValue Op, SelectionDAG &DAG) const {
27414   SDValue Chain     = Op.getOperand(0);
27415   SDValue Offset    = Op.getOperand(1);
27416   SDValue Handler   = Op.getOperand(2);
27417   SDLoc dl      (Op);
27418 
27419   EVT PtrVT = getPointerTy(DAG.getDataLayout());
27420   const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
27421   Register FrameReg = RegInfo->getFrameRegister(DAG.getMachineFunction());
27422   assert(((FrameReg == X86::RBP && PtrVT == MVT::i64) ||
27423           (FrameReg == X86::EBP && PtrVT == MVT::i32)) &&
27424          "Invalid Frame Register!");
27425   SDValue Frame = DAG.getCopyFromReg(DAG.getEntryNode(), dl, FrameReg, PtrVT);
27426   Register StoreAddrReg = (PtrVT == MVT::i64) ? X86::RCX : X86::ECX;
27427 
27428   SDValue StoreAddr = DAG.getNode(ISD::ADD, dl, PtrVT, Frame,
27429                                  DAG.getIntPtrConstant(RegInfo->getSlotSize(),
27430                                                        dl));
27431   StoreAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StoreAddr, Offset);
27432   Chain = DAG.getStore(Chain, dl, Handler, StoreAddr, MachinePointerInfo());
27433   Chain = DAG.getCopyToReg(Chain, dl, StoreAddrReg, StoreAddr);
27434 
27435   return DAG.getNode(X86ISD::EH_RETURN, dl, MVT::Other, Chain,
27436                      DAG.getRegister(StoreAddrReg, PtrVT));
27437 }
27438 
lowerEH_SJLJ_SETJMP(SDValue Op,SelectionDAG & DAG) const27439 SDValue X86TargetLowering::lowerEH_SJLJ_SETJMP(SDValue Op,
27440                                                SelectionDAG &DAG) const {
27441   SDLoc DL(Op);
27442   // If the subtarget is not 64bit, we may need the global base reg
27443   // after isel expand pseudo, i.e., after CGBR pass ran.
27444   // Therefore, ask for the GlobalBaseReg now, so that the pass
27445   // inserts the code for us in case we need it.
27446   // Otherwise, we will end up in a situation where we will
27447   // reference a virtual register that is not defined!
27448   if (!Subtarget.is64Bit()) {
27449     const X86InstrInfo *TII = Subtarget.getInstrInfo();
27450     (void)TII->getGlobalBaseReg(&DAG.getMachineFunction());
27451   }
27452   return DAG.getNode(X86ISD::EH_SJLJ_SETJMP, DL,
27453                      DAG.getVTList(MVT::i32, MVT::Other),
27454                      Op.getOperand(0), Op.getOperand(1));
27455 }
27456 
lowerEH_SJLJ_LONGJMP(SDValue Op,SelectionDAG & DAG) const27457 SDValue X86TargetLowering::lowerEH_SJLJ_LONGJMP(SDValue Op,
27458                                                 SelectionDAG &DAG) const {
27459   SDLoc DL(Op);
27460   return DAG.getNode(X86ISD::EH_SJLJ_LONGJMP, DL, MVT::Other,
27461                      Op.getOperand(0), Op.getOperand(1));
27462 }
27463 
lowerEH_SJLJ_SETUP_DISPATCH(SDValue Op,SelectionDAG & DAG) const27464 SDValue X86TargetLowering::lowerEH_SJLJ_SETUP_DISPATCH(SDValue Op,
27465                                                        SelectionDAG &DAG) const {
27466   SDLoc DL(Op);
27467   return DAG.getNode(X86ISD::EH_SJLJ_SETUP_DISPATCH, DL, MVT::Other,
27468                      Op.getOperand(0));
27469 }
27470 
LowerADJUST_TRAMPOLINE(SDValue Op,SelectionDAG & DAG)27471 static SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) {
27472   return Op.getOperand(0);
27473 }
27474 
LowerINIT_TRAMPOLINE(SDValue Op,SelectionDAG & DAG) const27475 SDValue X86TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
27476                                                 SelectionDAG &DAG) const {
27477   SDValue Root = Op.getOperand(0);
27478   SDValue Trmp = Op.getOperand(1); // trampoline
27479   SDValue FPtr = Op.getOperand(2); // nested function
27480   SDValue Nest = Op.getOperand(3); // 'nest' parameter value
27481   SDLoc dl (Op);
27482 
27483   const Value *TrmpAddr = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
27484   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
27485 
27486   if (Subtarget.is64Bit()) {
27487     SDValue OutChains[6];
27488 
27489     // Large code-model.
27490     const unsigned char JMP64r  = 0xFF; // 64-bit jmp through register opcode.
27491     const unsigned char MOV64ri = 0xB8; // X86::MOV64ri opcode.
27492 
27493     const unsigned char N86R10 = TRI->getEncodingValue(X86::R10) & 0x7;
27494     const unsigned char N86R11 = TRI->getEncodingValue(X86::R11) & 0x7;
27495 
27496     const unsigned char REX_WB = 0x40 | 0x08 | 0x01; // REX prefix
27497 
27498     // Load the pointer to the nested function into R11.
27499     unsigned OpCode = ((MOV64ri | N86R11) << 8) | REX_WB; // movabsq r11
27500     SDValue Addr = Trmp;
27501     OutChains[0] = DAG.getStore(Root, dl, DAG.getConstant(OpCode, dl, MVT::i16),
27502                                 Addr, MachinePointerInfo(TrmpAddr));
27503 
27504     Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
27505                        DAG.getConstant(2, dl, MVT::i64));
27506     OutChains[1] = DAG.getStore(Root, dl, FPtr, Addr,
27507                                 MachinePointerInfo(TrmpAddr, 2), Align(2));
27508 
27509     // Load the 'nest' parameter value into R10.
27510     // R10 is specified in X86CallingConv.td
27511     OpCode = ((MOV64ri | N86R10) << 8) | REX_WB; // movabsq r10
27512     Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
27513                        DAG.getConstant(10, dl, MVT::i64));
27514     OutChains[2] = DAG.getStore(Root, dl, DAG.getConstant(OpCode, dl, MVT::i16),
27515                                 Addr, MachinePointerInfo(TrmpAddr, 10));
27516 
27517     Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
27518                        DAG.getConstant(12, dl, MVT::i64));
27519     OutChains[3] = DAG.getStore(Root, dl, Nest, Addr,
27520                                 MachinePointerInfo(TrmpAddr, 12), Align(2));
27521 
27522     // Jump to the nested function.
27523     OpCode = (JMP64r << 8) | REX_WB; // jmpq *...
27524     Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
27525                        DAG.getConstant(20, dl, MVT::i64));
27526     OutChains[4] = DAG.getStore(Root, dl, DAG.getConstant(OpCode, dl, MVT::i16),
27527                                 Addr, MachinePointerInfo(TrmpAddr, 20));
27528 
27529     unsigned char ModRM = N86R11 | (4 << 3) | (3 << 6); // ...r11
27530     Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
27531                        DAG.getConstant(22, dl, MVT::i64));
27532     OutChains[5] = DAG.getStore(Root, dl, DAG.getConstant(ModRM, dl, MVT::i8),
27533                                 Addr, MachinePointerInfo(TrmpAddr, 22));
27534 
27535     return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
27536   } else {
27537     const Function *Func =
27538       cast<Function>(cast<SrcValueSDNode>(Op.getOperand(5))->getValue());
27539     CallingConv::ID CC = Func->getCallingConv();
27540     unsigned NestReg;
27541 
27542     switch (CC) {
27543     default:
27544       llvm_unreachable("Unsupported calling convention");
27545     case CallingConv::C:
27546     case CallingConv::X86_StdCall: {
27547       // Pass 'nest' parameter in ECX.
27548       // Must be kept in sync with X86CallingConv.td
27549       NestReg = X86::ECX;
27550 
27551       // Check that ECX wasn't needed by an 'inreg' parameter.
27552       FunctionType *FTy = Func->getFunctionType();
27553       const AttributeList &Attrs = Func->getAttributes();
27554 
27555       if (!Attrs.isEmpty() && !Func->isVarArg()) {
27556         unsigned InRegCount = 0;
27557         unsigned Idx = 0;
27558 
27559         for (FunctionType::param_iterator I = FTy->param_begin(),
27560              E = FTy->param_end(); I != E; ++I, ++Idx)
27561           if (Attrs.hasParamAttr(Idx, Attribute::InReg)) {
27562             const DataLayout &DL = DAG.getDataLayout();
27563             // FIXME: should only count parameters that are lowered to integers.
27564             InRegCount += (DL.getTypeSizeInBits(*I) + 31) / 32;
27565           }
27566 
27567         if (InRegCount > 2) {
27568           report_fatal_error("Nest register in use - reduce number of inreg"
27569                              " parameters!");
27570         }
27571       }
27572       break;
27573     }
27574     case CallingConv::X86_FastCall:
27575     case CallingConv::X86_ThisCall:
27576     case CallingConv::Fast:
27577     case CallingConv::Tail:
27578     case CallingConv::SwiftTail:
27579       // Pass 'nest' parameter in EAX.
27580       // Must be kept in sync with X86CallingConv.td
27581       NestReg = X86::EAX;
27582       break;
27583     }
27584 
27585     SDValue OutChains[4];
27586     SDValue Addr, Disp;
27587 
27588     Addr = DAG.getNode(ISD::ADD, dl, MVT::i32, Trmp,
27589                        DAG.getConstant(10, dl, MVT::i32));
27590     Disp = DAG.getNode(ISD::SUB, dl, MVT::i32, FPtr, Addr);
27591 
27592     // This is storing the opcode for MOV32ri.
27593     const unsigned char MOV32ri = 0xB8; // X86::MOV32ri's opcode byte.
27594     const unsigned char N86Reg = TRI->getEncodingValue(NestReg) & 0x7;
27595     OutChains[0] =
27596         DAG.getStore(Root, dl, DAG.getConstant(MOV32ri | N86Reg, dl, MVT::i8),
27597                      Trmp, MachinePointerInfo(TrmpAddr));
27598 
27599     Addr = DAG.getNode(ISD::ADD, dl, MVT::i32, Trmp,
27600                        DAG.getConstant(1, dl, MVT::i32));
27601     OutChains[1] = DAG.getStore(Root, dl, Nest, Addr,
27602                                 MachinePointerInfo(TrmpAddr, 1), Align(1));
27603 
27604     const unsigned char JMP = 0xE9; // jmp <32bit dst> opcode.
27605     Addr = DAG.getNode(ISD::ADD, dl, MVT::i32, Trmp,
27606                        DAG.getConstant(5, dl, MVT::i32));
27607     OutChains[2] =
27608         DAG.getStore(Root, dl, DAG.getConstant(JMP, dl, MVT::i8), Addr,
27609                      MachinePointerInfo(TrmpAddr, 5), Align(1));
27610 
27611     Addr = DAG.getNode(ISD::ADD, dl, MVT::i32, Trmp,
27612                        DAG.getConstant(6, dl, MVT::i32));
27613     OutChains[3] = DAG.getStore(Root, dl, Disp, Addr,
27614                                 MachinePointerInfo(TrmpAddr, 6), Align(1));
27615 
27616     return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
27617   }
27618 }
27619 
LowerGET_ROUNDING(SDValue Op,SelectionDAG & DAG) const27620 SDValue X86TargetLowering::LowerGET_ROUNDING(SDValue Op,
27621                                              SelectionDAG &DAG) const {
27622   /*
27623    The rounding mode is in bits 11:10 of FPSR, and has the following
27624    settings:
27625      00 Round to nearest
27626      01 Round to -inf
27627      10 Round to +inf
27628      11 Round to 0
27629 
27630   GET_ROUNDING, on the other hand, expects the following:
27631     -1 Undefined
27632      0 Round to 0
27633      1 Round to nearest
27634      2 Round to +inf
27635      3 Round to -inf
27636 
27637   To perform the conversion, we use a packed lookup table of the four 2-bit
27638   values that we can index by FPSP[11:10]
27639     0x2d --> (0b00,10,11,01) --> (0,2,3,1) >> FPSR[11:10]
27640 
27641     (0x2d >> ((FPSR & 0xc00) >> 9)) & 3
27642   */
27643 
27644   MachineFunction &MF = DAG.getMachineFunction();
27645   MVT VT = Op.getSimpleValueType();
27646   SDLoc DL(Op);
27647 
27648   // Save FP Control Word to stack slot
27649   int SSFI = MF.getFrameInfo().CreateStackObject(2, Align(2), false);
27650   SDValue StackSlot =
27651       DAG.getFrameIndex(SSFI, getPointerTy(DAG.getDataLayout()));
27652 
27653   MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, SSFI);
27654 
27655   SDValue Chain = Op.getOperand(0);
27656   SDValue Ops[] = {Chain, StackSlot};
27657   Chain = DAG.getMemIntrinsicNode(X86ISD::FNSTCW16m, DL,
27658                                   DAG.getVTList(MVT::Other), Ops, MVT::i16, MPI,
27659                                   Align(2), MachineMemOperand::MOStore);
27660 
27661   // Load FP Control Word from stack slot
27662   SDValue CWD = DAG.getLoad(MVT::i16, DL, Chain, StackSlot, MPI, Align(2));
27663   Chain = CWD.getValue(1);
27664 
27665   // Mask and turn the control bits into a shift for the lookup table.
27666   SDValue Shift =
27667     DAG.getNode(ISD::SRL, DL, MVT::i16,
27668                 DAG.getNode(ISD::AND, DL, MVT::i16,
27669                             CWD, DAG.getConstant(0xc00, DL, MVT::i16)),
27670                 DAG.getConstant(9, DL, MVT::i8));
27671   Shift = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Shift);
27672 
27673   SDValue LUT = DAG.getConstant(0x2d, DL, MVT::i32);
27674   SDValue RetVal =
27675     DAG.getNode(ISD::AND, DL, MVT::i32,
27676                 DAG.getNode(ISD::SRL, DL, MVT::i32, LUT, Shift),
27677                 DAG.getConstant(3, DL, MVT::i32));
27678 
27679   RetVal = DAG.getZExtOrTrunc(RetVal, DL, VT);
27680 
27681   return DAG.getMergeValues({RetVal, Chain}, DL);
27682 }
27683 
LowerSET_ROUNDING(SDValue Op,SelectionDAG & DAG) const27684 SDValue X86TargetLowering::LowerSET_ROUNDING(SDValue Op,
27685                                              SelectionDAG &DAG) const {
27686   MachineFunction &MF = DAG.getMachineFunction();
27687   SDLoc DL(Op);
27688   SDValue Chain = Op.getNode()->getOperand(0);
27689 
27690   // FP control word may be set only from data in memory. So we need to allocate
27691   // stack space to save/load FP control word.
27692   int OldCWFrameIdx = MF.getFrameInfo().CreateStackObject(4, Align(4), false);
27693   SDValue StackSlot =
27694       DAG.getFrameIndex(OldCWFrameIdx, getPointerTy(DAG.getDataLayout()));
27695   MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, OldCWFrameIdx);
27696   MachineMemOperand *MMO =
27697       MF.getMachineMemOperand(MPI, MachineMemOperand::MOStore, 2, Align(2));
27698 
27699   // Store FP control word into memory.
27700   SDValue Ops[] = {Chain, StackSlot};
27701   Chain = DAG.getMemIntrinsicNode(
27702       X86ISD::FNSTCW16m, DL, DAG.getVTList(MVT::Other), Ops, MVT::i16, MMO);
27703 
27704   // Load FP Control Word from stack slot and clear RM field (bits 11:10).
27705   SDValue CWD = DAG.getLoad(MVT::i16, DL, Chain, StackSlot, MPI);
27706   Chain = CWD.getValue(1);
27707   CWD = DAG.getNode(ISD::AND, DL, MVT::i16, CWD.getValue(0),
27708                     DAG.getConstant(0xf3ff, DL, MVT::i16));
27709 
27710   // Calculate new rounding mode.
27711   SDValue NewRM = Op.getNode()->getOperand(1);
27712   SDValue RMBits;
27713   if (auto *CVal = dyn_cast<ConstantSDNode>(NewRM)) {
27714     uint64_t RM = CVal->getZExtValue();
27715     int FieldVal;
27716     switch (static_cast<RoundingMode>(RM)) {
27717     // clang-format off
27718     case RoundingMode::NearestTiesToEven: FieldVal = X86::rmToNearest; break;
27719     case RoundingMode::TowardNegative:    FieldVal = X86::rmDownward; break;
27720     case RoundingMode::TowardPositive:    FieldVal = X86::rmUpward; break;
27721     case RoundingMode::TowardZero:        FieldVal = X86::rmTowardZero; break;
27722     default:
27723       llvm_unreachable("rounding mode is not supported by X86 hardware");
27724     // clang-format on
27725     }
27726     RMBits = DAG.getConstant(FieldVal, DL, MVT::i16);
27727   } else {
27728     // Need to convert argument into bits of control word:
27729     //    0 Round to 0       -> 11
27730     //    1 Round to nearest -> 00
27731     //    2 Round to +inf    -> 10
27732     //    3 Round to -inf    -> 01
27733     // The 2-bit value needs then to be shifted so that it occupies bits 11:10.
27734     // To make the conversion, put all these values into a value 0xc9 and shift
27735     // it left depending on the rounding mode:
27736     //    (0xc9 << 4) & 0xc00 = X86::rmTowardZero
27737     //    (0xc9 << 6) & 0xc00 = X86::rmToNearest
27738     //    ...
27739     // (0xc9 << (2 * NewRM + 4)) & 0xc00
27740     SDValue ShiftValue =
27741         DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
27742                     DAG.getNode(ISD::ADD, DL, MVT::i32,
27743                                 DAG.getNode(ISD::SHL, DL, MVT::i32, NewRM,
27744                                             DAG.getConstant(1, DL, MVT::i8)),
27745                                 DAG.getConstant(4, DL, MVT::i32)));
27746     SDValue Shifted =
27747         DAG.getNode(ISD::SHL, DL, MVT::i16, DAG.getConstant(0xc9, DL, MVT::i16),
27748                     ShiftValue);
27749     RMBits = DAG.getNode(ISD::AND, DL, MVT::i16, Shifted,
27750                          DAG.getConstant(0xc00, DL, MVT::i16));
27751   }
27752 
27753   // Update rounding mode bits and store the new FP Control Word into stack.
27754   CWD = DAG.getNode(ISD::OR, DL, MVT::i16, CWD, RMBits);
27755   Chain = DAG.getStore(Chain, DL, CWD, StackSlot, MPI, Align(2));
27756 
27757   // Load FP control word from the slot.
27758   SDValue OpsLD[] = {Chain, StackSlot};
27759   MachineMemOperand *MMOL =
27760       MF.getMachineMemOperand(MPI, MachineMemOperand::MOLoad, 2, Align(2));
27761   Chain = DAG.getMemIntrinsicNode(
27762       X86ISD::FLDCW16m, DL, DAG.getVTList(MVT::Other), OpsLD, MVT::i16, MMOL);
27763 
27764   // If target supports SSE, set MXCSR as well. Rounding mode is encoded in the
27765   // same way but in bits 14:13.
27766   if (Subtarget.hasSSE1()) {
27767     // Store MXCSR into memory.
27768     Chain = DAG.getNode(
27769         ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Chain,
27770         DAG.getTargetConstant(Intrinsic::x86_sse_stmxcsr, DL, MVT::i32),
27771         StackSlot);
27772 
27773     // Load MXCSR from stack slot and clear RM field (bits 14:13).
27774     SDValue CWD = DAG.getLoad(MVT::i32, DL, Chain, StackSlot, MPI);
27775     Chain = CWD.getValue(1);
27776     CWD = DAG.getNode(ISD::AND, DL, MVT::i32, CWD.getValue(0),
27777                       DAG.getConstant(0xffff9fff, DL, MVT::i32));
27778 
27779     // Shift X87 RM bits from 11:10 to 14:13.
27780     RMBits = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, RMBits);
27781     RMBits = DAG.getNode(ISD::SHL, DL, MVT::i32, RMBits,
27782                          DAG.getConstant(3, DL, MVT::i8));
27783 
27784     // Update rounding mode bits and store the new FP Control Word into stack.
27785     CWD = DAG.getNode(ISD::OR, DL, MVT::i32, CWD, RMBits);
27786     Chain = DAG.getStore(Chain, DL, CWD, StackSlot, MPI, Align(4));
27787 
27788     // Load MXCSR from the slot.
27789     Chain = DAG.getNode(
27790         ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Chain,
27791         DAG.getTargetConstant(Intrinsic::x86_sse_ldmxcsr, DL, MVT::i32),
27792         StackSlot);
27793   }
27794 
27795   return Chain;
27796 }
27797 
27798 const unsigned X87StateSize = 28;
27799 const unsigned FPStateSize = 32;
27800 [[maybe_unused]] const unsigned FPStateSizeInBits = FPStateSize * 8;
27801 
LowerGET_FPENV_MEM(SDValue Op,SelectionDAG & DAG) const27802 SDValue X86TargetLowering::LowerGET_FPENV_MEM(SDValue Op,
27803                                               SelectionDAG &DAG) const {
27804   MachineFunction &MF = DAG.getMachineFunction();
27805   SDLoc DL(Op);
27806   SDValue Chain = Op->getOperand(0);
27807   SDValue Ptr = Op->getOperand(1);
27808   auto *Node = cast<FPStateAccessSDNode>(Op);
27809   EVT MemVT = Node->getMemoryVT();
27810   assert(MemVT.getSizeInBits() == FPStateSizeInBits);
27811   MachineMemOperand *MMO = cast<FPStateAccessSDNode>(Op)->getMemOperand();
27812 
27813   // Get x87 state, if it presents.
27814   if (Subtarget.hasX87()) {
27815     Chain =
27816         DAG.getMemIntrinsicNode(X86ISD::FNSTENVm, DL, DAG.getVTList(MVT::Other),
27817                                 {Chain, Ptr}, MemVT, MMO);
27818 
27819     // FNSTENV changes the exception mask, so load back the stored environment.
27820     MachineMemOperand::Flags NewFlags =
27821         MachineMemOperand::MOLoad |
27822         (MMO->getFlags() & ~MachineMemOperand::MOStore);
27823     MMO = MF.getMachineMemOperand(MMO, NewFlags);
27824     Chain =
27825         DAG.getMemIntrinsicNode(X86ISD::FLDENVm, DL, DAG.getVTList(MVT::Other),
27826                                 {Chain, Ptr}, MemVT, MMO);
27827   }
27828 
27829   // If target supports SSE, get MXCSR as well.
27830   if (Subtarget.hasSSE1()) {
27831     // Get pointer to the MXCSR location in memory.
27832     MVT PtrVT = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
27833     SDValue MXCSRAddr = DAG.getNode(ISD::ADD, DL, PtrVT, Ptr,
27834                                     DAG.getConstant(X87StateSize, DL, PtrVT));
27835     // Store MXCSR into memory.
27836     Chain = DAG.getNode(
27837         ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Chain,
27838         DAG.getTargetConstant(Intrinsic::x86_sse_stmxcsr, DL, MVT::i32),
27839         MXCSRAddr);
27840   }
27841 
27842   return Chain;
27843 }
27844 
createSetFPEnvNodes(SDValue Ptr,SDValue Chain,const SDLoc & DL,EVT MemVT,MachineMemOperand * MMO,SelectionDAG & DAG,const X86Subtarget & Subtarget)27845 static SDValue createSetFPEnvNodes(SDValue Ptr, SDValue Chain, const SDLoc &DL,
27846                                    EVT MemVT, MachineMemOperand *MMO,
27847                                    SelectionDAG &DAG,
27848                                    const X86Subtarget &Subtarget) {
27849   // Set x87 state, if it presents.
27850   if (Subtarget.hasX87())
27851     Chain =
27852         DAG.getMemIntrinsicNode(X86ISD::FLDENVm, DL, DAG.getVTList(MVT::Other),
27853                                 {Chain, Ptr}, MemVT, MMO);
27854   // If target supports SSE, set MXCSR as well.
27855   if (Subtarget.hasSSE1()) {
27856     // Get pointer to the MXCSR location in memory.
27857     MVT PtrVT = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
27858     SDValue MXCSRAddr = DAG.getNode(ISD::ADD, DL, PtrVT, Ptr,
27859                                     DAG.getConstant(X87StateSize, DL, PtrVT));
27860     // Load MXCSR from memory.
27861     Chain = DAG.getNode(
27862         ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Chain,
27863         DAG.getTargetConstant(Intrinsic::x86_sse_ldmxcsr, DL, MVT::i32),
27864         MXCSRAddr);
27865   }
27866   return Chain;
27867 }
27868 
LowerSET_FPENV_MEM(SDValue Op,SelectionDAG & DAG) const27869 SDValue X86TargetLowering::LowerSET_FPENV_MEM(SDValue Op,
27870                                               SelectionDAG &DAG) const {
27871   SDLoc DL(Op);
27872   SDValue Chain = Op->getOperand(0);
27873   SDValue Ptr = Op->getOperand(1);
27874   auto *Node = cast<FPStateAccessSDNode>(Op);
27875   EVT MemVT = Node->getMemoryVT();
27876   assert(MemVT.getSizeInBits() == FPStateSizeInBits);
27877   MachineMemOperand *MMO = cast<FPStateAccessSDNode>(Op)->getMemOperand();
27878   return createSetFPEnvNodes(Ptr, Chain, DL, MemVT, MMO, DAG, Subtarget);
27879 }
27880 
LowerRESET_FPENV(SDValue Op,SelectionDAG & DAG) const27881 SDValue X86TargetLowering::LowerRESET_FPENV(SDValue Op,
27882                                             SelectionDAG &DAG) const {
27883   MachineFunction &MF = DAG.getMachineFunction();
27884   SDLoc DL(Op);
27885   SDValue Chain = Op.getNode()->getOperand(0);
27886 
27887   IntegerType *ItemTy = Type::getInt32Ty(*DAG.getContext());
27888   ArrayType *FPEnvTy = ArrayType::get(ItemTy, 8);
27889   SmallVector<Constant *, 8> FPEnvVals;
27890 
27891   // x87 FPU Control Word: mask all floating-point exceptions, sets rounding to
27892   // nearest. FPU precision is set to 53 bits on Windows and 64 bits otherwise
27893   // for compatibility with glibc.
27894   unsigned X87CW = Subtarget.isTargetWindowsMSVC() ? 0x27F : 0x37F;
27895   FPEnvVals.push_back(ConstantInt::get(ItemTy, X87CW));
27896   Constant *Zero = ConstantInt::get(ItemTy, 0);
27897   for (unsigned I = 0; I < 6; ++I)
27898     FPEnvVals.push_back(Zero);
27899 
27900   // MXCSR: mask all floating-point exceptions, sets rounding to nearest, clear
27901   // all exceptions, sets DAZ and FTZ to 0.
27902   FPEnvVals.push_back(ConstantInt::get(ItemTy, 0x1F80));
27903   Constant *FPEnvBits = ConstantArray::get(FPEnvTy, FPEnvVals);
27904   MVT PtrVT = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
27905   SDValue Env = DAG.getConstantPool(FPEnvBits, PtrVT);
27906   MachinePointerInfo MPI =
27907       MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
27908   MachineMemOperand *MMO = MF.getMachineMemOperand(
27909       MPI, MachineMemOperand::MOStore, X87StateSize, Align(4));
27910 
27911   return createSetFPEnvNodes(Env, Chain, DL, MVT::i32, MMO, DAG, Subtarget);
27912 }
27913 
27914 /// Lower a vector CTLZ using native supported vector CTLZ instruction.
27915 //
27916 // i8/i16 vector implemented using dword LZCNT vector instruction
27917 // ( sub(trunc(lzcnt(zext32(x)))) ). In case zext32(x) is illegal,
27918 // split the vector, perform operation on it's Lo a Hi part and
27919 // concatenate the results.
LowerVectorCTLZ_AVX512CDI(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)27920 static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG,
27921                                          const X86Subtarget &Subtarget) {
27922   assert(Op.getOpcode() == ISD::CTLZ);
27923   SDLoc dl(Op);
27924   MVT VT = Op.getSimpleValueType();
27925   MVT EltVT = VT.getVectorElementType();
27926   unsigned NumElems = VT.getVectorNumElements();
27927 
27928   assert((EltVT == MVT::i8 || EltVT == MVT::i16) &&
27929           "Unsupported element type");
27930 
27931   // Split vector, it's Lo and Hi parts will be handled in next iteration.
27932   if (NumElems > 16 ||
27933       (NumElems == 16 && !Subtarget.canExtendTo512DQ()))
27934     return splitVectorIntUnary(Op, DAG, dl);
27935 
27936   MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
27937   assert((NewVT.is256BitVector() || NewVT.is512BitVector()) &&
27938           "Unsupported value type for operation");
27939 
27940   // Use native supported vector instruction vplzcntd.
27941   Op = DAG.getNode(ISD::ZERO_EXTEND, dl, NewVT, Op.getOperand(0));
27942   SDValue CtlzNode = DAG.getNode(ISD::CTLZ, dl, NewVT, Op);
27943   SDValue TruncNode = DAG.getNode(ISD::TRUNCATE, dl, VT, CtlzNode);
27944   SDValue Delta = DAG.getConstant(32 - EltVT.getSizeInBits(), dl, VT);
27945 
27946   return DAG.getNode(ISD::SUB, dl, VT, TruncNode, Delta);
27947 }
27948 
27949 // Lower CTLZ using a PSHUFB lookup table implementation.
LowerVectorCTLZInRegLUT(SDValue Op,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)27950 static SDValue LowerVectorCTLZInRegLUT(SDValue Op, const SDLoc &DL,
27951                                        const X86Subtarget &Subtarget,
27952                                        SelectionDAG &DAG) {
27953   MVT VT = Op.getSimpleValueType();
27954   int NumElts = VT.getVectorNumElements();
27955   int NumBytes = NumElts * (VT.getScalarSizeInBits() / 8);
27956   MVT CurrVT = MVT::getVectorVT(MVT::i8, NumBytes);
27957 
27958   // Per-nibble leading zero PSHUFB lookup table.
27959   const int LUT[16] = {/* 0 */ 4, /* 1 */ 3, /* 2 */ 2, /* 3 */ 2,
27960                        /* 4 */ 1, /* 5 */ 1, /* 6 */ 1, /* 7 */ 1,
27961                        /* 8 */ 0, /* 9 */ 0, /* a */ 0, /* b */ 0,
27962                        /* c */ 0, /* d */ 0, /* e */ 0, /* f */ 0};
27963 
27964   SmallVector<SDValue, 64> LUTVec;
27965   for (int i = 0; i < NumBytes; ++i)
27966     LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8));
27967   SDValue InRegLUT = DAG.getBuildVector(CurrVT, DL, LUTVec);
27968 
27969   // Begin by bitcasting the input to byte vector, then split those bytes
27970   // into lo/hi nibbles and use the PSHUFB LUT to perform CTLZ on each of them.
27971   // If the hi input nibble is zero then we add both results together, otherwise
27972   // we just take the hi result (by masking the lo result to zero before the
27973   // add).
27974   SDValue Op0 = DAG.getBitcast(CurrVT, Op.getOperand(0));
27975   SDValue Zero = DAG.getConstant(0, DL, CurrVT);
27976 
27977   SDValue NibbleShift = DAG.getConstant(0x4, DL, CurrVT);
27978   SDValue Lo = Op0;
27979   SDValue Hi = DAG.getNode(ISD::SRL, DL, CurrVT, Op0, NibbleShift);
27980   SDValue HiZ;
27981   if (CurrVT.is512BitVector()) {
27982     MVT MaskVT = MVT::getVectorVT(MVT::i1, CurrVT.getVectorNumElements());
27983     HiZ = DAG.getSetCC(DL, MaskVT, Hi, Zero, ISD::SETEQ);
27984     HiZ = DAG.getNode(ISD::SIGN_EXTEND, DL, CurrVT, HiZ);
27985   } else {
27986     HiZ = DAG.getSetCC(DL, CurrVT, Hi, Zero, ISD::SETEQ);
27987   }
27988 
27989   Lo = DAG.getNode(X86ISD::PSHUFB, DL, CurrVT, InRegLUT, Lo);
27990   Hi = DAG.getNode(X86ISD::PSHUFB, DL, CurrVT, InRegLUT, Hi);
27991   Lo = DAG.getNode(ISD::AND, DL, CurrVT, Lo, HiZ);
27992   SDValue Res = DAG.getNode(ISD::ADD, DL, CurrVT, Lo, Hi);
27993 
27994   // Merge result back from vXi8 back to VT, working on the lo/hi halves
27995   // of the current vector width in the same way we did for the nibbles.
27996   // If the upper half of the input element is zero then add the halves'
27997   // leading zero counts together, otherwise just use the upper half's.
27998   // Double the width of the result until we are at target width.
27999   while (CurrVT != VT) {
28000     int CurrScalarSizeInBits = CurrVT.getScalarSizeInBits();
28001     int CurrNumElts = CurrVT.getVectorNumElements();
28002     MVT NextSVT = MVT::getIntegerVT(CurrScalarSizeInBits * 2);
28003     MVT NextVT = MVT::getVectorVT(NextSVT, CurrNumElts / 2);
28004     SDValue Shift = DAG.getConstant(CurrScalarSizeInBits, DL, NextVT);
28005 
28006     // Check if the upper half of the input element is zero.
28007     if (CurrVT.is512BitVector()) {
28008       MVT MaskVT = MVT::getVectorVT(MVT::i1, CurrVT.getVectorNumElements());
28009       HiZ = DAG.getSetCC(DL, MaskVT, DAG.getBitcast(CurrVT, Op0),
28010                          DAG.getBitcast(CurrVT, Zero), ISD::SETEQ);
28011       HiZ = DAG.getNode(ISD::SIGN_EXTEND, DL, CurrVT, HiZ);
28012     } else {
28013       HiZ = DAG.getSetCC(DL, CurrVT, DAG.getBitcast(CurrVT, Op0),
28014                          DAG.getBitcast(CurrVT, Zero), ISD::SETEQ);
28015     }
28016     HiZ = DAG.getBitcast(NextVT, HiZ);
28017 
28018     // Move the upper/lower halves to the lower bits as we'll be extending to
28019     // NextVT. Mask the lower result to zero if HiZ is true and add the results
28020     // together.
28021     SDValue ResNext = Res = DAG.getBitcast(NextVT, Res);
28022     SDValue R0 = DAG.getNode(ISD::SRL, DL, NextVT, ResNext, Shift);
28023     SDValue R1 = DAG.getNode(ISD::SRL, DL, NextVT, HiZ, Shift);
28024     R1 = DAG.getNode(ISD::AND, DL, NextVT, ResNext, R1);
28025     Res = DAG.getNode(ISD::ADD, DL, NextVT, R0, R1);
28026     CurrVT = NextVT;
28027   }
28028 
28029   return Res;
28030 }
28031 
LowerVectorCTLZ(SDValue Op,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)28032 static SDValue LowerVectorCTLZ(SDValue Op, const SDLoc &DL,
28033                                const X86Subtarget &Subtarget,
28034                                SelectionDAG &DAG) {
28035   MVT VT = Op.getSimpleValueType();
28036 
28037   if (Subtarget.hasCDI() &&
28038       // vXi8 vectors need to be promoted to 512-bits for vXi32.
28039       (Subtarget.canExtendTo512DQ() || VT.getVectorElementType() != MVT::i8))
28040     return LowerVectorCTLZ_AVX512CDI(Op, DAG, Subtarget);
28041 
28042   // Decompose 256-bit ops into smaller 128-bit ops.
28043   if (VT.is256BitVector() && !Subtarget.hasInt256())
28044     return splitVectorIntUnary(Op, DAG, DL);
28045 
28046   // Decompose 512-bit ops into smaller 256-bit ops.
28047   if (VT.is512BitVector() && !Subtarget.hasBWI())
28048     return splitVectorIntUnary(Op, DAG, DL);
28049 
28050   assert(Subtarget.hasSSSE3() && "Expected SSSE3 support for PSHUFB");
28051   return LowerVectorCTLZInRegLUT(Op, DL, Subtarget, DAG);
28052 }
28053 
LowerCTLZ(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28054 static SDValue LowerCTLZ(SDValue Op, const X86Subtarget &Subtarget,
28055                          SelectionDAG &DAG) {
28056   MVT VT = Op.getSimpleValueType();
28057   MVT OpVT = VT;
28058   unsigned NumBits = VT.getSizeInBits();
28059   SDLoc dl(Op);
28060   unsigned Opc = Op.getOpcode();
28061 
28062   if (VT.isVector())
28063     return LowerVectorCTLZ(Op, dl, Subtarget, DAG);
28064 
28065   Op = Op.getOperand(0);
28066   if (VT == MVT::i8) {
28067     // Zero extend to i32 since there is not an i8 bsr.
28068     OpVT = MVT::i32;
28069     Op = DAG.getNode(ISD::ZERO_EXTEND, dl, OpVT, Op);
28070   }
28071 
28072   // Issue a bsr (scan bits in reverse) which also sets EFLAGS.
28073   SDVTList VTs = DAG.getVTList(OpVT, MVT::i32);
28074   Op = DAG.getNode(X86ISD::BSR, dl, VTs, Op);
28075 
28076   if (Opc == ISD::CTLZ) {
28077     // If src is zero (i.e. bsr sets ZF), returns NumBits.
28078     SDValue Ops[] = {Op, DAG.getConstant(NumBits + NumBits - 1, dl, OpVT),
28079                      DAG.getTargetConstant(X86::COND_E, dl, MVT::i8),
28080                      Op.getValue(1)};
28081     Op = DAG.getNode(X86ISD::CMOV, dl, OpVT, Ops);
28082   }
28083 
28084   // Finally xor with NumBits-1.
28085   Op = DAG.getNode(ISD::XOR, dl, OpVT, Op,
28086                    DAG.getConstant(NumBits - 1, dl, OpVT));
28087 
28088   if (VT == MVT::i8)
28089     Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Op);
28090   return Op;
28091 }
28092 
LowerCTTZ(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28093 static SDValue LowerCTTZ(SDValue Op, const X86Subtarget &Subtarget,
28094                          SelectionDAG &DAG) {
28095   MVT VT = Op.getSimpleValueType();
28096   unsigned NumBits = VT.getScalarSizeInBits();
28097   SDValue N0 = Op.getOperand(0);
28098   SDLoc dl(Op);
28099 
28100   assert(!VT.isVector() && Op.getOpcode() == ISD::CTTZ &&
28101          "Only scalar CTTZ requires custom lowering");
28102 
28103   // Issue a bsf (scan bits forward) which also sets EFLAGS.
28104   SDVTList VTs = DAG.getVTList(VT, MVT::i32);
28105   Op = DAG.getNode(X86ISD::BSF, dl, VTs, N0);
28106 
28107   // If src is known never zero we can skip the CMOV.
28108   if (DAG.isKnownNeverZero(N0))
28109     return Op;
28110 
28111   // If src is zero (i.e. bsf sets ZF), returns NumBits.
28112   SDValue Ops[] = {Op, DAG.getConstant(NumBits, dl, VT),
28113                    DAG.getTargetConstant(X86::COND_E, dl, MVT::i8),
28114                    Op.getValue(1)};
28115   return DAG.getNode(X86ISD::CMOV, dl, VT, Ops);
28116 }
28117 
lowerAddSub(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)28118 static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG,
28119                            const X86Subtarget &Subtarget) {
28120   MVT VT = Op.getSimpleValueType();
28121   SDLoc DL(Op);
28122 
28123   if (VT == MVT::i16 || VT == MVT::i32)
28124     return lowerAddSubToHorizontalOp(Op, DL, DAG, Subtarget);
28125 
28126   if (VT == MVT::v32i16 || VT == MVT::v64i8)
28127     return splitVectorIntBinary(Op, DAG, DL);
28128 
28129   assert(Op.getSimpleValueType().is256BitVector() &&
28130          Op.getSimpleValueType().isInteger() &&
28131          "Only handle AVX 256-bit vector integer operation");
28132   return splitVectorIntBinary(Op, DAG, DL);
28133 }
28134 
LowerADDSAT_SUBSAT(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)28135 static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG,
28136                                   const X86Subtarget &Subtarget) {
28137   MVT VT = Op.getSimpleValueType();
28138   SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
28139   unsigned Opcode = Op.getOpcode();
28140   SDLoc DL(Op);
28141 
28142   if (VT == MVT::v32i16 || VT == MVT::v64i8 ||
28143       (VT.is256BitVector() && !Subtarget.hasInt256())) {
28144     assert(Op.getSimpleValueType().isInteger() &&
28145            "Only handle AVX vector integer operation");
28146     return splitVectorIntBinary(Op, DAG, DL);
28147   }
28148 
28149   // Avoid the generic expansion with min/max if we don't have pminu*/pmaxu*.
28150   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28151   EVT SetCCResultType =
28152       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
28153 
28154   unsigned BitWidth = VT.getScalarSizeInBits();
28155   if (Opcode == ISD::USUBSAT) {
28156     if (!TLI.isOperationLegal(ISD::UMAX, VT) || useVPTERNLOG(Subtarget, VT)) {
28157       // Handle a special-case with a bit-hack instead of cmp+select:
28158       // usubsat X, SMIN --> (X ^ SMIN) & (X s>> BW-1)
28159       // If the target can use VPTERNLOG, DAGToDAG will match this as
28160       // "vpsra + vpternlog" which is better than "vpmax + vpsub" with a
28161       // "broadcast" constant load.
28162       ConstantSDNode *C = isConstOrConstSplat(Y, true);
28163       if (C && C->getAPIntValue().isSignMask()) {
28164         SDValue SignMask = DAG.getConstant(C->getAPIntValue(), DL, VT);
28165         SDValue ShiftAmt = DAG.getConstant(BitWidth - 1, DL, VT);
28166         SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, X, SignMask);
28167         SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShiftAmt);
28168         return DAG.getNode(ISD::AND, DL, VT, Xor, Sra);
28169       }
28170     }
28171     if (!TLI.isOperationLegal(ISD::UMAX, VT)) {
28172       // usubsat X, Y --> (X >u Y) ? X - Y : 0
28173       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, X, Y);
28174       SDValue Cmp = DAG.getSetCC(DL, SetCCResultType, X, Y, ISD::SETUGT);
28175       // TODO: Move this to DAGCombiner?
28176       if (SetCCResultType == VT &&
28177           DAG.ComputeNumSignBits(Cmp) == VT.getScalarSizeInBits())
28178         return DAG.getNode(ISD::AND, DL, VT, Cmp, Sub);
28179       return DAG.getSelect(DL, VT, Cmp, Sub, DAG.getConstant(0, DL, VT));
28180     }
28181   }
28182 
28183   if ((Opcode == ISD::SADDSAT || Opcode == ISD::SSUBSAT) &&
28184       (!VT.isVector() || VT == MVT::v2i64)) {
28185     APInt MinVal = APInt::getSignedMinValue(BitWidth);
28186     APInt MaxVal = APInt::getSignedMaxValue(BitWidth);
28187     SDValue Zero = DAG.getConstant(0, DL, VT);
28188     SDValue Result =
28189         DAG.getNode(Opcode == ISD::SADDSAT ? ISD::SADDO : ISD::SSUBO, DL,
28190                     DAG.getVTList(VT, SetCCResultType), X, Y);
28191     SDValue SumDiff = Result.getValue(0);
28192     SDValue Overflow = Result.getValue(1);
28193     SDValue SatMin = DAG.getConstant(MinVal, DL, VT);
28194     SDValue SatMax = DAG.getConstant(MaxVal, DL, VT);
28195     SDValue SumNeg =
28196         DAG.getSetCC(DL, SetCCResultType, SumDiff, Zero, ISD::SETLT);
28197     Result = DAG.getSelect(DL, VT, SumNeg, SatMax, SatMin);
28198     return DAG.getSelect(DL, VT, Overflow, Result, SumDiff);
28199   }
28200 
28201   // Use default expansion.
28202   return SDValue();
28203 }
28204 
LowerABS(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28205 static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
28206                         SelectionDAG &DAG) {
28207   MVT VT = Op.getSimpleValueType();
28208   SDLoc DL(Op);
28209 
28210   if (VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) {
28211     // Since X86 does not have CMOV for 8-bit integer, we don't convert
28212     // 8-bit integer abs to NEG and CMOV.
28213     SDValue N0 = Op.getOperand(0);
28214     SDValue Neg = DAG.getNode(X86ISD::SUB, DL, DAG.getVTList(VT, MVT::i32),
28215                               DAG.getConstant(0, DL, VT), N0);
28216     SDValue Ops[] = {N0, Neg, DAG.getTargetConstant(X86::COND_NS, DL, MVT::i8),
28217                      SDValue(Neg.getNode(), 1)};
28218     return DAG.getNode(X86ISD::CMOV, DL, VT, Ops);
28219   }
28220 
28221   // ABS(vXi64 X) --> VPBLENDVPD(X, 0-X, X).
28222   if ((VT == MVT::v2i64 || VT == MVT::v4i64) && Subtarget.hasSSE41()) {
28223     SDValue Src = Op.getOperand(0);
28224     SDValue Neg = DAG.getNegative(Src, DL, VT);
28225     return DAG.getNode(X86ISD::BLENDV, DL, VT, Src, Neg, Src);
28226   }
28227 
28228   if (VT.is256BitVector() && !Subtarget.hasInt256()) {
28229     assert(VT.isInteger() &&
28230            "Only handle AVX 256-bit vector integer operation");
28231     return splitVectorIntUnary(Op, DAG, DL);
28232   }
28233 
28234   if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.hasBWI())
28235     return splitVectorIntUnary(Op, DAG, DL);
28236 
28237   // Default to expand.
28238   return SDValue();
28239 }
28240 
LowerAVG(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28241 static SDValue LowerAVG(SDValue Op, const X86Subtarget &Subtarget,
28242                         SelectionDAG &DAG) {
28243   MVT VT = Op.getSimpleValueType();
28244   SDLoc DL(Op);
28245 
28246   // For AVX1 cases, split to use legal ops.
28247   if (VT.is256BitVector() && !Subtarget.hasInt256())
28248     return splitVectorIntBinary(Op, DAG, DL);
28249 
28250   if (VT == MVT::v32i16 || VT == MVT::v64i8)
28251     return splitVectorIntBinary(Op, DAG, DL);
28252 
28253   // Default to expand.
28254   return SDValue();
28255 }
28256 
LowerMINMAX(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28257 static SDValue LowerMINMAX(SDValue Op, const X86Subtarget &Subtarget,
28258                            SelectionDAG &DAG) {
28259   MVT VT = Op.getSimpleValueType();
28260   SDLoc DL(Op);
28261 
28262   // For AVX1 cases, split to use legal ops.
28263   if (VT.is256BitVector() && !Subtarget.hasInt256())
28264     return splitVectorIntBinary(Op, DAG, DL);
28265 
28266   if (VT == MVT::v32i16 || VT == MVT::v64i8)
28267     return splitVectorIntBinary(Op, DAG, DL);
28268 
28269   // Default to expand.
28270   return SDValue();
28271 }
28272 
LowerFMINIMUM_FMAXIMUM(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28273 static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
28274                                       SelectionDAG &DAG) {
28275   assert((Op.getOpcode() == ISD::FMAXIMUM || Op.getOpcode() == ISD::FMINIMUM) &&
28276          "Expected FMAXIMUM or FMINIMUM opcode");
28277   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28278   EVT VT = Op.getValueType();
28279   SDValue X = Op.getOperand(0);
28280   SDValue Y = Op.getOperand(1);
28281   SDLoc DL(Op);
28282   uint64_t SizeInBits = VT.getScalarSizeInBits();
28283   APInt PreferredZero = APInt::getZero(SizeInBits);
28284   APInt OppositeZero = PreferredZero;
28285   EVT IVT = VT.changeTypeToInteger();
28286   X86ISD::NodeType MinMaxOp;
28287   if (Op.getOpcode() == ISD::FMAXIMUM) {
28288     MinMaxOp = X86ISD::FMAX;
28289     OppositeZero.setSignBit();
28290   } else {
28291     PreferredZero.setSignBit();
28292     MinMaxOp = X86ISD::FMIN;
28293   }
28294   EVT SetCCType =
28295       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
28296 
28297   // The tables below show the expected result of Max in cases of NaN and
28298   // signed zeros.
28299   //
28300   //                 Y                       Y
28301   //             Num   xNaN              +0     -0
28302   //          ---------------         ---------------
28303   //     Num  |  Max |   Y  |     +0  |  +0  |  +0  |
28304   // X        ---------------  X      ---------------
28305   //    xNaN  |   X  |  X/Y |     -0  |  +0  |  -0  |
28306   //          ---------------         ---------------
28307   //
28308   // It is achieved by means of FMAX/FMIN with preliminary checks and operand
28309   // reordering.
28310   //
28311   // We check if any of operands is NaN and return NaN. Then we check if any of
28312   // operands is zero or negative zero (for fmaximum and fminimum respectively)
28313   // to ensure the correct zero is returned.
28314   auto MatchesZero = [](SDValue Op, APInt Zero) {
28315     Op = peekThroughBitcasts(Op);
28316     if (auto *CstOp = dyn_cast<ConstantFPSDNode>(Op))
28317       return CstOp->getValueAPF().bitcastToAPInt() == Zero;
28318     if (auto *CstOp = dyn_cast<ConstantSDNode>(Op))
28319       return CstOp->getAPIntValue() == Zero;
28320     if (Op->getOpcode() == ISD::BUILD_VECTOR ||
28321         Op->getOpcode() == ISD::SPLAT_VECTOR) {
28322       for (const SDValue &OpVal : Op->op_values()) {
28323         if (OpVal.isUndef())
28324           continue;
28325         auto *CstOp = dyn_cast<ConstantFPSDNode>(OpVal);
28326         if (!CstOp)
28327           return false;
28328         if (!CstOp->getValueAPF().isZero())
28329           continue;
28330         if (CstOp->getValueAPF().bitcastToAPInt() != Zero)
28331           return false;
28332       }
28333       return true;
28334     }
28335     return false;
28336   };
28337 
28338   bool IsXNeverNaN = DAG.isKnownNeverNaN(X);
28339   bool IsYNeverNaN = DAG.isKnownNeverNaN(Y);
28340   bool IgnoreSignedZero = DAG.getTarget().Options.NoSignedZerosFPMath ||
28341                           Op->getFlags().hasNoSignedZeros() ||
28342                           DAG.isKnownNeverZeroFloat(X) ||
28343                           DAG.isKnownNeverZeroFloat(Y);
28344   SDValue NewX, NewY;
28345   if (IgnoreSignedZero || MatchesZero(Y, PreferredZero) ||
28346       MatchesZero(X, OppositeZero)) {
28347     // Operands are already in right order or order does not matter.
28348     NewX = X;
28349     NewY = Y;
28350   } else if (MatchesZero(X, PreferredZero) || MatchesZero(Y, OppositeZero)) {
28351     NewX = Y;
28352     NewY = X;
28353   } else if (!VT.isVector() && (VT == MVT::f16 || Subtarget.hasDQI()) &&
28354              (Op->getFlags().hasNoNaNs() || IsXNeverNaN || IsYNeverNaN)) {
28355     if (IsXNeverNaN)
28356       std::swap(X, Y);
28357     // VFPCLASSS consumes a vector type. So provide a minimal one corresponded
28358     // xmm register.
28359     MVT VectorType = MVT::getVectorVT(VT.getSimpleVT(), 128 / SizeInBits);
28360     SDValue VX = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VectorType, X);
28361     // Bits of classes:
28362     // Bits  Imm8[0] Imm8[1] Imm8[2] Imm8[3] Imm8[4]  Imm8[5]  Imm8[6] Imm8[7]
28363     // Class    QNAN PosZero NegZero  PosINF  NegINF Denormal Negative    SNAN
28364     SDValue Imm = DAG.getTargetConstant(MinMaxOp == X86ISD::FMAX ? 0b11 : 0b101,
28365                                         DL, MVT::i32);
28366     SDValue IsNanZero = DAG.getNode(X86ISD::VFPCLASSS, DL, MVT::v1i1, VX, Imm);
28367     SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v8i1,
28368                               DAG.getConstant(0, DL, MVT::v8i1), IsNanZero,
28369                               DAG.getIntPtrConstant(0, DL));
28370     SDValue NeedSwap = DAG.getBitcast(MVT::i8, Ins);
28371     NewX = DAG.getSelect(DL, VT, NeedSwap, Y, X);
28372     NewY = DAG.getSelect(DL, VT, NeedSwap, X, Y);
28373     return DAG.getNode(MinMaxOp, DL, VT, NewX, NewY, Op->getFlags());
28374   } else {
28375     SDValue IsXSigned;
28376     if (Subtarget.is64Bit() || VT != MVT::f64) {
28377       SDValue XInt = DAG.getNode(ISD::BITCAST, DL, IVT, X);
28378       SDValue ZeroCst = DAG.getConstant(0, DL, IVT);
28379       IsXSigned = DAG.getSetCC(DL, SetCCType, XInt, ZeroCst, ISD::SETLT);
28380     } else {
28381       assert(VT == MVT::f64);
28382       SDValue Ins = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v2f64,
28383                                 DAG.getConstantFP(0, DL, MVT::v2f64), X,
28384                                 DAG.getIntPtrConstant(0, DL));
28385       SDValue VX = DAG.getNode(ISD::BITCAST, DL, MVT::v4f32, Ins);
28386       SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, VX,
28387                                DAG.getIntPtrConstant(1, DL));
28388       Hi = DAG.getBitcast(MVT::i32, Hi);
28389       SDValue ZeroCst = DAG.getConstant(0, DL, MVT::i32);
28390       EVT SetCCType = TLI.getSetCCResultType(DAG.getDataLayout(),
28391                                              *DAG.getContext(), MVT::i32);
28392       IsXSigned = DAG.getSetCC(DL, SetCCType, Hi, ZeroCst, ISD::SETLT);
28393     }
28394     if (MinMaxOp == X86ISD::FMAX) {
28395       NewX = DAG.getSelect(DL, VT, IsXSigned, X, Y);
28396       NewY = DAG.getSelect(DL, VT, IsXSigned, Y, X);
28397     } else {
28398       NewX = DAG.getSelect(DL, VT, IsXSigned, Y, X);
28399       NewY = DAG.getSelect(DL, VT, IsXSigned, X, Y);
28400     }
28401   }
28402 
28403   bool IgnoreNaN = DAG.getTarget().Options.NoNaNsFPMath ||
28404                    Op->getFlags().hasNoNaNs() || (IsXNeverNaN && IsYNeverNaN);
28405 
28406   // If we did no ordering operands for signed zero handling and we need
28407   // to process NaN and we know that the second operand is not NaN then put
28408   // it in first operand and we will not need to post handle NaN after max/min.
28409   if (IgnoreSignedZero && !IgnoreNaN && DAG.isKnownNeverNaN(NewY))
28410     std::swap(NewX, NewY);
28411 
28412   SDValue MinMax = DAG.getNode(MinMaxOp, DL, VT, NewX, NewY, Op->getFlags());
28413 
28414   if (IgnoreNaN || DAG.isKnownNeverNaN(NewX))
28415     return MinMax;
28416 
28417   SDValue IsNaN = DAG.getSetCC(DL, SetCCType, NewX, NewX, ISD::SETUO);
28418   return DAG.getSelect(DL, VT, IsNaN, NewX, MinMax);
28419 }
28420 
LowerABD(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28421 static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
28422                         SelectionDAG &DAG) {
28423   MVT VT = Op.getSimpleValueType();
28424   SDLoc dl(Op);
28425 
28426   // For AVX1 cases, split to use legal ops.
28427   if (VT.is256BitVector() && !Subtarget.hasInt256())
28428     return splitVectorIntBinary(Op, DAG, dl);
28429 
28430   if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.useBWIRegs())
28431     return splitVectorIntBinary(Op, DAG, dl);
28432 
28433   bool IsSigned = Op.getOpcode() == ISD::ABDS;
28434   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28435 
28436   // TODO: Move to TargetLowering expandABD() once we have ABD promotion.
28437   if (VT.isScalarInteger()) {
28438     unsigned WideBits = std::max<unsigned>(2 * VT.getScalarSizeInBits(), 32u);
28439     MVT WideVT = MVT::getIntegerVT(WideBits);
28440     if (TLI.isTypeLegal(WideVT)) {
28441       // abds(lhs, rhs) -> trunc(abs(sub(sext(lhs), sext(rhs))))
28442       // abdu(lhs, rhs) -> trunc(abs(sub(zext(lhs), zext(rhs))))
28443       unsigned ExtOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
28444       SDValue LHS = DAG.getNode(ExtOpc, dl, WideVT, Op.getOperand(0));
28445       SDValue RHS = DAG.getNode(ExtOpc, dl, WideVT, Op.getOperand(1));
28446       SDValue Diff = DAG.getNode(ISD::SUB, dl, WideVT, LHS, RHS);
28447       SDValue AbsDiff = DAG.getNode(ISD::ABS, dl, WideVT, Diff);
28448       return DAG.getNode(ISD::TRUNCATE, dl, VT, AbsDiff);
28449     }
28450   }
28451 
28452   // Default to expand.
28453   return SDValue();
28454 }
28455 
LowerMUL(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28456 static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
28457                         SelectionDAG &DAG) {
28458   SDLoc dl(Op);
28459   MVT VT = Op.getSimpleValueType();
28460 
28461   // Decompose 256-bit ops into 128-bit ops.
28462   if (VT.is256BitVector() && !Subtarget.hasInt256())
28463     return splitVectorIntBinary(Op, DAG, dl);
28464 
28465   if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.hasBWI())
28466     return splitVectorIntBinary(Op, DAG, dl);
28467 
28468   SDValue A = Op.getOperand(0);
28469   SDValue B = Op.getOperand(1);
28470 
28471   // Lower v16i8/v32i8/v64i8 mul as sign-extension to v8i16/v16i16/v32i16
28472   // vector pairs, multiply and truncate.
28473   if (VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) {
28474     unsigned NumElts = VT.getVectorNumElements();
28475     unsigned NumLanes = VT.getSizeInBits() / 128;
28476     unsigned NumEltsPerLane = NumElts / NumLanes;
28477 
28478     if ((VT == MVT::v16i8 && Subtarget.hasInt256()) ||
28479         (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) {
28480       MVT ExVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements());
28481       return DAG.getNode(
28482           ISD::TRUNCATE, dl, VT,
28483           DAG.getNode(ISD::MUL, dl, ExVT,
28484                       DAG.getNode(ISD::ANY_EXTEND, dl, ExVT, A),
28485                       DAG.getNode(ISD::ANY_EXTEND, dl, ExVT, B)));
28486     }
28487 
28488     MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
28489 
28490     // For vXi8 mul, try PMADDUBSW to avoid the need for extension.
28491     // Don't do this if we only need to unpack one half.
28492     if (Subtarget.hasSSSE3()) {
28493       bool BIsBuildVector = isa<BuildVectorSDNode>(B);
28494       bool IsLoLaneAllZeroOrUndef = BIsBuildVector;
28495       bool IsHiLaneAllZeroOrUndef = BIsBuildVector;
28496       if (BIsBuildVector) {
28497         for (auto [Idx, Val] : enumerate(B->ops())) {
28498           if ((Idx % NumEltsPerLane) >= (NumEltsPerLane / 2))
28499             IsHiLaneAllZeroOrUndef &= isNullConstantOrUndef(Val);
28500           else
28501             IsLoLaneAllZeroOrUndef &= isNullConstantOrUndef(Val);
28502         }
28503       }
28504       if (!(IsLoLaneAllZeroOrUndef || IsHiLaneAllZeroOrUndef)) {
28505         SDValue Mask = DAG.getBitcast(VT, DAG.getConstant(0x00FF, dl, ExVT));
28506         SDValue BLo = DAG.getNode(ISD::AND, dl, VT, Mask, B);
28507         SDValue BHi = DAG.getNode(X86ISD::ANDNP, dl, VT, Mask, B);
28508         SDValue RLo = DAG.getNode(X86ISD::VPMADDUBSW, dl, ExVT, A, BLo);
28509         SDValue RHi = DAG.getNode(X86ISD::VPMADDUBSW, dl, ExVT, A, BHi);
28510         RLo = DAG.getNode(ISD::AND, dl, VT, DAG.getBitcast(VT, RLo), Mask);
28511         RHi = DAG.getNode(X86ISD::VSHLI, dl, ExVT, RHi,
28512                           DAG.getTargetConstant(8, dl, MVT::i8));
28513         return DAG.getNode(ISD::OR, dl, VT, RLo, DAG.getBitcast(VT, RHi));
28514       }
28515     }
28516 
28517     // Extract the lo/hi parts to any extend to i16.
28518     // We're going to mask off the low byte of each result element of the
28519     // pmullw, so it doesn't matter what's in the high byte of each 16-bit
28520     // element.
28521     SDValue Undef = DAG.getUNDEF(VT);
28522     SDValue ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Undef));
28523     SDValue AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Undef));
28524 
28525     SDValue BLo, BHi;
28526     if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) {
28527       // If the RHS is a constant, manually unpackl/unpackh.
28528       SmallVector<SDValue, 16> LoOps, HiOps;
28529       for (unsigned i = 0; i != NumElts; i += 16) {
28530         for (unsigned j = 0; j != 8; ++j) {
28531           LoOps.push_back(DAG.getAnyExtOrTrunc(B.getOperand(i + j), dl,
28532                                                MVT::i16));
28533           HiOps.push_back(DAG.getAnyExtOrTrunc(B.getOperand(i + j + 8), dl,
28534                                                MVT::i16));
28535         }
28536       }
28537 
28538       BLo = DAG.getBuildVector(ExVT, dl, LoOps);
28539       BHi = DAG.getBuildVector(ExVT, dl, HiOps);
28540     } else {
28541       BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Undef));
28542       BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Undef));
28543     }
28544 
28545     // Multiply, mask the lower 8bits of the lo/hi results and pack.
28546     SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo);
28547     SDValue RHi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi);
28548     return getPack(DAG, Subtarget, dl, VT, RLo, RHi);
28549   }
28550 
28551   // Lower v4i32 mul as 2x shuffle, 2x pmuludq, 2x shuffle.
28552   if (VT == MVT::v4i32) {
28553     assert(Subtarget.hasSSE2() && !Subtarget.hasSSE41() &&
28554            "Should not custom lower when pmulld is available!");
28555 
28556     // Extract the odd parts.
28557     static const int UnpackMask[] = { 1, -1, 3, -1 };
28558     SDValue Aodds = DAG.getVectorShuffle(VT, dl, A, A, UnpackMask);
28559     SDValue Bodds = DAG.getVectorShuffle(VT, dl, B, B, UnpackMask);
28560 
28561     // Multiply the even parts.
28562     SDValue Evens = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64,
28563                                 DAG.getBitcast(MVT::v2i64, A),
28564                                 DAG.getBitcast(MVT::v2i64, B));
28565     // Now multiply odd parts.
28566     SDValue Odds = DAG.getNode(X86ISD::PMULUDQ, dl, MVT::v2i64,
28567                                DAG.getBitcast(MVT::v2i64, Aodds),
28568                                DAG.getBitcast(MVT::v2i64, Bodds));
28569 
28570     Evens = DAG.getBitcast(VT, Evens);
28571     Odds = DAG.getBitcast(VT, Odds);
28572 
28573     // Merge the two vectors back together with a shuffle. This expands into 2
28574     // shuffles.
28575     static const int ShufMask[] = { 0, 4, 2, 6 };
28576     return DAG.getVectorShuffle(VT, dl, Evens, Odds, ShufMask);
28577   }
28578 
28579   assert((VT == MVT::v2i64 || VT == MVT::v4i64 || VT == MVT::v8i64) &&
28580          "Only know how to lower V2I64/V4I64/V8I64 multiply");
28581   assert(!Subtarget.hasDQI() && "DQI should use MULLQ");
28582 
28583   //  Ahi = psrlqi(a, 32);
28584   //  Bhi = psrlqi(b, 32);
28585   //
28586   //  AloBlo = pmuludq(a, b);
28587   //  AloBhi = pmuludq(a, Bhi);
28588   //  AhiBlo = pmuludq(Ahi, b);
28589   //
28590   //  Hi = psllqi(AloBhi + AhiBlo, 32);
28591   //  return AloBlo + Hi;
28592   KnownBits AKnown = DAG.computeKnownBits(A);
28593   KnownBits BKnown = DAG.computeKnownBits(B);
28594 
28595   APInt LowerBitsMask = APInt::getLowBitsSet(64, 32);
28596   bool ALoIsZero = LowerBitsMask.isSubsetOf(AKnown.Zero);
28597   bool BLoIsZero = LowerBitsMask.isSubsetOf(BKnown.Zero);
28598 
28599   APInt UpperBitsMask = APInt::getHighBitsSet(64, 32);
28600   bool AHiIsZero = UpperBitsMask.isSubsetOf(AKnown.Zero);
28601   bool BHiIsZero = UpperBitsMask.isSubsetOf(BKnown.Zero);
28602 
28603   SDValue Zero = DAG.getConstant(0, dl, VT);
28604 
28605   // Only multiply lo/hi halves that aren't known to be zero.
28606   SDValue AloBlo = Zero;
28607   if (!ALoIsZero && !BLoIsZero)
28608     AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, B);
28609 
28610   SDValue AloBhi = Zero;
28611   if (!ALoIsZero && !BHiIsZero) {
28612     SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
28613     AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
28614   }
28615 
28616   SDValue AhiBlo = Zero;
28617   if (!AHiIsZero && !BLoIsZero) {
28618     SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
28619     AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
28620   }
28621 
28622   SDValue Hi = DAG.getNode(ISD::ADD, dl, VT, AloBhi, AhiBlo);
28623   Hi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Hi, 32, DAG);
28624 
28625   return DAG.getNode(ISD::ADD, dl, VT, AloBlo, Hi);
28626 }
28627 
LowervXi8MulWithUNPCK(SDValue A,SDValue B,const SDLoc & dl,MVT VT,bool IsSigned,const X86Subtarget & Subtarget,SelectionDAG & DAG,SDValue * Low=nullptr)28628 static SDValue LowervXi8MulWithUNPCK(SDValue A, SDValue B, const SDLoc &dl,
28629                                      MVT VT, bool IsSigned,
28630                                      const X86Subtarget &Subtarget,
28631                                      SelectionDAG &DAG,
28632                                      SDValue *Low = nullptr) {
28633   unsigned NumElts = VT.getVectorNumElements();
28634 
28635   // For vXi8 we will unpack the low and high half of each 128 bit lane to widen
28636   // to a vXi16 type. Do the multiplies, shift the results and pack the half
28637   // lane results back together.
28638 
28639   // We'll take different approaches for signed and unsigned.
28640   // For unsigned we'll use punpcklbw/punpckhbw to put zero extend the bytes
28641   // and use pmullw to calculate the full 16-bit product.
28642   // For signed we'll use punpcklbw/punpckbw to extend the bytes to words and
28643   // shift them left into the upper byte of each word. This allows us to use
28644   // pmulhw to calculate the full 16-bit product. This trick means we don't
28645   // need to sign extend the bytes to use pmullw.
28646 
28647   MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
28648   SDValue Zero = DAG.getConstant(0, dl, VT);
28649 
28650   SDValue ALo, AHi;
28651   if (IsSigned) {
28652     ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, A));
28653     AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, A));
28654   } else {
28655     ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Zero));
28656     AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Zero));
28657   }
28658 
28659   SDValue BLo, BHi;
28660   if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) {
28661     // If the RHS is a constant, manually unpackl/unpackh and extend.
28662     SmallVector<SDValue, 16> LoOps, HiOps;
28663     for (unsigned i = 0; i != NumElts; i += 16) {
28664       for (unsigned j = 0; j != 8; ++j) {
28665         SDValue LoOp = B.getOperand(i + j);
28666         SDValue HiOp = B.getOperand(i + j + 8);
28667 
28668         if (IsSigned) {
28669           LoOp = DAG.getAnyExtOrTrunc(LoOp, dl, MVT::i16);
28670           HiOp = DAG.getAnyExtOrTrunc(HiOp, dl, MVT::i16);
28671           LoOp = DAG.getNode(ISD::SHL, dl, MVT::i16, LoOp,
28672                              DAG.getConstant(8, dl, MVT::i16));
28673           HiOp = DAG.getNode(ISD::SHL, dl, MVT::i16, HiOp,
28674                              DAG.getConstant(8, dl, MVT::i16));
28675         } else {
28676           LoOp = DAG.getZExtOrTrunc(LoOp, dl, MVT::i16);
28677           HiOp = DAG.getZExtOrTrunc(HiOp, dl, MVT::i16);
28678         }
28679 
28680         LoOps.push_back(LoOp);
28681         HiOps.push_back(HiOp);
28682       }
28683     }
28684 
28685     BLo = DAG.getBuildVector(ExVT, dl, LoOps);
28686     BHi = DAG.getBuildVector(ExVT, dl, HiOps);
28687   } else if (IsSigned) {
28688     BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, B));
28689     BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, B));
28690   } else {
28691     BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Zero));
28692     BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Zero));
28693   }
28694 
28695   // Multiply, lshr the upper 8bits to the lower 8bits of the lo/hi results and
28696   // pack back to vXi8.
28697   unsigned MulOpc = IsSigned ? ISD::MULHS : ISD::MUL;
28698   SDValue RLo = DAG.getNode(MulOpc, dl, ExVT, ALo, BLo);
28699   SDValue RHi = DAG.getNode(MulOpc, dl, ExVT, AHi, BHi);
28700 
28701   if (Low)
28702     *Low = getPack(DAG, Subtarget, dl, VT, RLo, RHi);
28703 
28704   return getPack(DAG, Subtarget, dl, VT, RLo, RHi, /*PackHiHalf*/ true);
28705 }
28706 
LowerMULH(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28707 static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
28708                          SelectionDAG &DAG) {
28709   SDLoc dl(Op);
28710   MVT VT = Op.getSimpleValueType();
28711   bool IsSigned = Op->getOpcode() == ISD::MULHS;
28712   unsigned NumElts = VT.getVectorNumElements();
28713   SDValue A = Op.getOperand(0);
28714   SDValue B = Op.getOperand(1);
28715 
28716   // Decompose 256-bit ops into 128-bit ops.
28717   if (VT.is256BitVector() && !Subtarget.hasInt256())
28718     return splitVectorIntBinary(Op, DAG, dl);
28719 
28720   if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.hasBWI())
28721     return splitVectorIntBinary(Op, DAG, dl);
28722 
28723   if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) {
28724     assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) ||
28725            (VT == MVT::v8i32 && Subtarget.hasInt256()) ||
28726            (VT == MVT::v16i32 && Subtarget.hasAVX512()));
28727 
28728     // PMULxD operations multiply each even value (starting at 0) of LHS with
28729     // the related value of RHS and produce a widen result.
28730     // E.g., PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h>
28731     // => <2 x i64> <ae|cg>
28732     //
28733     // In other word, to have all the results, we need to perform two PMULxD:
28734     // 1. one with the even values.
28735     // 2. one with the odd values.
28736     // To achieve #2, with need to place the odd values at an even position.
28737     //
28738     // Place the odd value at an even position (basically, shift all values 1
28739     // step to the left):
28740     const int Mask[] = {1, -1,  3, -1,  5, -1,  7, -1,
28741                         9, -1, 11, -1, 13, -1, 15, -1};
28742     // <a|b|c|d> => <b|undef|d|undef>
28743     SDValue Odd0 =
28744         DAG.getVectorShuffle(VT, dl, A, A, ArrayRef(&Mask[0], NumElts));
28745     // <e|f|g|h> => <f|undef|h|undef>
28746     SDValue Odd1 =
28747         DAG.getVectorShuffle(VT, dl, B, B, ArrayRef(&Mask[0], NumElts));
28748 
28749     // Emit two multiplies, one for the lower 2 ints and one for the higher 2
28750     // ints.
28751     MVT MulVT = MVT::getVectorVT(MVT::i64, NumElts / 2);
28752     unsigned Opcode =
28753         (IsSigned && Subtarget.hasSSE41()) ? X86ISD::PMULDQ : X86ISD::PMULUDQ;
28754     // PMULUDQ <4 x i32> <a|b|c|d>, <4 x i32> <e|f|g|h>
28755     // => <2 x i64> <ae|cg>
28756     SDValue Mul1 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT,
28757                                                   DAG.getBitcast(MulVT, A),
28758                                                   DAG.getBitcast(MulVT, B)));
28759     // PMULUDQ <4 x i32> <b|undef|d|undef>, <4 x i32> <f|undef|h|undef>
28760     // => <2 x i64> <bf|dh>
28761     SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT,
28762                                                   DAG.getBitcast(MulVT, Odd0),
28763                                                   DAG.getBitcast(MulVT, Odd1)));
28764 
28765     // Shuffle it back into the right order.
28766     SmallVector<int, 16> ShufMask(NumElts);
28767     for (int i = 0; i != (int)NumElts; ++i)
28768       ShufMask[i] = (i / 2) * 2 + ((i % 2) * NumElts) + 1;
28769 
28770     SDValue Res = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, ShufMask);
28771 
28772     // If we have a signed multiply but no PMULDQ fix up the result of an
28773     // unsigned multiply.
28774     if (IsSigned && !Subtarget.hasSSE41()) {
28775       SDValue Zero = DAG.getConstant(0, dl, VT);
28776       SDValue T1 = DAG.getNode(ISD::AND, dl, VT,
28777                                DAG.getSetCC(dl, VT, Zero, A, ISD::SETGT), B);
28778       SDValue T2 = DAG.getNode(ISD::AND, dl, VT,
28779                                DAG.getSetCC(dl, VT, Zero, B, ISD::SETGT), A);
28780 
28781       SDValue Fixup = DAG.getNode(ISD::ADD, dl, VT, T1, T2);
28782       Res = DAG.getNode(ISD::SUB, dl, VT, Res, Fixup);
28783     }
28784 
28785     return Res;
28786   }
28787 
28788   // Only i8 vectors should need custom lowering after this.
28789   assert((VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
28790          (VT == MVT::v64i8 && Subtarget.hasBWI())) &&
28791          "Unsupported vector type");
28792 
28793   // Lower v16i8/v32i8 as extension to v8i16/v16i16 vector pairs, multiply,
28794   // logical shift down the upper half and pack back to i8.
28795 
28796   // With SSE41 we can use sign/zero extend, but for pre-SSE41 we unpack
28797   // and then ashr/lshr the upper bits down to the lower bits before multiply.
28798 
28799   if ((VT == MVT::v16i8 && Subtarget.hasInt256()) ||
28800       (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) {
28801     MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
28802     unsigned ExAVX = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
28803     SDValue ExA = DAG.getNode(ExAVX, dl, ExVT, A);
28804     SDValue ExB = DAG.getNode(ExAVX, dl, ExVT, B);
28805     SDValue Mul = DAG.getNode(ISD::MUL, dl, ExVT, ExA, ExB);
28806     Mul = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG);
28807     return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
28808   }
28809 
28810   return LowervXi8MulWithUNPCK(A, B, dl, VT, IsSigned, Subtarget, DAG);
28811 }
28812 
28813 // Custom lowering for SMULO/UMULO.
LowerMULO(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)28814 static SDValue LowerMULO(SDValue Op, const X86Subtarget &Subtarget,
28815                          SelectionDAG &DAG) {
28816   MVT VT = Op.getSimpleValueType();
28817 
28818   // Scalars defer to LowerXALUO.
28819   if (!VT.isVector())
28820     return LowerXALUO(Op, DAG);
28821 
28822   SDLoc dl(Op);
28823   bool IsSigned = Op->getOpcode() == ISD::SMULO;
28824   SDValue A = Op.getOperand(0);
28825   SDValue B = Op.getOperand(1);
28826   EVT OvfVT = Op->getValueType(1);
28827 
28828   if ((VT == MVT::v32i8 && !Subtarget.hasInt256()) ||
28829       (VT == MVT::v64i8 && !Subtarget.hasBWI())) {
28830     // Extract the LHS Lo/Hi vectors
28831     SDValue LHSLo, LHSHi;
28832     std::tie(LHSLo, LHSHi) = splitVector(A, DAG, dl);
28833 
28834     // Extract the RHS Lo/Hi vectors
28835     SDValue RHSLo, RHSHi;
28836     std::tie(RHSLo, RHSHi) = splitVector(B, DAG, dl);
28837 
28838     EVT LoOvfVT, HiOvfVT;
28839     std::tie(LoOvfVT, HiOvfVT) = DAG.GetSplitDestVTs(OvfVT);
28840     SDVTList LoVTs = DAG.getVTList(LHSLo.getValueType(), LoOvfVT);
28841     SDVTList HiVTs = DAG.getVTList(LHSHi.getValueType(), HiOvfVT);
28842 
28843     // Issue the split operations.
28844     SDValue Lo = DAG.getNode(Op.getOpcode(), dl, LoVTs, LHSLo, RHSLo);
28845     SDValue Hi = DAG.getNode(Op.getOpcode(), dl, HiVTs, LHSHi, RHSHi);
28846 
28847     // Join the separate data results and the overflow results.
28848     SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
28849     SDValue Ovf = DAG.getNode(ISD::CONCAT_VECTORS, dl, OvfVT, Lo.getValue(1),
28850                               Hi.getValue(1));
28851 
28852     return DAG.getMergeValues({Res, Ovf}, dl);
28853   }
28854 
28855   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28856   EVT SetccVT =
28857       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
28858 
28859   if ((VT == MVT::v16i8 && Subtarget.hasInt256()) ||
28860       (VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) {
28861     unsigned NumElts = VT.getVectorNumElements();
28862     MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
28863     unsigned ExAVX = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
28864     SDValue ExA = DAG.getNode(ExAVX, dl, ExVT, A);
28865     SDValue ExB = DAG.getNode(ExAVX, dl, ExVT, B);
28866     SDValue Mul = DAG.getNode(ISD::MUL, dl, ExVT, ExA, ExB);
28867 
28868     SDValue Low = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
28869 
28870     SDValue Ovf;
28871     if (IsSigned) {
28872       SDValue High, LowSign;
28873       if (OvfVT.getVectorElementType() == MVT::i1 &&
28874           (Subtarget.hasBWI() || Subtarget.canExtendTo512DQ())) {
28875         // Rather the truncating try to do the compare on vXi16 or vXi32.
28876         // Shift the high down filling with sign bits.
28877         High = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Mul, 8, DAG);
28878         // Fill all 16 bits with the sign bit from the low.
28879         LowSign =
28880             getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ExVT, Mul, 8, DAG);
28881         LowSign = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, LowSign,
28882                                              15, DAG);
28883         SetccVT = OvfVT;
28884         if (!Subtarget.hasBWI()) {
28885           // We can't do a vXi16 compare so sign extend to v16i32.
28886           High = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v16i32, High);
28887           LowSign = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v16i32, LowSign);
28888         }
28889       } else {
28890         // Otherwise do the compare at vXi8.
28891         High = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG);
28892         High = DAG.getNode(ISD::TRUNCATE, dl, VT, High);
28893         LowSign =
28894             DAG.getNode(ISD::SRA, dl, VT, Low, DAG.getConstant(7, dl, VT));
28895       }
28896 
28897       Ovf = DAG.getSetCC(dl, SetccVT, LowSign, High, ISD::SETNE);
28898     } else {
28899       SDValue High =
28900           getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG);
28901       if (OvfVT.getVectorElementType() == MVT::i1 &&
28902           (Subtarget.hasBWI() || Subtarget.canExtendTo512DQ())) {
28903         // Rather the truncating try to do the compare on vXi16 or vXi32.
28904         SetccVT = OvfVT;
28905         if (!Subtarget.hasBWI()) {
28906           // We can't do a vXi16 compare so sign extend to v16i32.
28907           High = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v16i32, High);
28908         }
28909       } else {
28910         // Otherwise do the compare at vXi8.
28911         High = DAG.getNode(ISD::TRUNCATE, dl, VT, High);
28912       }
28913 
28914       Ovf =
28915           DAG.getSetCC(dl, SetccVT, High,
28916                        DAG.getConstant(0, dl, High.getValueType()), ISD::SETNE);
28917     }
28918 
28919     Ovf = DAG.getSExtOrTrunc(Ovf, dl, OvfVT);
28920 
28921     return DAG.getMergeValues({Low, Ovf}, dl);
28922   }
28923 
28924   SDValue Low;
28925   SDValue High =
28926       LowervXi8MulWithUNPCK(A, B, dl, VT, IsSigned, Subtarget, DAG, &Low);
28927 
28928   SDValue Ovf;
28929   if (IsSigned) {
28930     // SMULO overflows if the high bits don't match the sign of the low.
28931     SDValue LowSign =
28932         DAG.getNode(ISD::SRA, dl, VT, Low, DAG.getConstant(7, dl, VT));
28933     Ovf = DAG.getSetCC(dl, SetccVT, LowSign, High, ISD::SETNE);
28934   } else {
28935     // UMULO overflows if the high bits are non-zero.
28936     Ovf =
28937         DAG.getSetCC(dl, SetccVT, High, DAG.getConstant(0, dl, VT), ISD::SETNE);
28938   }
28939 
28940   Ovf = DAG.getSExtOrTrunc(Ovf, dl, OvfVT);
28941 
28942   return DAG.getMergeValues({Low, Ovf}, dl);
28943 }
28944 
LowerWin64_i128OP(SDValue Op,SelectionDAG & DAG) const28945 SDValue X86TargetLowering::LowerWin64_i128OP(SDValue Op, SelectionDAG &DAG) const {
28946   assert(Subtarget.isTargetWin64() && "Unexpected target");
28947   EVT VT = Op.getValueType();
28948   assert(VT.isInteger() && VT.getSizeInBits() == 128 &&
28949          "Unexpected return type for lowering");
28950 
28951   if (isa<ConstantSDNode>(Op->getOperand(1))) {
28952     SmallVector<SDValue> Result;
28953     if (expandDIVREMByConstant(Op.getNode(), Result, MVT::i64, DAG))
28954       return DAG.getNode(ISD::BUILD_PAIR, SDLoc(Op), VT, Result[0], Result[1]);
28955   }
28956 
28957   RTLIB::Libcall LC;
28958   bool isSigned;
28959   switch (Op->getOpcode()) {
28960   // clang-format off
28961   default: llvm_unreachable("Unexpected request for libcall!");
28962   case ISD::SDIV:      isSigned = true;  LC = RTLIB::SDIV_I128;    break;
28963   case ISD::UDIV:      isSigned = false; LC = RTLIB::UDIV_I128;    break;
28964   case ISD::SREM:      isSigned = true;  LC = RTLIB::SREM_I128;    break;
28965   case ISD::UREM:      isSigned = false; LC = RTLIB::UREM_I128;    break;
28966   // clang-format on
28967   }
28968 
28969   SDLoc dl(Op);
28970   SDValue InChain = DAG.getEntryNode();
28971 
28972   TargetLowering::ArgListTy Args;
28973   TargetLowering::ArgListEntry Entry;
28974   for (unsigned i = 0, e = Op->getNumOperands(); i != e; ++i) {
28975     EVT ArgVT = Op->getOperand(i).getValueType();
28976     assert(ArgVT.isInteger() && ArgVT.getSizeInBits() == 128 &&
28977            "Unexpected argument type for lowering");
28978     SDValue StackPtr = DAG.CreateStackTemporary(ArgVT, 16);
28979     int SPFI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
28980     MachinePointerInfo MPI =
28981         MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SPFI);
28982     Entry.Node = StackPtr;
28983     InChain =
28984         DAG.getStore(InChain, dl, Op->getOperand(i), StackPtr, MPI, Align(16));
28985     Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
28986     Entry.Ty = PointerType::get(ArgTy,0);
28987     Entry.IsSExt = false;
28988     Entry.IsZExt = false;
28989     Args.push_back(Entry);
28990   }
28991 
28992   SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC),
28993                                          getPointerTy(DAG.getDataLayout()));
28994 
28995   TargetLowering::CallLoweringInfo CLI(DAG);
28996   CLI.setDebugLoc(dl)
28997       .setChain(InChain)
28998       .setLibCallee(
28999           getLibcallCallingConv(LC),
29000           static_cast<EVT>(MVT::v2i64).getTypeForEVT(*DAG.getContext()), Callee,
29001           std::move(Args))
29002       .setInRegister()
29003       .setSExtResult(isSigned)
29004       .setZExtResult(!isSigned);
29005 
29006   std::pair<SDValue, SDValue> CallInfo = LowerCallTo(CLI);
29007   return DAG.getBitcast(VT, CallInfo.first);
29008 }
29009 
LowerWin64_FP_TO_INT128(SDValue Op,SelectionDAG & DAG,SDValue & Chain) const29010 SDValue X86TargetLowering::LowerWin64_FP_TO_INT128(SDValue Op,
29011                                                    SelectionDAG &DAG,
29012                                                    SDValue &Chain) const {
29013   assert(Subtarget.isTargetWin64() && "Unexpected target");
29014   EVT VT = Op.getValueType();
29015   bool IsStrict = Op->isStrictFPOpcode();
29016 
29017   SDValue Arg = Op.getOperand(IsStrict ? 1 : 0);
29018   EVT ArgVT = Arg.getValueType();
29019 
29020   assert(VT.isInteger() && VT.getSizeInBits() == 128 &&
29021          "Unexpected return type for lowering");
29022 
29023   RTLIB::Libcall LC;
29024   if (Op->getOpcode() == ISD::FP_TO_SINT ||
29025       Op->getOpcode() == ISD::STRICT_FP_TO_SINT)
29026     LC = RTLIB::getFPTOSINT(ArgVT, VT);
29027   else
29028     LC = RTLIB::getFPTOUINT(ArgVT, VT);
29029   assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unexpected request for libcall!");
29030 
29031   SDLoc dl(Op);
29032   MakeLibCallOptions CallOptions;
29033   Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
29034 
29035   SDValue Result;
29036   // Expect the i128 argument returned as a v2i64 in xmm0, cast back to the
29037   // expected VT (i128).
29038   std::tie(Result, Chain) =
29039       makeLibCall(DAG, LC, MVT::v2i64, Arg, CallOptions, dl, Chain);
29040   Result = DAG.getBitcast(VT, Result);
29041   return Result;
29042 }
29043 
LowerWin64_INT128_TO_FP(SDValue Op,SelectionDAG & DAG) const29044 SDValue X86TargetLowering::LowerWin64_INT128_TO_FP(SDValue Op,
29045                                                    SelectionDAG &DAG) const {
29046   assert(Subtarget.isTargetWin64() && "Unexpected target");
29047   EVT VT = Op.getValueType();
29048   bool IsStrict = Op->isStrictFPOpcode();
29049 
29050   SDValue Arg = Op.getOperand(IsStrict ? 1 : 0);
29051   EVT ArgVT = Arg.getValueType();
29052 
29053   assert(ArgVT.isInteger() && ArgVT.getSizeInBits() == 128 &&
29054          "Unexpected argument type for lowering");
29055 
29056   RTLIB::Libcall LC;
29057   if (Op->getOpcode() == ISD::SINT_TO_FP ||
29058       Op->getOpcode() == ISD::STRICT_SINT_TO_FP)
29059     LC = RTLIB::getSINTTOFP(ArgVT, VT);
29060   else
29061     LC = RTLIB::getUINTTOFP(ArgVT, VT);
29062   assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unexpected request for libcall!");
29063 
29064   SDLoc dl(Op);
29065   MakeLibCallOptions CallOptions;
29066   SDValue Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
29067 
29068   // Pass the i128 argument as an indirect argument on the stack.
29069   SDValue StackPtr = DAG.CreateStackTemporary(ArgVT, 16);
29070   int SPFI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
29071   MachinePointerInfo MPI =
29072       MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SPFI);
29073   Chain = DAG.getStore(Chain, dl, Arg, StackPtr, MPI, Align(16));
29074 
29075   SDValue Result;
29076   std::tie(Result, Chain) =
29077       makeLibCall(DAG, LC, VT, StackPtr, CallOptions, dl, Chain);
29078   return IsStrict ? DAG.getMergeValues({Result, Chain}, dl) : Result;
29079 }
29080 
29081 // Generate a GFNI gf2p8affine bitmask for vXi8 bitreverse/shift/rotate.
getGFNICtrlImm(unsigned Opcode,unsigned Amt=0)29082 uint64_t getGFNICtrlImm(unsigned Opcode, unsigned Amt = 0) {
29083   assert((Amt < 8) && "Shift/Rotation amount out of range");
29084   switch (Opcode) {
29085   case ISD::BITREVERSE:
29086     return 0x8040201008040201ULL;
29087   case ISD::SHL:
29088     return ((0x0102040810204080ULL >> (Amt)) &
29089             (0x0101010101010101ULL * (0xFF >> (Amt))));
29090   case ISD::SRL:
29091     return ((0x0102040810204080ULL << (Amt)) &
29092             (0x0101010101010101ULL * ((0xFF << (Amt)) & 0xFF)));
29093   case ISD::SRA:
29094     return (getGFNICtrlImm(ISD::SRL, Amt) |
29095             (0x8080808080808080ULL >> (64 - (8 * Amt))));
29096   case ISD::ROTL:
29097     return getGFNICtrlImm(ISD::SRL, 8 - Amt) | getGFNICtrlImm(ISD::SHL, Amt);
29098   case ISD::ROTR:
29099     return getGFNICtrlImm(ISD::SHL, 8 - Amt) | getGFNICtrlImm(ISD::SRL, Amt);
29100   }
29101   llvm_unreachable("Unsupported GFNI opcode");
29102 }
29103 
29104 // Generate a GFNI gf2p8affine bitmask for vXi8 bitreverse/shift/rotate.
getGFNICtrlMask(unsigned Opcode,SelectionDAG & DAG,const SDLoc & DL,MVT VT,unsigned Amt=0)29105 SDValue getGFNICtrlMask(unsigned Opcode, SelectionDAG &DAG, const SDLoc &DL, MVT VT,
29106                         unsigned Amt = 0) {
29107   assert(VT.getVectorElementType() == MVT::i8 &&
29108          (VT.getSizeInBits() % 64) == 0 && "Illegal GFNI control type");
29109   uint64_t Imm = getGFNICtrlImm(Opcode, Amt);
29110   SmallVector<SDValue> MaskBits;
29111   for (unsigned I = 0, E = VT.getSizeInBits(); I != E; I += 8) {
29112     uint64_t Bits = (Imm >> (I % 64)) & 255;
29113     MaskBits.push_back(DAG.getConstant(Bits, DL, MVT::i8));
29114   }
29115   return DAG.getBuildVector(VT, DL, MaskBits);
29116 }
29117 
29118 // Return true if the required (according to Opcode) shift-imm form is natively
29119 // supported by the Subtarget
supportedVectorShiftWithImm(EVT VT,const X86Subtarget & Subtarget,unsigned Opcode)29120 static bool supportedVectorShiftWithImm(EVT VT, const X86Subtarget &Subtarget,
29121                                         unsigned Opcode) {
29122   assert((Opcode == ISD::SHL || Opcode == ISD::SRA || Opcode == ISD::SRL) &&
29123          "Unexpected shift opcode");
29124 
29125   if (!VT.isSimple())
29126     return false;
29127 
29128   if (!(VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()))
29129     return false;
29130 
29131   if (VT.getScalarSizeInBits() < 16)
29132     return false;
29133 
29134   if (VT.is512BitVector() && Subtarget.useAVX512Regs() &&
29135       (VT.getScalarSizeInBits() > 16 || Subtarget.hasBWI()))
29136     return true;
29137 
29138   bool LShift = (VT.is128BitVector() && Subtarget.hasSSE2()) ||
29139                 (VT.is256BitVector() && Subtarget.hasInt256());
29140 
29141   bool AShift = LShift && (Subtarget.hasAVX512() ||
29142                            (VT != MVT::v2i64 && VT != MVT::v4i64));
29143   return (Opcode == ISD::SRA) ? AShift : LShift;
29144 }
29145 
29146 // The shift amount is a variable, but it is the same for all vector lanes.
29147 // These instructions are defined together with shift-immediate.
29148 static
supportedVectorShiftWithBaseAmnt(EVT VT,const X86Subtarget & Subtarget,unsigned Opcode)29149 bool supportedVectorShiftWithBaseAmnt(EVT VT, const X86Subtarget &Subtarget,
29150                                       unsigned Opcode) {
29151   return supportedVectorShiftWithImm(VT, Subtarget, Opcode);
29152 }
29153 
29154 // Return true if the required (according to Opcode) variable-shift form is
29155 // natively supported by the Subtarget
supportedVectorVarShift(EVT VT,const X86Subtarget & Subtarget,unsigned Opcode)29156 static bool supportedVectorVarShift(EVT VT, const X86Subtarget &Subtarget,
29157                                     unsigned Opcode) {
29158   assert((Opcode == ISD::SHL || Opcode == ISD::SRA || Opcode == ISD::SRL) &&
29159          "Unexpected shift opcode");
29160 
29161   if (!VT.isSimple())
29162     return false;
29163 
29164   if (!(VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()))
29165     return false;
29166 
29167   if (!Subtarget.hasInt256() || VT.getScalarSizeInBits() < 16)
29168     return false;
29169 
29170   // vXi16 supported only on AVX-512, BWI
29171   if (VT.getScalarSizeInBits() == 16 && !Subtarget.hasBWI())
29172     return false;
29173 
29174   if (Subtarget.hasAVX512() &&
29175       (Subtarget.useAVX512Regs() || !VT.is512BitVector()))
29176     return true;
29177 
29178   bool LShift = VT.is128BitVector() || VT.is256BitVector();
29179   bool AShift = LShift &&  VT != MVT::v2i64 && VT != MVT::v4i64;
29180   return (Opcode == ISD::SRA) ? AShift : LShift;
29181 }
29182 
LowerShiftByScalarImmediate(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)29183 static SDValue LowerShiftByScalarImmediate(SDValue Op, SelectionDAG &DAG,
29184                                            const X86Subtarget &Subtarget) {
29185   MVT VT = Op.getSimpleValueType();
29186   SDLoc dl(Op);
29187   SDValue R = Op.getOperand(0);
29188   SDValue Amt = Op.getOperand(1);
29189   unsigned X86Opc = getTargetVShiftUniformOpcode(Op.getOpcode(), false);
29190   unsigned EltSizeInBits = VT.getScalarSizeInBits();
29191 
29192   auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) {
29193     assert((VT == MVT::v2i64 || VT == MVT::v4i64) && "Unexpected SRA type");
29194     MVT ExVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2);
29195     SDValue Ex = DAG.getBitcast(ExVT, R);
29196 
29197     // ashr(R, 63) === cmp_slt(R, 0)
29198     if (ShiftAmt == 63 && Subtarget.hasSSE42()) {
29199       assert((VT != MVT::v4i64 || Subtarget.hasInt256()) &&
29200              "Unsupported PCMPGT op");
29201       return DAG.getNode(X86ISD::PCMPGT, dl, VT, DAG.getConstant(0, dl, VT), R);
29202     }
29203 
29204     if (ShiftAmt >= 32) {
29205       // Splat sign to upper i32 dst, and SRA upper i32 src to lower i32.
29206       SDValue Upper =
29207           getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex, 31, DAG);
29208       SDValue Lower = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex,
29209                                                  ShiftAmt - 32, DAG);
29210       if (VT == MVT::v2i64)
29211         Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower, {5, 1, 7, 3});
29212       if (VT == MVT::v4i64)
29213         Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower,
29214                                   {9, 1, 11, 3, 13, 5, 15, 7});
29215     } else {
29216       // SRA upper i32, SRL whole i64 and select lower i32.
29217       SDValue Upper = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Ex,
29218                                                  ShiftAmt, DAG);
29219       SDValue Lower =
29220           getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt, DAG);
29221       Lower = DAG.getBitcast(ExVT, Lower);
29222       if (VT == MVT::v2i64)
29223         Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower, {4, 1, 6, 3});
29224       if (VT == MVT::v4i64)
29225         Ex = DAG.getVectorShuffle(ExVT, dl, Upper, Lower,
29226                                   {8, 1, 10, 3, 12, 5, 14, 7});
29227     }
29228     return DAG.getBitcast(VT, Ex);
29229   };
29230 
29231   // Optimize shl/srl/sra with constant shift amount.
29232   APInt APIntShiftAmt;
29233   if (!X86::isConstantSplat(Amt, APIntShiftAmt))
29234     return SDValue();
29235 
29236   // If the shift amount is out of range, return undef.
29237   if (APIntShiftAmt.uge(EltSizeInBits))
29238     return DAG.getUNDEF(VT);
29239 
29240   uint64_t ShiftAmt = APIntShiftAmt.getZExtValue();
29241 
29242   if (supportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode())) {
29243     // Hardware support for vector shifts is sparse which makes us scalarize the
29244     // vector operations in many cases. Also, on sandybridge ADD is faster than
29245     // shl: (shl V, 1) -> (add (freeze V), (freeze V))
29246     if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1) {
29247       // R may be undef at run-time, but (shl R, 1) must be an even number (LSB
29248       // must be 0). (add undef, undef) however can be any value. To make this
29249       // safe, we must freeze R to ensure that register allocation uses the same
29250       // register for an undefined value. This ensures that the result will
29251       // still be even and preserves the original semantics.
29252       R = DAG.getFreeze(R);
29253       return DAG.getNode(ISD::ADD, dl, VT, R, R);
29254     }
29255 
29256     return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
29257   }
29258 
29259   // i64 SRA needs to be performed as partial shifts.
29260   if (((!Subtarget.hasXOP() && VT == MVT::v2i64) ||
29261        (Subtarget.hasInt256() && VT == MVT::v4i64)) &&
29262       Op.getOpcode() == ISD::SRA)
29263     return ArithmeticShiftRight64(ShiftAmt);
29264 
29265   // If we're logical shifting an all-signbits value then we can just perform as
29266   // a mask.
29267   if ((Op.getOpcode() == ISD::SHL || Op.getOpcode() == ISD::SRL) &&
29268       DAG.ComputeNumSignBits(R) == EltSizeInBits) {
29269     SDValue Mask = DAG.getAllOnesConstant(dl, VT);
29270     Mask = DAG.getNode(Op.getOpcode(), dl, VT, Mask, Amt);
29271     return DAG.getNode(ISD::AND, dl, VT, R, Mask);
29272   }
29273 
29274   if (VT == MVT::v16i8 || (Subtarget.hasInt256() && VT == MVT::v32i8) ||
29275       (Subtarget.hasBWI() && VT == MVT::v64i8)) {
29276     unsigned NumElts = VT.getVectorNumElements();
29277     MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
29278 
29279     // Simple i8 add case
29280     if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1) {
29281       // R may be undef at run-time, but (shl R, 1) must be an even number (LSB
29282       // must be 0). (add undef, undef) however can be any value. To make this
29283       // safe, we must freeze R to ensure that register allocation uses the same
29284       // register for an undefined value. This ensures that the result will
29285       // still be even and preserves the original semantics.
29286       R = DAG.getFreeze(R);
29287       return DAG.getNode(ISD::ADD, dl, VT, R, R);
29288     }
29289 
29290     // ashr(R, 7)  === cmp_slt(R, 0)
29291     if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) {
29292       SDValue Zeros = DAG.getConstant(0, dl, VT);
29293       if (VT.is512BitVector()) {
29294         assert(VT == MVT::v64i8 && "Unexpected element type!");
29295         SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R, ISD::SETGT);
29296         return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP);
29297       }
29298       return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
29299     }
29300 
29301     // XOP can shift v16i8 directly instead of as shift v8i16 + mask.
29302     if (VT == MVT::v16i8 && Subtarget.hasXOP())
29303       return SDValue();
29304 
29305     if (Subtarget.hasGFNI()) {
29306       SDValue Mask = getGFNICtrlMask(Op.getOpcode(), DAG, dl, VT, ShiftAmt);
29307       return DAG.getNode(X86ISD::GF2P8AFFINEQB, dl, VT, R, Mask,
29308                          DAG.getTargetConstant(0, dl, MVT::i8));
29309     }
29310 
29311     if (Op.getOpcode() == ISD::SHL) {
29312       // Make a large shift.
29313       SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT, R,
29314                                                ShiftAmt, DAG);
29315       SHL = DAG.getBitcast(VT, SHL);
29316       // Zero out the rightmost bits.
29317       APInt Mask = APInt::getHighBitsSet(8, 8 - ShiftAmt);
29318       return DAG.getNode(ISD::AND, dl, VT, SHL, DAG.getConstant(Mask, dl, VT));
29319     }
29320     if (Op.getOpcode() == ISD::SRL) {
29321       // Make a large shift.
29322       SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ShiftVT, R,
29323                                                ShiftAmt, DAG);
29324       SRL = DAG.getBitcast(VT, SRL);
29325       // Zero out the leftmost bits.
29326       APInt Mask = APInt::getLowBitsSet(8, 8 - ShiftAmt);
29327       return DAG.getNode(ISD::AND, dl, VT, SRL, DAG.getConstant(Mask, dl, VT));
29328     }
29329     if (Op.getOpcode() == ISD::SRA) {
29330       // ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask)
29331       SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt);
29332 
29333       SDValue Mask = DAG.getConstant(128 >> ShiftAmt, dl, VT);
29334       Res = DAG.getNode(ISD::XOR, dl, VT, Res, Mask);
29335       Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
29336       return Res;
29337     }
29338     llvm_unreachable("Unknown shift opcode.");
29339   }
29340 
29341   return SDValue();
29342 }
29343 
LowerShiftByScalarVariable(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)29344 static SDValue LowerShiftByScalarVariable(SDValue Op, SelectionDAG &DAG,
29345                                           const X86Subtarget &Subtarget) {
29346   MVT VT = Op.getSimpleValueType();
29347   SDLoc dl(Op);
29348   SDValue R = Op.getOperand(0);
29349   SDValue Amt = Op.getOperand(1);
29350   unsigned Opcode = Op.getOpcode();
29351   unsigned X86OpcI = getTargetVShiftUniformOpcode(Opcode, false);
29352 
29353   int BaseShAmtIdx = -1;
29354   if (SDValue BaseShAmt = DAG.getSplatSourceVector(Amt, BaseShAmtIdx)) {
29355     if (supportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode))
29356       return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, BaseShAmtIdx,
29357                                  Subtarget, DAG);
29358 
29359     // vXi8 shifts - shift as v8i16 + mask result.
29360     if (((VT == MVT::v16i8 && !Subtarget.canExtendTo512DQ()) ||
29361          (VT == MVT::v32i8 && !Subtarget.canExtendTo512BW()) ||
29362          VT == MVT::v64i8) &&
29363         !Subtarget.hasXOP()) {
29364       unsigned NumElts = VT.getVectorNumElements();
29365       MVT ExtVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
29366       if (supportedVectorShiftWithBaseAmnt(ExtVT, Subtarget, Opcode)) {
29367         unsigned LogicalOp = (Opcode == ISD::SHL ? ISD::SHL : ISD::SRL);
29368         unsigned LogicalX86Op = getTargetVShiftUniformOpcode(LogicalOp, false);
29369 
29370         // Create the mask using vXi16 shifts. For shift-rights we need to move
29371         // the upper byte down before splatting the vXi8 mask.
29372         SDValue BitMask = DAG.getConstant(-1, dl, ExtVT);
29373         BitMask = getTargetVShiftNode(LogicalX86Op, dl, ExtVT, BitMask,
29374                                       BaseShAmt, BaseShAmtIdx, Subtarget, DAG);
29375         if (Opcode != ISD::SHL)
29376           BitMask = getTargetVShiftByConstNode(LogicalX86Op, dl, ExtVT, BitMask,
29377                                                8, DAG);
29378         BitMask = DAG.getBitcast(VT, BitMask);
29379         BitMask = DAG.getVectorShuffle(VT, dl, BitMask, BitMask,
29380                                        SmallVector<int, 64>(NumElts, 0));
29381 
29382         SDValue Res = getTargetVShiftNode(LogicalX86Op, dl, ExtVT,
29383                                           DAG.getBitcast(ExtVT, R), BaseShAmt,
29384                                           BaseShAmtIdx, Subtarget, DAG);
29385         Res = DAG.getBitcast(VT, Res);
29386         Res = DAG.getNode(ISD::AND, dl, VT, Res, BitMask);
29387 
29388         if (Opcode == ISD::SRA) {
29389           // ashr(R, Amt) === sub(xor(lshr(R, Amt), SignMask), SignMask)
29390           // SignMask = lshr(SignBit, Amt) - safe to do this with PSRLW.
29391           SDValue SignMask = DAG.getConstant(0x8080, dl, ExtVT);
29392           SignMask =
29393               getTargetVShiftNode(LogicalX86Op, dl, ExtVT, SignMask, BaseShAmt,
29394                                   BaseShAmtIdx, Subtarget, DAG);
29395           SignMask = DAG.getBitcast(VT, SignMask);
29396           Res = DAG.getNode(ISD::XOR, dl, VT, Res, SignMask);
29397           Res = DAG.getNode(ISD::SUB, dl, VT, Res, SignMask);
29398         }
29399         return Res;
29400       }
29401     }
29402   }
29403 
29404   return SDValue();
29405 }
29406 
29407 // Convert a shift/rotate left amount to a multiplication scale factor.
convertShiftLeftToScale(SDValue Amt,const SDLoc & dl,const X86Subtarget & Subtarget,SelectionDAG & DAG)29408 static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl,
29409                                        const X86Subtarget &Subtarget,
29410                                        SelectionDAG &DAG) {
29411   MVT VT = Amt.getSimpleValueType();
29412   if (!(VT == MVT::v8i16 || VT == MVT::v4i32 ||
29413         (Subtarget.hasInt256() && VT == MVT::v16i16) ||
29414         (Subtarget.hasAVX512() && VT == MVT::v32i16) ||
29415         (!Subtarget.hasAVX512() && VT == MVT::v16i8) ||
29416         (Subtarget.hasInt256() && VT == MVT::v32i8) ||
29417         (Subtarget.hasBWI() && VT == MVT::v64i8)))
29418     return SDValue();
29419 
29420   MVT SVT = VT.getVectorElementType();
29421   unsigned SVTBits = SVT.getSizeInBits();
29422   unsigned NumElems = VT.getVectorNumElements();
29423 
29424   APInt UndefElts;
29425   SmallVector<APInt> EltBits;
29426   if (getTargetConstantBitsFromNode(Amt, SVTBits, UndefElts, EltBits)) {
29427     APInt One(SVTBits, 1);
29428     SmallVector<SDValue> Elts(NumElems, DAG.getUNDEF(SVT));
29429     for (unsigned I = 0; I != NumElems; ++I) {
29430       if (UndefElts[I] || EltBits[I].uge(SVTBits))
29431         continue;
29432       uint64_t ShAmt = EltBits[I].getZExtValue();
29433       Elts[I] = DAG.getConstant(One.shl(ShAmt), dl, SVT);
29434     }
29435     return DAG.getBuildVector(VT, dl, Elts);
29436   }
29437 
29438   // If the target doesn't support variable shifts, use either FP conversion
29439   // or integer multiplication to avoid shifting each element individually.
29440   if (VT == MVT::v4i32) {
29441     Amt = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(23, dl, VT));
29442     Amt = DAG.getNode(ISD::ADD, dl, VT, Amt,
29443                       DAG.getConstant(0x3f800000U, dl, VT));
29444     Amt = DAG.getBitcast(MVT::v4f32, Amt);
29445     return DAG.getNode(ISD::FP_TO_SINT, dl, VT, Amt);
29446   }
29447 
29448   // AVX2 can more effectively perform this as a zext/trunc to/from v8i32.
29449   if (VT == MVT::v8i16 && !Subtarget.hasAVX2()) {
29450     SDValue Z = DAG.getConstant(0, dl, VT);
29451     SDValue Lo = DAG.getBitcast(MVT::v4i32, getUnpackl(DAG, dl, VT, Amt, Z));
29452     SDValue Hi = DAG.getBitcast(MVT::v4i32, getUnpackh(DAG, dl, VT, Amt, Z));
29453     Lo = convertShiftLeftToScale(Lo, dl, Subtarget, DAG);
29454     Hi = convertShiftLeftToScale(Hi, dl, Subtarget, DAG);
29455     if (Subtarget.hasSSE41())
29456       return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi);
29457     return getPack(DAG, Subtarget, dl, VT, Lo, Hi);
29458   }
29459 
29460   return SDValue();
29461 }
29462 
LowerShift(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)29463 static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
29464                           SelectionDAG &DAG) {
29465   MVT VT = Op.getSimpleValueType();
29466   SDLoc dl(Op);
29467   SDValue R = Op.getOperand(0);
29468   SDValue Amt = Op.getOperand(1);
29469   unsigned EltSizeInBits = VT.getScalarSizeInBits();
29470   bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
29471 
29472   unsigned Opc = Op.getOpcode();
29473   unsigned X86OpcV = getTargetVShiftUniformOpcode(Opc, true);
29474   unsigned X86OpcI = getTargetVShiftUniformOpcode(Opc, false);
29475 
29476   assert(VT.isVector() && "Custom lowering only for vector shifts!");
29477   assert(Subtarget.hasSSE2() && "Only custom lower when we have SSE2!");
29478 
29479   if (SDValue V = LowerShiftByScalarImmediate(Op, DAG, Subtarget))
29480     return V;
29481 
29482   if (SDValue V = LowerShiftByScalarVariable(Op, DAG, Subtarget))
29483     return V;
29484 
29485   if (supportedVectorVarShift(VT, Subtarget, Opc))
29486     return Op;
29487 
29488   // i64 vector arithmetic shift can be emulated with the transform:
29489   // M = lshr(SIGN_MASK, Amt)
29490   // ashr(R, Amt) === sub(xor(lshr(R, Amt), M), M)
29491   if (((VT == MVT::v2i64 && !Subtarget.hasXOP()) ||
29492        (VT == MVT::v4i64 && Subtarget.hasInt256())) &&
29493       Opc == ISD::SRA) {
29494     SDValue S = DAG.getConstant(APInt::getSignMask(64), dl, VT);
29495     SDValue M = DAG.getNode(ISD::SRL, dl, VT, S, Amt);
29496     R = DAG.getNode(ISD::SRL, dl, VT, R, Amt);
29497     R = DAG.getNode(ISD::XOR, dl, VT, R, M);
29498     R = DAG.getNode(ISD::SUB, dl, VT, R, M);
29499     return R;
29500   }
29501 
29502   // XOP has 128-bit variable logical/arithmetic shifts.
29503   // +ve/-ve Amt = shift left/right.
29504   if (Subtarget.hasXOP() && (VT == MVT::v2i64 || VT == MVT::v4i32 ||
29505                              VT == MVT::v8i16 || VT == MVT::v16i8)) {
29506     if (Opc == ISD::SRL || Opc == ISD::SRA)
29507       Amt = DAG.getNegative(Amt, dl, VT);
29508     if (Opc == ISD::SHL || Opc == ISD::SRL)
29509       return DAG.getNode(X86ISD::VPSHL, dl, VT, R, Amt);
29510     if (Opc == ISD::SRA)
29511       return DAG.getNode(X86ISD::VPSHA, dl, VT, R, Amt);
29512   }
29513 
29514   // 2i64 vector logical shifts can efficiently avoid scalarization - do the
29515   // shifts per-lane and then shuffle the partial results back together.
29516   if (VT == MVT::v2i64 && Opc != ISD::SRA) {
29517     // Splat the shift amounts so the scalar shifts above will catch it.
29518     SDValue Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Amt, {0, 0});
29519     SDValue Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Amt, {1, 1});
29520     SDValue R0 = DAG.getNode(Opc, dl, VT, R, Amt0);
29521     SDValue R1 = DAG.getNode(Opc, dl, VT, R, Amt1);
29522     return DAG.getVectorShuffle(VT, dl, R0, R1, {0, 3});
29523   }
29524 
29525   // If possible, lower this shift as a sequence of two shifts by
29526   // constant plus a BLENDing shuffle instead of scalarizing it.
29527   // Example:
29528   //   (v4i32 (srl A, (build_vector < X, Y, Y, Y>)))
29529   //
29530   // Could be rewritten as:
29531   //   (v4i32 (MOVSS (srl A, <Y,Y,Y,Y>), (srl A, <X,X,X,X>)))
29532   //
29533   // The advantage is that the two shifts from the example would be
29534   // lowered as X86ISD::VSRLI nodes in parallel before blending.
29535   if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32 ||
29536                       (VT == MVT::v16i16 && Subtarget.hasInt256()))) {
29537     SDValue Amt1, Amt2;
29538     unsigned NumElts = VT.getVectorNumElements();
29539     SmallVector<int, 8> ShuffleMask;
29540     for (unsigned i = 0; i != NumElts; ++i) {
29541       SDValue A = Amt->getOperand(i);
29542       if (A.isUndef()) {
29543         ShuffleMask.push_back(SM_SentinelUndef);
29544         continue;
29545       }
29546       if (!Amt1 || Amt1 == A) {
29547         ShuffleMask.push_back(i);
29548         Amt1 = A;
29549         continue;
29550       }
29551       if (!Amt2 || Amt2 == A) {
29552         ShuffleMask.push_back(i + NumElts);
29553         Amt2 = A;
29554         continue;
29555       }
29556       break;
29557     }
29558 
29559     // Only perform this blend if we can perform it without loading a mask.
29560     if (ShuffleMask.size() == NumElts && Amt1 && Amt2 &&
29561         (VT != MVT::v16i16 ||
29562          is128BitLaneRepeatedShuffleMask(VT, ShuffleMask)) &&
29563         (VT == MVT::v4i32 || Subtarget.hasSSE41() || Opc != ISD::SHL ||
29564          canWidenShuffleElements(ShuffleMask))) {
29565       auto *Cst1 = dyn_cast<ConstantSDNode>(Amt1);
29566       auto *Cst2 = dyn_cast<ConstantSDNode>(Amt2);
29567       if (Cst1 && Cst2 && Cst1->getAPIntValue().ult(EltSizeInBits) &&
29568           Cst2->getAPIntValue().ult(EltSizeInBits)) {
29569         SDValue Shift1 = getTargetVShiftByConstNode(X86OpcI, dl, VT, R,
29570                                                     Cst1->getZExtValue(), DAG);
29571         SDValue Shift2 = getTargetVShiftByConstNode(X86OpcI, dl, VT, R,
29572                                                     Cst2->getZExtValue(), DAG);
29573         return DAG.getVectorShuffle(VT, dl, Shift1, Shift2, ShuffleMask);
29574       }
29575     }
29576   }
29577 
29578   // If possible, lower this packed shift into a vector multiply instead of
29579   // expanding it into a sequence of scalar shifts.
29580   // For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
29581   if (Opc == ISD::SHL && !(VT == MVT::v32i8 && (Subtarget.hasXOP() ||
29582                                                 Subtarget.canExtendTo512BW())))
29583     if (SDValue Scale = convertShiftLeftToScale(Amt, dl, Subtarget, DAG))
29584       return DAG.getNode(ISD::MUL, dl, VT, R, Scale);
29585 
29586   // Constant ISD::SRL can be performed efficiently on vXi16 vectors as we
29587   // can replace with ISD::MULHU, creating scale factor from (NumEltBits - Amt).
29588   if (Opc == ISD::SRL && ConstantAmt &&
29589       (VT == MVT::v8i16 || (VT == MVT::v16i16 && Subtarget.hasInt256()))) {
29590     SDValue EltBits = DAG.getConstant(EltSizeInBits, dl, VT);
29591     SDValue RAmt = DAG.getNode(ISD::SUB, dl, VT, EltBits, Amt);
29592     if (SDValue Scale = convertShiftLeftToScale(RAmt, dl, Subtarget, DAG)) {
29593       SDValue Zero = DAG.getConstant(0, dl, VT);
29594       SDValue ZAmt = DAG.getSetCC(dl, VT, Amt, Zero, ISD::SETEQ);
29595       SDValue Res = DAG.getNode(ISD::MULHU, dl, VT, R, Scale);
29596       return DAG.getSelect(dl, VT, ZAmt, R, Res);
29597     }
29598   }
29599 
29600   // Constant ISD::SRA can be performed efficiently on vXi16 vectors as we
29601   // can replace with ISD::MULHS, creating scale factor from (NumEltBits - Amt).
29602   // TODO: Special case handling for shift by 0/1, really we can afford either
29603   // of these cases in pre-SSE41/XOP/AVX512 but not both.
29604   if (Opc == ISD::SRA && ConstantAmt &&
29605       (VT == MVT::v8i16 || (VT == MVT::v16i16 && Subtarget.hasInt256())) &&
29606       ((Subtarget.hasSSE41() && !Subtarget.hasXOP() &&
29607         !Subtarget.hasAVX512()) ||
29608        DAG.isKnownNeverZero(Amt))) {
29609     SDValue EltBits = DAG.getConstant(EltSizeInBits, dl, VT);
29610     SDValue RAmt = DAG.getNode(ISD::SUB, dl, VT, EltBits, Amt);
29611     if (SDValue Scale = convertShiftLeftToScale(RAmt, dl, Subtarget, DAG)) {
29612       SDValue Amt0 =
29613           DAG.getSetCC(dl, VT, Amt, DAG.getConstant(0, dl, VT), ISD::SETEQ);
29614       SDValue Amt1 =
29615           DAG.getSetCC(dl, VT, Amt, DAG.getConstant(1, dl, VT), ISD::SETEQ);
29616       SDValue Sra1 =
29617           getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, 1, DAG);
29618       SDValue Res = DAG.getNode(ISD::MULHS, dl, VT, R, Scale);
29619       Res = DAG.getSelect(dl, VT, Amt0, R, Res);
29620       return DAG.getSelect(dl, VT, Amt1, Sra1, Res);
29621     }
29622   }
29623 
29624   // v4i32 Non Uniform Shifts.
29625   // If the shift amount is constant we can shift each lane using the SSE2
29626   // immediate shifts, else we need to zero-extend each lane to the lower i64
29627   // and shift using the SSE2 variable shifts.
29628   // The separate results can then be blended together.
29629   if (VT == MVT::v4i32) {
29630     SDValue Amt0, Amt1, Amt2, Amt3;
29631     if (ConstantAmt) {
29632       Amt0 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {0, 0, 0, 0});
29633       Amt1 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {1, 1, 1, 1});
29634       Amt2 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {2, 2, 2, 2});
29635       Amt3 = DAG.getVectorShuffle(VT, dl, Amt, DAG.getUNDEF(VT), {3, 3, 3, 3});
29636     } else {
29637       // The SSE2 shifts use the lower i64 as the same shift amount for
29638       // all lanes and the upper i64 is ignored. On AVX we're better off
29639       // just zero-extending, but for SSE just duplicating the top 16-bits is
29640       // cheaper and has the same effect for out of range values.
29641       if (Subtarget.hasAVX()) {
29642         SDValue Z = DAG.getConstant(0, dl, VT);
29643         Amt0 = DAG.getVectorShuffle(VT, dl, Amt, Z, {0, 4, -1, -1});
29644         Amt1 = DAG.getVectorShuffle(VT, dl, Amt, Z, {1, 5, -1, -1});
29645         Amt2 = DAG.getVectorShuffle(VT, dl, Amt, Z, {2, 6, -1, -1});
29646         Amt3 = DAG.getVectorShuffle(VT, dl, Amt, Z, {3, 7, -1, -1});
29647       } else {
29648         SDValue Amt01 = DAG.getBitcast(MVT::v8i16, Amt);
29649         SDValue Amt23 = DAG.getVectorShuffle(MVT::v8i16, dl, Amt01, Amt01,
29650                                              {4, 5, 6, 7, -1, -1, -1, -1});
29651         SDValue Msk02 = getV4X86ShuffleImm8ForMask({0, 1, 1, 1}, dl, DAG);
29652         SDValue Msk13 = getV4X86ShuffleImm8ForMask({2, 3, 3, 3}, dl, DAG);
29653         Amt0 = DAG.getNode(X86ISD::PSHUFLW, dl, MVT::v8i16, Amt01, Msk02);
29654         Amt1 = DAG.getNode(X86ISD::PSHUFLW, dl, MVT::v8i16, Amt01, Msk13);
29655         Amt2 = DAG.getNode(X86ISD::PSHUFLW, dl, MVT::v8i16, Amt23, Msk02);
29656         Amt3 = DAG.getNode(X86ISD::PSHUFLW, dl, MVT::v8i16, Amt23, Msk13);
29657       }
29658     }
29659 
29660     unsigned ShOpc = ConstantAmt ? Opc : X86OpcV;
29661     SDValue R0 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt0));
29662     SDValue R1 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt1));
29663     SDValue R2 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt2));
29664     SDValue R3 = DAG.getNode(ShOpc, dl, VT, R, DAG.getBitcast(VT, Amt3));
29665 
29666     // Merge the shifted lane results optimally with/without PBLENDW.
29667     // TODO - ideally shuffle combining would handle this.
29668     if (Subtarget.hasSSE41()) {
29669       SDValue R02 = DAG.getVectorShuffle(VT, dl, R0, R2, {0, -1, 6, -1});
29670       SDValue R13 = DAG.getVectorShuffle(VT, dl, R1, R3, {-1, 1, -1, 7});
29671       return DAG.getVectorShuffle(VT, dl, R02, R13, {0, 5, 2, 7});
29672     }
29673     SDValue R01 = DAG.getVectorShuffle(VT, dl, R0, R1, {0, -1, -1, 5});
29674     SDValue R23 = DAG.getVectorShuffle(VT, dl, R2, R3, {2, -1, -1, 7});
29675     return DAG.getVectorShuffle(VT, dl, R01, R23, {0, 3, 4, 7});
29676   }
29677 
29678   // It's worth extending once and using the vXi16/vXi32 shifts for smaller
29679   // types, but without AVX512 the extra overheads to get from vXi8 to vXi32
29680   // make the existing SSE solution better.
29681   // NOTE: We honor prefered vector width before promoting to 512-bits.
29682   if ((Subtarget.hasInt256() && VT == MVT::v8i16) ||
29683       (Subtarget.canExtendTo512DQ() && VT == MVT::v16i16) ||
29684       (Subtarget.canExtendTo512DQ() && VT == MVT::v16i8) ||
29685       (Subtarget.canExtendTo512BW() && VT == MVT::v32i8) ||
29686       (Subtarget.hasBWI() && Subtarget.hasVLX() && VT == MVT::v16i8)) {
29687     assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) &&
29688            "Unexpected vector type");
29689     MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32;
29690     MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements());
29691     unsigned ExtOpc = Opc == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
29692     R = DAG.getNode(ExtOpc, dl, ExtVT, R);
29693     Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt);
29694     return DAG.getNode(ISD::TRUNCATE, dl, VT,
29695                        DAG.getNode(Opc, dl, ExtVT, R, Amt));
29696   }
29697 
29698   // Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we
29699   // extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI.
29700   if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) &&
29701       (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29702        (VT == MVT::v64i8 && Subtarget.hasBWI())) &&
29703       !Subtarget.hasXOP()) {
29704     int NumElts = VT.getVectorNumElements();
29705     SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);
29706 
29707     // Extend constant shift amount to vXi16 (it doesn't matter if the type
29708     // isn't legal).
29709     MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
29710     Amt = DAG.getZExtOrTrunc(Amt, dl, ExVT);
29711     Amt = DAG.getNode(ISD::SUB, dl, ExVT, DAG.getConstant(8, dl, ExVT), Amt);
29712     Amt = DAG.getNode(ISD::SHL, dl, ExVT, DAG.getConstant(1, dl, ExVT), Amt);
29713     assert(ISD::isBuildVectorOfConstantSDNodes(Amt.getNode()) &&
29714            "Constant build vector expected");
29715 
29716     if (VT == MVT::v16i8 && Subtarget.hasInt256()) {
29717       bool IsSigned = Opc == ISD::SRA;
29718       R = DAG.getExtOrTrunc(IsSigned, R, dl, ExVT);
29719       R = DAG.getNode(ISD::MUL, dl, ExVT, R, Amt);
29720       R = DAG.getNode(X86ISD::VSRLI, dl, ExVT, R, Cst8);
29721       return DAG.getZExtOrTrunc(R, dl, VT);
29722     }
29723 
29724     SmallVector<SDValue, 16> LoAmt, HiAmt;
29725     for (int i = 0; i != NumElts; i += 16) {
29726       for (int j = 0; j != 8; ++j) {
29727         LoAmt.push_back(Amt.getOperand(i + j));
29728         HiAmt.push_back(Amt.getOperand(i + j + 8));
29729       }
29730     }
29731 
29732     MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29733     SDValue LoA = DAG.getBuildVector(VT16, dl, LoAmt);
29734     SDValue HiA = DAG.getBuildVector(VT16, dl, HiAmt);
29735 
29736     SDValue LoR = DAG.getBitcast(VT16, getUnpackl(DAG, dl, VT, R, R));
29737     SDValue HiR = DAG.getBitcast(VT16, getUnpackh(DAG, dl, VT, R, R));
29738     LoR = DAG.getNode(X86OpcI, dl, VT16, LoR, Cst8);
29739     HiR = DAG.getNode(X86OpcI, dl, VT16, HiR, Cst8);
29740     LoR = DAG.getNode(ISD::MUL, dl, VT16, LoR, LoA);
29741     HiR = DAG.getNode(ISD::MUL, dl, VT16, HiR, HiA);
29742     LoR = DAG.getNode(X86ISD::VSRLI, dl, VT16, LoR, Cst8);
29743     HiR = DAG.getNode(X86ISD::VSRLI, dl, VT16, HiR, Cst8);
29744     return DAG.getNode(X86ISD::PACKUS, dl, VT, LoR, HiR);
29745   }
29746 
29747   if (VT == MVT::v16i8 ||
29748       (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) ||
29749       (VT == MVT::v64i8 && Subtarget.hasBWI())) {
29750     MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
29751 
29752     auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
29753       if (VT.is512BitVector()) {
29754         // On AVX512BW targets we make use of the fact that VSELECT lowers
29755         // to a masked blend which selects bytes based just on the sign bit
29756         // extracted to a mask.
29757         MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
29758         V0 = DAG.getBitcast(VT, V0);
29759         V1 = DAG.getBitcast(VT, V1);
29760         Sel = DAG.getBitcast(VT, Sel);
29761         Sel = DAG.getSetCC(dl, MaskVT, DAG.getConstant(0, dl, VT), Sel,
29762                            ISD::SETGT);
29763         return DAG.getBitcast(SelVT, DAG.getSelect(dl, VT, Sel, V0, V1));
29764       } else if (Subtarget.hasSSE41()) {
29765         // On SSE41 targets we can use PBLENDVB which selects bytes based just
29766         // on the sign bit.
29767         V0 = DAG.getBitcast(VT, V0);
29768         V1 = DAG.getBitcast(VT, V1);
29769         Sel = DAG.getBitcast(VT, Sel);
29770         return DAG.getBitcast(SelVT,
29771                               DAG.getNode(X86ISD::BLENDV, dl, VT, Sel, V0, V1));
29772       }
29773       // On pre-SSE41 targets we test for the sign bit by comparing to
29774       // zero - a negative value will set all bits of the lanes to true
29775       // and VSELECT uses that in its OR(AND(V0,C),AND(V1,~C)) lowering.
29776       SDValue Z = DAG.getConstant(0, dl, SelVT);
29777       SDValue C = DAG.getNode(X86ISD::PCMPGT, dl, SelVT, Z, Sel);
29778       return DAG.getSelect(dl, SelVT, C, V0, V1);
29779     };
29780 
29781     // Turn 'a' into a mask suitable for VSELECT: a = a << 5;
29782     // We can safely do this using i16 shifts as we're only interested in
29783     // the 3 lower bits of each byte.
29784     Amt = DAG.getBitcast(ExtVT, Amt);
29785     Amt = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ExtVT, Amt, 5, DAG);
29786     Amt = DAG.getBitcast(VT, Amt);
29787 
29788     if (Opc == ISD::SHL || Opc == ISD::SRL) {
29789       // r = VSELECT(r, shift(r, 4), a);
29790       SDValue M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(4, dl, VT));
29791       R = SignBitSelect(VT, Amt, M, R);
29792 
29793       // a += a
29794       Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);
29795 
29796       // r = VSELECT(r, shift(r, 2), a);
29797       M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(2, dl, VT));
29798       R = SignBitSelect(VT, Amt, M, R);
29799 
29800       // a += a
29801       Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);
29802 
29803       // return VSELECT(r, shift(r, 1), a);
29804       M = DAG.getNode(Opc, dl, VT, R, DAG.getConstant(1, dl, VT));
29805       R = SignBitSelect(VT, Amt, M, R);
29806       return R;
29807     }
29808 
29809     if (Opc == ISD::SRA) {
29810       // For SRA we need to unpack each byte to the higher byte of a i16 vector
29811       // so we can correctly sign extend. We don't care what happens to the
29812       // lower byte.
29813       SDValue ALo = getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), Amt);
29814       SDValue AHi = getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), Amt);
29815       SDValue RLo = getUnpackl(DAG, dl, VT, DAG.getUNDEF(VT), R);
29816       SDValue RHi = getUnpackh(DAG, dl, VT, DAG.getUNDEF(VT), R);
29817       ALo = DAG.getBitcast(ExtVT, ALo);
29818       AHi = DAG.getBitcast(ExtVT, AHi);
29819       RLo = DAG.getBitcast(ExtVT, RLo);
29820       RHi = DAG.getBitcast(ExtVT, RHi);
29821 
29822       // r = VSELECT(r, shift(r, 4), a);
29823       SDValue MLo = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RLo, 4, DAG);
29824       SDValue MHi = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RHi, 4, DAG);
29825       RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
29826       RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);
29827 
29828       // a += a
29829       ALo = DAG.getNode(ISD::ADD, dl, ExtVT, ALo, ALo);
29830       AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi);
29831 
29832       // r = VSELECT(r, shift(r, 2), a);
29833       MLo = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RLo, 2, DAG);
29834       MHi = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RHi, 2, DAG);
29835       RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
29836       RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);
29837 
29838       // a += a
29839       ALo = DAG.getNode(ISD::ADD, dl, ExtVT, ALo, ALo);
29840       AHi = DAG.getNode(ISD::ADD, dl, ExtVT, AHi, AHi);
29841 
29842       // r = VSELECT(r, shift(r, 1), a);
29843       MLo = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RLo, 1, DAG);
29844       MHi = getTargetVShiftByConstNode(X86OpcI, dl, ExtVT, RHi, 1, DAG);
29845       RLo = SignBitSelect(ExtVT, ALo, MLo, RLo);
29846       RHi = SignBitSelect(ExtVT, AHi, MHi, RHi);
29847 
29848       // Logical shift the result back to the lower byte, leaving a zero upper
29849       // byte meaning that we can safely pack with PACKUSWB.
29850       RLo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, RLo, 8, DAG);
29851       RHi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, RHi, 8, DAG);
29852       return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi);
29853     }
29854   }
29855 
29856   if (Subtarget.hasInt256() && !Subtarget.hasXOP() && VT == MVT::v16i16) {
29857     MVT ExtVT = MVT::v8i32;
29858     SDValue Z = DAG.getConstant(0, dl, VT);
29859     SDValue ALo = getUnpackl(DAG, dl, VT, Amt, Z);
29860     SDValue AHi = getUnpackh(DAG, dl, VT, Amt, Z);
29861     SDValue RLo = getUnpackl(DAG, dl, VT, Z, R);
29862     SDValue RHi = getUnpackh(DAG, dl, VT, Z, R);
29863     ALo = DAG.getBitcast(ExtVT, ALo);
29864     AHi = DAG.getBitcast(ExtVT, AHi);
29865     RLo = DAG.getBitcast(ExtVT, RLo);
29866     RHi = DAG.getBitcast(ExtVT, RHi);
29867     SDValue Lo = DAG.getNode(Opc, dl, ExtVT, RLo, ALo);
29868     SDValue Hi = DAG.getNode(Opc, dl, ExtVT, RHi, AHi);
29869     Lo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, Lo, 16, DAG);
29870     Hi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExtVT, Hi, 16, DAG);
29871     return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi);
29872   }
29873 
29874   if (VT == MVT::v8i16) {
29875     // If we have a constant shift amount, the non-SSE41 path is best as
29876     // avoiding bitcasts make it easier to constant fold and reduce to PBLENDW.
29877     bool UseSSE41 = Subtarget.hasSSE41() &&
29878                     !ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
29879 
29880     auto SignBitSelect = [&](SDValue Sel, SDValue V0, SDValue V1) {
29881       // On SSE41 targets we can use PBLENDVB which selects bytes based just on
29882       // the sign bit.
29883       if (UseSSE41) {
29884         MVT ExtVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() * 2);
29885         V0 = DAG.getBitcast(ExtVT, V0);
29886         V1 = DAG.getBitcast(ExtVT, V1);
29887         Sel = DAG.getBitcast(ExtVT, Sel);
29888         return DAG.getBitcast(
29889             VT, DAG.getNode(X86ISD::BLENDV, dl, ExtVT, Sel, V0, V1));
29890       }
29891       // On pre-SSE41 targets we splat the sign bit - a negative value will
29892       // set all bits of the lanes to true and VSELECT uses that in
29893       // its OR(AND(V0,C),AND(V1,~C)) lowering.
29894       SDValue C =
29895           getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, Sel, 15, DAG);
29896       return DAG.getSelect(dl, VT, C, V0, V1);
29897     };
29898 
29899     // Turn 'a' into a mask suitable for VSELECT: a = a << 12;
29900     if (UseSSE41) {
29901       // On SSE41 targets we need to replicate the shift mask in both
29902       // bytes for PBLENDVB.
29903       Amt = DAG.getNode(
29904           ISD::OR, dl, VT,
29905           getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Amt, 4, DAG),
29906           getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Amt, 12, DAG));
29907     } else {
29908       Amt = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Amt, 12, DAG);
29909     }
29910 
29911     // r = VSELECT(r, shift(r, 8), a);
29912     SDValue M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 8, DAG);
29913     R = SignBitSelect(Amt, M, R);
29914 
29915     // a += a
29916     Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);
29917 
29918     // r = VSELECT(r, shift(r, 4), a);
29919     M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 4, DAG);
29920     R = SignBitSelect(Amt, M, R);
29921 
29922     // a += a
29923     Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);
29924 
29925     // r = VSELECT(r, shift(r, 2), a);
29926     M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 2, DAG);
29927     R = SignBitSelect(Amt, M, R);
29928 
29929     // a += a
29930     Amt = DAG.getNode(ISD::ADD, dl, VT, Amt, Amt);
29931 
29932     // return VSELECT(r, shift(r, 1), a);
29933     M = getTargetVShiftByConstNode(X86OpcI, dl, VT, R, 1, DAG);
29934     R = SignBitSelect(Amt, M, R);
29935     return R;
29936   }
29937 
29938   // Decompose 256-bit shifts into 128-bit shifts.
29939   if (VT.is256BitVector())
29940     return splitVectorIntBinary(Op, DAG, dl);
29941 
29942   if (VT == MVT::v32i16 || VT == MVT::v64i8)
29943     return splitVectorIntBinary(Op, DAG, dl);
29944 
29945   return SDValue();
29946 }
29947 
LowerFunnelShift(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)29948 static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
29949                                 SelectionDAG &DAG) {
29950   MVT VT = Op.getSimpleValueType();
29951   assert((Op.getOpcode() == ISD::FSHL || Op.getOpcode() == ISD::FSHR) &&
29952          "Unexpected funnel shift opcode!");
29953 
29954   SDLoc DL(Op);
29955   SDValue Op0 = Op.getOperand(0);
29956   SDValue Op1 = Op.getOperand(1);
29957   SDValue Amt = Op.getOperand(2);
29958   unsigned EltSizeInBits = VT.getScalarSizeInBits();
29959   bool IsFSHR = Op.getOpcode() == ISD::FSHR;
29960 
29961   if (VT.isVector()) {
29962     APInt APIntShiftAmt;
29963     bool IsCstSplat = X86::isConstantSplat(Amt, APIntShiftAmt);
29964     unsigned NumElts = VT.getVectorNumElements();
29965 
29966     if (Subtarget.hasVBMI2() && EltSizeInBits > 8) {
29967       if (IsFSHR)
29968         std::swap(Op0, Op1);
29969 
29970       if (IsCstSplat) {
29971         uint64_t ShiftAmt = APIntShiftAmt.urem(EltSizeInBits);
29972         SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8);
29973         return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT,
29974                              {Op0, Op1, Imm}, DAG, Subtarget);
29975       }
29976       return getAVX512Node(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, VT,
29977                            {Op0, Op1, Amt}, DAG, Subtarget);
29978     }
29979     assert((VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8 ||
29980             VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16 ||
29981             VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) &&
29982            "Unexpected funnel shift type!");
29983 
29984     // fshl(x,y,z) -> unpack(y,x) << (z & (bw-1))) >> bw.
29985     // fshr(x,y,z) -> unpack(y,x) >> (z & (bw-1))).
29986     if (IsCstSplat) {
29987       // TODO: Can't use generic expansion as UNDEF amt elements can be
29988       // converted to other values when folded to shift amounts, losing the
29989       // splat.
29990       uint64_t ShiftAmt = APIntShiftAmt.urem(EltSizeInBits);
29991       uint64_t ShXAmt = IsFSHR ? (EltSizeInBits - ShiftAmt) : ShiftAmt;
29992       uint64_t ShYAmt = IsFSHR ? ShiftAmt : (EltSizeInBits - ShiftAmt);
29993       assert((ShXAmt + ShYAmt) == EltSizeInBits && "Illegal funnel shift");
29994       MVT WideVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
29995 
29996       if (EltSizeInBits == 8 &&
29997           (Subtarget.hasXOP() ||
29998            (useVPTERNLOG(Subtarget, VT) &&
29999             supportedVectorShiftWithImm(WideVT, Subtarget, ISD::SHL)))) {
30000         // For vXi8 cases on Subtargets that can perform VPCMOV/VPTERNLOG
30001         // bit-select - lower using vXi16 shifts and then perform the bitmask at
30002         // the original vector width to handle cases where we split.
30003         APInt MaskX = APInt::getHighBitsSet(8, 8 - ShXAmt);
30004         APInt MaskY = APInt::getLowBitsSet(8, 8 - ShYAmt);
30005         SDValue ShX =
30006             DAG.getNode(ISD::SHL, DL, WideVT, DAG.getBitcast(WideVT, Op0),
30007                         DAG.getShiftAmountConstant(ShXAmt, WideVT, DL));
30008         SDValue ShY =
30009             DAG.getNode(ISD::SRL, DL, WideVT, DAG.getBitcast(WideVT, Op1),
30010                         DAG.getShiftAmountConstant(ShYAmt, WideVT, DL));
30011         ShX = DAG.getNode(ISD::AND, DL, VT, DAG.getBitcast(VT, ShX),
30012                           DAG.getConstant(MaskX, DL, VT));
30013         ShY = DAG.getNode(ISD::AND, DL, VT, DAG.getBitcast(VT, ShY),
30014                           DAG.getConstant(MaskY, DL, VT));
30015         return DAG.getNode(ISD::OR, DL, VT, ShX, ShY);
30016       }
30017 
30018       SDValue ShX = DAG.getNode(ISD::SHL, DL, VT, Op0,
30019                                 DAG.getShiftAmountConstant(ShXAmt, VT, DL));
30020       SDValue ShY = DAG.getNode(ISD::SRL, DL, VT, Op1,
30021                                 DAG.getShiftAmountConstant(ShYAmt, VT, DL));
30022       return DAG.getNode(ISD::OR, DL, VT, ShX, ShY);
30023     }
30024 
30025     SDValue AmtMask = DAG.getConstant(EltSizeInBits - 1, DL, VT);
30026     SDValue AmtMod = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);
30027     bool IsCst = ISD::isBuildVectorOfConstantSDNodes(AmtMod.getNode());
30028 
30029     // Constant vXi16 funnel shifts can be efficiently handled by default.
30030     if (IsCst && EltSizeInBits == 16)
30031       return SDValue();
30032 
30033     unsigned ShiftOpc = IsFSHR ? ISD::SRL : ISD::SHL;
30034     MVT ExtSVT = MVT::getIntegerVT(2 * EltSizeInBits);
30035     MVT ExtVT = MVT::getVectorVT(ExtSVT, NumElts / 2);
30036 
30037     // Split 256-bit integers on XOP/pre-AVX2 targets.
30038     // Split 512-bit integers on non 512-bit BWI targets.
30039     if ((VT.is256BitVector() && ((Subtarget.hasXOP() && EltSizeInBits < 16) ||
30040                                  !Subtarget.hasAVX2())) ||
30041         (VT.is512BitVector() && !Subtarget.useBWIRegs() &&
30042          EltSizeInBits < 32)) {
30043       // Pre-mask the amount modulo using the wider vector.
30044       Op = DAG.getNode(Op.getOpcode(), DL, VT, Op0, Op1, AmtMod);
30045       return splitVectorOp(Op, DAG, DL);
30046     }
30047 
30048     // Attempt to fold scalar shift as unpack(y,x) << zext(splat(z))
30049     if (supportedVectorShiftWithBaseAmnt(ExtVT, Subtarget, ShiftOpc)) {
30050       int ScalarAmtIdx = -1;
30051       if (SDValue ScalarAmt = DAG.getSplatSourceVector(AmtMod, ScalarAmtIdx)) {
30052         // Uniform vXi16 funnel shifts can be efficiently handled by default.
30053         if (EltSizeInBits == 16)
30054           return SDValue();
30055 
30056         SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, Op1, Op0));
30057         SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, Op1, Op0));
30058         Lo = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Lo, ScalarAmt,
30059                                  ScalarAmtIdx, Subtarget, DAG);
30060         Hi = getTargetVShiftNode(ShiftOpc, DL, ExtVT, Hi, ScalarAmt,
30061                                  ScalarAmtIdx, Subtarget, DAG);
30062         return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !IsFSHR);
30063       }
30064     }
30065 
30066     MVT WideSVT = MVT::getIntegerVT(
30067         std::min<unsigned>(EltSizeInBits * 2, Subtarget.hasBWI() ? 16 : 32));
30068     MVT WideVT = MVT::getVectorVT(WideSVT, NumElts);
30069 
30070     // If per-element shifts are legal, fallback to generic expansion.
30071     if (supportedVectorVarShift(VT, Subtarget, ShiftOpc) || Subtarget.hasXOP())
30072       return SDValue();
30073 
30074     // Attempt to fold as:
30075     // fshl(x,y,z) -> (((aext(x) << bw) | zext(y)) << (z & (bw-1))) >> bw.
30076     // fshr(x,y,z) -> (((aext(x) << bw) | zext(y)) >> (z & (bw-1))).
30077     if (supportedVectorVarShift(WideVT, Subtarget, ShiftOpc) &&
30078         supportedVectorShiftWithImm(WideVT, Subtarget, ShiftOpc)) {
30079       Op0 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Op0);
30080       Op1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, Op1);
30081       AmtMod = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, AmtMod);
30082       Op0 = getTargetVShiftByConstNode(X86ISD::VSHLI, DL, WideVT, Op0,
30083                                        EltSizeInBits, DAG);
30084       SDValue Res = DAG.getNode(ISD::OR, DL, WideVT, Op0, Op1);
30085       Res = DAG.getNode(ShiftOpc, DL, WideVT, Res, AmtMod);
30086       if (!IsFSHR)
30087         Res = getTargetVShiftByConstNode(X86ISD::VSRLI, DL, WideVT, Res,
30088                                          EltSizeInBits, DAG);
30089       return DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
30090     }
30091 
30092     // Attempt to fold per-element (ExtVT) shift as unpack(y,x) << zext(z)
30093     if (((IsCst || !Subtarget.hasAVX512()) && !IsFSHR && EltSizeInBits <= 16) ||
30094         supportedVectorVarShift(ExtVT, Subtarget, ShiftOpc)) {
30095       SDValue Z = DAG.getConstant(0, DL, VT);
30096       SDValue RLo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, Op1, Op0));
30097       SDValue RHi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, Op1, Op0));
30098       SDValue ALo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, AmtMod, Z));
30099       SDValue AHi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, AmtMod, Z));
30100       SDValue Lo = DAG.getNode(ShiftOpc, DL, ExtVT, RLo, ALo);
30101       SDValue Hi = DAG.getNode(ShiftOpc, DL, ExtVT, RHi, AHi);
30102       return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !IsFSHR);
30103     }
30104 
30105     // Fallback to generic expansion.
30106     return SDValue();
30107   }
30108   assert(
30109       (VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&
30110       "Unexpected funnel shift type!");
30111 
30112   // Expand slow SHLD/SHRD cases if we are not optimizing for size.
30113   bool OptForSize = DAG.shouldOptForSize();
30114   bool ExpandFunnel = !OptForSize && Subtarget.isSHLDSlow();
30115 
30116   // fshl(x,y,z) -> (((aext(x) << bw) | zext(y)) << (z & (bw-1))) >> bw.
30117   // fshr(x,y,z) -> (((aext(x) << bw) | zext(y)) >> (z & (bw-1))).
30118   if ((VT == MVT::i8 || (ExpandFunnel && VT == MVT::i16)) &&
30119       !isa<ConstantSDNode>(Amt)) {
30120     SDValue Mask = DAG.getConstant(EltSizeInBits - 1, DL, Amt.getValueType());
30121     SDValue HiShift = DAG.getConstant(EltSizeInBits, DL, Amt.getValueType());
30122     Op0 = DAG.getAnyExtOrTrunc(Op0, DL, MVT::i32);
30123     Op1 = DAG.getZExtOrTrunc(Op1, DL, MVT::i32);
30124     Amt = DAG.getNode(ISD::AND, DL, Amt.getValueType(), Amt, Mask);
30125     SDValue Res = DAG.getNode(ISD::SHL, DL, MVT::i32, Op0, HiShift);
30126     Res = DAG.getNode(ISD::OR, DL, MVT::i32, Res, Op1);
30127     if (IsFSHR) {
30128       Res = DAG.getNode(ISD::SRL, DL, MVT::i32, Res, Amt);
30129     } else {
30130       Res = DAG.getNode(ISD::SHL, DL, MVT::i32, Res, Amt);
30131       Res = DAG.getNode(ISD::SRL, DL, MVT::i32, Res, HiShift);
30132     }
30133     return DAG.getZExtOrTrunc(Res, DL, VT);
30134   }
30135 
30136   if (VT == MVT::i8 || ExpandFunnel)
30137     return SDValue();
30138 
30139   // i16 needs to modulo the shift amount, but i32/i64 have implicit modulo.
30140   if (VT == MVT::i16) {
30141     Amt = DAG.getNode(ISD::AND, DL, Amt.getValueType(), Amt,
30142                       DAG.getConstant(15, DL, Amt.getValueType()));
30143     unsigned FSHOp = (IsFSHR ? X86ISD::FSHR : X86ISD::FSHL);
30144     return DAG.getNode(FSHOp, DL, VT, Op0, Op1, Amt);
30145   }
30146 
30147   return Op;
30148 }
30149 
LowerRotate(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)30150 static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
30151                            SelectionDAG &DAG) {
30152   MVT VT = Op.getSimpleValueType();
30153   assert(VT.isVector() && "Custom lowering only for vector rotates!");
30154 
30155   SDLoc DL(Op);
30156   SDValue R = Op.getOperand(0);
30157   SDValue Amt = Op.getOperand(1);
30158   unsigned Opcode = Op.getOpcode();
30159   unsigned EltSizeInBits = VT.getScalarSizeInBits();
30160   int NumElts = VT.getVectorNumElements();
30161   bool IsROTL = Opcode == ISD::ROTL;
30162 
30163   // Check for constant splat rotation amount.
30164   APInt CstSplatValue;
30165   bool IsCstSplat = X86::isConstantSplat(Amt, CstSplatValue);
30166 
30167   // Check for splat rotate by zero.
30168   if (IsCstSplat && CstSplatValue.urem(EltSizeInBits) == 0)
30169     return R;
30170 
30171   // AVX512 implicitly uses modulo rotation amounts.
30172   if ((Subtarget.hasVLX() ||
30173        (Subtarget.hasAVX512() && Subtarget.hasEVEX512())) &&
30174       32 <= EltSizeInBits) {
30175     // Attempt to rotate by immediate.
30176     if (IsCstSplat) {
30177       unsigned RotOpc = IsROTL ? X86ISD::VROTLI : X86ISD::VROTRI;
30178       uint64_t RotAmt = CstSplatValue.urem(EltSizeInBits);
30179       return DAG.getNode(RotOpc, DL, VT, R,
30180                          DAG.getTargetConstant(RotAmt, DL, MVT::i8));
30181     }
30182 
30183     // Else, fall-back on VPROLV/VPRORV.
30184     return Op;
30185   }
30186 
30187   // AVX512 VBMI2 vXi16 - lower to funnel shifts.
30188   if (Subtarget.hasVBMI2() && 16 == EltSizeInBits) {
30189     unsigned FunnelOpc = IsROTL ? ISD::FSHL : ISD::FSHR;
30190     return DAG.getNode(FunnelOpc, DL, VT, R, R, Amt);
30191   }
30192 
30193   SDValue Z = DAG.getConstant(0, DL, VT);
30194 
30195   if (!IsROTL) {
30196     // If the ISD::ROTR amount is constant, we're always better converting to
30197     // ISD::ROTL.
30198     if (SDValue NegAmt = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {Z, Amt}))
30199       return DAG.getNode(ISD::ROTL, DL, VT, R, NegAmt);
30200 
30201     // XOP targets always prefers ISD::ROTL.
30202     if (Subtarget.hasXOP())
30203       return DAG.getNode(ISD::ROTL, DL, VT, R,
30204                          DAG.getNode(ISD::SUB, DL, VT, Z, Amt));
30205   }
30206 
30207   // Attempt to use GFNI gf2p8affine to rotate vXi8 by an uniform constant.
30208   if (IsCstSplat && Subtarget.hasGFNI() && VT.getScalarType() == MVT::i8 &&
30209       DAG.getTargetLoweringInfo().isTypeLegal(VT)) {
30210     uint64_t RotAmt = CstSplatValue.urem(EltSizeInBits);
30211     SDValue Mask = getGFNICtrlMask(Opcode, DAG, DL, VT, RotAmt);
30212     return DAG.getNode(X86ISD::GF2P8AFFINEQB, DL, VT, R, Mask,
30213                        DAG.getTargetConstant(0, DL, MVT::i8));
30214   }
30215 
30216   // Split 256-bit integers on XOP/pre-AVX2 targets.
30217   if (VT.is256BitVector() && (Subtarget.hasXOP() || !Subtarget.hasAVX2()))
30218     return splitVectorIntBinary(Op, DAG, DL);
30219 
30220   // XOP has 128-bit vector variable + immediate rotates.
30221   // +ve/-ve Amt = rotate left/right - just need to handle ISD::ROTL.
30222   // XOP implicitly uses modulo rotation amounts.
30223   if (Subtarget.hasXOP()) {
30224     assert(IsROTL && "Only ROTL expected");
30225     assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");
30226 
30227     // Attempt to rotate by immediate.
30228     if (IsCstSplat) {
30229       uint64_t RotAmt = CstSplatValue.urem(EltSizeInBits);
30230       return DAG.getNode(X86ISD::VROTLI, DL, VT, R,
30231                          DAG.getTargetConstant(RotAmt, DL, MVT::i8));
30232     }
30233 
30234     // Use general rotate by variable (per-element).
30235     return Op;
30236   }
30237 
30238   // Rotate by an uniform constant - expand back to shifts.
30239   // TODO: Can't use generic expansion as UNDEF amt elements can be converted
30240   // to other values when folded to shift amounts, losing the splat.
30241   if (IsCstSplat) {
30242     uint64_t RotAmt = CstSplatValue.urem(EltSizeInBits);
30243     uint64_t ShlAmt = IsROTL ? RotAmt : (EltSizeInBits - RotAmt);
30244     uint64_t SrlAmt = IsROTL ? (EltSizeInBits - RotAmt) : RotAmt;
30245     SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, R,
30246                               DAG.getShiftAmountConstant(ShlAmt, VT, DL));
30247     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, R,
30248                               DAG.getShiftAmountConstant(SrlAmt, VT, DL));
30249     return DAG.getNode(ISD::OR, DL, VT, Shl, Srl);
30250   }
30251 
30252   // Split 512-bit integers on non 512-bit BWI targets.
30253   if (VT.is512BitVector() && !Subtarget.useBWIRegs())
30254     return splitVectorIntBinary(Op, DAG, DL);
30255 
30256   assert(
30257       (VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 ||
30258        ((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) &&
30259         Subtarget.hasAVX2()) ||
30260        ((VT == MVT::v32i16 || VT == MVT::v64i8) && Subtarget.useBWIRegs())) &&
30261       "Only vXi32/vXi16/vXi8 vector rotates supported");
30262 
30263   MVT ExtSVT = MVT::getIntegerVT(2 * EltSizeInBits);
30264   MVT ExtVT = MVT::getVectorVT(ExtSVT, NumElts / 2);
30265 
30266   SDValue AmtMask = DAG.getConstant(EltSizeInBits - 1, DL, VT);
30267   SDValue AmtMod = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);
30268 
30269   // Attempt to fold as unpack(x,x) << zext(splat(y)):
30270   // rotl(x,y) -> (unpack(x,x) << (y & (bw-1))) >> bw.
30271   // rotr(x,y) -> (unpack(x,x) >> (y & (bw-1))).
30272   if (EltSizeInBits == 8 || EltSizeInBits == 16 || EltSizeInBits == 32) {
30273     int BaseRotAmtIdx = -1;
30274     if (SDValue BaseRotAmt = DAG.getSplatSourceVector(AmtMod, BaseRotAmtIdx)) {
30275       if (EltSizeInBits == 16 && Subtarget.hasSSE41()) {
30276         unsigned FunnelOpc = IsROTL ? ISD::FSHL : ISD::FSHR;
30277         return DAG.getNode(FunnelOpc, DL, VT, R, R, Amt);
30278       }
30279       unsigned ShiftX86Opc = IsROTL ? X86ISD::VSHLI : X86ISD::VSRLI;
30280       SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, R, R));
30281       SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, R, R));
30282       Lo = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Lo, BaseRotAmt,
30283                                BaseRotAmtIdx, Subtarget, DAG);
30284       Hi = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Hi, BaseRotAmt,
30285                                BaseRotAmtIdx, Subtarget, DAG);
30286       return getPack(DAG, Subtarget, DL, VT, Lo, Hi, IsROTL);
30287     }
30288   }
30289 
30290   bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
30291   unsigned ShiftOpc = IsROTL ? ISD::SHL : ISD::SRL;
30292 
30293   // Attempt to fold as unpack(x,x) << zext(y):
30294   // rotl(x,y) -> (unpack(x,x) << (y & (bw-1))) >> bw.
30295   // rotr(x,y) -> (unpack(x,x) >> (y & (bw-1))).
30296   // Const vXi16/vXi32 are excluded in favor of MUL-based lowering.
30297   if (!(ConstantAmt && EltSizeInBits != 8) &&
30298       !supportedVectorVarShift(VT, Subtarget, ShiftOpc) &&
30299       (ConstantAmt || supportedVectorVarShift(ExtVT, Subtarget, ShiftOpc))) {
30300     SDValue RLo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, R, R));
30301     SDValue RHi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, R, R));
30302     SDValue ALo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, AmtMod, Z));
30303     SDValue AHi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, AmtMod, Z));
30304     SDValue Lo = DAG.getNode(ShiftOpc, DL, ExtVT, RLo, ALo);
30305     SDValue Hi = DAG.getNode(ShiftOpc, DL, ExtVT, RHi, AHi);
30306     return getPack(DAG, Subtarget, DL, VT, Lo, Hi, IsROTL);
30307   }
30308 
30309   // v16i8/v32i8/v64i8: Split rotation into rot4/rot2/rot1 stages and select by
30310   // the amount bit.
30311   // TODO: We're doing nothing here that we couldn't do for funnel shifts.
30312   if (EltSizeInBits == 8) {
30313     MVT WideVT =
30314         MVT::getVectorVT(Subtarget.hasBWI() ? MVT::i16 : MVT::i32, NumElts);
30315 
30316     // Attempt to fold as:
30317     // rotl(x,y) -> (((aext(x) << bw) | zext(x)) << (y & (bw-1))) >> bw.
30318     // rotr(x,y) -> (((aext(x) << bw) | zext(x)) >> (y & (bw-1))).
30319     if (supportedVectorVarShift(WideVT, Subtarget, ShiftOpc) &&
30320         supportedVectorShiftWithImm(WideVT, Subtarget, ShiftOpc)) {
30321       // If we're rotating by constant, just use default promotion.
30322       if (ConstantAmt)
30323         return SDValue();
30324       // See if we can perform this by widening to vXi16 or vXi32.
30325       R = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, R);
30326       R = DAG.getNode(
30327           ISD::OR, DL, WideVT, R,
30328           getTargetVShiftByConstNode(X86ISD::VSHLI, DL, WideVT, R, 8, DAG));
30329       Amt = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, AmtMod);
30330       R = DAG.getNode(ShiftOpc, DL, WideVT, R, Amt);
30331       if (IsROTL)
30332         R = getTargetVShiftByConstNode(X86ISD::VSRLI, DL, WideVT, R, 8, DAG);
30333       return DAG.getNode(ISD::TRUNCATE, DL, VT, R);
30334     }
30335 
30336     // We don't need ModuloAmt here as we just peek at individual bits.
30337     auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
30338       if (Subtarget.hasSSE41()) {
30339         // On SSE41 targets we can use PBLENDVB which selects bytes based just
30340         // on the sign bit.
30341         V0 = DAG.getBitcast(VT, V0);
30342         V1 = DAG.getBitcast(VT, V1);
30343         Sel = DAG.getBitcast(VT, Sel);
30344         return DAG.getBitcast(SelVT,
30345                               DAG.getNode(X86ISD::BLENDV, DL, VT, Sel, V0, V1));
30346       }
30347       // On pre-SSE41 targets we test for the sign bit by comparing to
30348       // zero - a negative value will set all bits of the lanes to true
30349       // and VSELECT uses that in its OR(AND(V0,C),AND(V1,~C)) lowering.
30350       SDValue Z = DAG.getConstant(0, DL, SelVT);
30351       SDValue C = DAG.getNode(X86ISD::PCMPGT, DL, SelVT, Z, Sel);
30352       return DAG.getSelect(DL, SelVT, C, V0, V1);
30353     };
30354 
30355     // ISD::ROTR is currently only profitable on AVX512 targets with VPTERNLOG.
30356     if (!IsROTL && !useVPTERNLOG(Subtarget, VT)) {
30357       Amt = DAG.getNode(ISD::SUB, DL, VT, Z, Amt);
30358       IsROTL = true;
30359     }
30360 
30361     unsigned ShiftLHS = IsROTL ? ISD::SHL : ISD::SRL;
30362     unsigned ShiftRHS = IsROTL ? ISD::SRL : ISD::SHL;
30363 
30364     // Turn 'a' into a mask suitable for VSELECT: a = a << 5;
30365     // We can safely do this using i16 shifts as we're only interested in
30366     // the 3 lower bits of each byte.
30367     Amt = DAG.getBitcast(ExtVT, Amt);
30368     Amt = DAG.getNode(ISD::SHL, DL, ExtVT, Amt, DAG.getConstant(5, DL, ExtVT));
30369     Amt = DAG.getBitcast(VT, Amt);
30370 
30371     // r = VSELECT(r, rot(r, 4), a);
30372     SDValue M;
30373     M = DAG.getNode(
30374         ISD::OR, DL, VT,
30375         DAG.getNode(ShiftLHS, DL, VT, R, DAG.getConstant(4, DL, VT)),
30376         DAG.getNode(ShiftRHS, DL, VT, R, DAG.getConstant(4, DL, VT)));
30377     R = SignBitSelect(VT, Amt, M, R);
30378 
30379     // a += a
30380     Amt = DAG.getNode(ISD::ADD, DL, VT, Amt, Amt);
30381 
30382     // r = VSELECT(r, rot(r, 2), a);
30383     M = DAG.getNode(
30384         ISD::OR, DL, VT,
30385         DAG.getNode(ShiftLHS, DL, VT, R, DAG.getConstant(2, DL, VT)),
30386         DAG.getNode(ShiftRHS, DL, VT, R, DAG.getConstant(6, DL, VT)));
30387     R = SignBitSelect(VT, Amt, M, R);
30388 
30389     // a += a
30390     Amt = DAG.getNode(ISD::ADD, DL, VT, Amt, Amt);
30391 
30392     // return VSELECT(r, rot(r, 1), a);
30393     M = DAG.getNode(
30394         ISD::OR, DL, VT,
30395         DAG.getNode(ShiftLHS, DL, VT, R, DAG.getConstant(1, DL, VT)),
30396         DAG.getNode(ShiftRHS, DL, VT, R, DAG.getConstant(7, DL, VT)));
30397     return SignBitSelect(VT, Amt, M, R);
30398   }
30399 
30400   bool IsSplatAmt = DAG.isSplatValue(Amt);
30401   bool LegalVarShifts = supportedVectorVarShift(VT, Subtarget, ISD::SHL) &&
30402                         supportedVectorVarShift(VT, Subtarget, ISD::SRL);
30403 
30404   // Fallback for splats + all supported variable shifts.
30405   // Fallback for non-constants AVX2 vXi16 as well.
30406   if (IsSplatAmt || LegalVarShifts || (Subtarget.hasAVX2() && !ConstantAmt)) {
30407     Amt = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);
30408     SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT);
30409     AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt);
30410     SDValue SHL = DAG.getNode(IsROTL ? ISD::SHL : ISD::SRL, DL, VT, R, Amt);
30411     SDValue SRL = DAG.getNode(IsROTL ? ISD::SRL : ISD::SHL, DL, VT, R, AmtR);
30412     return DAG.getNode(ISD::OR, DL, VT, SHL, SRL);
30413   }
30414 
30415   // Everything below assumes ISD::ROTL.
30416   if (!IsROTL) {
30417     Amt = DAG.getNode(ISD::SUB, DL, VT, Z, Amt);
30418     IsROTL = true;
30419   }
30420 
30421   // ISD::ROT* uses modulo rotate amounts.
30422   Amt = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);
30423 
30424   assert(IsROTL && "Only ROTL supported");
30425 
30426   // As with shifts, attempt to convert the rotation amount to a multiplication
30427   // factor, fallback to general expansion.
30428   SDValue Scale = convertShiftLeftToScale(Amt, DL, Subtarget, DAG);
30429   if (!Scale)
30430     return SDValue();
30431 
30432   // v8i16/v16i16: perform unsigned multiply hi/lo and OR the results.
30433   if (EltSizeInBits == 16) {
30434     SDValue Lo = DAG.getNode(ISD::MUL, DL, VT, R, Scale);
30435     SDValue Hi = DAG.getNode(ISD::MULHU, DL, VT, R, Scale);
30436     return DAG.getNode(ISD::OR, DL, VT, Lo, Hi);
30437   }
30438 
30439   // v4i32: make use of the PMULUDQ instruction to multiply 2 lanes of v4i32
30440   // to v2i64 results at a time. The upper 32-bits contain the wrapped bits
30441   // that can then be OR'd with the lower 32-bits.
30442   assert(VT == MVT::v4i32 && "Only v4i32 vector rotate expected");
30443   static const int OddMask[] = {1, -1, 3, -1};
30444   SDValue R13 = DAG.getVectorShuffle(VT, DL, R, R, OddMask);
30445   SDValue Scale13 = DAG.getVectorShuffle(VT, DL, Scale, Scale, OddMask);
30446 
30447   SDValue Res02 = DAG.getNode(X86ISD::PMULUDQ, DL, MVT::v2i64,
30448                               DAG.getBitcast(MVT::v2i64, R),
30449                               DAG.getBitcast(MVT::v2i64, Scale));
30450   SDValue Res13 = DAG.getNode(X86ISD::PMULUDQ, DL, MVT::v2i64,
30451                               DAG.getBitcast(MVT::v2i64, R13),
30452                               DAG.getBitcast(MVT::v2i64, Scale13));
30453   Res02 = DAG.getBitcast(VT, Res02);
30454   Res13 = DAG.getBitcast(VT, Res13);
30455 
30456   return DAG.getNode(ISD::OR, DL, VT,
30457                      DAG.getVectorShuffle(VT, DL, Res02, Res13, {0, 4, 2, 6}),
30458                      DAG.getVectorShuffle(VT, DL, Res02, Res13, {1, 5, 3, 7}));
30459 }
30460 
30461 /// Returns true if the operand type is exactly twice the native width, and
30462 /// the corresponding cmpxchg8b or cmpxchg16b instruction is available.
30463 /// Used to know whether to use cmpxchg8/16b when expanding atomic operations
30464 /// (otherwise we leave them alone to become __sync_fetch_and_... calls).
needsCmpXchgNb(Type * MemType) const30465 bool X86TargetLowering::needsCmpXchgNb(Type *MemType) const {
30466   unsigned OpWidth = MemType->getPrimitiveSizeInBits();
30467 
30468   if (OpWidth == 64)
30469     return Subtarget.canUseCMPXCHG8B() && !Subtarget.is64Bit();
30470   if (OpWidth == 128)
30471     return Subtarget.canUseCMPXCHG16B();
30472 
30473   return false;
30474 }
30475 
30476 TargetLoweringBase::AtomicExpansionKind
shouldExpandAtomicStoreInIR(StoreInst * SI) const30477 X86TargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const {
30478   Type *MemType = SI->getValueOperand()->getType();
30479 
30480   if (!SI->getFunction()->hasFnAttribute(Attribute::NoImplicitFloat) &&
30481       !Subtarget.useSoftFloat()) {
30482     if (MemType->getPrimitiveSizeInBits() == 64 && !Subtarget.is64Bit() &&
30483         (Subtarget.hasSSE1() || Subtarget.hasX87()))
30484       return AtomicExpansionKind::None;
30485 
30486     if (MemType->getPrimitiveSizeInBits() == 128 && Subtarget.is64Bit() &&
30487         Subtarget.hasAVX())
30488       return AtomicExpansionKind::None;
30489   }
30490 
30491   return needsCmpXchgNb(MemType) ? AtomicExpansionKind::Expand
30492                                  : AtomicExpansionKind::None;
30493 }
30494 
30495 // Note: this turns large loads into lock cmpxchg8b/16b.
30496 TargetLowering::AtomicExpansionKind
shouldExpandAtomicLoadInIR(LoadInst * LI) const30497 X86TargetLowering::shouldExpandAtomicLoadInIR(LoadInst *LI) const {
30498   Type *MemType = LI->getType();
30499 
30500   if (!LI->getFunction()->hasFnAttribute(Attribute::NoImplicitFloat) &&
30501       !Subtarget.useSoftFloat()) {
30502     // If this a 64 bit atomic load on a 32-bit target and SSE2 is enabled, we
30503     // can use movq to do the load. If we have X87 we can load into an 80-bit
30504     // X87 register and store it to a stack temporary.
30505     if (MemType->getPrimitiveSizeInBits() == 64 && !Subtarget.is64Bit() &&
30506         (Subtarget.hasSSE1() || Subtarget.hasX87()))
30507       return AtomicExpansionKind::None;
30508 
30509     // If this is a 128-bit load with AVX, 128-bit SSE loads/stores are atomic.
30510     if (MemType->getPrimitiveSizeInBits() == 128 && Subtarget.is64Bit() &&
30511         Subtarget.hasAVX())
30512       return AtomicExpansionKind::None;
30513   }
30514 
30515   return needsCmpXchgNb(MemType) ? AtomicExpansionKind::CmpXChg
30516                                  : AtomicExpansionKind::None;
30517 }
30518 
30519 enum BitTestKind : unsigned {
30520   UndefBit,
30521   ConstantBit,
30522   NotConstantBit,
30523   ShiftBit,
30524   NotShiftBit
30525 };
30526 
FindSingleBitChange(Value * V)30527 static std::pair<Value *, BitTestKind> FindSingleBitChange(Value *V) {
30528   using namespace llvm::PatternMatch;
30529   BitTestKind BTK = UndefBit;
30530   auto *C = dyn_cast<ConstantInt>(V);
30531   if (C) {
30532     // Check if V is a power of 2 or NOT power of 2.
30533     if (isPowerOf2_64(C->getZExtValue()))
30534       BTK = ConstantBit;
30535     else if (isPowerOf2_64((~C->getValue()).getZExtValue()))
30536       BTK = NotConstantBit;
30537     return {V, BTK};
30538   }
30539 
30540   // Check if V is some power of 2 pattern known to be non-zero
30541   auto *I = dyn_cast<Instruction>(V);
30542   if (I) {
30543     bool Not = false;
30544     // Check if we have a NOT
30545     Value *PeekI;
30546     if (match(I, m_Not(m_Value(PeekI))) ||
30547         match(I, m_Sub(m_AllOnes(), m_Value(PeekI)))) {
30548       Not = true;
30549       I = dyn_cast<Instruction>(PeekI);
30550 
30551       // If I is constant, it will fold and we can evaluate later. If its an
30552       // argument or something of that nature, we can't analyze.
30553       if (I == nullptr)
30554         return {nullptr, UndefBit};
30555     }
30556     // We can only use 1 << X without more sophisticated analysis. C << X where
30557     // C is a power of 2 but not 1 can result in zero which cannot be translated
30558     // to bittest. Likewise any C >> X (either arith or logical) can be zero.
30559     if (I->getOpcode() == Instruction::Shl) {
30560       // Todo(1): The cmpxchg case is pretty costly so matching `BLSI(X)`, `X &
30561       // -X` and some other provable power of 2 patterns that we can use CTZ on
30562       // may be profitable.
30563       // Todo(2): It may be possible in some cases to prove that Shl(C, X) is
30564       // non-zero even where C != 1. Likewise LShr(C, X) and AShr(C, X) may also
30565       // be provably a non-zero power of 2.
30566       // Todo(3): ROTL and ROTR patterns on a power of 2 C should also be
30567       // transformable to bittest.
30568       auto *ShiftVal = dyn_cast<ConstantInt>(I->getOperand(0));
30569       if (!ShiftVal)
30570         return {nullptr, UndefBit};
30571       if (ShiftVal->equalsInt(1))
30572         BTK = Not ? NotShiftBit : ShiftBit;
30573 
30574       if (BTK == UndefBit)
30575         return {nullptr, UndefBit};
30576 
30577       Value *BitV = I->getOperand(1);
30578 
30579       Value *AndOp;
30580       const APInt *AndC;
30581       if (match(BitV, m_c_And(m_Value(AndOp), m_APInt(AndC)))) {
30582         // Read past a shiftmask instruction to find count
30583         if (*AndC == (I->getType()->getPrimitiveSizeInBits() - 1))
30584           BitV = AndOp;
30585       }
30586       return {BitV, BTK};
30587     }
30588   }
30589   return {nullptr, UndefBit};
30590 }
30591 
30592 TargetLowering::AtomicExpansionKind
shouldExpandLogicAtomicRMWInIR(AtomicRMWInst * AI) const30593 X86TargetLowering::shouldExpandLogicAtomicRMWInIR(AtomicRMWInst *AI) const {
30594   using namespace llvm::PatternMatch;
30595   // If the atomicrmw's result isn't actually used, we can just add a "lock"
30596   // prefix to a normal instruction for these operations.
30597   if (AI->use_empty())
30598     return AtomicExpansionKind::None;
30599 
30600   if (AI->getOperation() == AtomicRMWInst::Xor) {
30601     // A ^ SignBit -> A + SignBit. This allows us to use `xadd` which is
30602     // preferable to both `cmpxchg` and `btc`.
30603     if (match(AI->getOperand(1), m_SignMask()))
30604       return AtomicExpansionKind::None;
30605   }
30606 
30607   // If the atomicrmw's result is used by a single bit AND, we may use
30608   // bts/btr/btc instruction for these operations.
30609   // Note: InstCombinePass can cause a de-optimization here. It replaces the
30610   // SETCC(And(AtomicRMW(P, power_of_2), power_of_2)) with LShr and Xor
30611   // (depending on CC). This pattern can only use bts/btr/btc but we don't
30612   // detect it.
30613   Instruction *I = AI->user_back();
30614   auto BitChange = FindSingleBitChange(AI->getValOperand());
30615   if (BitChange.second == UndefBit || !AI->hasOneUse() ||
30616       I->getOpcode() != Instruction::And ||
30617       AI->getType()->getPrimitiveSizeInBits() == 8 ||
30618       AI->getParent() != I->getParent())
30619     return AtomicExpansionKind::CmpXChg;
30620 
30621   unsigned OtherIdx = I->getOperand(0) == AI ? 1 : 0;
30622 
30623   // This is a redundant AND, it should get cleaned up elsewhere.
30624   if (AI == I->getOperand(OtherIdx))
30625     return AtomicExpansionKind::CmpXChg;
30626 
30627   // The following instruction must be a AND single bit.
30628   if (BitChange.second == ConstantBit || BitChange.second == NotConstantBit) {
30629     auto *C1 = cast<ConstantInt>(AI->getValOperand());
30630     auto *C2 = dyn_cast<ConstantInt>(I->getOperand(OtherIdx));
30631     if (!C2 || !isPowerOf2_64(C2->getZExtValue())) {
30632       return AtomicExpansionKind::CmpXChg;
30633     }
30634     if (AI->getOperation() == AtomicRMWInst::And) {
30635       return ~C1->getValue() == C2->getValue()
30636                  ? AtomicExpansionKind::BitTestIntrinsic
30637                  : AtomicExpansionKind::CmpXChg;
30638     }
30639     return C1 == C2 ? AtomicExpansionKind::BitTestIntrinsic
30640                     : AtomicExpansionKind::CmpXChg;
30641   }
30642 
30643   assert(BitChange.second == ShiftBit || BitChange.second == NotShiftBit);
30644 
30645   auto BitTested = FindSingleBitChange(I->getOperand(OtherIdx));
30646   if (BitTested.second != ShiftBit && BitTested.second != NotShiftBit)
30647     return AtomicExpansionKind::CmpXChg;
30648 
30649   assert(BitChange.first != nullptr && BitTested.first != nullptr);
30650 
30651   // If shift amounts are not the same we can't use BitTestIntrinsic.
30652   if (BitChange.first != BitTested.first)
30653     return AtomicExpansionKind::CmpXChg;
30654 
30655   // If atomic AND need to be masking all be one bit and testing the one bit
30656   // unset in the mask.
30657   if (AI->getOperation() == AtomicRMWInst::And)
30658     return (BitChange.second == NotShiftBit && BitTested.second == ShiftBit)
30659                ? AtomicExpansionKind::BitTestIntrinsic
30660                : AtomicExpansionKind::CmpXChg;
30661 
30662   // If atomic XOR/OR need to be setting and testing the same bit.
30663   return (BitChange.second == ShiftBit && BitTested.second == ShiftBit)
30664              ? AtomicExpansionKind::BitTestIntrinsic
30665              : AtomicExpansionKind::CmpXChg;
30666 }
30667 
emitBitTestAtomicRMWIntrinsic(AtomicRMWInst * AI) const30668 void X86TargetLowering::emitBitTestAtomicRMWIntrinsic(AtomicRMWInst *AI) const {
30669   IRBuilder<> Builder(AI);
30670   Builder.CollectMetadataToCopy(AI, {LLVMContext::MD_pcsections});
30671   Intrinsic::ID IID_C = Intrinsic::not_intrinsic;
30672   Intrinsic::ID IID_I = Intrinsic::not_intrinsic;
30673   switch (AI->getOperation()) {
30674   default:
30675     llvm_unreachable("Unknown atomic operation");
30676   case AtomicRMWInst::Or:
30677     IID_C = Intrinsic::x86_atomic_bts;
30678     IID_I = Intrinsic::x86_atomic_bts_rm;
30679     break;
30680   case AtomicRMWInst::Xor:
30681     IID_C = Intrinsic::x86_atomic_btc;
30682     IID_I = Intrinsic::x86_atomic_btc_rm;
30683     break;
30684   case AtomicRMWInst::And:
30685     IID_C = Intrinsic::x86_atomic_btr;
30686     IID_I = Intrinsic::x86_atomic_btr_rm;
30687     break;
30688   }
30689   Instruction *I = AI->user_back();
30690   LLVMContext &Ctx = AI->getContext();
30691   Value *Addr = Builder.CreatePointerCast(AI->getPointerOperand(),
30692                                           PointerType::getUnqual(Ctx));
30693   Function *BitTest = nullptr;
30694   Value *Result = nullptr;
30695   auto BitTested = FindSingleBitChange(AI->getValOperand());
30696   assert(BitTested.first != nullptr);
30697 
30698   if (BitTested.second == ConstantBit || BitTested.second == NotConstantBit) {
30699     auto *C = cast<ConstantInt>(I->getOperand(I->getOperand(0) == AI ? 1 : 0));
30700 
30701     BitTest = Intrinsic::getDeclaration(AI->getModule(), IID_C, AI->getType());
30702 
30703     unsigned Imm = llvm::countr_zero(C->getZExtValue());
30704     Result = Builder.CreateCall(BitTest, {Addr, Builder.getInt8(Imm)});
30705   } else {
30706     BitTest = Intrinsic::getDeclaration(AI->getModule(), IID_I, AI->getType());
30707 
30708     assert(BitTested.second == ShiftBit || BitTested.second == NotShiftBit);
30709 
30710     Value *SI = BitTested.first;
30711     assert(SI != nullptr);
30712 
30713     // BT{S|R|C} on memory operand don't modulo bit position so we need to
30714     // mask it.
30715     unsigned ShiftBits = SI->getType()->getPrimitiveSizeInBits();
30716     Value *BitPos =
30717         Builder.CreateAnd(SI, Builder.getIntN(ShiftBits, ShiftBits - 1));
30718     // Todo(1): In many cases it may be provable that SI is less than
30719     // ShiftBits in which case this mask is unnecessary
30720     // Todo(2): In the fairly idiomatic case of P[X / sizeof_bits(X)] OP 1
30721     // << (X % sizeof_bits(X)) we can drop the shift mask and AGEN in
30722     // favor of just a raw BT{S|R|C}.
30723 
30724     Result = Builder.CreateCall(BitTest, {Addr, BitPos});
30725     Result = Builder.CreateZExtOrTrunc(Result, AI->getType());
30726 
30727     // If the result is only used for zero/non-zero status then we don't need to
30728     // shift value back. Otherwise do so.
30729     for (auto It = I->user_begin(); It != I->user_end(); ++It) {
30730       if (auto *ICmp = dyn_cast<ICmpInst>(*It)) {
30731         if (ICmp->isEquality()) {
30732           auto *C0 = dyn_cast<ConstantInt>(ICmp->getOperand(0));
30733           auto *C1 = dyn_cast<ConstantInt>(ICmp->getOperand(1));
30734           if (C0 || C1) {
30735             assert(C0 == nullptr || C1 == nullptr);
30736             if ((C0 ? C0 : C1)->isZero())
30737               continue;
30738           }
30739         }
30740       }
30741       Result = Builder.CreateShl(Result, BitPos);
30742       break;
30743     }
30744   }
30745 
30746   I->replaceAllUsesWith(Result);
30747   I->eraseFromParent();
30748   AI->eraseFromParent();
30749 }
30750 
shouldExpandCmpArithRMWInIR(AtomicRMWInst * AI)30751 static bool shouldExpandCmpArithRMWInIR(AtomicRMWInst *AI) {
30752   using namespace llvm::PatternMatch;
30753   if (!AI->hasOneUse())
30754     return false;
30755 
30756   Value *Op = AI->getOperand(1);
30757   ICmpInst::Predicate Pred;
30758   Instruction *I = AI->user_back();
30759   AtomicRMWInst::BinOp Opc = AI->getOperation();
30760   if (Opc == AtomicRMWInst::Add) {
30761     if (match(I, m_c_ICmp(Pred, m_Sub(m_ZeroInt(), m_Specific(Op)), m_Value())))
30762       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE;
30763     if (match(I, m_OneUse(m_c_Add(m_Specific(Op), m_Value())))) {
30764       if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
30765         return Pred == CmpInst::ICMP_SLT;
30766       if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
30767         return Pred == CmpInst::ICMP_SGT;
30768     }
30769     return false;
30770   }
30771   if (Opc == AtomicRMWInst::Sub) {
30772     if (match(I, m_c_ICmp(Pred, m_Specific(Op), m_Value())))
30773       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE;
30774     if (match(I, m_OneUse(m_Sub(m_Value(), m_Specific(Op))))) {
30775       if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
30776         return Pred == CmpInst::ICMP_SLT;
30777       if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
30778         return Pred == CmpInst::ICMP_SGT;
30779     }
30780     return false;
30781   }
30782   if ((Opc == AtomicRMWInst::Or &&
30783        match(I, m_OneUse(m_c_Or(m_Specific(Op), m_Value())))) ||
30784       (Opc == AtomicRMWInst::And &&
30785        match(I, m_OneUse(m_c_And(m_Specific(Op), m_Value()))))) {
30786     if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
30787       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE ||
30788              Pred == CmpInst::ICMP_SLT;
30789     if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
30790       return Pred == CmpInst::ICMP_SGT;
30791     return false;
30792   }
30793   if (Opc == AtomicRMWInst::Xor) {
30794     if (match(I, m_c_ICmp(Pred, m_Specific(Op), m_Value())))
30795       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE;
30796     if (match(I, m_OneUse(m_c_Xor(m_Specific(Op), m_Value())))) {
30797       if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
30798         return Pred == CmpInst::ICMP_SLT;
30799       if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
30800         return Pred == CmpInst::ICMP_SGT;
30801     }
30802     return false;
30803   }
30804 
30805   return false;
30806 }
30807 
emitCmpArithAtomicRMWIntrinsic(AtomicRMWInst * AI) const30808 void X86TargetLowering::emitCmpArithAtomicRMWIntrinsic(
30809     AtomicRMWInst *AI) const {
30810   IRBuilder<> Builder(AI);
30811   Builder.CollectMetadataToCopy(AI, {LLVMContext::MD_pcsections});
30812   Instruction *TempI = nullptr;
30813   LLVMContext &Ctx = AI->getContext();
30814   ICmpInst *ICI = dyn_cast<ICmpInst>(AI->user_back());
30815   if (!ICI) {
30816     TempI = AI->user_back();
30817     assert(TempI->hasOneUse() && "Must have one use");
30818     ICI = cast<ICmpInst>(TempI->user_back());
30819   }
30820   X86::CondCode CC = X86::COND_INVALID;
30821   ICmpInst::Predicate Pred = ICI->getPredicate();
30822   switch (Pred) {
30823   default:
30824     llvm_unreachable("Not supported Pred");
30825   case CmpInst::ICMP_EQ:
30826     CC = X86::COND_E;
30827     break;
30828   case CmpInst::ICMP_NE:
30829     CC = X86::COND_NE;
30830     break;
30831   case CmpInst::ICMP_SLT:
30832     CC = X86::COND_S;
30833     break;
30834   case CmpInst::ICMP_SGT:
30835     CC = X86::COND_NS;
30836     break;
30837   }
30838   Intrinsic::ID IID = Intrinsic::not_intrinsic;
30839   switch (AI->getOperation()) {
30840   default:
30841     llvm_unreachable("Unknown atomic operation");
30842   case AtomicRMWInst::Add:
30843     IID = Intrinsic::x86_atomic_add_cc;
30844     break;
30845   case AtomicRMWInst::Sub:
30846     IID = Intrinsic::x86_atomic_sub_cc;
30847     break;
30848   case AtomicRMWInst::Or:
30849     IID = Intrinsic::x86_atomic_or_cc;
30850     break;
30851   case AtomicRMWInst::And:
30852     IID = Intrinsic::x86_atomic_and_cc;
30853     break;
30854   case AtomicRMWInst::Xor:
30855     IID = Intrinsic::x86_atomic_xor_cc;
30856     break;
30857   }
30858   Function *CmpArith =
30859       Intrinsic::getDeclaration(AI->getModule(), IID, AI->getType());
30860   Value *Addr = Builder.CreatePointerCast(AI->getPointerOperand(),
30861                                           PointerType::getUnqual(Ctx));
30862   Value *Call = Builder.CreateCall(
30863       CmpArith, {Addr, AI->getValOperand(), Builder.getInt32((unsigned)CC)});
30864   Value *Result = Builder.CreateTrunc(Call, Type::getInt1Ty(Ctx));
30865   ICI->replaceAllUsesWith(Result);
30866   ICI->eraseFromParent();
30867   if (TempI)
30868     TempI->eraseFromParent();
30869   AI->eraseFromParent();
30870 }
30871 
30872 TargetLowering::AtomicExpansionKind
shouldExpandAtomicRMWInIR(AtomicRMWInst * AI) const30873 X86TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
30874   unsigned NativeWidth = Subtarget.is64Bit() ? 64 : 32;
30875   Type *MemType = AI->getType();
30876 
30877   // If the operand is too big, we must see if cmpxchg8/16b is available
30878   // and default to library calls otherwise.
30879   if (MemType->getPrimitiveSizeInBits() > NativeWidth) {
30880     return needsCmpXchgNb(MemType) ? AtomicExpansionKind::CmpXChg
30881                                    : AtomicExpansionKind::None;
30882   }
30883 
30884   AtomicRMWInst::BinOp Op = AI->getOperation();
30885   switch (Op) {
30886   case AtomicRMWInst::Xchg:
30887     return AtomicExpansionKind::None;
30888   case AtomicRMWInst::Add:
30889   case AtomicRMWInst::Sub:
30890     if (shouldExpandCmpArithRMWInIR(AI))
30891       return AtomicExpansionKind::CmpArithIntrinsic;
30892     // It's better to use xadd, xsub or xchg for these in other cases.
30893     return AtomicExpansionKind::None;
30894   case AtomicRMWInst::Or:
30895   case AtomicRMWInst::And:
30896   case AtomicRMWInst::Xor:
30897     if (shouldExpandCmpArithRMWInIR(AI))
30898       return AtomicExpansionKind::CmpArithIntrinsic;
30899     return shouldExpandLogicAtomicRMWInIR(AI);
30900   case AtomicRMWInst::Nand:
30901   case AtomicRMWInst::Max:
30902   case AtomicRMWInst::Min:
30903   case AtomicRMWInst::UMax:
30904   case AtomicRMWInst::UMin:
30905   case AtomicRMWInst::FAdd:
30906   case AtomicRMWInst::FSub:
30907   case AtomicRMWInst::FMax:
30908   case AtomicRMWInst::FMin:
30909   case AtomicRMWInst::UIncWrap:
30910   case AtomicRMWInst::UDecWrap:
30911   default:
30912     // These always require a non-trivial set of data operations on x86. We must
30913     // use a cmpxchg loop.
30914     return AtomicExpansionKind::CmpXChg;
30915   }
30916 }
30917 
30918 LoadInst *
lowerIdempotentRMWIntoFencedLoad(AtomicRMWInst * AI) const30919 X86TargetLowering::lowerIdempotentRMWIntoFencedLoad(AtomicRMWInst *AI) const {
30920   unsigned NativeWidth = Subtarget.is64Bit() ? 64 : 32;
30921   Type *MemType = AI->getType();
30922   // Accesses larger than the native width are turned into cmpxchg/libcalls, so
30923   // there is no benefit in turning such RMWs into loads, and it is actually
30924   // harmful as it introduces a mfence.
30925   if (MemType->getPrimitiveSizeInBits() > NativeWidth)
30926     return nullptr;
30927 
30928   // If this is a canonical idempotent atomicrmw w/no uses, we have a better
30929   // lowering available in lowerAtomicArith.
30930   // TODO: push more cases through this path.
30931   if (auto *C = dyn_cast<ConstantInt>(AI->getValOperand()))
30932     if (AI->getOperation() == AtomicRMWInst::Or && C->isZero() &&
30933         AI->use_empty())
30934       return nullptr;
30935 
30936   IRBuilder<> Builder(AI);
30937   Builder.CollectMetadataToCopy(AI, {LLVMContext::MD_pcsections});
30938   Module *M = Builder.GetInsertBlock()->getParent()->getParent();
30939   auto SSID = AI->getSyncScopeID();
30940   // We must restrict the ordering to avoid generating loads with Release or
30941   // ReleaseAcquire orderings.
30942   auto Order = AtomicCmpXchgInst::getStrongestFailureOrdering(AI->getOrdering());
30943 
30944   // Before the load we need a fence. Here is an example lifted from
30945   // http://www.hpl.hp.com/techreports/2012/HPL-2012-68.pdf showing why a fence
30946   // is required:
30947   // Thread 0:
30948   //   x.store(1, relaxed);
30949   //   r1 = y.fetch_add(0, release);
30950   // Thread 1:
30951   //   y.fetch_add(42, acquire);
30952   //   r2 = x.load(relaxed);
30953   // r1 = r2 = 0 is impossible, but becomes possible if the idempotent rmw is
30954   // lowered to just a load without a fence. A mfence flushes the store buffer,
30955   // making the optimization clearly correct.
30956   // FIXME: it is required if isReleaseOrStronger(Order) but it is not clear
30957   // otherwise, we might be able to be more aggressive on relaxed idempotent
30958   // rmw. In practice, they do not look useful, so we don't try to be
30959   // especially clever.
30960   if (SSID == SyncScope::SingleThread)
30961     // FIXME: we could just insert an ISD::MEMBARRIER here, except we are at
30962     // the IR level, so we must wrap it in an intrinsic.
30963     return nullptr;
30964 
30965   if (!Subtarget.hasMFence())
30966     // FIXME: it might make sense to use a locked operation here but on a
30967     // different cache-line to prevent cache-line bouncing. In practice it
30968     // is probably a small win, and x86 processors without mfence are rare
30969     // enough that we do not bother.
30970     return nullptr;
30971 
30972   Function *MFence =
30973       llvm::Intrinsic::getDeclaration(M, Intrinsic::x86_sse2_mfence);
30974   Builder.CreateCall(MFence, {});
30975 
30976   // Finally we can emit the atomic load.
30977   LoadInst *Loaded = Builder.CreateAlignedLoad(
30978       AI->getType(), AI->getPointerOperand(), AI->getAlign());
30979   Loaded->setAtomic(Order, SSID);
30980   AI->replaceAllUsesWith(Loaded);
30981   AI->eraseFromParent();
30982   return Loaded;
30983 }
30984 
30985 /// Emit a locked operation on a stack location which does not change any
30986 /// memory location, but does involve a lock prefix.  Location is chosen to be
30987 /// a) very likely accessed only by a single thread to minimize cache traffic,
30988 /// and b) definitely dereferenceable.  Returns the new Chain result.
emitLockedStackOp(SelectionDAG & DAG,const X86Subtarget & Subtarget,SDValue Chain,const SDLoc & DL)30989 static SDValue emitLockedStackOp(SelectionDAG &DAG,
30990                                  const X86Subtarget &Subtarget, SDValue Chain,
30991                                  const SDLoc &DL) {
30992   // Implementation notes:
30993   // 1) LOCK prefix creates a full read/write reordering barrier for memory
30994   // operations issued by the current processor.  As such, the location
30995   // referenced is not relevant for the ordering properties of the instruction.
30996   // See: Intel® 64 and IA-32 ArchitecturesSoftware Developer’s Manual,
30997   // 8.2.3.9  Loads and Stores Are Not Reordered with Locked Instructions
30998   // 2) Using an immediate operand appears to be the best encoding choice
30999   // here since it doesn't require an extra register.
31000   // 3) OR appears to be very slightly faster than ADD. (Though, the difference
31001   // is small enough it might just be measurement noise.)
31002   // 4) When choosing offsets, there are several contributing factors:
31003   //   a) If there's no redzone, we default to TOS.  (We could allocate a cache
31004   //      line aligned stack object to improve this case.)
31005   //   b) To minimize our chances of introducing a false dependence, we prefer
31006   //      to offset the stack usage from TOS slightly.
31007   //   c) To minimize concerns about cross thread stack usage - in particular,
31008   //      the idiomatic MyThreadPool.run([&StackVars]() {...}) pattern which
31009   //      captures state in the TOS frame and accesses it from many threads -
31010   //      we want to use an offset such that the offset is in a distinct cache
31011   //      line from the TOS frame.
31012   //
31013   // For a general discussion of the tradeoffs and benchmark results, see:
31014   // https://shipilev.net/blog/2014/on-the-fence-with-dependencies/
31015 
31016   auto &MF = DAG.getMachineFunction();
31017   auto &TFL = *Subtarget.getFrameLowering();
31018   const unsigned SPOffset = TFL.has128ByteRedZone(MF) ? -64 : 0;
31019 
31020   if (Subtarget.is64Bit()) {
31021     SDValue Zero = DAG.getTargetConstant(0, DL, MVT::i32);
31022     SDValue Ops[] = {
31023       DAG.getRegister(X86::RSP, MVT::i64),                  // Base
31024       DAG.getTargetConstant(1, DL, MVT::i8),                // Scale
31025       DAG.getRegister(0, MVT::i64),                         // Index
31026       DAG.getTargetConstant(SPOffset, DL, MVT::i32),        // Disp
31027       DAG.getRegister(0, MVT::i16),                         // Segment.
31028       Zero,
31029       Chain};
31030     SDNode *Res = DAG.getMachineNode(X86::OR32mi8Locked, DL, MVT::i32,
31031                                      MVT::Other, Ops);
31032     return SDValue(Res, 1);
31033   }
31034 
31035   SDValue Zero = DAG.getTargetConstant(0, DL, MVT::i32);
31036   SDValue Ops[] = {
31037     DAG.getRegister(X86::ESP, MVT::i32),            // Base
31038     DAG.getTargetConstant(1, DL, MVT::i8),          // Scale
31039     DAG.getRegister(0, MVT::i32),                   // Index
31040     DAG.getTargetConstant(SPOffset, DL, MVT::i32),  // Disp
31041     DAG.getRegister(0, MVT::i16),                   // Segment.
31042     Zero,
31043     Chain
31044   };
31045   SDNode *Res = DAG.getMachineNode(X86::OR32mi8Locked, DL, MVT::i32,
31046                                    MVT::Other, Ops);
31047   return SDValue(Res, 1);
31048 }
31049 
LowerATOMIC_FENCE(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)31050 static SDValue LowerATOMIC_FENCE(SDValue Op, const X86Subtarget &Subtarget,
31051                                  SelectionDAG &DAG) {
31052   SDLoc dl(Op);
31053   AtomicOrdering FenceOrdering =
31054       static_cast<AtomicOrdering>(Op.getConstantOperandVal(1));
31055   SyncScope::ID FenceSSID =
31056       static_cast<SyncScope::ID>(Op.getConstantOperandVal(2));
31057 
31058   // The only fence that needs an instruction is a sequentially-consistent
31059   // cross-thread fence.
31060   if (FenceOrdering == AtomicOrdering::SequentiallyConsistent &&
31061       FenceSSID == SyncScope::System) {
31062     if (Subtarget.hasMFence())
31063       return DAG.getNode(X86ISD::MFENCE, dl, MVT::Other, Op.getOperand(0));
31064 
31065     SDValue Chain = Op.getOperand(0);
31066     return emitLockedStackOp(DAG, Subtarget, Chain, dl);
31067   }
31068 
31069   // MEMBARRIER is a compiler barrier; it codegens to a no-op.
31070   return DAG.getNode(ISD::MEMBARRIER, dl, MVT::Other, Op.getOperand(0));
31071 }
31072 
LowerCMP_SWAP(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)31073 static SDValue LowerCMP_SWAP(SDValue Op, const X86Subtarget &Subtarget,
31074                              SelectionDAG &DAG) {
31075   MVT T = Op.getSimpleValueType();
31076   SDLoc DL(Op);
31077   unsigned Reg = 0;
31078   unsigned size = 0;
31079   switch(T.SimpleTy) {
31080   default: llvm_unreachable("Invalid value type!");
31081   case MVT::i8:  Reg = X86::AL;  size = 1; break;
31082   case MVT::i16: Reg = X86::AX;  size = 2; break;
31083   case MVT::i32: Reg = X86::EAX; size = 4; break;
31084   case MVT::i64:
31085     assert(Subtarget.is64Bit() && "Node not type legal!");
31086     Reg = X86::RAX; size = 8;
31087     break;
31088   }
31089   SDValue cpIn = DAG.getCopyToReg(Op.getOperand(0), DL, Reg,
31090                                   Op.getOperand(2), SDValue());
31091   SDValue Ops[] = { cpIn.getValue(0),
31092                     Op.getOperand(1),
31093                     Op.getOperand(3),
31094                     DAG.getTargetConstant(size, DL, MVT::i8),
31095                     cpIn.getValue(1) };
31096   SDVTList Tys = DAG.getVTList(MVT::Other, MVT::Glue);
31097   MachineMemOperand *MMO = cast<AtomicSDNode>(Op)->getMemOperand();
31098   SDValue Result = DAG.getMemIntrinsicNode(X86ISD::LCMPXCHG_DAG, DL, Tys,
31099                                            Ops, T, MMO);
31100 
31101   SDValue cpOut =
31102     DAG.getCopyFromReg(Result.getValue(0), DL, Reg, T, Result.getValue(1));
31103   SDValue EFLAGS = DAG.getCopyFromReg(cpOut.getValue(1), DL, X86::EFLAGS,
31104                                       MVT::i32, cpOut.getValue(2));
31105   SDValue Success = getSETCC(X86::COND_E, EFLAGS, DL, DAG);
31106 
31107   return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(),
31108                      cpOut, Success, EFLAGS.getValue(1));
31109 }
31110 
31111 // Create MOVMSKB, taking into account whether we need to split for AVX1.
getPMOVMSKB(const SDLoc & DL,SDValue V,SelectionDAG & DAG,const X86Subtarget & Subtarget)31112 static SDValue getPMOVMSKB(const SDLoc &DL, SDValue V, SelectionDAG &DAG,
31113                            const X86Subtarget &Subtarget) {
31114   MVT InVT = V.getSimpleValueType();
31115 
31116   if (InVT == MVT::v64i8) {
31117     SDValue Lo, Hi;
31118     std::tie(Lo, Hi) = DAG.SplitVector(V, DL);
31119     Lo = getPMOVMSKB(DL, Lo, DAG, Subtarget);
31120     Hi = getPMOVMSKB(DL, Hi, DAG, Subtarget);
31121     Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Lo);
31122     Hi = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, Hi);
31123     Hi = DAG.getNode(ISD::SHL, DL, MVT::i64, Hi,
31124                      DAG.getConstant(32, DL, MVT::i8));
31125     return DAG.getNode(ISD::OR, DL, MVT::i64, Lo, Hi);
31126   }
31127   if (InVT == MVT::v32i8 && !Subtarget.hasInt256()) {
31128     SDValue Lo, Hi;
31129     std::tie(Lo, Hi) = DAG.SplitVector(V, DL);
31130     Lo = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Lo);
31131     Hi = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Hi);
31132     Hi = DAG.getNode(ISD::SHL, DL, MVT::i32, Hi,
31133                      DAG.getConstant(16, DL, MVT::i8));
31134     return DAG.getNode(ISD::OR, DL, MVT::i32, Lo, Hi);
31135   }
31136 
31137   return DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V);
31138 }
31139 
LowerBITCAST(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)31140 static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
31141                             SelectionDAG &DAG) {
31142   SDValue Src = Op.getOperand(0);
31143   MVT SrcVT = Src.getSimpleValueType();
31144   MVT DstVT = Op.getSimpleValueType();
31145 
31146   // Legalize (v64i1 (bitcast i64 (X))) by splitting the i64, bitcasting each
31147   // half to v32i1 and concatenating the result.
31148   if (SrcVT == MVT::i64 && DstVT == MVT::v64i1) {
31149     assert(!Subtarget.is64Bit() && "Expected 32-bit mode");
31150     assert(Subtarget.hasBWI() && "Expected BWI target");
31151     SDLoc dl(Op);
31152     SDValue Lo, Hi;
31153     std::tie(Lo, Hi) = DAG.SplitScalar(Src, dl, MVT::i32, MVT::i32);
31154     Lo = DAG.getBitcast(MVT::v32i1, Lo);
31155     Hi = DAG.getBitcast(MVT::v32i1, Hi);
31156     return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lo, Hi);
31157   }
31158 
31159   // Use MOVMSK for vector to scalar conversion to prevent scalarization.
31160   if ((SrcVT == MVT::v16i1 || SrcVT == MVT::v32i1) && DstVT.isScalarInteger()) {
31161     assert(!Subtarget.hasAVX512() && "Should use K-registers with AVX512");
31162     MVT SExtVT = SrcVT == MVT::v16i1 ? MVT::v16i8 : MVT::v32i8;
31163     SDLoc DL(Op);
31164     SDValue V = DAG.getSExtOrTrunc(Src, DL, SExtVT);
31165     V = getPMOVMSKB(DL, V, DAG, Subtarget);
31166     return DAG.getZExtOrTrunc(V, DL, DstVT);
31167   }
31168 
31169   assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
31170           SrcVT == MVT::i64) && "Unexpected VT!");
31171 
31172   assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
31173   if (!(DstVT == MVT::f64 && SrcVT == MVT::i64) &&
31174       !(DstVT == MVT::x86mmx && SrcVT.isVector()))
31175     // This conversion needs to be expanded.
31176     return SDValue();
31177 
31178   SDLoc dl(Op);
31179   if (SrcVT.isVector()) {
31180     // Widen the vector in input in the case of MVT::v2i32.
31181     // Example: from MVT::v2i32 to MVT::v4i32.
31182     MVT NewVT = MVT::getVectorVT(SrcVT.getVectorElementType(),
31183                                  SrcVT.getVectorNumElements() * 2);
31184     Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewVT, Src,
31185                       DAG.getUNDEF(SrcVT));
31186   } else {
31187     assert(SrcVT == MVT::i64 && !Subtarget.is64Bit() &&
31188            "Unexpected source type in LowerBITCAST");
31189     Src = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2i64, Src);
31190   }
31191 
31192   MVT V2X64VT = DstVT == MVT::f64 ? MVT::v2f64 : MVT::v2i64;
31193   Src = DAG.getNode(ISD::BITCAST, dl, V2X64VT, Src);
31194 
31195   if (DstVT == MVT::x86mmx)
31196     return DAG.getNode(X86ISD::MOVDQ2Q, dl, DstVT, Src);
31197 
31198   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, DstVT, Src,
31199                      DAG.getIntPtrConstant(0, dl));
31200 }
31201 
31202 /// Compute the horizontal sum of bytes in V for the elements of VT.
31203 ///
31204 /// Requires V to be a byte vector and VT to be an integer vector type with
31205 /// wider elements than V's type. The width of the elements of VT determines
31206 /// how many bytes of V are summed horizontally to produce each element of the
31207 /// result.
LowerHorizontalByteSum(SDValue V,MVT VT,const X86Subtarget & Subtarget,SelectionDAG & DAG)31208 static SDValue LowerHorizontalByteSum(SDValue V, MVT VT,
31209                                       const X86Subtarget &Subtarget,
31210                                       SelectionDAG &DAG) {
31211   SDLoc DL(V);
31212   MVT ByteVecVT = V.getSimpleValueType();
31213   MVT EltVT = VT.getVectorElementType();
31214   assert(ByteVecVT.getVectorElementType() == MVT::i8 &&
31215          "Expected value to have byte element type.");
31216   assert(EltVT != MVT::i8 &&
31217          "Horizontal byte sum only makes sense for wider elements!");
31218   unsigned VecSize = VT.getSizeInBits();
31219   assert(ByteVecVT.getSizeInBits() == VecSize && "Cannot change vector size!");
31220 
31221   // PSADBW instruction horizontally add all bytes and leave the result in i64
31222   // chunks, thus directly computes the pop count for v2i64 and v4i64.
31223   if (EltVT == MVT::i64) {
31224     SDValue Zeros = DAG.getConstant(0, DL, ByteVecVT);
31225     MVT SadVecVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
31226     V = DAG.getNode(X86ISD::PSADBW, DL, SadVecVT, V, Zeros);
31227     return DAG.getBitcast(VT, V);
31228   }
31229 
31230   if (EltVT == MVT::i32) {
31231     // We unpack the low half and high half into i32s interleaved with zeros so
31232     // that we can use PSADBW to horizontally sum them. The most useful part of
31233     // this is that it lines up the results of two PSADBW instructions to be
31234     // two v2i64 vectors which concatenated are the 4 population counts. We can
31235     // then use PACKUSWB to shrink and concatenate them into a v4i32 again.
31236     SDValue Zeros = DAG.getConstant(0, DL, VT);
31237     SDValue V32 = DAG.getBitcast(VT, V);
31238     SDValue Low = getUnpackl(DAG, DL, VT, V32, Zeros);
31239     SDValue High = getUnpackh(DAG, DL, VT, V32, Zeros);
31240 
31241     // Do the horizontal sums into two v2i64s.
31242     Zeros = DAG.getConstant(0, DL, ByteVecVT);
31243     MVT SadVecVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
31244     Low = DAG.getNode(X86ISD::PSADBW, DL, SadVecVT,
31245                       DAG.getBitcast(ByteVecVT, Low), Zeros);
31246     High = DAG.getNode(X86ISD::PSADBW, DL, SadVecVT,
31247                        DAG.getBitcast(ByteVecVT, High), Zeros);
31248 
31249     // Merge them together.
31250     MVT ShortVecVT = MVT::getVectorVT(MVT::i16, VecSize / 16);
31251     V = DAG.getNode(X86ISD::PACKUS, DL, ByteVecVT,
31252                     DAG.getBitcast(ShortVecVT, Low),
31253                     DAG.getBitcast(ShortVecVT, High));
31254 
31255     return DAG.getBitcast(VT, V);
31256   }
31257 
31258   // The only element type left is i16.
31259   assert(EltVT == MVT::i16 && "Unknown how to handle type");
31260 
31261   // To obtain pop count for each i16 element starting from the pop count for
31262   // i8 elements, shift the i16s left by 8, sum as i8s, and then shift as i16s
31263   // right by 8. It is important to shift as i16s as i8 vector shift isn't
31264   // directly supported.
31265   SDValue ShifterV = DAG.getConstant(8, DL, VT);
31266   SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, DAG.getBitcast(VT, V), ShifterV);
31267   V = DAG.getNode(ISD::ADD, DL, ByteVecVT, DAG.getBitcast(ByteVecVT, Shl),
31268                   DAG.getBitcast(ByteVecVT, V));
31269   return DAG.getNode(ISD::SRL, DL, VT, DAG.getBitcast(VT, V), ShifterV);
31270 }
31271 
LowerVectorCTPOPInRegLUT(SDValue Op,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)31272 static SDValue LowerVectorCTPOPInRegLUT(SDValue Op, const SDLoc &DL,
31273                                         const X86Subtarget &Subtarget,
31274                                         SelectionDAG &DAG) {
31275   MVT VT = Op.getSimpleValueType();
31276   MVT EltVT = VT.getVectorElementType();
31277   int NumElts = VT.getVectorNumElements();
31278   (void)EltVT;
31279   assert(EltVT == MVT::i8 && "Only vXi8 vector CTPOP lowering supported.");
31280 
31281   // Implement a lookup table in register by using an algorithm based on:
31282   // http://wm.ite.pl/articles/sse-popcount.html
31283   //
31284   // The general idea is that every lower byte nibble in the input vector is an
31285   // index into a in-register pre-computed pop count table. We then split up the
31286   // input vector in two new ones: (1) a vector with only the shifted-right
31287   // higher nibbles for each byte and (2) a vector with the lower nibbles (and
31288   // masked out higher ones) for each byte. PSHUFB is used separately with both
31289   // to index the in-register table. Next, both are added and the result is a
31290   // i8 vector where each element contains the pop count for input byte.
31291   const int LUT[16] = {/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
31292                        /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
31293                        /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
31294                        /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4};
31295 
31296   SmallVector<SDValue, 64> LUTVec;
31297   for (int i = 0; i < NumElts; ++i)
31298     LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8));
31299   SDValue InRegLUT = DAG.getBuildVector(VT, DL, LUTVec);
31300   SDValue M0F = DAG.getConstant(0x0F, DL, VT);
31301 
31302   // High nibbles
31303   SDValue FourV = DAG.getConstant(4, DL, VT);
31304   SDValue HiNibbles = DAG.getNode(ISD::SRL, DL, VT, Op, FourV);
31305 
31306   // Low nibbles
31307   SDValue LoNibbles = DAG.getNode(ISD::AND, DL, VT, Op, M0F);
31308 
31309   // The input vector is used as the shuffle mask that index elements into the
31310   // LUT. After counting low and high nibbles, add the vector to obtain the
31311   // final pop count per i8 element.
31312   SDValue HiPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, HiNibbles);
31313   SDValue LoPopCnt = DAG.getNode(X86ISD::PSHUFB, DL, VT, InRegLUT, LoNibbles);
31314   return DAG.getNode(ISD::ADD, DL, VT, HiPopCnt, LoPopCnt);
31315 }
31316 
31317 // Please ensure that any codegen change from LowerVectorCTPOP is reflected in
31318 // updated cost models in X86TTIImpl::getIntrinsicInstrCost.
LowerVectorCTPOP(SDValue Op,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)31319 static SDValue LowerVectorCTPOP(SDValue Op, const SDLoc &DL,
31320                                 const X86Subtarget &Subtarget,
31321                                 SelectionDAG &DAG) {
31322   MVT VT = Op.getSimpleValueType();
31323   assert((VT.is512BitVector() || VT.is256BitVector() || VT.is128BitVector()) &&
31324          "Unknown CTPOP type to handle");
31325   SDValue Op0 = Op.getOperand(0);
31326 
31327   // TRUNC(CTPOP(ZEXT(X))) to make use of vXi32/vXi64 VPOPCNT instructions.
31328   if (Subtarget.hasVPOPCNTDQ()) {
31329     unsigned NumElems = VT.getVectorNumElements();
31330     assert((VT.getVectorElementType() == MVT::i8 ||
31331             VT.getVectorElementType() == MVT::i16) && "Unexpected type");
31332     if (NumElems < 16 || (NumElems == 16 && Subtarget.canExtendTo512DQ())) {
31333       MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
31334       Op = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, Op0);
31335       Op = DAG.getNode(ISD::CTPOP, DL, NewVT, Op);
31336       return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
31337     }
31338   }
31339 
31340   // Decompose 256-bit ops into smaller 128-bit ops.
31341   if (VT.is256BitVector() && !Subtarget.hasInt256())
31342     return splitVectorIntUnary(Op, DAG, DL);
31343 
31344   // Decompose 512-bit ops into smaller 256-bit ops.
31345   if (VT.is512BitVector() && !Subtarget.hasBWI())
31346     return splitVectorIntUnary(Op, DAG, DL);
31347 
31348   // For element types greater than i8, do vXi8 pop counts and a bytesum.
31349   if (VT.getScalarType() != MVT::i8) {
31350     MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8);
31351     SDValue ByteOp = DAG.getBitcast(ByteVT, Op0);
31352     SDValue PopCnt8 = DAG.getNode(ISD::CTPOP, DL, ByteVT, ByteOp);
31353     return LowerHorizontalByteSum(PopCnt8, VT, Subtarget, DAG);
31354   }
31355 
31356   // We can't use the fast LUT approach, so fall back on LegalizeDAG.
31357   if (!Subtarget.hasSSSE3())
31358     return SDValue();
31359 
31360   return LowerVectorCTPOPInRegLUT(Op0, DL, Subtarget, DAG);
31361 }
31362 
LowerCTPOP(SDValue N,const X86Subtarget & Subtarget,SelectionDAG & DAG)31363 static SDValue LowerCTPOP(SDValue N, const X86Subtarget &Subtarget,
31364                           SelectionDAG &DAG) {
31365   MVT VT = N.getSimpleValueType();
31366   SDValue Op = N.getOperand(0);
31367   SDLoc DL(N);
31368 
31369   if (VT.isScalarInteger()) {
31370     // Compute the lower/upper bounds of the active bits of the value,
31371     // allowing us to shift the active bits down if necessary to fit into the
31372     // special cases below.
31373     KnownBits Known = DAG.computeKnownBits(Op);
31374     unsigned LZ = Known.countMinLeadingZeros();
31375     unsigned TZ = Known.countMinTrailingZeros();
31376     assert((LZ + TZ) < Known.getBitWidth() && "Illegal shifted mask");
31377     unsigned ActiveBits = Known.getBitWidth() - LZ;
31378     unsigned ShiftedActiveBits = Known.getBitWidth() - (LZ + TZ);
31379 
31380     // i2 CTPOP - "ctpop(x) --> sub(x, (x >> 1))".
31381     if (ShiftedActiveBits <= 2) {
31382       if (ActiveBits > 2)
31383         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
31384                          DAG.getShiftAmountConstant(TZ, VT, DL));
31385       Op = DAG.getZExtOrTrunc(Op, DL, MVT::i32);
31386       Op = DAG.getNode(ISD::SUB, DL, MVT::i32, Op,
31387                        DAG.getNode(ISD::SRL, DL, MVT::i32, Op,
31388                                    DAG.getShiftAmountConstant(1, VT, DL)));
31389       return DAG.getZExtOrTrunc(Op, DL, VT);
31390     }
31391 
31392     // i3 CTPOP - perform LUT into i32 integer.
31393     if (ShiftedActiveBits <= 3) {
31394       if (ActiveBits > 3)
31395         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
31396                          DAG.getShiftAmountConstant(TZ, VT, DL));
31397       Op = DAG.getZExtOrTrunc(Op, DL, MVT::i32);
31398       Op = DAG.getNode(ISD::SHL, DL, MVT::i32, Op,
31399                        DAG.getShiftAmountConstant(1, VT, DL));
31400       Op = DAG.getNode(ISD::SRL, DL, MVT::i32,
31401                        DAG.getConstant(0b1110100110010100U, DL, MVT::i32), Op);
31402       Op = DAG.getNode(ISD::AND, DL, MVT::i32, Op,
31403                        DAG.getConstant(0x3, DL, MVT::i32));
31404       return DAG.getZExtOrTrunc(Op, DL, VT);
31405     }
31406 
31407     // i4 CTPOP - perform LUT into i64 integer.
31408     if (ShiftedActiveBits <= 4 &&
31409         DAG.getTargetLoweringInfo().isTypeLegal(MVT::i64)) {
31410       SDValue LUT = DAG.getConstant(0x4332322132212110ULL, DL, MVT::i64);
31411       if (ActiveBits > 4)
31412         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
31413                          DAG.getShiftAmountConstant(TZ, VT, DL));
31414       Op = DAG.getZExtOrTrunc(Op, DL, MVT::i32);
31415       Op = DAG.getNode(ISD::MUL, DL, MVT::i32, Op,
31416                        DAG.getConstant(4, DL, MVT::i32));
31417       Op = DAG.getNode(ISD::SRL, DL, MVT::i64, LUT,
31418                        DAG.getShiftAmountOperand(MVT::i64, Op));
31419       Op = DAG.getNode(ISD::AND, DL, MVT::i64, Op,
31420                        DAG.getConstant(0x7, DL, MVT::i64));
31421       return DAG.getZExtOrTrunc(Op, DL, VT);
31422     }
31423 
31424     // i8 CTPOP - with efficient i32 MUL, then attempt multiply-mask-multiply.
31425     if (ShiftedActiveBits <= 8) {
31426       SDValue Mask11 = DAG.getConstant(0x11111111U, DL, MVT::i32);
31427       if (ActiveBits > 8)
31428         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
31429                          DAG.getShiftAmountConstant(TZ, VT, DL));
31430       Op = DAG.getZExtOrTrunc(Op, DL, MVT::i32);
31431       Op = DAG.getNode(ISD::MUL, DL, MVT::i32, Op,
31432                        DAG.getConstant(0x08040201U, DL, MVT::i32));
31433       Op = DAG.getNode(ISD::SRL, DL, MVT::i32, Op,
31434                        DAG.getShiftAmountConstant(3, MVT::i32, DL));
31435       Op = DAG.getNode(ISD::AND, DL, MVT::i32, Op, Mask11);
31436       Op = DAG.getNode(ISD::MUL, DL, MVT::i32, Op, Mask11);
31437       Op = DAG.getNode(ISD::SRL, DL, MVT::i32, Op,
31438                        DAG.getShiftAmountConstant(28, MVT::i32, DL));
31439       return DAG.getZExtOrTrunc(Op, DL, VT);
31440     }
31441 
31442     return SDValue(); // fallback to generic expansion.
31443   }
31444 
31445   assert(VT.isVector() &&
31446          "We only do custom lowering for vector population count.");
31447   return LowerVectorCTPOP(N, DL, Subtarget, DAG);
31448 }
31449 
LowerBITREVERSE_XOP(SDValue Op,SelectionDAG & DAG)31450 static SDValue LowerBITREVERSE_XOP(SDValue Op, SelectionDAG &DAG) {
31451   MVT VT = Op.getSimpleValueType();
31452   SDValue In = Op.getOperand(0);
31453   SDLoc DL(Op);
31454 
31455   // For scalars, its still beneficial to transfer to/from the SIMD unit to
31456   // perform the BITREVERSE.
31457   if (!VT.isVector()) {
31458     MVT VecVT = MVT::getVectorVT(VT, 128 / VT.getSizeInBits());
31459     SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, In);
31460     Res = DAG.getNode(ISD::BITREVERSE, DL, VecVT, Res);
31461     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Res,
31462                        DAG.getIntPtrConstant(0, DL));
31463   }
31464 
31465   int NumElts = VT.getVectorNumElements();
31466   int ScalarSizeInBytes = VT.getScalarSizeInBits() / 8;
31467 
31468   // Decompose 256-bit ops into smaller 128-bit ops.
31469   if (VT.is256BitVector())
31470     return splitVectorIntUnary(Op, DAG, DL);
31471 
31472   assert(VT.is128BitVector() &&
31473          "Only 128-bit vector bitreverse lowering supported.");
31474 
31475   // VPPERM reverses the bits of a byte with the permute Op (2 << 5), and we
31476   // perform the BSWAP in the shuffle.
31477   // Its best to shuffle using the second operand as this will implicitly allow
31478   // memory folding for multiple vectors.
31479   SmallVector<SDValue, 16> MaskElts;
31480   for (int i = 0; i != NumElts; ++i) {
31481     for (int j = ScalarSizeInBytes - 1; j >= 0; --j) {
31482       int SourceByte = 16 + (i * ScalarSizeInBytes) + j;
31483       int PermuteByte = SourceByte | (2 << 5);
31484       MaskElts.push_back(DAG.getConstant(PermuteByte, DL, MVT::i8));
31485     }
31486   }
31487 
31488   SDValue Mask = DAG.getBuildVector(MVT::v16i8, DL, MaskElts);
31489   SDValue Res = DAG.getBitcast(MVT::v16i8, In);
31490   Res = DAG.getNode(X86ISD::VPPERM, DL, MVT::v16i8, DAG.getUNDEF(MVT::v16i8),
31491                     Res, Mask);
31492   return DAG.getBitcast(VT, Res);
31493 }
31494 
LowerBITREVERSE(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)31495 static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget,
31496                                SelectionDAG &DAG) {
31497   MVT VT = Op.getSimpleValueType();
31498 
31499   if (Subtarget.hasXOP() && !VT.is512BitVector())
31500     return LowerBITREVERSE_XOP(Op, DAG);
31501 
31502   assert(Subtarget.hasSSSE3() && "SSSE3 required for BITREVERSE");
31503 
31504   SDValue In = Op.getOperand(0);
31505   SDLoc DL(Op);
31506 
31507   // Split 512-bit ops without BWI so that we can still use the PSHUFB lowering.
31508   if (VT.is512BitVector() && !Subtarget.hasBWI())
31509     return splitVectorIntUnary(Op, DAG, DL);
31510 
31511   // Decompose 256-bit ops into smaller 128-bit ops on pre-AVX2.
31512   if (VT.is256BitVector() && !Subtarget.hasInt256())
31513     return splitVectorIntUnary(Op, DAG, DL);
31514 
31515   // Lower i8/i16/i32/i64 as vXi8 BITREVERSE + BSWAP
31516   if (!VT.isVector()) {
31517     assert(
31518         (VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16 || VT == MVT::i8) &&
31519         "Only tested for i8/i16/i32/i64");
31520     MVT VecVT = MVT::getVectorVT(VT, 128 / VT.getSizeInBits());
31521     SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, In);
31522     Res = DAG.getNode(ISD::BITREVERSE, DL, MVT::v16i8,
31523                       DAG.getBitcast(MVT::v16i8, Res));
31524     Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
31525                       DAG.getBitcast(VecVT, Res), DAG.getIntPtrConstant(0, DL));
31526     return (VT == MVT::i8) ? Res : DAG.getNode(ISD::BSWAP, DL, VT, Res);
31527   }
31528 
31529   assert(VT.isVector() && VT.getSizeInBits() >= 128);
31530 
31531   // Lower vXi16/vXi32/vXi64 as BSWAP + vXi8 BITREVERSE.
31532   if (VT.getScalarType() != MVT::i8) {
31533     MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8);
31534     SDValue Res = DAG.getNode(ISD::BSWAP, DL, VT, In);
31535     Res = DAG.getBitcast(ByteVT, Res);
31536     Res = DAG.getNode(ISD::BITREVERSE, DL, ByteVT, Res);
31537     return DAG.getBitcast(VT, Res);
31538   }
31539   assert(VT.isVector() && VT.getScalarType() == MVT::i8 &&
31540          "Only byte vector BITREVERSE supported");
31541 
31542   unsigned NumElts = VT.getVectorNumElements();
31543 
31544   // If we have GFNI, we can use GF2P8AFFINEQB to reverse the bits.
31545   if (Subtarget.hasGFNI()) {
31546     SDValue Matrix = getGFNICtrlMask(ISD::BITREVERSE, DAG, DL, VT);
31547     return DAG.getNode(X86ISD::GF2P8AFFINEQB, DL, VT, In, Matrix,
31548                        DAG.getTargetConstant(0, DL, MVT::i8));
31549   }
31550 
31551   // Perform BITREVERSE using PSHUFB lookups. Each byte is split into
31552   // two nibbles and a PSHUFB lookup to find the bitreverse of each
31553   // 0-15 value (moved to the other nibble).
31554   SDValue NibbleMask = DAG.getConstant(0xF, DL, VT);
31555   SDValue Lo = DAG.getNode(ISD::AND, DL, VT, In, NibbleMask);
31556   SDValue Hi = DAG.getNode(ISD::SRL, DL, VT, In, DAG.getConstant(4, DL, VT));
31557 
31558   const int LoLUT[16] = {
31559       /* 0 */ 0x00, /* 1 */ 0x80, /* 2 */ 0x40, /* 3 */ 0xC0,
31560       /* 4 */ 0x20, /* 5 */ 0xA0, /* 6 */ 0x60, /* 7 */ 0xE0,
31561       /* 8 */ 0x10, /* 9 */ 0x90, /* a */ 0x50, /* b */ 0xD0,
31562       /* c */ 0x30, /* d */ 0xB0, /* e */ 0x70, /* f */ 0xF0};
31563   const int HiLUT[16] = {
31564       /* 0 */ 0x00, /* 1 */ 0x08, /* 2 */ 0x04, /* 3 */ 0x0C,
31565       /* 4 */ 0x02, /* 5 */ 0x0A, /* 6 */ 0x06, /* 7 */ 0x0E,
31566       /* 8 */ 0x01, /* 9 */ 0x09, /* a */ 0x05, /* b */ 0x0D,
31567       /* c */ 0x03, /* d */ 0x0B, /* e */ 0x07, /* f */ 0x0F};
31568 
31569   SmallVector<SDValue, 16> LoMaskElts, HiMaskElts;
31570   for (unsigned i = 0; i < NumElts; ++i) {
31571     LoMaskElts.push_back(DAG.getConstant(LoLUT[i % 16], DL, MVT::i8));
31572     HiMaskElts.push_back(DAG.getConstant(HiLUT[i % 16], DL, MVT::i8));
31573   }
31574 
31575   SDValue LoMask = DAG.getBuildVector(VT, DL, LoMaskElts);
31576   SDValue HiMask = DAG.getBuildVector(VT, DL, HiMaskElts);
31577   Lo = DAG.getNode(X86ISD::PSHUFB, DL, VT, LoMask, Lo);
31578   Hi = DAG.getNode(X86ISD::PSHUFB, DL, VT, HiMask, Hi);
31579   return DAG.getNode(ISD::OR, DL, VT, Lo, Hi);
31580 }
31581 
LowerPARITY(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)31582 static SDValue LowerPARITY(SDValue Op, const X86Subtarget &Subtarget,
31583                            SelectionDAG &DAG) {
31584   SDLoc DL(Op);
31585   SDValue X = Op.getOperand(0);
31586   MVT VT = Op.getSimpleValueType();
31587 
31588   // Special case. If the input fits in 8-bits we can use a single 8-bit TEST.
31589   if (VT == MVT::i8 ||
31590       DAG.MaskedValueIsZero(X, APInt::getBitsSetFrom(VT.getSizeInBits(), 8))) {
31591     X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, X);
31592     SDValue Flags = DAG.getNode(X86ISD::CMP, DL, MVT::i32, X,
31593                                 DAG.getConstant(0, DL, MVT::i8));
31594     // Copy the inverse of the parity flag into a register with setcc.
31595     SDValue Setnp = getSETCC(X86::COND_NP, Flags, DL, DAG);
31596     // Extend to the original type.
31597     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Setnp);
31598   }
31599 
31600   // If we have POPCNT, use the default expansion.
31601   if (Subtarget.hasPOPCNT())
31602     return SDValue();
31603 
31604   if (VT == MVT::i64) {
31605     // Xor the high and low 16-bits together using a 32-bit operation.
31606     SDValue Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32,
31607                              DAG.getNode(ISD::SRL, DL, MVT::i64, X,
31608                                          DAG.getConstant(32, DL, MVT::i8)));
31609     SDValue Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
31610     X = DAG.getNode(ISD::XOR, DL, MVT::i32, Lo, Hi);
31611   }
31612 
31613   if (VT != MVT::i16) {
31614     // Xor the high and low 16-bits together using a 32-bit operation.
31615     SDValue Hi16 = DAG.getNode(ISD::SRL, DL, MVT::i32, X,
31616                                DAG.getConstant(16, DL, MVT::i8));
31617     X = DAG.getNode(ISD::XOR, DL, MVT::i32, X, Hi16);
31618   } else {
31619     // If the input is 16-bits, we need to extend to use an i32 shift below.
31620     X = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, X);
31621   }
31622 
31623   // Finally xor the low 2 bytes together and use a 8-bit flag setting xor.
31624   // This should allow an h-reg to be used to save a shift.
31625   SDValue Hi = DAG.getNode(
31626       ISD::TRUNCATE, DL, MVT::i8,
31627       DAG.getNode(ISD::SRL, DL, MVT::i32, X, DAG.getConstant(8, DL, MVT::i8)));
31628   SDValue Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, X);
31629   SDVTList VTs = DAG.getVTList(MVT::i8, MVT::i32);
31630   SDValue Flags = DAG.getNode(X86ISD::XOR, DL, VTs, Lo, Hi).getValue(1);
31631 
31632   // Copy the inverse of the parity flag into a register with setcc.
31633   SDValue Setnp = getSETCC(X86::COND_NP, Flags, DL, DAG);
31634   // Extend to the original type.
31635   return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Setnp);
31636 }
31637 
lowerAtomicArithWithLOCK(SDValue N,SelectionDAG & DAG,const X86Subtarget & Subtarget)31638 static SDValue lowerAtomicArithWithLOCK(SDValue N, SelectionDAG &DAG,
31639                                         const X86Subtarget &Subtarget) {
31640   unsigned NewOpc = 0;
31641   switch (N->getOpcode()) {
31642   case ISD::ATOMIC_LOAD_ADD:
31643     NewOpc = X86ISD::LADD;
31644     break;
31645   case ISD::ATOMIC_LOAD_SUB:
31646     NewOpc = X86ISD::LSUB;
31647     break;
31648   case ISD::ATOMIC_LOAD_OR:
31649     NewOpc = X86ISD::LOR;
31650     break;
31651   case ISD::ATOMIC_LOAD_XOR:
31652     NewOpc = X86ISD::LXOR;
31653     break;
31654   case ISD::ATOMIC_LOAD_AND:
31655     NewOpc = X86ISD::LAND;
31656     break;
31657   default:
31658     llvm_unreachable("Unknown ATOMIC_LOAD_ opcode");
31659   }
31660 
31661   MachineMemOperand *MMO = cast<MemSDNode>(N)->getMemOperand();
31662 
31663   return DAG.getMemIntrinsicNode(
31664       NewOpc, SDLoc(N), DAG.getVTList(MVT::i32, MVT::Other),
31665       {N->getOperand(0), N->getOperand(1), N->getOperand(2)},
31666       /*MemVT=*/N->getSimpleValueType(0), MMO);
31667 }
31668 
31669 /// Lower atomic_load_ops into LOCK-prefixed operations.
lowerAtomicArith(SDValue N,SelectionDAG & DAG,const X86Subtarget & Subtarget)31670 static SDValue lowerAtomicArith(SDValue N, SelectionDAG &DAG,
31671                                 const X86Subtarget &Subtarget) {
31672   AtomicSDNode *AN = cast<AtomicSDNode>(N.getNode());
31673   SDValue Chain = N->getOperand(0);
31674   SDValue LHS = N->getOperand(1);
31675   SDValue RHS = N->getOperand(2);
31676   unsigned Opc = N->getOpcode();
31677   MVT VT = N->getSimpleValueType(0);
31678   SDLoc DL(N);
31679 
31680   // We can lower atomic_load_add into LXADD. However, any other atomicrmw op
31681   // can only be lowered when the result is unused.  They should have already
31682   // been transformed into a cmpxchg loop in AtomicExpand.
31683   if (N->hasAnyUseOfValue(0)) {
31684     // Handle (atomic_load_sub p, v) as (atomic_load_add p, -v), to be able to
31685     // select LXADD if LOCK_SUB can't be selected.
31686     // Handle (atomic_load_xor p, SignBit) as (atomic_load_add p, SignBit) so we
31687     // can use LXADD as opposed to cmpxchg.
31688     if (Opc == ISD::ATOMIC_LOAD_SUB ||
31689         (Opc == ISD::ATOMIC_LOAD_XOR && isMinSignedConstant(RHS)))
31690       return DAG.getAtomic(ISD::ATOMIC_LOAD_ADD, DL, VT, Chain, LHS,
31691                            DAG.getNegative(RHS, DL, VT), AN->getMemOperand());
31692 
31693     assert(Opc == ISD::ATOMIC_LOAD_ADD &&
31694            "Used AtomicRMW ops other than Add should have been expanded!");
31695     return N;
31696   }
31697 
31698   // Specialized lowering for the canonical form of an idemptotent atomicrmw.
31699   // The core idea here is that since the memory location isn't actually
31700   // changing, all we need is a lowering for the *ordering* impacts of the
31701   // atomicrmw.  As such, we can chose a different operation and memory
31702   // location to minimize impact on other code.
31703   // The above holds unless the node is marked volatile in which
31704   // case it needs to be preserved according to the langref.
31705   if (Opc == ISD::ATOMIC_LOAD_OR && isNullConstant(RHS) && !AN->isVolatile()) {
31706     // On X86, the only ordering which actually requires an instruction is
31707     // seq_cst which isn't SingleThread, everything just needs to be preserved
31708     // during codegen and then dropped. Note that we expect (but don't assume),
31709     // that orderings other than seq_cst and acq_rel have been canonicalized to
31710     // a store or load.
31711     if (AN->getSuccessOrdering() == AtomicOrdering::SequentiallyConsistent &&
31712         AN->getSyncScopeID() == SyncScope::System) {
31713       // Prefer a locked operation against a stack location to minimize cache
31714       // traffic.  This assumes that stack locations are very likely to be
31715       // accessed only by the owning thread.
31716       SDValue NewChain = emitLockedStackOp(DAG, Subtarget, Chain, DL);
31717       assert(!N->hasAnyUseOfValue(0));
31718       // NOTE: The getUNDEF is needed to give something for the unused result 0.
31719       return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(),
31720                          DAG.getUNDEF(VT), NewChain);
31721     }
31722     // MEMBARRIER is a compiler barrier; it codegens to a no-op.
31723     SDValue NewChain = DAG.getNode(ISD::MEMBARRIER, DL, MVT::Other, Chain);
31724     assert(!N->hasAnyUseOfValue(0));
31725     // NOTE: The getUNDEF is needed to give something for the unused result 0.
31726     return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(),
31727                        DAG.getUNDEF(VT), NewChain);
31728   }
31729 
31730   SDValue LockOp = lowerAtomicArithWithLOCK(N, DAG, Subtarget);
31731   // RAUW the chain, but don't worry about the result, as it's unused.
31732   assert(!N->hasAnyUseOfValue(0));
31733   // NOTE: The getUNDEF is needed to give something for the unused result 0.
31734   return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(),
31735                      DAG.getUNDEF(VT), LockOp.getValue(1));
31736 }
31737 
LowerATOMIC_STORE(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)31738 static SDValue LowerATOMIC_STORE(SDValue Op, SelectionDAG &DAG,
31739                                  const X86Subtarget &Subtarget) {
31740   auto *Node = cast<AtomicSDNode>(Op.getNode());
31741   SDLoc dl(Node);
31742   EVT VT = Node->getMemoryVT();
31743 
31744   bool IsSeqCst =
31745       Node->getSuccessOrdering() == AtomicOrdering::SequentiallyConsistent;
31746   bool IsTypeLegal = DAG.getTargetLoweringInfo().isTypeLegal(VT);
31747 
31748   // If this store is not sequentially consistent and the type is legal
31749   // we can just keep it.
31750   if (!IsSeqCst && IsTypeLegal)
31751     return Op;
31752 
31753   if (!IsTypeLegal && !Subtarget.useSoftFloat() &&
31754       !DAG.getMachineFunction().getFunction().hasFnAttribute(
31755           Attribute::NoImplicitFloat)) {
31756     SDValue Chain;
31757     // For illegal i128 atomic_store, when AVX is enabled, we can simply emit a
31758     // vector store.
31759     if (VT == MVT::i128 && Subtarget.is64Bit() && Subtarget.hasAVX()) {
31760       SDValue VecVal = DAG.getBitcast(MVT::v2i64, Node->getVal());
31761       Chain = DAG.getStore(Node->getChain(), dl, VecVal, Node->getBasePtr(),
31762                            Node->getMemOperand());
31763     }
31764 
31765     // For illegal i64 atomic_stores, we can try to use MOVQ or MOVLPS if SSE
31766     // is enabled.
31767     if (VT == MVT::i64) {
31768       if (Subtarget.hasSSE1()) {
31769         SDValue SclToVec =
31770             DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2i64, Node->getVal());
31771         MVT StVT = Subtarget.hasSSE2() ? MVT::v2i64 : MVT::v4f32;
31772         SclToVec = DAG.getBitcast(StVT, SclToVec);
31773         SDVTList Tys = DAG.getVTList(MVT::Other);
31774         SDValue Ops[] = {Node->getChain(), SclToVec, Node->getBasePtr()};
31775         Chain = DAG.getMemIntrinsicNode(X86ISD::VEXTRACT_STORE, dl, Tys, Ops,
31776                                         MVT::i64, Node->getMemOperand());
31777       } else if (Subtarget.hasX87()) {
31778         // First load this into an 80-bit X87 register using a stack temporary.
31779         // This will put the whole integer into the significand.
31780         SDValue StackPtr = DAG.CreateStackTemporary(MVT::i64);
31781         int SPFI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
31782         MachinePointerInfo MPI =
31783             MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SPFI);
31784         Chain = DAG.getStore(Node->getChain(), dl, Node->getVal(), StackPtr,
31785                              MPI, MaybeAlign(), MachineMemOperand::MOStore);
31786         SDVTList Tys = DAG.getVTList(MVT::f80, MVT::Other);
31787         SDValue LdOps[] = {Chain, StackPtr};
31788         SDValue Value = DAG.getMemIntrinsicNode(
31789             X86ISD::FILD, dl, Tys, LdOps, MVT::i64, MPI,
31790             /*Align*/ std::nullopt, MachineMemOperand::MOLoad);
31791         Chain = Value.getValue(1);
31792 
31793         // Now use an FIST to do the atomic store.
31794         SDValue StoreOps[] = {Chain, Value, Node->getBasePtr()};
31795         Chain =
31796             DAG.getMemIntrinsicNode(X86ISD::FIST, dl, DAG.getVTList(MVT::Other),
31797                                     StoreOps, MVT::i64, Node->getMemOperand());
31798       }
31799     }
31800 
31801     if (Chain) {
31802       // If this is a sequentially consistent store, also emit an appropriate
31803       // barrier.
31804       if (IsSeqCst)
31805         Chain = emitLockedStackOp(DAG, Subtarget, Chain, dl);
31806 
31807       return Chain;
31808     }
31809   }
31810 
31811   // Convert seq_cst store -> xchg
31812   // Convert wide store -> swap (-> cmpxchg8b/cmpxchg16b)
31813   // FIXME: 16-byte ATOMIC_SWAP isn't actually hooked up at the moment.
31814   SDValue Swap = DAG.getAtomic(ISD::ATOMIC_SWAP, dl, Node->getMemoryVT(),
31815                                Node->getOperand(0), Node->getOperand(2),
31816                                Node->getOperand(1), Node->getMemOperand());
31817   return Swap.getValue(1);
31818 }
31819 
LowerADDSUBO_CARRY(SDValue Op,SelectionDAG & DAG)31820 static SDValue LowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG) {
31821   SDNode *N = Op.getNode();
31822   MVT VT = N->getSimpleValueType(0);
31823   unsigned Opc = Op.getOpcode();
31824 
31825   // Let legalize expand this if it isn't a legal type yet.
31826   if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
31827     return SDValue();
31828 
31829   SDVTList VTs = DAG.getVTList(VT, MVT::i32);
31830   SDLoc DL(N);
31831 
31832   // Set the carry flag.
31833   SDValue Carry = Op.getOperand(2);
31834   EVT CarryVT = Carry.getValueType();
31835   Carry = DAG.getNode(X86ISD::ADD, DL, DAG.getVTList(CarryVT, MVT::i32),
31836                       Carry, DAG.getAllOnesConstant(DL, CarryVT));
31837 
31838   bool IsAdd = Opc == ISD::UADDO_CARRY || Opc == ISD::SADDO_CARRY;
31839   SDValue Sum = DAG.getNode(IsAdd ? X86ISD::ADC : X86ISD::SBB, DL, VTs,
31840                             Op.getOperand(0), Op.getOperand(1),
31841                             Carry.getValue(1));
31842 
31843   bool IsSigned = Opc == ISD::SADDO_CARRY || Opc == ISD::SSUBO_CARRY;
31844   SDValue SetCC = getSETCC(IsSigned ? X86::COND_O : X86::COND_B,
31845                            Sum.getValue(1), DL, DAG);
31846   if (N->getValueType(1) == MVT::i1)
31847     SetCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC);
31848 
31849   return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC);
31850 }
31851 
LowerFSINCOS(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)31852 static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget &Subtarget,
31853                             SelectionDAG &DAG) {
31854   assert(Subtarget.isTargetDarwin() && Subtarget.is64Bit());
31855 
31856   // For MacOSX, we want to call an alternative entry point: __sincos_stret,
31857   // which returns the values as { float, float } (in XMM0) or
31858   // { double, double } (which is returned in XMM0, XMM1).
31859   SDLoc dl(Op);
31860   SDValue Arg = Op.getOperand(0);
31861   EVT ArgVT = Arg.getValueType();
31862   Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
31863 
31864   TargetLowering::ArgListTy Args;
31865   TargetLowering::ArgListEntry Entry;
31866 
31867   Entry.Node = Arg;
31868   Entry.Ty = ArgTy;
31869   Entry.IsSExt = false;
31870   Entry.IsZExt = false;
31871   Args.push_back(Entry);
31872 
31873   bool isF64 = ArgVT == MVT::f64;
31874   // Only optimize x86_64 for now. i386 is a bit messy. For f32,
31875   // the small struct {f32, f32} is returned in (eax, edx). For f64,
31876   // the results are returned via SRet in memory.
31877   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
31878   RTLIB::Libcall LC = isF64 ? RTLIB::SINCOS_STRET_F64 : RTLIB::SINCOS_STRET_F32;
31879   const char *LibcallName = TLI.getLibcallName(LC);
31880   SDValue Callee =
31881       DAG.getExternalSymbol(LibcallName, TLI.getPointerTy(DAG.getDataLayout()));
31882 
31883   Type *RetTy = isF64 ? (Type *)StructType::get(ArgTy, ArgTy)
31884                       : (Type *)FixedVectorType::get(ArgTy, 4);
31885 
31886   TargetLowering::CallLoweringInfo CLI(DAG);
31887   CLI.setDebugLoc(dl)
31888       .setChain(DAG.getEntryNode())
31889       .setLibCallee(CallingConv::C, RetTy, Callee, std::move(Args));
31890 
31891   std::pair<SDValue, SDValue> CallResult = TLI.LowerCallTo(CLI);
31892 
31893   if (isF64)
31894     // Returned in xmm0 and xmm1.
31895     return CallResult.first;
31896 
31897   // Returned in bits 0:31 and 32:64 xmm0.
31898   SDValue SinVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ArgVT,
31899                                CallResult.first, DAG.getIntPtrConstant(0, dl));
31900   SDValue CosVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ArgVT,
31901                                CallResult.first, DAG.getIntPtrConstant(1, dl));
31902   SDVTList Tys = DAG.getVTList(ArgVT, ArgVT);
31903   return DAG.getNode(ISD::MERGE_VALUES, dl, Tys, SinVal, CosVal);
31904 }
31905 
31906 /// Widen a vector input to a vector of NVT.  The
31907 /// input vector must have the same element type as NVT.
ExtendToType(SDValue InOp,MVT NVT,SelectionDAG & DAG,bool FillWithZeroes=false)31908 static SDValue ExtendToType(SDValue InOp, MVT NVT, SelectionDAG &DAG,
31909                             bool FillWithZeroes = false) {
31910   // Check if InOp already has the right width.
31911   MVT InVT = InOp.getSimpleValueType();
31912   if (InVT == NVT)
31913     return InOp;
31914 
31915   if (InOp.isUndef())
31916     return DAG.getUNDEF(NVT);
31917 
31918   assert(InVT.getVectorElementType() == NVT.getVectorElementType() &&
31919          "input and widen element type must match");
31920 
31921   unsigned InNumElts = InVT.getVectorNumElements();
31922   unsigned WidenNumElts = NVT.getVectorNumElements();
31923   assert(WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0 &&
31924          "Unexpected request for vector widening");
31925 
31926   SDLoc dl(InOp);
31927   if (InOp.getOpcode() == ISD::CONCAT_VECTORS &&
31928       InOp.getNumOperands() == 2) {
31929     SDValue N1 = InOp.getOperand(1);
31930     if ((ISD::isBuildVectorAllZeros(N1.getNode()) && FillWithZeroes) ||
31931         N1.isUndef()) {
31932       InOp = InOp.getOperand(0);
31933       InVT = InOp.getSimpleValueType();
31934       InNumElts = InVT.getVectorNumElements();
31935     }
31936   }
31937   if (ISD::isBuildVectorOfConstantSDNodes(InOp.getNode()) ||
31938       ISD::isBuildVectorOfConstantFPSDNodes(InOp.getNode())) {
31939     SmallVector<SDValue, 16> Ops;
31940     for (unsigned i = 0; i < InNumElts; ++i)
31941       Ops.push_back(InOp.getOperand(i));
31942 
31943     EVT EltVT = InOp.getOperand(0).getValueType();
31944 
31945     SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, EltVT) :
31946       DAG.getUNDEF(EltVT);
31947     for (unsigned i = 0; i < WidenNumElts - InNumElts; ++i)
31948       Ops.push_back(FillVal);
31949     return DAG.getBuildVector(NVT, dl, Ops);
31950   }
31951   SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, NVT) :
31952     DAG.getUNDEF(NVT);
31953   return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, NVT, FillVal,
31954                      InOp, DAG.getIntPtrConstant(0, dl));
31955 }
31956 
LowerMSCATTER(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)31957 static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
31958                              SelectionDAG &DAG) {
31959   assert(Subtarget.hasAVX512() &&
31960          "MGATHER/MSCATTER are supported on AVX-512 arch only");
31961 
31962   MaskedScatterSDNode *N = cast<MaskedScatterSDNode>(Op.getNode());
31963   SDValue Src = N->getValue();
31964   MVT VT = Src.getSimpleValueType();
31965   assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op");
31966   SDLoc dl(Op);
31967 
31968   SDValue Scale = N->getScale();
31969   SDValue Index = N->getIndex();
31970   SDValue Mask = N->getMask();
31971   SDValue Chain = N->getChain();
31972   SDValue BasePtr = N->getBasePtr();
31973 
31974   if (VT == MVT::v2f32 || VT == MVT::v2i32) {
31975     assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type");
31976     // If the index is v2i64 and we have VLX we can use xmm for data and index.
31977     if (Index.getValueType() == MVT::v2i64 && Subtarget.hasVLX()) {
31978       const TargetLowering &TLI = DAG.getTargetLoweringInfo();
31979       EVT WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
31980       Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, WideVT, Src, DAG.getUNDEF(VT));
31981       SDVTList VTs = DAG.getVTList(MVT::Other);
31982       SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
31983       return DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
31984                                      N->getMemoryVT(), N->getMemOperand());
31985     }
31986     return SDValue();
31987   }
31988 
31989   MVT IndexVT = Index.getSimpleValueType();
31990 
31991   // If the index is v2i32, we're being called by type legalization and we
31992   // should just let the default handling take care of it.
31993   if (IndexVT == MVT::v2i32)
31994     return SDValue();
31995 
31996   // If we don't have VLX and neither the passthru or index is 512-bits, we
31997   // need to widen until one is.
31998   if (!Subtarget.hasVLX() && !VT.is512BitVector() &&
31999       !Index.getSimpleValueType().is512BitVector()) {
32000     // Determine how much we need to widen by to get a 512-bit type.
32001     unsigned Factor = std::min(512/VT.getSizeInBits(),
32002                                512/IndexVT.getSizeInBits());
32003     unsigned NumElts = VT.getVectorNumElements() * Factor;
32004 
32005     VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts);
32006     IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts);
32007     MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
32008 
32009     Src = ExtendToType(Src, VT, DAG);
32010     Index = ExtendToType(Index, IndexVT, DAG);
32011     Mask = ExtendToType(Mask, MaskVT, DAG, true);
32012   }
32013 
32014   SDVTList VTs = DAG.getVTList(MVT::Other);
32015   SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
32016   return DAG.getMemIntrinsicNode(X86ISD::MSCATTER, dl, VTs, Ops,
32017                                  N->getMemoryVT(), N->getMemOperand());
32018 }
32019 
LowerMLOAD(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)32020 static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget,
32021                           SelectionDAG &DAG) {
32022 
32023   MaskedLoadSDNode *N = cast<MaskedLoadSDNode>(Op.getNode());
32024   MVT VT = Op.getSimpleValueType();
32025   MVT ScalarVT = VT.getScalarType();
32026   SDValue Mask = N->getMask();
32027   MVT MaskVT = Mask.getSimpleValueType();
32028   SDValue PassThru = N->getPassThru();
32029   SDLoc dl(Op);
32030 
32031   // Handle AVX masked loads which don't support passthru other than 0.
32032   if (MaskVT.getVectorElementType() != MVT::i1) {
32033     // We also allow undef in the isel pattern.
32034     if (PassThru.isUndef() || ISD::isBuildVectorAllZeros(PassThru.getNode()))
32035       return Op;
32036 
32037     SDValue NewLoad = DAG.getMaskedLoad(
32038         VT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
32039         getZeroVector(VT, Subtarget, DAG, dl), N->getMemoryVT(),
32040         N->getMemOperand(), N->getAddressingMode(), N->getExtensionType(),
32041         N->isExpandingLoad());
32042     // Emit a blend.
32043     SDValue Select = DAG.getNode(ISD::VSELECT, dl, VT, Mask, NewLoad, PassThru);
32044     return DAG.getMergeValues({ Select, NewLoad.getValue(1) }, dl);
32045   }
32046 
32047   assert((!N->isExpandingLoad() || Subtarget.hasAVX512()) &&
32048          "Expanding masked load is supported on AVX-512 target only!");
32049 
32050   assert((!N->isExpandingLoad() || ScalarVT.getSizeInBits() >= 32) &&
32051          "Expanding masked load is supported for 32 and 64-bit types only!");
32052 
32053   assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
32054          "Cannot lower masked load op.");
32055 
32056   assert((ScalarVT.getSizeInBits() >= 32 ||
32057           (Subtarget.hasBWI() &&
32058               (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) &&
32059          "Unsupported masked load op.");
32060 
32061   // This operation is legal for targets with VLX, but without
32062   // VLX the vector should be widened to 512 bit
32063   unsigned NumEltsInWideVec = 512 / VT.getScalarSizeInBits();
32064   MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
32065   PassThru = ExtendToType(PassThru, WideDataVT, DAG);
32066 
32067   // Mask element has to be i1.
32068   assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 &&
32069          "Unexpected mask type");
32070 
32071   MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
32072 
32073   Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
32074   SDValue NewLoad = DAG.getMaskedLoad(
32075       WideDataVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask,
32076       PassThru, N->getMemoryVT(), N->getMemOperand(), N->getAddressingMode(),
32077       N->getExtensionType(), N->isExpandingLoad());
32078 
32079   SDValue Extract =
32080       DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, NewLoad.getValue(0),
32081                   DAG.getIntPtrConstant(0, dl));
32082   SDValue RetOps[] = {Extract, NewLoad.getValue(1)};
32083   return DAG.getMergeValues(RetOps, dl);
32084 }
32085 
LowerMSTORE(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)32086 static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget,
32087                            SelectionDAG &DAG) {
32088   MaskedStoreSDNode *N = cast<MaskedStoreSDNode>(Op.getNode());
32089   SDValue DataToStore = N->getValue();
32090   MVT VT = DataToStore.getSimpleValueType();
32091   MVT ScalarVT = VT.getScalarType();
32092   SDValue Mask = N->getMask();
32093   SDLoc dl(Op);
32094 
32095   assert((!N->isCompressingStore() || Subtarget.hasAVX512()) &&
32096          "Expanding masked load is supported on AVX-512 target only!");
32097 
32098   assert((!N->isCompressingStore() || ScalarVT.getSizeInBits() >= 32) &&
32099          "Expanding masked load is supported for 32 and 64-bit types only!");
32100 
32101   assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
32102          "Cannot lower masked store op.");
32103 
32104   assert((ScalarVT.getSizeInBits() >= 32 ||
32105           (Subtarget.hasBWI() &&
32106               (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) &&
32107           "Unsupported masked store op.");
32108 
32109   // This operation is legal for targets with VLX, but without
32110   // VLX the vector should be widened to 512 bit
32111   unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
32112   MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
32113 
32114   // Mask element has to be i1.
32115   assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 &&
32116          "Unexpected mask type");
32117 
32118   MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
32119 
32120   DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
32121   Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
32122   return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
32123                             N->getOffset(), Mask, N->getMemoryVT(),
32124                             N->getMemOperand(), N->getAddressingMode(),
32125                             N->isTruncatingStore(), N->isCompressingStore());
32126 }
32127 
LowerMGATHER(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)32128 static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
32129                             SelectionDAG &DAG) {
32130   assert(Subtarget.hasAVX2() &&
32131          "MGATHER/MSCATTER are supported on AVX-512/AVX-2 arch only");
32132 
32133   MaskedGatherSDNode *N = cast<MaskedGatherSDNode>(Op.getNode());
32134   SDLoc dl(Op);
32135   MVT VT = Op.getSimpleValueType();
32136   SDValue Index = N->getIndex();
32137   SDValue Mask = N->getMask();
32138   SDValue PassThru = N->getPassThru();
32139   MVT IndexVT = Index.getSimpleValueType();
32140 
32141   assert(VT.getScalarSizeInBits() >= 32 && "Unsupported gather op");
32142 
32143   // If the index is v2i32, we're being called by type legalization.
32144   if (IndexVT == MVT::v2i32)
32145     return SDValue();
32146 
32147   // If we don't have VLX and neither the passthru or index is 512-bits, we
32148   // need to widen until one is.
32149   MVT OrigVT = VT;
32150   if (Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
32151       !IndexVT.is512BitVector()) {
32152     // Determine how much we need to widen by to get a 512-bit type.
32153     unsigned Factor = std::min(512/VT.getSizeInBits(),
32154                                512/IndexVT.getSizeInBits());
32155 
32156     unsigned NumElts = VT.getVectorNumElements() * Factor;
32157 
32158     VT = MVT::getVectorVT(VT.getVectorElementType(), NumElts);
32159     IndexVT = MVT::getVectorVT(IndexVT.getVectorElementType(), NumElts);
32160     MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
32161 
32162     PassThru = ExtendToType(PassThru, VT, DAG);
32163     Index = ExtendToType(Index, IndexVT, DAG);
32164     Mask = ExtendToType(Mask, MaskVT, DAG, true);
32165   }
32166 
32167   // Break dependency on the data register.
32168   if (PassThru.isUndef())
32169     PassThru = getZeroVector(VT, Subtarget, DAG, dl);
32170 
32171   SDValue Ops[] = { N->getChain(), PassThru, Mask, N->getBasePtr(), Index,
32172                     N->getScale() };
32173   SDValue NewGather = DAG.getMemIntrinsicNode(
32174       X86ISD::MGATHER, dl, DAG.getVTList(VT, MVT::Other), Ops, N->getMemoryVT(),
32175       N->getMemOperand());
32176   SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OrigVT,
32177                                 NewGather, DAG.getIntPtrConstant(0, dl));
32178   return DAG.getMergeValues({Extract, NewGather.getValue(1)}, dl);
32179 }
32180 
LowerADDRSPACECAST(SDValue Op,SelectionDAG & DAG)32181 static SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) {
32182   SDLoc dl(Op);
32183   SDValue Src = Op.getOperand(0);
32184   MVT DstVT = Op.getSimpleValueType();
32185 
32186   AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Op.getNode());
32187   unsigned SrcAS = N->getSrcAddressSpace();
32188 
32189   assert(SrcAS != N->getDestAddressSpace() &&
32190          "addrspacecast must be between different address spaces");
32191 
32192   if (SrcAS == X86AS::PTR32_UPTR && DstVT == MVT::i64) {
32193     Op = DAG.getNode(ISD::ZERO_EXTEND, dl, DstVT, Src);
32194   } else if (DstVT == MVT::i64) {
32195     Op = DAG.getNode(ISD::SIGN_EXTEND, dl, DstVT, Src);
32196   } else if (DstVT == MVT::i32) {
32197     Op = DAG.getNode(ISD::TRUNCATE, dl, DstVT, Src);
32198   } else {
32199     report_fatal_error("Bad address space in addrspacecast");
32200   }
32201   return Op;
32202 }
32203 
LowerGC_TRANSITION(SDValue Op,SelectionDAG & DAG) const32204 SDValue X86TargetLowering::LowerGC_TRANSITION(SDValue Op,
32205                                               SelectionDAG &DAG) const {
32206   // TODO: Eventually, the lowering of these nodes should be informed by or
32207   // deferred to the GC strategy for the function in which they appear. For
32208   // now, however, they must be lowered to something. Since they are logically
32209   // no-ops in the case of a null GC strategy (or a GC strategy which does not
32210   // require special handling for these nodes), lower them as literal NOOPs for
32211   // the time being.
32212   SmallVector<SDValue, 2> Ops;
32213   Ops.push_back(Op.getOperand(0));
32214   if (Op->getGluedNode())
32215     Ops.push_back(Op->getOperand(Op->getNumOperands() - 1));
32216 
32217   SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
32218   return SDValue(DAG.getMachineNode(X86::NOOP, SDLoc(Op), VTs, Ops), 0);
32219 }
32220 
32221 // Custom split CVTPS2PH with wide types.
LowerCVTPS2PH(SDValue Op,SelectionDAG & DAG)32222 static SDValue LowerCVTPS2PH(SDValue Op, SelectionDAG &DAG) {
32223   SDLoc dl(Op);
32224   EVT VT = Op.getValueType();
32225   SDValue Lo, Hi;
32226   std::tie(Lo, Hi) = DAG.SplitVectorOperand(Op.getNode(), 0);
32227   EVT LoVT, HiVT;
32228   std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
32229   SDValue RC = Op.getOperand(1);
32230   Lo = DAG.getNode(X86ISD::CVTPS2PH, dl, LoVT, Lo, RC);
32231   Hi = DAG.getNode(X86ISD::CVTPS2PH, dl, HiVT, Hi, RC);
32232   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
32233 }
32234 
LowerPREFETCH(SDValue Op,const X86Subtarget & Subtarget,SelectionDAG & DAG)32235 static SDValue LowerPREFETCH(SDValue Op, const X86Subtarget &Subtarget,
32236                              SelectionDAG &DAG) {
32237   unsigned IsData = Op.getConstantOperandVal(4);
32238 
32239   // We don't support non-data prefetch without PREFETCHI.
32240   // Just preserve the chain.
32241   if (!IsData && !Subtarget.hasPREFETCHI())
32242     return Op.getOperand(0);
32243 
32244   return Op;
32245 }
32246 
getInstrStrFromOpNo(const SmallVectorImpl<StringRef> & AsmStrs,unsigned OpNo)32247 static StringRef getInstrStrFromOpNo(const SmallVectorImpl<StringRef> &AsmStrs,
32248                                      unsigned OpNo) {
32249   const APInt Operand(32, OpNo);
32250   std::string OpNoStr = llvm::toString(Operand, 10, false);
32251   std::string Str(" $");
32252 
32253   std::string OpNoStr1(Str + OpNoStr);             // e.g. " $1" (OpNo=1)
32254   std::string OpNoStr2(Str + "{" + OpNoStr + ":"); // With modifier, e.g. ${1:P}
32255 
32256   auto I = StringRef::npos;
32257   for (auto &AsmStr : AsmStrs) {
32258     // Match the OpNo string. We should match exactly to exclude match
32259     // sub-string, e.g. "$12" contain "$1"
32260     if (AsmStr.ends_with(OpNoStr1))
32261       I = AsmStr.size() - OpNoStr1.size();
32262 
32263     // Get the index of operand in AsmStr.
32264     if (I == StringRef::npos)
32265       I = AsmStr.find(OpNoStr1 + ",");
32266     if (I == StringRef::npos)
32267       I = AsmStr.find(OpNoStr2);
32268 
32269     if (I == StringRef::npos)
32270       continue;
32271 
32272     assert(I > 0 && "Unexpected inline asm string!");
32273     // Remove the operand string and label (if exsit).
32274     // For example:
32275     // ".L__MSASMLABEL_.${:uid}__l:call dword ptr ${0:P}"
32276     // ==>
32277     // ".L__MSASMLABEL_.${:uid}__l:call dword ptr "
32278     // ==>
32279     // "call dword ptr "
32280     auto TmpStr = AsmStr.substr(0, I);
32281     I = TmpStr.rfind(':');
32282     if (I != StringRef::npos)
32283       TmpStr = TmpStr.substr(I + 1);
32284     return TmpStr.take_while(llvm::isAlpha);
32285   }
32286 
32287   return StringRef();
32288 }
32289 
isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> & AsmStrs,unsigned OpNo) const32290 bool X86TargetLowering::isInlineAsmTargetBranch(
32291     const SmallVectorImpl<StringRef> &AsmStrs, unsigned OpNo) const {
32292   // In a __asm block, __asm inst foo where inst is CALL or JMP should be
32293   // changed from indirect TargetLowering::C_Memory to direct
32294   // TargetLowering::C_Address.
32295   // We don't need to special case LOOP* and Jcc, which cannot target a memory
32296   // location.
32297   StringRef Inst = getInstrStrFromOpNo(AsmStrs, OpNo);
32298   return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
32299 }
32300 
getFlagsOfCmpZeroFori1(SelectionDAG & DAG,const SDLoc & DL,SDValue Mask)32301 static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
32302                                       SDValue Mask) {
32303   EVT Ty = MVT::i8;
32304   auto V = DAG.getBitcast(MVT::i1, Mask);
32305   auto VE = DAG.getZExtOrTrunc(V, DL, Ty);
32306   auto Zero = DAG.getConstant(0, DL, Ty);
32307   SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
32308   auto CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
32309   return SDValue(CmpZero.getNode(), 1);
32310 }
32311 
visitMaskedLoad(SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,MachineMemOperand * MMO,SDValue & NewLoad,SDValue Ptr,SDValue PassThru,SDValue Mask) const32312 SDValue X86TargetLowering::visitMaskedLoad(
32313     SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
32314     SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
32315   // @llvm.masked.load.v1*(ptr, alignment, mask, passthru)
32316   // ->
32317   // _, flags = SUB 0, mask
32318   // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
32319   // bit_cast_to_vector<res>
32320   EVT VTy = PassThru.getValueType();
32321   EVT Ty = VTy.getVectorElementType();
32322   SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
32323   auto ScalarPassThru = PassThru.isUndef() ? DAG.getConstant(0, DL, Ty)
32324                                            : DAG.getBitcast(Ty, PassThru);
32325   auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
32326   auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
32327   SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
32328   NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
32329   return DAG.getBitcast(VTy, NewLoad);
32330 }
32331 
visitMaskedStore(SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,MachineMemOperand * MMO,SDValue Ptr,SDValue Val,SDValue Mask) const32332 SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
32333                                             SDValue Chain,
32334                                             MachineMemOperand *MMO, SDValue Ptr,
32335                                             SDValue Val, SDValue Mask) const {
32336   // llvm.masked.store.v1*(Src0, Ptr, alignment, Mask)
32337   // ->
32338   // _, flags = SUB 0, mask
32339   // chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
32340   EVT Ty = Val.getValueType().getVectorElementType();
32341   SDVTList Tys = DAG.getVTList(MVT::Other);
32342   auto ScalarVal = DAG.getBitcast(Ty, Val);
32343   auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
32344   auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
32345   SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
32346   return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
32347 }
32348 
32349 /// Provide custom lowering hooks for some operations.
LowerOperation(SDValue Op,SelectionDAG & DAG) const32350 SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
32351   switch (Op.getOpcode()) {
32352   // clang-format off
32353   default: llvm_unreachable("Should not custom lower this!");
32354   case ISD::ATOMIC_FENCE:       return LowerATOMIC_FENCE(Op, Subtarget, DAG);
32355   case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
32356     return LowerCMP_SWAP(Op, Subtarget, DAG);
32357   case ISD::CTPOP:              return LowerCTPOP(Op, Subtarget, DAG);
32358   case ISD::ATOMIC_LOAD_ADD:
32359   case ISD::ATOMIC_LOAD_SUB:
32360   case ISD::ATOMIC_LOAD_OR:
32361   case ISD::ATOMIC_LOAD_XOR:
32362   case ISD::ATOMIC_LOAD_AND:    return lowerAtomicArith(Op, DAG, Subtarget);
32363   case ISD::ATOMIC_STORE:       return LowerATOMIC_STORE(Op, DAG, Subtarget);
32364   case ISD::BITREVERSE:         return LowerBITREVERSE(Op, Subtarget, DAG);
32365   case ISD::PARITY:             return LowerPARITY(Op, Subtarget, DAG);
32366   case ISD::BUILD_VECTOR:       return LowerBUILD_VECTOR(Op, DAG);
32367   case ISD::CONCAT_VECTORS:     return LowerCONCAT_VECTORS(Op, Subtarget, DAG);
32368   case ISD::VECTOR_SHUFFLE:     return lowerVECTOR_SHUFFLE(Op, Subtarget, DAG);
32369   case ISD::VSELECT:            return LowerVSELECT(Op, DAG);
32370   case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG);
32371   case ISD::INSERT_VECTOR_ELT:  return LowerINSERT_VECTOR_ELT(Op, DAG);
32372   case ISD::INSERT_SUBVECTOR:   return LowerINSERT_SUBVECTOR(Op, Subtarget,DAG);
32373   case ISD::EXTRACT_SUBVECTOR:  return LowerEXTRACT_SUBVECTOR(Op,Subtarget,DAG);
32374   case ISD::SCALAR_TO_VECTOR:   return LowerSCALAR_TO_VECTOR(Op, Subtarget,DAG);
32375   case ISD::ConstantPool:       return LowerConstantPool(Op, DAG);
32376   case ISD::GlobalAddress:      return LowerGlobalAddress(Op, DAG);
32377   case ISD::GlobalTLSAddress:   return LowerGlobalTLSAddress(Op, DAG);
32378   case ISD::ExternalSymbol:     return LowerExternalSymbol(Op, DAG);
32379   case ISD::BlockAddress:       return LowerBlockAddress(Op, DAG);
32380   case ISD::SHL_PARTS:
32381   case ISD::SRA_PARTS:
32382   case ISD::SRL_PARTS:          return LowerShiftParts(Op, DAG);
32383   case ISD::FSHL:
32384   case ISD::FSHR:               return LowerFunnelShift(Op, Subtarget, DAG);
32385   case ISD::STRICT_SINT_TO_FP:
32386   case ISD::SINT_TO_FP:         return LowerSINT_TO_FP(Op, DAG);
32387   case ISD::STRICT_UINT_TO_FP:
32388   case ISD::UINT_TO_FP:         return LowerUINT_TO_FP(Op, DAG);
32389   case ISD::TRUNCATE:           return LowerTRUNCATE(Op, DAG);
32390   case ISD::ZERO_EXTEND:        return LowerZERO_EXTEND(Op, Subtarget, DAG);
32391   case ISD::SIGN_EXTEND:        return LowerSIGN_EXTEND(Op, Subtarget, DAG);
32392   case ISD::ANY_EXTEND:         return LowerANY_EXTEND(Op, Subtarget, DAG);
32393   case ISD::ZERO_EXTEND_VECTOR_INREG:
32394   case ISD::SIGN_EXTEND_VECTOR_INREG:
32395     return LowerEXTEND_VECTOR_INREG(Op, Subtarget, DAG);
32396   case ISD::FP_TO_SINT:
32397   case ISD::STRICT_FP_TO_SINT:
32398   case ISD::FP_TO_UINT:
32399   case ISD::STRICT_FP_TO_UINT:  return LowerFP_TO_INT(Op, DAG);
32400   case ISD::FP_TO_SINT_SAT:
32401   case ISD::FP_TO_UINT_SAT:     return LowerFP_TO_INT_SAT(Op, DAG);
32402   case ISD::FP_EXTEND:
32403   case ISD::STRICT_FP_EXTEND:   return LowerFP_EXTEND(Op, DAG);
32404   case ISD::FP_ROUND:
32405   case ISD::STRICT_FP_ROUND:    return LowerFP_ROUND(Op, DAG);
32406   case ISD::FP16_TO_FP:
32407   case ISD::STRICT_FP16_TO_FP:  return LowerFP16_TO_FP(Op, DAG);
32408   case ISD::FP_TO_FP16:
32409   case ISD::STRICT_FP_TO_FP16:  return LowerFP_TO_FP16(Op, DAG);
32410   case ISD::FP_TO_BF16:         return LowerFP_TO_BF16(Op, DAG);
32411   case ISD::LOAD:               return LowerLoad(Op, Subtarget, DAG);
32412   case ISD::STORE:              return LowerStore(Op, Subtarget, DAG);
32413   case ISD::FADD:
32414   case ISD::FSUB:               return lowerFaddFsub(Op, DAG);
32415   case ISD::FROUND:             return LowerFROUND(Op, DAG);
32416   case ISD::FABS:
32417   case ISD::FNEG:               return LowerFABSorFNEG(Op, DAG);
32418   case ISD::FCOPYSIGN:          return LowerFCOPYSIGN(Op, DAG);
32419   case ISD::FGETSIGN:           return LowerFGETSIGN(Op, DAG);
32420   case ISD::LRINT:
32421   case ISD::LLRINT:             return LowerLRINT_LLRINT(Op, DAG);
32422   case ISD::SETCC:
32423   case ISD::STRICT_FSETCC:
32424   case ISD::STRICT_FSETCCS:     return LowerSETCC(Op, DAG);
32425   case ISD::SETCCCARRY:         return LowerSETCCCARRY(Op, DAG);
32426   case ISD::SELECT:             return LowerSELECT(Op, DAG);
32427   case ISD::BRCOND:             return LowerBRCOND(Op, DAG);
32428   case ISD::JumpTable:          return LowerJumpTable(Op, DAG);
32429   case ISD::VASTART:            return LowerVASTART(Op, DAG);
32430   case ISD::VAARG:              return LowerVAARG(Op, DAG);
32431   case ISD::VACOPY:             return LowerVACOPY(Op, Subtarget, DAG);
32432   case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG);
32433   case ISD::INTRINSIC_VOID:
32434   case ISD::INTRINSIC_W_CHAIN:  return LowerINTRINSIC_W_CHAIN(Op, Subtarget, DAG);
32435   case ISD::RETURNADDR:         return LowerRETURNADDR(Op, DAG);
32436   case ISD::ADDROFRETURNADDR:   return LowerADDROFRETURNADDR(Op, DAG);
32437   case ISD::FRAMEADDR:          return LowerFRAMEADDR(Op, DAG);
32438   case ISD::FRAME_TO_ARGS_OFFSET:
32439                                 return LowerFRAME_TO_ARGS_OFFSET(Op, DAG);
32440   case ISD::DYNAMIC_STACKALLOC: return LowerDYNAMIC_STACKALLOC(Op, DAG);
32441   case ISD::EH_RETURN:          return LowerEH_RETURN(Op, DAG);
32442   case ISD::EH_SJLJ_SETJMP:     return lowerEH_SJLJ_SETJMP(Op, DAG);
32443   case ISD::EH_SJLJ_LONGJMP:    return lowerEH_SJLJ_LONGJMP(Op, DAG);
32444   case ISD::EH_SJLJ_SETUP_DISPATCH:
32445     return lowerEH_SJLJ_SETUP_DISPATCH(Op, DAG);
32446   case ISD::INIT_TRAMPOLINE:    return LowerINIT_TRAMPOLINE(Op, DAG);
32447   case ISD::ADJUST_TRAMPOLINE:  return LowerADJUST_TRAMPOLINE(Op, DAG);
32448   case ISD::GET_ROUNDING:       return LowerGET_ROUNDING(Op, DAG);
32449   case ISD::SET_ROUNDING:       return LowerSET_ROUNDING(Op, DAG);
32450   case ISD::GET_FPENV_MEM:      return LowerGET_FPENV_MEM(Op, DAG);
32451   case ISD::SET_FPENV_MEM:      return LowerSET_FPENV_MEM(Op, DAG);
32452   case ISD::RESET_FPENV:        return LowerRESET_FPENV(Op, DAG);
32453   case ISD::CTLZ:
32454   case ISD::CTLZ_ZERO_UNDEF:    return LowerCTLZ(Op, Subtarget, DAG);
32455   case ISD::CTTZ:
32456   case ISD::CTTZ_ZERO_UNDEF:    return LowerCTTZ(Op, Subtarget, DAG);
32457   case ISD::MUL:                return LowerMUL(Op, Subtarget, DAG);
32458   case ISD::MULHS:
32459   case ISD::MULHU:              return LowerMULH(Op, Subtarget, DAG);
32460   case ISD::ROTL:
32461   case ISD::ROTR:               return LowerRotate(Op, Subtarget, DAG);
32462   case ISD::SRA:
32463   case ISD::SRL:
32464   case ISD::SHL:                return LowerShift(Op, Subtarget, DAG);
32465   case ISD::SADDO:
32466   case ISD::UADDO:
32467   case ISD::SSUBO:
32468   case ISD::USUBO:              return LowerXALUO(Op, DAG);
32469   case ISD::SMULO:
32470   case ISD::UMULO:              return LowerMULO(Op, Subtarget, DAG);
32471   case ISD::READCYCLECOUNTER:   return LowerREADCYCLECOUNTER(Op, Subtarget,DAG);
32472   case ISD::BITCAST:            return LowerBITCAST(Op, Subtarget, DAG);
32473   case ISD::SADDO_CARRY:
32474   case ISD::SSUBO_CARRY:
32475   case ISD::UADDO_CARRY:
32476   case ISD::USUBO_CARRY:        return LowerADDSUBO_CARRY(Op, DAG);
32477   case ISD::ADD:
32478   case ISD::SUB:                return lowerAddSub(Op, DAG, Subtarget);
32479   case ISD::UADDSAT:
32480   case ISD::SADDSAT:
32481   case ISD::USUBSAT:
32482   case ISD::SSUBSAT:            return LowerADDSAT_SUBSAT(Op, DAG, Subtarget);
32483   case ISD::SMAX:
32484   case ISD::SMIN:
32485   case ISD::UMAX:
32486   case ISD::UMIN:               return LowerMINMAX(Op, Subtarget, DAG);
32487   case ISD::FMINIMUM:
32488   case ISD::FMAXIMUM:
32489     return LowerFMINIMUM_FMAXIMUM(Op, Subtarget, DAG);
32490   case ISD::ABS:                return LowerABS(Op, Subtarget, DAG);
32491   case ISD::ABDS:
32492   case ISD::ABDU:               return LowerABD(Op, Subtarget, DAG);
32493   case ISD::AVGCEILU:           return LowerAVG(Op, Subtarget, DAG);
32494   case ISD::FSINCOS:            return LowerFSINCOS(Op, Subtarget, DAG);
32495   case ISD::MLOAD:              return LowerMLOAD(Op, Subtarget, DAG);
32496   case ISD::MSTORE:             return LowerMSTORE(Op, Subtarget, DAG);
32497   case ISD::MGATHER:            return LowerMGATHER(Op, Subtarget, DAG);
32498   case ISD::MSCATTER:           return LowerMSCATTER(Op, Subtarget, DAG);
32499   case ISD::GC_TRANSITION_START:
32500   case ISD::GC_TRANSITION_END:  return LowerGC_TRANSITION(Op, DAG);
32501   case ISD::ADDRSPACECAST:      return LowerADDRSPACECAST(Op, DAG);
32502   case X86ISD::CVTPS2PH:        return LowerCVTPS2PH(Op, DAG);
32503   case ISD::PREFETCH:           return LowerPREFETCH(Op, Subtarget, DAG);
32504   // clang-format on
32505   }
32506 }
32507 
32508 /// Replace a node with an illegal result type with a new node built out of
32509 /// custom code.
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const32510 void X86TargetLowering::ReplaceNodeResults(SDNode *N,
32511                                            SmallVectorImpl<SDValue>&Results,
32512                                            SelectionDAG &DAG) const {
32513   SDLoc dl(N);
32514   switch (N->getOpcode()) {
32515   default:
32516 #ifndef NDEBUG
32517     dbgs() << "ReplaceNodeResults: ";
32518     N->dump(&DAG);
32519 #endif
32520     llvm_unreachable("Do not know how to custom type legalize this operation!");
32521   case X86ISD::CVTPH2PS: {
32522     EVT VT = N->getValueType(0);
32523     SDValue Lo, Hi;
32524     std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
32525     EVT LoVT, HiVT;
32526     std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
32527     Lo = DAG.getNode(X86ISD::CVTPH2PS, dl, LoVT, Lo);
32528     Hi = DAG.getNode(X86ISD::CVTPH2PS, dl, HiVT, Hi);
32529     SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
32530     Results.push_back(Res);
32531     return;
32532   }
32533   case X86ISD::STRICT_CVTPH2PS: {
32534     EVT VT = N->getValueType(0);
32535     SDValue Lo, Hi;
32536     std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 1);
32537     EVT LoVT, HiVT;
32538     std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
32539     Lo = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {LoVT, MVT::Other},
32540                      {N->getOperand(0), Lo});
32541     Hi = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {HiVT, MVT::Other},
32542                      {N->getOperand(0), Hi});
32543     SDValue Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
32544                                 Lo.getValue(1), Hi.getValue(1));
32545     SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
32546     Results.push_back(Res);
32547     Results.push_back(Chain);
32548     return;
32549   }
32550   case X86ISD::CVTPS2PH:
32551     Results.push_back(LowerCVTPS2PH(SDValue(N, 0), DAG));
32552     return;
32553   case ISD::CTPOP: {
32554     assert(N->getValueType(0) == MVT::i64 && "Unexpected VT!");
32555     // If we have at most 32 active bits, then perform as i32 CTPOP.
32556     // TODO: Perform this in generic legalizer?
32557     KnownBits Known = DAG.computeKnownBits(N->getOperand(0));
32558     unsigned LZ = Known.countMinLeadingZeros();
32559     unsigned TZ = Known.countMinTrailingZeros();
32560     if ((LZ + TZ) >= 32) {
32561       SDValue Op = DAG.getNode(ISD::SRL, dl, MVT::i64, N->getOperand(0),
32562                                DAG.getShiftAmountConstant(TZ, MVT::i64, dl));
32563       Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Op);
32564       Op = DAG.getNode(ISD::CTPOP, dl, MVT::i32, Op);
32565       Op = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, Op);
32566       Results.push_back(Op);
32567       return;
32568     }
32569     // Use a v2i64 if possible.
32570     bool NoImplicitFloatOps =
32571         DAG.getMachineFunction().getFunction().hasFnAttribute(
32572             Attribute::NoImplicitFloat);
32573     if (isTypeLegal(MVT::v2i64) && !NoImplicitFloatOps) {
32574       SDValue Wide =
32575           DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v2i64, N->getOperand(0));
32576       Wide = DAG.getNode(ISD::CTPOP, dl, MVT::v2i64, Wide);
32577       // Bit count should fit in 32-bits, extract it as that and then zero
32578       // extend to i64. Otherwise we end up extracting bits 63:32 separately.
32579       Wide = DAG.getNode(ISD::BITCAST, dl, MVT::v4i32, Wide);
32580       Wide = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, Wide,
32581                          DAG.getIntPtrConstant(0, dl));
32582       Wide = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, Wide);
32583       Results.push_back(Wide);
32584     }
32585     return;
32586   }
32587   case ISD::MUL: {
32588     EVT VT = N->getValueType(0);
32589     assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
32590            VT.getVectorElementType() == MVT::i8 && "Unexpected VT!");
32591     // Pre-promote these to vXi16 to avoid op legalization thinking all 16
32592     // elements are needed.
32593     MVT MulVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements());
32594     SDValue Op0 = DAG.getNode(ISD::ANY_EXTEND, dl, MulVT, N->getOperand(0));
32595     SDValue Op1 = DAG.getNode(ISD::ANY_EXTEND, dl, MulVT, N->getOperand(1));
32596     SDValue Res = DAG.getNode(ISD::MUL, dl, MulVT, Op0, Op1);
32597     Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
32598     unsigned NumConcats = 16 / VT.getVectorNumElements();
32599     SmallVector<SDValue, 8> ConcatOps(NumConcats, DAG.getUNDEF(VT));
32600     ConcatOps[0] = Res;
32601     Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, ConcatOps);
32602     Results.push_back(Res);
32603     return;
32604   }
32605   case ISD::SMULO:
32606   case ISD::UMULO: {
32607     EVT VT = N->getValueType(0);
32608     assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
32609            VT == MVT::v2i32 && "Unexpected VT!");
32610     bool IsSigned = N->getOpcode() == ISD::SMULO;
32611     unsigned ExtOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
32612     SDValue Op0 = DAG.getNode(ExtOpc, dl, MVT::v2i64, N->getOperand(0));
32613     SDValue Op1 = DAG.getNode(ExtOpc, dl, MVT::v2i64, N->getOperand(1));
32614     SDValue Res = DAG.getNode(ISD::MUL, dl, MVT::v2i64, Op0, Op1);
32615     // Extract the high 32 bits from each result using PSHUFD.
32616     // TODO: Could use SRL+TRUNCATE but that doesn't become a PSHUFD.
32617     SDValue Hi = DAG.getBitcast(MVT::v4i32, Res);
32618     Hi = DAG.getVectorShuffle(MVT::v4i32, dl, Hi, Hi, {1, 3, -1, -1});
32619     Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, Hi,
32620                      DAG.getIntPtrConstant(0, dl));
32621 
32622     // Truncate the low bits of the result. This will become PSHUFD.
32623     Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
32624 
32625     SDValue HiCmp;
32626     if (IsSigned) {
32627       // SMULO overflows if the high bits don't match the sign of the low.
32628       HiCmp = DAG.getNode(ISD::SRA, dl, VT, Res, DAG.getConstant(31, dl, VT));
32629     } else {
32630       // UMULO overflows if the high bits are non-zero.
32631       HiCmp = DAG.getConstant(0, dl, VT);
32632     }
32633     SDValue Ovf = DAG.getSetCC(dl, N->getValueType(1), Hi, HiCmp, ISD::SETNE);
32634 
32635     // Widen the result with by padding with undef.
32636     Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Res,
32637                       DAG.getUNDEF(VT));
32638     Results.push_back(Res);
32639     Results.push_back(Ovf);
32640     return;
32641   }
32642   case X86ISD::VPMADDWD: {
32643     // Legalize types for X86ISD::VPMADDWD by widening.
32644     assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
32645 
32646     EVT VT = N->getValueType(0);
32647     EVT InVT = N->getOperand(0).getValueType();
32648     assert(VT.getSizeInBits() < 128 && 128 % VT.getSizeInBits() == 0 &&
32649            "Expected a VT that divides into 128 bits.");
32650     assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
32651            "Unexpected type action!");
32652     unsigned NumConcat = 128 / InVT.getSizeInBits();
32653 
32654     EVT InWideVT = EVT::getVectorVT(*DAG.getContext(),
32655                                     InVT.getVectorElementType(),
32656                                     NumConcat * InVT.getVectorNumElements());
32657     EVT WideVT = EVT::getVectorVT(*DAG.getContext(),
32658                                   VT.getVectorElementType(),
32659                                   NumConcat * VT.getVectorNumElements());
32660 
32661     SmallVector<SDValue, 16> Ops(NumConcat, DAG.getUNDEF(InVT));
32662     Ops[0] = N->getOperand(0);
32663     SDValue InVec0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, InWideVT, Ops);
32664     Ops[0] = N->getOperand(1);
32665     SDValue InVec1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, InWideVT, Ops);
32666 
32667     SDValue Res = DAG.getNode(N->getOpcode(), dl, WideVT, InVec0, InVec1);
32668     Results.push_back(Res);
32669     return;
32670   }
32671   // We might have generated v2f32 FMIN/FMAX operations. Widen them to v4f32.
32672   case X86ISD::FMINC:
32673   case X86ISD::FMIN:
32674   case X86ISD::FMAXC:
32675   case X86ISD::FMAX: {
32676     EVT VT = N->getValueType(0);
32677     assert(VT == MVT::v2f32 && "Unexpected type (!= v2f32) on FMIN/FMAX.");
32678     SDValue UNDEF = DAG.getUNDEF(VT);
32679     SDValue LHS = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32,
32680                               N->getOperand(0), UNDEF);
32681     SDValue RHS = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32,
32682                               N->getOperand(1), UNDEF);
32683     Results.push_back(DAG.getNode(N->getOpcode(), dl, MVT::v4f32, LHS, RHS));
32684     return;
32685   }
32686   case ISD::SDIV:
32687   case ISD::UDIV:
32688   case ISD::SREM:
32689   case ISD::UREM: {
32690     EVT VT = N->getValueType(0);
32691     if (VT.isVector()) {
32692       assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
32693              "Unexpected type action!");
32694       // If this RHS is a constant splat vector we can widen this and let
32695       // division/remainder by constant optimize it.
32696       // TODO: Can we do something for non-splat?
32697       APInt SplatVal;
32698       if (ISD::isConstantSplatVector(N->getOperand(1).getNode(), SplatVal)) {
32699         unsigned NumConcats = 128 / VT.getSizeInBits();
32700         SmallVector<SDValue, 8> Ops0(NumConcats, DAG.getUNDEF(VT));
32701         Ops0[0] = N->getOperand(0);
32702         EVT ResVT = getTypeToTransformTo(*DAG.getContext(), VT);
32703         SDValue N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Ops0);
32704         SDValue N1 = DAG.getConstant(SplatVal, dl, ResVT);
32705         SDValue Res = DAG.getNode(N->getOpcode(), dl, ResVT, N0, N1);
32706         Results.push_back(Res);
32707       }
32708       return;
32709     }
32710 
32711     SDValue V = LowerWin64_i128OP(SDValue(N,0), DAG);
32712     Results.push_back(V);
32713     return;
32714   }
32715   case ISD::TRUNCATE: {
32716     MVT VT = N->getSimpleValueType(0);
32717     if (getTypeAction(*DAG.getContext(), VT) != TypeWidenVector)
32718       return;
32719 
32720     // The generic legalizer will try to widen the input type to the same
32721     // number of elements as the widened result type. But this isn't always
32722     // the best thing so do some custom legalization to avoid some cases.
32723     MVT WidenVT = getTypeToTransformTo(*DAG.getContext(), VT).getSimpleVT();
32724     SDValue In = N->getOperand(0);
32725     EVT InVT = In.getValueType();
32726     EVT InEltVT = InVT.getVectorElementType();
32727     EVT EltVT = VT.getVectorElementType();
32728     unsigned MinElts = VT.getVectorNumElements();
32729     unsigned WidenNumElts = WidenVT.getVectorNumElements();
32730     unsigned InBits = InVT.getSizeInBits();
32731 
32732     // See if there are sufficient leading bits to perform a PACKUS/PACKSS.
32733     unsigned PackOpcode;
32734     if (SDValue Src =
32735             matchTruncateWithPACK(PackOpcode, VT, In, dl, DAG, Subtarget)) {
32736       if (SDValue Res = truncateVectorWithPACK(PackOpcode, VT, Src,
32737                                                dl, DAG, Subtarget)) {
32738         Res = widenSubVector(WidenVT, Res, false, Subtarget, DAG, dl);
32739         Results.push_back(Res);
32740         return;
32741       }
32742     }
32743 
32744     if ((128 % InBits) == 0 && WidenVT.is128BitVector()) {
32745       // 128 bit and smaller inputs should avoid truncate all together and
32746       // use a shuffle.
32747       if ((InEltVT.getSizeInBits() % EltVT.getSizeInBits()) == 0) {
32748         int Scale = InEltVT.getSizeInBits() / EltVT.getSizeInBits();
32749         SmallVector<int, 16> TruncMask(WidenNumElts, -1);
32750         for (unsigned I = 0; I < MinElts; ++I)
32751           TruncMask[I] = Scale * I;
32752         SDValue WidenIn = widenSubVector(In, false, Subtarget, DAG, dl, 128);
32753         assert(isTypeLegal(WidenVT) && isTypeLegal(WidenIn.getValueType()) &&
32754                "Illegal vector type in truncation");
32755         WidenIn = DAG.getBitcast(WidenVT, WidenIn);
32756         Results.push_back(
32757             DAG.getVectorShuffle(WidenVT, dl, WidenIn, WidenIn, TruncMask));
32758         return;
32759       }
32760     }
32761 
32762     // With AVX512 there are some cases that can use a target specific
32763     // truncate node to go from 256/512 to less than 128 with zeros in the
32764     // upper elements of the 128 bit result.
32765     if (Subtarget.hasAVX512() && isTypeLegal(InVT)) {
32766       // We can use VTRUNC directly if for 256 bits with VLX or for any 512.
32767       if ((InBits == 256 && Subtarget.hasVLX()) || InBits == 512) {
32768         Results.push_back(DAG.getNode(X86ISD::VTRUNC, dl, WidenVT, In));
32769         return;
32770       }
32771       // There's one case we can widen to 512 bits and use VTRUNC.
32772       if (InVT == MVT::v4i64 && VT == MVT::v4i8 && isTypeLegal(MVT::v8i64)) {
32773         In = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i64, In,
32774                          DAG.getUNDEF(MVT::v4i64));
32775         Results.push_back(DAG.getNode(X86ISD::VTRUNC, dl, WidenVT, In));
32776         return;
32777       }
32778     }
32779     if (Subtarget.hasVLX() && InVT == MVT::v8i64 && VT == MVT::v8i8 &&
32780         getTypeAction(*DAG.getContext(), InVT) == TypeSplitVector &&
32781         isTypeLegal(MVT::v4i64)) {
32782       // Input needs to be split and output needs to widened. Let's use two
32783       // VTRUNCs, and shuffle their results together into the wider type.
32784       SDValue Lo, Hi;
32785       std::tie(Lo, Hi) = DAG.SplitVector(In, dl);
32786 
32787       Lo = DAG.getNode(X86ISD::VTRUNC, dl, MVT::v16i8, Lo);
32788       Hi = DAG.getNode(X86ISD::VTRUNC, dl, MVT::v16i8, Hi);
32789       SDValue Res = DAG.getVectorShuffle(MVT::v16i8, dl, Lo, Hi,
32790                                          { 0,  1,  2,  3, 16, 17, 18, 19,
32791                                           -1, -1, -1, -1, -1, -1, -1, -1 });
32792       Results.push_back(Res);
32793       return;
32794     }
32795 
32796     // Attempt to widen the truncation input vector to let LowerTRUNCATE handle
32797     // this via type legalization.
32798     if ((InEltVT == MVT::i16 || InEltVT == MVT::i32 || InEltVT == MVT::i64) &&
32799         (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32) &&
32800         (!Subtarget.hasSSSE3() ||
32801          (!isTypeLegal(InVT) &&
32802           !(MinElts <= 4 && InEltVT == MVT::i64 && EltVT == MVT::i8)))) {
32803       SDValue WidenIn = widenSubVector(In, false, Subtarget, DAG, dl,
32804                                        InEltVT.getSizeInBits() * WidenNumElts);
32805       Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, WidenVT, WidenIn));
32806       return;
32807     }
32808 
32809     return;
32810   }
32811   case ISD::ANY_EXTEND:
32812     // Right now, only MVT::v8i8 has Custom action for an illegal type.
32813     // It's intended to custom handle the input type.
32814     assert(N->getValueType(0) == MVT::v8i8 &&
32815            "Do not know how to legalize this Node");
32816     return;
32817   case ISD::SIGN_EXTEND:
32818   case ISD::ZERO_EXTEND: {
32819     EVT VT = N->getValueType(0);
32820     SDValue In = N->getOperand(0);
32821     EVT InVT = In.getValueType();
32822     if (!Subtarget.hasSSE41() && VT == MVT::v4i64 &&
32823         (InVT == MVT::v4i16 || InVT == MVT::v4i8)){
32824       assert(getTypeAction(*DAG.getContext(), InVT) == TypeWidenVector &&
32825              "Unexpected type action!");
32826       assert(N->getOpcode() == ISD::SIGN_EXTEND && "Unexpected opcode");
32827       // Custom split this so we can extend i8/i16->i32 invec. This is better
32828       // since sign_extend_inreg i8/i16->i64 requires an extend to i32 using
32829       // sra. Then extending from i32 to i64 using pcmpgt. By custom splitting
32830       // we allow the sra from the extend to i32 to be shared by the split.
32831       In = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, In);
32832 
32833       // Fill a vector with sign bits for each element.
32834       SDValue Zero = DAG.getConstant(0, dl, MVT::v4i32);
32835       SDValue SignBits = DAG.getSetCC(dl, MVT::v4i32, Zero, In, ISD::SETGT);
32836 
32837       // Create an unpackl and unpackh to interleave the sign bits then bitcast
32838       // to v2i64.
32839       SDValue Lo = DAG.getVectorShuffle(MVT::v4i32, dl, In, SignBits,
32840                                         {0, 4, 1, 5});
32841       Lo = DAG.getNode(ISD::BITCAST, dl, MVT::v2i64, Lo);
32842       SDValue Hi = DAG.getVectorShuffle(MVT::v4i32, dl, In, SignBits,
32843                                         {2, 6, 3, 7});
32844       Hi = DAG.getNode(ISD::BITCAST, dl, MVT::v2i64, Hi);
32845 
32846       SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
32847       Results.push_back(Res);
32848       return;
32849     }
32850 
32851     if (VT == MVT::v16i32 || VT == MVT::v8i64) {
32852       if (!InVT.is128BitVector()) {
32853         // Not a 128 bit vector, but maybe type legalization will promote
32854         // it to 128 bits.
32855         if (getTypeAction(*DAG.getContext(), InVT) != TypePromoteInteger)
32856           return;
32857         InVT = getTypeToTransformTo(*DAG.getContext(), InVT);
32858         if (!InVT.is128BitVector())
32859           return;
32860 
32861         // Promote the input to 128 bits. Type legalization will turn this into
32862         // zext_inreg/sext_inreg.
32863         In = DAG.getNode(N->getOpcode(), dl, InVT, In);
32864       }
32865 
32866       // Perform custom splitting instead of the two stage extend we would get
32867       // by default.
32868       EVT LoVT, HiVT;
32869       std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
32870       assert(isTypeLegal(LoVT) && "Split VT not legal?");
32871 
32872       SDValue Lo = getEXTEND_VECTOR_INREG(N->getOpcode(), dl, LoVT, In, DAG);
32873 
32874       // We need to shift the input over by half the number of elements.
32875       unsigned NumElts = InVT.getVectorNumElements();
32876       unsigned HalfNumElts = NumElts / 2;
32877       SmallVector<int, 16> ShufMask(NumElts, SM_SentinelUndef);
32878       for (unsigned i = 0; i != HalfNumElts; ++i)
32879         ShufMask[i] = i + HalfNumElts;
32880 
32881       SDValue Hi = DAG.getVectorShuffle(InVT, dl, In, In, ShufMask);
32882       Hi = getEXTEND_VECTOR_INREG(N->getOpcode(), dl, HiVT, Hi, DAG);
32883 
32884       SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
32885       Results.push_back(Res);
32886     }
32887     return;
32888   }
32889   case ISD::FP_TO_SINT:
32890   case ISD::STRICT_FP_TO_SINT:
32891   case ISD::FP_TO_UINT:
32892   case ISD::STRICT_FP_TO_UINT: {
32893     bool IsStrict = N->isStrictFPOpcode();
32894     bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT ||
32895                     N->getOpcode() == ISD::STRICT_FP_TO_SINT;
32896     EVT VT = N->getValueType(0);
32897     SDValue Src = N->getOperand(IsStrict ? 1 : 0);
32898     SDValue Chain = IsStrict ? N->getOperand(0) : SDValue();
32899     EVT SrcVT = Src.getValueType();
32900 
32901     SDValue Res;
32902     if (isSoftF16(SrcVT, Subtarget)) {
32903       EVT NVT = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
32904       if (IsStrict) {
32905         Res =
32906             DAG.getNode(N->getOpcode(), dl, {VT, MVT::Other},
32907                         {Chain, DAG.getNode(ISD::STRICT_FP_EXTEND, dl,
32908                                             {NVT, MVT::Other}, {Chain, Src})});
32909         Chain = Res.getValue(1);
32910       } else {
32911         Res = DAG.getNode(N->getOpcode(), dl, VT,
32912                           DAG.getNode(ISD::FP_EXTEND, dl, NVT, Src));
32913       }
32914       Results.push_back(Res);
32915       if (IsStrict)
32916         Results.push_back(Chain);
32917 
32918       return;
32919     }
32920 
32921     if (VT.isVector() && Subtarget.hasFP16() &&
32922         SrcVT.getVectorElementType() == MVT::f16) {
32923       EVT EleVT = VT.getVectorElementType();
32924       EVT ResVT = EleVT == MVT::i32 ? MVT::v4i32 : MVT::v8i16;
32925 
32926       if (SrcVT != MVT::v8f16) {
32927         SDValue Tmp =
32928             IsStrict ? DAG.getConstantFP(0.0, dl, SrcVT) : DAG.getUNDEF(SrcVT);
32929         SmallVector<SDValue, 4> Ops(SrcVT == MVT::v2f16 ? 4 : 2, Tmp);
32930         Ops[0] = Src;
32931         Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f16, Ops);
32932       }
32933 
32934       if (IsStrict) {
32935         unsigned Opc =
32936             IsSigned ? X86ISD::STRICT_CVTTP2SI : X86ISD::STRICT_CVTTP2UI;
32937         Res =
32938             DAG.getNode(Opc, dl, {ResVT, MVT::Other}, {N->getOperand(0), Src});
32939         Chain = Res.getValue(1);
32940       } else {
32941         unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
32942         Res = DAG.getNode(Opc, dl, ResVT, Src);
32943       }
32944 
32945       // TODO: Need to add exception check code for strict FP.
32946       if (EleVT.getSizeInBits() < 16) {
32947         MVT TmpVT = MVT::getVectorVT(EleVT.getSimpleVT(), 8);
32948         Res = DAG.getNode(ISD::TRUNCATE, dl, TmpVT, Res);
32949 
32950         // Now widen to 128 bits.
32951         unsigned NumConcats = 128 / TmpVT.getSizeInBits();
32952         MVT ConcatVT = MVT::getVectorVT(EleVT.getSimpleVT(), 8 * NumConcats);
32953         SmallVector<SDValue, 8> ConcatOps(NumConcats, DAG.getUNDEF(TmpVT));
32954         ConcatOps[0] = Res;
32955         Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatVT, ConcatOps);
32956       }
32957 
32958       Results.push_back(Res);
32959       if (IsStrict)
32960         Results.push_back(Chain);
32961 
32962       return;
32963     }
32964 
32965     if (VT.isVector() && VT.getScalarSizeInBits() < 32) {
32966       assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
32967              "Unexpected type action!");
32968 
32969       // Try to create a 128 bit vector, but don't exceed a 32 bit element.
32970       unsigned NewEltWidth = std::min(128 / VT.getVectorNumElements(), 32U);
32971       MVT PromoteVT = MVT::getVectorVT(MVT::getIntegerVT(NewEltWidth),
32972                                        VT.getVectorNumElements());
32973       SDValue Res;
32974       SDValue Chain;
32975       if (IsStrict) {
32976         Res = DAG.getNode(ISD::STRICT_FP_TO_SINT, dl, {PromoteVT, MVT::Other},
32977                           {N->getOperand(0), Src});
32978         Chain = Res.getValue(1);
32979       } else
32980         Res = DAG.getNode(ISD::FP_TO_SINT, dl, PromoteVT, Src);
32981 
32982       // Preserve what we know about the size of the original result. If the
32983       // result is v2i32, we have to manually widen the assert.
32984       if (PromoteVT == MVT::v2i32)
32985         Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Res,
32986                           DAG.getUNDEF(MVT::v2i32));
32987 
32988       Res = DAG.getNode(!IsSigned ? ISD::AssertZext : ISD::AssertSext, dl,
32989                         Res.getValueType(), Res,
32990                         DAG.getValueType(VT.getVectorElementType()));
32991 
32992       if (PromoteVT == MVT::v2i32)
32993         Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res,
32994                           DAG.getIntPtrConstant(0, dl));
32995 
32996       // Truncate back to the original width.
32997       Res = DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
32998 
32999       // Now widen to 128 bits.
33000       unsigned NumConcats = 128 / VT.getSizeInBits();
33001       MVT ConcatVT = MVT::getVectorVT(VT.getSimpleVT().getVectorElementType(),
33002                                       VT.getVectorNumElements() * NumConcats);
33003       SmallVector<SDValue, 8> ConcatOps(NumConcats, DAG.getUNDEF(VT));
33004       ConcatOps[0] = Res;
33005       Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatVT, ConcatOps);
33006       Results.push_back(Res);
33007       if (IsStrict)
33008         Results.push_back(Chain);
33009       return;
33010     }
33011 
33012 
33013     if (VT == MVT::v2i32) {
33014       assert((!IsStrict || IsSigned || Subtarget.hasAVX512()) &&
33015              "Strict unsigned conversion requires AVX512");
33016       assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
33017       assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
33018              "Unexpected type action!");
33019       if (Src.getValueType() == MVT::v2f64) {
33020         if (!IsSigned && !Subtarget.hasAVX512()) {
33021           SDValue Res =
33022               expandFP_TO_UINT_SSE(MVT::v4i32, Src, dl, DAG, Subtarget);
33023           Results.push_back(Res);
33024           return;
33025         }
33026 
33027         unsigned Opc;
33028         if (IsStrict)
33029           Opc = IsSigned ? X86ISD::STRICT_CVTTP2SI : X86ISD::STRICT_CVTTP2UI;
33030         else
33031           Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
33032 
33033         // If we have VLX we can emit a target specific FP_TO_UINT node,.
33034         if (!IsSigned && !Subtarget.hasVLX()) {
33035           // Otherwise we can defer to the generic legalizer which will widen
33036           // the input as well. This will be further widened during op
33037           // legalization to v8i32<-v8f64.
33038           // For strict nodes we'll need to widen ourselves.
33039           // FIXME: Fix the type legalizer to safely widen strict nodes?
33040           if (!IsStrict)
33041             return;
33042           Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f64, Src,
33043                             DAG.getConstantFP(0.0, dl, MVT::v2f64));
33044           Opc = N->getOpcode();
33045         }
33046         SDValue Res;
33047         SDValue Chain;
33048         if (IsStrict) {
33049           Res = DAG.getNode(Opc, dl, {MVT::v4i32, MVT::Other},
33050                             {N->getOperand(0), Src});
33051           Chain = Res.getValue(1);
33052         } else {
33053           Res = DAG.getNode(Opc, dl, MVT::v4i32, Src);
33054         }
33055         Results.push_back(Res);
33056         if (IsStrict)
33057           Results.push_back(Chain);
33058         return;
33059       }
33060 
33061       // Custom widen strict v2f32->v2i32 by padding with zeros.
33062       // FIXME: Should generic type legalizer do this?
33063       if (Src.getValueType() == MVT::v2f32 && IsStrict) {
33064         Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src,
33065                           DAG.getConstantFP(0.0, dl, MVT::v2f32));
33066         SDValue Res = DAG.getNode(N->getOpcode(), dl, {MVT::v4i32, MVT::Other},
33067                                   {N->getOperand(0), Src});
33068         Results.push_back(Res);
33069         Results.push_back(Res.getValue(1));
33070         return;
33071       }
33072 
33073       // The FP_TO_INTHelper below only handles f32/f64/f80 scalar inputs,
33074       // so early out here.
33075       return;
33076     }
33077 
33078     assert(!VT.isVector() && "Vectors should have been handled above!");
33079 
33080     if ((Subtarget.hasDQI() && VT == MVT::i64 &&
33081          (SrcVT == MVT::f32 || SrcVT == MVT::f64)) ||
33082         (Subtarget.hasFP16() && SrcVT == MVT::f16)) {
33083       assert(!Subtarget.is64Bit() && "i64 should be legal");
33084       unsigned NumElts = Subtarget.hasVLX() ? 2 : 8;
33085       // If we use a 128-bit result we might need to use a target specific node.
33086       unsigned SrcElts =
33087           std::max(NumElts, 128U / (unsigned)SrcVT.getSizeInBits());
33088       MVT VecVT = MVT::getVectorVT(MVT::i64, NumElts);
33089       MVT VecInVT = MVT::getVectorVT(SrcVT.getSimpleVT(), SrcElts);
33090       unsigned Opc = N->getOpcode();
33091       if (NumElts != SrcElts) {
33092         if (IsStrict)
33093           Opc = IsSigned ? X86ISD::STRICT_CVTTP2SI : X86ISD::STRICT_CVTTP2UI;
33094         else
33095           Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
33096       }
33097 
33098       SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl);
33099       SDValue Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VecInVT,
33100                                 DAG.getConstantFP(0.0, dl, VecInVT), Src,
33101                                 ZeroIdx);
33102       SDValue Chain;
33103       if (IsStrict) {
33104         SDVTList Tys = DAG.getVTList(VecVT, MVT::Other);
33105         Res = DAG.getNode(Opc, SDLoc(N), Tys, N->getOperand(0), Res);
33106         Chain = Res.getValue(1);
33107       } else
33108         Res = DAG.getNode(Opc, SDLoc(N), VecVT, Res);
33109       Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Res, ZeroIdx);
33110       Results.push_back(Res);
33111       if (IsStrict)
33112         Results.push_back(Chain);
33113       return;
33114     }
33115 
33116     if (VT == MVT::i128 && Subtarget.isTargetWin64()) {
33117       SDValue Chain;
33118       SDValue V = LowerWin64_FP_TO_INT128(SDValue(N, 0), DAG, Chain);
33119       Results.push_back(V);
33120       if (IsStrict)
33121         Results.push_back(Chain);
33122       return;
33123     }
33124 
33125     if (SDValue V = FP_TO_INTHelper(SDValue(N, 0), DAG, IsSigned, Chain)) {
33126       Results.push_back(V);
33127       if (IsStrict)
33128         Results.push_back(Chain);
33129     }
33130     return;
33131   }
33132   case ISD::LRINT:
33133   case ISD::LLRINT: {
33134     if (SDValue V = LRINT_LLRINTHelper(N, DAG))
33135       Results.push_back(V);
33136     return;
33137   }
33138 
33139   case ISD::SINT_TO_FP:
33140   case ISD::STRICT_SINT_TO_FP:
33141   case ISD::UINT_TO_FP:
33142   case ISD::STRICT_UINT_TO_FP: {
33143     bool IsStrict = N->isStrictFPOpcode();
33144     bool IsSigned = N->getOpcode() == ISD::SINT_TO_FP ||
33145                     N->getOpcode() == ISD::STRICT_SINT_TO_FP;
33146     EVT VT = N->getValueType(0);
33147     SDValue Src = N->getOperand(IsStrict ? 1 : 0);
33148     if (VT.getVectorElementType() == MVT::f16 && Subtarget.hasFP16() &&
33149         Subtarget.hasVLX()) {
33150       if (Src.getValueType().getVectorElementType() == MVT::i16)
33151         return;
33152 
33153       if (VT == MVT::v2f16 && Src.getValueType() == MVT::v2i32)
33154         Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src,
33155                           IsStrict ? DAG.getConstant(0, dl, MVT::v2i32)
33156                                    : DAG.getUNDEF(MVT::v2i32));
33157       if (IsStrict) {
33158         unsigned Opc =
33159             IsSigned ? X86ISD::STRICT_CVTSI2P : X86ISD::STRICT_CVTUI2P;
33160         SDValue Res = DAG.getNode(Opc, dl, {MVT::v8f16, MVT::Other},
33161                                   {N->getOperand(0), Src});
33162         Results.push_back(Res);
33163         Results.push_back(Res.getValue(1));
33164       } else {
33165         unsigned Opc = IsSigned ? X86ISD::CVTSI2P : X86ISD::CVTUI2P;
33166         Results.push_back(DAG.getNode(Opc, dl, MVT::v8f16, Src));
33167       }
33168       return;
33169     }
33170     if (VT != MVT::v2f32)
33171       return;
33172     EVT SrcVT = Src.getValueType();
33173     if (Subtarget.hasDQI() && Subtarget.hasVLX() && SrcVT == MVT::v2i64) {
33174       if (IsStrict) {
33175         unsigned Opc = IsSigned ? X86ISD::STRICT_CVTSI2P
33176                                 : X86ISD::STRICT_CVTUI2P;
33177         SDValue Res = DAG.getNode(Opc, dl, {MVT::v4f32, MVT::Other},
33178                                   {N->getOperand(0), Src});
33179         Results.push_back(Res);
33180         Results.push_back(Res.getValue(1));
33181       } else {
33182         unsigned Opc = IsSigned ? X86ISD::CVTSI2P : X86ISD::CVTUI2P;
33183         Results.push_back(DAG.getNode(Opc, dl, MVT::v4f32, Src));
33184       }
33185       return;
33186     }
33187     if (SrcVT == MVT::v2i64 && !IsSigned && Subtarget.is64Bit() &&
33188         Subtarget.hasSSE41() && !Subtarget.hasAVX512()) {
33189       SDValue Zero = DAG.getConstant(0, dl, SrcVT);
33190       SDValue One  = DAG.getConstant(1, dl, SrcVT);
33191       SDValue Sign = DAG.getNode(ISD::OR, dl, SrcVT,
33192                                  DAG.getNode(ISD::SRL, dl, SrcVT, Src, One),
33193                                  DAG.getNode(ISD::AND, dl, SrcVT, Src, One));
33194       SDValue IsNeg = DAG.getSetCC(dl, MVT::v2i64, Src, Zero, ISD::SETLT);
33195       SDValue SignSrc = DAG.getSelect(dl, SrcVT, IsNeg, Sign, Src);
33196       SmallVector<SDValue, 4> SignCvts(4, DAG.getConstantFP(0.0, dl, MVT::f32));
33197       for (int i = 0; i != 2; ++i) {
33198         SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64,
33199                                   SignSrc, DAG.getIntPtrConstant(i, dl));
33200         if (IsStrict)
33201           SignCvts[i] =
33202               DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {MVT::f32, MVT::Other},
33203                           {N->getOperand(0), Elt});
33204         else
33205           SignCvts[i] = DAG.getNode(ISD::SINT_TO_FP, dl, MVT::f32, Elt);
33206       };
33207       SDValue SignCvt = DAG.getBuildVector(MVT::v4f32, dl, SignCvts);
33208       SDValue Slow, Chain;
33209       if (IsStrict) {
33210         Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
33211                             SignCvts[0].getValue(1), SignCvts[1].getValue(1));
33212         Slow = DAG.getNode(ISD::STRICT_FADD, dl, {MVT::v4f32, MVT::Other},
33213                            {Chain, SignCvt, SignCvt});
33214         Chain = Slow.getValue(1);
33215       } else {
33216         Slow = DAG.getNode(ISD::FADD, dl, MVT::v4f32, SignCvt, SignCvt);
33217       }
33218       IsNeg = DAG.getBitcast(MVT::v4i32, IsNeg);
33219       IsNeg =
33220           DAG.getVectorShuffle(MVT::v4i32, dl, IsNeg, IsNeg, {1, 3, -1, -1});
33221       SDValue Cvt = DAG.getSelect(dl, MVT::v4f32, IsNeg, Slow, SignCvt);
33222       Results.push_back(Cvt);
33223       if (IsStrict)
33224         Results.push_back(Chain);
33225       return;
33226     }
33227 
33228     if (SrcVT != MVT::v2i32)
33229       return;
33230 
33231     if (IsSigned || Subtarget.hasAVX512()) {
33232       if (!IsStrict)
33233         return;
33234 
33235       // Custom widen strict v2i32->v2f32 to avoid scalarization.
33236       // FIXME: Should generic type legalizer do this?
33237       Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src,
33238                         DAG.getConstant(0, dl, MVT::v2i32));
33239       SDValue Res = DAG.getNode(N->getOpcode(), dl, {MVT::v4f32, MVT::Other},
33240                                 {N->getOperand(0), Src});
33241       Results.push_back(Res);
33242       Results.push_back(Res.getValue(1));
33243       return;
33244     }
33245 
33246     assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
33247     SDValue ZExtIn = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v2i64, Src);
33248     SDValue VBias = DAG.getConstantFP(
33249         llvm::bit_cast<double>(0x4330000000000000ULL), dl, MVT::v2f64);
33250     SDValue Or = DAG.getNode(ISD::OR, dl, MVT::v2i64, ZExtIn,
33251                              DAG.getBitcast(MVT::v2i64, VBias));
33252     Or = DAG.getBitcast(MVT::v2f64, Or);
33253     if (IsStrict) {
33254       SDValue Sub = DAG.getNode(ISD::STRICT_FSUB, dl, {MVT::v2f64, MVT::Other},
33255                                 {N->getOperand(0), Or, VBias});
33256       SDValue Res = DAG.getNode(X86ISD::STRICT_VFPROUND, dl,
33257                                 {MVT::v4f32, MVT::Other},
33258                                 {Sub.getValue(1), Sub});
33259       Results.push_back(Res);
33260       Results.push_back(Res.getValue(1));
33261     } else {
33262       // TODO: Are there any fast-math-flags to propagate here?
33263       SDValue Sub = DAG.getNode(ISD::FSUB, dl, MVT::v2f64, Or, VBias);
33264       Results.push_back(DAG.getNode(X86ISD::VFPROUND, dl, MVT::v4f32, Sub));
33265     }
33266     return;
33267   }
33268   case ISD::STRICT_FP_ROUND:
33269   case ISD::FP_ROUND: {
33270     bool IsStrict = N->isStrictFPOpcode();
33271     SDValue Chain = IsStrict ? N->getOperand(0) : SDValue();
33272     SDValue Src = N->getOperand(IsStrict ? 1 : 0);
33273     SDValue Rnd = N->getOperand(IsStrict ? 2 : 1);
33274     EVT SrcVT = Src.getValueType();
33275     EVT VT = N->getValueType(0);
33276     SDValue V;
33277     if (VT == MVT::v2f16 && Src.getValueType() == MVT::v2f32) {
33278       SDValue Ext = IsStrict ? DAG.getConstantFP(0.0, dl, MVT::v2f32)
33279                              : DAG.getUNDEF(MVT::v2f32);
33280       Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, Ext);
33281     }
33282     if (!Subtarget.hasFP16() && VT.getVectorElementType() == MVT::f16) {
33283       assert(Subtarget.hasF16C() && "Cannot widen f16 without F16C");
33284       if (SrcVT.getVectorElementType() != MVT::f32)
33285         return;
33286 
33287       if (IsStrict)
33288         V = DAG.getNode(X86ISD::STRICT_CVTPS2PH, dl, {MVT::v8i16, MVT::Other},
33289                         {Chain, Src, Rnd});
33290       else
33291         V = DAG.getNode(X86ISD::CVTPS2PH, dl, MVT::v8i16, Src, Rnd);
33292 
33293       Results.push_back(DAG.getBitcast(MVT::v8f16, V));
33294       if (IsStrict)
33295         Results.push_back(V.getValue(1));
33296       return;
33297     }
33298     if (!isTypeLegal(Src.getValueType()))
33299       return;
33300     EVT NewVT = VT.getVectorElementType() == MVT::f16 ? MVT::v8f16 : MVT::v4f32;
33301     if (IsStrict)
33302       V = DAG.getNode(X86ISD::STRICT_VFPROUND, dl, {NewVT, MVT::Other},
33303                       {Chain, Src});
33304     else
33305       V = DAG.getNode(X86ISD::VFPROUND, dl, NewVT, Src);
33306     Results.push_back(V);
33307     if (IsStrict)
33308       Results.push_back(V.getValue(1));
33309     return;
33310   }
33311   case ISD::FP_EXTEND:
33312   case ISD::STRICT_FP_EXTEND: {
33313     // Right now, only MVT::v2f32 has OperationAction for FP_EXTEND.
33314     // No other ValueType for FP_EXTEND should reach this point.
33315     assert(N->getValueType(0) == MVT::v2f32 &&
33316            "Do not know how to legalize this Node");
33317     if (!Subtarget.hasFP16() || !Subtarget.hasVLX())
33318       return;
33319     bool IsStrict = N->isStrictFPOpcode();
33320     SDValue Src = N->getOperand(IsStrict ? 1 : 0);
33321     if (Src.getValueType().getVectorElementType() != MVT::f16)
33322       return;
33323     SDValue Ext = IsStrict ? DAG.getConstantFP(0.0, dl, MVT::v2f16)
33324                            : DAG.getUNDEF(MVT::v2f16);
33325     SDValue V = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f16, Src, Ext);
33326     if (IsStrict)
33327       V = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::v4f32, MVT::Other},
33328                       {N->getOperand(0), V});
33329     else
33330       V = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, V);
33331     Results.push_back(V);
33332     if (IsStrict)
33333       Results.push_back(V.getValue(1));
33334     return;
33335   }
33336   case ISD::INTRINSIC_W_CHAIN: {
33337     unsigned IntNo = N->getConstantOperandVal(1);
33338     switch (IntNo) {
33339     default : llvm_unreachable("Do not know how to custom type "
33340                                "legalize this intrinsic operation!");
33341     case Intrinsic::x86_rdtsc:
33342       return getReadTimeStampCounter(N, dl, X86::RDTSC, DAG, Subtarget,
33343                                      Results);
33344     case Intrinsic::x86_rdtscp:
33345       return getReadTimeStampCounter(N, dl, X86::RDTSCP, DAG, Subtarget,
33346                                      Results);
33347     case Intrinsic::x86_rdpmc:
33348       expandIntrinsicWChainHelper(N, dl, DAG, X86::RDPMC, X86::ECX, Subtarget,
33349                                   Results);
33350       return;
33351     case Intrinsic::x86_rdpru:
33352       expandIntrinsicWChainHelper(N, dl, DAG, X86::RDPRU, X86::ECX, Subtarget,
33353         Results);
33354       return;
33355     case Intrinsic::x86_xgetbv:
33356       expandIntrinsicWChainHelper(N, dl, DAG, X86::XGETBV, X86::ECX, Subtarget,
33357                                   Results);
33358       return;
33359     }
33360   }
33361   case ISD::READCYCLECOUNTER: {
33362     return getReadTimeStampCounter(N, dl, X86::RDTSC, DAG, Subtarget, Results);
33363   }
33364   case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS: {
33365     EVT T = N->getValueType(0);
33366     assert((T == MVT::i64 || T == MVT::i128) && "can only expand cmpxchg pair");
33367     bool Regs64bit = T == MVT::i128;
33368     assert((!Regs64bit || Subtarget.canUseCMPXCHG16B()) &&
33369            "64-bit ATOMIC_CMP_SWAP_WITH_SUCCESS requires CMPXCHG16B");
33370     MVT HalfT = Regs64bit ? MVT::i64 : MVT::i32;
33371     SDValue cpInL, cpInH;
33372     std::tie(cpInL, cpInH) =
33373         DAG.SplitScalar(N->getOperand(2), dl, HalfT, HalfT);
33374     cpInL = DAG.getCopyToReg(N->getOperand(0), dl,
33375                              Regs64bit ? X86::RAX : X86::EAX, cpInL, SDValue());
33376     cpInH =
33377         DAG.getCopyToReg(cpInL.getValue(0), dl, Regs64bit ? X86::RDX : X86::EDX,
33378                          cpInH, cpInL.getValue(1));
33379     SDValue swapInL, swapInH;
33380     std::tie(swapInL, swapInH) =
33381         DAG.SplitScalar(N->getOperand(3), dl, HalfT, HalfT);
33382     swapInH =
33383         DAG.getCopyToReg(cpInH.getValue(0), dl, Regs64bit ? X86::RCX : X86::ECX,
33384                          swapInH, cpInH.getValue(1));
33385 
33386     // In 64-bit mode we might need the base pointer in RBX, but we can't know
33387     // until later. So we keep the RBX input in a vreg and use a custom
33388     // inserter.
33389     // Since RBX will be a reserved register the register allocator will not
33390     // make sure its value will be properly saved and restored around this
33391     // live-range.
33392     SDValue Result;
33393     SDVTList Tys = DAG.getVTList(MVT::Other, MVT::Glue);
33394     MachineMemOperand *MMO = cast<AtomicSDNode>(N)->getMemOperand();
33395     if (Regs64bit) {
33396       SDValue Ops[] = {swapInH.getValue(0), N->getOperand(1), swapInL,
33397                        swapInH.getValue(1)};
33398       Result =
33399           DAG.getMemIntrinsicNode(X86ISD::LCMPXCHG16_DAG, dl, Tys, Ops, T, MMO);
33400     } else {
33401       swapInL = DAG.getCopyToReg(swapInH.getValue(0), dl, X86::EBX, swapInL,
33402                                  swapInH.getValue(1));
33403       SDValue Ops[] = {swapInL.getValue(0), N->getOperand(1),
33404                        swapInL.getValue(1)};
33405       Result =
33406           DAG.getMemIntrinsicNode(X86ISD::LCMPXCHG8_DAG, dl, Tys, Ops, T, MMO);
33407     }
33408 
33409     SDValue cpOutL = DAG.getCopyFromReg(Result.getValue(0), dl,
33410                                         Regs64bit ? X86::RAX : X86::EAX,
33411                                         HalfT, Result.getValue(1));
33412     SDValue cpOutH = DAG.getCopyFromReg(cpOutL.getValue(1), dl,
33413                                         Regs64bit ? X86::RDX : X86::EDX,
33414                                         HalfT, cpOutL.getValue(2));
33415     SDValue OpsF[] = { cpOutL.getValue(0), cpOutH.getValue(0)};
33416 
33417     SDValue EFLAGS = DAG.getCopyFromReg(cpOutH.getValue(1), dl, X86::EFLAGS,
33418                                         MVT::i32, cpOutH.getValue(2));
33419     SDValue Success = getSETCC(X86::COND_E, EFLAGS, dl, DAG);
33420     Success = DAG.getZExtOrTrunc(Success, dl, N->getValueType(1));
33421 
33422     Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, T, OpsF));
33423     Results.push_back(Success);
33424     Results.push_back(EFLAGS.getValue(1));
33425     return;
33426   }
33427   case ISD::ATOMIC_LOAD: {
33428     assert(
33429         (N->getValueType(0) == MVT::i64 || N->getValueType(0) == MVT::i128) &&
33430         "Unexpected VT!");
33431     bool NoImplicitFloatOps =
33432         DAG.getMachineFunction().getFunction().hasFnAttribute(
33433             Attribute::NoImplicitFloat);
33434     if (!Subtarget.useSoftFloat() && !NoImplicitFloatOps) {
33435       auto *Node = cast<AtomicSDNode>(N);
33436 
33437       if (N->getValueType(0) == MVT::i128) {
33438         if (Subtarget.is64Bit() && Subtarget.hasAVX()) {
33439           SDValue Ld = DAG.getLoad(MVT::v2i64, dl, Node->getChain(),
33440                                    Node->getBasePtr(), Node->getMemOperand());
33441           SDValue ResL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Ld,
33442                                      DAG.getIntPtrConstant(0, dl));
33443           SDValue ResH = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Ld,
33444                                      DAG.getIntPtrConstant(1, dl));
33445           Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, N->getValueType(0),
33446                                         {ResL, ResH}));
33447           Results.push_back(Ld.getValue(1));
33448           return;
33449         }
33450         break;
33451       }
33452       if (Subtarget.hasSSE1()) {
33453         // Use a VZEXT_LOAD which will be selected as MOVQ or XORPS+MOVLPS.
33454         // Then extract the lower 64-bits.
33455         MVT LdVT = Subtarget.hasSSE2() ? MVT::v2i64 : MVT::v4f32;
33456         SDVTList Tys = DAG.getVTList(LdVT, MVT::Other);
33457         SDValue Ops[] = { Node->getChain(), Node->getBasePtr() };
33458         SDValue Ld = DAG.getMemIntrinsicNode(X86ISD::VZEXT_LOAD, dl, Tys, Ops,
33459                                              MVT::i64, Node->getMemOperand());
33460         if (Subtarget.hasSSE2()) {
33461           SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Ld,
33462                                     DAG.getIntPtrConstant(0, dl));
33463           Results.push_back(Res);
33464           Results.push_back(Ld.getValue(1));
33465           return;
33466         }
33467         // We use an alternative sequence for SSE1 that extracts as v2f32 and
33468         // then casts to i64. This avoids a 128-bit stack temporary being
33469         // created by type legalization if we were to cast v4f32->v2i64.
33470         SDValue Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2f32, Ld,
33471                                   DAG.getIntPtrConstant(0, dl));
33472         Res = DAG.getBitcast(MVT::i64, Res);
33473         Results.push_back(Res);
33474         Results.push_back(Ld.getValue(1));
33475         return;
33476       }
33477       if (Subtarget.hasX87()) {
33478         // First load this into an 80-bit X87 register. This will put the whole
33479         // integer into the significand.
33480         SDVTList Tys = DAG.getVTList(MVT::f80, MVT::Other);
33481         SDValue Ops[] = { Node->getChain(), Node->getBasePtr() };
33482         SDValue Result = DAG.getMemIntrinsicNode(X86ISD::FILD,
33483                                                  dl, Tys, Ops, MVT::i64,
33484                                                  Node->getMemOperand());
33485         SDValue Chain = Result.getValue(1);
33486 
33487         // Now store the X87 register to a stack temporary and convert to i64.
33488         // This store is not atomic and doesn't need to be.
33489         // FIXME: We don't need a stack temporary if the result of the load
33490         // is already being stored. We could just directly store there.
33491         SDValue StackPtr = DAG.CreateStackTemporary(MVT::i64);
33492         int SPFI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
33493         MachinePointerInfo MPI =
33494             MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), SPFI);
33495         SDValue StoreOps[] = { Chain, Result, StackPtr };
33496         Chain = DAG.getMemIntrinsicNode(
33497             X86ISD::FIST, dl, DAG.getVTList(MVT::Other), StoreOps, MVT::i64,
33498             MPI, std::nullopt /*Align*/, MachineMemOperand::MOStore);
33499 
33500         // Finally load the value back from the stack temporary and return it.
33501         // This load is not atomic and doesn't need to be.
33502         // This load will be further type legalized.
33503         Result = DAG.getLoad(MVT::i64, dl, Chain, StackPtr, MPI);
33504         Results.push_back(Result);
33505         Results.push_back(Result.getValue(1));
33506         return;
33507       }
33508     }
33509     // TODO: Use MOVLPS when SSE1 is available?
33510     // Delegate to generic TypeLegalization. Situations we can really handle
33511     // should have already been dealt with by AtomicExpandPass.cpp.
33512     break;
33513   }
33514   case ISD::ATOMIC_SWAP:
33515   case ISD::ATOMIC_LOAD_ADD:
33516   case ISD::ATOMIC_LOAD_SUB:
33517   case ISD::ATOMIC_LOAD_AND:
33518   case ISD::ATOMIC_LOAD_OR:
33519   case ISD::ATOMIC_LOAD_XOR:
33520   case ISD::ATOMIC_LOAD_NAND:
33521   case ISD::ATOMIC_LOAD_MIN:
33522   case ISD::ATOMIC_LOAD_MAX:
33523   case ISD::ATOMIC_LOAD_UMIN:
33524   case ISD::ATOMIC_LOAD_UMAX:
33525     // Delegate to generic TypeLegalization. Situations we can really handle
33526     // should have already been dealt with by AtomicExpandPass.cpp.
33527     break;
33528 
33529   case ISD::BITCAST: {
33530     assert(Subtarget.hasSSE2() && "Requires at least SSE2!");
33531     EVT DstVT = N->getValueType(0);
33532     EVT SrcVT = N->getOperand(0).getValueType();
33533 
33534     // If this is a bitcast from a v64i1 k-register to a i64 on a 32-bit target
33535     // we can split using the k-register rather than memory.
33536     if (SrcVT == MVT::v64i1 && DstVT == MVT::i64 && Subtarget.hasBWI()) {
33537       assert(!Subtarget.is64Bit() && "Expected 32-bit mode");
33538       SDValue Lo, Hi;
33539       std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
33540       Lo = DAG.getBitcast(MVT::i32, Lo);
33541       Hi = DAG.getBitcast(MVT::i32, Hi);
33542       SDValue Res = DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Lo, Hi);
33543       Results.push_back(Res);
33544       return;
33545     }
33546 
33547     if (DstVT.isVector() && SrcVT == MVT::x86mmx) {
33548       // FIXME: Use v4f32 for SSE1?
33549       assert(Subtarget.hasSSE2() && "Requires SSE2");
33550       assert(getTypeAction(*DAG.getContext(), DstVT) == TypeWidenVector &&
33551              "Unexpected type action!");
33552       EVT WideVT = getTypeToTransformTo(*DAG.getContext(), DstVT);
33553       SDValue Res = DAG.getNode(X86ISD::MOVQ2DQ, dl, MVT::v2i64,
33554                                 N->getOperand(0));
33555       Res = DAG.getBitcast(WideVT, Res);
33556       Results.push_back(Res);
33557       return;
33558     }
33559 
33560     return;
33561   }
33562   case ISD::MGATHER: {
33563     EVT VT = N->getValueType(0);
33564     if ((VT == MVT::v2f32 || VT == MVT::v2i32) &&
33565         (Subtarget.hasVLX() || !Subtarget.hasAVX512())) {
33566       auto *Gather = cast<MaskedGatherSDNode>(N);
33567       SDValue Index = Gather->getIndex();
33568       if (Index.getValueType() != MVT::v2i64)
33569         return;
33570       assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
33571              "Unexpected type action!");
33572       EVT WideVT = getTypeToTransformTo(*DAG.getContext(), VT);
33573       SDValue Mask = Gather->getMask();
33574       assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type");
33575       SDValue PassThru = DAG.getNode(ISD::CONCAT_VECTORS, dl, WideVT,
33576                                      Gather->getPassThru(),
33577                                      DAG.getUNDEF(VT));
33578       if (!Subtarget.hasVLX()) {
33579         // We need to widen the mask, but the instruction will only use 2
33580         // of its elements. So we can use undef.
33581         Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask,
33582                            DAG.getUNDEF(MVT::v2i1));
33583         Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask);
33584       }
33585       SDValue Ops[] = { Gather->getChain(), PassThru, Mask,
33586                         Gather->getBasePtr(), Index, Gather->getScale() };
33587       SDValue Res = DAG.getMemIntrinsicNode(
33588           X86ISD::MGATHER, dl, DAG.getVTList(WideVT, MVT::Other), Ops,
33589           Gather->getMemoryVT(), Gather->getMemOperand());
33590       Results.push_back(Res);
33591       Results.push_back(Res.getValue(1));
33592       return;
33593     }
33594     return;
33595   }
33596   case ISD::LOAD: {
33597     // Use an f64/i64 load and a scalar_to_vector for v2f32/v2i32 loads. This
33598     // avoids scalarizing in 32-bit mode. In 64-bit mode this avoids a int->fp
33599     // cast since type legalization will try to use an i64 load.
33600     MVT VT = N->getSimpleValueType(0);
33601     assert(VT.isVector() && VT.getSizeInBits() == 64 && "Unexpected VT");
33602     assert(getTypeAction(*DAG.getContext(), VT) == TypeWidenVector &&
33603            "Unexpected type action!");
33604     if (!ISD::isNON_EXTLoad(N))
33605       return;
33606     auto *Ld = cast<LoadSDNode>(N);
33607     if (Subtarget.hasSSE2()) {
33608       MVT LdVT = Subtarget.is64Bit() && VT.isInteger() ? MVT::i64 : MVT::f64;
33609       SDValue Res = DAG.getLoad(LdVT, dl, Ld->getChain(), Ld->getBasePtr(),
33610                                 Ld->getPointerInfo(), Ld->getOriginalAlign(),
33611                                 Ld->getMemOperand()->getFlags());
33612       SDValue Chain = Res.getValue(1);
33613       MVT VecVT = MVT::getVectorVT(LdVT, 2);
33614       Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Res);
33615       EVT WideVT = getTypeToTransformTo(*DAG.getContext(), VT);
33616       Res = DAG.getBitcast(WideVT, Res);
33617       Results.push_back(Res);
33618       Results.push_back(Chain);
33619       return;
33620     }
33621     assert(Subtarget.hasSSE1() && "Expected SSE");
33622     SDVTList Tys = DAG.getVTList(MVT::v4f32, MVT::Other);
33623     SDValue Ops[] = {Ld->getChain(), Ld->getBasePtr()};
33624     SDValue Res = DAG.getMemIntrinsicNode(X86ISD::VZEXT_LOAD, dl, Tys, Ops,
33625                                           MVT::i64, Ld->getMemOperand());
33626     Results.push_back(Res);
33627     Results.push_back(Res.getValue(1));
33628     return;
33629   }
33630   case ISD::ADDRSPACECAST: {
33631     SDValue V = LowerADDRSPACECAST(SDValue(N,0), DAG);
33632     Results.push_back(V);
33633     return;
33634   }
33635   case ISD::BITREVERSE: {
33636     assert(N->getValueType(0) == MVT::i64 && "Unexpected VT!");
33637     assert(Subtarget.hasXOP() && "Expected XOP");
33638     // We can use VPPERM by copying to a vector register and back. We'll need
33639     // to move the scalar in two i32 pieces.
33640     Results.push_back(LowerBITREVERSE(SDValue(N, 0), Subtarget, DAG));
33641     return;
33642   }
33643   case ISD::EXTRACT_VECTOR_ELT: {
33644     // f16 = extract vXf16 %vec, i64 %idx
33645     assert(N->getSimpleValueType(0) == MVT::f16 &&
33646            "Unexpected Value type of EXTRACT_VECTOR_ELT!");
33647     assert(Subtarget.hasFP16() && "Expected FP16");
33648     SDValue VecOp = N->getOperand(0);
33649     EVT ExtVT = VecOp.getValueType().changeVectorElementTypeToInteger();
33650     SDValue Split = DAG.getBitcast(ExtVT, N->getOperand(0));
33651     Split = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i16, Split,
33652                         N->getOperand(1));
33653     Split = DAG.getBitcast(MVT::f16, Split);
33654     Results.push_back(Split);
33655     return;
33656   }
33657   }
33658 }
33659 
getTargetNodeName(unsigned Opcode) const33660 const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
33661   switch ((X86ISD::NodeType)Opcode) {
33662   case X86ISD::FIRST_NUMBER:       break;
33663 #define NODE_NAME_CASE(NODE) case X86ISD::NODE: return "X86ISD::" #NODE;
33664   NODE_NAME_CASE(BSF)
33665   NODE_NAME_CASE(BSR)
33666   NODE_NAME_CASE(FSHL)
33667   NODE_NAME_CASE(FSHR)
33668   NODE_NAME_CASE(FAND)
33669   NODE_NAME_CASE(FANDN)
33670   NODE_NAME_CASE(FOR)
33671   NODE_NAME_CASE(FXOR)
33672   NODE_NAME_CASE(FILD)
33673   NODE_NAME_CASE(FIST)
33674   NODE_NAME_CASE(FP_TO_INT_IN_MEM)
33675   NODE_NAME_CASE(FLD)
33676   NODE_NAME_CASE(FST)
33677   NODE_NAME_CASE(CALL)
33678   NODE_NAME_CASE(CALL_RVMARKER)
33679   NODE_NAME_CASE(BT)
33680   NODE_NAME_CASE(CMP)
33681   NODE_NAME_CASE(FCMP)
33682   NODE_NAME_CASE(STRICT_FCMP)
33683   NODE_NAME_CASE(STRICT_FCMPS)
33684   NODE_NAME_CASE(COMI)
33685   NODE_NAME_CASE(UCOMI)
33686   NODE_NAME_CASE(CMPM)
33687   NODE_NAME_CASE(CMPMM)
33688   NODE_NAME_CASE(STRICT_CMPM)
33689   NODE_NAME_CASE(CMPMM_SAE)
33690   NODE_NAME_CASE(SETCC)
33691   NODE_NAME_CASE(SETCC_CARRY)
33692   NODE_NAME_CASE(FSETCC)
33693   NODE_NAME_CASE(FSETCCM)
33694   NODE_NAME_CASE(FSETCCM_SAE)
33695   NODE_NAME_CASE(CMOV)
33696   NODE_NAME_CASE(BRCOND)
33697   NODE_NAME_CASE(RET_GLUE)
33698   NODE_NAME_CASE(IRET)
33699   NODE_NAME_CASE(REP_STOS)
33700   NODE_NAME_CASE(REP_MOVS)
33701   NODE_NAME_CASE(GlobalBaseReg)
33702   NODE_NAME_CASE(Wrapper)
33703   NODE_NAME_CASE(WrapperRIP)
33704   NODE_NAME_CASE(MOVQ2DQ)
33705   NODE_NAME_CASE(MOVDQ2Q)
33706   NODE_NAME_CASE(MMX_MOVD2W)
33707   NODE_NAME_CASE(MMX_MOVW2D)
33708   NODE_NAME_CASE(PEXTRB)
33709   NODE_NAME_CASE(PEXTRW)
33710   NODE_NAME_CASE(INSERTPS)
33711   NODE_NAME_CASE(PINSRB)
33712   NODE_NAME_CASE(PINSRW)
33713   NODE_NAME_CASE(PSHUFB)
33714   NODE_NAME_CASE(ANDNP)
33715   NODE_NAME_CASE(BLENDI)
33716   NODE_NAME_CASE(BLENDV)
33717   NODE_NAME_CASE(HADD)
33718   NODE_NAME_CASE(HSUB)
33719   NODE_NAME_CASE(FHADD)
33720   NODE_NAME_CASE(FHSUB)
33721   NODE_NAME_CASE(CONFLICT)
33722   NODE_NAME_CASE(FMAX)
33723   NODE_NAME_CASE(FMAXS)
33724   NODE_NAME_CASE(FMAX_SAE)
33725   NODE_NAME_CASE(FMAXS_SAE)
33726   NODE_NAME_CASE(FMIN)
33727   NODE_NAME_CASE(FMINS)
33728   NODE_NAME_CASE(FMIN_SAE)
33729   NODE_NAME_CASE(FMINS_SAE)
33730   NODE_NAME_CASE(FMAXC)
33731   NODE_NAME_CASE(FMINC)
33732   NODE_NAME_CASE(FRSQRT)
33733   NODE_NAME_CASE(FRCP)
33734   NODE_NAME_CASE(EXTRQI)
33735   NODE_NAME_CASE(INSERTQI)
33736   NODE_NAME_CASE(TLSADDR)
33737   NODE_NAME_CASE(TLSBASEADDR)
33738   NODE_NAME_CASE(TLSCALL)
33739   NODE_NAME_CASE(TLSDESC)
33740   NODE_NAME_CASE(EH_SJLJ_SETJMP)
33741   NODE_NAME_CASE(EH_SJLJ_LONGJMP)
33742   NODE_NAME_CASE(EH_SJLJ_SETUP_DISPATCH)
33743   NODE_NAME_CASE(EH_RETURN)
33744   NODE_NAME_CASE(TC_RETURN)
33745   NODE_NAME_CASE(FNSTCW16m)
33746   NODE_NAME_CASE(FLDCW16m)
33747   NODE_NAME_CASE(FNSTENVm)
33748   NODE_NAME_CASE(FLDENVm)
33749   NODE_NAME_CASE(LCMPXCHG_DAG)
33750   NODE_NAME_CASE(LCMPXCHG8_DAG)
33751   NODE_NAME_CASE(LCMPXCHG16_DAG)
33752   NODE_NAME_CASE(LCMPXCHG16_SAVE_RBX_DAG)
33753   NODE_NAME_CASE(LADD)
33754   NODE_NAME_CASE(LSUB)
33755   NODE_NAME_CASE(LOR)
33756   NODE_NAME_CASE(LXOR)
33757   NODE_NAME_CASE(LAND)
33758   NODE_NAME_CASE(LBTS)
33759   NODE_NAME_CASE(LBTC)
33760   NODE_NAME_CASE(LBTR)
33761   NODE_NAME_CASE(LBTS_RM)
33762   NODE_NAME_CASE(LBTC_RM)
33763   NODE_NAME_CASE(LBTR_RM)
33764   NODE_NAME_CASE(AADD)
33765   NODE_NAME_CASE(AOR)
33766   NODE_NAME_CASE(AXOR)
33767   NODE_NAME_CASE(AAND)
33768   NODE_NAME_CASE(VZEXT_MOVL)
33769   NODE_NAME_CASE(VZEXT_LOAD)
33770   NODE_NAME_CASE(VEXTRACT_STORE)
33771   NODE_NAME_CASE(VTRUNC)
33772   NODE_NAME_CASE(VTRUNCS)
33773   NODE_NAME_CASE(VTRUNCUS)
33774   NODE_NAME_CASE(VMTRUNC)
33775   NODE_NAME_CASE(VMTRUNCS)
33776   NODE_NAME_CASE(VMTRUNCUS)
33777   NODE_NAME_CASE(VTRUNCSTORES)
33778   NODE_NAME_CASE(VTRUNCSTOREUS)
33779   NODE_NAME_CASE(VMTRUNCSTORES)
33780   NODE_NAME_CASE(VMTRUNCSTOREUS)
33781   NODE_NAME_CASE(VFPEXT)
33782   NODE_NAME_CASE(STRICT_VFPEXT)
33783   NODE_NAME_CASE(VFPEXT_SAE)
33784   NODE_NAME_CASE(VFPEXTS)
33785   NODE_NAME_CASE(VFPEXTS_SAE)
33786   NODE_NAME_CASE(VFPROUND)
33787   NODE_NAME_CASE(STRICT_VFPROUND)
33788   NODE_NAME_CASE(VMFPROUND)
33789   NODE_NAME_CASE(VFPROUND_RND)
33790   NODE_NAME_CASE(VFPROUNDS)
33791   NODE_NAME_CASE(VFPROUNDS_RND)
33792   NODE_NAME_CASE(VSHLDQ)
33793   NODE_NAME_CASE(VSRLDQ)
33794   NODE_NAME_CASE(VSHL)
33795   NODE_NAME_CASE(VSRL)
33796   NODE_NAME_CASE(VSRA)
33797   NODE_NAME_CASE(VSHLI)
33798   NODE_NAME_CASE(VSRLI)
33799   NODE_NAME_CASE(VSRAI)
33800   NODE_NAME_CASE(VSHLV)
33801   NODE_NAME_CASE(VSRLV)
33802   NODE_NAME_CASE(VSRAV)
33803   NODE_NAME_CASE(VROTLI)
33804   NODE_NAME_CASE(VROTRI)
33805   NODE_NAME_CASE(VPPERM)
33806   NODE_NAME_CASE(CMPP)
33807   NODE_NAME_CASE(STRICT_CMPP)
33808   NODE_NAME_CASE(PCMPEQ)
33809   NODE_NAME_CASE(PCMPGT)
33810   NODE_NAME_CASE(PHMINPOS)
33811   NODE_NAME_CASE(ADD)
33812   NODE_NAME_CASE(SUB)
33813   NODE_NAME_CASE(ADC)
33814   NODE_NAME_CASE(SBB)
33815   NODE_NAME_CASE(SMUL)
33816   NODE_NAME_CASE(UMUL)
33817   NODE_NAME_CASE(OR)
33818   NODE_NAME_CASE(XOR)
33819   NODE_NAME_CASE(AND)
33820   NODE_NAME_CASE(BEXTR)
33821   NODE_NAME_CASE(BEXTRI)
33822   NODE_NAME_CASE(BZHI)
33823   NODE_NAME_CASE(PDEP)
33824   NODE_NAME_CASE(PEXT)
33825   NODE_NAME_CASE(MUL_IMM)
33826   NODE_NAME_CASE(MOVMSK)
33827   NODE_NAME_CASE(PTEST)
33828   NODE_NAME_CASE(TESTP)
33829   NODE_NAME_CASE(KORTEST)
33830   NODE_NAME_CASE(KTEST)
33831   NODE_NAME_CASE(KADD)
33832   NODE_NAME_CASE(KSHIFTL)
33833   NODE_NAME_CASE(KSHIFTR)
33834   NODE_NAME_CASE(PACKSS)
33835   NODE_NAME_CASE(PACKUS)
33836   NODE_NAME_CASE(PALIGNR)
33837   NODE_NAME_CASE(VALIGN)
33838   NODE_NAME_CASE(VSHLD)
33839   NODE_NAME_CASE(VSHRD)
33840   NODE_NAME_CASE(VSHLDV)
33841   NODE_NAME_CASE(VSHRDV)
33842   NODE_NAME_CASE(PSHUFD)
33843   NODE_NAME_CASE(PSHUFHW)
33844   NODE_NAME_CASE(PSHUFLW)
33845   NODE_NAME_CASE(SHUFP)
33846   NODE_NAME_CASE(SHUF128)
33847   NODE_NAME_CASE(MOVLHPS)
33848   NODE_NAME_CASE(MOVHLPS)
33849   NODE_NAME_CASE(MOVDDUP)
33850   NODE_NAME_CASE(MOVSHDUP)
33851   NODE_NAME_CASE(MOVSLDUP)
33852   NODE_NAME_CASE(MOVSD)
33853   NODE_NAME_CASE(MOVSS)
33854   NODE_NAME_CASE(MOVSH)
33855   NODE_NAME_CASE(UNPCKL)
33856   NODE_NAME_CASE(UNPCKH)
33857   NODE_NAME_CASE(VBROADCAST)
33858   NODE_NAME_CASE(VBROADCAST_LOAD)
33859   NODE_NAME_CASE(VBROADCASTM)
33860   NODE_NAME_CASE(SUBV_BROADCAST_LOAD)
33861   NODE_NAME_CASE(VPERMILPV)
33862   NODE_NAME_CASE(VPERMILPI)
33863   NODE_NAME_CASE(VPERM2X128)
33864   NODE_NAME_CASE(VPERMV)
33865   NODE_NAME_CASE(VPERMV3)
33866   NODE_NAME_CASE(VPERMI)
33867   NODE_NAME_CASE(VPTERNLOG)
33868   NODE_NAME_CASE(VFIXUPIMM)
33869   NODE_NAME_CASE(VFIXUPIMM_SAE)
33870   NODE_NAME_CASE(VFIXUPIMMS)
33871   NODE_NAME_CASE(VFIXUPIMMS_SAE)
33872   NODE_NAME_CASE(VRANGE)
33873   NODE_NAME_CASE(VRANGE_SAE)
33874   NODE_NAME_CASE(VRANGES)
33875   NODE_NAME_CASE(VRANGES_SAE)
33876   NODE_NAME_CASE(PMULUDQ)
33877   NODE_NAME_CASE(PMULDQ)
33878   NODE_NAME_CASE(PSADBW)
33879   NODE_NAME_CASE(DBPSADBW)
33880   NODE_NAME_CASE(VASTART_SAVE_XMM_REGS)
33881   NODE_NAME_CASE(VAARG_64)
33882   NODE_NAME_CASE(VAARG_X32)
33883   NODE_NAME_CASE(DYN_ALLOCA)
33884   NODE_NAME_CASE(MFENCE)
33885   NODE_NAME_CASE(SEG_ALLOCA)
33886   NODE_NAME_CASE(PROBED_ALLOCA)
33887   NODE_NAME_CASE(RDRAND)
33888   NODE_NAME_CASE(RDSEED)
33889   NODE_NAME_CASE(RDPKRU)
33890   NODE_NAME_CASE(WRPKRU)
33891   NODE_NAME_CASE(VPMADDUBSW)
33892   NODE_NAME_CASE(VPMADDWD)
33893   NODE_NAME_CASE(VPSHA)
33894   NODE_NAME_CASE(VPSHL)
33895   NODE_NAME_CASE(VPCOM)
33896   NODE_NAME_CASE(VPCOMU)
33897   NODE_NAME_CASE(VPERMIL2)
33898   NODE_NAME_CASE(FMSUB)
33899   NODE_NAME_CASE(STRICT_FMSUB)
33900   NODE_NAME_CASE(FNMADD)
33901   NODE_NAME_CASE(STRICT_FNMADD)
33902   NODE_NAME_CASE(FNMSUB)
33903   NODE_NAME_CASE(STRICT_FNMSUB)
33904   NODE_NAME_CASE(FMADDSUB)
33905   NODE_NAME_CASE(FMSUBADD)
33906   NODE_NAME_CASE(FMADD_RND)
33907   NODE_NAME_CASE(FNMADD_RND)
33908   NODE_NAME_CASE(FMSUB_RND)
33909   NODE_NAME_CASE(FNMSUB_RND)
33910   NODE_NAME_CASE(FMADDSUB_RND)
33911   NODE_NAME_CASE(FMSUBADD_RND)
33912   NODE_NAME_CASE(VFMADDC)
33913   NODE_NAME_CASE(VFMADDC_RND)
33914   NODE_NAME_CASE(VFCMADDC)
33915   NODE_NAME_CASE(VFCMADDC_RND)
33916   NODE_NAME_CASE(VFMULC)
33917   NODE_NAME_CASE(VFMULC_RND)
33918   NODE_NAME_CASE(VFCMULC)
33919   NODE_NAME_CASE(VFCMULC_RND)
33920   NODE_NAME_CASE(VFMULCSH)
33921   NODE_NAME_CASE(VFMULCSH_RND)
33922   NODE_NAME_CASE(VFCMULCSH)
33923   NODE_NAME_CASE(VFCMULCSH_RND)
33924   NODE_NAME_CASE(VFMADDCSH)
33925   NODE_NAME_CASE(VFMADDCSH_RND)
33926   NODE_NAME_CASE(VFCMADDCSH)
33927   NODE_NAME_CASE(VFCMADDCSH_RND)
33928   NODE_NAME_CASE(VPMADD52H)
33929   NODE_NAME_CASE(VPMADD52L)
33930   NODE_NAME_CASE(VRNDSCALE)
33931   NODE_NAME_CASE(STRICT_VRNDSCALE)
33932   NODE_NAME_CASE(VRNDSCALE_SAE)
33933   NODE_NAME_CASE(VRNDSCALES)
33934   NODE_NAME_CASE(VRNDSCALES_SAE)
33935   NODE_NAME_CASE(VREDUCE)
33936   NODE_NAME_CASE(VREDUCE_SAE)
33937   NODE_NAME_CASE(VREDUCES)
33938   NODE_NAME_CASE(VREDUCES_SAE)
33939   NODE_NAME_CASE(VGETMANT)
33940   NODE_NAME_CASE(VGETMANT_SAE)
33941   NODE_NAME_CASE(VGETMANTS)
33942   NODE_NAME_CASE(VGETMANTS_SAE)
33943   NODE_NAME_CASE(PCMPESTR)
33944   NODE_NAME_CASE(PCMPISTR)
33945   NODE_NAME_CASE(XTEST)
33946   NODE_NAME_CASE(COMPRESS)
33947   NODE_NAME_CASE(EXPAND)
33948   NODE_NAME_CASE(SELECTS)
33949   NODE_NAME_CASE(ADDSUB)
33950   NODE_NAME_CASE(RCP14)
33951   NODE_NAME_CASE(RCP14S)
33952   NODE_NAME_CASE(RSQRT14)
33953   NODE_NAME_CASE(RSQRT14S)
33954   NODE_NAME_CASE(FADD_RND)
33955   NODE_NAME_CASE(FADDS)
33956   NODE_NAME_CASE(FADDS_RND)
33957   NODE_NAME_CASE(FSUB_RND)
33958   NODE_NAME_CASE(FSUBS)
33959   NODE_NAME_CASE(FSUBS_RND)
33960   NODE_NAME_CASE(FMUL_RND)
33961   NODE_NAME_CASE(FMULS)
33962   NODE_NAME_CASE(FMULS_RND)
33963   NODE_NAME_CASE(FDIV_RND)
33964   NODE_NAME_CASE(FDIVS)
33965   NODE_NAME_CASE(FDIVS_RND)
33966   NODE_NAME_CASE(FSQRT_RND)
33967   NODE_NAME_CASE(FSQRTS)
33968   NODE_NAME_CASE(FSQRTS_RND)
33969   NODE_NAME_CASE(FGETEXP)
33970   NODE_NAME_CASE(FGETEXP_SAE)
33971   NODE_NAME_CASE(FGETEXPS)
33972   NODE_NAME_CASE(FGETEXPS_SAE)
33973   NODE_NAME_CASE(SCALEF)
33974   NODE_NAME_CASE(SCALEF_RND)
33975   NODE_NAME_CASE(SCALEFS)
33976   NODE_NAME_CASE(SCALEFS_RND)
33977   NODE_NAME_CASE(MULHRS)
33978   NODE_NAME_CASE(SINT_TO_FP_RND)
33979   NODE_NAME_CASE(UINT_TO_FP_RND)
33980   NODE_NAME_CASE(CVTTP2SI)
33981   NODE_NAME_CASE(CVTTP2UI)
33982   NODE_NAME_CASE(STRICT_CVTTP2SI)
33983   NODE_NAME_CASE(STRICT_CVTTP2UI)
33984   NODE_NAME_CASE(MCVTTP2SI)
33985   NODE_NAME_CASE(MCVTTP2UI)
33986   NODE_NAME_CASE(CVTTP2SI_SAE)
33987   NODE_NAME_CASE(CVTTP2UI_SAE)
33988   NODE_NAME_CASE(CVTTS2SI)
33989   NODE_NAME_CASE(CVTTS2UI)
33990   NODE_NAME_CASE(CVTTS2SI_SAE)
33991   NODE_NAME_CASE(CVTTS2UI_SAE)
33992   NODE_NAME_CASE(CVTSI2P)
33993   NODE_NAME_CASE(CVTUI2P)
33994   NODE_NAME_CASE(STRICT_CVTSI2P)
33995   NODE_NAME_CASE(STRICT_CVTUI2P)
33996   NODE_NAME_CASE(MCVTSI2P)
33997   NODE_NAME_CASE(MCVTUI2P)
33998   NODE_NAME_CASE(VFPCLASS)
33999   NODE_NAME_CASE(VFPCLASSS)
34000   NODE_NAME_CASE(MULTISHIFT)
34001   NODE_NAME_CASE(SCALAR_SINT_TO_FP)
34002   NODE_NAME_CASE(SCALAR_SINT_TO_FP_RND)
34003   NODE_NAME_CASE(SCALAR_UINT_TO_FP)
34004   NODE_NAME_CASE(SCALAR_UINT_TO_FP_RND)
34005   NODE_NAME_CASE(CVTPS2PH)
34006   NODE_NAME_CASE(STRICT_CVTPS2PH)
34007   NODE_NAME_CASE(CVTPS2PH_SAE)
34008   NODE_NAME_CASE(MCVTPS2PH)
34009   NODE_NAME_CASE(MCVTPS2PH_SAE)
34010   NODE_NAME_CASE(CVTPH2PS)
34011   NODE_NAME_CASE(STRICT_CVTPH2PS)
34012   NODE_NAME_CASE(CVTPH2PS_SAE)
34013   NODE_NAME_CASE(CVTP2SI)
34014   NODE_NAME_CASE(CVTP2UI)
34015   NODE_NAME_CASE(MCVTP2SI)
34016   NODE_NAME_CASE(MCVTP2UI)
34017   NODE_NAME_CASE(CVTP2SI_RND)
34018   NODE_NAME_CASE(CVTP2UI_RND)
34019   NODE_NAME_CASE(CVTS2SI)
34020   NODE_NAME_CASE(CVTS2UI)
34021   NODE_NAME_CASE(CVTS2SI_RND)
34022   NODE_NAME_CASE(CVTS2UI_RND)
34023   NODE_NAME_CASE(CVTNE2PS2BF16)
34024   NODE_NAME_CASE(CVTNEPS2BF16)
34025   NODE_NAME_CASE(MCVTNEPS2BF16)
34026   NODE_NAME_CASE(DPBF16PS)
34027   NODE_NAME_CASE(LWPINS)
34028   NODE_NAME_CASE(MGATHER)
34029   NODE_NAME_CASE(MSCATTER)
34030   NODE_NAME_CASE(VPDPBUSD)
34031   NODE_NAME_CASE(VPDPBUSDS)
34032   NODE_NAME_CASE(VPDPWSSD)
34033   NODE_NAME_CASE(VPDPWSSDS)
34034   NODE_NAME_CASE(VPSHUFBITQMB)
34035   NODE_NAME_CASE(GF2P8MULB)
34036   NODE_NAME_CASE(GF2P8AFFINEQB)
34037   NODE_NAME_CASE(GF2P8AFFINEINVQB)
34038   NODE_NAME_CASE(NT_CALL)
34039   NODE_NAME_CASE(NT_BRIND)
34040   NODE_NAME_CASE(UMWAIT)
34041   NODE_NAME_CASE(TPAUSE)
34042   NODE_NAME_CASE(ENQCMD)
34043   NODE_NAME_CASE(ENQCMDS)
34044   NODE_NAME_CASE(VP2INTERSECT)
34045   NODE_NAME_CASE(VPDPBSUD)
34046   NODE_NAME_CASE(VPDPBSUDS)
34047   NODE_NAME_CASE(VPDPBUUD)
34048   NODE_NAME_CASE(VPDPBUUDS)
34049   NODE_NAME_CASE(VPDPBSSD)
34050   NODE_NAME_CASE(VPDPBSSDS)
34051   NODE_NAME_CASE(AESENC128KL)
34052   NODE_NAME_CASE(AESDEC128KL)
34053   NODE_NAME_CASE(AESENC256KL)
34054   NODE_NAME_CASE(AESDEC256KL)
34055   NODE_NAME_CASE(AESENCWIDE128KL)
34056   NODE_NAME_CASE(AESDECWIDE128KL)
34057   NODE_NAME_CASE(AESENCWIDE256KL)
34058   NODE_NAME_CASE(AESDECWIDE256KL)
34059   NODE_NAME_CASE(CMPCCXADD)
34060   NODE_NAME_CASE(TESTUI)
34061   NODE_NAME_CASE(FP80_ADD)
34062   NODE_NAME_CASE(STRICT_FP80_ADD)
34063   NODE_NAME_CASE(CCMP)
34064   NODE_NAME_CASE(CTEST)
34065   NODE_NAME_CASE(CLOAD)
34066   NODE_NAME_CASE(CSTORE)
34067   }
34068   return nullptr;
34069 #undef NODE_NAME_CASE
34070 }
34071 
34072 /// Return true if the addressing mode represented by AM is legal for this
34073 /// target, for a load/store of the specified type.
isLegalAddressingMode(const DataLayout & DL,const AddrMode & AM,Type * Ty,unsigned AS,Instruction * I) const34074 bool X86TargetLowering::isLegalAddressingMode(const DataLayout &DL,
34075                                               const AddrMode &AM, Type *Ty,
34076                                               unsigned AS,
34077                                               Instruction *I) const {
34078   // X86 supports extremely general addressing modes.
34079   CodeModel::Model M = getTargetMachine().getCodeModel();
34080 
34081   // X86 allows a sign-extended 32-bit immediate field as a displacement.
34082   if (!X86::isOffsetSuitableForCodeModel(AM.BaseOffs, M, AM.BaseGV != nullptr))
34083     return false;
34084 
34085   if (AM.BaseGV) {
34086     unsigned GVFlags = Subtarget.classifyGlobalReference(AM.BaseGV);
34087 
34088     // If a reference to this global requires an extra load, we can't fold it.
34089     if (isGlobalStubReference(GVFlags))
34090       return false;
34091 
34092     // If BaseGV requires a register for the PIC base, we cannot also have a
34093     // BaseReg specified.
34094     if (AM.HasBaseReg && isGlobalRelativeToPICBase(GVFlags))
34095       return false;
34096 
34097     // If lower 4G is not available, then we must use rip-relative addressing.
34098     if ((M != CodeModel::Small || isPositionIndependent()) &&
34099         Subtarget.is64Bit() && (AM.BaseOffs || AM.Scale > 1))
34100       return false;
34101   }
34102 
34103   switch (AM.Scale) {
34104   case 0:
34105   case 1:
34106   case 2:
34107   case 4:
34108   case 8:
34109     // These scales always work.
34110     break;
34111   case 3:
34112   case 5:
34113   case 9:
34114     // These scales are formed with basereg+scalereg.  Only accept if there is
34115     // no basereg yet.
34116     if (AM.HasBaseReg)
34117       return false;
34118     break;
34119   default:  // Other stuff never works.
34120     return false;
34121   }
34122 
34123   return true;
34124 }
34125 
isVectorShiftByScalarCheap(Type * Ty) const34126 bool X86TargetLowering::isVectorShiftByScalarCheap(Type *Ty) const {
34127   unsigned Bits = Ty->getScalarSizeInBits();
34128 
34129   // XOP has v16i8/v8i16/v4i32/v2i64 variable vector shifts.
34130   // Splitting for v32i8/v16i16 on XOP+AVX2 targets is still preferred.
34131   if (Subtarget.hasXOP() &&
34132       (Bits == 8 || Bits == 16 || Bits == 32 || Bits == 64))
34133     return false;
34134 
34135   // AVX2 has vpsllv[dq] instructions (and other shifts) that make variable
34136   // shifts just as cheap as scalar ones.
34137   if (Subtarget.hasAVX2() && (Bits == 32 || Bits == 64))
34138     return false;
34139 
34140   // AVX512BW has shifts such as vpsllvw.
34141   if (Subtarget.hasBWI() && Bits == 16)
34142     return false;
34143 
34144   // Otherwise, it's significantly cheaper to shift by a scalar amount than by a
34145   // fully general vector.
34146   return true;
34147 }
34148 
isBinOp(unsigned Opcode) const34149 bool X86TargetLowering::isBinOp(unsigned Opcode) const {
34150   switch (Opcode) {
34151   // These are non-commutative binops.
34152   // TODO: Add more X86ISD opcodes once we have test coverage.
34153   case X86ISD::ANDNP:
34154   case X86ISD::PCMPGT:
34155   case X86ISD::FMAX:
34156   case X86ISD::FMIN:
34157   case X86ISD::FANDN:
34158   case X86ISD::VPSHA:
34159   case X86ISD::VPSHL:
34160   case X86ISD::VSHLV:
34161   case X86ISD::VSRLV:
34162   case X86ISD::VSRAV:
34163     return true;
34164   }
34165 
34166   return TargetLoweringBase::isBinOp(Opcode);
34167 }
34168 
isCommutativeBinOp(unsigned Opcode) const34169 bool X86TargetLowering::isCommutativeBinOp(unsigned Opcode) const {
34170   switch (Opcode) {
34171   // TODO: Add more X86ISD opcodes once we have test coverage.
34172   case X86ISD::PCMPEQ:
34173   case X86ISD::PMULDQ:
34174   case X86ISD::PMULUDQ:
34175   case X86ISD::FMAXC:
34176   case X86ISD::FMINC:
34177   case X86ISD::FAND:
34178   case X86ISD::FOR:
34179   case X86ISD::FXOR:
34180     return true;
34181   }
34182 
34183   return TargetLoweringBase::isCommutativeBinOp(Opcode);
34184 }
34185 
isTruncateFree(Type * Ty1,Type * Ty2) const34186 bool X86TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
34187   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
34188     return false;
34189   unsigned NumBits1 = Ty1->getPrimitiveSizeInBits();
34190   unsigned NumBits2 = Ty2->getPrimitiveSizeInBits();
34191   return NumBits1 > NumBits2;
34192 }
34193 
allowTruncateForTailCall(Type * Ty1,Type * Ty2) const34194 bool X86TargetLowering::allowTruncateForTailCall(Type *Ty1, Type *Ty2) const {
34195   if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
34196     return false;
34197 
34198   if (!isTypeLegal(EVT::getEVT(Ty1)))
34199     return false;
34200 
34201   assert(Ty1->getPrimitiveSizeInBits() <= 64 && "i128 is probably not a noop");
34202 
34203   // Assuming the caller doesn't have a zeroext or signext return parameter,
34204   // truncation all the way down to i1 is valid.
34205   return true;
34206 }
34207 
isLegalICmpImmediate(int64_t Imm) const34208 bool X86TargetLowering::isLegalICmpImmediate(int64_t Imm) const {
34209   return isInt<32>(Imm);
34210 }
34211 
isLegalAddImmediate(int64_t Imm) const34212 bool X86TargetLowering::isLegalAddImmediate(int64_t Imm) const {
34213   // Can also use sub to handle negated immediates.
34214   return isInt<32>(Imm);
34215 }
34216 
isLegalStoreImmediate(int64_t Imm) const34217 bool X86TargetLowering::isLegalStoreImmediate(int64_t Imm) const {
34218   return isInt<32>(Imm);
34219 }
34220 
isTruncateFree(EVT VT1,EVT VT2) const34221 bool X86TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const {
34222   if (!VT1.isScalarInteger() || !VT2.isScalarInteger())
34223     return false;
34224   unsigned NumBits1 = VT1.getSizeInBits();
34225   unsigned NumBits2 = VT2.getSizeInBits();
34226   return NumBits1 > NumBits2;
34227 }
34228 
isZExtFree(Type * Ty1,Type * Ty2) const34229 bool X86TargetLowering::isZExtFree(Type *Ty1, Type *Ty2) const {
34230   // x86-64 implicitly zero-extends 32-bit results in 64-bit registers.
34231   return Ty1->isIntegerTy(32) && Ty2->isIntegerTy(64) && Subtarget.is64Bit();
34232 }
34233 
isZExtFree(EVT VT1,EVT VT2) const34234 bool X86TargetLowering::isZExtFree(EVT VT1, EVT VT2) const {
34235   // x86-64 implicitly zero-extends 32-bit results in 64-bit registers.
34236   return VT1 == MVT::i32 && VT2 == MVT::i64 && Subtarget.is64Bit();
34237 }
34238 
isZExtFree(SDValue Val,EVT VT2) const34239 bool X86TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
34240   EVT VT1 = Val.getValueType();
34241   if (isZExtFree(VT1, VT2))
34242     return true;
34243 
34244   if (Val.getOpcode() != ISD::LOAD)
34245     return false;
34246 
34247   if (!VT1.isSimple() || !VT1.isInteger() ||
34248       !VT2.isSimple() || !VT2.isInteger())
34249     return false;
34250 
34251   switch (VT1.getSimpleVT().SimpleTy) {
34252   default: break;
34253   case MVT::i8:
34254   case MVT::i16:
34255   case MVT::i32:
34256     // X86 has 8, 16, and 32-bit zero-extending loads.
34257     return true;
34258   }
34259 
34260   return false;
34261 }
34262 
shouldSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const34263 bool X86TargetLowering::shouldSinkOperands(Instruction *I,
34264                                            SmallVectorImpl<Use *> &Ops) const {
34265   using namespace llvm::PatternMatch;
34266 
34267   FixedVectorType *VTy = dyn_cast<FixedVectorType>(I->getType());
34268   if (!VTy)
34269     return false;
34270 
34271   if (I->getOpcode() == Instruction::Mul &&
34272       VTy->getElementType()->isIntegerTy(64)) {
34273     for (auto &Op : I->operands()) {
34274       // Make sure we are not already sinking this operand
34275       if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
34276         continue;
34277 
34278       // Look for PMULDQ pattern where the input is a sext_inreg from vXi32 or
34279       // the PMULUDQ pattern where the input is a zext_inreg from vXi32.
34280       if (Subtarget.hasSSE41() &&
34281           match(Op.get(), m_AShr(m_Shl(m_Value(), m_SpecificInt(32)),
34282                                  m_SpecificInt(32)))) {
34283         Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
34284         Ops.push_back(&Op);
34285       } else if (Subtarget.hasSSE2() &&
34286                  match(Op.get(),
34287                        m_And(m_Value(), m_SpecificInt(UINT64_C(0xffffffff))))) {
34288         Ops.push_back(&Op);
34289       }
34290     }
34291 
34292     return !Ops.empty();
34293   }
34294 
34295   // A uniform shift amount in a vector shift or funnel shift may be much
34296   // cheaper than a generic variable vector shift, so make that pattern visible
34297   // to SDAG by sinking the shuffle instruction next to the shift.
34298   int ShiftAmountOpNum = -1;
34299   if (I->isShift())
34300     ShiftAmountOpNum = 1;
34301   else if (auto *II = dyn_cast<IntrinsicInst>(I)) {
34302     if (II->getIntrinsicID() == Intrinsic::fshl ||
34303         II->getIntrinsicID() == Intrinsic::fshr)
34304       ShiftAmountOpNum = 2;
34305   }
34306 
34307   if (ShiftAmountOpNum == -1)
34308     return false;
34309 
34310   auto *Shuf = dyn_cast<ShuffleVectorInst>(I->getOperand(ShiftAmountOpNum));
34311   if (Shuf && getSplatIndex(Shuf->getShuffleMask()) >= 0 &&
34312       isVectorShiftByScalarCheap(I->getType())) {
34313     Ops.push_back(&I->getOperandUse(ShiftAmountOpNum));
34314     return true;
34315   }
34316 
34317   return false;
34318 }
34319 
shouldConvertPhiType(Type * From,Type * To) const34320 bool X86TargetLowering::shouldConvertPhiType(Type *From, Type *To) const {
34321   if (!Subtarget.is64Bit())
34322     return false;
34323   return TargetLowering::shouldConvertPhiType(From, To);
34324 }
34325 
isVectorLoadExtDesirable(SDValue ExtVal) const34326 bool X86TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
34327   if (isa<MaskedLoadSDNode>(ExtVal.getOperand(0)))
34328     return false;
34329 
34330   EVT SrcVT = ExtVal.getOperand(0).getValueType();
34331 
34332   // There is no extending load for vXi1.
34333   if (SrcVT.getScalarType() == MVT::i1)
34334     return false;
34335 
34336   return true;
34337 }
34338 
isFMAFasterThanFMulAndFAdd(const MachineFunction & MF,EVT VT) const34339 bool X86TargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
34340                                                    EVT VT) const {
34341   if (!Subtarget.hasAnyFMA())
34342     return false;
34343 
34344   VT = VT.getScalarType();
34345 
34346   if (!VT.isSimple())
34347     return false;
34348 
34349   switch (VT.getSimpleVT().SimpleTy) {
34350   case MVT::f16:
34351     return Subtarget.hasFP16();
34352   case MVT::f32:
34353   case MVT::f64:
34354     return true;
34355   default:
34356     break;
34357   }
34358 
34359   return false;
34360 }
34361 
isNarrowingProfitable(EVT SrcVT,EVT DestVT) const34362 bool X86TargetLowering::isNarrowingProfitable(EVT SrcVT, EVT DestVT) const {
34363   // i16 instructions are longer (0x66 prefix) and potentially slower.
34364   return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
34365 }
34366 
shouldFoldSelectWithIdentityConstant(unsigned Opcode,EVT VT) const34367 bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
34368                                                              EVT VT) const {
34369   // TODO: This is too general. There are cases where pre-AVX512 codegen would
34370   //       benefit. The transform may also be profitable for scalar code.
34371   if (!Subtarget.hasAVX512())
34372     return false;
34373   if (!Subtarget.hasVLX() && !VT.is512BitVector())
34374     return false;
34375   if (!VT.isVector() || VT.getScalarType() == MVT::i1)
34376     return false;
34377 
34378   return true;
34379 }
34380 
34381 /// Targets can use this to indicate that they only support *some*
34382 /// VECTOR_SHUFFLE operations, those with specific masks.
34383 /// By default, if a target supports the VECTOR_SHUFFLE node, all mask values
34384 /// are assumed to be legal.
isShuffleMaskLegal(ArrayRef<int> Mask,EVT VT) const34385 bool X86TargetLowering::isShuffleMaskLegal(ArrayRef<int> Mask, EVT VT) const {
34386   if (!VT.isSimple())
34387     return false;
34388 
34389   // Not for i1 vectors
34390   if (VT.getSimpleVT().getScalarType() == MVT::i1)
34391     return false;
34392 
34393   // Very little shuffling can be done for 64-bit vectors right now.
34394   if (VT.getSimpleVT().getSizeInBits() == 64)
34395     return false;
34396 
34397   // We only care that the types being shuffled are legal. The lowering can
34398   // handle any possible shuffle mask that results.
34399   return isTypeLegal(VT.getSimpleVT());
34400 }
34401 
isVectorClearMaskLegal(ArrayRef<int> Mask,EVT VT) const34402 bool X86TargetLowering::isVectorClearMaskLegal(ArrayRef<int> Mask,
34403                                                EVT VT) const {
34404   // Don't convert an 'and' into a shuffle that we don't directly support.
34405   // vpblendw and vpshufb for 256-bit vectors are not available on AVX1.
34406   if (!Subtarget.hasAVX2())
34407     if (VT == MVT::v32i8 || VT == MVT::v16i16)
34408       return false;
34409 
34410   // Just delegate to the generic legality, clear masks aren't special.
34411   return isShuffleMaskLegal(Mask, VT);
34412 }
34413 
areJTsAllowed(const Function * Fn) const34414 bool X86TargetLowering::areJTsAllowed(const Function *Fn) const {
34415   // If the subtarget is using thunks, we need to not generate jump tables.
34416   if (Subtarget.useIndirectThunkBranches())
34417     return false;
34418 
34419   // Otherwise, fallback on the generic logic.
34420   return TargetLowering::areJTsAllowed(Fn);
34421 }
34422 
getPreferredSwitchConditionType(LLVMContext & Context,EVT ConditionVT) const34423 MVT X86TargetLowering::getPreferredSwitchConditionType(LLVMContext &Context,
34424                                                        EVT ConditionVT) const {
34425   // Avoid 8 and 16 bit types because they increase the chance for unnecessary
34426   // zero-extensions.
34427   if (ConditionVT.getSizeInBits() < 32)
34428     return MVT::i32;
34429   return TargetLoweringBase::getPreferredSwitchConditionType(Context,
34430                                                              ConditionVT);
34431 }
34432 
34433 //===----------------------------------------------------------------------===//
34434 //                           X86 Scheduler Hooks
34435 //===----------------------------------------------------------------------===//
34436 
34437 // Returns true if EFLAG is consumed after this iterator in the rest of the
34438 // basic block or any successors of the basic block.
isEFLAGSLiveAfter(MachineBasicBlock::iterator Itr,MachineBasicBlock * BB)34439 static bool isEFLAGSLiveAfter(MachineBasicBlock::iterator Itr,
34440                               MachineBasicBlock *BB) {
34441   // Scan forward through BB for a use/def of EFLAGS.
34442   for (const MachineInstr &mi : llvm::make_range(std::next(Itr), BB->end())) {
34443     if (mi.readsRegister(X86::EFLAGS, /*TRI=*/nullptr))
34444       return true;
34445     // If we found a def, we can stop searching.
34446     if (mi.definesRegister(X86::EFLAGS, /*TRI=*/nullptr))
34447       return false;
34448   }
34449 
34450   // If we hit the end of the block, check whether EFLAGS is live into a
34451   // successor.
34452   for (MachineBasicBlock *Succ : BB->successors())
34453     if (Succ->isLiveIn(X86::EFLAGS))
34454       return true;
34455 
34456   return false;
34457 }
34458 
34459 /// Utility function to emit xbegin specifying the start of an RTM region.
emitXBegin(MachineInstr & MI,MachineBasicBlock * MBB,const TargetInstrInfo * TII)34460 static MachineBasicBlock *emitXBegin(MachineInstr &MI, MachineBasicBlock *MBB,
34461                                      const TargetInstrInfo *TII) {
34462   const MIMetadata MIMD(MI);
34463 
34464   const BasicBlock *BB = MBB->getBasicBlock();
34465   MachineFunction::iterator I = ++MBB->getIterator();
34466 
34467   // For the v = xbegin(), we generate
34468   //
34469   // thisMBB:
34470   //  xbegin sinkMBB
34471   //
34472   // mainMBB:
34473   //  s0 = -1
34474   //
34475   // fallBB:
34476   //  eax = # XABORT_DEF
34477   //  s1 = eax
34478   //
34479   // sinkMBB:
34480   //  v = phi(s0/mainBB, s1/fallBB)
34481 
34482   MachineBasicBlock *thisMBB = MBB;
34483   MachineFunction *MF = MBB->getParent();
34484   MachineBasicBlock *mainMBB = MF->CreateMachineBasicBlock(BB);
34485   MachineBasicBlock *fallMBB = MF->CreateMachineBasicBlock(BB);
34486   MachineBasicBlock *sinkMBB = MF->CreateMachineBasicBlock(BB);
34487   MF->insert(I, mainMBB);
34488   MF->insert(I, fallMBB);
34489   MF->insert(I, sinkMBB);
34490 
34491   if (isEFLAGSLiveAfter(MI, MBB)) {
34492     mainMBB->addLiveIn(X86::EFLAGS);
34493     fallMBB->addLiveIn(X86::EFLAGS);
34494     sinkMBB->addLiveIn(X86::EFLAGS);
34495   }
34496 
34497   // Transfer the remainder of BB and its successor edges to sinkMBB.
34498   sinkMBB->splice(sinkMBB->begin(), MBB,
34499                   std::next(MachineBasicBlock::iterator(MI)), MBB->end());
34500   sinkMBB->transferSuccessorsAndUpdatePHIs(MBB);
34501 
34502   MachineRegisterInfo &MRI = MF->getRegInfo();
34503   Register DstReg = MI.getOperand(0).getReg();
34504   const TargetRegisterClass *RC = MRI.getRegClass(DstReg);
34505   Register mainDstReg = MRI.createVirtualRegister(RC);
34506   Register fallDstReg = MRI.createVirtualRegister(RC);
34507 
34508   // thisMBB:
34509   //  xbegin fallMBB
34510   //  # fallthrough to mainMBB
34511   //  # abortion to fallMBB
34512   BuildMI(thisMBB, MIMD, TII->get(X86::XBEGIN_4)).addMBB(fallMBB);
34513   thisMBB->addSuccessor(mainMBB);
34514   thisMBB->addSuccessor(fallMBB);
34515 
34516   // mainMBB:
34517   //  mainDstReg := -1
34518   BuildMI(mainMBB, MIMD, TII->get(X86::MOV32ri), mainDstReg).addImm(-1);
34519   BuildMI(mainMBB, MIMD, TII->get(X86::JMP_1)).addMBB(sinkMBB);
34520   mainMBB->addSuccessor(sinkMBB);
34521 
34522   // fallMBB:
34523   //  ; pseudo instruction to model hardware's definition from XABORT
34524   //  EAX := XABORT_DEF
34525   //  fallDstReg := EAX
34526   BuildMI(fallMBB, MIMD, TII->get(X86::XABORT_DEF));
34527   BuildMI(fallMBB, MIMD, TII->get(TargetOpcode::COPY), fallDstReg)
34528       .addReg(X86::EAX);
34529   fallMBB->addSuccessor(sinkMBB);
34530 
34531   // sinkMBB:
34532   //  DstReg := phi(mainDstReg/mainBB, fallDstReg/fallBB)
34533   BuildMI(*sinkMBB, sinkMBB->begin(), MIMD, TII->get(X86::PHI), DstReg)
34534       .addReg(mainDstReg).addMBB(mainMBB)
34535       .addReg(fallDstReg).addMBB(fallMBB);
34536 
34537   MI.eraseFromParent();
34538   return sinkMBB;
34539 }
34540 
34541 MachineBasicBlock *
EmitVAARGWithCustomInserter(MachineInstr & MI,MachineBasicBlock * MBB) const34542 X86TargetLowering::EmitVAARGWithCustomInserter(MachineInstr &MI,
34543                                                MachineBasicBlock *MBB) const {
34544   // Emit va_arg instruction on X86-64.
34545 
34546   // Operands to this pseudo-instruction:
34547   // 0  ) Output        : destination address (reg)
34548   // 1-5) Input         : va_list address (addr, i64mem)
34549   // 6  ) ArgSize       : Size (in bytes) of vararg type
34550   // 7  ) ArgMode       : 0=overflow only, 1=use gp_offset, 2=use fp_offset
34551   // 8  ) Align         : Alignment of type
34552   // 9  ) EFLAGS (implicit-def)
34553 
34554   assert(MI.getNumOperands() == 10 && "VAARG should have 10 operands!");
34555   static_assert(X86::AddrNumOperands == 5, "VAARG assumes 5 address operands");
34556 
34557   Register DestReg = MI.getOperand(0).getReg();
34558   MachineOperand &Base = MI.getOperand(1);
34559   MachineOperand &Scale = MI.getOperand(2);
34560   MachineOperand &Index = MI.getOperand(3);
34561   MachineOperand &Disp = MI.getOperand(4);
34562   MachineOperand &Segment = MI.getOperand(5);
34563   unsigned ArgSize = MI.getOperand(6).getImm();
34564   unsigned ArgMode = MI.getOperand(7).getImm();
34565   Align Alignment = Align(MI.getOperand(8).getImm());
34566 
34567   MachineFunction *MF = MBB->getParent();
34568 
34569   // Memory Reference
34570   assert(MI.hasOneMemOperand() && "Expected VAARG to have one memoperand");
34571 
34572   MachineMemOperand *OldMMO = MI.memoperands().front();
34573 
34574   // Clone the MMO into two separate MMOs for loading and storing
34575   MachineMemOperand *LoadOnlyMMO = MF->getMachineMemOperand(
34576       OldMMO, OldMMO->getFlags() & ~MachineMemOperand::MOStore);
34577   MachineMemOperand *StoreOnlyMMO = MF->getMachineMemOperand(
34578       OldMMO, OldMMO->getFlags() & ~MachineMemOperand::MOLoad);
34579 
34580   // Machine Information
34581   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
34582   MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
34583   const TargetRegisterClass *AddrRegClass =
34584       getRegClassFor(getPointerTy(MBB->getParent()->getDataLayout()));
34585   const TargetRegisterClass *OffsetRegClass = getRegClassFor(MVT::i32);
34586   const MIMetadata MIMD(MI);
34587 
34588   // struct va_list {
34589   //   i32   gp_offset
34590   //   i32   fp_offset
34591   //   i64   overflow_area (address)
34592   //   i64   reg_save_area (address)
34593   // }
34594   // sizeof(va_list) = 24
34595   // alignment(va_list) = 8
34596 
34597   unsigned TotalNumIntRegs = 6;
34598   unsigned TotalNumXMMRegs = 8;
34599   bool UseGPOffset = (ArgMode == 1);
34600   bool UseFPOffset = (ArgMode == 2);
34601   unsigned MaxOffset = TotalNumIntRegs * 8 +
34602                        (UseFPOffset ? TotalNumXMMRegs * 16 : 0);
34603 
34604   /* Align ArgSize to a multiple of 8 */
34605   unsigned ArgSizeA8 = (ArgSize + 7) & ~7;
34606   bool NeedsAlign = (Alignment > 8);
34607 
34608   MachineBasicBlock *thisMBB = MBB;
34609   MachineBasicBlock *overflowMBB;
34610   MachineBasicBlock *offsetMBB;
34611   MachineBasicBlock *endMBB;
34612 
34613   unsigned OffsetDestReg = 0;    // Argument address computed by offsetMBB
34614   unsigned OverflowDestReg = 0;  // Argument address computed by overflowMBB
34615   unsigned OffsetReg = 0;
34616 
34617   if (!UseGPOffset && !UseFPOffset) {
34618     // If we only pull from the overflow region, we don't create a branch.
34619     // We don't need to alter control flow.
34620     OffsetDestReg = 0; // unused
34621     OverflowDestReg = DestReg;
34622 
34623     offsetMBB = nullptr;
34624     overflowMBB = thisMBB;
34625     endMBB = thisMBB;
34626   } else {
34627     // First emit code to check if gp_offset (or fp_offset) is below the bound.
34628     // If so, pull the argument from reg_save_area. (branch to offsetMBB)
34629     // If not, pull from overflow_area. (branch to overflowMBB)
34630     //
34631     //       thisMBB
34632     //         |     .
34633     //         |        .
34634     //     offsetMBB   overflowMBB
34635     //         |        .
34636     //         |     .
34637     //        endMBB
34638 
34639     // Registers for the PHI in endMBB
34640     OffsetDestReg = MRI.createVirtualRegister(AddrRegClass);
34641     OverflowDestReg = MRI.createVirtualRegister(AddrRegClass);
34642 
34643     const BasicBlock *LLVM_BB = MBB->getBasicBlock();
34644     overflowMBB = MF->CreateMachineBasicBlock(LLVM_BB);
34645     offsetMBB = MF->CreateMachineBasicBlock(LLVM_BB);
34646     endMBB = MF->CreateMachineBasicBlock(LLVM_BB);
34647 
34648     MachineFunction::iterator MBBIter = ++MBB->getIterator();
34649 
34650     // Insert the new basic blocks
34651     MF->insert(MBBIter, offsetMBB);
34652     MF->insert(MBBIter, overflowMBB);
34653     MF->insert(MBBIter, endMBB);
34654 
34655     // Transfer the remainder of MBB and its successor edges to endMBB.
34656     endMBB->splice(endMBB->begin(), thisMBB,
34657                    std::next(MachineBasicBlock::iterator(MI)), thisMBB->end());
34658     endMBB->transferSuccessorsAndUpdatePHIs(thisMBB);
34659 
34660     // Make offsetMBB and overflowMBB successors of thisMBB
34661     thisMBB->addSuccessor(offsetMBB);
34662     thisMBB->addSuccessor(overflowMBB);
34663 
34664     // endMBB is a successor of both offsetMBB and overflowMBB
34665     offsetMBB->addSuccessor(endMBB);
34666     overflowMBB->addSuccessor(endMBB);
34667 
34668     // Load the offset value into a register
34669     OffsetReg = MRI.createVirtualRegister(OffsetRegClass);
34670     BuildMI(thisMBB, MIMD, TII->get(X86::MOV32rm), OffsetReg)
34671         .add(Base)
34672         .add(Scale)
34673         .add(Index)
34674         .addDisp(Disp, UseFPOffset ? 4 : 0)
34675         .add(Segment)
34676         .setMemRefs(LoadOnlyMMO);
34677 
34678     // Check if there is enough room left to pull this argument.
34679     BuildMI(thisMBB, MIMD, TII->get(X86::CMP32ri))
34680       .addReg(OffsetReg)
34681       .addImm(MaxOffset + 8 - ArgSizeA8);
34682 
34683     // Branch to "overflowMBB" if offset >= max
34684     // Fall through to "offsetMBB" otherwise
34685     BuildMI(thisMBB, MIMD, TII->get(X86::JCC_1))
34686       .addMBB(overflowMBB).addImm(X86::COND_AE);
34687   }
34688 
34689   // In offsetMBB, emit code to use the reg_save_area.
34690   if (offsetMBB) {
34691     assert(OffsetReg != 0);
34692 
34693     // Read the reg_save_area address.
34694     Register RegSaveReg = MRI.createVirtualRegister(AddrRegClass);
34695     BuildMI(
34696         offsetMBB, MIMD,
34697         TII->get(Subtarget.isTarget64BitLP64() ? X86::MOV64rm : X86::MOV32rm),
34698         RegSaveReg)
34699         .add(Base)
34700         .add(Scale)
34701         .add(Index)
34702         .addDisp(Disp, Subtarget.isTarget64BitLP64() ? 16 : 12)
34703         .add(Segment)
34704         .setMemRefs(LoadOnlyMMO);
34705 
34706     if (Subtarget.isTarget64BitLP64()) {
34707       // Zero-extend the offset
34708       Register OffsetReg64 = MRI.createVirtualRegister(AddrRegClass);
34709       BuildMI(offsetMBB, MIMD, TII->get(X86::SUBREG_TO_REG), OffsetReg64)
34710           .addImm(0)
34711           .addReg(OffsetReg)
34712           .addImm(X86::sub_32bit);
34713 
34714       // Add the offset to the reg_save_area to get the final address.
34715       BuildMI(offsetMBB, MIMD, TII->get(X86::ADD64rr), OffsetDestReg)
34716           .addReg(OffsetReg64)
34717           .addReg(RegSaveReg);
34718     } else {
34719       // Add the offset to the reg_save_area to get the final address.
34720       BuildMI(offsetMBB, MIMD, TII->get(X86::ADD32rr), OffsetDestReg)
34721           .addReg(OffsetReg)
34722           .addReg(RegSaveReg);
34723     }
34724 
34725     // Compute the offset for the next argument
34726     Register NextOffsetReg = MRI.createVirtualRegister(OffsetRegClass);
34727     BuildMI(offsetMBB, MIMD, TII->get(X86::ADD32ri), NextOffsetReg)
34728       .addReg(OffsetReg)
34729       .addImm(UseFPOffset ? 16 : 8);
34730 
34731     // Store it back into the va_list.
34732     BuildMI(offsetMBB, MIMD, TII->get(X86::MOV32mr))
34733         .add(Base)
34734         .add(Scale)
34735         .add(Index)
34736         .addDisp(Disp, UseFPOffset ? 4 : 0)
34737         .add(Segment)
34738         .addReg(NextOffsetReg)
34739         .setMemRefs(StoreOnlyMMO);
34740 
34741     // Jump to endMBB
34742     BuildMI(offsetMBB, MIMD, TII->get(X86::JMP_1))
34743       .addMBB(endMBB);
34744   }
34745 
34746   //
34747   // Emit code to use overflow area
34748   //
34749 
34750   // Load the overflow_area address into a register.
34751   Register OverflowAddrReg = MRI.createVirtualRegister(AddrRegClass);
34752   BuildMI(overflowMBB, MIMD,
34753           TII->get(Subtarget.isTarget64BitLP64() ? X86::MOV64rm : X86::MOV32rm),
34754           OverflowAddrReg)
34755       .add(Base)
34756       .add(Scale)
34757       .add(Index)
34758       .addDisp(Disp, 8)
34759       .add(Segment)
34760       .setMemRefs(LoadOnlyMMO);
34761 
34762   // If we need to align it, do so. Otherwise, just copy the address
34763   // to OverflowDestReg.
34764   if (NeedsAlign) {
34765     // Align the overflow address
34766     Register TmpReg = MRI.createVirtualRegister(AddrRegClass);
34767 
34768     // aligned_addr = (addr + (align-1)) & ~(align-1)
34769     BuildMI(
34770         overflowMBB, MIMD,
34771         TII->get(Subtarget.isTarget64BitLP64() ? X86::ADD64ri32 : X86::ADD32ri),
34772         TmpReg)
34773         .addReg(OverflowAddrReg)
34774         .addImm(Alignment.value() - 1);
34775 
34776     BuildMI(
34777         overflowMBB, MIMD,
34778         TII->get(Subtarget.isTarget64BitLP64() ? X86::AND64ri32 : X86::AND32ri),
34779         OverflowDestReg)
34780         .addReg(TmpReg)
34781         .addImm(~(uint64_t)(Alignment.value() - 1));
34782   } else {
34783     BuildMI(overflowMBB, MIMD, TII->get(TargetOpcode::COPY), OverflowDestReg)
34784       .addReg(OverflowAddrReg);
34785   }
34786 
34787   // Compute the next overflow address after this argument.
34788   // (the overflow address should be kept 8-byte aligned)
34789   Register NextAddrReg = MRI.createVirtualRegister(AddrRegClass);
34790   BuildMI(
34791       overflowMBB, MIMD,
34792       TII->get(Subtarget.isTarget64BitLP64() ? X86::ADD64ri32 : X86::ADD32ri),
34793       NextAddrReg)
34794       .addReg(OverflowDestReg)
34795       .addImm(ArgSizeA8);
34796 
34797   // Store the new overflow address.
34798   BuildMI(overflowMBB, MIMD,
34799           TII->get(Subtarget.isTarget64BitLP64() ? X86::MOV64mr : X86::MOV32mr))
34800       .add(Base)
34801       .add(Scale)
34802       .add(Index)
34803       .addDisp(Disp, 8)
34804       .add(Segment)
34805       .addReg(NextAddrReg)
34806       .setMemRefs(StoreOnlyMMO);
34807 
34808   // If we branched, emit the PHI to the front of endMBB.
34809   if (offsetMBB) {
34810     BuildMI(*endMBB, endMBB->begin(), MIMD,
34811             TII->get(X86::PHI), DestReg)
34812       .addReg(OffsetDestReg).addMBB(offsetMBB)
34813       .addReg(OverflowDestReg).addMBB(overflowMBB);
34814   }
34815 
34816   // Erase the pseudo instruction
34817   MI.eraseFromParent();
34818 
34819   return endMBB;
34820 }
34821 
34822 // The EFLAGS operand of SelectItr might be missing a kill marker
34823 // because there were multiple uses of EFLAGS, and ISel didn't know
34824 // which to mark. Figure out whether SelectItr should have had a
34825 // kill marker, and set it if it should. Returns the correct kill
34826 // marker value.
checkAndUpdateEFLAGSKill(MachineBasicBlock::iterator SelectItr,MachineBasicBlock * BB,const TargetRegisterInfo * TRI)34827 static bool checkAndUpdateEFLAGSKill(MachineBasicBlock::iterator SelectItr,
34828                                      MachineBasicBlock* BB,
34829                                      const TargetRegisterInfo* TRI) {
34830   if (isEFLAGSLiveAfter(SelectItr, BB))
34831     return false;
34832 
34833   // We found a def, or hit the end of the basic block and EFLAGS wasn't live
34834   // out. SelectMI should have a kill flag on EFLAGS.
34835   SelectItr->addRegisterKilled(X86::EFLAGS, TRI);
34836   return true;
34837 }
34838 
34839 // Return true if it is OK for this CMOV pseudo-opcode to be cascaded
34840 // together with other CMOV pseudo-opcodes into a single basic-block with
34841 // conditional jump around it.
isCMOVPseudo(MachineInstr & MI)34842 static bool isCMOVPseudo(MachineInstr &MI) {
34843   switch (MI.getOpcode()) {
34844   case X86::CMOV_FR16:
34845   case X86::CMOV_FR16X:
34846   case X86::CMOV_FR32:
34847   case X86::CMOV_FR32X:
34848   case X86::CMOV_FR64:
34849   case X86::CMOV_FR64X:
34850   case X86::CMOV_GR8:
34851   case X86::CMOV_GR16:
34852   case X86::CMOV_GR32:
34853   case X86::CMOV_RFP32:
34854   case X86::CMOV_RFP64:
34855   case X86::CMOV_RFP80:
34856   case X86::CMOV_VR64:
34857   case X86::CMOV_VR128:
34858   case X86::CMOV_VR128X:
34859   case X86::CMOV_VR256:
34860   case X86::CMOV_VR256X:
34861   case X86::CMOV_VR512:
34862   case X86::CMOV_VK1:
34863   case X86::CMOV_VK2:
34864   case X86::CMOV_VK4:
34865   case X86::CMOV_VK8:
34866   case X86::CMOV_VK16:
34867   case X86::CMOV_VK32:
34868   case X86::CMOV_VK64:
34869     return true;
34870 
34871   default:
34872     return false;
34873   }
34874 }
34875 
34876 // Helper function, which inserts PHI functions into SinkMBB:
34877 //   %Result(i) = phi [ %FalseValue(i), FalseMBB ], [ %TrueValue(i), TrueMBB ],
34878 // where %FalseValue(i) and %TrueValue(i) are taken from the consequent CMOVs
34879 // in [MIItBegin, MIItEnd) range. It returns the last MachineInstrBuilder for
34880 // the last PHI function inserted.
createPHIsForCMOVsInSinkBB(MachineBasicBlock::iterator MIItBegin,MachineBasicBlock::iterator MIItEnd,MachineBasicBlock * TrueMBB,MachineBasicBlock * FalseMBB,MachineBasicBlock * SinkMBB)34881 static MachineInstrBuilder createPHIsForCMOVsInSinkBB(
34882     MachineBasicBlock::iterator MIItBegin, MachineBasicBlock::iterator MIItEnd,
34883     MachineBasicBlock *TrueMBB, MachineBasicBlock *FalseMBB,
34884     MachineBasicBlock *SinkMBB) {
34885   MachineFunction *MF = TrueMBB->getParent();
34886   const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
34887   const MIMetadata MIMD(*MIItBegin);
34888 
34889   X86::CondCode CC = X86::CondCode(MIItBegin->getOperand(3).getImm());
34890   X86::CondCode OppCC = X86::GetOppositeBranchCondition(CC);
34891 
34892   MachineBasicBlock::iterator SinkInsertionPoint = SinkMBB->begin();
34893 
34894   // As we are creating the PHIs, we have to be careful if there is more than
34895   // one.  Later CMOVs may reference the results of earlier CMOVs, but later
34896   // PHIs have to reference the individual true/false inputs from earlier PHIs.
34897   // That also means that PHI construction must work forward from earlier to
34898   // later, and that the code must maintain a mapping from earlier PHI's
34899   // destination registers, and the registers that went into the PHI.
34900   DenseMap<unsigned, std::pair<unsigned, unsigned>> RegRewriteTable;
34901   MachineInstrBuilder MIB;
34902 
34903   for (MachineBasicBlock::iterator MIIt = MIItBegin; MIIt != MIItEnd; ++MIIt) {
34904     Register DestReg = MIIt->getOperand(0).getReg();
34905     Register Op1Reg = MIIt->getOperand(1).getReg();
34906     Register Op2Reg = MIIt->getOperand(2).getReg();
34907 
34908     // If this CMOV we are generating is the opposite condition from
34909     // the jump we generated, then we have to swap the operands for the
34910     // PHI that is going to be generated.
34911     if (MIIt->getOperand(3).getImm() == OppCC)
34912       std::swap(Op1Reg, Op2Reg);
34913 
34914     if (RegRewriteTable.contains(Op1Reg))
34915       Op1Reg = RegRewriteTable[Op1Reg].first;
34916 
34917     if (RegRewriteTable.contains(Op2Reg))
34918       Op2Reg = RegRewriteTable[Op2Reg].second;
34919 
34920     MIB =
34921         BuildMI(*SinkMBB, SinkInsertionPoint, MIMD, TII->get(X86::PHI), DestReg)
34922             .addReg(Op1Reg)
34923             .addMBB(FalseMBB)
34924             .addReg(Op2Reg)
34925             .addMBB(TrueMBB);
34926 
34927     // Add this PHI to the rewrite table.
34928     RegRewriteTable[DestReg] = std::make_pair(Op1Reg, Op2Reg);
34929   }
34930 
34931   return MIB;
34932 }
34933 
34934 // Lower cascaded selects in form of (SecondCmov (FirstCMOV F, T, cc1), T, cc2).
34935 MachineBasicBlock *
EmitLoweredCascadedSelect(MachineInstr & FirstCMOV,MachineInstr & SecondCascadedCMOV,MachineBasicBlock * ThisMBB) const34936 X86TargetLowering::EmitLoweredCascadedSelect(MachineInstr &FirstCMOV,
34937                                              MachineInstr &SecondCascadedCMOV,
34938                                              MachineBasicBlock *ThisMBB) const {
34939   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
34940   const MIMetadata MIMD(FirstCMOV);
34941 
34942   // We lower cascaded CMOVs such as
34943   //
34944   //   (SecondCascadedCMOV (FirstCMOV F, T, cc1), T, cc2)
34945   //
34946   // to two successive branches.
34947   //
34948   // Without this, we would add a PHI between the two jumps, which ends up
34949   // creating a few copies all around. For instance, for
34950   //
34951   //    (sitofp (zext (fcmp une)))
34952   //
34953   // we would generate:
34954   //
34955   //         ucomiss %xmm1, %xmm0
34956   //         movss  <1.0f>, %xmm0
34957   //         movaps  %xmm0, %xmm1
34958   //         jne     .LBB5_2
34959   //         xorps   %xmm1, %xmm1
34960   // .LBB5_2:
34961   //         jp      .LBB5_4
34962   //         movaps  %xmm1, %xmm0
34963   // .LBB5_4:
34964   //         retq
34965   //
34966   // because this custom-inserter would have generated:
34967   //
34968   //   A
34969   //   | \
34970   //   |  B
34971   //   | /
34972   //   C
34973   //   | \
34974   //   |  D
34975   //   | /
34976   //   E
34977   //
34978   // A: X = ...; Y = ...
34979   // B: empty
34980   // C: Z = PHI [X, A], [Y, B]
34981   // D: empty
34982   // E: PHI [X, C], [Z, D]
34983   //
34984   // If we lower both CMOVs in a single step, we can instead generate:
34985   //
34986   //   A
34987   //   | \
34988   //   |  C
34989   //   | /|
34990   //   |/ |
34991   //   |  |
34992   //   |  D
34993   //   | /
34994   //   E
34995   //
34996   // A: X = ...; Y = ...
34997   // D: empty
34998   // E: PHI [X, A], [X, C], [Y, D]
34999   //
35000   // Which, in our sitofp/fcmp example, gives us something like:
35001   //
35002   //         ucomiss %xmm1, %xmm0
35003   //         movss  <1.0f>, %xmm0
35004   //         jne     .LBB5_4
35005   //         jp      .LBB5_4
35006   //         xorps   %xmm0, %xmm0
35007   // .LBB5_4:
35008   //         retq
35009   //
35010 
35011   // We lower cascaded CMOV into two successive branches to the same block.
35012   // EFLAGS is used by both, so mark it as live in the second.
35013   const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock();
35014   MachineFunction *F = ThisMBB->getParent();
35015   MachineBasicBlock *FirstInsertedMBB = F->CreateMachineBasicBlock(LLVM_BB);
35016   MachineBasicBlock *SecondInsertedMBB = F->CreateMachineBasicBlock(LLVM_BB);
35017   MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB);
35018 
35019   MachineFunction::iterator It = ++ThisMBB->getIterator();
35020   F->insert(It, FirstInsertedMBB);
35021   F->insert(It, SecondInsertedMBB);
35022   F->insert(It, SinkMBB);
35023 
35024   // For a cascaded CMOV, we lower it to two successive branches to
35025   // the same block (SinkMBB).  EFLAGS is used by both, so mark it as live in
35026   // the FirstInsertedMBB.
35027   FirstInsertedMBB->addLiveIn(X86::EFLAGS);
35028 
35029   // If the EFLAGS register isn't dead in the terminator, then claim that it's
35030   // live into the sink and copy blocks.
35031   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
35032   if (!SecondCascadedCMOV.killsRegister(X86::EFLAGS, /*TRI=*/nullptr) &&
35033       !checkAndUpdateEFLAGSKill(SecondCascadedCMOV, ThisMBB, TRI)) {
35034     SecondInsertedMBB->addLiveIn(X86::EFLAGS);
35035     SinkMBB->addLiveIn(X86::EFLAGS);
35036   }
35037 
35038   // Transfer the remainder of ThisMBB and its successor edges to SinkMBB.
35039   SinkMBB->splice(SinkMBB->begin(), ThisMBB,
35040                   std::next(MachineBasicBlock::iterator(FirstCMOV)),
35041                   ThisMBB->end());
35042   SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB);
35043 
35044   // Fallthrough block for ThisMBB.
35045   ThisMBB->addSuccessor(FirstInsertedMBB);
35046   // The true block target of the first branch is always SinkMBB.
35047   ThisMBB->addSuccessor(SinkMBB);
35048   // Fallthrough block for FirstInsertedMBB.
35049   FirstInsertedMBB->addSuccessor(SecondInsertedMBB);
35050   // The true block for the branch of FirstInsertedMBB.
35051   FirstInsertedMBB->addSuccessor(SinkMBB);
35052   // This is fallthrough.
35053   SecondInsertedMBB->addSuccessor(SinkMBB);
35054 
35055   // Create the conditional branch instructions.
35056   X86::CondCode FirstCC = X86::CondCode(FirstCMOV.getOperand(3).getImm());
35057   BuildMI(ThisMBB, MIMD, TII->get(X86::JCC_1)).addMBB(SinkMBB).addImm(FirstCC);
35058 
35059   X86::CondCode SecondCC =
35060       X86::CondCode(SecondCascadedCMOV.getOperand(3).getImm());
35061   BuildMI(FirstInsertedMBB, MIMD, TII->get(X86::JCC_1))
35062       .addMBB(SinkMBB)
35063       .addImm(SecondCC);
35064 
35065   //  SinkMBB:
35066   //   %Result = phi [ %FalseValue, SecondInsertedMBB ], [ %TrueValue, ThisMBB ]
35067   Register DestReg = SecondCascadedCMOV.getOperand(0).getReg();
35068   Register Op1Reg = FirstCMOV.getOperand(1).getReg();
35069   Register Op2Reg = FirstCMOV.getOperand(2).getReg();
35070   MachineInstrBuilder MIB =
35071       BuildMI(*SinkMBB, SinkMBB->begin(), MIMD, TII->get(X86::PHI), DestReg)
35072           .addReg(Op1Reg)
35073           .addMBB(SecondInsertedMBB)
35074           .addReg(Op2Reg)
35075           .addMBB(ThisMBB);
35076 
35077   // The second SecondInsertedMBB provides the same incoming value as the
35078   // FirstInsertedMBB (the True operand of the SELECT_CC/CMOV nodes).
35079   MIB.addReg(FirstCMOV.getOperand(2).getReg()).addMBB(FirstInsertedMBB);
35080 
35081   // Now remove the CMOVs.
35082   FirstCMOV.eraseFromParent();
35083   SecondCascadedCMOV.eraseFromParent();
35084 
35085   return SinkMBB;
35086 }
35087 
35088 MachineBasicBlock *
EmitLoweredSelect(MachineInstr & MI,MachineBasicBlock * ThisMBB) const35089 X86TargetLowering::EmitLoweredSelect(MachineInstr &MI,
35090                                      MachineBasicBlock *ThisMBB) const {
35091   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
35092   const MIMetadata MIMD(MI);
35093 
35094   // To "insert" a SELECT_CC instruction, we actually have to insert the
35095   // diamond control-flow pattern.  The incoming instruction knows the
35096   // destination vreg to set, the condition code register to branch on, the
35097   // true/false values to select between and a branch opcode to use.
35098 
35099   //  ThisMBB:
35100   //  ...
35101   //   TrueVal = ...
35102   //   cmpTY ccX, r1, r2
35103   //   bCC copy1MBB
35104   //   fallthrough --> FalseMBB
35105 
35106   // This code lowers all pseudo-CMOV instructions. Generally it lowers these
35107   // as described above, by inserting a BB, and then making a PHI at the join
35108   // point to select the true and false operands of the CMOV in the PHI.
35109   //
35110   // The code also handles two different cases of multiple CMOV opcodes
35111   // in a row.
35112   //
35113   // Case 1:
35114   // In this case, there are multiple CMOVs in a row, all which are based on
35115   // the same condition setting (or the exact opposite condition setting).
35116   // In this case we can lower all the CMOVs using a single inserted BB, and
35117   // then make a number of PHIs at the join point to model the CMOVs. The only
35118   // trickiness here, is that in a case like:
35119   //
35120   // t2 = CMOV cond1 t1, f1
35121   // t3 = CMOV cond1 t2, f2
35122   //
35123   // when rewriting this into PHIs, we have to perform some renaming on the
35124   // temps since you cannot have a PHI operand refer to a PHI result earlier
35125   // in the same block.  The "simple" but wrong lowering would be:
35126   //
35127   // t2 = PHI t1(BB1), f1(BB2)
35128   // t3 = PHI t2(BB1), f2(BB2)
35129   //
35130   // but clearly t2 is not defined in BB1, so that is incorrect. The proper
35131   // renaming is to note that on the path through BB1, t2 is really just a
35132   // copy of t1, and do that renaming, properly generating:
35133   //
35134   // t2 = PHI t1(BB1), f1(BB2)
35135   // t3 = PHI t1(BB1), f2(BB2)
35136   //
35137   // Case 2:
35138   // CMOV ((CMOV F, T, cc1), T, cc2) is checked here and handled by a separate
35139   // function - EmitLoweredCascadedSelect.
35140 
35141   X86::CondCode CC = X86::CondCode(MI.getOperand(3).getImm());
35142   X86::CondCode OppCC = X86::GetOppositeBranchCondition(CC);
35143   MachineInstr *LastCMOV = &MI;
35144   MachineBasicBlock::iterator NextMIIt = MachineBasicBlock::iterator(MI);
35145 
35146   // Check for case 1, where there are multiple CMOVs with the same condition
35147   // first.  Of the two cases of multiple CMOV lowerings, case 1 reduces the
35148   // number of jumps the most.
35149 
35150   if (isCMOVPseudo(MI)) {
35151     // See if we have a string of CMOVS with the same condition. Skip over
35152     // intervening debug insts.
35153     while (NextMIIt != ThisMBB->end() && isCMOVPseudo(*NextMIIt) &&
35154            (NextMIIt->getOperand(3).getImm() == CC ||
35155             NextMIIt->getOperand(3).getImm() == OppCC)) {
35156       LastCMOV = &*NextMIIt;
35157       NextMIIt = next_nodbg(NextMIIt, ThisMBB->end());
35158     }
35159   }
35160 
35161   // This checks for case 2, but only do this if we didn't already find
35162   // case 1, as indicated by LastCMOV == MI.
35163   if (LastCMOV == &MI && NextMIIt != ThisMBB->end() &&
35164       NextMIIt->getOpcode() == MI.getOpcode() &&
35165       NextMIIt->getOperand(2).getReg() == MI.getOperand(2).getReg() &&
35166       NextMIIt->getOperand(1).getReg() == MI.getOperand(0).getReg() &&
35167       NextMIIt->getOperand(1).isKill()) {
35168     return EmitLoweredCascadedSelect(MI, *NextMIIt, ThisMBB);
35169   }
35170 
35171   const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock();
35172   MachineFunction *F = ThisMBB->getParent();
35173   MachineBasicBlock *FalseMBB = F->CreateMachineBasicBlock(LLVM_BB);
35174   MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB);
35175 
35176   MachineFunction::iterator It = ++ThisMBB->getIterator();
35177   F->insert(It, FalseMBB);
35178   F->insert(It, SinkMBB);
35179 
35180   // Set the call frame size on entry to the new basic blocks.
35181   unsigned CallFrameSize = TII->getCallFrameSizeAt(MI);
35182   FalseMBB->setCallFrameSize(CallFrameSize);
35183   SinkMBB->setCallFrameSize(CallFrameSize);
35184 
35185   // If the EFLAGS register isn't dead in the terminator, then claim that it's
35186   // live into the sink and copy blocks.
35187   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
35188   if (!LastCMOV->killsRegister(X86::EFLAGS, /*TRI=*/nullptr) &&
35189       !checkAndUpdateEFLAGSKill(LastCMOV, ThisMBB, TRI)) {
35190     FalseMBB->addLiveIn(X86::EFLAGS);
35191     SinkMBB->addLiveIn(X86::EFLAGS);
35192   }
35193 
35194   // Transfer any debug instructions inside the CMOV sequence to the sunk block.
35195   auto DbgRange = llvm::make_range(MachineBasicBlock::iterator(MI),
35196                                    MachineBasicBlock::iterator(LastCMOV));
35197   for (MachineInstr &MI : llvm::make_early_inc_range(DbgRange))
35198     if (MI.isDebugInstr())
35199       SinkMBB->push_back(MI.removeFromParent());
35200 
35201   // Transfer the remainder of ThisMBB and its successor edges to SinkMBB.
35202   SinkMBB->splice(SinkMBB->end(), ThisMBB,
35203                   std::next(MachineBasicBlock::iterator(LastCMOV)),
35204                   ThisMBB->end());
35205   SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB);
35206 
35207   // Fallthrough block for ThisMBB.
35208   ThisMBB->addSuccessor(FalseMBB);
35209   // The true block target of the first (or only) branch is always a SinkMBB.
35210   ThisMBB->addSuccessor(SinkMBB);
35211   // Fallthrough block for FalseMBB.
35212   FalseMBB->addSuccessor(SinkMBB);
35213 
35214   // Create the conditional branch instruction.
35215   BuildMI(ThisMBB, MIMD, TII->get(X86::JCC_1)).addMBB(SinkMBB).addImm(CC);
35216 
35217   //  SinkMBB:
35218   //   %Result = phi [ %FalseValue, FalseMBB ], [ %TrueValue, ThisMBB ]
35219   //  ...
35220   MachineBasicBlock::iterator MIItBegin = MachineBasicBlock::iterator(MI);
35221   MachineBasicBlock::iterator MIItEnd =
35222       std::next(MachineBasicBlock::iterator(LastCMOV));
35223   createPHIsForCMOVsInSinkBB(MIItBegin, MIItEnd, ThisMBB, FalseMBB, SinkMBB);
35224 
35225   // Now remove the CMOV(s).
35226   ThisMBB->erase(MIItBegin, MIItEnd);
35227 
35228   return SinkMBB;
35229 }
35230 
getSUBriOpcode(bool IsLP64)35231 static unsigned getSUBriOpcode(bool IsLP64) {
35232   if (IsLP64)
35233     return X86::SUB64ri32;
35234   else
35235     return X86::SUB32ri;
35236 }
35237 
35238 MachineBasicBlock *
EmitLoweredProbedAlloca(MachineInstr & MI,MachineBasicBlock * MBB) const35239 X86TargetLowering::EmitLoweredProbedAlloca(MachineInstr &MI,
35240                                            MachineBasicBlock *MBB) const {
35241   MachineFunction *MF = MBB->getParent();
35242   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
35243   const X86FrameLowering &TFI = *Subtarget.getFrameLowering();
35244   const MIMetadata MIMD(MI);
35245   const BasicBlock *LLVM_BB = MBB->getBasicBlock();
35246 
35247   const unsigned ProbeSize = getStackProbeSize(*MF);
35248 
35249   MachineRegisterInfo &MRI = MF->getRegInfo();
35250   MachineBasicBlock *testMBB = MF->CreateMachineBasicBlock(LLVM_BB);
35251   MachineBasicBlock *tailMBB = MF->CreateMachineBasicBlock(LLVM_BB);
35252   MachineBasicBlock *blockMBB = MF->CreateMachineBasicBlock(LLVM_BB);
35253 
35254   MachineFunction::iterator MBBIter = ++MBB->getIterator();
35255   MF->insert(MBBIter, testMBB);
35256   MF->insert(MBBIter, blockMBB);
35257   MF->insert(MBBIter, tailMBB);
35258 
35259   Register sizeVReg = MI.getOperand(1).getReg();
35260 
35261   Register physSPReg = TFI.Uses64BitFramePtr ? X86::RSP : X86::ESP;
35262 
35263   Register TmpStackPtr = MRI.createVirtualRegister(
35264       TFI.Uses64BitFramePtr ? &X86::GR64RegClass : &X86::GR32RegClass);
35265   Register FinalStackPtr = MRI.createVirtualRegister(
35266       TFI.Uses64BitFramePtr ? &X86::GR64RegClass : &X86::GR32RegClass);
35267 
35268   BuildMI(*MBB, {MI}, MIMD, TII->get(TargetOpcode::COPY), TmpStackPtr)
35269       .addReg(physSPReg);
35270   {
35271     const unsigned Opc = TFI.Uses64BitFramePtr ? X86::SUB64rr : X86::SUB32rr;
35272     BuildMI(*MBB, {MI}, MIMD, TII->get(Opc), FinalStackPtr)
35273         .addReg(TmpStackPtr)
35274         .addReg(sizeVReg);
35275   }
35276 
35277   // test rsp size
35278 
35279   BuildMI(testMBB, MIMD,
35280           TII->get(TFI.Uses64BitFramePtr ? X86::CMP64rr : X86::CMP32rr))
35281       .addReg(FinalStackPtr)
35282       .addReg(physSPReg);
35283 
35284   BuildMI(testMBB, MIMD, TII->get(X86::JCC_1))
35285       .addMBB(tailMBB)
35286       .addImm(X86::COND_GE);
35287   testMBB->addSuccessor(blockMBB);
35288   testMBB->addSuccessor(tailMBB);
35289 
35290   // Touch the block then extend it. This is done on the opposite side of
35291   // static probe where we allocate then touch, to avoid the need of probing the
35292   // tail of the static alloca. Possible scenarios are:
35293   //
35294   //       + ---- <- ------------ <- ------------- <- ------------ +
35295   //       |                                                       |
35296   // [free probe] -> [page alloc] -> [alloc probe] -> [tail alloc] + -> [dyn probe] -> [page alloc] -> [dyn probe] -> [tail alloc] +
35297   //                                                               |                                                               |
35298   //                                                               + <- ----------- <- ------------ <- ----------- <- ------------ +
35299   //
35300   // The property we want to enforce is to never have more than [page alloc] between two probes.
35301 
35302   const unsigned XORMIOpc =
35303       TFI.Uses64BitFramePtr ? X86::XOR64mi32 : X86::XOR32mi;
35304   addRegOffset(BuildMI(blockMBB, MIMD, TII->get(XORMIOpc)), physSPReg, false, 0)
35305       .addImm(0);
35306 
35307   BuildMI(blockMBB, MIMD, TII->get(getSUBriOpcode(TFI.Uses64BitFramePtr)),
35308           physSPReg)
35309       .addReg(physSPReg)
35310       .addImm(ProbeSize);
35311 
35312   BuildMI(blockMBB, MIMD, TII->get(X86::JMP_1)).addMBB(testMBB);
35313   blockMBB->addSuccessor(testMBB);
35314 
35315   // Replace original instruction by the expected stack ptr
35316   BuildMI(tailMBB, MIMD, TII->get(TargetOpcode::COPY),
35317           MI.getOperand(0).getReg())
35318       .addReg(FinalStackPtr);
35319 
35320   tailMBB->splice(tailMBB->end(), MBB,
35321                   std::next(MachineBasicBlock::iterator(MI)), MBB->end());
35322   tailMBB->transferSuccessorsAndUpdatePHIs(MBB);
35323   MBB->addSuccessor(testMBB);
35324 
35325   // Delete the original pseudo instruction.
35326   MI.eraseFromParent();
35327 
35328   // And we're done.
35329   return tailMBB;
35330 }
35331 
35332 MachineBasicBlock *
EmitLoweredSegAlloca(MachineInstr & MI,MachineBasicBlock * BB) const35333 X86TargetLowering::EmitLoweredSegAlloca(MachineInstr &MI,
35334                                         MachineBasicBlock *BB) const {
35335   MachineFunction *MF = BB->getParent();
35336   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
35337   const MIMetadata MIMD(MI);
35338   const BasicBlock *LLVM_BB = BB->getBasicBlock();
35339 
35340   assert(MF->shouldSplitStack());
35341 
35342   const bool Is64Bit = Subtarget.is64Bit();
35343   const bool IsLP64 = Subtarget.isTarget64BitLP64();
35344 
35345   const unsigned TlsReg = Is64Bit ? X86::FS : X86::GS;
35346   const unsigned TlsOffset = IsLP64 ? 0x70 : Is64Bit ? 0x40 : 0x30;
35347 
35348   // BB:
35349   //  ... [Till the alloca]
35350   // If stacklet is not large enough, jump to mallocMBB
35351   //
35352   // bumpMBB:
35353   //  Allocate by subtracting from RSP
35354   //  Jump to continueMBB
35355   //
35356   // mallocMBB:
35357   //  Allocate by call to runtime
35358   //
35359   // continueMBB:
35360   //  ...
35361   //  [rest of original BB]
35362   //
35363 
35364   MachineBasicBlock *mallocMBB = MF->CreateMachineBasicBlock(LLVM_BB);
35365   MachineBasicBlock *bumpMBB = MF->CreateMachineBasicBlock(LLVM_BB);
35366   MachineBasicBlock *continueMBB = MF->CreateMachineBasicBlock(LLVM_BB);
35367 
35368   MachineRegisterInfo &MRI = MF->getRegInfo();
35369   const TargetRegisterClass *AddrRegClass =
35370       getRegClassFor(getPointerTy(MF->getDataLayout()));
35371 
35372   Register mallocPtrVReg = MRI.createVirtualRegister(AddrRegClass),
35373            bumpSPPtrVReg = MRI.createVirtualRegister(AddrRegClass),
35374            tmpSPVReg = MRI.createVirtualRegister(AddrRegClass),
35375            SPLimitVReg = MRI.createVirtualRegister(AddrRegClass),
35376            sizeVReg = MI.getOperand(1).getReg(),
35377            physSPReg =
35378                IsLP64 || Subtarget.isTargetNaCl64() ? X86::RSP : X86::ESP;
35379 
35380   MachineFunction::iterator MBBIter = ++BB->getIterator();
35381 
35382   MF->insert(MBBIter, bumpMBB);
35383   MF->insert(MBBIter, mallocMBB);
35384   MF->insert(MBBIter, continueMBB);
35385 
35386   continueMBB->splice(continueMBB->begin(), BB,
35387                       std::next(MachineBasicBlock::iterator(MI)), BB->end());
35388   continueMBB->transferSuccessorsAndUpdatePHIs(BB);
35389 
35390   // Add code to the main basic block to check if the stack limit has been hit,
35391   // and if so, jump to mallocMBB otherwise to bumpMBB.
35392   BuildMI(BB, MIMD, TII->get(TargetOpcode::COPY), tmpSPVReg).addReg(physSPReg);
35393   BuildMI(BB, MIMD, TII->get(IsLP64 ? X86::SUB64rr:X86::SUB32rr), SPLimitVReg)
35394     .addReg(tmpSPVReg).addReg(sizeVReg);
35395   BuildMI(BB, MIMD, TII->get(IsLP64 ? X86::CMP64mr:X86::CMP32mr))
35396     .addReg(0).addImm(1).addReg(0).addImm(TlsOffset).addReg(TlsReg)
35397     .addReg(SPLimitVReg);
35398   BuildMI(BB, MIMD, TII->get(X86::JCC_1)).addMBB(mallocMBB).addImm(X86::COND_G);
35399 
35400   // bumpMBB simply decreases the stack pointer, since we know the current
35401   // stacklet has enough space.
35402   BuildMI(bumpMBB, MIMD, TII->get(TargetOpcode::COPY), physSPReg)
35403     .addReg(SPLimitVReg);
35404   BuildMI(bumpMBB, MIMD, TII->get(TargetOpcode::COPY), bumpSPPtrVReg)
35405     .addReg(SPLimitVReg);
35406   BuildMI(bumpMBB, MIMD, TII->get(X86::JMP_1)).addMBB(continueMBB);
35407 
35408   // Calls into a routine in libgcc to allocate more space from the heap.
35409   const uint32_t *RegMask =
35410       Subtarget.getRegisterInfo()->getCallPreservedMask(*MF, CallingConv::C);
35411   if (IsLP64) {
35412     BuildMI(mallocMBB, MIMD, TII->get(X86::MOV64rr), X86::RDI)
35413       .addReg(sizeVReg);
35414     BuildMI(mallocMBB, MIMD, TII->get(X86::CALL64pcrel32))
35415       .addExternalSymbol("__morestack_allocate_stack_space")
35416       .addRegMask(RegMask)
35417       .addReg(X86::RDI, RegState::Implicit)
35418       .addReg(X86::RAX, RegState::ImplicitDefine);
35419   } else if (Is64Bit) {
35420     BuildMI(mallocMBB, MIMD, TII->get(X86::MOV32rr), X86::EDI)
35421       .addReg(sizeVReg);
35422     BuildMI(mallocMBB, MIMD, TII->get(X86::CALL64pcrel32))
35423       .addExternalSymbol("__morestack_allocate_stack_space")
35424       .addRegMask(RegMask)
35425       .addReg(X86::EDI, RegState::Implicit)
35426       .addReg(X86::EAX, RegState::ImplicitDefine);
35427   } else {
35428     BuildMI(mallocMBB, MIMD, TII->get(X86::SUB32ri), physSPReg).addReg(physSPReg)
35429       .addImm(12);
35430     BuildMI(mallocMBB, MIMD, TII->get(X86::PUSH32r)).addReg(sizeVReg);
35431     BuildMI(mallocMBB, MIMD, TII->get(X86::CALLpcrel32))
35432       .addExternalSymbol("__morestack_allocate_stack_space")
35433       .addRegMask(RegMask)
35434       .addReg(X86::EAX, RegState::ImplicitDefine);
35435   }
35436 
35437   if (!Is64Bit)
35438     BuildMI(mallocMBB, MIMD, TII->get(X86::ADD32ri), physSPReg).addReg(physSPReg)
35439       .addImm(16);
35440 
35441   BuildMI(mallocMBB, MIMD, TII->get(TargetOpcode::COPY), mallocPtrVReg)
35442     .addReg(IsLP64 ? X86::RAX : X86::EAX);
35443   BuildMI(mallocMBB, MIMD, TII->get(X86::JMP_1)).addMBB(continueMBB);
35444 
35445   // Set up the CFG correctly.
35446   BB->addSuccessor(bumpMBB);
35447   BB->addSuccessor(mallocMBB);
35448   mallocMBB->addSuccessor(continueMBB);
35449   bumpMBB->addSuccessor(continueMBB);
35450 
35451   // Take care of the PHI nodes.
35452   BuildMI(*continueMBB, continueMBB->begin(), MIMD, TII->get(X86::PHI),
35453           MI.getOperand(0).getReg())
35454       .addReg(mallocPtrVReg)
35455       .addMBB(mallocMBB)
35456       .addReg(bumpSPPtrVReg)
35457       .addMBB(bumpMBB);
35458 
35459   // Delete the original pseudo instruction.
35460   MI.eraseFromParent();
35461 
35462   // And we're done.
35463   return continueMBB;
35464 }
35465 
35466 MachineBasicBlock *
EmitLoweredCatchRet(MachineInstr & MI,MachineBasicBlock * BB) const35467 X86TargetLowering::EmitLoweredCatchRet(MachineInstr &MI,
35468                                        MachineBasicBlock *BB) const {
35469   MachineFunction *MF = BB->getParent();
35470   const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
35471   MachineBasicBlock *TargetMBB = MI.getOperand(0).getMBB();
35472   const MIMetadata MIMD(MI);
35473 
35474   assert(!isAsynchronousEHPersonality(
35475              classifyEHPersonality(MF->getFunction().getPersonalityFn())) &&
35476          "SEH does not use catchret!");
35477 
35478   // Only 32-bit EH needs to worry about manually restoring stack pointers.
35479   if (!Subtarget.is32Bit())
35480     return BB;
35481 
35482   // C++ EH creates a new target block to hold the restore code, and wires up
35483   // the new block to the return destination with a normal JMP_4.
35484   MachineBasicBlock *RestoreMBB =
35485       MF->CreateMachineBasicBlock(BB->getBasicBlock());
35486   assert(BB->succ_size() == 1);
35487   MF->insert(std::next(BB->getIterator()), RestoreMBB);
35488   RestoreMBB->transferSuccessorsAndUpdatePHIs(BB);
35489   BB->addSuccessor(RestoreMBB);
35490   MI.getOperand(0).setMBB(RestoreMBB);
35491 
35492   // Marking this as an EH pad but not a funclet entry block causes PEI to
35493   // restore stack pointers in the block.
35494   RestoreMBB->setIsEHPad(true);
35495 
35496   auto RestoreMBBI = RestoreMBB->begin();
35497   BuildMI(*RestoreMBB, RestoreMBBI, MIMD, TII.get(X86::JMP_4)).addMBB(TargetMBB);
35498   return BB;
35499 }
35500 
35501 MachineBasicBlock *
EmitLoweredTLSAddr(MachineInstr & MI,MachineBasicBlock * BB) const35502 X86TargetLowering::EmitLoweredTLSAddr(MachineInstr &MI,
35503                                       MachineBasicBlock *BB) const {
35504   // So, here we replace TLSADDR with the sequence:
35505   // adjust_stackdown -> TLSADDR -> adjust_stackup.
35506   // We need this because TLSADDR is lowered into calls
35507   // inside MC, therefore without the two markers shrink-wrapping
35508   // may push the prologue/epilogue pass them.
35509   const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
35510   const MIMetadata MIMD(MI);
35511   MachineFunction &MF = *BB->getParent();
35512 
35513   // Emit CALLSEQ_START right before the instruction.
35514   MF.getFrameInfo().setAdjustsStack(true);
35515   unsigned AdjStackDown = TII.getCallFrameSetupOpcode();
35516   MachineInstrBuilder CallseqStart =
35517       BuildMI(MF, MIMD, TII.get(AdjStackDown)).addImm(0).addImm(0).addImm(0);
35518   BB->insert(MachineBasicBlock::iterator(MI), CallseqStart);
35519 
35520   // Emit CALLSEQ_END right after the instruction.
35521   // We don't call erase from parent because we want to keep the
35522   // original instruction around.
35523   unsigned AdjStackUp = TII.getCallFrameDestroyOpcode();
35524   MachineInstrBuilder CallseqEnd =
35525       BuildMI(MF, MIMD, TII.get(AdjStackUp)).addImm(0).addImm(0);
35526   BB->insertAfter(MachineBasicBlock::iterator(MI), CallseqEnd);
35527 
35528   return BB;
35529 }
35530 
35531 MachineBasicBlock *
EmitLoweredTLSCall(MachineInstr & MI,MachineBasicBlock * BB) const35532 X86TargetLowering::EmitLoweredTLSCall(MachineInstr &MI,
35533                                       MachineBasicBlock *BB) const {
35534   // This is pretty easy.  We're taking the value that we received from
35535   // our load from the relocation, sticking it in either RDI (x86-64)
35536   // or EAX and doing an indirect call.  The return value will then
35537   // be in the normal return register.
35538   MachineFunction *F = BB->getParent();
35539   const X86InstrInfo *TII = Subtarget.getInstrInfo();
35540   const MIMetadata MIMD(MI);
35541 
35542   assert(Subtarget.isTargetDarwin() && "Darwin only instr emitted?");
35543   assert(MI.getOperand(3).isGlobal() && "This should be a global");
35544 
35545   // Get a register mask for the lowered call.
35546   // FIXME: The 32-bit calls have non-standard calling conventions. Use a
35547   // proper register mask.
35548   const uint32_t *RegMask =
35549       Subtarget.is64Bit() ?
35550       Subtarget.getRegisterInfo()->getDarwinTLSCallPreservedMask() :
35551       Subtarget.getRegisterInfo()->getCallPreservedMask(*F, CallingConv::C);
35552   if (Subtarget.is64Bit()) {
35553     MachineInstrBuilder MIB =
35554         BuildMI(*BB, MI, MIMD, TII->get(X86::MOV64rm), X86::RDI)
35555             .addReg(X86::RIP)
35556             .addImm(0)
35557             .addReg(0)
35558             .addGlobalAddress(MI.getOperand(3).getGlobal(), 0,
35559                               MI.getOperand(3).getTargetFlags())
35560             .addReg(0);
35561     MIB = BuildMI(*BB, MI, MIMD, TII->get(X86::CALL64m));
35562     addDirectMem(MIB, X86::RDI);
35563     MIB.addReg(X86::RAX, RegState::ImplicitDefine).addRegMask(RegMask);
35564   } else if (!isPositionIndependent()) {
35565     MachineInstrBuilder MIB =
35566         BuildMI(*BB, MI, MIMD, TII->get(X86::MOV32rm), X86::EAX)
35567             .addReg(0)
35568             .addImm(0)
35569             .addReg(0)
35570             .addGlobalAddress(MI.getOperand(3).getGlobal(), 0,
35571                               MI.getOperand(3).getTargetFlags())
35572             .addReg(0);
35573     MIB = BuildMI(*BB, MI, MIMD, TII->get(X86::CALL32m));
35574     addDirectMem(MIB, X86::EAX);
35575     MIB.addReg(X86::EAX, RegState::ImplicitDefine).addRegMask(RegMask);
35576   } else {
35577     MachineInstrBuilder MIB =
35578         BuildMI(*BB, MI, MIMD, TII->get(X86::MOV32rm), X86::EAX)
35579             .addReg(TII->getGlobalBaseReg(F))
35580             .addImm(0)
35581             .addReg(0)
35582             .addGlobalAddress(MI.getOperand(3).getGlobal(), 0,
35583                               MI.getOperand(3).getTargetFlags())
35584             .addReg(0);
35585     MIB = BuildMI(*BB, MI, MIMD, TII->get(X86::CALL32m));
35586     addDirectMem(MIB, X86::EAX);
35587     MIB.addReg(X86::EAX, RegState::ImplicitDefine).addRegMask(RegMask);
35588   }
35589 
35590   MI.eraseFromParent(); // The pseudo instruction is gone now.
35591   return BB;
35592 }
35593 
getOpcodeForIndirectThunk(unsigned RPOpc)35594 static unsigned getOpcodeForIndirectThunk(unsigned RPOpc) {
35595   switch (RPOpc) {
35596   case X86::INDIRECT_THUNK_CALL32:
35597     return X86::CALLpcrel32;
35598   case X86::INDIRECT_THUNK_CALL64:
35599     return X86::CALL64pcrel32;
35600   case X86::INDIRECT_THUNK_TCRETURN32:
35601     return X86::TCRETURNdi;
35602   case X86::INDIRECT_THUNK_TCRETURN64:
35603     return X86::TCRETURNdi64;
35604   }
35605   llvm_unreachable("not indirect thunk opcode");
35606 }
35607 
getIndirectThunkSymbol(const X86Subtarget & Subtarget,unsigned Reg)35608 static const char *getIndirectThunkSymbol(const X86Subtarget &Subtarget,
35609                                           unsigned Reg) {
35610   if (Subtarget.useRetpolineExternalThunk()) {
35611     // When using an external thunk for retpolines, we pick names that match the
35612     // names GCC happens to use as well. This helps simplify the implementation
35613     // of the thunks for kernels where they have no easy ability to create
35614     // aliases and are doing non-trivial configuration of the thunk's body. For
35615     // example, the Linux kernel will do boot-time hot patching of the thunk
35616     // bodies and cannot easily export aliases of these to loaded modules.
35617     //
35618     // Note that at any point in the future, we may need to change the semantics
35619     // of how we implement retpolines and at that time will likely change the
35620     // name of the called thunk. Essentially, there is no hard guarantee that
35621     // LLVM will generate calls to specific thunks, we merely make a best-effort
35622     // attempt to help out kernels and other systems where duplicating the
35623     // thunks is costly.
35624     switch (Reg) {
35625     case X86::EAX:
35626       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35627       return "__x86_indirect_thunk_eax";
35628     case X86::ECX:
35629       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35630       return "__x86_indirect_thunk_ecx";
35631     case X86::EDX:
35632       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35633       return "__x86_indirect_thunk_edx";
35634     case X86::EDI:
35635       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35636       return "__x86_indirect_thunk_edi";
35637     case X86::R11:
35638       assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!");
35639       return "__x86_indirect_thunk_r11";
35640     }
35641     llvm_unreachable("unexpected reg for external indirect thunk");
35642   }
35643 
35644   if (Subtarget.useRetpolineIndirectCalls() ||
35645       Subtarget.useRetpolineIndirectBranches()) {
35646     // When targeting an internal COMDAT thunk use an LLVM-specific name.
35647     switch (Reg) {
35648     case X86::EAX:
35649       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35650       return "__llvm_retpoline_eax";
35651     case X86::ECX:
35652       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35653       return "__llvm_retpoline_ecx";
35654     case X86::EDX:
35655       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35656       return "__llvm_retpoline_edx";
35657     case X86::EDI:
35658       assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!");
35659       return "__llvm_retpoline_edi";
35660     case X86::R11:
35661       assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!");
35662       return "__llvm_retpoline_r11";
35663     }
35664     llvm_unreachable("unexpected reg for retpoline");
35665   }
35666 
35667   if (Subtarget.useLVIControlFlowIntegrity()) {
35668     assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!");
35669     return "__llvm_lvi_thunk_r11";
35670   }
35671   llvm_unreachable("getIndirectThunkSymbol() invoked without thunk feature");
35672 }
35673 
35674 MachineBasicBlock *
EmitLoweredIndirectThunk(MachineInstr & MI,MachineBasicBlock * BB) const35675 X86TargetLowering::EmitLoweredIndirectThunk(MachineInstr &MI,
35676                                             MachineBasicBlock *BB) const {
35677   // Copy the virtual register into the R11 physical register and
35678   // call the retpoline thunk.
35679   const MIMetadata MIMD(MI);
35680   const X86InstrInfo *TII = Subtarget.getInstrInfo();
35681   Register CalleeVReg = MI.getOperand(0).getReg();
35682   unsigned Opc = getOpcodeForIndirectThunk(MI.getOpcode());
35683 
35684   // Find an available scratch register to hold the callee. On 64-bit, we can
35685   // just use R11, but we scan for uses anyway to ensure we don't generate
35686   // incorrect code. On 32-bit, we use one of EAX, ECX, or EDX that isn't
35687   // already a register use operand to the call to hold the callee. If none
35688   // are available, use EDI instead. EDI is chosen because EBX is the PIC base
35689   // register and ESI is the base pointer to realigned stack frames with VLAs.
35690   SmallVector<unsigned, 3> AvailableRegs;
35691   if (Subtarget.is64Bit())
35692     AvailableRegs.push_back(X86::R11);
35693   else
35694     AvailableRegs.append({X86::EAX, X86::ECX, X86::EDX, X86::EDI});
35695 
35696   // Zero out any registers that are already used.
35697   for (const auto &MO : MI.operands()) {
35698     if (MO.isReg() && MO.isUse())
35699       for (unsigned &Reg : AvailableRegs)
35700         if (Reg == MO.getReg())
35701           Reg = 0;
35702   }
35703 
35704   // Choose the first remaining non-zero available register.
35705   unsigned AvailableReg = 0;
35706   for (unsigned MaybeReg : AvailableRegs) {
35707     if (MaybeReg) {
35708       AvailableReg = MaybeReg;
35709       break;
35710     }
35711   }
35712   if (!AvailableReg)
35713     report_fatal_error("calling convention incompatible with retpoline, no "
35714                        "available registers");
35715 
35716   const char *Symbol = getIndirectThunkSymbol(Subtarget, AvailableReg);
35717 
35718   BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), AvailableReg)
35719       .addReg(CalleeVReg);
35720   MI.getOperand(0).ChangeToES(Symbol);
35721   MI.setDesc(TII->get(Opc));
35722   MachineInstrBuilder(*BB->getParent(), &MI)
35723       .addReg(AvailableReg, RegState::Implicit | RegState::Kill);
35724   return BB;
35725 }
35726 
35727 /// SetJmp implies future control flow change upon calling the corresponding
35728 /// LongJmp.
35729 /// Instead of using the 'return' instruction, the long jump fixes the stack and
35730 /// performs an indirect branch. To do so it uses the registers that were stored
35731 /// in the jump buffer (when calling SetJmp).
35732 /// In case the shadow stack is enabled we need to fix it as well, because some
35733 /// return addresses will be skipped.
35734 /// The function will save the SSP for future fixing in the function
35735 /// emitLongJmpShadowStackFix.
35736 /// \sa emitLongJmpShadowStackFix
35737 /// \param [in] MI The temporary Machine Instruction for the builtin.
35738 /// \param [in] MBB The Machine Basic Block that will be modified.
emitSetJmpShadowStackFix(MachineInstr & MI,MachineBasicBlock * MBB) const35739 void X86TargetLowering::emitSetJmpShadowStackFix(MachineInstr &MI,
35740                                                  MachineBasicBlock *MBB) const {
35741   const MIMetadata MIMD(MI);
35742   MachineFunction *MF = MBB->getParent();
35743   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
35744   MachineRegisterInfo &MRI = MF->getRegInfo();
35745   MachineInstrBuilder MIB;
35746 
35747   // Memory Reference.
35748   SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(),
35749                                            MI.memoperands_end());
35750 
35751   // Initialize a register with zero.
35752   MVT PVT = getPointerTy(MF->getDataLayout());
35753   const TargetRegisterClass *PtrRC = getRegClassFor(PVT);
35754   Register ZReg = MRI.createVirtualRegister(PtrRC);
35755   unsigned XorRROpc = (PVT == MVT::i64) ? X86::XOR64rr : X86::XOR32rr;
35756   BuildMI(*MBB, MI, MIMD, TII->get(XorRROpc))
35757       .addDef(ZReg)
35758       .addReg(ZReg, RegState::Undef)
35759       .addReg(ZReg, RegState::Undef);
35760 
35761   // Read the current SSP Register value to the zeroed register.
35762   Register SSPCopyReg = MRI.createVirtualRegister(PtrRC);
35763   unsigned RdsspOpc = (PVT == MVT::i64) ? X86::RDSSPQ : X86::RDSSPD;
35764   BuildMI(*MBB, MI, MIMD, TII->get(RdsspOpc), SSPCopyReg).addReg(ZReg);
35765 
35766   // Write the SSP register value to offset 3 in input memory buffer.
35767   unsigned PtrStoreOpc = (PVT == MVT::i64) ? X86::MOV64mr : X86::MOV32mr;
35768   MIB = BuildMI(*MBB, MI, MIMD, TII->get(PtrStoreOpc));
35769   const int64_t SSPOffset = 3 * PVT.getStoreSize();
35770   const unsigned MemOpndSlot = 1;
35771   for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
35772     if (i == X86::AddrDisp)
35773       MIB.addDisp(MI.getOperand(MemOpndSlot + i), SSPOffset);
35774     else
35775       MIB.add(MI.getOperand(MemOpndSlot + i));
35776   }
35777   MIB.addReg(SSPCopyReg);
35778   MIB.setMemRefs(MMOs);
35779 }
35780 
35781 MachineBasicBlock *
emitEHSjLjSetJmp(MachineInstr & MI,MachineBasicBlock * MBB) const35782 X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI,
35783                                     MachineBasicBlock *MBB) const {
35784   const MIMetadata MIMD(MI);
35785   MachineFunction *MF = MBB->getParent();
35786   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
35787   const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
35788   MachineRegisterInfo &MRI = MF->getRegInfo();
35789 
35790   const BasicBlock *BB = MBB->getBasicBlock();
35791   MachineFunction::iterator I = ++MBB->getIterator();
35792 
35793   // Memory Reference
35794   SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(),
35795                                            MI.memoperands_end());
35796 
35797   unsigned DstReg;
35798   unsigned MemOpndSlot = 0;
35799 
35800   unsigned CurOp = 0;
35801 
35802   DstReg = MI.getOperand(CurOp++).getReg();
35803   const TargetRegisterClass *RC = MRI.getRegClass(DstReg);
35804   assert(TRI->isTypeLegalForClass(*RC, MVT::i32) && "Invalid destination!");
35805   (void)TRI;
35806   Register mainDstReg = MRI.createVirtualRegister(RC);
35807   Register restoreDstReg = MRI.createVirtualRegister(RC);
35808 
35809   MemOpndSlot = CurOp;
35810 
35811   MVT PVT = getPointerTy(MF->getDataLayout());
35812   assert((PVT == MVT::i64 || PVT == MVT::i32) &&
35813          "Invalid Pointer Size!");
35814 
35815   // For v = setjmp(buf), we generate
35816   //
35817   // thisMBB:
35818   //  buf[LabelOffset] = restoreMBB <-- takes address of restoreMBB
35819   //  SjLjSetup restoreMBB
35820   //
35821   // mainMBB:
35822   //  v_main = 0
35823   //
35824   // sinkMBB:
35825   //  v = phi(main, restore)
35826   //
35827   // restoreMBB:
35828   //  if base pointer being used, load it from frame
35829   //  v_restore = 1
35830 
35831   MachineBasicBlock *thisMBB = MBB;
35832   MachineBasicBlock *mainMBB = MF->CreateMachineBasicBlock(BB);
35833   MachineBasicBlock *sinkMBB = MF->CreateMachineBasicBlock(BB);
35834   MachineBasicBlock *restoreMBB = MF->CreateMachineBasicBlock(BB);
35835   MF->insert(I, mainMBB);
35836   MF->insert(I, sinkMBB);
35837   MF->push_back(restoreMBB);
35838   restoreMBB->setMachineBlockAddressTaken();
35839 
35840   MachineInstrBuilder MIB;
35841 
35842   // Transfer the remainder of BB and its successor edges to sinkMBB.
35843   sinkMBB->splice(sinkMBB->begin(), MBB,
35844                   std::next(MachineBasicBlock::iterator(MI)), MBB->end());
35845   sinkMBB->transferSuccessorsAndUpdatePHIs(MBB);
35846 
35847   // thisMBB:
35848   unsigned PtrStoreOpc = 0;
35849   unsigned LabelReg = 0;
35850   const int64_t LabelOffset = 1 * PVT.getStoreSize();
35851   bool UseImmLabel = (MF->getTarget().getCodeModel() == CodeModel::Small) &&
35852                      !isPositionIndependent();
35853 
35854   // Prepare IP either in reg or imm.
35855   if (!UseImmLabel) {
35856     PtrStoreOpc = (PVT == MVT::i64) ? X86::MOV64mr : X86::MOV32mr;
35857     const TargetRegisterClass *PtrRC = getRegClassFor(PVT);
35858     LabelReg = MRI.createVirtualRegister(PtrRC);
35859     if (Subtarget.is64Bit()) {
35860       MIB = BuildMI(*thisMBB, MI, MIMD, TII->get(X86::LEA64r), LabelReg)
35861               .addReg(X86::RIP)
35862               .addImm(0)
35863               .addReg(0)
35864               .addMBB(restoreMBB)
35865               .addReg(0);
35866     } else {
35867       const X86InstrInfo *XII = static_cast<const X86InstrInfo*>(TII);
35868       MIB = BuildMI(*thisMBB, MI, MIMD, TII->get(X86::LEA32r), LabelReg)
35869               .addReg(XII->getGlobalBaseReg(MF))
35870               .addImm(0)
35871               .addReg(0)
35872               .addMBB(restoreMBB, Subtarget.classifyBlockAddressReference())
35873               .addReg(0);
35874     }
35875   } else
35876     PtrStoreOpc = (PVT == MVT::i64) ? X86::MOV64mi32 : X86::MOV32mi;
35877   // Store IP
35878   MIB = BuildMI(*thisMBB, MI, MIMD, TII->get(PtrStoreOpc));
35879   for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
35880     if (i == X86::AddrDisp)
35881       MIB.addDisp(MI.getOperand(MemOpndSlot + i), LabelOffset);
35882     else
35883       MIB.add(MI.getOperand(MemOpndSlot + i));
35884   }
35885   if (!UseImmLabel)
35886     MIB.addReg(LabelReg);
35887   else
35888     MIB.addMBB(restoreMBB);
35889   MIB.setMemRefs(MMOs);
35890 
35891   if (MF->getFunction().getParent()->getModuleFlag("cf-protection-return")) {
35892     emitSetJmpShadowStackFix(MI, thisMBB);
35893   }
35894 
35895   // Setup
35896   MIB = BuildMI(*thisMBB, MI, MIMD, TII->get(X86::EH_SjLj_Setup))
35897           .addMBB(restoreMBB);
35898 
35899   const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
35900   MIB.addRegMask(RegInfo->getNoPreservedMask());
35901   thisMBB->addSuccessor(mainMBB);
35902   thisMBB->addSuccessor(restoreMBB);
35903 
35904   // mainMBB:
35905   //  EAX = 0
35906   BuildMI(mainMBB, MIMD, TII->get(X86::MOV32r0), mainDstReg);
35907   mainMBB->addSuccessor(sinkMBB);
35908 
35909   // sinkMBB:
35910   BuildMI(*sinkMBB, sinkMBB->begin(), MIMD, TII->get(X86::PHI), DstReg)
35911       .addReg(mainDstReg)
35912       .addMBB(mainMBB)
35913       .addReg(restoreDstReg)
35914       .addMBB(restoreMBB);
35915 
35916   // restoreMBB:
35917   if (RegInfo->hasBasePointer(*MF)) {
35918     const bool Uses64BitFramePtr =
35919         Subtarget.isTarget64BitLP64() || Subtarget.isTargetNaCl64();
35920     X86MachineFunctionInfo *X86FI = MF->getInfo<X86MachineFunctionInfo>();
35921     X86FI->setRestoreBasePointer(MF);
35922     Register FramePtr = RegInfo->getFrameRegister(*MF);
35923     Register BasePtr = RegInfo->getBaseRegister();
35924     unsigned Opm = Uses64BitFramePtr ? X86::MOV64rm : X86::MOV32rm;
35925     addRegOffset(BuildMI(restoreMBB, MIMD, TII->get(Opm), BasePtr),
35926                  FramePtr, true, X86FI->getRestoreBasePointerOffset())
35927       .setMIFlag(MachineInstr::FrameSetup);
35928   }
35929   BuildMI(restoreMBB, MIMD, TII->get(X86::MOV32ri), restoreDstReg).addImm(1);
35930   BuildMI(restoreMBB, MIMD, TII->get(X86::JMP_1)).addMBB(sinkMBB);
35931   restoreMBB->addSuccessor(sinkMBB);
35932 
35933   MI.eraseFromParent();
35934   return sinkMBB;
35935 }
35936 
35937 /// Fix the shadow stack using the previously saved SSP pointer.
35938 /// \sa emitSetJmpShadowStackFix
35939 /// \param [in] MI The temporary Machine Instruction for the builtin.
35940 /// \param [in] MBB The Machine Basic Block that will be modified.
35941 /// \return The sink MBB that will perform the future indirect branch.
35942 MachineBasicBlock *
emitLongJmpShadowStackFix(MachineInstr & MI,MachineBasicBlock * MBB) const35943 X86TargetLowering::emitLongJmpShadowStackFix(MachineInstr &MI,
35944                                              MachineBasicBlock *MBB) const {
35945   const MIMetadata MIMD(MI);
35946   MachineFunction *MF = MBB->getParent();
35947   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
35948   MachineRegisterInfo &MRI = MF->getRegInfo();
35949 
35950   // Memory Reference
35951   SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(),
35952                                            MI.memoperands_end());
35953 
35954   MVT PVT = getPointerTy(MF->getDataLayout());
35955   const TargetRegisterClass *PtrRC = getRegClassFor(PVT);
35956 
35957   // checkSspMBB:
35958   //         xor vreg1, vreg1
35959   //         rdssp vreg1
35960   //         test vreg1, vreg1
35961   //         je sinkMBB   # Jump if Shadow Stack is not supported
35962   // fallMBB:
35963   //         mov buf+24/12(%rip), vreg2
35964   //         sub vreg1, vreg2
35965   //         jbe sinkMBB  # No need to fix the Shadow Stack
35966   // fixShadowMBB:
35967   //         shr 3/2, vreg2
35968   //         incssp vreg2  # fix the SSP according to the lower 8 bits
35969   //         shr 8, vreg2
35970   //         je sinkMBB
35971   // fixShadowLoopPrepareMBB:
35972   //         shl vreg2
35973   //         mov 128, vreg3
35974   // fixShadowLoopMBB:
35975   //         incssp vreg3
35976   //         dec vreg2
35977   //         jne fixShadowLoopMBB # Iterate until you finish fixing
35978   //                              # the Shadow Stack
35979   // sinkMBB:
35980 
35981   MachineFunction::iterator I = ++MBB->getIterator();
35982   const BasicBlock *BB = MBB->getBasicBlock();
35983 
35984   MachineBasicBlock *checkSspMBB = MF->CreateMachineBasicBlock(BB);
35985   MachineBasicBlock *fallMBB = MF->CreateMachineBasicBlock(BB);
35986   MachineBasicBlock *fixShadowMBB = MF->CreateMachineBasicBlock(BB);
35987   MachineBasicBlock *fixShadowLoopPrepareMBB = MF->CreateMachineBasicBlock(BB);
35988   MachineBasicBlock *fixShadowLoopMBB = MF->CreateMachineBasicBlock(BB);
35989   MachineBasicBlock *sinkMBB = MF->CreateMachineBasicBlock(BB);
35990   MF->insert(I, checkSspMBB);
35991   MF->insert(I, fallMBB);
35992   MF->insert(I, fixShadowMBB);
35993   MF->insert(I, fixShadowLoopPrepareMBB);
35994   MF->insert(I, fixShadowLoopMBB);
35995   MF->insert(I, sinkMBB);
35996 
35997   // Transfer the remainder of BB and its successor edges to sinkMBB.
35998   sinkMBB->splice(sinkMBB->begin(), MBB, MachineBasicBlock::iterator(MI),
35999                   MBB->end());
36000   sinkMBB->transferSuccessorsAndUpdatePHIs(MBB);
36001 
36002   MBB->addSuccessor(checkSspMBB);
36003 
36004   // Initialize a register with zero.
36005   Register ZReg = MRI.createVirtualRegister(&X86::GR32RegClass);
36006   BuildMI(checkSspMBB, MIMD, TII->get(X86::MOV32r0), ZReg);
36007 
36008   if (PVT == MVT::i64) {
36009     Register TmpZReg = MRI.createVirtualRegister(PtrRC);
36010     BuildMI(checkSspMBB, MIMD, TII->get(X86::SUBREG_TO_REG), TmpZReg)
36011       .addImm(0)
36012       .addReg(ZReg)
36013       .addImm(X86::sub_32bit);
36014     ZReg = TmpZReg;
36015   }
36016 
36017   // Read the current SSP Register value to the zeroed register.
36018   Register SSPCopyReg = MRI.createVirtualRegister(PtrRC);
36019   unsigned RdsspOpc = (PVT == MVT::i64) ? X86::RDSSPQ : X86::RDSSPD;
36020   BuildMI(checkSspMBB, MIMD, TII->get(RdsspOpc), SSPCopyReg).addReg(ZReg);
36021 
36022   // Check whether the result of the SSP register is zero and jump directly
36023   // to the sink.
36024   unsigned TestRROpc = (PVT == MVT::i64) ? X86::TEST64rr : X86::TEST32rr;
36025   BuildMI(checkSspMBB, MIMD, TII->get(TestRROpc))
36026       .addReg(SSPCopyReg)
36027       .addReg(SSPCopyReg);
36028   BuildMI(checkSspMBB, MIMD, TII->get(X86::JCC_1))
36029       .addMBB(sinkMBB)
36030       .addImm(X86::COND_E);
36031   checkSspMBB->addSuccessor(sinkMBB);
36032   checkSspMBB->addSuccessor(fallMBB);
36033 
36034   // Reload the previously saved SSP register value.
36035   Register PrevSSPReg = MRI.createVirtualRegister(PtrRC);
36036   unsigned PtrLoadOpc = (PVT == MVT::i64) ? X86::MOV64rm : X86::MOV32rm;
36037   const int64_t SPPOffset = 3 * PVT.getStoreSize();
36038   MachineInstrBuilder MIB =
36039       BuildMI(fallMBB, MIMD, TII->get(PtrLoadOpc), PrevSSPReg);
36040   for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
36041     const MachineOperand &MO = MI.getOperand(i);
36042     if (i == X86::AddrDisp)
36043       MIB.addDisp(MO, SPPOffset);
36044     else if (MO.isReg()) // Don't add the whole operand, we don't want to
36045                          // preserve kill flags.
36046       MIB.addReg(MO.getReg());
36047     else
36048       MIB.add(MO);
36049   }
36050   MIB.setMemRefs(MMOs);
36051 
36052   // Subtract the current SSP from the previous SSP.
36053   Register SspSubReg = MRI.createVirtualRegister(PtrRC);
36054   unsigned SubRROpc = (PVT == MVT::i64) ? X86::SUB64rr : X86::SUB32rr;
36055   BuildMI(fallMBB, MIMD, TII->get(SubRROpc), SspSubReg)
36056       .addReg(PrevSSPReg)
36057       .addReg(SSPCopyReg);
36058 
36059   // Jump to sink in case PrevSSPReg <= SSPCopyReg.
36060   BuildMI(fallMBB, MIMD, TII->get(X86::JCC_1))
36061       .addMBB(sinkMBB)
36062       .addImm(X86::COND_BE);
36063   fallMBB->addSuccessor(sinkMBB);
36064   fallMBB->addSuccessor(fixShadowMBB);
36065 
36066   // Shift right by 2/3 for 32/64 because incssp multiplies the argument by 4/8.
36067   unsigned ShrRIOpc = (PVT == MVT::i64) ? X86::SHR64ri : X86::SHR32ri;
36068   unsigned Offset = (PVT == MVT::i64) ? 3 : 2;
36069   Register SspFirstShrReg = MRI.createVirtualRegister(PtrRC);
36070   BuildMI(fixShadowMBB, MIMD, TII->get(ShrRIOpc), SspFirstShrReg)
36071       .addReg(SspSubReg)
36072       .addImm(Offset);
36073 
36074   // Increase SSP when looking only on the lower 8 bits of the delta.
36075   unsigned IncsspOpc = (PVT == MVT::i64) ? X86::INCSSPQ : X86::INCSSPD;
36076   BuildMI(fixShadowMBB, MIMD, TII->get(IncsspOpc)).addReg(SspFirstShrReg);
36077 
36078   // Reset the lower 8 bits.
36079   Register SspSecondShrReg = MRI.createVirtualRegister(PtrRC);
36080   BuildMI(fixShadowMBB, MIMD, TII->get(ShrRIOpc), SspSecondShrReg)
36081       .addReg(SspFirstShrReg)
36082       .addImm(8);
36083 
36084   // Jump if the result of the shift is zero.
36085   BuildMI(fixShadowMBB, MIMD, TII->get(X86::JCC_1))
36086       .addMBB(sinkMBB)
36087       .addImm(X86::COND_E);
36088   fixShadowMBB->addSuccessor(sinkMBB);
36089   fixShadowMBB->addSuccessor(fixShadowLoopPrepareMBB);
36090 
36091   // Do a single shift left.
36092   unsigned ShlR1Opc = (PVT == MVT::i64) ? X86::SHL64ri : X86::SHL32ri;
36093   Register SspAfterShlReg = MRI.createVirtualRegister(PtrRC);
36094   BuildMI(fixShadowLoopPrepareMBB, MIMD, TII->get(ShlR1Opc), SspAfterShlReg)
36095       .addReg(SspSecondShrReg)
36096       .addImm(1);
36097 
36098   // Save the value 128 to a register (will be used next with incssp).
36099   Register Value128InReg = MRI.createVirtualRegister(PtrRC);
36100   unsigned MovRIOpc = (PVT == MVT::i64) ? X86::MOV64ri32 : X86::MOV32ri;
36101   BuildMI(fixShadowLoopPrepareMBB, MIMD, TII->get(MovRIOpc), Value128InReg)
36102       .addImm(128);
36103   fixShadowLoopPrepareMBB->addSuccessor(fixShadowLoopMBB);
36104 
36105   // Since incssp only looks at the lower 8 bits, we might need to do several
36106   // iterations of incssp until we finish fixing the shadow stack.
36107   Register DecReg = MRI.createVirtualRegister(PtrRC);
36108   Register CounterReg = MRI.createVirtualRegister(PtrRC);
36109   BuildMI(fixShadowLoopMBB, MIMD, TII->get(X86::PHI), CounterReg)
36110       .addReg(SspAfterShlReg)
36111       .addMBB(fixShadowLoopPrepareMBB)
36112       .addReg(DecReg)
36113       .addMBB(fixShadowLoopMBB);
36114 
36115   // Every iteration we increase the SSP by 128.
36116   BuildMI(fixShadowLoopMBB, MIMD, TII->get(IncsspOpc)).addReg(Value128InReg);
36117 
36118   // Every iteration we decrement the counter by 1.
36119   unsigned DecROpc = (PVT == MVT::i64) ? X86::DEC64r : X86::DEC32r;
36120   BuildMI(fixShadowLoopMBB, MIMD, TII->get(DecROpc), DecReg).addReg(CounterReg);
36121 
36122   // Jump if the counter is not zero yet.
36123   BuildMI(fixShadowLoopMBB, MIMD, TII->get(X86::JCC_1))
36124       .addMBB(fixShadowLoopMBB)
36125       .addImm(X86::COND_NE);
36126   fixShadowLoopMBB->addSuccessor(sinkMBB);
36127   fixShadowLoopMBB->addSuccessor(fixShadowLoopMBB);
36128 
36129   return sinkMBB;
36130 }
36131 
36132 MachineBasicBlock *
emitEHSjLjLongJmp(MachineInstr & MI,MachineBasicBlock * MBB) const36133 X86TargetLowering::emitEHSjLjLongJmp(MachineInstr &MI,
36134                                      MachineBasicBlock *MBB) const {
36135   const MIMetadata MIMD(MI);
36136   MachineFunction *MF = MBB->getParent();
36137   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
36138   MachineRegisterInfo &MRI = MF->getRegInfo();
36139 
36140   // Memory Reference
36141   SmallVector<MachineMemOperand *, 2> MMOs(MI.memoperands_begin(),
36142                                            MI.memoperands_end());
36143 
36144   MVT PVT = getPointerTy(MF->getDataLayout());
36145   assert((PVT == MVT::i64 || PVT == MVT::i32) &&
36146          "Invalid Pointer Size!");
36147 
36148   const TargetRegisterClass *RC =
36149     (PVT == MVT::i64) ? &X86::GR64RegClass : &X86::GR32RegClass;
36150   Register Tmp = MRI.createVirtualRegister(RC);
36151   // Since FP is only updated here but NOT referenced, it's treated as GPR.
36152   const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo();
36153   Register FP = (PVT == MVT::i64) ? X86::RBP : X86::EBP;
36154   Register SP = RegInfo->getStackRegister();
36155 
36156   MachineInstrBuilder MIB;
36157 
36158   const int64_t LabelOffset = 1 * PVT.getStoreSize();
36159   const int64_t SPOffset = 2 * PVT.getStoreSize();
36160 
36161   unsigned PtrLoadOpc = (PVT == MVT::i64) ? X86::MOV64rm : X86::MOV32rm;
36162   unsigned IJmpOpc = (PVT == MVT::i64) ? X86::JMP64r : X86::JMP32r;
36163 
36164   MachineBasicBlock *thisMBB = MBB;
36165 
36166   // When CET and shadow stack is enabled, we need to fix the Shadow Stack.
36167   if (MF->getFunction().getParent()->getModuleFlag("cf-protection-return")) {
36168     thisMBB = emitLongJmpShadowStackFix(MI, thisMBB);
36169   }
36170 
36171   // Reload FP
36172   MIB = BuildMI(*thisMBB, MI, MIMD, TII->get(PtrLoadOpc), FP);
36173   for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
36174     const MachineOperand &MO = MI.getOperand(i);
36175     if (MO.isReg()) // Don't add the whole operand, we don't want to
36176                     // preserve kill flags.
36177       MIB.addReg(MO.getReg());
36178     else
36179       MIB.add(MO);
36180   }
36181   MIB.setMemRefs(MMOs);
36182 
36183   // Reload IP
36184   MIB = BuildMI(*thisMBB, MI, MIMD, TII->get(PtrLoadOpc), Tmp);
36185   for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
36186     const MachineOperand &MO = MI.getOperand(i);
36187     if (i == X86::AddrDisp)
36188       MIB.addDisp(MO, LabelOffset);
36189     else if (MO.isReg()) // Don't add the whole operand, we don't want to
36190                          // preserve kill flags.
36191       MIB.addReg(MO.getReg());
36192     else
36193       MIB.add(MO);
36194   }
36195   MIB.setMemRefs(MMOs);
36196 
36197   // Reload SP
36198   MIB = BuildMI(*thisMBB, MI, MIMD, TII->get(PtrLoadOpc), SP);
36199   for (unsigned i = 0; i < X86::AddrNumOperands; ++i) {
36200     if (i == X86::AddrDisp)
36201       MIB.addDisp(MI.getOperand(i), SPOffset);
36202     else
36203       MIB.add(MI.getOperand(i)); // We can preserve the kill flags here, it's
36204                                  // the last instruction of the expansion.
36205   }
36206   MIB.setMemRefs(MMOs);
36207 
36208   // Jump
36209   BuildMI(*thisMBB, MI, MIMD, TII->get(IJmpOpc)).addReg(Tmp);
36210 
36211   MI.eraseFromParent();
36212   return thisMBB;
36213 }
36214 
SetupEntryBlockForSjLj(MachineInstr & MI,MachineBasicBlock * MBB,MachineBasicBlock * DispatchBB,int FI) const36215 void X86TargetLowering::SetupEntryBlockForSjLj(MachineInstr &MI,
36216                                                MachineBasicBlock *MBB,
36217                                                MachineBasicBlock *DispatchBB,
36218                                                int FI) const {
36219   const MIMetadata MIMD(MI);
36220   MachineFunction *MF = MBB->getParent();
36221   MachineRegisterInfo *MRI = &MF->getRegInfo();
36222   const X86InstrInfo *TII = Subtarget.getInstrInfo();
36223 
36224   MVT PVT = getPointerTy(MF->getDataLayout());
36225   assert((PVT == MVT::i64 || PVT == MVT::i32) && "Invalid Pointer Size!");
36226 
36227   unsigned Op = 0;
36228   unsigned VR = 0;
36229 
36230   bool UseImmLabel = (MF->getTarget().getCodeModel() == CodeModel::Small) &&
36231                      !isPositionIndependent();
36232 
36233   if (UseImmLabel) {
36234     Op = (PVT == MVT::i64) ? X86::MOV64mi32 : X86::MOV32mi;
36235   } else {
36236     const TargetRegisterClass *TRC =
36237         (PVT == MVT::i64) ? &X86::GR64RegClass : &X86::GR32RegClass;
36238     VR = MRI->createVirtualRegister(TRC);
36239     Op = (PVT == MVT::i64) ? X86::MOV64mr : X86::MOV32mr;
36240 
36241     if (Subtarget.is64Bit())
36242       BuildMI(*MBB, MI, MIMD, TII->get(X86::LEA64r), VR)
36243           .addReg(X86::RIP)
36244           .addImm(1)
36245           .addReg(0)
36246           .addMBB(DispatchBB)
36247           .addReg(0);
36248     else
36249       BuildMI(*MBB, MI, MIMD, TII->get(X86::LEA32r), VR)
36250           .addReg(0) /* TII->getGlobalBaseReg(MF) */
36251           .addImm(1)
36252           .addReg(0)
36253           .addMBB(DispatchBB, Subtarget.classifyBlockAddressReference())
36254           .addReg(0);
36255   }
36256 
36257   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MIMD, TII->get(Op));
36258   addFrameReference(MIB, FI, Subtarget.is64Bit() ? 56 : 36);
36259   if (UseImmLabel)
36260     MIB.addMBB(DispatchBB);
36261   else
36262     MIB.addReg(VR);
36263 }
36264 
36265 MachineBasicBlock *
EmitSjLjDispatchBlock(MachineInstr & MI,MachineBasicBlock * BB) const36266 X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI,
36267                                          MachineBasicBlock *BB) const {
36268   const MIMetadata MIMD(MI);
36269   MachineFunction *MF = BB->getParent();
36270   MachineRegisterInfo *MRI = &MF->getRegInfo();
36271   const X86InstrInfo *TII = Subtarget.getInstrInfo();
36272   int FI = MF->getFrameInfo().getFunctionContextIndex();
36273 
36274   // Get a mapping of the call site numbers to all of the landing pads they're
36275   // associated with.
36276   DenseMap<unsigned, SmallVector<MachineBasicBlock *, 2>> CallSiteNumToLPad;
36277   unsigned MaxCSNum = 0;
36278   for (auto &MBB : *MF) {
36279     if (!MBB.isEHPad())
36280       continue;
36281 
36282     MCSymbol *Sym = nullptr;
36283     for (const auto &MI : MBB) {
36284       if (MI.isDebugInstr())
36285         continue;
36286 
36287       assert(MI.isEHLabel() && "expected EH_LABEL");
36288       Sym = MI.getOperand(0).getMCSymbol();
36289       break;
36290     }
36291 
36292     if (!MF->hasCallSiteLandingPad(Sym))
36293       continue;
36294 
36295     for (unsigned CSI : MF->getCallSiteLandingPad(Sym)) {
36296       CallSiteNumToLPad[CSI].push_back(&MBB);
36297       MaxCSNum = std::max(MaxCSNum, CSI);
36298     }
36299   }
36300 
36301   // Get an ordered list of the machine basic blocks for the jump table.
36302   std::vector<MachineBasicBlock *> LPadList;
36303   SmallPtrSet<MachineBasicBlock *, 32> InvokeBBs;
36304   LPadList.reserve(CallSiteNumToLPad.size());
36305 
36306   for (unsigned CSI = 1; CSI <= MaxCSNum; ++CSI) {
36307     for (auto &LP : CallSiteNumToLPad[CSI]) {
36308       LPadList.push_back(LP);
36309       InvokeBBs.insert(LP->pred_begin(), LP->pred_end());
36310     }
36311   }
36312 
36313   assert(!LPadList.empty() &&
36314          "No landing pad destinations for the dispatch jump table!");
36315 
36316   // Create the MBBs for the dispatch code.
36317 
36318   // Shove the dispatch's address into the return slot in the function context.
36319   MachineBasicBlock *DispatchBB = MF->CreateMachineBasicBlock();
36320   DispatchBB->setIsEHPad(true);
36321 
36322   MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock();
36323   BuildMI(TrapBB, MIMD, TII->get(X86::TRAP));
36324   DispatchBB->addSuccessor(TrapBB);
36325 
36326   MachineBasicBlock *DispContBB = MF->CreateMachineBasicBlock();
36327   DispatchBB->addSuccessor(DispContBB);
36328 
36329   // Insert MBBs.
36330   MF->push_back(DispatchBB);
36331   MF->push_back(DispContBB);
36332   MF->push_back(TrapBB);
36333 
36334   // Insert code into the entry block that creates and registers the function
36335   // context.
36336   SetupEntryBlockForSjLj(MI, BB, DispatchBB, FI);
36337 
36338   // Create the jump table and associated information
36339   unsigned JTE = getJumpTableEncoding();
36340   MachineJumpTableInfo *JTI = MF->getOrCreateJumpTableInfo(JTE);
36341   unsigned MJTI = JTI->createJumpTableIndex(LPadList);
36342 
36343   const X86RegisterInfo &RI = TII->getRegisterInfo();
36344   // Add a register mask with no preserved registers.  This results in all
36345   // registers being marked as clobbered.
36346   if (RI.hasBasePointer(*MF)) {
36347     const bool FPIs64Bit =
36348         Subtarget.isTarget64BitLP64() || Subtarget.isTargetNaCl64();
36349     X86MachineFunctionInfo *MFI = MF->getInfo<X86MachineFunctionInfo>();
36350     MFI->setRestoreBasePointer(MF);
36351 
36352     Register FP = RI.getFrameRegister(*MF);
36353     Register BP = RI.getBaseRegister();
36354     unsigned Op = FPIs64Bit ? X86::MOV64rm : X86::MOV32rm;
36355     addRegOffset(BuildMI(DispatchBB, MIMD, TII->get(Op), BP), FP, true,
36356                  MFI->getRestoreBasePointerOffset())
36357         .addRegMask(RI.getNoPreservedMask());
36358   } else {
36359     BuildMI(DispatchBB, MIMD, TII->get(X86::NOOP))
36360         .addRegMask(RI.getNoPreservedMask());
36361   }
36362 
36363   // IReg is used as an index in a memory operand and therefore can't be SP
36364   Register IReg = MRI->createVirtualRegister(&X86::GR32_NOSPRegClass);
36365   addFrameReference(BuildMI(DispatchBB, MIMD, TII->get(X86::MOV32rm), IReg), FI,
36366                     Subtarget.is64Bit() ? 8 : 4);
36367   BuildMI(DispatchBB, MIMD, TII->get(X86::CMP32ri))
36368       .addReg(IReg)
36369       .addImm(LPadList.size());
36370   BuildMI(DispatchBB, MIMD, TII->get(X86::JCC_1))
36371       .addMBB(TrapBB)
36372       .addImm(X86::COND_AE);
36373 
36374   if (Subtarget.is64Bit()) {
36375     Register BReg = MRI->createVirtualRegister(&X86::GR64RegClass);
36376     Register IReg64 = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass);
36377 
36378     // leaq .LJTI0_0(%rip), BReg
36379     BuildMI(DispContBB, MIMD, TII->get(X86::LEA64r), BReg)
36380         .addReg(X86::RIP)
36381         .addImm(1)
36382         .addReg(0)
36383         .addJumpTableIndex(MJTI)
36384         .addReg(0);
36385     // movzx IReg64, IReg
36386     BuildMI(DispContBB, MIMD, TII->get(TargetOpcode::SUBREG_TO_REG), IReg64)
36387         .addImm(0)
36388         .addReg(IReg)
36389         .addImm(X86::sub_32bit);
36390 
36391     switch (JTE) {
36392     case MachineJumpTableInfo::EK_BlockAddress:
36393       // jmpq *(BReg,IReg64,8)
36394       BuildMI(DispContBB, MIMD, TII->get(X86::JMP64m))
36395           .addReg(BReg)
36396           .addImm(8)
36397           .addReg(IReg64)
36398           .addImm(0)
36399           .addReg(0);
36400       break;
36401     case MachineJumpTableInfo::EK_LabelDifference32: {
36402       Register OReg = MRI->createVirtualRegister(&X86::GR32RegClass);
36403       Register OReg64 = MRI->createVirtualRegister(&X86::GR64RegClass);
36404       Register TReg = MRI->createVirtualRegister(&X86::GR64RegClass);
36405 
36406       // movl (BReg,IReg64,4), OReg
36407       BuildMI(DispContBB, MIMD, TII->get(X86::MOV32rm), OReg)
36408           .addReg(BReg)
36409           .addImm(4)
36410           .addReg(IReg64)
36411           .addImm(0)
36412           .addReg(0);
36413       // movsx OReg64, OReg
36414       BuildMI(DispContBB, MIMD, TII->get(X86::MOVSX64rr32), OReg64)
36415           .addReg(OReg);
36416       // addq BReg, OReg64, TReg
36417       BuildMI(DispContBB, MIMD, TII->get(X86::ADD64rr), TReg)
36418           .addReg(OReg64)
36419           .addReg(BReg);
36420       // jmpq *TReg
36421       BuildMI(DispContBB, MIMD, TII->get(X86::JMP64r)).addReg(TReg);
36422       break;
36423     }
36424     default:
36425       llvm_unreachable("Unexpected jump table encoding");
36426     }
36427   } else {
36428     // jmpl *.LJTI0_0(,IReg,4)
36429     BuildMI(DispContBB, MIMD, TII->get(X86::JMP32m))
36430         .addReg(0)
36431         .addImm(4)
36432         .addReg(IReg)
36433         .addJumpTableIndex(MJTI)
36434         .addReg(0);
36435   }
36436 
36437   // Add the jump table entries as successors to the MBB.
36438   SmallPtrSet<MachineBasicBlock *, 8> SeenMBBs;
36439   for (auto &LP : LPadList)
36440     if (SeenMBBs.insert(LP).second)
36441       DispContBB->addSuccessor(LP);
36442 
36443   // N.B. the order the invoke BBs are processed in doesn't matter here.
36444   SmallVector<MachineBasicBlock *, 64> MBBLPads;
36445   const MCPhysReg *SavedRegs = MF->getRegInfo().getCalleeSavedRegs();
36446   for (MachineBasicBlock *MBB : InvokeBBs) {
36447     // Remove the landing pad successor from the invoke block and replace it
36448     // with the new dispatch block.
36449     // Keep a copy of Successors since it's modified inside the loop.
36450     SmallVector<MachineBasicBlock *, 8> Successors(MBB->succ_rbegin(),
36451                                                    MBB->succ_rend());
36452     // FIXME: Avoid quadratic complexity.
36453     for (auto *MBBS : Successors) {
36454       if (MBBS->isEHPad()) {
36455         MBB->removeSuccessor(MBBS);
36456         MBBLPads.push_back(MBBS);
36457       }
36458     }
36459 
36460     MBB->addSuccessor(DispatchBB);
36461 
36462     // Find the invoke call and mark all of the callee-saved registers as
36463     // 'implicit defined' so that they're spilled.  This prevents code from
36464     // moving instructions to before the EH block, where they will never be
36465     // executed.
36466     for (auto &II : reverse(*MBB)) {
36467       if (!II.isCall())
36468         continue;
36469 
36470       DenseMap<unsigned, bool> DefRegs;
36471       for (auto &MOp : II.operands())
36472         if (MOp.isReg())
36473           DefRegs[MOp.getReg()] = true;
36474 
36475       MachineInstrBuilder MIB(*MF, &II);
36476       for (unsigned RegIdx = 0; SavedRegs[RegIdx]; ++RegIdx) {
36477         unsigned Reg = SavedRegs[RegIdx];
36478         if (!DefRegs[Reg])
36479           MIB.addReg(Reg, RegState::ImplicitDefine | RegState::Dead);
36480       }
36481 
36482       break;
36483     }
36484   }
36485 
36486   // Mark all former landing pads as non-landing pads.  The dispatch is the only
36487   // landing pad now.
36488   for (auto &LP : MBBLPads)
36489     LP->setIsEHPad(false);
36490 
36491   // The instruction is gone now.
36492   MI.eraseFromParent();
36493   return BB;
36494 }
36495 
36496 MachineBasicBlock *
emitPatchableEventCall(MachineInstr & MI,MachineBasicBlock * BB) const36497 X86TargetLowering::emitPatchableEventCall(MachineInstr &MI,
36498                                           MachineBasicBlock *BB) const {
36499   // Wrap patchable event calls in CALLSEQ_START/CALLSEQ_END, as tracing
36500   // calls may require proper stack alignment.
36501   const TargetInstrInfo &TII = *Subtarget.getInstrInfo();
36502   const MIMetadata MIMD(MI);
36503   MachineFunction &MF = *BB->getParent();
36504 
36505   // Emit CALLSEQ_START right before the instruction.
36506   MF.getFrameInfo().setAdjustsStack(true);
36507   unsigned AdjStackDown = TII.getCallFrameSetupOpcode();
36508   MachineInstrBuilder CallseqStart =
36509       BuildMI(MF, MIMD, TII.get(AdjStackDown)).addImm(0).addImm(0).addImm(0);
36510   BB->insert(MachineBasicBlock::iterator(MI), CallseqStart);
36511 
36512   // Emit CALLSEQ_END right after the instruction.
36513   unsigned AdjStackUp = TII.getCallFrameDestroyOpcode();
36514   MachineInstrBuilder CallseqEnd =
36515       BuildMI(MF, MIMD, TII.get(AdjStackUp)).addImm(0).addImm(0);
36516   BB->insertAfter(MachineBasicBlock::iterator(MI), CallseqEnd);
36517 
36518   return BB;
36519 }
36520 
36521 MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr & MI,MachineBasicBlock * BB) const36522 X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
36523                                                MachineBasicBlock *BB) const {
36524   MachineFunction *MF = BB->getParent();
36525   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
36526   const MIMetadata MIMD(MI);
36527 
36528   auto TMMImmToTMMReg = [](unsigned Imm) {
36529     assert (Imm < 8 && "Illegal tmm index");
36530     return X86::TMM0 + Imm;
36531   };
36532   switch (MI.getOpcode()) {
36533   default: llvm_unreachable("Unexpected instr type to insert");
36534   case X86::TLS_addr32:
36535   case X86::TLS_addr64:
36536   case X86::TLS_addrX32:
36537   case X86::TLS_base_addr32:
36538   case X86::TLS_base_addr64:
36539   case X86::TLS_base_addrX32:
36540   case X86::TLS_desc32:
36541   case X86::TLS_desc64:
36542     return EmitLoweredTLSAddr(MI, BB);
36543   case X86::INDIRECT_THUNK_CALL32:
36544   case X86::INDIRECT_THUNK_CALL64:
36545   case X86::INDIRECT_THUNK_TCRETURN32:
36546   case X86::INDIRECT_THUNK_TCRETURN64:
36547     return EmitLoweredIndirectThunk(MI, BB);
36548   case X86::CATCHRET:
36549     return EmitLoweredCatchRet(MI, BB);
36550   case X86::SEG_ALLOCA_32:
36551   case X86::SEG_ALLOCA_64:
36552     return EmitLoweredSegAlloca(MI, BB);
36553   case X86::PROBED_ALLOCA_32:
36554   case X86::PROBED_ALLOCA_64:
36555     return EmitLoweredProbedAlloca(MI, BB);
36556   case X86::TLSCall_32:
36557   case X86::TLSCall_64:
36558     return EmitLoweredTLSCall(MI, BB);
36559   case X86::CMOV_FR16:
36560   case X86::CMOV_FR16X:
36561   case X86::CMOV_FR32:
36562   case X86::CMOV_FR32X:
36563   case X86::CMOV_FR64:
36564   case X86::CMOV_FR64X:
36565   case X86::CMOV_GR8:
36566   case X86::CMOV_GR16:
36567   case X86::CMOV_GR32:
36568   case X86::CMOV_RFP32:
36569   case X86::CMOV_RFP64:
36570   case X86::CMOV_RFP80:
36571   case X86::CMOV_VR64:
36572   case X86::CMOV_VR128:
36573   case X86::CMOV_VR128X:
36574   case X86::CMOV_VR256:
36575   case X86::CMOV_VR256X:
36576   case X86::CMOV_VR512:
36577   case X86::CMOV_VK1:
36578   case X86::CMOV_VK2:
36579   case X86::CMOV_VK4:
36580   case X86::CMOV_VK8:
36581   case X86::CMOV_VK16:
36582   case X86::CMOV_VK32:
36583   case X86::CMOV_VK64:
36584     return EmitLoweredSelect(MI, BB);
36585 
36586   case X86::FP80_ADDr:
36587   case X86::FP80_ADDm32: {
36588     // Change the floating point control register to use double extended
36589     // precision when performing the addition.
36590     int OrigCWFrameIdx =
36591         MF->getFrameInfo().CreateStackObject(2, Align(2), false);
36592     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::FNSTCW16m)),
36593                       OrigCWFrameIdx);
36594 
36595     // Load the old value of the control word...
36596     Register OldCW = MF->getRegInfo().createVirtualRegister(&X86::GR32RegClass);
36597     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::MOVZX32rm16), OldCW),
36598                       OrigCWFrameIdx);
36599 
36600     // OR 0b11 into bit 8 and 9. 0b11 is the encoding for double extended
36601     // precision.
36602     Register NewCW = MF->getRegInfo().createVirtualRegister(&X86::GR32RegClass);
36603     BuildMI(*BB, MI, MIMD, TII->get(X86::OR32ri), NewCW)
36604         .addReg(OldCW, RegState::Kill)
36605         .addImm(0x300);
36606 
36607     // Extract to 16 bits.
36608     Register NewCW16 =
36609         MF->getRegInfo().createVirtualRegister(&X86::GR16RegClass);
36610     BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), NewCW16)
36611         .addReg(NewCW, RegState::Kill, X86::sub_16bit);
36612 
36613     // Prepare memory for FLDCW.
36614     int NewCWFrameIdx =
36615         MF->getFrameInfo().CreateStackObject(2, Align(2), false);
36616     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::MOV16mr)),
36617                       NewCWFrameIdx)
36618         .addReg(NewCW16, RegState::Kill);
36619 
36620     // Reload the modified control word now...
36621     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::FLDCW16m)),
36622                       NewCWFrameIdx);
36623 
36624     // Do the addition.
36625     if (MI.getOpcode() == X86::FP80_ADDr) {
36626       BuildMI(*BB, MI, MIMD, TII->get(X86::ADD_Fp80))
36627           .add(MI.getOperand(0))
36628           .add(MI.getOperand(1))
36629           .add(MI.getOperand(2));
36630     } else {
36631       BuildMI(*BB, MI, MIMD, TII->get(X86::ADD_Fp80m32))
36632           .add(MI.getOperand(0))
36633           .add(MI.getOperand(1))
36634           .add(MI.getOperand(2))
36635           .add(MI.getOperand(3))
36636           .add(MI.getOperand(4))
36637           .add(MI.getOperand(5))
36638           .add(MI.getOperand(6));
36639     }
36640 
36641     // Reload the original control word now.
36642     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::FLDCW16m)),
36643                       OrigCWFrameIdx);
36644 
36645     MI.eraseFromParent(); // The pseudo instruction is gone now.
36646     return BB;
36647   }
36648 
36649   case X86::FP32_TO_INT16_IN_MEM:
36650   case X86::FP32_TO_INT32_IN_MEM:
36651   case X86::FP32_TO_INT64_IN_MEM:
36652   case X86::FP64_TO_INT16_IN_MEM:
36653   case X86::FP64_TO_INT32_IN_MEM:
36654   case X86::FP64_TO_INT64_IN_MEM:
36655   case X86::FP80_TO_INT16_IN_MEM:
36656   case X86::FP80_TO_INT32_IN_MEM:
36657   case X86::FP80_TO_INT64_IN_MEM: {
36658     // Change the floating point control register to use "round towards zero"
36659     // mode when truncating to an integer value.
36660     int OrigCWFrameIdx =
36661         MF->getFrameInfo().CreateStackObject(2, Align(2), false);
36662     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::FNSTCW16m)),
36663                       OrigCWFrameIdx);
36664 
36665     // Load the old value of the control word...
36666     Register OldCW = MF->getRegInfo().createVirtualRegister(&X86::GR32RegClass);
36667     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::MOVZX32rm16), OldCW),
36668                       OrigCWFrameIdx);
36669 
36670     // OR 0b11 into bit 10 and 11. 0b11 is the encoding for round toward zero.
36671     Register NewCW = MF->getRegInfo().createVirtualRegister(&X86::GR32RegClass);
36672     BuildMI(*BB, MI, MIMD, TII->get(X86::OR32ri), NewCW)
36673       .addReg(OldCW, RegState::Kill).addImm(0xC00);
36674 
36675     // Extract to 16 bits.
36676     Register NewCW16 =
36677         MF->getRegInfo().createVirtualRegister(&X86::GR16RegClass);
36678     BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), NewCW16)
36679       .addReg(NewCW, RegState::Kill, X86::sub_16bit);
36680 
36681     // Prepare memory for FLDCW.
36682     int NewCWFrameIdx =
36683         MF->getFrameInfo().CreateStackObject(2, Align(2), false);
36684     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::MOV16mr)),
36685                       NewCWFrameIdx)
36686       .addReg(NewCW16, RegState::Kill);
36687 
36688     // Reload the modified control word now...
36689     addFrameReference(BuildMI(*BB, MI, MIMD,
36690                               TII->get(X86::FLDCW16m)), NewCWFrameIdx);
36691 
36692     // Get the X86 opcode to use.
36693     unsigned Opc;
36694     switch (MI.getOpcode()) {
36695     // clang-format off
36696     default: llvm_unreachable("illegal opcode!");
36697     case X86::FP32_TO_INT16_IN_MEM: Opc = X86::IST_Fp16m32; break;
36698     case X86::FP32_TO_INT32_IN_MEM: Opc = X86::IST_Fp32m32; break;
36699     case X86::FP32_TO_INT64_IN_MEM: Opc = X86::IST_Fp64m32; break;
36700     case X86::FP64_TO_INT16_IN_MEM: Opc = X86::IST_Fp16m64; break;
36701     case X86::FP64_TO_INT32_IN_MEM: Opc = X86::IST_Fp32m64; break;
36702     case X86::FP64_TO_INT64_IN_MEM: Opc = X86::IST_Fp64m64; break;
36703     case X86::FP80_TO_INT16_IN_MEM: Opc = X86::IST_Fp16m80; break;
36704     case X86::FP80_TO_INT32_IN_MEM: Opc = X86::IST_Fp32m80; break;
36705     case X86::FP80_TO_INT64_IN_MEM: Opc = X86::IST_Fp64m80; break;
36706     // clang-format on
36707     }
36708 
36709     X86AddressMode AM = getAddressFromInstr(&MI, 0);
36710     addFullAddress(BuildMI(*BB, MI, MIMD, TII->get(Opc)), AM)
36711         .addReg(MI.getOperand(X86::AddrNumOperands).getReg());
36712 
36713     // Reload the original control word now.
36714     addFrameReference(BuildMI(*BB, MI, MIMD, TII->get(X86::FLDCW16m)),
36715                       OrigCWFrameIdx);
36716 
36717     MI.eraseFromParent(); // The pseudo instruction is gone now.
36718     return BB;
36719   }
36720 
36721   // xbegin
36722   case X86::XBEGIN:
36723     return emitXBegin(MI, BB, Subtarget.getInstrInfo());
36724 
36725   case X86::VAARG_64:
36726   case X86::VAARG_X32:
36727     return EmitVAARGWithCustomInserter(MI, BB);
36728 
36729   case X86::EH_SjLj_SetJmp32:
36730   case X86::EH_SjLj_SetJmp64:
36731     return emitEHSjLjSetJmp(MI, BB);
36732 
36733   case X86::EH_SjLj_LongJmp32:
36734   case X86::EH_SjLj_LongJmp64:
36735     return emitEHSjLjLongJmp(MI, BB);
36736 
36737   case X86::Int_eh_sjlj_setup_dispatch:
36738     return EmitSjLjDispatchBlock(MI, BB);
36739 
36740   case TargetOpcode::STATEPOINT:
36741     // As an implementation detail, STATEPOINT shares the STACKMAP format at
36742     // this point in the process.  We diverge later.
36743     return emitPatchPoint(MI, BB);
36744 
36745   case TargetOpcode::STACKMAP:
36746   case TargetOpcode::PATCHPOINT:
36747     return emitPatchPoint(MI, BB);
36748 
36749   case TargetOpcode::PATCHABLE_EVENT_CALL:
36750   case TargetOpcode::PATCHABLE_TYPED_EVENT_CALL:
36751     return emitPatchableEventCall(MI, BB);
36752 
36753   case X86::LCMPXCHG8B: {
36754     const X86RegisterInfo *TRI = Subtarget.getRegisterInfo();
36755     // In addition to 4 E[ABCD] registers implied by encoding, CMPXCHG8B
36756     // requires a memory operand. If it happens that current architecture is
36757     // i686 and for current function we need a base pointer
36758     // - which is ESI for i686 - register allocator would not be able to
36759     // allocate registers for an address in form of X(%reg, %reg, Y)
36760     // - there never would be enough unreserved registers during regalloc
36761     // (without the need for base ptr the only option would be X(%edi, %esi, Y).
36762     // We are giving a hand to register allocator by precomputing the address in
36763     // a new vreg using LEA.
36764 
36765     // If it is not i686 or there is no base pointer - nothing to do here.
36766     if (!Subtarget.is32Bit() || !TRI->hasBasePointer(*MF))
36767       return BB;
36768 
36769     // Even though this code does not necessarily needs the base pointer to
36770     // be ESI, we check for that. The reason: if this assert fails, there are
36771     // some changes happened in the compiler base pointer handling, which most
36772     // probably have to be addressed somehow here.
36773     assert(TRI->getBaseRegister() == X86::ESI &&
36774            "LCMPXCHG8B custom insertion for i686 is written with X86::ESI as a "
36775            "base pointer in mind");
36776 
36777     MachineRegisterInfo &MRI = MF->getRegInfo();
36778     MVT SPTy = getPointerTy(MF->getDataLayout());
36779     const TargetRegisterClass *AddrRegClass = getRegClassFor(SPTy);
36780     Register computedAddrVReg = MRI.createVirtualRegister(AddrRegClass);
36781 
36782     X86AddressMode AM = getAddressFromInstr(&MI, 0);
36783     // Regalloc does not need any help when the memory operand of CMPXCHG8B
36784     // does not use index register.
36785     if (AM.IndexReg == X86::NoRegister)
36786       return BB;
36787 
36788     // After X86TargetLowering::ReplaceNodeResults CMPXCHG8B is glued to its
36789     // four operand definitions that are E[ABCD] registers. We skip them and
36790     // then insert the LEA.
36791     MachineBasicBlock::reverse_iterator RMBBI(MI.getReverseIterator());
36792     while (RMBBI != BB->rend() &&
36793            (RMBBI->definesRegister(X86::EAX, /*TRI=*/nullptr) ||
36794             RMBBI->definesRegister(X86::EBX, /*TRI=*/nullptr) ||
36795             RMBBI->definesRegister(X86::ECX, /*TRI=*/nullptr) ||
36796             RMBBI->definesRegister(X86::EDX, /*TRI=*/nullptr))) {
36797       ++RMBBI;
36798     }
36799     MachineBasicBlock::iterator MBBI(RMBBI);
36800     addFullAddress(
36801         BuildMI(*BB, *MBBI, MIMD, TII->get(X86::LEA32r), computedAddrVReg), AM);
36802 
36803     setDirectAddressInInstr(&MI, 0, computedAddrVReg);
36804 
36805     return BB;
36806   }
36807   case X86::LCMPXCHG16B_NO_RBX: {
36808     const X86RegisterInfo *TRI = Subtarget.getRegisterInfo();
36809     Register BasePtr = TRI->getBaseRegister();
36810     if (TRI->hasBasePointer(*MF) &&
36811         (BasePtr == X86::RBX || BasePtr == X86::EBX)) {
36812       if (!BB->isLiveIn(BasePtr))
36813         BB->addLiveIn(BasePtr);
36814       // Save RBX into a virtual register.
36815       Register SaveRBX =
36816           MF->getRegInfo().createVirtualRegister(&X86::GR64RegClass);
36817       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), SaveRBX)
36818           .addReg(X86::RBX);
36819       Register Dst = MF->getRegInfo().createVirtualRegister(&X86::GR64RegClass);
36820       MachineInstrBuilder MIB =
36821           BuildMI(*BB, MI, MIMD, TII->get(X86::LCMPXCHG16B_SAVE_RBX), Dst);
36822       for (unsigned Idx = 0; Idx < X86::AddrNumOperands; ++Idx)
36823         MIB.add(MI.getOperand(Idx));
36824       MIB.add(MI.getOperand(X86::AddrNumOperands));
36825       MIB.addReg(SaveRBX);
36826     } else {
36827       // Simple case, just copy the virtual register to RBX.
36828       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), X86::RBX)
36829           .add(MI.getOperand(X86::AddrNumOperands));
36830       MachineInstrBuilder MIB =
36831           BuildMI(*BB, MI, MIMD, TII->get(X86::LCMPXCHG16B));
36832       for (unsigned Idx = 0; Idx < X86::AddrNumOperands; ++Idx)
36833         MIB.add(MI.getOperand(Idx));
36834     }
36835     MI.eraseFromParent();
36836     return BB;
36837   }
36838   case X86::MWAITX: {
36839     const X86RegisterInfo *TRI = Subtarget.getRegisterInfo();
36840     Register BasePtr = TRI->getBaseRegister();
36841     bool IsRBX = (BasePtr == X86::RBX || BasePtr == X86::EBX);
36842     // If no need to save the base pointer, we generate MWAITXrrr,
36843     // else we generate pseudo MWAITX_SAVE_RBX.
36844     if (!IsRBX || !TRI->hasBasePointer(*MF)) {
36845       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), X86::ECX)
36846           .addReg(MI.getOperand(0).getReg());
36847       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), X86::EAX)
36848           .addReg(MI.getOperand(1).getReg());
36849       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), X86::EBX)
36850           .addReg(MI.getOperand(2).getReg());
36851       BuildMI(*BB, MI, MIMD, TII->get(X86::MWAITXrrr));
36852       MI.eraseFromParent();
36853     } else {
36854       if (!BB->isLiveIn(BasePtr)) {
36855         BB->addLiveIn(BasePtr);
36856       }
36857       // Parameters can be copied into ECX and EAX but not EBX yet.
36858       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), X86::ECX)
36859           .addReg(MI.getOperand(0).getReg());
36860       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), X86::EAX)
36861           .addReg(MI.getOperand(1).getReg());
36862       assert(Subtarget.is64Bit() && "Expected 64-bit mode!");
36863       // Save RBX into a virtual register.
36864       Register SaveRBX =
36865           MF->getRegInfo().createVirtualRegister(&X86::GR64RegClass);
36866       BuildMI(*BB, MI, MIMD, TII->get(TargetOpcode::COPY), SaveRBX)
36867           .addReg(X86::RBX);
36868       // Generate mwaitx pseudo.
36869       Register Dst = MF->getRegInfo().createVirtualRegister(&X86::GR64RegClass);
36870       BuildMI(*BB, MI, MIMD, TII->get(X86::MWAITX_SAVE_RBX))
36871           .addDef(Dst) // Destination tied in with SaveRBX.
36872           .addReg(MI.getOperand(2).getReg()) // input value of EBX.
36873           .addUse(SaveRBX);                  // Save of base pointer.
36874       MI.eraseFromParent();
36875     }
36876     return BB;
36877   }
36878   case TargetOpcode::PREALLOCATED_SETUP: {
36879     assert(Subtarget.is32Bit() && "preallocated only used in 32-bit");
36880     auto *MFI = MF->getInfo<X86MachineFunctionInfo>();
36881     MFI->setHasPreallocatedCall(true);
36882     int64_t PreallocatedId = MI.getOperand(0).getImm();
36883     size_t StackAdjustment = MFI->getPreallocatedStackSize(PreallocatedId);
36884     assert(StackAdjustment != 0 && "0 stack adjustment");
36885     LLVM_DEBUG(dbgs() << "PREALLOCATED_SETUP stack adjustment "
36886                       << StackAdjustment << "\n");
36887     BuildMI(*BB, MI, MIMD, TII->get(X86::SUB32ri), X86::ESP)
36888         .addReg(X86::ESP)
36889         .addImm(StackAdjustment);
36890     MI.eraseFromParent();
36891     return BB;
36892   }
36893   case TargetOpcode::PREALLOCATED_ARG: {
36894     assert(Subtarget.is32Bit() && "preallocated calls only used in 32-bit");
36895     int64_t PreallocatedId = MI.getOperand(1).getImm();
36896     int64_t ArgIdx = MI.getOperand(2).getImm();
36897     auto *MFI = MF->getInfo<X86MachineFunctionInfo>();
36898     size_t ArgOffset = MFI->getPreallocatedArgOffsets(PreallocatedId)[ArgIdx];
36899     LLVM_DEBUG(dbgs() << "PREALLOCATED_ARG arg index " << ArgIdx
36900                       << ", arg offset " << ArgOffset << "\n");
36901     // stack pointer + offset
36902     addRegOffset(BuildMI(*BB, MI, MIMD, TII->get(X86::LEA32r),
36903                          MI.getOperand(0).getReg()),
36904                  X86::ESP, false, ArgOffset);
36905     MI.eraseFromParent();
36906     return BB;
36907   }
36908   case X86::PTDPBSSD:
36909   case X86::PTDPBSUD:
36910   case X86::PTDPBUSD:
36911   case X86::PTDPBUUD:
36912   case X86::PTDPBF16PS:
36913   case X86::PTDPFP16PS: {
36914     unsigned Opc;
36915     switch (MI.getOpcode()) {
36916     // clang-format off
36917     default: llvm_unreachable("illegal opcode!");
36918     case X86::PTDPBSSD: Opc = X86::TDPBSSD; break;
36919     case X86::PTDPBSUD: Opc = X86::TDPBSUD; break;
36920     case X86::PTDPBUSD: Opc = X86::TDPBUSD; break;
36921     case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;
36922     case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;
36923     case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break;
36924     // clang-format on
36925     }
36926 
36927     MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc));
36928     MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define);
36929     MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef);
36930     MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef);
36931     MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef);
36932 
36933     MI.eraseFromParent(); // The pseudo is gone now.
36934     return BB;
36935   }
36936   case X86::PTILEZERO: {
36937     unsigned Imm = MI.getOperand(0).getImm();
36938     BuildMI(*BB, MI, MIMD, TII->get(X86::TILEZERO), TMMImmToTMMReg(Imm));
36939     MI.eraseFromParent(); // The pseudo is gone now.
36940     auto *MFI = MF->getInfo<X86MachineFunctionInfo>();
36941     MFI->setAMXProgModel(AMXProgModelEnum::DirectReg);
36942     return BB;
36943   }
36944   case X86::PTILEZEROV: {
36945     auto *MFI = MF->getInfo<X86MachineFunctionInfo>();
36946     MFI->setAMXProgModel(AMXProgModelEnum::ManagedRA);
36947     return BB;
36948   }
36949   case X86::PTILELOADD:
36950   case X86::PTILELOADDT1:
36951   case X86::PTILESTORED: {
36952     unsigned Opc;
36953     switch (MI.getOpcode()) {
36954     default: llvm_unreachable("illegal opcode!");
36955 #define GET_EGPR_IF_ENABLED(OPC) (Subtarget.hasEGPR() ? OPC##_EVEX : OPC)
36956     case X86::PTILELOADD:
36957       Opc = GET_EGPR_IF_ENABLED(X86::TILELOADD);
36958       break;
36959     case X86::PTILELOADDT1:
36960       Opc = GET_EGPR_IF_ENABLED(X86::TILELOADDT1);
36961       break;
36962     case X86::PTILESTORED:
36963       Opc = GET_EGPR_IF_ENABLED(X86::TILESTORED);
36964       break;
36965 #undef GET_EGPR_IF_ENABLED
36966     }
36967 
36968     MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc));
36969     unsigned CurOp = 0;
36970     if (Opc != X86::TILESTORED && Opc != X86::TILESTORED_EVEX)
36971       MIB.addReg(TMMImmToTMMReg(MI.getOperand(CurOp++).getImm()),
36972                  RegState::Define);
36973 
36974     MIB.add(MI.getOperand(CurOp++)); // base
36975     MIB.add(MI.getOperand(CurOp++)); // scale
36976     MIB.add(MI.getOperand(CurOp++)); // index -- stride
36977     MIB.add(MI.getOperand(CurOp++)); // displacement
36978     MIB.add(MI.getOperand(CurOp++)); // segment
36979 
36980     if (Opc == X86::TILESTORED || Opc == X86::TILESTORED_EVEX)
36981       MIB.addReg(TMMImmToTMMReg(MI.getOperand(CurOp++).getImm()),
36982                  RegState::Undef);
36983 
36984     MI.eraseFromParent(); // The pseudo is gone now.
36985     return BB;
36986   }
36987   case X86::PTCMMIMFP16PS:
36988   case X86::PTCMMRLFP16PS: {
36989     const MIMetadata MIMD(MI);
36990     unsigned Opc;
36991     switch (MI.getOpcode()) {
36992     // clang-format off
36993     default: llvm_unreachable("Unexpected instruction!");
36994     case X86::PTCMMIMFP16PS:     Opc = X86::TCMMIMFP16PS;     break;
36995     case X86::PTCMMRLFP16PS:     Opc = X86::TCMMRLFP16PS;     break;
36996     // clang-format on
36997     }
36998     MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc));
36999     MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define);
37000     MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef);
37001     MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef);
37002     MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef);
37003     MI.eraseFromParent(); // The pseudo is gone now.
37004     return BB;
37005   }
37006   }
37007 }
37008 
37009 //===----------------------------------------------------------------------===//
37010 //                           X86 Optimization Hooks
37011 //===----------------------------------------------------------------------===//
37012 
37013 bool
targetShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const37014 X86TargetLowering::targetShrinkDemandedConstant(SDValue Op,
37015                                                 const APInt &DemandedBits,
37016                                                 const APInt &DemandedElts,
37017                                                 TargetLoweringOpt &TLO) const {
37018   EVT VT = Op.getValueType();
37019   unsigned Opcode = Op.getOpcode();
37020   unsigned EltSize = VT.getScalarSizeInBits();
37021 
37022   if (VT.isVector()) {
37023     // If the constant is only all signbits in the active bits, then we should
37024     // extend it to the entire constant to allow it act as a boolean constant
37025     // vector.
37026     auto NeedsSignExtension = [&](SDValue V, unsigned ActiveBits) {
37027       if (!ISD::isBuildVectorOfConstantSDNodes(V.getNode()))
37028         return false;
37029       for (unsigned i = 0, e = V.getNumOperands(); i != e; ++i) {
37030         if (!DemandedElts[i] || V.getOperand(i).isUndef())
37031           continue;
37032         const APInt &Val = V.getConstantOperandAPInt(i);
37033         if (Val.getBitWidth() > Val.getNumSignBits() &&
37034             Val.trunc(ActiveBits).getNumSignBits() == ActiveBits)
37035           return true;
37036       }
37037       return false;
37038     };
37039     // For vectors - if we have a constant, then try to sign extend.
37040     // TODO: Handle AND cases.
37041     unsigned ActiveBits = DemandedBits.getActiveBits();
37042     if (EltSize > ActiveBits && EltSize > 1 && isTypeLegal(VT) &&
37043         (Opcode == ISD::OR || Opcode == ISD::XOR || Opcode == X86ISD::ANDNP) &&
37044         NeedsSignExtension(Op.getOperand(1), ActiveBits)) {
37045       EVT ExtSVT = EVT::getIntegerVT(*TLO.DAG.getContext(), ActiveBits);
37046       EVT ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtSVT,
37047                                    VT.getVectorNumElements());
37048       SDValue NewC =
37049           TLO.DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(Op), VT,
37050                           Op.getOperand(1), TLO.DAG.getValueType(ExtVT));
37051       SDValue NewOp =
37052           TLO.DAG.getNode(Opcode, SDLoc(Op), VT, Op.getOperand(0), NewC);
37053       return TLO.CombineTo(Op, NewOp);
37054     }
37055     return false;
37056   }
37057 
37058   // Only optimize Ands to prevent shrinking a constant that could be
37059   // matched by movzx.
37060   if (Opcode != ISD::AND)
37061     return false;
37062 
37063   // Make sure the RHS really is a constant.
37064   ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
37065   if (!C)
37066     return false;
37067 
37068   const APInt &Mask = C->getAPIntValue();
37069 
37070   // Clear all non-demanded bits initially.
37071   APInt ShrunkMask = Mask & DemandedBits;
37072 
37073   // Find the width of the shrunk mask.
37074   unsigned Width = ShrunkMask.getActiveBits();
37075 
37076   // If the mask is all 0s there's nothing to do here.
37077   if (Width == 0)
37078     return false;
37079 
37080   // Find the next power of 2 width, rounding up to a byte.
37081   Width = llvm::bit_ceil(std::max(Width, 8U));
37082   // Truncate the width to size to handle illegal types.
37083   Width = std::min(Width, EltSize);
37084 
37085   // Calculate a possible zero extend mask for this constant.
37086   APInt ZeroExtendMask = APInt::getLowBitsSet(EltSize, Width);
37087 
37088   // If we aren't changing the mask, just return true to keep it and prevent
37089   // the caller from optimizing.
37090   if (ZeroExtendMask == Mask)
37091     return true;
37092 
37093   // Make sure the new mask can be represented by a combination of mask bits
37094   // and non-demanded bits.
37095   if (!ZeroExtendMask.isSubsetOf(Mask | ~DemandedBits))
37096     return false;
37097 
37098   // Replace the constant with the zero extend mask.
37099   SDLoc DL(Op);
37100   SDValue NewC = TLO.DAG.getConstant(ZeroExtendMask, DL, VT);
37101   SDValue NewOp = TLO.DAG.getNode(ISD::AND, DL, VT, Op.getOperand(0), NewC);
37102   return TLO.CombineTo(Op, NewOp);
37103 }
37104 
computeKnownBitsForPSADBW(SDValue LHS,SDValue RHS,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth)37105 static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
37106                                       KnownBits &Known,
37107                                       const APInt &DemandedElts,
37108                                       const SelectionDAG &DAG, unsigned Depth) {
37109   KnownBits Known2;
37110   unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37111   APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37112   Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
37113   Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
37114   Known = KnownBits::abdu(Known, Known2).zext(16);
37115   // Known = (((D0 + D1) + (D2 + D3)) + ((D4 + D5) + (D6 + D7)))
37116   Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true,
37117                                       Known, Known);
37118   Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true,
37119                                       Known, Known);
37120   Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true,
37121                                       Known, Known);
37122   Known = Known.zext(64);
37123 }
37124 
computeKnownBitsForPMADDWD(SDValue LHS,SDValue RHS,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth)37125 static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
37126                                        KnownBits &Known,
37127                                        const APInt &DemandedElts,
37128                                        const SelectionDAG &DAG,
37129                                        unsigned Depth) {
37130   unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37131 
37132   // Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
37133   APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37134   APInt DemandedLoElts =
37135       DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
37136   APInt DemandedHiElts =
37137       DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
37138   KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
37139   KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
37140   KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
37141   KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
37142   KnownBits Lo = KnownBits::mul(LHSLo.sext(32), RHSLo.sext(32));
37143   KnownBits Hi = KnownBits::mul(LHSHi.sext(32), RHSHi.sext(32));
37144   Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
37145                                       /*NUW=*/false, Lo, Hi);
37146 }
37147 
computeKnownBitsForPMADDUBSW(SDValue LHS,SDValue RHS,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth)37148 static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
37149                                          KnownBits &Known,
37150                                          const APInt &DemandedElts,
37151                                          const SelectionDAG &DAG,
37152                                          unsigned Depth) {
37153   unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37154 
37155   // Multiply unsigned/signed i8 elements to create i16 values and add_sat Lo/Hi
37156   // pairs.
37157   APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37158   APInt DemandedLoElts =
37159       DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
37160   APInt DemandedHiElts =
37161       DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
37162   KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
37163   KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
37164   KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
37165   KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
37166   KnownBits Lo = KnownBits::mul(LHSLo.zext(16), RHSLo.sext(16));
37167   KnownBits Hi = KnownBits::mul(LHSHi.zext(16), RHSHi.sext(16));
37168   Known = KnownBits::sadd_sat(Lo, Hi);
37169 }
37170 
computeKnownBitsForHorizontalOperation(const SDValue Op,const APInt & DemandedElts,unsigned Depth,const SelectionDAG & DAG,const function_ref<KnownBits (const KnownBits &,const KnownBits &)> KnownBitsFunc)37171 static KnownBits computeKnownBitsForHorizontalOperation(
37172     const SDValue Op, const APInt &DemandedElts, unsigned Depth,
37173     const SelectionDAG &DAG,
37174     const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
37175         KnownBitsFunc) {
37176   APInt DemandedEltsLHS, DemandedEltsRHS;
37177   getHorizDemandedEltsForFirstOperand(Op.getValueType().getSizeInBits(),
37178                                       DemandedElts, DemandedEltsLHS,
37179                                       DemandedEltsRHS);
37180 
37181   const auto ComputeForSingleOpFunc =
37182       [&DAG, Depth, KnownBitsFunc](SDValue Op, APInt &DemandedEltsOp) {
37183         return KnownBitsFunc(
37184             DAG.computeKnownBits(Op, DemandedEltsOp, Depth + 1),
37185             DAG.computeKnownBits(Op, DemandedEltsOp << 1, Depth + 1));
37186       };
37187 
37188   if (DemandedEltsRHS.isZero())
37189     return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS);
37190   if (DemandedEltsLHS.isZero())
37191     return ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS);
37192 
37193   return ComputeForSingleOpFunc(Op.getOperand(0), DemandedEltsLHS)
37194       .intersectWith(ComputeForSingleOpFunc(Op.getOperand(1), DemandedEltsRHS));
37195 }
37196 
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const37197 void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
37198                                                       KnownBits &Known,
37199                                                       const APInt &DemandedElts,
37200                                                       const SelectionDAG &DAG,
37201                                                       unsigned Depth) const {
37202   unsigned BitWidth = Known.getBitWidth();
37203   unsigned NumElts = DemandedElts.getBitWidth();
37204   unsigned Opc = Op.getOpcode();
37205   EVT VT = Op.getValueType();
37206   assert((Opc >= ISD::BUILTIN_OP_END ||
37207           Opc == ISD::INTRINSIC_WO_CHAIN ||
37208           Opc == ISD::INTRINSIC_W_CHAIN ||
37209           Opc == ISD::INTRINSIC_VOID) &&
37210          "Should use MaskedValueIsZero if you don't know whether Op"
37211          " is a target node!");
37212 
37213   Known.resetAll();
37214   switch (Opc) {
37215   default: break;
37216   case X86ISD::MUL_IMM: {
37217     KnownBits Known2;
37218     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37219     Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37220     Known = KnownBits::mul(Known, Known2);
37221     break;
37222   }
37223   case X86ISD::SETCC:
37224     Known.Zero.setBitsFrom(1);
37225     break;
37226   case X86ISD::MOVMSK: {
37227     unsigned NumLoBits = Op.getOperand(0).getValueType().getVectorNumElements();
37228     Known.Zero.setBitsFrom(NumLoBits);
37229     break;
37230   }
37231   case X86ISD::PEXTRB:
37232   case X86ISD::PEXTRW: {
37233     SDValue Src = Op.getOperand(0);
37234     EVT SrcVT = Src.getValueType();
37235     APInt DemandedElt = APInt::getOneBitSet(SrcVT.getVectorNumElements(),
37236                                             Op.getConstantOperandVal(1));
37237     Known = DAG.computeKnownBits(Src, DemandedElt, Depth + 1);
37238     Known = Known.anyextOrTrunc(BitWidth);
37239     Known.Zero.setBitsFrom(SrcVT.getScalarSizeInBits());
37240     break;
37241   }
37242   case X86ISD::VSRAI:
37243   case X86ISD::VSHLI:
37244   case X86ISD::VSRLI: {
37245     unsigned ShAmt = Op.getConstantOperandVal(1);
37246     if (ShAmt >= VT.getScalarSizeInBits()) {
37247       // Out of range logical bit shifts are guaranteed to be zero.
37248       // Out of range arithmetic bit shifts splat the sign bit.
37249       if (Opc != X86ISD::VSRAI) {
37250         Known.setAllZero();
37251         break;
37252       }
37253 
37254       ShAmt = VT.getScalarSizeInBits() - 1;
37255     }
37256 
37257     Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37258     if (Opc == X86ISD::VSHLI) {
37259       Known.Zero <<= ShAmt;
37260       Known.One <<= ShAmt;
37261       // Low bits are known zero.
37262       Known.Zero.setLowBits(ShAmt);
37263     } else if (Opc == X86ISD::VSRLI) {
37264       Known.Zero.lshrInPlace(ShAmt);
37265       Known.One.lshrInPlace(ShAmt);
37266       // High bits are known zero.
37267       Known.Zero.setHighBits(ShAmt);
37268     } else {
37269       Known.Zero.ashrInPlace(ShAmt);
37270       Known.One.ashrInPlace(ShAmt);
37271     }
37272     break;
37273   }
37274   case X86ISD::PACKUS: {
37275     // PACKUS is just a truncation if the upper half is zero.
37276     APInt DemandedLHS, DemandedRHS;
37277     getPackDemandedElts(VT, DemandedElts, DemandedLHS, DemandedRHS);
37278 
37279     Known.One = APInt::getAllOnes(BitWidth * 2);
37280     Known.Zero = APInt::getAllOnes(BitWidth * 2);
37281 
37282     KnownBits Known2;
37283     if (!!DemandedLHS) {
37284       Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedLHS, Depth + 1);
37285       Known = Known.intersectWith(Known2);
37286     }
37287     if (!!DemandedRHS) {
37288       Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedRHS, Depth + 1);
37289       Known = Known.intersectWith(Known2);
37290     }
37291 
37292     if (Known.countMinLeadingZeros() < BitWidth)
37293       Known.resetAll();
37294     Known = Known.trunc(BitWidth);
37295     break;
37296   }
37297   case X86ISD::PSHUFB: {
37298     SDValue Src = Op.getOperand(0);
37299     SDValue Idx = Op.getOperand(1);
37300 
37301     // If the index vector is never negative (MSB is zero), then all elements
37302     // come from the source vector. This is useful for cases where
37303     // PSHUFB is being used as a LUT (ctpop etc.) - the target shuffle handling
37304     // below will handle the more common constant shuffle mask case.
37305     KnownBits KnownIdx = DAG.computeKnownBits(Idx, DemandedElts, Depth + 1);
37306     if (KnownIdx.isNonNegative())
37307       Known = DAG.computeKnownBits(Src, Depth + 1);
37308     break;
37309   }
37310   case X86ISD::VBROADCAST: {
37311     SDValue Src = Op.getOperand(0);
37312     if (!Src.getSimpleValueType().isVector()) {
37313       Known = DAG.computeKnownBits(Src, Depth + 1);
37314       return;
37315     }
37316     break;
37317   }
37318   case X86ISD::AND: {
37319     if (Op.getResNo() == 0) {
37320       KnownBits Known2;
37321       Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37322       Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37323       Known &= Known2;
37324     }
37325     break;
37326   }
37327   case X86ISD::ANDNP: {
37328     KnownBits Known2;
37329     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37330     Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37331 
37332     // ANDNP = (~X & Y);
37333     Known.One &= Known2.Zero;
37334     Known.Zero |= Known2.One;
37335     break;
37336   }
37337   case X86ISD::FOR: {
37338     KnownBits Known2;
37339     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37340     Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37341 
37342     Known |= Known2;
37343     break;
37344   }
37345   case X86ISD::PSADBW: {
37346     SDValue LHS = Op.getOperand(0);
37347     SDValue RHS = Op.getOperand(1);
37348     assert(VT.getScalarType() == MVT::i64 &&
37349            LHS.getValueType() == RHS.getValueType() &&
37350            LHS.getValueType().getScalarType() == MVT::i8 &&
37351            "Unexpected PSADBW types");
37352     computeKnownBitsForPSADBW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37353     break;
37354   }
37355   case X86ISD::PCMPGT:
37356   case X86ISD::PCMPEQ: {
37357     KnownBits KnownLhs =
37358         DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37359     KnownBits KnownRhs =
37360         DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37361     std::optional<bool> Res = Opc == X86ISD::PCMPEQ
37362                                   ? KnownBits::eq(KnownLhs, KnownRhs)
37363                                   : KnownBits::sgt(KnownLhs, KnownRhs);
37364     if (Res) {
37365       if (*Res)
37366         Known.setAllOnes();
37367       else
37368         Known.setAllZero();
37369     }
37370     break;
37371   }
37372   case X86ISD::VPMADDWD: {
37373     SDValue LHS = Op.getOperand(0);
37374     SDValue RHS = Op.getOperand(1);
37375     assert(VT.getVectorElementType() == MVT::i32 &&
37376            LHS.getValueType() == RHS.getValueType() &&
37377            LHS.getValueType().getVectorElementType() == MVT::i16 &&
37378            "Unexpected PMADDWD types");
37379     computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
37380     break;
37381   }
37382   case X86ISD::VPMADDUBSW: {
37383     SDValue LHS = Op.getOperand(0);
37384     SDValue RHS = Op.getOperand(1);
37385     assert(VT.getVectorElementType() == MVT::i16 &&
37386            LHS.getValueType() == RHS.getValueType() &&
37387            LHS.getValueType().getVectorElementType() == MVT::i8 &&
37388            "Unexpected PMADDUBSW types");
37389     computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37390     break;
37391   }
37392   case X86ISD::PMULUDQ: {
37393     KnownBits Known2;
37394     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37395     Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37396 
37397     Known = Known.trunc(BitWidth / 2).zext(BitWidth);
37398     Known2 = Known2.trunc(BitWidth / 2).zext(BitWidth);
37399     Known = KnownBits::mul(Known, Known2);
37400     break;
37401   }
37402   case X86ISD::CMOV: {
37403     Known = DAG.computeKnownBits(Op.getOperand(1), Depth + 1);
37404     // If we don't know any bits, early out.
37405     if (Known.isUnknown())
37406       break;
37407     KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
37408 
37409     // Only known if known in both the LHS and RHS.
37410     Known = Known.intersectWith(Known2);
37411     break;
37412   }
37413   case X86ISD::BEXTR:
37414   case X86ISD::BEXTRI: {
37415     SDValue Op0 = Op.getOperand(0);
37416     SDValue Op1 = Op.getOperand(1);
37417 
37418     if (auto* Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
37419       unsigned Shift = Cst1->getAPIntValue().extractBitsAsZExtValue(8, 0);
37420       unsigned Length = Cst1->getAPIntValue().extractBitsAsZExtValue(8, 8);
37421 
37422       // If the length is 0, the result is 0.
37423       if (Length == 0) {
37424         Known.setAllZero();
37425         break;
37426       }
37427 
37428       if ((Shift + Length) <= BitWidth) {
37429         Known = DAG.computeKnownBits(Op0, Depth + 1);
37430         Known = Known.extractBits(Length, Shift);
37431         Known = Known.zextOrTrunc(BitWidth);
37432       }
37433     }
37434     break;
37435   }
37436   case X86ISD::PDEP: {
37437     KnownBits Known2;
37438     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37439     Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
37440     // Zeros are retained from the mask operand. But not ones.
37441     Known.One.clearAllBits();
37442     // The result will have at least as many trailing zeros as the non-mask
37443     // operand since bits can only map to the same or higher bit position.
37444     Known.Zero.setLowBits(Known2.countMinTrailingZeros());
37445     break;
37446   }
37447   case X86ISD::PEXT: {
37448     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
37449     // The result has as many leading zeros as the number of zeroes in the mask.
37450     unsigned Count = Known.Zero.popcount();
37451     Known.Zero = APInt::getHighBitsSet(BitWidth, Count);
37452     Known.One.clearAllBits();
37453     break;
37454   }
37455   case X86ISD::VTRUNC:
37456   case X86ISD::VTRUNCS:
37457   case X86ISD::VTRUNCUS:
37458   case X86ISD::CVTSI2P:
37459   case X86ISD::CVTUI2P:
37460   case X86ISD::CVTP2SI:
37461   case X86ISD::CVTP2UI:
37462   case X86ISD::MCVTP2SI:
37463   case X86ISD::MCVTP2UI:
37464   case X86ISD::CVTTP2SI:
37465   case X86ISD::CVTTP2UI:
37466   case X86ISD::MCVTTP2SI:
37467   case X86ISD::MCVTTP2UI:
37468   case X86ISD::MCVTSI2P:
37469   case X86ISD::MCVTUI2P:
37470   case X86ISD::VFPROUND:
37471   case X86ISD::VMFPROUND:
37472   case X86ISD::CVTPS2PH:
37473   case X86ISD::MCVTPS2PH: {
37474     // Truncations/Conversions - upper elements are known zero.
37475     EVT SrcVT = Op.getOperand(0).getValueType();
37476     if (SrcVT.isVector()) {
37477       unsigned NumSrcElts = SrcVT.getVectorNumElements();
37478       if (NumElts > NumSrcElts && DemandedElts.countr_zero() >= NumSrcElts)
37479         Known.setAllZero();
37480     }
37481     break;
37482   }
37483   case X86ISD::STRICT_CVTTP2SI:
37484   case X86ISD::STRICT_CVTTP2UI:
37485   case X86ISD::STRICT_CVTSI2P:
37486   case X86ISD::STRICT_CVTUI2P:
37487   case X86ISD::STRICT_VFPROUND:
37488   case X86ISD::STRICT_CVTPS2PH: {
37489     // Strict Conversions - upper elements are known zero.
37490     EVT SrcVT = Op.getOperand(1).getValueType();
37491     if (SrcVT.isVector()) {
37492       unsigned NumSrcElts = SrcVT.getVectorNumElements();
37493       if (NumElts > NumSrcElts && DemandedElts.countr_zero() >= NumSrcElts)
37494         Known.setAllZero();
37495     }
37496     break;
37497   }
37498   case X86ISD::MOVQ2DQ: {
37499     // Move from MMX to XMM. Upper half of XMM should be 0.
37500     if (DemandedElts.countr_zero() >= (NumElts / 2))
37501       Known.setAllZero();
37502     break;
37503   }
37504   case X86ISD::VBROADCAST_LOAD: {
37505     APInt UndefElts;
37506     SmallVector<APInt, 16> EltBits;
37507     if (getTargetConstantBitsFromNode(Op, BitWidth, UndefElts, EltBits,
37508                                       /*AllowWholeUndefs*/ false,
37509                                       /*AllowPartialUndefs*/ false)) {
37510       Known.Zero.setAllBits();
37511       Known.One.setAllBits();
37512       for (unsigned I = 0; I != NumElts; ++I) {
37513         if (!DemandedElts[I])
37514           continue;
37515         if (UndefElts[I]) {
37516           Known.resetAll();
37517           break;
37518         }
37519         KnownBits Known2 = KnownBits::makeConstant(EltBits[I]);
37520         Known = Known.intersectWith(Known2);
37521       }
37522       return;
37523     }
37524     break;
37525   }
37526   case X86ISD::HADD:
37527   case X86ISD::HSUB: {
37528     Known = computeKnownBitsForHorizontalOperation(
37529         Op, DemandedElts, Depth, DAG,
37530         [Opc](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
37531           return KnownBits::computeForAddSub(
37532               /*Add=*/Opc == X86ISD::HADD, /*NSW=*/false, /*NUW=*/false,
37533               KnownLHS, KnownRHS);
37534         });
37535     break;
37536   }
37537   case ISD::INTRINSIC_WO_CHAIN: {
37538     switch (Op->getConstantOperandVal(0)) {
37539     case Intrinsic::x86_sse2_pmadd_wd:
37540     case Intrinsic::x86_avx2_pmadd_wd:
37541     case Intrinsic::x86_avx512_pmaddw_d_512: {
37542       SDValue LHS = Op.getOperand(1);
37543       SDValue RHS = Op.getOperand(2);
37544       assert(VT.getScalarType() == MVT::i32 &&
37545              LHS.getValueType() == RHS.getValueType() &&
37546              LHS.getValueType().getScalarType() == MVT::i16 &&
37547              "Unexpected PMADDWD types");
37548       computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
37549       break;
37550     }
37551     case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
37552     case Intrinsic::x86_avx2_pmadd_ub_sw:
37553     case Intrinsic::x86_avx512_pmaddubs_w_512: {
37554       SDValue LHS = Op.getOperand(1);
37555       SDValue RHS = Op.getOperand(2);
37556       assert(VT.getScalarType() == MVT::i16 &&
37557              LHS.getValueType() == RHS.getValueType() &&
37558              LHS.getValueType().getScalarType() == MVT::i8 &&
37559              "Unexpected PMADDUBSW types");
37560       computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37561       break;
37562     }
37563     case Intrinsic::x86_sse2_psad_bw:
37564     case Intrinsic::x86_avx2_psad_bw:
37565     case Intrinsic::x86_avx512_psad_bw_512: {
37566       SDValue LHS = Op.getOperand(1);
37567       SDValue RHS = Op.getOperand(2);
37568       assert(VT.getScalarType() == MVT::i64 &&
37569              LHS.getValueType() == RHS.getValueType() &&
37570              LHS.getValueType().getScalarType() == MVT::i8 &&
37571              "Unexpected PSADBW types");
37572       computeKnownBitsForPSADBW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37573       break;
37574     }
37575     }
37576     break;
37577   }
37578   }
37579 
37580   // Handle target shuffles.
37581   // TODO - use resolveTargetShuffleInputs once we can limit recursive depth.
37582   if (isTargetShuffle(Opc)) {
37583     SmallVector<int, 64> Mask;
37584     SmallVector<SDValue, 2> Ops;
37585     if (getTargetShuffleMask(Op, true, Ops, Mask)) {
37586       unsigned NumOps = Ops.size();
37587       unsigned NumElts = VT.getVectorNumElements();
37588       if (Mask.size() == NumElts) {
37589         SmallVector<APInt, 2> DemandedOps(NumOps, APInt(NumElts, 0));
37590         Known.Zero.setAllBits(); Known.One.setAllBits();
37591         for (unsigned i = 0; i != NumElts; ++i) {
37592           if (!DemandedElts[i])
37593             continue;
37594           int M = Mask[i];
37595           if (M == SM_SentinelUndef) {
37596             // For UNDEF elements, we don't know anything about the common state
37597             // of the shuffle result.
37598             Known.resetAll();
37599             break;
37600           }
37601           if (M == SM_SentinelZero) {
37602             Known.One.clearAllBits();
37603             continue;
37604           }
37605           assert(0 <= M && (unsigned)M < (NumOps * NumElts) &&
37606                  "Shuffle index out of range");
37607 
37608           unsigned OpIdx = (unsigned)M / NumElts;
37609           unsigned EltIdx = (unsigned)M % NumElts;
37610           if (Ops[OpIdx].getValueType() != VT) {
37611             // TODO - handle target shuffle ops with different value types.
37612             Known.resetAll();
37613             break;
37614           }
37615           DemandedOps[OpIdx].setBit(EltIdx);
37616         }
37617         // Known bits are the values that are shared by every demanded element.
37618         for (unsigned i = 0; i != NumOps && !Known.isUnknown(); ++i) {
37619           if (!DemandedOps[i])
37620             continue;
37621           KnownBits Known2 =
37622               DAG.computeKnownBits(Ops[i], DemandedOps[i], Depth + 1);
37623           Known = Known.intersectWith(Known2);
37624         }
37625       }
37626     }
37627   }
37628 }
37629 
ComputeNumSignBitsForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const37630 unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode(
37631     SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
37632     unsigned Depth) const {
37633   EVT VT = Op.getValueType();
37634   unsigned VTBits = VT.getScalarSizeInBits();
37635   unsigned Opcode = Op.getOpcode();
37636   switch (Opcode) {
37637   case X86ISD::SETCC_CARRY:
37638     // SETCC_CARRY sets the dest to ~0 for true or 0 for false.
37639     return VTBits;
37640 
37641   case X86ISD::VTRUNC: {
37642     SDValue Src = Op.getOperand(0);
37643     MVT SrcVT = Src.getSimpleValueType();
37644     unsigned NumSrcBits = SrcVT.getScalarSizeInBits();
37645     assert(VTBits < NumSrcBits && "Illegal truncation input type");
37646     APInt DemandedSrc = DemandedElts.zextOrTrunc(SrcVT.getVectorNumElements());
37647     unsigned Tmp = DAG.ComputeNumSignBits(Src, DemandedSrc, Depth + 1);
37648     if (Tmp > (NumSrcBits - VTBits))
37649       return Tmp - (NumSrcBits - VTBits);
37650     return 1;
37651   }
37652 
37653   case X86ISD::PACKSS: {
37654     // PACKSS is just a truncation if the sign bits extend to the packed size.
37655     APInt DemandedLHS, DemandedRHS;
37656     getPackDemandedElts(Op.getValueType(), DemandedElts, DemandedLHS,
37657                         DemandedRHS);
37658 
37659     // Helper to detect PACKSSDW(BITCAST(PACKSSDW(X)),BITCAST(PACKSSDW(Y)))
37660     // patterns often used to compact vXi64 allsignbit patterns.
37661     auto NumSignBitsPACKSS = [&](SDValue V, const APInt &Elts) -> unsigned {
37662       SDValue BC = peekThroughBitcasts(V);
37663       if (BC.getOpcode() == X86ISD::PACKSS &&
37664           BC.getScalarValueSizeInBits() == 16 &&
37665           V.getScalarValueSizeInBits() == 32) {
37666         SDValue BC0 = peekThroughBitcasts(BC.getOperand(0));
37667         SDValue BC1 = peekThroughBitcasts(BC.getOperand(1));
37668         if (BC0.getScalarValueSizeInBits() == 64 &&
37669             BC1.getScalarValueSizeInBits() == 64 &&
37670             DAG.ComputeNumSignBits(BC0, Depth + 1) == 64 &&
37671             DAG.ComputeNumSignBits(BC1, Depth + 1) == 64)
37672           return 32;
37673       }
37674       return DAG.ComputeNumSignBits(V, Elts, Depth + 1);
37675     };
37676 
37677     unsigned SrcBits = Op.getOperand(0).getScalarValueSizeInBits();
37678     unsigned Tmp0 = SrcBits, Tmp1 = SrcBits;
37679     if (!!DemandedLHS)
37680       Tmp0 = NumSignBitsPACKSS(Op.getOperand(0), DemandedLHS);
37681     if (!!DemandedRHS)
37682       Tmp1 = NumSignBitsPACKSS(Op.getOperand(1), DemandedRHS);
37683     unsigned Tmp = std::min(Tmp0, Tmp1);
37684     if (Tmp > (SrcBits - VTBits))
37685       return Tmp - (SrcBits - VTBits);
37686     return 1;
37687   }
37688 
37689   case X86ISD::VBROADCAST: {
37690     SDValue Src = Op.getOperand(0);
37691     if (!Src.getSimpleValueType().isVector())
37692       return DAG.ComputeNumSignBits(Src, Depth + 1);
37693     break;
37694   }
37695 
37696   case X86ISD::VSHLI: {
37697     SDValue Src = Op.getOperand(0);
37698     const APInt &ShiftVal = Op.getConstantOperandAPInt(1);
37699     if (ShiftVal.uge(VTBits))
37700       return VTBits; // Shifted all bits out --> zero.
37701     unsigned Tmp = DAG.ComputeNumSignBits(Src, DemandedElts, Depth + 1);
37702     if (ShiftVal.uge(Tmp))
37703       return 1; // Shifted all sign bits out --> unknown.
37704     return Tmp - ShiftVal.getZExtValue();
37705   }
37706 
37707   case X86ISD::VSRAI: {
37708     SDValue Src = Op.getOperand(0);
37709     APInt ShiftVal = Op.getConstantOperandAPInt(1);
37710     if (ShiftVal.uge(VTBits - 1))
37711       return VTBits; // Sign splat.
37712     unsigned Tmp = DAG.ComputeNumSignBits(Src, DemandedElts, Depth + 1);
37713     ShiftVal += Tmp;
37714     return ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue();
37715   }
37716 
37717   case X86ISD::FSETCC:
37718     // cmpss/cmpsd return zero/all-bits result values in the bottom element.
37719     if (VT == MVT::f32 || VT == MVT::f64 ||
37720         ((VT == MVT::v4f32 || VT == MVT::v2f64) && DemandedElts == 1))
37721       return VTBits;
37722     break;
37723 
37724   case X86ISD::PCMPGT:
37725   case X86ISD::PCMPEQ:
37726   case X86ISD::CMPP:
37727   case X86ISD::VPCOM:
37728   case X86ISD::VPCOMU:
37729     // Vector compares return zero/all-bits result values.
37730     return VTBits;
37731 
37732   case X86ISD::ANDNP: {
37733     unsigned Tmp0 =
37734         DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
37735     if (Tmp0 == 1) return 1; // Early out.
37736     unsigned Tmp1 =
37737         DAG.ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth + 1);
37738     return std::min(Tmp0, Tmp1);
37739   }
37740 
37741   case X86ISD::CMOV: {
37742     unsigned Tmp0 = DAG.ComputeNumSignBits(Op.getOperand(0), Depth+1);
37743     if (Tmp0 == 1) return 1;  // Early out.
37744     unsigned Tmp1 = DAG.ComputeNumSignBits(Op.getOperand(1), Depth+1);
37745     return std::min(Tmp0, Tmp1);
37746   }
37747   }
37748 
37749   // Handle target shuffles.
37750   // TODO - use resolveTargetShuffleInputs once we can limit recursive depth.
37751   if (isTargetShuffle(Opcode)) {
37752     SmallVector<int, 64> Mask;
37753     SmallVector<SDValue, 2> Ops;
37754     if (getTargetShuffleMask(Op, true, Ops, Mask)) {
37755       unsigned NumOps = Ops.size();
37756       unsigned NumElts = VT.getVectorNumElements();
37757       if (Mask.size() == NumElts) {
37758         SmallVector<APInt, 2> DemandedOps(NumOps, APInt(NumElts, 0));
37759         for (unsigned i = 0; i != NumElts; ++i) {
37760           if (!DemandedElts[i])
37761             continue;
37762           int M = Mask[i];
37763           if (M == SM_SentinelUndef) {
37764             // For UNDEF elements, we don't know anything about the common state
37765             // of the shuffle result.
37766             return 1;
37767           } else if (M == SM_SentinelZero) {
37768             // Zero = all sign bits.
37769             continue;
37770           }
37771           assert(0 <= M && (unsigned)M < (NumOps * NumElts) &&
37772                  "Shuffle index out of range");
37773 
37774           unsigned OpIdx = (unsigned)M / NumElts;
37775           unsigned EltIdx = (unsigned)M % NumElts;
37776           if (Ops[OpIdx].getValueType() != VT) {
37777             // TODO - handle target shuffle ops with different value types.
37778             return 1;
37779           }
37780           DemandedOps[OpIdx].setBit(EltIdx);
37781         }
37782         unsigned Tmp0 = VTBits;
37783         for (unsigned i = 0; i != NumOps && Tmp0 > 1; ++i) {
37784           if (!DemandedOps[i])
37785             continue;
37786           unsigned Tmp1 =
37787               DAG.ComputeNumSignBits(Ops[i], DemandedOps[i], Depth + 1);
37788           Tmp0 = std::min(Tmp0, Tmp1);
37789         }
37790         return Tmp0;
37791       }
37792     }
37793   }
37794 
37795   // Fallback case.
37796   return 1;
37797 }
37798 
unwrapAddress(SDValue N) const37799 SDValue X86TargetLowering::unwrapAddress(SDValue N) const {
37800   if (N->getOpcode() == X86ISD::Wrapper || N->getOpcode() == X86ISD::WrapperRIP)
37801     return N->getOperand(0);
37802   return N;
37803 }
37804 
37805 // Helper to look for a normal load that can be narrowed into a vzload with the
37806 // specified VT and memory VT. Returns SDValue() on failure.
narrowLoadToVZLoad(LoadSDNode * LN,MVT MemVT,MVT VT,SelectionDAG & DAG)37807 static SDValue narrowLoadToVZLoad(LoadSDNode *LN, MVT MemVT, MVT VT,
37808                                   SelectionDAG &DAG) {
37809   // Can't if the load is volatile or atomic.
37810   if (!LN->isSimple())
37811     return SDValue();
37812 
37813   SDVTList Tys = DAG.getVTList(VT, MVT::Other);
37814   SDValue Ops[] = {LN->getChain(), LN->getBasePtr()};
37815   return DAG.getMemIntrinsicNode(X86ISD::VZEXT_LOAD, SDLoc(LN), Tys, Ops, MemVT,
37816                                  LN->getPointerInfo(), LN->getOriginalAlign(),
37817                                  LN->getMemOperand()->getFlags());
37818 }
37819 
37820 // Attempt to match a combined shuffle mask against supported unary shuffle
37821 // instructions.
37822 // TODO: Investigate sharing more of this with shuffle lowering.
matchUnaryShuffle(MVT MaskVT,ArrayRef<int> Mask,bool AllowFloatDomain,bool AllowIntDomain,SDValue V1,const SelectionDAG & DAG,const X86Subtarget & Subtarget,unsigned & Shuffle,MVT & SrcVT,MVT & DstVT)37823 static bool matchUnaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
37824                               bool AllowFloatDomain, bool AllowIntDomain,
37825                               SDValue V1, const SelectionDAG &DAG,
37826                               const X86Subtarget &Subtarget, unsigned &Shuffle,
37827                               MVT &SrcVT, MVT &DstVT) {
37828   unsigned NumMaskElts = Mask.size();
37829   unsigned MaskEltSize = MaskVT.getScalarSizeInBits();
37830 
37831   // Match against a VZEXT_MOVL vXi32 and vXi16 zero-extending instruction.
37832   if (Mask[0] == 0 &&
37833       (MaskEltSize == 32 || (MaskEltSize == 16 && Subtarget.hasFP16()))) {
37834     if ((isUndefOrZero(Mask[1]) && isUndefInRange(Mask, 2, NumMaskElts - 2)) ||
37835         (V1.getOpcode() == ISD::SCALAR_TO_VECTOR &&
37836          isUndefOrZeroInRange(Mask, 1, NumMaskElts - 1))) {
37837       Shuffle = X86ISD::VZEXT_MOVL;
37838       if (MaskEltSize == 16)
37839         SrcVT = DstVT = MaskVT.changeVectorElementType(MVT::f16);
37840       else
37841         SrcVT = DstVT = !Subtarget.hasSSE2() ? MVT::v4f32 : MaskVT;
37842       return true;
37843     }
37844   }
37845 
37846   // Match against a ANY/SIGN/ZERO_EXTEND_VECTOR_INREG instruction.
37847   // TODO: Add 512-bit vector support (split AVX512F and AVX512BW).
37848   if (AllowIntDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSE41()) ||
37849                          (MaskVT.is256BitVector() && Subtarget.hasInt256()))) {
37850     unsigned MaxScale = 64 / MaskEltSize;
37851     bool UseSign = V1.getScalarValueSizeInBits() == MaskEltSize &&
37852                    DAG.ComputeNumSignBits(V1) == MaskEltSize;
37853     for (unsigned Scale = 2; Scale <= MaxScale; Scale *= 2) {
37854       bool MatchAny = true;
37855       bool MatchZero = true;
37856       bool MatchSign = UseSign;
37857       unsigned NumDstElts = NumMaskElts / Scale;
37858       for (unsigned i = 0;
37859            i != NumDstElts && (MatchAny || MatchSign || MatchZero); ++i) {
37860         if (!isUndefOrEqual(Mask[i * Scale], (int)i)) {
37861           MatchAny = MatchSign = MatchZero = false;
37862           break;
37863         }
37864         unsigned Pos = (i * Scale) + 1;
37865         unsigned Len = Scale - 1;
37866         MatchAny &= isUndefInRange(Mask, Pos, Len);
37867         MatchZero &= isUndefOrZeroInRange(Mask, Pos, Len);
37868         MatchSign &= isUndefOrEqualInRange(Mask, (int)i, Pos, Len);
37869       }
37870       if (MatchAny || MatchSign || MatchZero) {
37871         assert((MatchSign || MatchZero) &&
37872                "Failed to match sext/zext but matched aext?");
37873         unsigned SrcSize = std::max(128u, NumDstElts * MaskEltSize);
37874         MVT ScalarTy = MaskVT.isInteger() ? MaskVT.getScalarType()
37875                                           : MVT::getIntegerVT(MaskEltSize);
37876         SrcVT = MVT::getVectorVT(ScalarTy, SrcSize / MaskEltSize);
37877 
37878         Shuffle = unsigned(
37879             MatchAny ? ISD::ANY_EXTEND
37880                      : (MatchSign ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND));
37881         if (SrcVT.getVectorNumElements() != NumDstElts)
37882           Shuffle = DAG.getOpcode_EXTEND_VECTOR_INREG(Shuffle);
37883 
37884         DstVT = MVT::getIntegerVT(Scale * MaskEltSize);
37885         DstVT = MVT::getVectorVT(DstVT, NumDstElts);
37886         return true;
37887       }
37888     }
37889   }
37890 
37891   // Match against a VZEXT_MOVL instruction, SSE1 only supports 32-bits (MOVSS).
37892   if (((MaskEltSize == 32) || (MaskEltSize == 64 && Subtarget.hasSSE2()) ||
37893        (MaskEltSize == 16 && Subtarget.hasFP16())) &&
37894       isUndefOrEqual(Mask[0], 0) &&
37895       isUndefOrZeroInRange(Mask, 1, NumMaskElts - 1)) {
37896     Shuffle = X86ISD::VZEXT_MOVL;
37897     if (MaskEltSize == 16)
37898       SrcVT = DstVT = MaskVT.changeVectorElementType(MVT::f16);
37899     else
37900       SrcVT = DstVT = !Subtarget.hasSSE2() ? MVT::v4f32 : MaskVT;
37901     return true;
37902   }
37903 
37904   // Check if we have SSE3 which will let us use MOVDDUP etc. The
37905   // instructions are no slower than UNPCKLPD but has the option to
37906   // fold the input operand into even an unaligned memory load.
37907   if (MaskVT.is128BitVector() && Subtarget.hasSSE3() && AllowFloatDomain) {
37908     if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0}, DAG, V1)) {
37909       Shuffle = X86ISD::MOVDDUP;
37910       SrcVT = DstVT = MVT::v2f64;
37911       return true;
37912     }
37913     if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2}, DAG, V1)) {
37914       Shuffle = X86ISD::MOVSLDUP;
37915       SrcVT = DstVT = MVT::v4f32;
37916       return true;
37917     }
37918     if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1, 3, 3}, DAG, V1)) {
37919       Shuffle = X86ISD::MOVSHDUP;
37920       SrcVT = DstVT = MVT::v4f32;
37921       return true;
37922     }
37923   }
37924 
37925   if (MaskVT.is256BitVector() && AllowFloatDomain) {
37926     assert(Subtarget.hasAVX() && "AVX required for 256-bit vector shuffles");
37927     if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2}, DAG, V1)) {
37928       Shuffle = X86ISD::MOVDDUP;
37929       SrcVT = DstVT = MVT::v4f64;
37930       return true;
37931     }
37932     if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2, 4, 4, 6, 6}, DAG,
37933                                   V1)) {
37934       Shuffle = X86ISD::MOVSLDUP;
37935       SrcVT = DstVT = MVT::v8f32;
37936       return true;
37937     }
37938     if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1, 3, 3, 5, 5, 7, 7}, DAG,
37939                                   V1)) {
37940       Shuffle = X86ISD::MOVSHDUP;
37941       SrcVT = DstVT = MVT::v8f32;
37942       return true;
37943     }
37944   }
37945 
37946   if (MaskVT.is512BitVector() && AllowFloatDomain) {
37947     assert(Subtarget.hasAVX512() &&
37948            "AVX512 required for 512-bit vector shuffles");
37949     if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2, 4, 4, 6, 6}, DAG,
37950                                   V1)) {
37951       Shuffle = X86ISD::MOVDDUP;
37952       SrcVT = DstVT = MVT::v8f64;
37953       return true;
37954     }
37955     if (isTargetShuffleEquivalent(
37956             MaskVT, Mask,
37957             {0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}, DAG, V1)) {
37958       Shuffle = X86ISD::MOVSLDUP;
37959       SrcVT = DstVT = MVT::v16f32;
37960       return true;
37961     }
37962     if (isTargetShuffleEquivalent(
37963             MaskVT, Mask,
37964             {1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}, DAG, V1)) {
37965       Shuffle = X86ISD::MOVSHDUP;
37966       SrcVT = DstVT = MVT::v16f32;
37967       return true;
37968     }
37969   }
37970 
37971   return false;
37972 }
37973 
37974 // Attempt to match a combined shuffle mask against supported unary immediate
37975 // permute instructions.
37976 // TODO: Investigate sharing more of this with shuffle lowering.
matchUnaryPermuteShuffle(MVT MaskVT,ArrayRef<int> Mask,const APInt & Zeroable,bool AllowFloatDomain,bool AllowIntDomain,const SelectionDAG & DAG,const X86Subtarget & Subtarget,unsigned & Shuffle,MVT & ShuffleVT,unsigned & PermuteImm)37977 static bool matchUnaryPermuteShuffle(MVT MaskVT, ArrayRef<int> Mask,
37978                                      const APInt &Zeroable,
37979                                      bool AllowFloatDomain, bool AllowIntDomain,
37980                                      const SelectionDAG &DAG,
37981                                      const X86Subtarget &Subtarget,
37982                                      unsigned &Shuffle, MVT &ShuffleVT,
37983                                      unsigned &PermuteImm) {
37984   unsigned NumMaskElts = Mask.size();
37985   unsigned InputSizeInBits = MaskVT.getSizeInBits();
37986   unsigned MaskScalarSizeInBits = InputSizeInBits / NumMaskElts;
37987   MVT MaskEltVT = MVT::getIntegerVT(MaskScalarSizeInBits);
37988   bool ContainsZeros = isAnyZero(Mask);
37989 
37990   // Handle VPERMI/VPERMILPD vXi64/vXi64 patterns.
37991   if (!ContainsZeros && MaskScalarSizeInBits == 64) {
37992     // Check for lane crossing permutes.
37993     if (is128BitLaneCrossingShuffleMask(MaskEltVT, Mask)) {
37994       // PERMPD/PERMQ permutes within a 256-bit vector (AVX2+).
37995       if (Subtarget.hasAVX2() && MaskVT.is256BitVector()) {
37996         Shuffle = X86ISD::VPERMI;
37997         ShuffleVT = (AllowFloatDomain ? MVT::v4f64 : MVT::v4i64);
37998         PermuteImm = getV4X86ShuffleImm(Mask);
37999         return true;
38000       }
38001       if (Subtarget.hasAVX512() && MaskVT.is512BitVector()) {
38002         SmallVector<int, 4> RepeatedMask;
38003         if (is256BitLaneRepeatedShuffleMask(MVT::v8f64, Mask, RepeatedMask)) {
38004           Shuffle = X86ISD::VPERMI;
38005           ShuffleVT = (AllowFloatDomain ? MVT::v8f64 : MVT::v8i64);
38006           PermuteImm = getV4X86ShuffleImm(RepeatedMask);
38007           return true;
38008         }
38009       }
38010     } else if (AllowFloatDomain && Subtarget.hasAVX()) {
38011       // VPERMILPD can permute with a non-repeating shuffle.
38012       Shuffle = X86ISD::VPERMILPI;
38013       ShuffleVT = MVT::getVectorVT(MVT::f64, Mask.size());
38014       PermuteImm = 0;
38015       for (int i = 0, e = Mask.size(); i != e; ++i) {
38016         int M = Mask[i];
38017         if (M == SM_SentinelUndef)
38018           continue;
38019         assert(((M / 2) == (i / 2)) && "Out of range shuffle mask index");
38020         PermuteImm |= (M & 1) << i;
38021       }
38022       return true;
38023     }
38024   }
38025 
38026   // We are checking for shuffle match or shift match. Loop twice so we can
38027   // order which we try and match first depending on target preference.
38028   for (unsigned Order = 0; Order < 2; ++Order) {
38029     if (Subtarget.preferLowerShuffleAsShift() ? (Order == 1) : (Order == 0)) {
38030       // Handle PSHUFD/VPERMILPI vXi32/vXf32 repeated patterns.
38031       // AVX introduced the VPERMILPD/VPERMILPS float permutes, before then we
38032       // had to use 2-input SHUFPD/SHUFPS shuffles (not handled here).
38033       if ((MaskScalarSizeInBits == 64 || MaskScalarSizeInBits == 32) &&
38034           !ContainsZeros && (AllowIntDomain || Subtarget.hasAVX())) {
38035         SmallVector<int, 4> RepeatedMask;
38036         if (is128BitLaneRepeatedShuffleMask(MaskEltVT, Mask, RepeatedMask)) {
38037           // Narrow the repeated mask to create 32-bit element permutes.
38038           SmallVector<int, 4> WordMask = RepeatedMask;
38039           if (MaskScalarSizeInBits == 64)
38040             narrowShuffleMaskElts(2, RepeatedMask, WordMask);
38041 
38042           Shuffle = (AllowIntDomain ? X86ISD::PSHUFD : X86ISD::VPERMILPI);
38043           ShuffleVT = (AllowIntDomain ? MVT::i32 : MVT::f32);
38044           ShuffleVT = MVT::getVectorVT(ShuffleVT, InputSizeInBits / 32);
38045           PermuteImm = getV4X86ShuffleImm(WordMask);
38046           return true;
38047         }
38048       }
38049 
38050       // Handle PSHUFLW/PSHUFHW vXi16 repeated patterns.
38051       if (!ContainsZeros && AllowIntDomain && MaskScalarSizeInBits == 16 &&
38052           ((MaskVT.is128BitVector() && Subtarget.hasSSE2()) ||
38053            (MaskVT.is256BitVector() && Subtarget.hasAVX2()) ||
38054            (MaskVT.is512BitVector() && Subtarget.hasBWI()))) {
38055         SmallVector<int, 4> RepeatedMask;
38056         if (is128BitLaneRepeatedShuffleMask(MaskEltVT, Mask, RepeatedMask)) {
38057           ArrayRef<int> LoMask(RepeatedMask.data() + 0, 4);
38058           ArrayRef<int> HiMask(RepeatedMask.data() + 4, 4);
38059 
38060           // PSHUFLW: permute lower 4 elements only.
38061           if (isUndefOrInRange(LoMask, 0, 4) &&
38062               isSequentialOrUndefInRange(HiMask, 0, 4, 4)) {
38063             Shuffle = X86ISD::PSHUFLW;
38064             ShuffleVT = MVT::getVectorVT(MVT::i16, InputSizeInBits / 16);
38065             PermuteImm = getV4X86ShuffleImm(LoMask);
38066             return true;
38067           }
38068 
38069           // PSHUFHW: permute upper 4 elements only.
38070           if (isUndefOrInRange(HiMask, 4, 8) &&
38071               isSequentialOrUndefInRange(LoMask, 0, 4, 0)) {
38072             // Offset the HiMask so that we can create the shuffle immediate.
38073             int OffsetHiMask[4];
38074             for (int i = 0; i != 4; ++i)
38075               OffsetHiMask[i] = (HiMask[i] < 0 ? HiMask[i] : HiMask[i] - 4);
38076 
38077             Shuffle = X86ISD::PSHUFHW;
38078             ShuffleVT = MVT::getVectorVT(MVT::i16, InputSizeInBits / 16);
38079             PermuteImm = getV4X86ShuffleImm(OffsetHiMask);
38080             return true;
38081           }
38082         }
38083       }
38084     } else {
38085       // Attempt to match against bit rotates.
38086       if (!ContainsZeros && AllowIntDomain && MaskScalarSizeInBits < 64 &&
38087           ((MaskVT.is128BitVector() && Subtarget.hasXOP()) ||
38088            Subtarget.hasAVX512())) {
38089         int RotateAmt = matchShuffleAsBitRotate(ShuffleVT, MaskScalarSizeInBits,
38090                                                 Subtarget, Mask);
38091         if (0 < RotateAmt) {
38092           Shuffle = X86ISD::VROTLI;
38093           PermuteImm = (unsigned)RotateAmt;
38094           return true;
38095         }
38096       }
38097     }
38098     // Attempt to match against byte/bit shifts.
38099     if (AllowIntDomain &&
38100         ((MaskVT.is128BitVector() && Subtarget.hasSSE2()) ||
38101          (MaskVT.is256BitVector() && Subtarget.hasAVX2()) ||
38102          (MaskVT.is512BitVector() && Subtarget.hasAVX512()))) {
38103       int ShiftAmt =
38104           matchShuffleAsShift(ShuffleVT, Shuffle, MaskScalarSizeInBits, Mask, 0,
38105                               Zeroable, Subtarget);
38106       if (0 < ShiftAmt && (!ShuffleVT.is512BitVector() || Subtarget.hasBWI() ||
38107                            32 <= ShuffleVT.getScalarSizeInBits())) {
38108         // Byte shifts can be slower so only match them on second attempt.
38109         if (Order == 0 &&
38110             (Shuffle == X86ISD::VSHLDQ || Shuffle == X86ISD::VSRLDQ))
38111           continue;
38112 
38113         PermuteImm = (unsigned)ShiftAmt;
38114         return true;
38115       }
38116 
38117     }
38118   }
38119 
38120   return false;
38121 }
38122 
38123 // Attempt to match a combined unary shuffle mask against supported binary
38124 // shuffle instructions.
38125 // TODO: Investigate sharing more of this with shuffle lowering.
matchBinaryShuffle(MVT MaskVT,ArrayRef<int> Mask,bool AllowFloatDomain,bool AllowIntDomain,SDValue & V1,SDValue & V2,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget,unsigned & Shuffle,MVT & SrcVT,MVT & DstVT,bool IsUnary)38126 static bool matchBinaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
38127                                bool AllowFloatDomain, bool AllowIntDomain,
38128                                SDValue &V1, SDValue &V2, const SDLoc &DL,
38129                                SelectionDAG &DAG, const X86Subtarget &Subtarget,
38130                                unsigned &Shuffle, MVT &SrcVT, MVT &DstVT,
38131                                bool IsUnary) {
38132   unsigned NumMaskElts = Mask.size();
38133   unsigned EltSizeInBits = MaskVT.getScalarSizeInBits();
38134   unsigned SizeInBits = MaskVT.getSizeInBits();
38135 
38136   if (MaskVT.is128BitVector()) {
38137     if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0}, DAG) &&
38138         AllowFloatDomain) {
38139       V2 = V1;
38140       V1 = (SM_SentinelUndef == Mask[0] ? DAG.getUNDEF(MVT::v4f32) : V1);
38141       Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKL : X86ISD::MOVLHPS;
38142       SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
38143       return true;
38144     }
38145     if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1}, DAG) &&
38146         AllowFloatDomain) {
38147       V2 = V1;
38148       Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKH : X86ISD::MOVHLPS;
38149       SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
38150       return true;
38151     }
38152     if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 3}, DAG) &&
38153         Subtarget.hasSSE2() && (AllowFloatDomain || !Subtarget.hasSSE41())) {
38154       std::swap(V1, V2);
38155       Shuffle = X86ISD::MOVSD;
38156       SrcVT = DstVT = MVT::v2f64;
38157       return true;
38158     }
38159     if (isTargetShuffleEquivalent(MaskVT, Mask, {4, 1, 2, 3}, DAG) &&
38160         (AllowFloatDomain || !Subtarget.hasSSE41())) {
38161       Shuffle = X86ISD::MOVSS;
38162       SrcVT = DstVT = MVT::v4f32;
38163       return true;
38164     }
38165     if (isTargetShuffleEquivalent(MaskVT, Mask, {8, 1, 2, 3, 4, 5, 6, 7},
38166                                   DAG) &&
38167         Subtarget.hasFP16()) {
38168       Shuffle = X86ISD::MOVSH;
38169       SrcVT = DstVT = MVT::v8f16;
38170       return true;
38171     }
38172   }
38173 
38174   // Attempt to match against either an unary or binary PACKSS/PACKUS shuffle.
38175   if (((MaskVT == MVT::v8i16 || MaskVT == MVT::v16i8) && Subtarget.hasSSE2()) ||
38176       ((MaskVT == MVT::v16i16 || MaskVT == MVT::v32i8) && Subtarget.hasInt256()) ||
38177       ((MaskVT == MVT::v32i16 || MaskVT == MVT::v64i8) && Subtarget.hasBWI())) {
38178     if (matchShuffleWithPACK(MaskVT, SrcVT, V1, V2, Shuffle, Mask, DAG,
38179                              Subtarget)) {
38180       DstVT = MaskVT;
38181       return true;
38182     }
38183   }
38184   // TODO: Can we handle this inside matchShuffleWithPACK?
38185   if (MaskVT == MVT::v4i32 && Subtarget.hasSSE2() &&
38186       isTargetShuffleEquivalent(MaskVT, Mask, {0, 2, 4, 6}, DAG) &&
38187       V1.getScalarValueSizeInBits() == 64 &&
38188       V2.getScalarValueSizeInBits() == 64) {
38189     // Use (SSE41) PACKUSWD if the leading zerobits goto the lowest 16-bits.
38190     unsigned MinLZV1 = DAG.computeKnownBits(V1).countMinLeadingZeros();
38191     unsigned MinLZV2 = DAG.computeKnownBits(V2).countMinLeadingZeros();
38192     if (Subtarget.hasSSE41() && MinLZV1 >= 48 && MinLZV2 >= 48) {
38193       SrcVT = MVT::v4i32;
38194       DstVT = MVT::v8i16;
38195       Shuffle = X86ISD::PACKUS;
38196       return true;
38197     }
38198     // Use PACKUSBW if the leading zerobits goto the lowest 8-bits.
38199     if (MinLZV1 >= 56 && MinLZV2 >= 56) {
38200       SrcVT = MVT::v8i16;
38201       DstVT = MVT::v16i8;
38202       Shuffle = X86ISD::PACKUS;
38203       return true;
38204     }
38205     // Use PACKSSWD if the signbits extend to the lowest 16-bits.
38206     if (DAG.ComputeNumSignBits(V1) > 48 && DAG.ComputeNumSignBits(V2) > 48) {
38207       SrcVT = MVT::v4i32;
38208       DstVT = MVT::v8i16;
38209       Shuffle = X86ISD::PACKSS;
38210       return true;
38211     }
38212   }
38213 
38214   // Attempt to match against either a unary or binary UNPCKL/UNPCKH shuffle.
38215   if ((MaskVT == MVT::v4f32 && Subtarget.hasSSE1()) ||
38216       (MaskVT.is128BitVector() && Subtarget.hasSSE2()) ||
38217       (MaskVT.is256BitVector() && 32 <= EltSizeInBits && Subtarget.hasAVX()) ||
38218       (MaskVT.is256BitVector() && Subtarget.hasAVX2()) ||
38219       (MaskVT.is512BitVector() && Subtarget.hasAVX512() &&
38220        (32 <= EltSizeInBits || Subtarget.hasBWI()))) {
38221     if (matchShuffleWithUNPCK(MaskVT, V1, V2, Shuffle, IsUnary, Mask, DL, DAG,
38222                               Subtarget)) {
38223       SrcVT = DstVT = MaskVT;
38224       if (MaskVT.is256BitVector() && !Subtarget.hasAVX2())
38225         SrcVT = DstVT = (32 == EltSizeInBits ? MVT::v8f32 : MVT::v4f64);
38226       return true;
38227     }
38228   }
38229 
38230   // Attempt to match against a OR if we're performing a blend shuffle and the
38231   // non-blended source element is zero in each case.
38232   // TODO: Handle cases where V1/V2 sizes doesn't match SizeInBits.
38233   if (SizeInBits == V1.getValueSizeInBits() &&
38234       SizeInBits == V2.getValueSizeInBits() &&
38235       (EltSizeInBits % V1.getScalarValueSizeInBits()) == 0 &&
38236       (EltSizeInBits % V2.getScalarValueSizeInBits()) == 0) {
38237     bool IsBlend = true;
38238     unsigned NumV1Elts = V1.getValueType().getVectorNumElements();
38239     unsigned NumV2Elts = V2.getValueType().getVectorNumElements();
38240     unsigned Scale1 = NumV1Elts / NumMaskElts;
38241     unsigned Scale2 = NumV2Elts / NumMaskElts;
38242     APInt DemandedZeroV1 = APInt::getZero(NumV1Elts);
38243     APInt DemandedZeroV2 = APInt::getZero(NumV2Elts);
38244     for (unsigned i = 0; i != NumMaskElts; ++i) {
38245       int M = Mask[i];
38246       if (M == SM_SentinelUndef)
38247         continue;
38248       if (M == SM_SentinelZero) {
38249         DemandedZeroV1.setBits(i * Scale1, (i + 1) * Scale1);
38250         DemandedZeroV2.setBits(i * Scale2, (i + 1) * Scale2);
38251         continue;
38252       }
38253       if (M == (int)i) {
38254         DemandedZeroV2.setBits(i * Scale2, (i + 1) * Scale2);
38255         continue;
38256       }
38257       if (M == (int)(i + NumMaskElts)) {
38258         DemandedZeroV1.setBits(i * Scale1, (i + 1) * Scale1);
38259         continue;
38260       }
38261       IsBlend = false;
38262       break;
38263     }
38264     if (IsBlend) {
38265       if (DAG.MaskedVectorIsZero(V1, DemandedZeroV1) &&
38266           DAG.MaskedVectorIsZero(V2, DemandedZeroV2)) {
38267         Shuffle = ISD::OR;
38268         SrcVT = DstVT = MaskVT.changeTypeToInteger();
38269         return true;
38270       }
38271       if (NumV1Elts == NumV2Elts && NumV1Elts == NumMaskElts) {
38272         // FIXME: handle mismatched sizes?
38273         // TODO: investigate if `ISD::OR` handling in
38274         // `TargetLowering::SimplifyDemandedVectorElts` can be improved instead.
38275         auto computeKnownBitsElementWise = [&DAG](SDValue V) {
38276           unsigned NumElts = V.getValueType().getVectorNumElements();
38277           KnownBits Known(NumElts);
38278           for (unsigned EltIdx = 0; EltIdx != NumElts; ++EltIdx) {
38279             APInt Mask = APInt::getOneBitSet(NumElts, EltIdx);
38280             KnownBits PeepholeKnown = DAG.computeKnownBits(V, Mask);
38281             if (PeepholeKnown.isZero())
38282               Known.Zero.setBit(EltIdx);
38283             if (PeepholeKnown.isAllOnes())
38284               Known.One.setBit(EltIdx);
38285           }
38286           return Known;
38287         };
38288 
38289         KnownBits V1Known = computeKnownBitsElementWise(V1);
38290         KnownBits V2Known = computeKnownBitsElementWise(V2);
38291 
38292         for (unsigned i = 0; i != NumMaskElts && IsBlend; ++i) {
38293           int M = Mask[i];
38294           if (M == SM_SentinelUndef)
38295             continue;
38296           if (M == SM_SentinelZero) {
38297             IsBlend &= V1Known.Zero[i] && V2Known.Zero[i];
38298             continue;
38299           }
38300           if (M == (int)i) {
38301             IsBlend &= V2Known.Zero[i] || V1Known.One[i];
38302             continue;
38303           }
38304           if (M == (int)(i + NumMaskElts)) {
38305             IsBlend &= V1Known.Zero[i] || V2Known.One[i];
38306             continue;
38307           }
38308           llvm_unreachable("will not get here.");
38309         }
38310         if (IsBlend) {
38311           Shuffle = ISD::OR;
38312           SrcVT = DstVT = MaskVT.changeTypeToInteger();
38313           return true;
38314         }
38315       }
38316     }
38317   }
38318 
38319   return false;
38320 }
38321 
matchBinaryPermuteShuffle(MVT MaskVT,ArrayRef<int> Mask,const APInt & Zeroable,bool AllowFloatDomain,bool AllowIntDomain,SDValue & V1,SDValue & V2,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget,unsigned & Shuffle,MVT & ShuffleVT,unsigned & PermuteImm)38322 static bool matchBinaryPermuteShuffle(
38323     MVT MaskVT, ArrayRef<int> Mask, const APInt &Zeroable,
38324     bool AllowFloatDomain, bool AllowIntDomain, SDValue &V1, SDValue &V2,
38325     const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget,
38326     unsigned &Shuffle, MVT &ShuffleVT, unsigned &PermuteImm) {
38327   unsigned NumMaskElts = Mask.size();
38328   unsigned EltSizeInBits = MaskVT.getScalarSizeInBits();
38329 
38330   // Attempt to match against VALIGND/VALIGNQ rotate.
38331   if (AllowIntDomain && (EltSizeInBits == 64 || EltSizeInBits == 32) &&
38332       ((MaskVT.is128BitVector() && Subtarget.hasVLX()) ||
38333        (MaskVT.is256BitVector() && Subtarget.hasVLX()) ||
38334        (MaskVT.is512BitVector() && Subtarget.hasAVX512()))) {
38335     if (!isAnyZero(Mask)) {
38336       int Rotation = matchShuffleAsElementRotate(V1, V2, Mask);
38337       if (0 < Rotation) {
38338         Shuffle = X86ISD::VALIGN;
38339         if (EltSizeInBits == 64)
38340           ShuffleVT = MVT::getVectorVT(MVT::i64, MaskVT.getSizeInBits() / 64);
38341         else
38342           ShuffleVT = MVT::getVectorVT(MVT::i32, MaskVT.getSizeInBits() / 32);
38343         PermuteImm = Rotation;
38344         return true;
38345       }
38346     }
38347   }
38348 
38349   // Attempt to match against PALIGNR byte rotate.
38350   if (AllowIntDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSSE3()) ||
38351                          (MaskVT.is256BitVector() && Subtarget.hasAVX2()) ||
38352                          (MaskVT.is512BitVector() && Subtarget.hasBWI()))) {
38353     int ByteRotation = matchShuffleAsByteRotate(MaskVT, V1, V2, Mask);
38354     if (0 < ByteRotation) {
38355       Shuffle = X86ISD::PALIGNR;
38356       ShuffleVT = MVT::getVectorVT(MVT::i8, MaskVT.getSizeInBits() / 8);
38357       PermuteImm = ByteRotation;
38358       return true;
38359     }
38360   }
38361 
38362   // Attempt to combine to X86ISD::BLENDI.
38363   if ((NumMaskElts <= 8 && ((Subtarget.hasSSE41() && MaskVT.is128BitVector()) ||
38364                             (Subtarget.hasAVX() && MaskVT.is256BitVector()))) ||
38365       (MaskVT == MVT::v16i16 && Subtarget.hasAVX2())) {
38366     uint64_t BlendMask = 0;
38367     bool ForceV1Zero = false, ForceV2Zero = false;
38368     SmallVector<int, 8> TargetMask(Mask);
38369     if (matchShuffleAsBlend(MaskVT, V1, V2, TargetMask, Zeroable, ForceV1Zero,
38370                             ForceV2Zero, BlendMask)) {
38371       if (MaskVT == MVT::v16i16) {
38372         // We can only use v16i16 PBLENDW if the lanes are repeated.
38373         SmallVector<int, 8> RepeatedMask;
38374         if (isRepeatedTargetShuffleMask(128, MaskVT, TargetMask,
38375                                         RepeatedMask)) {
38376           assert(RepeatedMask.size() == 8 &&
38377                  "Repeated mask size doesn't match!");
38378           PermuteImm = 0;
38379           for (int i = 0; i < 8; ++i)
38380             if (RepeatedMask[i] >= 8)
38381               PermuteImm |= 1 << i;
38382           V1 = ForceV1Zero ? getZeroVector(MaskVT, Subtarget, DAG, DL) : V1;
38383           V2 = ForceV2Zero ? getZeroVector(MaskVT, Subtarget, DAG, DL) : V2;
38384           Shuffle = X86ISD::BLENDI;
38385           ShuffleVT = MaskVT;
38386           return true;
38387         }
38388       } else {
38389         V1 = ForceV1Zero ? getZeroVector(MaskVT, Subtarget, DAG, DL) : V1;
38390         V2 = ForceV2Zero ? getZeroVector(MaskVT, Subtarget, DAG, DL) : V2;
38391         PermuteImm = (unsigned)BlendMask;
38392         Shuffle = X86ISD::BLENDI;
38393         ShuffleVT = MaskVT;
38394         return true;
38395       }
38396     }
38397   }
38398 
38399   // Attempt to combine to INSERTPS, but only if it has elements that need to
38400   // be set to zero.
38401   if (AllowFloatDomain && EltSizeInBits == 32 && Subtarget.hasSSE41() &&
38402       MaskVT.is128BitVector() && isAnyZero(Mask) &&
38403       matchShuffleAsInsertPS(V1, V2, PermuteImm, Zeroable, Mask, DAG)) {
38404     Shuffle = X86ISD::INSERTPS;
38405     ShuffleVT = MVT::v4f32;
38406     return true;
38407   }
38408 
38409   // Attempt to combine to SHUFPD.
38410   if (AllowFloatDomain && EltSizeInBits == 64 &&
38411       ((MaskVT.is128BitVector() && Subtarget.hasSSE2()) ||
38412        (MaskVT.is256BitVector() && Subtarget.hasAVX()) ||
38413        (MaskVT.is512BitVector() && Subtarget.hasAVX512()))) {
38414     bool ForceV1Zero = false, ForceV2Zero = false;
38415     if (matchShuffleWithSHUFPD(MaskVT, V1, V2, ForceV1Zero, ForceV2Zero,
38416                                PermuteImm, Mask, Zeroable)) {
38417       V1 = ForceV1Zero ? getZeroVector(MaskVT, Subtarget, DAG, DL) : V1;
38418       V2 = ForceV2Zero ? getZeroVector(MaskVT, Subtarget, DAG, DL) : V2;
38419       Shuffle = X86ISD::SHUFP;
38420       ShuffleVT = MVT::getVectorVT(MVT::f64, MaskVT.getSizeInBits() / 64);
38421       return true;
38422     }
38423   }
38424 
38425   // Attempt to combine to SHUFPS.
38426   if (AllowFloatDomain && EltSizeInBits == 32 &&
38427       ((MaskVT.is128BitVector() && Subtarget.hasSSE1()) ||
38428        (MaskVT.is256BitVector() && Subtarget.hasAVX()) ||
38429        (MaskVT.is512BitVector() && Subtarget.hasAVX512()))) {
38430     SmallVector<int, 4> RepeatedMask;
38431     if (isRepeatedTargetShuffleMask(128, MaskVT, Mask, RepeatedMask)) {
38432       // Match each half of the repeated mask, to determine if its just
38433       // referencing one of the vectors, is zeroable or entirely undef.
38434       auto MatchHalf = [&](unsigned Offset, int &S0, int &S1) {
38435         int M0 = RepeatedMask[Offset];
38436         int M1 = RepeatedMask[Offset + 1];
38437 
38438         if (isUndefInRange(RepeatedMask, Offset, 2)) {
38439           return DAG.getUNDEF(MaskVT);
38440         } else if (isUndefOrZeroInRange(RepeatedMask, Offset, 2)) {
38441           S0 = (SM_SentinelUndef == M0 ? -1 : 0);
38442           S1 = (SM_SentinelUndef == M1 ? -1 : 1);
38443           return getZeroVector(MaskVT, Subtarget, DAG, DL);
38444         } else if (isUndefOrInRange(M0, 0, 4) && isUndefOrInRange(M1, 0, 4)) {
38445           S0 = (SM_SentinelUndef == M0 ? -1 : M0 & 3);
38446           S1 = (SM_SentinelUndef == M1 ? -1 : M1 & 3);
38447           return V1;
38448         } else if (isUndefOrInRange(M0, 4, 8) && isUndefOrInRange(M1, 4, 8)) {
38449           S0 = (SM_SentinelUndef == M0 ? -1 : M0 & 3);
38450           S1 = (SM_SentinelUndef == M1 ? -1 : M1 & 3);
38451           return V2;
38452         }
38453 
38454         return SDValue();
38455       };
38456 
38457       int ShufMask[4] = {-1, -1, -1, -1};
38458       SDValue Lo = MatchHalf(0, ShufMask[0], ShufMask[1]);
38459       SDValue Hi = MatchHalf(2, ShufMask[2], ShufMask[3]);
38460 
38461       if (Lo && Hi) {
38462         V1 = Lo;
38463         V2 = Hi;
38464         Shuffle = X86ISD::SHUFP;
38465         ShuffleVT = MVT::getVectorVT(MVT::f32, MaskVT.getSizeInBits() / 32);
38466         PermuteImm = getV4X86ShuffleImm(ShufMask);
38467         return true;
38468       }
38469     }
38470   }
38471 
38472   // Attempt to combine to INSERTPS more generally if X86ISD::SHUFP failed.
38473   if (AllowFloatDomain && EltSizeInBits == 32 && Subtarget.hasSSE41() &&
38474       MaskVT.is128BitVector() &&
38475       matchShuffleAsInsertPS(V1, V2, PermuteImm, Zeroable, Mask, DAG)) {
38476     Shuffle = X86ISD::INSERTPS;
38477     ShuffleVT = MVT::v4f32;
38478     return true;
38479   }
38480 
38481   return false;
38482 }
38483 
38484 static SDValue combineX86ShuffleChainWithExtract(
38485     ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth,
38486     bool HasVariableMask, bool AllowVariableCrossLaneMask,
38487     bool AllowVariablePerLaneMask, SelectionDAG &DAG,
38488     const X86Subtarget &Subtarget);
38489 
38490 /// Combine an arbitrary chain of shuffles into a single instruction if
38491 /// possible.
38492 ///
38493 /// This is the leaf of the recursive combine below. When we have found some
38494 /// chain of single-use x86 shuffle instructions and accumulated the combined
38495 /// shuffle mask represented by them, this will try to pattern match that mask
38496 /// into either a single instruction if there is a special purpose instruction
38497 /// for this operation, or into a PSHUFB instruction which is a fully general
38498 /// instruction but should only be used to replace chains over a certain depth.
combineX86ShuffleChain(ArrayRef<SDValue> Inputs,SDValue Root,ArrayRef<int> BaseMask,int Depth,bool HasVariableMask,bool AllowVariableCrossLaneMask,bool AllowVariablePerLaneMask,SelectionDAG & DAG,const X86Subtarget & Subtarget)38499 static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
38500                                       ArrayRef<int> BaseMask, int Depth,
38501                                       bool HasVariableMask,
38502                                       bool AllowVariableCrossLaneMask,
38503                                       bool AllowVariablePerLaneMask,
38504                                       SelectionDAG &DAG,
38505                                       const X86Subtarget &Subtarget) {
38506   assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!");
38507   assert((Inputs.size() == 1 || Inputs.size() == 2) &&
38508          "Unexpected number of shuffle inputs!");
38509 
38510   SDLoc DL(Root);
38511   MVT RootVT = Root.getSimpleValueType();
38512   unsigned RootSizeInBits = RootVT.getSizeInBits();
38513   unsigned NumRootElts = RootVT.getVectorNumElements();
38514 
38515   // Canonicalize shuffle input op to the requested type.
38516   auto CanonicalizeShuffleInput = [&](MVT VT, SDValue Op) {
38517     if (VT.getSizeInBits() > Op.getValueSizeInBits())
38518       Op = widenSubVector(Op, false, Subtarget, DAG, DL, VT.getSizeInBits());
38519     else if (VT.getSizeInBits() < Op.getValueSizeInBits())
38520       Op = extractSubVector(Op, 0, DAG, DL, VT.getSizeInBits());
38521     return DAG.getBitcast(VT, Op);
38522   };
38523 
38524   // Find the inputs that enter the chain. Note that multiple uses are OK
38525   // here, we're not going to remove the operands we find.
38526   bool UnaryShuffle = (Inputs.size() == 1);
38527   SDValue V1 = peekThroughBitcasts(Inputs[0]);
38528   SDValue V2 = (UnaryShuffle ? DAG.getUNDEF(V1.getValueType())
38529                              : peekThroughBitcasts(Inputs[1]));
38530 
38531   MVT VT1 = V1.getSimpleValueType();
38532   MVT VT2 = V2.getSimpleValueType();
38533   assert((RootSizeInBits % VT1.getSizeInBits()) == 0 &&
38534          (RootSizeInBits % VT2.getSizeInBits()) == 0 && "Vector size mismatch");
38535 
38536   SDValue Res;
38537 
38538   unsigned NumBaseMaskElts = BaseMask.size();
38539   if (NumBaseMaskElts == 1) {
38540     assert(BaseMask[0] == 0 && "Invalid shuffle index found!");
38541     return CanonicalizeShuffleInput(RootVT, V1);
38542   }
38543 
38544   bool OptForSize = DAG.shouldOptForSize();
38545   unsigned BaseMaskEltSizeInBits = RootSizeInBits / NumBaseMaskElts;
38546   bool FloatDomain = VT1.isFloatingPoint() || VT2.isFloatingPoint() ||
38547                      (RootVT.isFloatingPoint() && Depth >= 1) ||
38548                      (RootVT.is256BitVector() && !Subtarget.hasAVX2());
38549 
38550   // Don't combine if we are a AVX512/EVEX target and the mask element size
38551   // is different from the root element size - this would prevent writemasks
38552   // from being reused.
38553   bool IsMaskedShuffle = false;
38554   if (RootSizeInBits == 512 || (Subtarget.hasVLX() && RootSizeInBits >= 128)) {
38555     if (Root.hasOneUse() && Root->use_begin()->getOpcode() == ISD::VSELECT &&
38556         Root->use_begin()->getOperand(0).getScalarValueSizeInBits() == 1) {
38557       IsMaskedShuffle = true;
38558     }
38559   }
38560 
38561   // If we are shuffling a splat (and not introducing zeros) then we can just
38562   // use it directly. This works for smaller elements as well as they already
38563   // repeat across each mask element.
38564   if (UnaryShuffle && !isAnyZero(BaseMask) &&
38565       V1.getValueSizeInBits() >= RootSizeInBits &&
38566       (BaseMaskEltSizeInBits % V1.getScalarValueSizeInBits()) == 0 &&
38567       DAG.isSplatValue(V1, /*AllowUndefs*/ false)) {
38568     return CanonicalizeShuffleInput(RootVT, V1);
38569   }
38570 
38571   SmallVector<int, 64> Mask(BaseMask);
38572 
38573   // See if the shuffle is a hidden identity shuffle - repeated args in HOPs
38574   // etc. can be simplified.
38575   if (VT1 == VT2 && VT1.getSizeInBits() == RootSizeInBits && VT1.isVector()) {
38576     SmallVector<int> ScaledMask, IdentityMask;
38577     unsigned NumElts = VT1.getVectorNumElements();
38578     if (Mask.size() <= NumElts &&
38579         scaleShuffleElements(Mask, NumElts, ScaledMask)) {
38580       for (unsigned i = 0; i != NumElts; ++i)
38581         IdentityMask.push_back(i);
38582       if (isTargetShuffleEquivalent(RootVT, ScaledMask, IdentityMask, DAG, V1,
38583                                     V2))
38584         return CanonicalizeShuffleInput(RootVT, V1);
38585     }
38586   }
38587 
38588   // Handle 128/256-bit lane shuffles of 512-bit vectors.
38589   if (RootVT.is512BitVector() &&
38590       (NumBaseMaskElts == 2 || NumBaseMaskElts == 4)) {
38591     // If the upper subvectors are zeroable, then an extract+insert is more
38592     // optimal than using X86ISD::SHUF128. The insertion is free, even if it has
38593     // to zero the upper subvectors.
38594     if (isUndefOrZeroInRange(Mask, 1, NumBaseMaskElts - 1)) {
38595       if (Depth == 0 && Root.getOpcode() == ISD::INSERT_SUBVECTOR)
38596         return SDValue(); // Nothing to do!
38597       assert(isInRange(Mask[0], 0, NumBaseMaskElts) &&
38598              "Unexpected lane shuffle");
38599       Res = CanonicalizeShuffleInput(RootVT, V1);
38600       unsigned SubIdx = Mask[0] * (NumRootElts / NumBaseMaskElts);
38601       bool UseZero = isAnyZero(Mask);
38602       Res = extractSubVector(Res, SubIdx, DAG, DL, BaseMaskEltSizeInBits);
38603       return widenSubVector(Res, UseZero, Subtarget, DAG, DL, RootSizeInBits);
38604     }
38605 
38606     // Narrow shuffle mask to v4x128.
38607     SmallVector<int, 4> ScaledMask;
38608     assert((BaseMaskEltSizeInBits % 128) == 0 && "Illegal mask size");
38609     narrowShuffleMaskElts(BaseMaskEltSizeInBits / 128, Mask, ScaledMask);
38610 
38611     // Try to lower to vshuf64x2/vshuf32x4.
38612     auto MatchSHUF128 = [&](MVT ShuffleVT, const SDLoc &DL,
38613                             ArrayRef<int> ScaledMask, SDValue V1, SDValue V2,
38614                             SelectionDAG &DAG) {
38615       int PermMask[4] = {-1, -1, -1, -1};
38616       // Ensure elements came from the same Op.
38617       SDValue Ops[2] = {DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT)};
38618       for (int i = 0; i < 4; ++i) {
38619         assert(ScaledMask[i] >= -1 && "Illegal shuffle sentinel value");
38620         if (ScaledMask[i] < 0)
38621           continue;
38622 
38623         SDValue Op = ScaledMask[i] >= 4 ? V2 : V1;
38624         unsigned OpIndex = i / 2;
38625         if (Ops[OpIndex].isUndef())
38626           Ops[OpIndex] = Op;
38627         else if (Ops[OpIndex] != Op)
38628           return SDValue();
38629 
38630         PermMask[i] = ScaledMask[i] % 4;
38631       }
38632 
38633       return DAG.getNode(X86ISD::SHUF128, DL, ShuffleVT,
38634                          CanonicalizeShuffleInput(ShuffleVT, Ops[0]),
38635                          CanonicalizeShuffleInput(ShuffleVT, Ops[1]),
38636                          getV4X86ShuffleImm8ForMask(PermMask, DL, DAG));
38637     };
38638 
38639     // FIXME: Is there a better way to do this? is256BitLaneRepeatedShuffleMask
38640     // doesn't work because our mask is for 128 bits and we don't have an MVT
38641     // to match that.
38642     bool PreferPERMQ = UnaryShuffle && isUndefOrInRange(ScaledMask[0], 0, 2) &&
38643                        isUndefOrInRange(ScaledMask[1], 0, 2) &&
38644                        isUndefOrInRange(ScaledMask[2], 2, 4) &&
38645                        isUndefOrInRange(ScaledMask[3], 2, 4) &&
38646                        (ScaledMask[0] < 0 || ScaledMask[2] < 0 ||
38647                         ScaledMask[0] == (ScaledMask[2] % 2)) &&
38648                        (ScaledMask[1] < 0 || ScaledMask[3] < 0 ||
38649                         ScaledMask[1] == (ScaledMask[3] % 2));
38650 
38651     if (!isAnyZero(ScaledMask) && !PreferPERMQ) {
38652       if (Depth == 0 && Root.getOpcode() == X86ISD::SHUF128)
38653         return SDValue(); // Nothing to do!
38654       MVT ShuffleVT = (FloatDomain ? MVT::v8f64 : MVT::v8i64);
38655       if (SDValue V = MatchSHUF128(ShuffleVT, DL, ScaledMask, V1, V2, DAG))
38656         return DAG.getBitcast(RootVT, V);
38657     }
38658   }
38659 
38660   // Handle 128-bit lane shuffles of 256-bit vectors.
38661   if (RootVT.is256BitVector() && NumBaseMaskElts == 2) {
38662     // If the upper half is zeroable, then an extract+insert is more optimal
38663     // than using X86ISD::VPERM2X128. The insertion is free, even if it has to
38664     // zero the upper half.
38665     if (isUndefOrZero(Mask[1])) {
38666       if (Depth == 0 && Root.getOpcode() == ISD::INSERT_SUBVECTOR)
38667         return SDValue(); // Nothing to do!
38668       assert(isInRange(Mask[0], 0, 2) && "Unexpected lane shuffle");
38669       Res = CanonicalizeShuffleInput(RootVT, V1);
38670       Res = extract128BitVector(Res, Mask[0] * (NumRootElts / 2), DAG, DL);
38671       return widenSubVector(Res, Mask[1] == SM_SentinelZero, Subtarget, DAG, DL,
38672                             256);
38673     }
38674 
38675     // If we're inserting the low subvector, an insert-subvector 'concat'
38676     // pattern is quicker than VPERM2X128.
38677     // TODO: Add AVX2 support instead of VPERMQ/VPERMPD.
38678     if (BaseMask[0] == 0 && (BaseMask[1] == 0 || BaseMask[1] == 2) &&
38679         !Subtarget.hasAVX2()) {
38680       if (Depth == 0 && Root.getOpcode() == ISD::INSERT_SUBVECTOR)
38681         return SDValue(); // Nothing to do!
38682       SDValue Lo = CanonicalizeShuffleInput(RootVT, V1);
38683       SDValue Hi = CanonicalizeShuffleInput(RootVT, BaseMask[1] == 0 ? V1 : V2);
38684       Hi = extractSubVector(Hi, 0, DAG, DL, 128);
38685       return insertSubVector(Lo, Hi, NumRootElts / 2, DAG, DL, 128);
38686     }
38687 
38688     if (Depth == 0 && Root.getOpcode() == X86ISD::VPERM2X128)
38689       return SDValue(); // Nothing to do!
38690 
38691     // If we have AVX2, prefer to use VPERMQ/VPERMPD for unary shuffles unless
38692     // we need to use the zeroing feature.
38693     // Prefer blends for sequential shuffles unless we are optimizing for size.
38694     if (UnaryShuffle &&
38695         !(Subtarget.hasAVX2() && isUndefOrInRange(Mask, 0, 2)) &&
38696         (OptForSize || !isSequentialOrUndefOrZeroInRange(Mask, 0, 2, 0))) {
38697       unsigned PermMask = 0;
38698       PermMask |= ((Mask[0] < 0 ? 0x8 : (Mask[0] & 1)) << 0);
38699       PermMask |= ((Mask[1] < 0 ? 0x8 : (Mask[1] & 1)) << 4);
38700       return DAG.getNode(
38701           X86ISD::VPERM2X128, DL, RootVT, CanonicalizeShuffleInput(RootVT, V1),
38702           DAG.getUNDEF(RootVT), DAG.getTargetConstant(PermMask, DL, MVT::i8));
38703     }
38704 
38705     if (Depth == 0 && Root.getOpcode() == X86ISD::SHUF128)
38706       return SDValue(); // Nothing to do!
38707 
38708     // TODO - handle AVX512VL cases with X86ISD::SHUF128.
38709     if (!UnaryShuffle && !IsMaskedShuffle) {
38710       assert(llvm::all_of(Mask, [](int M) { return 0 <= M && M < 4; }) &&
38711              "Unexpected shuffle sentinel value");
38712       // Prefer blends to X86ISD::VPERM2X128.
38713       if (!((Mask[0] == 0 && Mask[1] == 3) || (Mask[0] == 2 && Mask[1] == 1))) {
38714         unsigned PermMask = 0;
38715         PermMask |= ((Mask[0] & 3) << 0);
38716         PermMask |= ((Mask[1] & 3) << 4);
38717         SDValue LHS = isInRange(Mask[0], 0, 2) ? V1 : V2;
38718         SDValue RHS = isInRange(Mask[1], 0, 2) ? V1 : V2;
38719         return DAG.getNode(X86ISD::VPERM2X128, DL, RootVT,
38720                           CanonicalizeShuffleInput(RootVT, LHS),
38721                           CanonicalizeShuffleInput(RootVT, RHS),
38722                           DAG.getTargetConstant(PermMask, DL, MVT::i8));
38723       }
38724     }
38725   }
38726 
38727   // For masks that have been widened to 128-bit elements or more,
38728   // narrow back down to 64-bit elements.
38729   if (BaseMaskEltSizeInBits > 64) {
38730     assert((BaseMaskEltSizeInBits % 64) == 0 && "Illegal mask size");
38731     int MaskScale = BaseMaskEltSizeInBits / 64;
38732     SmallVector<int, 64> ScaledMask;
38733     narrowShuffleMaskElts(MaskScale, Mask, ScaledMask);
38734     Mask = std::move(ScaledMask);
38735   }
38736 
38737   // For masked shuffles, we're trying to match the root width for better
38738   // writemask folding, attempt to scale the mask.
38739   // TODO - variable shuffles might need this to be widened again.
38740   if (IsMaskedShuffle && NumRootElts > Mask.size()) {
38741     assert((NumRootElts % Mask.size()) == 0 && "Illegal mask size");
38742     int MaskScale = NumRootElts / Mask.size();
38743     SmallVector<int, 64> ScaledMask;
38744     narrowShuffleMaskElts(MaskScale, Mask, ScaledMask);
38745     Mask = std::move(ScaledMask);
38746   }
38747 
38748   unsigned NumMaskElts = Mask.size();
38749   unsigned MaskEltSizeInBits = RootSizeInBits / NumMaskElts;
38750   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
38751 
38752   // Determine the effective mask value type.
38753   FloatDomain &= (32 <= MaskEltSizeInBits);
38754   MVT MaskVT = FloatDomain ? MVT::getFloatingPointVT(MaskEltSizeInBits)
38755                            : MVT::getIntegerVT(MaskEltSizeInBits);
38756   MaskVT = MVT::getVectorVT(MaskVT, NumMaskElts);
38757 
38758   // Only allow legal mask types.
38759   if (!TLI.isTypeLegal(MaskVT))
38760     return SDValue();
38761 
38762   // Attempt to match the mask against known shuffle patterns.
38763   MVT ShuffleSrcVT, ShuffleVT;
38764   unsigned Shuffle, PermuteImm;
38765 
38766   // Which shuffle domains are permitted?
38767   // Permit domain crossing at higher combine depths.
38768   // TODO: Should we indicate which domain is preferred if both are allowed?
38769   bool AllowFloatDomain = FloatDomain || (Depth >= 3);
38770   bool AllowIntDomain = (!FloatDomain || (Depth >= 3)) && Subtarget.hasSSE2() &&
38771                         (!MaskVT.is256BitVector() || Subtarget.hasAVX2());
38772 
38773   // Determine zeroable mask elements.
38774   APInt KnownUndef, KnownZero;
38775   resolveZeroablesFromTargetShuffle(Mask, KnownUndef, KnownZero);
38776   APInt Zeroable = KnownUndef | KnownZero;
38777 
38778   if (UnaryShuffle) {
38779     // Attempt to match against broadcast-from-vector.
38780     // Limit AVX1 to cases where we're loading+broadcasting a scalar element.
38781     if ((Subtarget.hasAVX2() ||
38782          (Subtarget.hasAVX() && 32 <= MaskEltSizeInBits)) &&
38783         (!IsMaskedShuffle || NumRootElts == NumMaskElts)) {
38784       if (isUndefOrEqual(Mask, 0)) {
38785         if (V1.getValueType() == MaskVT &&
38786             V1.getOpcode() == ISD::SCALAR_TO_VECTOR &&
38787             X86::mayFoldLoad(V1.getOperand(0), Subtarget)) {
38788           if (Depth == 0 && Root.getOpcode() == X86ISD::VBROADCAST)
38789             return SDValue(); // Nothing to do!
38790           Res = V1.getOperand(0);
38791           Res = DAG.getNode(X86ISD::VBROADCAST, DL, MaskVT, Res);
38792           return DAG.getBitcast(RootVT, Res);
38793         }
38794         if (Subtarget.hasAVX2()) {
38795           if (Depth == 0 && Root.getOpcode() == X86ISD::VBROADCAST)
38796             return SDValue(); // Nothing to do!
38797           Res = CanonicalizeShuffleInput(MaskVT, V1);
38798           Res = DAG.getNode(X86ISD::VBROADCAST, DL, MaskVT, Res);
38799           return DAG.getBitcast(RootVT, Res);
38800         }
38801       }
38802     }
38803 
38804     if (matchUnaryShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, V1,
38805                           DAG, Subtarget, Shuffle, ShuffleSrcVT, ShuffleVT) &&
38806         (!IsMaskedShuffle ||
38807          (NumRootElts == ShuffleVT.getVectorNumElements()))) {
38808       if (Depth == 0 && Root.getOpcode() == Shuffle)
38809         return SDValue(); // Nothing to do!
38810       Res = CanonicalizeShuffleInput(ShuffleSrcVT, V1);
38811       Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res);
38812       return DAG.getBitcast(RootVT, Res);
38813     }
38814 
38815     if (matchUnaryPermuteShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain,
38816                                  AllowIntDomain, DAG, Subtarget, Shuffle, ShuffleVT,
38817                                  PermuteImm) &&
38818         (!IsMaskedShuffle ||
38819          (NumRootElts == ShuffleVT.getVectorNumElements()))) {
38820       if (Depth == 0 && Root.getOpcode() == Shuffle)
38821         return SDValue(); // Nothing to do!
38822       Res = CanonicalizeShuffleInput(ShuffleVT, V1);
38823       Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res,
38824                         DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
38825       return DAG.getBitcast(RootVT, Res);
38826     }
38827   }
38828 
38829   // Attempt to combine to INSERTPS, but only if the inserted element has come
38830   // from a scalar.
38831   // TODO: Handle other insertions here as well?
38832   if (!UnaryShuffle && AllowFloatDomain && RootSizeInBits == 128 &&
38833       Subtarget.hasSSE41() &&
38834       !isTargetShuffleEquivalent(MaskVT, Mask, {4, 1, 2, 3}, DAG)) {
38835     if (MaskEltSizeInBits == 32) {
38836       SDValue SrcV1 = V1, SrcV2 = V2;
38837       if (matchShuffleAsInsertPS(SrcV1, SrcV2, PermuteImm, Zeroable, Mask,
38838                                  DAG) &&
38839           SrcV2.getOpcode() == ISD::SCALAR_TO_VECTOR) {
38840         if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTPS)
38841           return SDValue(); // Nothing to do!
38842         Res = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32,
38843                           CanonicalizeShuffleInput(MVT::v4f32, SrcV1),
38844                           CanonicalizeShuffleInput(MVT::v4f32, SrcV2),
38845                           DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
38846         return DAG.getBitcast(RootVT, Res);
38847       }
38848     }
38849     if (MaskEltSizeInBits == 64 &&
38850         isTargetShuffleEquivalent(MaskVT, Mask, {0, 2}, DAG) &&
38851         V2.getOpcode() == ISD::SCALAR_TO_VECTOR &&
38852         V2.getScalarValueSizeInBits() <= 32) {
38853       if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTPS)
38854         return SDValue(); // Nothing to do!
38855       PermuteImm = (/*DstIdx*/ 2 << 4) | (/*SrcIdx*/ 0 << 0);
38856       Res = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32,
38857                         CanonicalizeShuffleInput(MVT::v4f32, V1),
38858                         CanonicalizeShuffleInput(MVT::v4f32, V2),
38859                         DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
38860       return DAG.getBitcast(RootVT, Res);
38861     }
38862   }
38863 
38864   SDValue NewV1 = V1; // Save operands in case early exit happens.
38865   SDValue NewV2 = V2;
38866   if (matchBinaryShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, NewV1,
38867                          NewV2, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT,
38868                          ShuffleVT, UnaryShuffle) &&
38869       (!IsMaskedShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
38870     if (Depth == 0 && Root.getOpcode() == Shuffle)
38871       return SDValue(); // Nothing to do!
38872     NewV1 = CanonicalizeShuffleInput(ShuffleSrcVT, NewV1);
38873     NewV2 = CanonicalizeShuffleInput(ShuffleSrcVT, NewV2);
38874     Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2);
38875     return DAG.getBitcast(RootVT, Res);
38876   }
38877 
38878   NewV1 = V1; // Save operands in case early exit happens.
38879   NewV2 = V2;
38880   if (matchBinaryPermuteShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain,
38881                                 AllowIntDomain, NewV1, NewV2, DL, DAG,
38882                                 Subtarget, Shuffle, ShuffleVT, PermuteImm) &&
38883       (!IsMaskedShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) {
38884     if (Depth == 0 && Root.getOpcode() == Shuffle)
38885       return SDValue(); // Nothing to do!
38886     NewV1 = CanonicalizeShuffleInput(ShuffleVT, NewV1);
38887     NewV2 = CanonicalizeShuffleInput(ShuffleVT, NewV2);
38888     Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2,
38889                       DAG.getTargetConstant(PermuteImm, DL, MVT::i8));
38890     return DAG.getBitcast(RootVT, Res);
38891   }
38892 
38893   // Typically from here on, we need an integer version of MaskVT.
38894   MVT IntMaskVT = MVT::getIntegerVT(MaskEltSizeInBits);
38895   IntMaskVT = MVT::getVectorVT(IntMaskVT, NumMaskElts);
38896 
38897   // Annoyingly, SSE4A instructions don't map into the above match helpers.
38898   if (Subtarget.hasSSE4A() && AllowIntDomain && RootSizeInBits == 128) {
38899     uint64_t BitLen, BitIdx;
38900     if (matchShuffleAsEXTRQ(IntMaskVT, V1, V2, Mask, BitLen, BitIdx,
38901                             Zeroable)) {
38902       if (Depth == 0 && Root.getOpcode() == X86ISD::EXTRQI)
38903         return SDValue(); // Nothing to do!
38904       V1 = CanonicalizeShuffleInput(IntMaskVT, V1);
38905       Res = DAG.getNode(X86ISD::EXTRQI, DL, IntMaskVT, V1,
38906                         DAG.getTargetConstant(BitLen, DL, MVT::i8),
38907                         DAG.getTargetConstant(BitIdx, DL, MVT::i8));
38908       return DAG.getBitcast(RootVT, Res);
38909     }
38910 
38911     if (matchShuffleAsINSERTQ(IntMaskVT, V1, V2, Mask, BitLen, BitIdx)) {
38912       if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTQI)
38913         return SDValue(); // Nothing to do!
38914       V1 = CanonicalizeShuffleInput(IntMaskVT, V1);
38915       V2 = CanonicalizeShuffleInput(IntMaskVT, V2);
38916       Res = DAG.getNode(X86ISD::INSERTQI, DL, IntMaskVT, V1, V2,
38917                         DAG.getTargetConstant(BitLen, DL, MVT::i8),
38918                         DAG.getTargetConstant(BitIdx, DL, MVT::i8));
38919       return DAG.getBitcast(RootVT, Res);
38920     }
38921   }
38922 
38923   // Match shuffle against TRUNCATE patterns.
38924   if (AllowIntDomain && MaskEltSizeInBits < 64 && Subtarget.hasAVX512()) {
38925     // Match against a VTRUNC instruction, accounting for src/dst sizes.
38926     if (matchShuffleAsVTRUNC(ShuffleSrcVT, ShuffleVT, IntMaskVT, Mask, Zeroable,
38927                              Subtarget)) {
38928       bool IsTRUNCATE = ShuffleVT.getVectorNumElements() ==
38929                         ShuffleSrcVT.getVectorNumElements();
38930       unsigned Opc =
38931           IsTRUNCATE ? (unsigned)ISD::TRUNCATE : (unsigned)X86ISD::VTRUNC;
38932       if (Depth == 0 && Root.getOpcode() == Opc)
38933         return SDValue(); // Nothing to do!
38934       V1 = CanonicalizeShuffleInput(ShuffleSrcVT, V1);
38935       Res = DAG.getNode(Opc, DL, ShuffleVT, V1);
38936       if (ShuffleVT.getSizeInBits() < RootSizeInBits)
38937         Res = widenSubVector(Res, true, Subtarget, DAG, DL, RootSizeInBits);
38938       return DAG.getBitcast(RootVT, Res);
38939     }
38940 
38941     // Do we need a more general binary truncation pattern?
38942     if (RootSizeInBits < 512 &&
38943         ((RootVT.is256BitVector() && Subtarget.useAVX512Regs()) ||
38944          (RootVT.is128BitVector() && Subtarget.hasVLX())) &&
38945         (MaskEltSizeInBits > 8 || Subtarget.hasBWI()) &&
38946         isSequentialOrUndefInRange(Mask, 0, NumMaskElts, 0, 2)) {
38947       // Bail if this was already a truncation or PACK node.
38948       // We sometimes fail to match PACK if we demand known undef elements.
38949       if (Depth == 0 && (Root.getOpcode() == ISD::TRUNCATE ||
38950                          Root.getOpcode() == X86ISD::PACKSS ||
38951                          Root.getOpcode() == X86ISD::PACKUS))
38952         return SDValue(); // Nothing to do!
38953       ShuffleSrcVT = MVT::getIntegerVT(MaskEltSizeInBits * 2);
38954       ShuffleSrcVT = MVT::getVectorVT(ShuffleSrcVT, NumMaskElts / 2);
38955       V1 = CanonicalizeShuffleInput(ShuffleSrcVT, V1);
38956       V2 = CanonicalizeShuffleInput(ShuffleSrcVT, V2);
38957       ShuffleSrcVT = MVT::getIntegerVT(MaskEltSizeInBits * 2);
38958       ShuffleSrcVT = MVT::getVectorVT(ShuffleSrcVT, NumMaskElts);
38959       Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShuffleSrcVT, V1, V2);
38960       Res = DAG.getNode(ISD::TRUNCATE, DL, IntMaskVT, Res);
38961       return DAG.getBitcast(RootVT, Res);
38962     }
38963   }
38964 
38965   // Don't try to re-form single instruction chains under any circumstances now
38966   // that we've done encoding canonicalization for them.
38967   if (Depth < 1)
38968     return SDValue();
38969 
38970   // Depth threshold above which we can efficiently use variable mask shuffles.
38971   int VariableCrossLaneShuffleDepth =
38972       Subtarget.hasFastVariableCrossLaneShuffle() ? 1 : 2;
38973   int VariablePerLaneShuffleDepth =
38974       Subtarget.hasFastVariablePerLaneShuffle() ? 1 : 2;
38975   AllowVariableCrossLaneMask &=
38976       (Depth >= VariableCrossLaneShuffleDepth) || HasVariableMask;
38977   AllowVariablePerLaneMask &=
38978       (Depth >= VariablePerLaneShuffleDepth) || HasVariableMask;
38979   // VPERMI2W/VPERMI2B are 3 uops on Skylake and Icelake so we require a
38980   // higher depth before combining them.
38981   bool AllowBWIVPERMV3 =
38982       (Depth >= (VariableCrossLaneShuffleDepth + 2) || HasVariableMask);
38983 
38984   bool MaskContainsZeros = isAnyZero(Mask);
38985 
38986   if (is128BitLaneCrossingShuffleMask(MaskVT, Mask)) {
38987     // If we have a single input lane-crossing shuffle then lower to VPERMV.
38988     if (UnaryShuffle && AllowVariableCrossLaneMask && !MaskContainsZeros) {
38989       if (Subtarget.hasAVX2() &&
38990           (MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) {
38991         SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true);
38992         Res = CanonicalizeShuffleInput(MaskVT, V1);
38993         Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res);
38994         return DAG.getBitcast(RootVT, Res);
38995       }
38996       // AVX512 variants (non-VLX will pad to 512-bit shuffles).
38997       if ((Subtarget.hasAVX512() &&
38998            (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 ||
38999             MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32)) ||
39000           (Subtarget.hasBWI() &&
39001            (MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) ||
39002           (Subtarget.hasVBMI() &&
39003            (MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8))) {
39004         V1 = CanonicalizeShuffleInput(MaskVT, V1);
39005         V2 = DAG.getUNDEF(MaskVT);
39006         Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
39007         return DAG.getBitcast(RootVT, Res);
39008       }
39009     }
39010 
39011     // Lower a unary+zero lane-crossing shuffle as VPERMV3 with a zero
39012     // vector as the second source (non-VLX will pad to 512-bit shuffles).
39013     if (UnaryShuffle && AllowVariableCrossLaneMask &&
39014         ((Subtarget.hasAVX512() &&
39015           (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 ||
39016            MaskVT == MVT::v4f64 || MaskVT == MVT::v4i64 ||
39017            MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32 ||
39018            MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32)) ||
39019          (Subtarget.hasBWI() && AllowBWIVPERMV3 &&
39020           (MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) ||
39021          (Subtarget.hasVBMI() && AllowBWIVPERMV3 &&
39022           (MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8)))) {
39023       // Adjust shuffle mask - replace SM_SentinelZero with second source index.
39024       for (unsigned i = 0; i != NumMaskElts; ++i)
39025         if (Mask[i] == SM_SentinelZero)
39026           Mask[i] = NumMaskElts + i;
39027       V1 = CanonicalizeShuffleInput(MaskVT, V1);
39028       V2 = getZeroVector(MaskVT, Subtarget, DAG, DL);
39029       Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
39030       return DAG.getBitcast(RootVT, Res);
39031     }
39032 
39033     // If that failed and either input is extracted then try to combine as a
39034     // shuffle with the larger type.
39035     if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
39036             Inputs, Root, BaseMask, Depth, HasVariableMask,
39037             AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG,
39038             Subtarget))
39039       return WideShuffle;
39040 
39041     // If we have a dual input lane-crossing shuffle then lower to VPERMV3,
39042     // (non-VLX will pad to 512-bit shuffles).
39043     if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
39044         ((Subtarget.hasAVX512() &&
39045           (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 ||
39046            MaskVT == MVT::v4f64 || MaskVT == MVT::v4i64 ||
39047            MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32 ||
39048            MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) ||
39049          (Subtarget.hasBWI() && AllowBWIVPERMV3 &&
39050           (MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) ||
39051          (Subtarget.hasVBMI() && AllowBWIVPERMV3 &&
39052           (MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8)))) {
39053       V1 = CanonicalizeShuffleInput(MaskVT, V1);
39054       V2 = CanonicalizeShuffleInput(MaskVT, V2);
39055       Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
39056       return DAG.getBitcast(RootVT, Res);
39057     }
39058     return SDValue();
39059   }
39060 
39061   // See if we can combine a single input shuffle with zeros to a bit-mask,
39062   // which is much simpler than any shuffle.
39063   if (UnaryShuffle && MaskContainsZeros && AllowVariablePerLaneMask &&
39064       isSequentialOrUndefOrZeroInRange(Mask, 0, NumMaskElts, 0) &&
39065       TLI.isTypeLegal(MaskVT)) {
39066     APInt Zero = APInt::getZero(MaskEltSizeInBits);
39067     APInt AllOnes = APInt::getAllOnes(MaskEltSizeInBits);
39068     APInt UndefElts(NumMaskElts, 0);
39069     SmallVector<APInt, 64> EltBits(NumMaskElts, Zero);
39070     for (unsigned i = 0; i != NumMaskElts; ++i) {
39071       int M = Mask[i];
39072       if (M == SM_SentinelUndef) {
39073         UndefElts.setBit(i);
39074         continue;
39075       }
39076       if (M == SM_SentinelZero)
39077         continue;
39078       EltBits[i] = AllOnes;
39079     }
39080     SDValue BitMask = getConstVector(EltBits, UndefElts, MaskVT, DAG, DL);
39081     Res = CanonicalizeShuffleInput(MaskVT, V1);
39082     unsigned AndOpcode =
39083         MaskVT.isFloatingPoint() ? unsigned(X86ISD::FAND) : unsigned(ISD::AND);
39084     Res = DAG.getNode(AndOpcode, DL, MaskVT, Res, BitMask);
39085     return DAG.getBitcast(RootVT, Res);
39086   }
39087 
39088   // If we have a single input shuffle with different shuffle patterns in the
39089   // the 128-bit lanes use the variable mask to VPERMILPS.
39090   // TODO Combine other mask types at higher depths.
39091   if (UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
39092       ((MaskVT == MVT::v8f32 && Subtarget.hasAVX()) ||
39093        (MaskVT == MVT::v16f32 && Subtarget.hasAVX512()))) {
39094     SmallVector<SDValue, 16> VPermIdx;
39095     for (int M : Mask) {
39096       SDValue Idx =
39097           M < 0 ? DAG.getUNDEF(MVT::i32) : DAG.getConstant(M % 4, DL, MVT::i32);
39098       VPermIdx.push_back(Idx);
39099     }
39100     SDValue VPermMask = DAG.getBuildVector(IntMaskVT, DL, VPermIdx);
39101     Res = CanonicalizeShuffleInput(MaskVT, V1);
39102     Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask);
39103     return DAG.getBitcast(RootVT, Res);
39104   }
39105 
39106   // With XOP, binary shuffles of 128/256-bit floating point vectors can combine
39107   // to VPERMIL2PD/VPERMIL2PS.
39108   if (AllowVariablePerLaneMask && Subtarget.hasXOP() &&
39109       (MaskVT == MVT::v2f64 || MaskVT == MVT::v4f64 || MaskVT == MVT::v4f32 ||
39110        MaskVT == MVT::v8f32)) {
39111     // VPERMIL2 Operation.
39112     // Bits[3] - Match Bit.
39113     // Bits[2:1] - (Per Lane) PD Shuffle Mask.
39114     // Bits[2:0] - (Per Lane) PS Shuffle Mask.
39115     unsigned NumLanes = MaskVT.getSizeInBits() / 128;
39116     unsigned NumEltsPerLane = NumMaskElts / NumLanes;
39117     SmallVector<int, 8> VPerm2Idx;
39118     unsigned M2ZImm = 0;
39119     for (int M : Mask) {
39120       if (M == SM_SentinelUndef) {
39121         VPerm2Idx.push_back(-1);
39122         continue;
39123       }
39124       if (M == SM_SentinelZero) {
39125         M2ZImm = 2;
39126         VPerm2Idx.push_back(8);
39127         continue;
39128       }
39129       int Index = (M % NumEltsPerLane) + ((M / NumMaskElts) * NumEltsPerLane);
39130       Index = (MaskVT.getScalarSizeInBits() == 64 ? Index << 1 : Index);
39131       VPerm2Idx.push_back(Index);
39132     }
39133     V1 = CanonicalizeShuffleInput(MaskVT, V1);
39134     V2 = CanonicalizeShuffleInput(MaskVT, V2);
39135     SDValue VPerm2MaskOp = getConstVector(VPerm2Idx, IntMaskVT, DAG, DL, true);
39136     Res = DAG.getNode(X86ISD::VPERMIL2, DL, MaskVT, V1, V2, VPerm2MaskOp,
39137                       DAG.getTargetConstant(M2ZImm, DL, MVT::i8));
39138     return DAG.getBitcast(RootVT, Res);
39139   }
39140 
39141   // If we have 3 or more shuffle instructions or a chain involving a variable
39142   // mask, we can replace them with a single PSHUFB instruction profitably.
39143   // Intel's manuals suggest only using PSHUFB if doing so replacing 5
39144   // instructions, but in practice PSHUFB tends to be *very* fast so we're
39145   // more aggressive.
39146   if (UnaryShuffle && AllowVariablePerLaneMask &&
39147       ((RootVT.is128BitVector() && Subtarget.hasSSSE3()) ||
39148        (RootVT.is256BitVector() && Subtarget.hasAVX2()) ||
39149        (RootVT.is512BitVector() && Subtarget.hasBWI()))) {
39150     SmallVector<SDValue, 16> PSHUFBMask;
39151     int NumBytes = RootVT.getSizeInBits() / 8;
39152     int Ratio = NumBytes / NumMaskElts;
39153     for (int i = 0; i < NumBytes; ++i) {
39154       int M = Mask[i / Ratio];
39155       if (M == SM_SentinelUndef) {
39156         PSHUFBMask.push_back(DAG.getUNDEF(MVT::i8));
39157         continue;
39158       }
39159       if (M == SM_SentinelZero) {
39160         PSHUFBMask.push_back(DAG.getConstant(0x80, DL, MVT::i8));
39161         continue;
39162       }
39163       M = Ratio * M + i % Ratio;
39164       assert((M / 16) == (i / 16) && "Lane crossing detected");
39165       PSHUFBMask.push_back(DAG.getConstant(M, DL, MVT::i8));
39166     }
39167     MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes);
39168     Res = CanonicalizeShuffleInput(ByteVT, V1);
39169     SDValue PSHUFBMaskOp = DAG.getBuildVector(ByteVT, DL, PSHUFBMask);
39170     Res = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, Res, PSHUFBMaskOp);
39171     return DAG.getBitcast(RootVT, Res);
39172   }
39173 
39174   // With XOP, if we have a 128-bit binary input shuffle we can always combine
39175   // to VPPERM. We match the depth requirement of PSHUFB - VPPERM is never
39176   // slower than PSHUFB on targets that support both.
39177   if (AllowVariablePerLaneMask && RootVT.is128BitVector() &&
39178       Subtarget.hasXOP()) {
39179     // VPPERM Mask Operation
39180     // Bits[4:0] - Byte Index (0 - 31)
39181     // Bits[7:5] - Permute Operation (0 - Source byte, 4 - ZERO)
39182     SmallVector<SDValue, 16> VPPERMMask;
39183     int NumBytes = 16;
39184     int Ratio = NumBytes / NumMaskElts;
39185     for (int i = 0; i < NumBytes; ++i) {
39186       int M = Mask[i / Ratio];
39187       if (M == SM_SentinelUndef) {
39188         VPPERMMask.push_back(DAG.getUNDEF(MVT::i8));
39189         continue;
39190       }
39191       if (M == SM_SentinelZero) {
39192         VPPERMMask.push_back(DAG.getConstant(0x80, DL, MVT::i8));
39193         continue;
39194       }
39195       M = Ratio * M + i % Ratio;
39196       VPPERMMask.push_back(DAG.getConstant(M, DL, MVT::i8));
39197     }
39198     MVT ByteVT = MVT::v16i8;
39199     V1 = CanonicalizeShuffleInput(ByteVT, V1);
39200     V2 = CanonicalizeShuffleInput(ByteVT, V2);
39201     SDValue VPPERMMaskOp = DAG.getBuildVector(ByteVT, DL, VPPERMMask);
39202     Res = DAG.getNode(X86ISD::VPPERM, DL, ByteVT, V1, V2, VPPERMMaskOp);
39203     return DAG.getBitcast(RootVT, Res);
39204   }
39205 
39206   // If that failed and either input is extracted then try to combine as a
39207   // shuffle with the larger type.
39208   if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
39209           Inputs, Root, BaseMask, Depth, HasVariableMask,
39210           AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG, Subtarget))
39211     return WideShuffle;
39212 
39213   // If we have a dual input shuffle then lower to VPERMV3,
39214   // (non-VLX will pad to 512-bit shuffles)
39215   if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
39216       ((Subtarget.hasAVX512() &&
39217         (MaskVT == MVT::v2f64 || MaskVT == MVT::v4f64 || MaskVT == MVT::v8f64 ||
39218          MaskVT == MVT::v2i64 || MaskVT == MVT::v4i64 || MaskVT == MVT::v8i64 ||
39219          MaskVT == MVT::v4f32 || MaskVT == MVT::v4i32 || MaskVT == MVT::v8f32 ||
39220          MaskVT == MVT::v8i32 || MaskVT == MVT::v16f32 ||
39221          MaskVT == MVT::v16i32)) ||
39222        (Subtarget.hasBWI() && AllowBWIVPERMV3 &&
39223         (MaskVT == MVT::v8i16 || MaskVT == MVT::v16i16 ||
39224          MaskVT == MVT::v32i16)) ||
39225        (Subtarget.hasVBMI() && AllowBWIVPERMV3 &&
39226         (MaskVT == MVT::v16i8 || MaskVT == MVT::v32i8 ||
39227          MaskVT == MVT::v64i8)))) {
39228     V1 = CanonicalizeShuffleInput(MaskVT, V1);
39229     V2 = CanonicalizeShuffleInput(MaskVT, V2);
39230     Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG);
39231     return DAG.getBitcast(RootVT, Res);
39232   }
39233 
39234   // Failed to find any combines.
39235   return SDValue();
39236 }
39237 
39238 // Combine an arbitrary chain of shuffles + extract_subvectors into a single
39239 // instruction if possible.
39240 //
39241 // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
39242 // type size to attempt to combine:
39243 // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
39244 // -->
39245 // extract_subvector(shuffle(x,y,m2),0)
combineX86ShuffleChainWithExtract(ArrayRef<SDValue> Inputs,SDValue Root,ArrayRef<int> BaseMask,int Depth,bool HasVariableMask,bool AllowVariableCrossLaneMask,bool AllowVariablePerLaneMask,SelectionDAG & DAG,const X86Subtarget & Subtarget)39246 static SDValue combineX86ShuffleChainWithExtract(
39247     ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth,
39248     bool HasVariableMask, bool AllowVariableCrossLaneMask,
39249     bool AllowVariablePerLaneMask, SelectionDAG &DAG,
39250     const X86Subtarget &Subtarget) {
39251   unsigned NumMaskElts = BaseMask.size();
39252   unsigned NumInputs = Inputs.size();
39253   if (NumInputs == 0)
39254     return SDValue();
39255 
39256   EVT RootVT = Root.getValueType();
39257   unsigned RootSizeInBits = RootVT.getSizeInBits();
39258   unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
39259   assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
39260 
39261   // Peek through extract_subvector to find widest legal vector.
39262   // TODO: Handle ISD::TRUNCATE
39263   unsigned WideSizeInBits = RootSizeInBits;
39264   for (unsigned I = 0; I != NumInputs; ++I) {
39265     SDValue Input = peekThroughBitcasts(Inputs[I]);
39266     while (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR)
39267       Input = peekThroughBitcasts(Input.getOperand(0));
39268     if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
39269         WideSizeInBits < Input.getValueSizeInBits())
39270       WideSizeInBits = Input.getValueSizeInBits();
39271   }
39272 
39273   // Bail if we fail to find a source larger than the existing root.
39274   unsigned Scale = WideSizeInBits / RootSizeInBits;
39275   if (WideSizeInBits <= RootSizeInBits ||
39276       (WideSizeInBits % RootSizeInBits) != 0)
39277     return SDValue();
39278 
39279   // Create new mask for larger type.
39280   SmallVector<int, 64> WideMask(BaseMask);
39281   for (int &M : WideMask) {
39282     if (M < 0)
39283       continue;
39284     M = (M % NumMaskElts) + ((M / NumMaskElts) * Scale * NumMaskElts);
39285   }
39286   WideMask.append((Scale - 1) * NumMaskElts, SM_SentinelUndef);
39287 
39288   // Attempt to peek through inputs and adjust mask when we extract from an
39289   // upper subvector.
39290   int AdjustedMasks = 0;
39291   SmallVector<SDValue, 4> WideInputs(Inputs.begin(), Inputs.end());
39292   for (unsigned I = 0; I != NumInputs; ++I) {
39293     SDValue &Input = WideInputs[I];
39294     Input = peekThroughBitcasts(Input);
39295     while (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
39296            Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
39297       uint64_t Idx = Input.getConstantOperandVal(1);
39298       if (Idx != 0) {
39299         ++AdjustedMasks;
39300         unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
39301         Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
39302 
39303         int lo = I * WideMask.size();
39304         int hi = (I + 1) * WideMask.size();
39305         for (int &M : WideMask)
39306           if (lo <= M && M < hi)
39307             M += Idx;
39308       }
39309       Input = peekThroughBitcasts(Input.getOperand(0));
39310     }
39311   }
39312 
39313   // Remove unused/repeated shuffle source ops.
39314   resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
39315   assert(!WideInputs.empty() && "Shuffle with no inputs detected");
39316 
39317   // Bail if we're always extracting from the lowest subvectors,
39318   // combineX86ShuffleChain should match this for the current width, or the
39319   // shuffle still references too many inputs.
39320   if (AdjustedMasks == 0 || WideInputs.size() > 2)
39321     return SDValue();
39322 
39323   // Minor canonicalization of the accumulated shuffle mask to make it easier
39324   // to match below. All this does is detect masks with sequential pairs of
39325   // elements, and shrink them to the half-width mask. It does this in a loop
39326   // so it will reduce the size of the mask to the minimal width mask which
39327   // performs an equivalent shuffle.
39328   while (WideMask.size() > 1) {
39329     SmallVector<int, 64> WidenedMask;
39330     if (!canWidenShuffleElements(WideMask, WidenedMask))
39331       break;
39332     WideMask = std::move(WidenedMask);
39333   }
39334 
39335   // Canonicalization of binary shuffle masks to improve pattern matching by
39336   // commuting the inputs.
39337   if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
39338     ShuffleVectorSDNode::commuteMask(WideMask);
39339     std::swap(WideInputs[0], WideInputs[1]);
39340   }
39341 
39342   // Increase depth for every upper subvector we've peeked through.
39343   Depth += AdjustedMasks;
39344 
39345   // Attempt to combine wider chain.
39346   // TODO: Can we use a better Root?
39347   SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
39348                              WideInputs.back().getValueSizeInBits()
39349                          ? WideInputs.front()
39350                          : WideInputs.back();
39351   assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
39352          "WideRootSize mismatch");
39353 
39354   if (SDValue WideShuffle =
39355           combineX86ShuffleChain(WideInputs, WideRoot, WideMask, Depth,
39356                                  HasVariableMask, AllowVariableCrossLaneMask,
39357                                  AllowVariablePerLaneMask, DAG, Subtarget)) {
39358     WideShuffle =
39359         extractSubVector(WideShuffle, 0, DAG, SDLoc(Root), RootSizeInBits);
39360     return DAG.getBitcast(RootVT, WideShuffle);
39361   }
39362 
39363   return SDValue();
39364 }
39365 
39366 // Canonicalize the combined shuffle mask chain with horizontal ops.
39367 // NOTE: This may update the Ops and Mask.
canonicalizeShuffleMaskWithHorizOp(MutableArrayRef<SDValue> Ops,MutableArrayRef<int> Mask,unsigned RootSizeInBits,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)39368 static SDValue canonicalizeShuffleMaskWithHorizOp(
39369     MutableArrayRef<SDValue> Ops, MutableArrayRef<int> Mask,
39370     unsigned RootSizeInBits, const SDLoc &DL, SelectionDAG &DAG,
39371     const X86Subtarget &Subtarget) {
39372   if (Mask.empty() || Ops.empty())
39373     return SDValue();
39374 
39375   SmallVector<SDValue> BC;
39376   for (SDValue Op : Ops)
39377     BC.push_back(peekThroughBitcasts(Op));
39378 
39379   // All ops must be the same horizop + type.
39380   SDValue BC0 = BC[0];
39381   EVT VT0 = BC0.getValueType();
39382   unsigned Opcode0 = BC0.getOpcode();
39383   if (VT0.getSizeInBits() != RootSizeInBits || llvm::any_of(BC, [&](SDValue V) {
39384         return V.getOpcode() != Opcode0 || V.getValueType() != VT0;
39385       }))
39386     return SDValue();
39387 
39388   bool isHoriz = (Opcode0 == X86ISD::FHADD || Opcode0 == X86ISD::HADD ||
39389                   Opcode0 == X86ISD::FHSUB || Opcode0 == X86ISD::HSUB);
39390   bool isPack = (Opcode0 == X86ISD::PACKSS || Opcode0 == X86ISD::PACKUS);
39391   if (!isHoriz && !isPack)
39392     return SDValue();
39393 
39394   // Do all ops have a single use?
39395   bool OneUseOps = llvm::all_of(Ops, [](SDValue Op) {
39396     return Op.hasOneUse() &&
39397            peekThroughBitcasts(Op) == peekThroughOneUseBitcasts(Op);
39398   });
39399 
39400   int NumElts = VT0.getVectorNumElements();
39401   int NumLanes = VT0.getSizeInBits() / 128;
39402   int NumEltsPerLane = NumElts / NumLanes;
39403   int NumHalfEltsPerLane = NumEltsPerLane / 2;
39404   MVT SrcVT = BC0.getOperand(0).getSimpleValueType();
39405   unsigned EltSizeInBits = RootSizeInBits / Mask.size();
39406 
39407   if (NumEltsPerLane >= 4 &&
39408       (isPack || shouldUseHorizontalOp(Ops.size() == 1, DAG, Subtarget))) {
39409     SmallVector<int> LaneMask, ScaledMask;
39410     if (isRepeatedTargetShuffleMask(128, EltSizeInBits, Mask, LaneMask) &&
39411         scaleShuffleElements(LaneMask, 4, ScaledMask)) {
39412       // See if we can remove the shuffle by resorting the HOP chain so that
39413       // the HOP args are pre-shuffled.
39414       // TODO: Generalize to any sized/depth chain.
39415       // TODO: Add support for PACKSS/PACKUS.
39416       if (isHoriz) {
39417         // Attempt to find a HOP(HOP(X,Y),HOP(Z,W)) source operand.
39418         auto GetHOpSrc = [&](int M) {
39419           if (M == SM_SentinelUndef)
39420             return DAG.getUNDEF(VT0);
39421           if (M == SM_SentinelZero)
39422             return getZeroVector(VT0.getSimpleVT(), Subtarget, DAG, DL);
39423           SDValue Src0 = BC[M / 4];
39424           SDValue Src1 = Src0.getOperand((M % 4) >= 2);
39425           if (Src1.getOpcode() == Opcode0 && Src0->isOnlyUserOf(Src1.getNode()))
39426             return Src1.getOperand(M % 2);
39427           return SDValue();
39428         };
39429         SDValue M0 = GetHOpSrc(ScaledMask[0]);
39430         SDValue M1 = GetHOpSrc(ScaledMask[1]);
39431         SDValue M2 = GetHOpSrc(ScaledMask[2]);
39432         SDValue M3 = GetHOpSrc(ScaledMask[3]);
39433         if (M0 && M1 && M2 && M3) {
39434           SDValue LHS = DAG.getNode(Opcode0, DL, SrcVT, M0, M1);
39435           SDValue RHS = DAG.getNode(Opcode0, DL, SrcVT, M2, M3);
39436           return DAG.getNode(Opcode0, DL, VT0, LHS, RHS);
39437         }
39438       }
39439       // shuffle(hop(x,y),hop(z,w)) -> permute(hop(x,z)) etc.
39440       if (Ops.size() >= 2) {
39441         SDValue LHS, RHS;
39442         auto GetHOpSrc = [&](int M, int &OutM) {
39443           // TODO: Support SM_SentinelZero
39444           if (M < 0)
39445             return M == SM_SentinelUndef;
39446           SDValue Src = BC[M / 4].getOperand((M % 4) >= 2);
39447           if (!LHS || LHS == Src) {
39448             LHS = Src;
39449             OutM = (M % 2);
39450             return true;
39451           }
39452           if (!RHS || RHS == Src) {
39453             RHS = Src;
39454             OutM = (M % 2) + 2;
39455             return true;
39456           }
39457           return false;
39458         };
39459         int PostMask[4] = {-1, -1, -1, -1};
39460         if (GetHOpSrc(ScaledMask[0], PostMask[0]) &&
39461             GetHOpSrc(ScaledMask[1], PostMask[1]) &&
39462             GetHOpSrc(ScaledMask[2], PostMask[2]) &&
39463             GetHOpSrc(ScaledMask[3], PostMask[3])) {
39464           LHS = DAG.getBitcast(SrcVT, LHS);
39465           RHS = DAG.getBitcast(SrcVT, RHS ? RHS : LHS);
39466           SDValue Res = DAG.getNode(Opcode0, DL, VT0, LHS, RHS);
39467           // Use SHUFPS for the permute so this will work on SSE2 targets,
39468           // shuffle combining and domain handling will simplify this later on.
39469           MVT ShuffleVT = MVT::getVectorVT(MVT::f32, RootSizeInBits / 32);
39470           Res = DAG.getBitcast(ShuffleVT, Res);
39471           return DAG.getNode(X86ISD::SHUFP, DL, ShuffleVT, Res, Res,
39472                              getV4X86ShuffleImm8ForMask(PostMask, DL, DAG));
39473         }
39474       }
39475     }
39476   }
39477 
39478   if (2 < Ops.size())
39479     return SDValue();
39480 
39481   SDValue BC1 = BC[BC.size() - 1];
39482   if (Mask.size() == VT0.getVectorNumElements()) {
39483     // Canonicalize binary shuffles of horizontal ops that use the
39484     // same sources to an unary shuffle.
39485     // TODO: Try to perform this fold even if the shuffle remains.
39486     if (Ops.size() == 2) {
39487       auto ContainsOps = [](SDValue HOp, SDValue Op) {
39488         return Op == HOp.getOperand(0) || Op == HOp.getOperand(1);
39489       };
39490       // Commute if all BC0's ops are contained in BC1.
39491       if (ContainsOps(BC1, BC0.getOperand(0)) &&
39492           ContainsOps(BC1, BC0.getOperand(1))) {
39493         ShuffleVectorSDNode::commuteMask(Mask);
39494         std::swap(Ops[0], Ops[1]);
39495         std::swap(BC0, BC1);
39496       }
39497 
39498       // If BC1 can be represented by BC0, then convert to unary shuffle.
39499       if (ContainsOps(BC0, BC1.getOperand(0)) &&
39500           ContainsOps(BC0, BC1.getOperand(1))) {
39501         for (int &M : Mask) {
39502           if (M < NumElts) // BC0 element or UNDEF/Zero sentinel.
39503             continue;
39504           int SubLane = ((M % NumEltsPerLane) >= NumHalfEltsPerLane) ? 1 : 0;
39505           M -= NumElts + (SubLane * NumHalfEltsPerLane);
39506           if (BC1.getOperand(SubLane) != BC0.getOperand(0))
39507             M += NumHalfEltsPerLane;
39508         }
39509       }
39510     }
39511 
39512     // Canonicalize unary horizontal ops to only refer to lower halves.
39513     for (int i = 0; i != NumElts; ++i) {
39514       int &M = Mask[i];
39515       if (isUndefOrZero(M))
39516         continue;
39517       if (M < NumElts && BC0.getOperand(0) == BC0.getOperand(1) &&
39518           (M % NumEltsPerLane) >= NumHalfEltsPerLane)
39519         M -= NumHalfEltsPerLane;
39520       if (NumElts <= M && BC1.getOperand(0) == BC1.getOperand(1) &&
39521           (M % NumEltsPerLane) >= NumHalfEltsPerLane)
39522         M -= NumHalfEltsPerLane;
39523     }
39524   }
39525 
39526   // Combine binary shuffle of 2 similar 'Horizontal' instructions into a
39527   // single instruction. Attempt to match a v2X64 repeating shuffle pattern that
39528   // represents the LHS/RHS inputs for the lower/upper halves.
39529   SmallVector<int, 16> TargetMask128, WideMask128;
39530   if (isRepeatedTargetShuffleMask(128, EltSizeInBits, Mask, TargetMask128) &&
39531       scaleShuffleElements(TargetMask128, 2, WideMask128)) {
39532     assert(isUndefOrZeroOrInRange(WideMask128, 0, 4) && "Illegal shuffle");
39533     bool SingleOp = (Ops.size() == 1);
39534     if (isPack || OneUseOps ||
39535         shouldUseHorizontalOp(SingleOp, DAG, Subtarget)) {
39536       SDValue Lo = isInRange(WideMask128[0], 0, 2) ? BC0 : BC1;
39537       SDValue Hi = isInRange(WideMask128[1], 0, 2) ? BC0 : BC1;
39538       Lo = Lo.getOperand(WideMask128[0] & 1);
39539       Hi = Hi.getOperand(WideMask128[1] & 1);
39540       if (SingleOp) {
39541         SDValue Undef = DAG.getUNDEF(SrcVT);
39542         SDValue Zero = getZeroVector(SrcVT, Subtarget, DAG, DL);
39543         Lo = (WideMask128[0] == SM_SentinelZero ? Zero : Lo);
39544         Hi = (WideMask128[1] == SM_SentinelZero ? Zero : Hi);
39545         Lo = (WideMask128[0] == SM_SentinelUndef ? Undef : Lo);
39546         Hi = (WideMask128[1] == SM_SentinelUndef ? Undef : Hi);
39547       }
39548       return DAG.getNode(Opcode0, DL, VT0, Lo, Hi);
39549     }
39550   }
39551 
39552   // If we are post-shuffling a 256-bit hop and not requiring the upper
39553   // elements, then try to narrow to a 128-bit hop directly.
39554   SmallVector<int, 16> WideMask64;
39555   if (Ops.size() == 1 && NumLanes == 2 &&
39556       scaleShuffleElements(Mask, 4, WideMask64) &&
39557       isUndefInRange(WideMask64, 2, 2)) {
39558     int M0 = WideMask64[0];
39559     int M1 = WideMask64[1];
39560     if (isInRange(M0, 0, 4) && isInRange(M1, 0, 4)) {
39561       MVT HalfVT = VT0.getSimpleVT().getHalfNumVectorElementsVT();
39562       unsigned Idx0 = (M0 & 2) ? (SrcVT.getVectorNumElements() / 2) : 0;
39563       unsigned Idx1 = (M1 & 2) ? (SrcVT.getVectorNumElements() / 2) : 0;
39564       SDValue V0 = extract128BitVector(BC[0].getOperand(M0 & 1), Idx0, DAG, DL);
39565       SDValue V1 = extract128BitVector(BC[0].getOperand(M1 & 1), Idx1, DAG, DL);
39566       SDValue Res = DAG.getNode(Opcode0, DL, HalfVT, V0, V1);
39567       return widenSubVector(Res, false, Subtarget, DAG, DL, 256);
39568     }
39569   }
39570 
39571   return SDValue();
39572 }
39573 
39574 // Attempt to constant fold all of the constant source ops.
39575 // Returns true if the entire shuffle is folded to a constant.
39576 // TODO: Extend this to merge multiple constant Ops and update the mask.
combineX86ShufflesConstants(MVT VT,ArrayRef<SDValue> Ops,ArrayRef<int> Mask,bool HasVariableMask,SelectionDAG & DAG,const SDLoc & DL,const X86Subtarget & Subtarget)39577 static SDValue combineX86ShufflesConstants(MVT VT, ArrayRef<SDValue> Ops,
39578                                            ArrayRef<int> Mask,
39579                                            bool HasVariableMask,
39580                                            SelectionDAG &DAG, const SDLoc &DL,
39581                                            const X86Subtarget &Subtarget) {
39582   unsigned SizeInBits = VT.getSizeInBits();
39583   unsigned NumMaskElts = Mask.size();
39584   unsigned MaskSizeInBits = SizeInBits / NumMaskElts;
39585   unsigned NumOps = Ops.size();
39586 
39587   // Extract constant bits from each source op.
39588   SmallVector<APInt, 16> UndefEltsOps(NumOps);
39589   SmallVector<SmallVector<APInt, 16>, 16> RawBitsOps(NumOps);
39590   for (unsigned I = 0; I != NumOps; ++I)
39591     if (!getTargetConstantBitsFromNode(Ops[I], MaskSizeInBits, UndefEltsOps[I],
39592                                        RawBitsOps[I],
39593                                        /*AllowWholeUndefs*/ true,
39594                                        /*AllowPartialUndefs*/ true))
39595       return SDValue();
39596 
39597   // If we're optimizing for size, only fold if at least one of the constants is
39598   // only used once or the combined shuffle has included a variable mask
39599   // shuffle, this is to avoid constant pool bloat.
39600   bool IsOptimizingSize = DAG.shouldOptForSize();
39601   if (IsOptimizingSize && !HasVariableMask &&
39602       llvm::none_of(Ops, [](SDValue SrcOp) { return SrcOp->hasOneUse(); }))
39603     return SDValue();
39604 
39605   // Shuffle the constant bits according to the mask.
39606   APInt UndefElts(NumMaskElts, 0);
39607   APInt ZeroElts(NumMaskElts, 0);
39608   APInt ConstantElts(NumMaskElts, 0);
39609   SmallVector<APInt, 8> ConstantBitData(NumMaskElts,
39610                                         APInt::getZero(MaskSizeInBits));
39611   for (unsigned i = 0; i != NumMaskElts; ++i) {
39612     int M = Mask[i];
39613     if (M == SM_SentinelUndef) {
39614       UndefElts.setBit(i);
39615       continue;
39616     } else if (M == SM_SentinelZero) {
39617       ZeroElts.setBit(i);
39618       continue;
39619     }
39620     assert(0 <= M && M < (int)(NumMaskElts * NumOps));
39621 
39622     unsigned SrcOpIdx = (unsigned)M / NumMaskElts;
39623     unsigned SrcMaskIdx = (unsigned)M % NumMaskElts;
39624 
39625     auto &SrcUndefElts = UndefEltsOps[SrcOpIdx];
39626     if (SrcUndefElts[SrcMaskIdx]) {
39627       UndefElts.setBit(i);
39628       continue;
39629     }
39630 
39631     auto &SrcEltBits = RawBitsOps[SrcOpIdx];
39632     APInt &Bits = SrcEltBits[SrcMaskIdx];
39633     if (!Bits) {
39634       ZeroElts.setBit(i);
39635       continue;
39636     }
39637 
39638     ConstantElts.setBit(i);
39639     ConstantBitData[i] = Bits;
39640   }
39641   assert((UndefElts | ZeroElts | ConstantElts).isAllOnes());
39642 
39643   // Attempt to create a zero vector.
39644   if ((UndefElts | ZeroElts).isAllOnes())
39645     return getZeroVector(VT, Subtarget, DAG, DL);
39646 
39647   // Create the constant data.
39648   MVT MaskSVT;
39649   if (VT.isFloatingPoint() && (MaskSizeInBits == 32 || MaskSizeInBits == 64))
39650     MaskSVT = MVT::getFloatingPointVT(MaskSizeInBits);
39651   else
39652     MaskSVT = MVT::getIntegerVT(MaskSizeInBits);
39653 
39654   MVT MaskVT = MVT::getVectorVT(MaskSVT, NumMaskElts);
39655   if (!DAG.getTargetLoweringInfo().isTypeLegal(MaskVT))
39656     return SDValue();
39657 
39658   SDValue CstOp = getConstVector(ConstantBitData, UndefElts, MaskVT, DAG, DL);
39659   return DAG.getBitcast(VT, CstOp);
39660 }
39661 
39662 namespace llvm {
39663   namespace X86 {
39664     enum {
39665       MaxShuffleCombineDepth = 8
39666     };
39667   } // namespace X86
39668 } // namespace llvm
39669 
39670 /// Fully generic combining of x86 shuffle instructions.
39671 ///
39672 /// This should be the last combine run over the x86 shuffle instructions. Once
39673 /// they have been fully optimized, this will recursively consider all chains
39674 /// of single-use shuffle instructions, build a generic model of the cumulative
39675 /// shuffle operation, and check for simpler instructions which implement this
39676 /// operation. We use this primarily for two purposes:
39677 ///
39678 /// 1) Collapse generic shuffles to specialized single instructions when
39679 ///    equivalent. In most cases, this is just an encoding size win, but
39680 ///    sometimes we will collapse multiple generic shuffles into a single
39681 ///    special-purpose shuffle.
39682 /// 2) Look for sequences of shuffle instructions with 3 or more total
39683 ///    instructions, and replace them with the slightly more expensive SSSE3
39684 ///    PSHUFB instruction if available. We do this as the last combining step
39685 ///    to ensure we avoid using PSHUFB if we can implement the shuffle with
39686 ///    a suitable short sequence of other instructions. The PSHUFB will either
39687 ///    use a register or have to read from memory and so is slightly (but only
39688 ///    slightly) more expensive than the other shuffle instructions.
39689 ///
39690 /// Because this is inherently a quadratic operation (for each shuffle in
39691 /// a chain, we recurse up the chain), the depth is limited to 8 instructions.
39692 /// This should never be an issue in practice as the shuffle lowering doesn't
39693 /// produce sequences of more than 8 instructions.
39694 ///
39695 /// FIXME: We will currently miss some cases where the redundant shuffling
39696 /// would simplify under the threshold for PSHUFB formation because of
39697 /// combine-ordering. To fix this, we should do the redundant instruction
39698 /// combining in this recursive walk.
combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps,int SrcOpIndex,SDValue Root,ArrayRef<int> RootMask,ArrayRef<const SDNode * > SrcNodes,unsigned Depth,unsigned MaxDepth,bool HasVariableMask,bool AllowVariableCrossLaneMask,bool AllowVariablePerLaneMask,SelectionDAG & DAG,const X86Subtarget & Subtarget)39699 static SDValue combineX86ShufflesRecursively(
39700     ArrayRef<SDValue> SrcOps, int SrcOpIndex, SDValue Root,
39701     ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, unsigned Depth,
39702     unsigned MaxDepth, bool HasVariableMask, bool AllowVariableCrossLaneMask,
39703     bool AllowVariablePerLaneMask, SelectionDAG &DAG,
39704     const X86Subtarget &Subtarget) {
39705   assert(!RootMask.empty() &&
39706          (RootMask.size() > 1 || (RootMask[0] == 0 && SrcOpIndex == 0)) &&
39707          "Illegal shuffle root mask");
39708   MVT RootVT = Root.getSimpleValueType();
39709   assert(RootVT.isVector() && "Shuffles operate on vector types!");
39710   unsigned RootSizeInBits = RootVT.getSizeInBits();
39711   SDLoc DL(Root);
39712 
39713   // Bound the depth of our recursive combine because this is ultimately
39714   // quadratic in nature.
39715   if (Depth >= MaxDepth)
39716     return SDValue();
39717 
39718   // Directly rip through bitcasts to find the underlying operand.
39719   SDValue Op = SrcOps[SrcOpIndex];
39720   Op = peekThroughOneUseBitcasts(Op);
39721 
39722   EVT VT = Op.getValueType();
39723   if (!VT.isVector() || !VT.isSimple())
39724     return SDValue(); // Bail if we hit a non-simple non-vector.
39725 
39726   // FIXME: Just bail on f16 for now.
39727   if (VT.getVectorElementType() == MVT::f16)
39728     return SDValue();
39729 
39730   assert((RootSizeInBits % VT.getSizeInBits()) == 0 &&
39731          "Can only combine shuffles upto size of the root op.");
39732 
39733   // Create a demanded elts mask from the referenced elements of Op.
39734   APInt OpDemandedElts = APInt::getZero(RootMask.size());
39735   for (int M : RootMask) {
39736     int BaseIdx = RootMask.size() * SrcOpIndex;
39737     if (isInRange(M, BaseIdx, BaseIdx + RootMask.size()))
39738       OpDemandedElts.setBit(M - BaseIdx);
39739   }
39740   if (RootSizeInBits != VT.getSizeInBits()) {
39741     // Op is smaller than Root - extract the demanded elts for the subvector.
39742     unsigned Scale = RootSizeInBits / VT.getSizeInBits();
39743     unsigned NumOpMaskElts = RootMask.size() / Scale;
39744     assert((RootMask.size() % Scale) == 0 && "Root mask size mismatch");
39745     assert(OpDemandedElts
39746                .extractBits(RootMask.size() - NumOpMaskElts, NumOpMaskElts)
39747                .isZero() &&
39748            "Out of range elements referenced in root mask");
39749     OpDemandedElts = OpDemandedElts.extractBits(NumOpMaskElts, 0);
39750   }
39751   OpDemandedElts =
39752       APIntOps::ScaleBitMask(OpDemandedElts, VT.getVectorNumElements());
39753 
39754   // Extract target shuffle mask and resolve sentinels and inputs.
39755   SmallVector<int, 64> OpMask;
39756   SmallVector<SDValue, 2> OpInputs;
39757   APInt OpUndef, OpZero;
39758   bool IsOpVariableMask = isTargetShuffleVariableMask(Op.getOpcode());
39759   if (getTargetShuffleInputs(Op, OpDemandedElts, OpInputs, OpMask, OpUndef,
39760                              OpZero, DAG, Depth, false)) {
39761     // Shuffle inputs must not be larger than the shuffle result.
39762     // TODO: Relax this for single input faux shuffles (e.g. trunc).
39763     if (llvm::any_of(OpInputs, [VT](SDValue OpInput) {
39764           return OpInput.getValueSizeInBits() > VT.getSizeInBits();
39765         }))
39766       return SDValue();
39767   } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
39768              (RootSizeInBits % Op.getOperand(0).getValueSizeInBits()) == 0 &&
39769              !isNullConstant(Op.getOperand(1))) {
39770     SDValue SrcVec = Op.getOperand(0);
39771     int ExtractIdx = Op.getConstantOperandVal(1);
39772     unsigned NumElts = VT.getVectorNumElements();
39773     OpInputs.assign({SrcVec});
39774     OpMask.assign(NumElts, SM_SentinelUndef);
39775     std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
39776     OpZero = OpUndef = APInt::getZero(NumElts);
39777   } else {
39778     return SDValue();
39779   }
39780 
39781   // If the shuffle result was smaller than the root, we need to adjust the
39782   // mask indices and pad the mask with undefs.
39783   if (RootSizeInBits > VT.getSizeInBits()) {
39784     unsigned NumSubVecs = RootSizeInBits / VT.getSizeInBits();
39785     unsigned OpMaskSize = OpMask.size();
39786     if (OpInputs.size() > 1) {
39787       unsigned PaddedMaskSize = NumSubVecs * OpMaskSize;
39788       for (int &M : OpMask) {
39789         if (M < 0)
39790           continue;
39791         int EltIdx = M % OpMaskSize;
39792         int OpIdx = M / OpMaskSize;
39793         M = (PaddedMaskSize * OpIdx) + EltIdx;
39794       }
39795     }
39796     OpZero = OpZero.zext(NumSubVecs * OpMaskSize);
39797     OpUndef = OpUndef.zext(NumSubVecs * OpMaskSize);
39798     OpMask.append((NumSubVecs - 1) * OpMaskSize, SM_SentinelUndef);
39799   }
39800 
39801   SmallVector<int, 64> Mask;
39802   SmallVector<SDValue, 16> Ops;
39803 
39804   // We don't need to merge masks if the root is empty.
39805   bool EmptyRoot = (Depth == 0) && (RootMask.size() == 1);
39806   if (EmptyRoot) {
39807     // Only resolve zeros if it will remove an input, otherwise we might end
39808     // up in an infinite loop.
39809     bool ResolveKnownZeros = true;
39810     if (!OpZero.isZero()) {
39811       APInt UsedInputs = APInt::getZero(OpInputs.size());
39812       for (int i = 0, e = OpMask.size(); i != e; ++i) {
39813         int M = OpMask[i];
39814         if (OpUndef[i] || OpZero[i] || isUndefOrZero(M))
39815           continue;
39816         UsedInputs.setBit(M / OpMask.size());
39817         if (UsedInputs.isAllOnes()) {
39818           ResolveKnownZeros = false;
39819           break;
39820         }
39821       }
39822     }
39823     resolveTargetShuffleFromZeroables(OpMask, OpUndef, OpZero,
39824                                       ResolveKnownZeros);
39825 
39826     Mask = OpMask;
39827     Ops.append(OpInputs.begin(), OpInputs.end());
39828   } else {
39829     resolveTargetShuffleFromZeroables(OpMask, OpUndef, OpZero);
39830 
39831     // Add the inputs to the Ops list, avoiding duplicates.
39832     Ops.append(SrcOps.begin(), SrcOps.end());
39833 
39834     auto AddOp = [&Ops](SDValue Input, int InsertionPoint) -> int {
39835       // Attempt to find an existing match.
39836       SDValue InputBC = peekThroughBitcasts(Input);
39837       for (int i = 0, e = Ops.size(); i < e; ++i)
39838         if (InputBC == peekThroughBitcasts(Ops[i]))
39839           return i;
39840       // Match failed - should we replace an existing Op?
39841       if (InsertionPoint >= 0) {
39842         Ops[InsertionPoint] = Input;
39843         return InsertionPoint;
39844       }
39845       // Add to the end of the Ops list.
39846       Ops.push_back(Input);
39847       return Ops.size() - 1;
39848     };
39849 
39850     SmallVector<int, 2> OpInputIdx;
39851     for (SDValue OpInput : OpInputs)
39852       OpInputIdx.push_back(
39853           AddOp(OpInput, OpInputIdx.empty() ? SrcOpIndex : -1));
39854 
39855     assert(((RootMask.size() > OpMask.size() &&
39856              RootMask.size() % OpMask.size() == 0) ||
39857             (OpMask.size() > RootMask.size() &&
39858              OpMask.size() % RootMask.size() == 0) ||
39859             OpMask.size() == RootMask.size()) &&
39860            "The smaller number of elements must divide the larger.");
39861 
39862     // This function can be performance-critical, so we rely on the power-of-2
39863     // knowledge that we have about the mask sizes to replace div/rem ops with
39864     // bit-masks and shifts.
39865     assert(llvm::has_single_bit<uint32_t>(RootMask.size()) &&
39866            "Non-power-of-2 shuffle mask sizes");
39867     assert(llvm::has_single_bit<uint32_t>(OpMask.size()) &&
39868            "Non-power-of-2 shuffle mask sizes");
39869     unsigned RootMaskSizeLog2 = llvm::countr_zero(RootMask.size());
39870     unsigned OpMaskSizeLog2 = llvm::countr_zero(OpMask.size());
39871 
39872     unsigned MaskWidth = std::max<unsigned>(OpMask.size(), RootMask.size());
39873     unsigned RootRatio =
39874         std::max<unsigned>(1, OpMask.size() >> RootMaskSizeLog2);
39875     unsigned OpRatio = std::max<unsigned>(1, RootMask.size() >> OpMaskSizeLog2);
39876     assert((RootRatio == 1 || OpRatio == 1) &&
39877            "Must not have a ratio for both incoming and op masks!");
39878 
39879     assert(isPowerOf2_32(MaskWidth) && "Non-power-of-2 shuffle mask sizes");
39880     assert(isPowerOf2_32(RootRatio) && "Non-power-of-2 shuffle mask sizes");
39881     assert(isPowerOf2_32(OpRatio) && "Non-power-of-2 shuffle mask sizes");
39882     unsigned RootRatioLog2 = llvm::countr_zero(RootRatio);
39883     unsigned OpRatioLog2 = llvm::countr_zero(OpRatio);
39884 
39885     Mask.resize(MaskWidth, SM_SentinelUndef);
39886 
39887     // Merge this shuffle operation's mask into our accumulated mask. Note that
39888     // this shuffle's mask will be the first applied to the input, followed by
39889     // the root mask to get us all the way to the root value arrangement. The
39890     // reason for this order is that we are recursing up the operation chain.
39891     for (unsigned i = 0; i < MaskWidth; ++i) {
39892       unsigned RootIdx = i >> RootRatioLog2;
39893       if (RootMask[RootIdx] < 0) {
39894         // This is a zero or undef lane, we're done.
39895         Mask[i] = RootMask[RootIdx];
39896         continue;
39897       }
39898 
39899       unsigned RootMaskedIdx =
39900           RootRatio == 1
39901               ? RootMask[RootIdx]
39902               : (RootMask[RootIdx] << RootRatioLog2) + (i & (RootRatio - 1));
39903 
39904       // Just insert the scaled root mask value if it references an input other
39905       // than the SrcOp we're currently inserting.
39906       if ((RootMaskedIdx < (SrcOpIndex * MaskWidth)) ||
39907           (((SrcOpIndex + 1) * MaskWidth) <= RootMaskedIdx)) {
39908         Mask[i] = RootMaskedIdx;
39909         continue;
39910       }
39911 
39912       RootMaskedIdx = RootMaskedIdx & (MaskWidth - 1);
39913       unsigned OpIdx = RootMaskedIdx >> OpRatioLog2;
39914       if (OpMask[OpIdx] < 0) {
39915         // The incoming lanes are zero or undef, it doesn't matter which ones we
39916         // are using.
39917         Mask[i] = OpMask[OpIdx];
39918         continue;
39919       }
39920 
39921       // Ok, we have non-zero lanes, map them through to one of the Op's inputs.
39922       unsigned OpMaskedIdx = OpRatio == 1 ? OpMask[OpIdx]
39923                                           : (OpMask[OpIdx] << OpRatioLog2) +
39924                                                 (RootMaskedIdx & (OpRatio - 1));
39925 
39926       OpMaskedIdx = OpMaskedIdx & (MaskWidth - 1);
39927       int InputIdx = OpMask[OpIdx] / (int)OpMask.size();
39928       assert(0 <= OpInputIdx[InputIdx] && "Unknown target shuffle input");
39929       OpMaskedIdx += OpInputIdx[InputIdx] * MaskWidth;
39930 
39931       Mask[i] = OpMaskedIdx;
39932     }
39933   }
39934 
39935   // Peek through vector widenings and set out of bounds mask indices to undef.
39936   // TODO: Can resolveTargetShuffleInputsAndMask do some of this?
39937   for (unsigned I = 0, E = Ops.size(); I != E; ++I) {
39938     SDValue &Op = Ops[I];
39939     if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
39940         isNullConstant(Op.getOperand(2))) {
39941       Op = Op.getOperand(1);
39942       unsigned Scale = RootSizeInBits / Op.getValueSizeInBits();
39943       int Lo = I * Mask.size();
39944       int Hi = (I + 1) * Mask.size();
39945       int NewHi = Lo + (Mask.size() / Scale);
39946       for (int &M : Mask) {
39947         if (Lo <= M && NewHi <= M && M < Hi)
39948           M = SM_SentinelUndef;
39949       }
39950     }
39951   }
39952 
39953   // Peek through any free extract_subvector nodes back to root size.
39954   for (SDValue &Op : Ops)
39955     while (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
39956            (RootSizeInBits % Op.getOperand(0).getValueSizeInBits()) == 0 &&
39957            isNullConstant(Op.getOperand(1)))
39958       Op = Op.getOperand(0);
39959 
39960   // Remove unused/repeated shuffle source ops.
39961   resolveTargetShuffleInputsAndMask(Ops, Mask);
39962 
39963   // Handle the all undef/zero/ones cases early.
39964   if (all_of(Mask, [](int Idx) { return Idx == SM_SentinelUndef; }))
39965     return DAG.getUNDEF(RootVT);
39966   if (all_of(Mask, [](int Idx) { return Idx < 0; }))
39967     return getZeroVector(RootVT, Subtarget, DAG, DL);
39968   if (Ops.size() == 1 && ISD::isBuildVectorAllOnes(Ops[0].getNode()) &&
39969       !llvm::is_contained(Mask, SM_SentinelZero))
39970     return getOnesVector(RootVT, DAG, DL);
39971 
39972   assert(!Ops.empty() && "Shuffle with no inputs detected");
39973   HasVariableMask |= IsOpVariableMask;
39974 
39975   // Update the list of shuffle nodes that have been combined so far.
39976   SmallVector<const SDNode *, 16> CombinedNodes(SrcNodes.begin(),
39977                                                 SrcNodes.end());
39978   CombinedNodes.push_back(Op.getNode());
39979 
39980   // See if we can recurse into each shuffle source op (if it's a target
39981   // shuffle). The source op should only be generally combined if it either has
39982   // a single use (i.e. current Op) or all its users have already been combined,
39983   // if not then we can still combine but should prevent generation of variable
39984   // shuffles to avoid constant pool bloat.
39985   // Don't recurse if we already have more source ops than we can combine in
39986   // the remaining recursion depth.
39987   if (Ops.size() < (MaxDepth - Depth)) {
39988     for (int i = 0, e = Ops.size(); i < e; ++i) {
39989       // For empty roots, we need to resolve zeroable elements before combining
39990       // them with other shuffles.
39991       SmallVector<int, 64> ResolvedMask = Mask;
39992       if (EmptyRoot)
39993         resolveTargetShuffleFromZeroables(ResolvedMask, OpUndef, OpZero);
39994       bool AllowCrossLaneVar = false;
39995       bool AllowPerLaneVar = false;
39996       if (Ops[i].getNode()->hasOneUse() ||
39997           SDNode::areOnlyUsersOf(CombinedNodes, Ops[i].getNode())) {
39998         AllowCrossLaneVar = AllowVariableCrossLaneMask;
39999         AllowPerLaneVar = AllowVariablePerLaneMask;
40000       }
40001       if (SDValue Res = combineX86ShufflesRecursively(
40002               Ops, i, Root, ResolvedMask, CombinedNodes, Depth + 1, MaxDepth,
40003               HasVariableMask, AllowCrossLaneVar, AllowPerLaneVar, DAG,
40004               Subtarget))
40005         return Res;
40006     }
40007   }
40008 
40009   // Attempt to constant fold all of the constant source ops.
40010   if (SDValue Cst = combineX86ShufflesConstants(
40011           RootVT, Ops, Mask, HasVariableMask, DAG, DL, Subtarget))
40012     return Cst;
40013 
40014   // If constant fold failed and we only have constants - then we have
40015   // multiple uses by a single non-variable shuffle - just bail.
40016   if (Depth == 0 && llvm::all_of(Ops, [&](SDValue Op) {
40017         APInt UndefElts;
40018         SmallVector<APInt> RawBits;
40019         unsigned EltSizeInBits = RootSizeInBits / Mask.size();
40020         return getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts,
40021                                              RawBits,
40022                                              /*AllowWholeUndefs*/ true,
40023                                              /*AllowPartialUndefs*/ true);
40024       })) {
40025     return SDValue();
40026   }
40027 
40028   // Canonicalize the combined shuffle mask chain with horizontal ops.
40029   // NOTE: This will update the Ops and Mask.
40030   if (SDValue HOp = canonicalizeShuffleMaskWithHorizOp(
40031           Ops, Mask, RootSizeInBits, DL, DAG, Subtarget))
40032     return DAG.getBitcast(RootVT, HOp);
40033 
40034   // Try to refine our inputs given our knowledge of target shuffle mask.
40035   for (auto I : enumerate(Ops)) {
40036     int OpIdx = I.index();
40037     SDValue &Op = I.value();
40038 
40039     // What range of shuffle mask element values results in picking from Op?
40040     int Lo = OpIdx * Mask.size();
40041     int Hi = Lo + Mask.size();
40042 
40043     // Which elements of Op do we demand, given the mask's granularity?
40044     APInt OpDemandedElts(Mask.size(), 0);
40045     for (int MaskElt : Mask) {
40046       if (isInRange(MaskElt, Lo, Hi)) { // Picks from Op?
40047         int OpEltIdx = MaskElt - Lo;
40048         OpDemandedElts.setBit(OpEltIdx);
40049       }
40050     }
40051 
40052     // Is the shuffle result smaller than the root?
40053     if (Op.getValueSizeInBits() < RootSizeInBits) {
40054       // We padded the mask with undefs. But we now need to undo that.
40055       unsigned NumExpectedVectorElts = Mask.size();
40056       unsigned EltSizeInBits = RootSizeInBits / NumExpectedVectorElts;
40057       unsigned NumOpVectorElts = Op.getValueSizeInBits() / EltSizeInBits;
40058       assert(!OpDemandedElts.extractBits(
40059                  NumExpectedVectorElts - NumOpVectorElts, NumOpVectorElts) &&
40060              "Demanding the virtual undef widening padding?");
40061       OpDemandedElts = OpDemandedElts.trunc(NumOpVectorElts); // NUW
40062     }
40063 
40064     // The Op itself may be of different VT, so we need to scale the mask.
40065     unsigned NumOpElts = Op.getValueType().getVectorNumElements();
40066     APInt OpScaledDemandedElts = APIntOps::ScaleBitMask(OpDemandedElts, NumOpElts);
40067 
40068     // Can this operand be simplified any further, given it's demanded elements?
40069     if (SDValue NewOp =
40070             DAG.getTargetLoweringInfo().SimplifyMultipleUseDemandedVectorElts(
40071                 Op, OpScaledDemandedElts, DAG))
40072       Op = NewOp;
40073   }
40074   // FIXME: should we rerun resolveTargetShuffleInputsAndMask() now?
40075 
40076   // Widen any subvector shuffle inputs we've collected.
40077   // TODO: Remove this to avoid generating temporary nodes, we should only
40078   // widen once combineX86ShuffleChain has found a match.
40079   if (any_of(Ops, [RootSizeInBits](SDValue Op) {
40080         return Op.getValueSizeInBits() < RootSizeInBits;
40081       })) {
40082     for (SDValue &Op : Ops)
40083       if (Op.getValueSizeInBits() < RootSizeInBits)
40084         Op = widenSubVector(Op, false, Subtarget, DAG, SDLoc(Op),
40085                             RootSizeInBits);
40086     // Reresolve - we might have repeated subvector sources.
40087     resolveTargetShuffleInputsAndMask(Ops, Mask);
40088   }
40089 
40090   // We can only combine unary and binary shuffle mask cases.
40091   if (Ops.size() <= 2) {
40092     // Minor canonicalization of the accumulated shuffle mask to make it easier
40093     // to match below. All this does is detect masks with sequential pairs of
40094     // elements, and shrink them to the half-width mask. It does this in a loop
40095     // so it will reduce the size of the mask to the minimal width mask which
40096     // performs an equivalent shuffle.
40097     while (Mask.size() > 1) {
40098       SmallVector<int, 64> WidenedMask;
40099       if (!canWidenShuffleElements(Mask, WidenedMask))
40100         break;
40101       Mask = std::move(WidenedMask);
40102     }
40103 
40104     // Canonicalization of binary shuffle masks to improve pattern matching by
40105     // commuting the inputs.
40106     if (Ops.size() == 2 && canonicalizeShuffleMaskWithCommute(Mask)) {
40107       ShuffleVectorSDNode::commuteMask(Mask);
40108       std::swap(Ops[0], Ops[1]);
40109     }
40110 
40111     // Try to combine into a single shuffle instruction.
40112     if (SDValue Shuffle = combineX86ShuffleChain(
40113             Ops, Root, Mask, Depth, HasVariableMask, AllowVariableCrossLaneMask,
40114             AllowVariablePerLaneMask, DAG, Subtarget))
40115       return Shuffle;
40116 
40117     // If all the operands come from the same larger vector, fallthrough and try
40118     // to use combineX86ShuffleChainWithExtract.
40119     SDValue LHS = peekThroughBitcasts(Ops.front());
40120     SDValue RHS = peekThroughBitcasts(Ops.back());
40121     if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
40122         (RootSizeInBits / Mask.size()) != 64 ||
40123         LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
40124         RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
40125         LHS.getOperand(0) != RHS.getOperand(0))
40126       return SDValue();
40127   }
40128 
40129   // If that failed and any input is extracted then try to combine as a
40130   // shuffle with the larger type.
40131   return combineX86ShuffleChainWithExtract(
40132       Ops, Root, Mask, Depth, HasVariableMask, AllowVariableCrossLaneMask,
40133       AllowVariablePerLaneMask, DAG, Subtarget);
40134 }
40135 
40136 /// Helper entry wrapper to combineX86ShufflesRecursively.
combineX86ShufflesRecursively(SDValue Op,SelectionDAG & DAG,const X86Subtarget & Subtarget)40137 static SDValue combineX86ShufflesRecursively(SDValue Op, SelectionDAG &DAG,
40138                                              const X86Subtarget &Subtarget) {
40139   return combineX86ShufflesRecursively(
40140       {Op}, 0, Op, {0}, {}, /*Depth*/ 0, X86::MaxShuffleCombineDepth,
40141       /*HasVarMask*/ false,
40142       /*AllowCrossLaneVarMask*/ true, /*AllowPerLaneVarMask*/ true, DAG,
40143       Subtarget);
40144 }
40145 
40146 /// Get the PSHUF-style mask from PSHUF node.
40147 ///
40148 /// This is a very minor wrapper around getTargetShuffleMask to easy forming v4
40149 /// PSHUF-style masks that can be reused with such instructions.
getPSHUFShuffleMask(SDValue N)40150 static SmallVector<int, 4> getPSHUFShuffleMask(SDValue N) {
40151   MVT VT = N.getSimpleValueType();
40152   SmallVector<int, 4> Mask;
40153   SmallVector<SDValue, 2> Ops;
40154   bool HaveMask = getTargetShuffleMask(N, false, Ops, Mask);
40155   (void)HaveMask;
40156   assert(HaveMask);
40157 
40158   // If we have more than 128-bits, only the low 128-bits of shuffle mask
40159   // matter. Check that the upper masks are repeats and remove them.
40160   if (VT.getSizeInBits() > 128) {
40161     int LaneElts = 128 / VT.getScalarSizeInBits();
40162 #ifndef NDEBUG
40163     for (int i = 1, NumLanes = VT.getSizeInBits() / 128; i < NumLanes; ++i)
40164       for (int j = 0; j < LaneElts; ++j)
40165         assert(Mask[j] == Mask[i * LaneElts + j] - (LaneElts * i) &&
40166                "Mask doesn't repeat in high 128-bit lanes!");
40167 #endif
40168     Mask.resize(LaneElts);
40169   }
40170 
40171   switch (N.getOpcode()) {
40172   case X86ISD::PSHUFD:
40173     return Mask;
40174   case X86ISD::PSHUFLW:
40175     Mask.resize(4);
40176     return Mask;
40177   case X86ISD::PSHUFHW:
40178     Mask.erase(Mask.begin(), Mask.begin() + 4);
40179     for (int &M : Mask)
40180       M -= 4;
40181     return Mask;
40182   default:
40183     llvm_unreachable("No valid shuffle instruction found!");
40184   }
40185 }
40186 
40187 /// Search for a combinable shuffle across a chain ending in pshufd.
40188 ///
40189 /// We walk up the chain and look for a combinable shuffle, skipping over
40190 /// shuffles that we could hoist this shuffle's transformation past without
40191 /// altering anything.
combineRedundantDWordShuffle(SDValue N,MutableArrayRef<int> Mask,const SDLoc & DL,SelectionDAG & DAG)40192 static SDValue combineRedundantDWordShuffle(SDValue N,
40193                                             MutableArrayRef<int> Mask,
40194                                             const SDLoc &DL,
40195                                             SelectionDAG &DAG) {
40196   assert(N.getOpcode() == X86ISD::PSHUFD &&
40197          "Called with something other than an x86 128-bit half shuffle!");
40198 
40199   // Walk up a single-use chain looking for a combinable shuffle. Keep a stack
40200   // of the shuffles in the chain so that we can form a fresh chain to replace
40201   // this one.
40202   SmallVector<SDValue, 8> Chain;
40203   SDValue V = N.getOperand(0);
40204   for (; V.hasOneUse(); V = V.getOperand(0)) {
40205     switch (V.getOpcode()) {
40206     default:
40207       return SDValue(); // Nothing combined!
40208 
40209     case ISD::BITCAST:
40210       // Skip bitcasts as we always know the type for the target specific
40211       // instructions.
40212       continue;
40213 
40214     case X86ISD::PSHUFD:
40215       // Found another dword shuffle.
40216       break;
40217 
40218     case X86ISD::PSHUFLW:
40219       // Check that the low words (being shuffled) are the identity in the
40220       // dword shuffle, and the high words are self-contained.
40221       if (Mask[0] != 0 || Mask[1] != 1 ||
40222           !(Mask[2] >= 2 && Mask[2] < 4 && Mask[3] >= 2 && Mask[3] < 4))
40223         return SDValue();
40224 
40225       Chain.push_back(V);
40226       continue;
40227 
40228     case X86ISD::PSHUFHW:
40229       // Check that the high words (being shuffled) are the identity in the
40230       // dword shuffle, and the low words are self-contained.
40231       if (Mask[2] != 2 || Mask[3] != 3 ||
40232           !(Mask[0] >= 0 && Mask[0] < 2 && Mask[1] >= 0 && Mask[1] < 2))
40233         return SDValue();
40234 
40235       Chain.push_back(V);
40236       continue;
40237 
40238     case X86ISD::UNPCKL:
40239     case X86ISD::UNPCKH:
40240       // For either i8 -> i16 or i16 -> i32 unpacks, we can combine a dword
40241       // shuffle into a preceding word shuffle.
40242       if (V.getSimpleValueType().getVectorElementType() != MVT::i8 &&
40243           V.getSimpleValueType().getVectorElementType() != MVT::i16)
40244         return SDValue();
40245 
40246       // Search for a half-shuffle which we can combine with.
40247       unsigned CombineOp =
40248           V.getOpcode() == X86ISD::UNPCKL ? X86ISD::PSHUFLW : X86ISD::PSHUFHW;
40249       if (V.getOperand(0) != V.getOperand(1) ||
40250           !V->isOnlyUserOf(V.getOperand(0).getNode()))
40251         return SDValue();
40252       Chain.push_back(V);
40253       V = V.getOperand(0);
40254       do {
40255         switch (V.getOpcode()) {
40256         default:
40257           return SDValue(); // Nothing to combine.
40258 
40259         case X86ISD::PSHUFLW:
40260         case X86ISD::PSHUFHW:
40261           if (V.getOpcode() == CombineOp)
40262             break;
40263 
40264           Chain.push_back(V);
40265 
40266           [[fallthrough]];
40267         case ISD::BITCAST:
40268           V = V.getOperand(0);
40269           continue;
40270         }
40271         break;
40272       } while (V.hasOneUse());
40273       break;
40274     }
40275     // Break out of the loop if we break out of the switch.
40276     break;
40277   }
40278 
40279   if (!V.hasOneUse())
40280     // We fell out of the loop without finding a viable combining instruction.
40281     return SDValue();
40282 
40283   // Merge this node's mask and our incoming mask.
40284   SmallVector<int, 4> VMask = getPSHUFShuffleMask(V);
40285   for (int &M : Mask)
40286     M = VMask[M];
40287   V = DAG.getNode(V.getOpcode(), DL, V.getValueType(), V.getOperand(0),
40288                   getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
40289 
40290   // Rebuild the chain around this new shuffle.
40291   while (!Chain.empty()) {
40292     SDValue W = Chain.pop_back_val();
40293 
40294     if (V.getValueType() != W.getOperand(0).getValueType())
40295       V = DAG.getBitcast(W.getOperand(0).getValueType(), V);
40296 
40297     switch (W.getOpcode()) {
40298     default:
40299       llvm_unreachable("Only PSHUF and UNPCK instructions get here!");
40300 
40301     case X86ISD::UNPCKL:
40302     case X86ISD::UNPCKH:
40303       V = DAG.getNode(W.getOpcode(), DL, W.getValueType(), V, V);
40304       break;
40305 
40306     case X86ISD::PSHUFD:
40307     case X86ISD::PSHUFLW:
40308     case X86ISD::PSHUFHW:
40309       V = DAG.getNode(W.getOpcode(), DL, W.getValueType(), V, W.getOperand(1));
40310       break;
40311     }
40312   }
40313   if (V.getValueType() != N.getValueType())
40314     V = DAG.getBitcast(N.getValueType(), V);
40315 
40316   // Return the new chain to replace N.
40317   return V;
40318 }
40319 
40320 // Attempt to commute shufps LHS loads:
40321 // permilps(shufps(load(),x)) --> permilps(shufps(x,load()))
combineCommutableSHUFP(SDValue N,MVT VT,const SDLoc & DL,SelectionDAG & DAG)40322 static SDValue combineCommutableSHUFP(SDValue N, MVT VT, const SDLoc &DL,
40323                                       SelectionDAG &DAG) {
40324   // TODO: Add vXf64 support.
40325   if (VT != MVT::v4f32 && VT != MVT::v8f32 && VT != MVT::v16f32)
40326     return SDValue();
40327 
40328   // SHUFP(LHS, RHS) -> SHUFP(RHS, LHS) iff LHS is foldable + RHS is not.
40329   auto commuteSHUFP = [&VT, &DL, &DAG](SDValue Parent, SDValue V) {
40330     if (V.getOpcode() != X86ISD::SHUFP || !Parent->isOnlyUserOf(V.getNode()))
40331       return SDValue();
40332     SDValue N0 = V.getOperand(0);
40333     SDValue N1 = V.getOperand(1);
40334     unsigned Imm = V.getConstantOperandVal(2);
40335     const X86Subtarget &Subtarget = DAG.getSubtarget<X86Subtarget>();
40336     if (!X86::mayFoldLoad(peekThroughOneUseBitcasts(N0), Subtarget) ||
40337         X86::mayFoldLoad(peekThroughOneUseBitcasts(N1), Subtarget))
40338       return SDValue();
40339     Imm = ((Imm & 0x0F) << 4) | ((Imm & 0xF0) >> 4);
40340     return DAG.getNode(X86ISD::SHUFP, DL, VT, N1, N0,
40341                        DAG.getTargetConstant(Imm, DL, MVT::i8));
40342   };
40343 
40344   switch (N.getOpcode()) {
40345   case X86ISD::VPERMILPI:
40346     if (SDValue NewSHUFP = commuteSHUFP(N, N.getOperand(0))) {
40347       unsigned Imm = N.getConstantOperandVal(1);
40348       return DAG.getNode(X86ISD::VPERMILPI, DL, VT, NewSHUFP,
40349                          DAG.getTargetConstant(Imm ^ 0xAA, DL, MVT::i8));
40350     }
40351     break;
40352   case X86ISD::SHUFP: {
40353     SDValue N0 = N.getOperand(0);
40354     SDValue N1 = N.getOperand(1);
40355     unsigned Imm = N.getConstantOperandVal(2);
40356     if (N0 == N1) {
40357       if (SDValue NewSHUFP = commuteSHUFP(N, N0))
40358         return DAG.getNode(X86ISD::SHUFP, DL, VT, NewSHUFP, NewSHUFP,
40359                            DAG.getTargetConstant(Imm ^ 0xAA, DL, MVT::i8));
40360     } else if (SDValue NewSHUFP = commuteSHUFP(N, N0)) {
40361       return DAG.getNode(X86ISD::SHUFP, DL, VT, NewSHUFP, N1,
40362                          DAG.getTargetConstant(Imm ^ 0x0A, DL, MVT::i8));
40363     } else if (SDValue NewSHUFP = commuteSHUFP(N, N1)) {
40364       return DAG.getNode(X86ISD::SHUFP, DL, VT, N0, NewSHUFP,
40365                          DAG.getTargetConstant(Imm ^ 0xA0, DL, MVT::i8));
40366     }
40367     break;
40368   }
40369   }
40370 
40371   return SDValue();
40372 }
40373 
40374 // Attempt to fold BLEND(PERMUTE(X),PERMUTE(Y)) -> PERMUTE(BLEND(X,Y))
40375 // iff we don't demand the same element index for both X and Y.
40376 static SDValue
combineBlendOfPermutes(MVT VT,SDValue N0,SDValue N1,ArrayRef<int> BlendMask,const APInt & DemandedElts,SelectionDAG & DAG,const X86Subtarget & Subtarget,const SDLoc & DL)40377 combineBlendOfPermutes(MVT VT, SDValue N0, SDValue N1, ArrayRef<int> BlendMask,
40378                        const APInt &DemandedElts, SelectionDAG &DAG,
40379                        const X86Subtarget &Subtarget, const SDLoc &DL) {
40380   assert(isBlendOrUndef(BlendMask) && "Blend shuffle expected");
40381   if (!N0.hasOneUse() || !N1.hasOneUse())
40382     return SDValue();
40383 
40384   unsigned NumElts = VT.getVectorNumElements();
40385   SDValue BC0 = peekThroughOneUseBitcasts(N0);
40386   SDValue BC1 = peekThroughOneUseBitcasts(N1);
40387 
40388   // See if both operands are shuffles, and that we can scale the shuffle masks
40389   // to the same width as the blend mask.
40390   // TODO: Support SM_SentinelZero?
40391   SmallVector<SDValue, 2> Ops0, Ops1;
40392   SmallVector<int, 32> Mask0, Mask1, ScaledMask0, ScaledMask1;
40393   if (!getTargetShuffleMask(BC0, /*AllowSentinelZero=*/false, Ops0, Mask0) ||
40394       !getTargetShuffleMask(BC1, /*AllowSentinelZero=*/false, Ops1, Mask1) ||
40395       !scaleShuffleElements(Mask0, NumElts, ScaledMask0) ||
40396       !scaleShuffleElements(Mask1, NumElts, ScaledMask1))
40397     return SDValue();
40398 
40399   // Determine the demanded elts from both permutes.
40400   APInt Demanded0, DemandedLHS0, DemandedRHS0;
40401   APInt Demanded1, DemandedLHS1, DemandedRHS1;
40402   if (!getShuffleDemandedElts(NumElts, BlendMask, DemandedElts, Demanded0,
40403                               Demanded1,
40404                               /*AllowUndefElts=*/true) ||
40405       !getShuffleDemandedElts(NumElts, ScaledMask0, Demanded0, DemandedLHS0,
40406                               DemandedRHS0, /*AllowUndefElts=*/true) ||
40407       !getShuffleDemandedElts(NumElts, ScaledMask1, Demanded1, DemandedLHS1,
40408                               DemandedRHS1, /*AllowUndefElts=*/true))
40409     return SDValue();
40410 
40411   // Confirm that we only use a single operand from both permutes and that we
40412   // don't demand the same index from both.
40413   if (!DemandedRHS0.isZero() || !DemandedRHS1.isZero() ||
40414       DemandedLHS0.intersects(DemandedLHS1))
40415     return SDValue();
40416 
40417   // Use the permute demanded elts masks as the new blend mask.
40418   // Create the new permute mask as a blend of the 2 original permute masks.
40419   SmallVector<int, 32> NewBlendMask(NumElts, SM_SentinelUndef);
40420   SmallVector<int, 32> NewPermuteMask(NumElts, SM_SentinelUndef);
40421   for (unsigned I = 0; I != NumElts; ++I) {
40422     if (Demanded0[I]) {
40423       int M = ScaledMask0[I];
40424       if (0 <= M) {
40425         assert(isUndefOrEqual(NewBlendMask[M], M) &&
40426                "BlendMask demands LHS AND RHS");
40427         NewBlendMask[M] = M;
40428         NewPermuteMask[I] = M;
40429       }
40430     } else if (Demanded1[I]) {
40431       int M = ScaledMask1[I];
40432       if (0 <= M) {
40433         assert(isUndefOrEqual(NewBlendMask[M], M + NumElts) &&
40434                "BlendMask demands LHS AND RHS");
40435         NewBlendMask[M] = M + NumElts;
40436         NewPermuteMask[I] = M;
40437       }
40438     }
40439   }
40440   assert(isBlendOrUndef(NewBlendMask) && "Bad blend");
40441   assert(isUndefOrInRange(NewPermuteMask, 0, NumElts) && "Bad permute");
40442 
40443   // v16i16 shuffles can explode in complexity very easily, only accept them if
40444   // the blend mask is the same in the 128-bit subvectors (or can widen to
40445   // v8i32) and the permute can be widened as well.
40446   if (VT == MVT::v16i16) {
40447     if (!is128BitLaneRepeatedShuffleMask(VT, NewBlendMask) &&
40448         !canWidenShuffleElements(NewBlendMask))
40449       return SDValue();
40450     if (!canWidenShuffleElements(NewPermuteMask))
40451       return SDValue();
40452   }
40453 
40454   // Don't introduce lane-crossing permutes without AVX2, unless it can be
40455   // widened to a lane permute (vperm2f128).
40456   if (VT.is256BitVector() && !Subtarget.hasAVX2() &&
40457       isLaneCrossingShuffleMask(128, VT.getScalarSizeInBits(),
40458                                 NewPermuteMask) &&
40459       !canScaleShuffleElements(NewPermuteMask, 2))
40460     return SDValue();
40461 
40462   SDValue NewBlend =
40463       DAG.getVectorShuffle(VT, DL, DAG.getBitcast(VT, Ops0[0]),
40464                            DAG.getBitcast(VT, Ops1[0]), NewBlendMask);
40465   return DAG.getVectorShuffle(VT, DL, NewBlend, DAG.getUNDEF(VT),
40466                               NewPermuteMask);
40467 }
40468 
40469 // TODO - move this to TLI like isBinOp?
isUnaryOp(unsigned Opcode)40470 static bool isUnaryOp(unsigned Opcode) {
40471   switch (Opcode) {
40472   case ISD::CTLZ:
40473   case ISD::CTTZ:
40474   case ISD::CTPOP:
40475     return true;
40476   }
40477   return false;
40478 }
40479 
40480 // Canonicalize SHUFFLE(UNARYOP(X)) -> UNARYOP(SHUFFLE(X)).
40481 // Canonicalize SHUFFLE(BINOP(X,Y)) -> BINOP(SHUFFLE(X),SHUFFLE(Y)).
canonicalizeShuffleWithOp(SDValue N,SelectionDAG & DAG,const SDLoc & DL)40482 static SDValue canonicalizeShuffleWithOp(SDValue N, SelectionDAG &DAG,
40483                                          const SDLoc &DL) {
40484   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
40485   EVT ShuffleVT = N.getValueType();
40486   unsigned Opc = N.getOpcode();
40487 
40488   auto IsMergeableWithShuffle = [Opc, &DAG](SDValue Op, bool FoldShuf = true,
40489                                             bool FoldLoad = false) {
40490     // AllZeros/AllOnes constants are freely shuffled and will peek through
40491     // bitcasts. Other constant build vectors do not peek through bitcasts. Only
40492     // merge with target shuffles if it has one use so shuffle combining is
40493     // likely to kick in. Shuffles of splats are expected to be removed.
40494     return ISD::isBuildVectorAllOnes(Op.getNode()) ||
40495            ISD::isBuildVectorAllZeros(Op.getNode()) ||
40496            ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
40497            ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()) ||
40498            getTargetConstantFromNode(dyn_cast<LoadSDNode>(Op)) ||
40499            (Op.getOpcode() == Opc && Op->hasOneUse()) ||
40500            (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op->hasOneUse()) ||
40501            (FoldShuf && isTargetShuffle(Op.getOpcode()) && Op->hasOneUse()) ||
40502            (FoldLoad && isShuffleFoldableLoad(Op)) ||
40503            DAG.isSplatValue(Op, /*AllowUndefs*/ false);
40504   };
40505   auto IsSafeToMoveShuffle = [ShuffleVT](SDValue Op, unsigned BinOp) {
40506     // Ensure we only shuffle whole vector src elements, unless its a logical
40507     // binops where we can more aggressively move shuffles from dst to src.
40508     return isLogicOp(BinOp) ||
40509            (Op.getScalarValueSizeInBits() <= ShuffleVT.getScalarSizeInBits());
40510   };
40511 
40512   switch (Opc) {
40513   // Unary and Unary+Permute Shuffles.
40514   case X86ISD::PSHUFB: {
40515     // Don't merge PSHUFB if it contains zero'd elements.
40516     SmallVector<int> Mask;
40517     SmallVector<SDValue> Ops;
40518     if (!getTargetShuffleMask(N, false, Ops, Mask))
40519       break;
40520     [[fallthrough]];
40521   }
40522   case X86ISD::VBROADCAST:
40523   case X86ISD::MOVDDUP:
40524   case X86ISD::PSHUFD:
40525   case X86ISD::PSHUFHW:
40526   case X86ISD::PSHUFLW:
40527   case X86ISD::VPERMI:
40528   case X86ISD::VPERMILPI: {
40529     if (N.getOperand(0).getValueType() == ShuffleVT &&
40530         N->isOnlyUserOf(N.getOperand(0).getNode())) {
40531       SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0));
40532       unsigned SrcOpcode = N0.getOpcode();
40533       if (TLI.isBinOp(SrcOpcode) && IsSafeToMoveShuffle(N0, SrcOpcode)) {
40534         SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0));
40535         SDValue Op01 = peekThroughOneUseBitcasts(N0.getOperand(1));
40536         if (IsMergeableWithShuffle(Op00, Opc != X86ISD::VPERMI,
40537                                    Opc != X86ISD::PSHUFB) ||
40538             IsMergeableWithShuffle(Op01, Opc != X86ISD::VPERMI,
40539                                    Opc != X86ISD::PSHUFB)) {
40540           SDValue LHS, RHS;
40541           Op00 = DAG.getBitcast(ShuffleVT, Op00);
40542           Op01 = DAG.getBitcast(ShuffleVT, Op01);
40543           if (N.getNumOperands() == 2) {
40544             LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, N.getOperand(1));
40545             RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, N.getOperand(1));
40546           } else {
40547             LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00);
40548             RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01);
40549           }
40550           EVT OpVT = N0.getValueType();
40551           return DAG.getBitcast(ShuffleVT,
40552                                 DAG.getNode(SrcOpcode, DL, OpVT,
40553                                             DAG.getBitcast(OpVT, LHS),
40554                                             DAG.getBitcast(OpVT, RHS)));
40555         }
40556       }
40557     }
40558     break;
40559   }
40560   // Binary and Binary+Permute Shuffles.
40561   case X86ISD::INSERTPS: {
40562     // Don't merge INSERTPS if it contains zero'd elements.
40563     unsigned InsertPSMask = N.getConstantOperandVal(2);
40564     unsigned ZeroMask = InsertPSMask & 0xF;
40565     if (ZeroMask != 0)
40566       break;
40567     [[fallthrough]];
40568   }
40569   case X86ISD::MOVSD:
40570   case X86ISD::MOVSS:
40571   case X86ISD::BLENDI:
40572   case X86ISD::SHUFP:
40573   case X86ISD::UNPCKH:
40574   case X86ISD::UNPCKL: {
40575     if (N->isOnlyUserOf(N.getOperand(0).getNode()) &&
40576         N->isOnlyUserOf(N.getOperand(1).getNode())) {
40577       SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0));
40578       SDValue N1 = peekThroughOneUseBitcasts(N.getOperand(1));
40579       unsigned SrcOpcode = N0.getOpcode();
40580       if (TLI.isBinOp(SrcOpcode) && N1.getOpcode() == SrcOpcode &&
40581           N0.getValueType() == N1.getValueType() &&
40582           IsSafeToMoveShuffle(N0, SrcOpcode) &&
40583           IsSafeToMoveShuffle(N1, SrcOpcode)) {
40584         SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0));
40585         SDValue Op10 = peekThroughOneUseBitcasts(N1.getOperand(0));
40586         SDValue Op01 = peekThroughOneUseBitcasts(N0.getOperand(1));
40587         SDValue Op11 = peekThroughOneUseBitcasts(N1.getOperand(1));
40588         // Ensure the total number of shuffles doesn't increase by folding this
40589         // shuffle through to the source ops.
40590         if (((IsMergeableWithShuffle(Op00) && IsMergeableWithShuffle(Op10)) ||
40591              (IsMergeableWithShuffle(Op01) && IsMergeableWithShuffle(Op11))) ||
40592             ((IsMergeableWithShuffle(Op00) || IsMergeableWithShuffle(Op10)) &&
40593              (IsMergeableWithShuffle(Op01) || IsMergeableWithShuffle(Op11)))) {
40594           SDValue LHS, RHS;
40595           Op00 = DAG.getBitcast(ShuffleVT, Op00);
40596           Op10 = DAG.getBitcast(ShuffleVT, Op10);
40597           Op01 = DAG.getBitcast(ShuffleVT, Op01);
40598           Op11 = DAG.getBitcast(ShuffleVT, Op11);
40599           if (N.getNumOperands() == 3) {
40600             LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10, N.getOperand(2));
40601             RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, Op11, N.getOperand(2));
40602           } else {
40603             LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10);
40604             RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, Op11);
40605           }
40606           EVT OpVT = N0.getValueType();
40607           return DAG.getBitcast(ShuffleVT,
40608                                 DAG.getNode(SrcOpcode, DL, OpVT,
40609                                             DAG.getBitcast(OpVT, LHS),
40610                                             DAG.getBitcast(OpVT, RHS)));
40611         }
40612       }
40613       if (isUnaryOp(SrcOpcode) && N1.getOpcode() == SrcOpcode &&
40614           N0.getValueType() == N1.getValueType() &&
40615           IsSafeToMoveShuffle(N0, SrcOpcode) &&
40616           IsSafeToMoveShuffle(N1, SrcOpcode)) {
40617         SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0));
40618         SDValue Op10 = peekThroughOneUseBitcasts(N1.getOperand(0));
40619         SDValue Res;
40620         Op00 = DAG.getBitcast(ShuffleVT, Op00);
40621         Op10 = DAG.getBitcast(ShuffleVT, Op10);
40622         if (N.getNumOperands() == 3) {
40623           Res = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10, N.getOperand(2));
40624         } else {
40625           Res = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10);
40626         }
40627         EVT OpVT = N0.getValueType();
40628         return DAG.getBitcast(
40629             ShuffleVT,
40630             DAG.getNode(SrcOpcode, DL, OpVT, DAG.getBitcast(OpVT, Res)));
40631       }
40632     }
40633     break;
40634   }
40635   }
40636   return SDValue();
40637 }
40638 
40639 /// Attempt to fold vpermf128(op(),op()) -> op(vpermf128(),vpermf128()).
canonicalizeLaneShuffleWithRepeatedOps(SDValue V,SelectionDAG & DAG,const SDLoc & DL)40640 static SDValue canonicalizeLaneShuffleWithRepeatedOps(SDValue V,
40641                                                       SelectionDAG &DAG,
40642                                                       const SDLoc &DL) {
40643   assert(V.getOpcode() == X86ISD::VPERM2X128 && "Unknown lane shuffle");
40644 
40645   MVT VT = V.getSimpleValueType();
40646   SDValue Src0 = peekThroughBitcasts(V.getOperand(0));
40647   SDValue Src1 = peekThroughBitcasts(V.getOperand(1));
40648   unsigned SrcOpc0 = Src0.getOpcode();
40649   unsigned SrcOpc1 = Src1.getOpcode();
40650   EVT SrcVT0 = Src0.getValueType();
40651   EVT SrcVT1 = Src1.getValueType();
40652 
40653   if (!Src1.isUndef() && (SrcVT0 != SrcVT1 || SrcOpc0 != SrcOpc1))
40654     return SDValue();
40655 
40656   switch (SrcOpc0) {
40657   case X86ISD::MOVDDUP: {
40658     SDValue LHS = Src0.getOperand(0);
40659     SDValue RHS = Src1.isUndef() ? Src1 : Src1.getOperand(0);
40660     SDValue Res =
40661         DAG.getNode(X86ISD::VPERM2X128, DL, SrcVT0, LHS, RHS, V.getOperand(2));
40662     Res = DAG.getNode(SrcOpc0, DL, SrcVT0, Res);
40663     return DAG.getBitcast(VT, Res);
40664   }
40665   case X86ISD::VPERMILPI:
40666     // TODO: Handle v4f64 permutes with different low/high lane masks.
40667     if (SrcVT0 == MVT::v4f64) {
40668       uint64_t Mask = Src0.getConstantOperandVal(1);
40669       if ((Mask & 0x3) != ((Mask >> 2) & 0x3))
40670         break;
40671     }
40672     [[fallthrough]];
40673   case X86ISD::VSHLI:
40674   case X86ISD::VSRLI:
40675   case X86ISD::VSRAI:
40676   case X86ISD::PSHUFD:
40677     if (Src1.isUndef() || Src0.getOperand(1) == Src1.getOperand(1)) {
40678       SDValue LHS = Src0.getOperand(0);
40679       SDValue RHS = Src1.isUndef() ? Src1 : Src1.getOperand(0);
40680       SDValue Res = DAG.getNode(X86ISD::VPERM2X128, DL, SrcVT0, LHS, RHS,
40681                                 V.getOperand(2));
40682       Res = DAG.getNode(SrcOpc0, DL, SrcVT0, Res, Src0.getOperand(1));
40683       return DAG.getBitcast(VT, Res);
40684     }
40685     break;
40686   }
40687 
40688   return SDValue();
40689 }
40690 
40691 /// Try to combine x86 target specific shuffles.
combineTargetShuffle(SDValue N,const SDLoc & DL,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)40692 static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
40693                                     SelectionDAG &DAG,
40694                                     TargetLowering::DAGCombinerInfo &DCI,
40695                                     const X86Subtarget &Subtarget) {
40696   MVT VT = N.getSimpleValueType();
40697   SmallVector<int, 4> Mask;
40698   unsigned Opcode = N.getOpcode();
40699   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
40700 
40701   if (SDValue R = combineCommutableSHUFP(N, VT, DL, DAG))
40702     return R;
40703 
40704   // Handle specific target shuffles.
40705   switch (Opcode) {
40706   case X86ISD::MOVDDUP: {
40707     SDValue Src = N.getOperand(0);
40708     // Turn a 128-bit MOVDDUP of a full vector load into movddup+vzload.
40709     if (VT == MVT::v2f64 && Src.hasOneUse() &&
40710         ISD::isNormalLoad(Src.getNode())) {
40711       LoadSDNode *LN = cast<LoadSDNode>(Src);
40712       if (SDValue VZLoad = narrowLoadToVZLoad(LN, MVT::f64, MVT::v2f64, DAG)) {
40713         SDValue Movddup = DAG.getNode(X86ISD::MOVDDUP, DL, MVT::v2f64, VZLoad);
40714         DCI.CombineTo(N.getNode(), Movddup);
40715         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
40716         DCI.recursivelyDeleteUnusedNodes(LN);
40717         return N; // Return N so it doesn't get rechecked!
40718       }
40719     }
40720 
40721     return SDValue();
40722   }
40723   case X86ISD::VBROADCAST: {
40724     SDValue Src = N.getOperand(0);
40725     SDValue BC = peekThroughBitcasts(Src);
40726     EVT SrcVT = Src.getValueType();
40727     EVT BCVT = BC.getValueType();
40728 
40729     // If broadcasting from another shuffle, attempt to simplify it.
40730     // TODO - we really need a general SimplifyDemandedVectorElts mechanism.
40731     if (isTargetShuffle(BC.getOpcode()) &&
40732         VT.getScalarSizeInBits() % BCVT.getScalarSizeInBits() == 0) {
40733       unsigned Scale = VT.getScalarSizeInBits() / BCVT.getScalarSizeInBits();
40734       SmallVector<int, 16> DemandedMask(BCVT.getVectorNumElements(),
40735                                         SM_SentinelUndef);
40736       for (unsigned i = 0; i != Scale; ++i)
40737         DemandedMask[i] = i;
40738       if (SDValue Res = combineX86ShufflesRecursively(
40739               {BC}, 0, BC, DemandedMask, {}, /*Depth*/ 0,
40740               X86::MaxShuffleCombineDepth,
40741               /*HasVarMask*/ false, /*AllowCrossLaneVarMask*/ true,
40742               /*AllowPerLaneVarMask*/ true, DAG, Subtarget))
40743         return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
40744                            DAG.getBitcast(SrcVT, Res));
40745     }
40746 
40747     // broadcast(bitcast(src)) -> bitcast(broadcast(src))
40748     // 32-bit targets have to bitcast i64 to f64, so better to bitcast upward.
40749     if (Src.getOpcode() == ISD::BITCAST &&
40750         SrcVT.getScalarSizeInBits() == BCVT.getScalarSizeInBits() &&
40751         TLI.isTypeLegal(BCVT) &&
40752         FixedVectorType::isValidElementType(
40753             BCVT.getScalarType().getTypeForEVT(*DAG.getContext()))) {
40754       EVT NewVT = EVT::getVectorVT(*DAG.getContext(), BCVT.getScalarType(),
40755                                    VT.getVectorNumElements());
40756       return DAG.getBitcast(VT, DAG.getNode(X86ISD::VBROADCAST, DL, NewVT, BC));
40757     }
40758 
40759     // vbroadcast(bitcast(vbroadcast(src))) -> bitcast(vbroadcast(src))
40760     // If we're re-broadcasting a smaller type then broadcast with that type and
40761     // bitcast.
40762     // TODO: Do this for any splat?
40763     if (Src.getOpcode() == ISD::BITCAST &&
40764         (BC.getOpcode() == X86ISD::VBROADCAST ||
40765          BC.getOpcode() == X86ISD::VBROADCAST_LOAD) &&
40766         (VT.getScalarSizeInBits() % BCVT.getScalarSizeInBits()) == 0 &&
40767         (VT.getSizeInBits() % BCVT.getSizeInBits()) == 0) {
40768       MVT NewVT =
40769           MVT::getVectorVT(BCVT.getSimpleVT().getScalarType(),
40770                            VT.getSizeInBits() / BCVT.getScalarSizeInBits());
40771       return DAG.getBitcast(VT, DAG.getNode(X86ISD::VBROADCAST, DL, NewVT, BC));
40772     }
40773 
40774     // Reduce broadcast source vector to lowest 128-bits.
40775     if (SrcVT.getSizeInBits() > 128)
40776       return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
40777                          extract128BitVector(Src, 0, DAG, DL));
40778 
40779     // broadcast(scalar_to_vector(x)) -> broadcast(x).
40780     if (Src.getOpcode() == ISD::SCALAR_TO_VECTOR &&
40781         Src.getValueType().getScalarType() == Src.getOperand(0).getValueType())
40782       return DAG.getNode(X86ISD::VBROADCAST, DL, VT, Src.getOperand(0));
40783 
40784     // broadcast(extract_vector_elt(x, 0)) -> broadcast(x).
40785     if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
40786         isNullConstant(Src.getOperand(1)) &&
40787         Src.getValueType() ==
40788             Src.getOperand(0).getValueType().getScalarType() &&
40789         TLI.isTypeLegal(Src.getOperand(0).getValueType()))
40790       return DAG.getNode(X86ISD::VBROADCAST, DL, VT, Src.getOperand(0));
40791 
40792     // Share broadcast with the longest vector and extract low subvector (free).
40793     // Ensure the same SDValue from the SDNode use is being used.
40794     for (SDNode *User : Src->uses())
40795       if (User != N.getNode() && User->getOpcode() == X86ISD::VBROADCAST &&
40796           Src == User->getOperand(0) &&
40797           User->getValueSizeInBits(0).getFixedValue() >
40798               VT.getFixedSizeInBits()) {
40799         return extractSubVector(SDValue(User, 0), 0, DAG, DL,
40800                                 VT.getSizeInBits());
40801       }
40802 
40803     // vbroadcast(scalarload X) -> vbroadcast_load X
40804     // For float loads, extract other uses of the scalar from the broadcast.
40805     if (!SrcVT.isVector() && (Src.hasOneUse() || VT.isFloatingPoint()) &&
40806         ISD::isNormalLoad(Src.getNode())) {
40807       LoadSDNode *LN = cast<LoadSDNode>(Src);
40808       SDVTList Tys = DAG.getVTList(VT, MVT::Other);
40809       SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
40810       SDValue BcastLd =
40811           DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops,
40812                                   LN->getMemoryVT(), LN->getMemOperand());
40813       // If the load value is used only by N, replace it via CombineTo N.
40814       bool NoReplaceExtract = Src.hasOneUse();
40815       DCI.CombineTo(N.getNode(), BcastLd);
40816       if (NoReplaceExtract) {
40817         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
40818         DCI.recursivelyDeleteUnusedNodes(LN);
40819       } else {
40820         SDValue Scl = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcVT, BcastLd,
40821                                   DAG.getIntPtrConstant(0, DL));
40822         DCI.CombineTo(LN, Scl, BcastLd.getValue(1));
40823       }
40824       return N; // Return N so it doesn't get rechecked!
40825     }
40826 
40827     // Due to isTypeDesirableForOp, we won't always shrink a load truncated to
40828     // i16. So shrink it ourselves if we can make a broadcast_load.
40829     if (SrcVT == MVT::i16 && Src.getOpcode() == ISD::TRUNCATE &&
40830         Src.hasOneUse() && Src.getOperand(0).hasOneUse()) {
40831       assert(Subtarget.hasAVX2() && "Expected AVX2");
40832       SDValue TruncIn = Src.getOperand(0);
40833 
40834       // If this is a truncate of a non extending load we can just narrow it to
40835       // use a broadcast_load.
40836       if (ISD::isNormalLoad(TruncIn.getNode())) {
40837         LoadSDNode *LN = cast<LoadSDNode>(TruncIn);
40838         // Unless its volatile or atomic.
40839         if (LN->isSimple()) {
40840           SDVTList Tys = DAG.getVTList(VT, MVT::Other);
40841           SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
40842           SDValue BcastLd = DAG.getMemIntrinsicNode(
40843               X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, MVT::i16,
40844               LN->getPointerInfo(), LN->getOriginalAlign(),
40845               LN->getMemOperand()->getFlags());
40846           DCI.CombineTo(N.getNode(), BcastLd);
40847           DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
40848           DCI.recursivelyDeleteUnusedNodes(Src.getNode());
40849           return N; // Return N so it doesn't get rechecked!
40850         }
40851       }
40852 
40853       // If this is a truncate of an i16 extload, we can directly replace it.
40854       if (ISD::isUNINDEXEDLoad(Src.getOperand(0).getNode()) &&
40855           ISD::isEXTLoad(Src.getOperand(0).getNode())) {
40856         LoadSDNode *LN = cast<LoadSDNode>(Src.getOperand(0));
40857         if (LN->getMemoryVT().getSizeInBits() == 16) {
40858           SDVTList Tys = DAG.getVTList(VT, MVT::Other);
40859           SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
40860           SDValue BcastLd =
40861               DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops,
40862                                       LN->getMemoryVT(), LN->getMemOperand());
40863           DCI.CombineTo(N.getNode(), BcastLd);
40864           DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
40865           DCI.recursivelyDeleteUnusedNodes(Src.getNode());
40866           return N; // Return N so it doesn't get rechecked!
40867         }
40868       }
40869 
40870       // If this is a truncate of load that has been shifted right, we can
40871       // offset the pointer and use a narrower load.
40872       if (TruncIn.getOpcode() == ISD::SRL &&
40873           TruncIn.getOperand(0).hasOneUse() &&
40874           isa<ConstantSDNode>(TruncIn.getOperand(1)) &&
40875           ISD::isNormalLoad(TruncIn.getOperand(0).getNode())) {
40876         LoadSDNode *LN = cast<LoadSDNode>(TruncIn.getOperand(0));
40877         unsigned ShiftAmt = TruncIn.getConstantOperandVal(1);
40878         // Make sure the shift amount and the load size are divisible by 16.
40879         // Don't do this if the load is volatile or atomic.
40880         if (ShiftAmt % 16 == 0 && TruncIn.getValueSizeInBits() % 16 == 0 &&
40881             LN->isSimple()) {
40882           unsigned Offset = ShiftAmt / 8;
40883           SDVTList Tys = DAG.getVTList(VT, MVT::Other);
40884           SDValue Ptr = DAG.getMemBasePlusOffset(
40885               LN->getBasePtr(), TypeSize::getFixed(Offset), DL);
40886           SDValue Ops[] = { LN->getChain(), Ptr };
40887           SDValue BcastLd = DAG.getMemIntrinsicNode(
40888               X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, MVT::i16,
40889               LN->getPointerInfo().getWithOffset(Offset),
40890               LN->getOriginalAlign(),
40891               LN->getMemOperand()->getFlags());
40892           DCI.CombineTo(N.getNode(), BcastLd);
40893           DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
40894           DCI.recursivelyDeleteUnusedNodes(Src.getNode());
40895           return N; // Return N so it doesn't get rechecked!
40896         }
40897       }
40898     }
40899 
40900     // vbroadcast(vzload X) -> vbroadcast_load X
40901     if (Src.getOpcode() == X86ISD::VZEXT_LOAD && Src.hasOneUse()) {
40902       MemSDNode *LN = cast<MemIntrinsicSDNode>(Src);
40903       if (LN->getMemoryVT().getSizeInBits() == VT.getScalarSizeInBits()) {
40904         SDVTList Tys = DAG.getVTList(VT, MVT::Other);
40905         SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
40906         SDValue BcastLd =
40907             DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops,
40908                                     LN->getMemoryVT(), LN->getMemOperand());
40909         DCI.CombineTo(N.getNode(), BcastLd);
40910         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
40911         DCI.recursivelyDeleteUnusedNodes(LN);
40912         return N; // Return N so it doesn't get rechecked!
40913       }
40914     }
40915 
40916     // vbroadcast(vector load X) -> vbroadcast_load
40917     if ((SrcVT == MVT::v2f64 || SrcVT == MVT::v4f32 || SrcVT == MVT::v2i64 ||
40918          SrcVT == MVT::v4i32) &&
40919         Src.hasOneUse() && ISD::isNormalLoad(Src.getNode())) {
40920       LoadSDNode *LN = cast<LoadSDNode>(Src);
40921       // Unless the load is volatile or atomic.
40922       if (LN->isSimple()) {
40923         SDVTList Tys = DAG.getVTList(VT, MVT::Other);
40924         SDValue Ops[] = {LN->getChain(), LN->getBasePtr()};
40925         SDValue BcastLd = DAG.getMemIntrinsicNode(
40926             X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, SrcVT.getScalarType(),
40927             LN->getPointerInfo(), LN->getOriginalAlign(),
40928             LN->getMemOperand()->getFlags());
40929         DCI.CombineTo(N.getNode(), BcastLd);
40930         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
40931         DCI.recursivelyDeleteUnusedNodes(LN);
40932         return N; // Return N so it doesn't get rechecked!
40933       }
40934     }
40935 
40936     return SDValue();
40937   }
40938   case X86ISD::VZEXT_MOVL: {
40939     SDValue N0 = N.getOperand(0);
40940 
40941     // If this a vzmovl of a full vector load, replace it with a vzload, unless
40942     // the load is volatile.
40943     if (N0.hasOneUse() && ISD::isNormalLoad(N0.getNode())) {
40944       auto *LN = cast<LoadSDNode>(N0);
40945       if (SDValue VZLoad =
40946               narrowLoadToVZLoad(LN, VT.getVectorElementType(), VT, DAG)) {
40947         DCI.CombineTo(N.getNode(), VZLoad);
40948         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
40949         DCI.recursivelyDeleteUnusedNodes(LN);
40950         return N;
40951       }
40952     }
40953 
40954     // If this a VZEXT_MOVL of a VBROADCAST_LOAD, we don't need the broadcast
40955     // and can just use a VZEXT_LOAD.
40956     // FIXME: Is there some way to do this with SimplifyDemandedVectorElts?
40957     if (N0.hasOneUse() && N0.getOpcode() == X86ISD::VBROADCAST_LOAD) {
40958       auto *LN = cast<MemSDNode>(N0);
40959       if (VT.getScalarSizeInBits() == LN->getMemoryVT().getSizeInBits()) {
40960         SDVTList Tys = DAG.getVTList(VT, MVT::Other);
40961         SDValue Ops[] = {LN->getChain(), LN->getBasePtr()};
40962         SDValue VZLoad =
40963             DAG.getMemIntrinsicNode(X86ISD::VZEXT_LOAD, DL, Tys, Ops,
40964                                     LN->getMemoryVT(), LN->getMemOperand());
40965         DCI.CombineTo(N.getNode(), VZLoad);
40966         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
40967         DCI.recursivelyDeleteUnusedNodes(LN);
40968         return N;
40969       }
40970     }
40971 
40972     // Turn (v2i64 (vzext_movl (scalar_to_vector (i64 X)))) into
40973     // (v2i64 (bitcast (v4i32 (vzext_movl (scalar_to_vector (i32 (trunc X)))))))
40974     // if the upper bits of the i64 are zero.
40975     if (N0.hasOneUse() && N0.getOpcode() == ISD::SCALAR_TO_VECTOR &&
40976         N0.getOperand(0).hasOneUse() &&
40977         N0.getOperand(0).getValueType() == MVT::i64) {
40978       SDValue In = N0.getOperand(0);
40979       APInt Mask = APInt::getHighBitsSet(64, 32);
40980       if (DAG.MaskedValueIsZero(In, Mask)) {
40981         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, In);
40982         MVT VecVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2);
40983         SDValue SclVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, Trunc);
40984         SDValue Movl = DAG.getNode(X86ISD::VZEXT_MOVL, DL, VecVT, SclVec);
40985         return DAG.getBitcast(VT, Movl);
40986       }
40987     }
40988 
40989     // Load a scalar integer constant directly to XMM instead of transferring an
40990     // immediate value from GPR.
40991     // vzext_movl (scalar_to_vector C) --> load [C,0...]
40992     if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR) {
40993       if (auto *C = dyn_cast<ConstantSDNode>(N0.getOperand(0))) {
40994         // Create a vector constant - scalar constant followed by zeros.
40995         EVT ScalarVT = N0.getOperand(0).getValueType();
40996         Type *ScalarTy = ScalarVT.getTypeForEVT(*DAG.getContext());
40997         unsigned NumElts = VT.getVectorNumElements();
40998         Constant *Zero = ConstantInt::getNullValue(ScalarTy);
40999         SmallVector<Constant *, 32> ConstantVec(NumElts, Zero);
41000         ConstantVec[0] = const_cast<ConstantInt *>(C->getConstantIntValue());
41001 
41002         // Load the vector constant from constant pool.
41003         MVT PVT = TLI.getPointerTy(DAG.getDataLayout());
41004         SDValue CP = DAG.getConstantPool(ConstantVector::get(ConstantVec), PVT);
41005         MachinePointerInfo MPI =
41006             MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
41007         Align Alignment = cast<ConstantPoolSDNode>(CP)->getAlign();
41008         return DAG.getLoad(VT, DL, DAG.getEntryNode(), CP, MPI, Alignment,
41009                            MachineMemOperand::MOLoad);
41010       }
41011     }
41012 
41013     // Pull subvector inserts into undef through VZEXT_MOVL by making it an
41014     // insert into a zero vector. This helps get VZEXT_MOVL closer to
41015     // scalar_to_vectors where 256/512 are canonicalized to an insert and a
41016     // 128-bit scalar_to_vector. This reduces the number of isel patterns.
41017     if (!DCI.isBeforeLegalizeOps() && N0.hasOneUse()) {
41018       SDValue V = peekThroughOneUseBitcasts(N0);
41019 
41020       if (V.getOpcode() == ISD::INSERT_SUBVECTOR && V.getOperand(0).isUndef() &&
41021           isNullConstant(V.getOperand(2))) {
41022         SDValue In = V.getOperand(1);
41023         MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(),
41024                                      In.getValueSizeInBits() /
41025                                          VT.getScalarSizeInBits());
41026         In = DAG.getBitcast(SubVT, In);
41027         SDValue Movl = DAG.getNode(X86ISD::VZEXT_MOVL, DL, SubVT, In);
41028         return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
41029                            getZeroVector(VT, Subtarget, DAG, DL), Movl,
41030                            V.getOperand(2));
41031       }
41032     }
41033 
41034     return SDValue();
41035   }
41036   case X86ISD::BLENDI: {
41037     SDValue N0 = N.getOperand(0);
41038     SDValue N1 = N.getOperand(1);
41039     unsigned EltBits = VT.getScalarSizeInBits();
41040 
41041     if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
41042       // blend(bitcast(x),bitcast(y)) -> bitcast(blend(x,y)) to narrower types.
41043       // TODO: Handle MVT::v16i16 repeated blend mask.
41044       if (N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()) {
41045         MVT SrcVT = N0.getOperand(0).getSimpleValueType();
41046         unsigned SrcBits = SrcVT.getScalarSizeInBits();
41047         if ((EltBits % SrcBits) == 0 && SrcBits >= 32) {
41048           unsigned Size = VT.getVectorNumElements();
41049           unsigned NewSize = SrcVT.getVectorNumElements();
41050           APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(Size);
41051           APInt NewBlendMask = APIntOps::ScaleBitMask(BlendMask, NewSize);
41052           return DAG.getBitcast(
41053               VT, DAG.getNode(X86ISD::BLENDI, DL, SrcVT, N0.getOperand(0),
41054                               N1.getOperand(0),
41055                               DAG.getTargetConstant(NewBlendMask.getZExtValue(),
41056                                                     DL, MVT::i8)));
41057         }
41058       }
41059       // Share PSHUFB masks:
41060       // blend(pshufb(x,m1),pshufb(y,m2))
41061       // --> m3 = blend(m1,m2)
41062       //     blend(pshufb(x,m3),pshufb(y,m3))
41063       if (N0.hasOneUse() && N1.hasOneUse()) {
41064         SmallVector<int> Mask, ByteMask;
41065         SmallVector<SDValue> Ops;
41066         SDValue LHS = peekThroughOneUseBitcasts(N0);
41067         SDValue RHS = peekThroughOneUseBitcasts(N1);
41068         if (LHS.getOpcode() == X86ISD::PSHUFB &&
41069             RHS.getOpcode() == X86ISD::PSHUFB &&
41070             LHS.getOperand(1) != RHS.getOperand(1) &&
41071             LHS.getOperand(1).hasOneUse() && RHS.getOperand(1).hasOneUse() &&
41072             getTargetShuffleMask(N, /*AllowSentinelZero=*/false, Ops, Mask)) {
41073           assert(Ops.size() == 2 && LHS == peekThroughOneUseBitcasts(Ops[0]) &&
41074                  RHS == peekThroughOneUseBitcasts(Ops[1]) &&
41075                  "BLENDI decode mismatch");
41076           MVT ShufVT = LHS.getSimpleValueType();
41077           SDValue MaskLHS = LHS.getOperand(1);
41078           SDValue MaskRHS = RHS.getOperand(1);
41079           llvm::narrowShuffleMaskElts(EltBits / 8, Mask, ByteMask);
41080           if (SDValue NewMask = combineX86ShufflesConstants(
41081                   ShufVT, {MaskLHS, MaskRHS}, ByteMask,
41082                   /*HasVariableMask=*/true, DAG, DL, Subtarget)) {
41083             SDValue NewLHS = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT,
41084                                          LHS.getOperand(0), NewMask);
41085             SDValue NewRHS = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT,
41086                                          RHS.getOperand(0), NewMask);
41087             return DAG.getNode(X86ISD::BLENDI, DL, VT,
41088                                DAG.getBitcast(VT, NewLHS),
41089                                DAG.getBitcast(VT, NewRHS), N.getOperand(2));
41090           }
41091         }
41092       }
41093     }
41094     return SDValue();
41095   }
41096   case X86ISD::SHUFP: {
41097     // Fold shufps(shuffle(x),shuffle(y)) -> shufps(x,y).
41098     // This is a more relaxed shuffle combiner that can ignore oneuse limits.
41099     // TODO: Support types other than v4f32.
41100     if (VT == MVT::v4f32) {
41101       bool Updated = false;
41102       SmallVector<int> Mask;
41103       SmallVector<SDValue> Ops;
41104       if (getTargetShuffleMask(N, false, Ops, Mask) && Ops.size() == 2) {
41105         for (int i = 0; i != 2; ++i) {
41106           SmallVector<SDValue> SubOps;
41107           SmallVector<int> SubMask, SubScaledMask;
41108           SDValue Sub = peekThroughBitcasts(Ops[i]);
41109           // TODO: Scaling might be easier if we specify the demanded elts.
41110           if (getTargetShuffleInputs(Sub, SubOps, SubMask, DAG, 0, false) &&
41111               scaleShuffleElements(SubMask, 4, SubScaledMask) &&
41112               SubOps.size() == 1 && isUndefOrInRange(SubScaledMask, 0, 4)) {
41113             int Ofs = i * 2;
41114             Mask[Ofs + 0] = SubScaledMask[Mask[Ofs + 0] % 4] + (i * 4);
41115             Mask[Ofs + 1] = SubScaledMask[Mask[Ofs + 1] % 4] + (i * 4);
41116             Ops[i] = DAG.getBitcast(VT, SubOps[0]);
41117             Updated = true;
41118           }
41119         }
41120       }
41121       if (Updated) {
41122         for (int &M : Mask)
41123           M %= 4;
41124         Ops.push_back(getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
41125         return DAG.getNode(X86ISD::SHUFP, DL, VT, Ops);
41126       }
41127     }
41128     return SDValue();
41129   }
41130   case X86ISD::VPERMI: {
41131     // vpermi(bitcast(x)) -> bitcast(vpermi(x)) for same number of elements.
41132     // TODO: Remove when we have preferred domains in combineX86ShuffleChain.
41133     SDValue N0 = N.getOperand(0);
41134     SDValue N1 = N.getOperand(1);
41135     unsigned EltSizeInBits = VT.getScalarSizeInBits();
41136     if (N0.getOpcode() == ISD::BITCAST &&
41137         N0.getOperand(0).getScalarValueSizeInBits() == EltSizeInBits) {
41138       SDValue Src = N0.getOperand(0);
41139       EVT SrcVT = Src.getValueType();
41140       SDValue Res = DAG.getNode(X86ISD::VPERMI, DL, SrcVT, Src, N1);
41141       return DAG.getBitcast(VT, Res);
41142     }
41143     return SDValue();
41144   }
41145   case X86ISD::SHUF128: {
41146     // If we're permuting the upper 256-bits subvectors of a concatenation, then
41147     // see if we can peek through and access the subvector directly.
41148     if (VT.is512BitVector()) {
41149       // 512-bit mask uses 4 x i2 indices - if the msb is always set then only the
41150       // upper subvector is used.
41151       SDValue LHS = N->getOperand(0);
41152       SDValue RHS = N->getOperand(1);
41153       uint64_t Mask = N->getConstantOperandVal(2);
41154       SmallVector<SDValue> LHSOps, RHSOps;
41155       SDValue NewLHS, NewRHS;
41156       if ((Mask & 0x0A) == 0x0A &&
41157           collectConcatOps(LHS.getNode(), LHSOps, DAG) && LHSOps.size() == 2) {
41158         NewLHS = widenSubVector(LHSOps[1], false, Subtarget, DAG, DL, 512);
41159         Mask &= ~0x0A;
41160       }
41161       if ((Mask & 0xA0) == 0xA0 &&
41162           collectConcatOps(RHS.getNode(), RHSOps, DAG) && RHSOps.size() == 2) {
41163         NewRHS = widenSubVector(RHSOps[1], false, Subtarget, DAG, DL, 512);
41164         Mask &= ~0xA0;
41165       }
41166       if (NewLHS || NewRHS)
41167         return DAG.getNode(X86ISD::SHUF128, DL, VT, NewLHS ? NewLHS : LHS,
41168                            NewRHS ? NewRHS : RHS,
41169                            DAG.getTargetConstant(Mask, DL, MVT::i8));
41170     }
41171     return SDValue();
41172   }
41173   case X86ISD::VPERM2X128: {
41174     // Fold vperm2x128(bitcast(x),bitcast(y),c) -> bitcast(vperm2x128(x,y,c)).
41175     SDValue LHS = N->getOperand(0);
41176     SDValue RHS = N->getOperand(1);
41177     if (LHS.getOpcode() == ISD::BITCAST &&
41178         (RHS.getOpcode() == ISD::BITCAST || RHS.isUndef())) {
41179       EVT SrcVT = LHS.getOperand(0).getValueType();
41180       if (RHS.isUndef() || SrcVT == RHS.getOperand(0).getValueType()) {
41181         return DAG.getBitcast(VT, DAG.getNode(X86ISD::VPERM2X128, DL, SrcVT,
41182                                               DAG.getBitcast(SrcVT, LHS),
41183                                               DAG.getBitcast(SrcVT, RHS),
41184                                               N->getOperand(2)));
41185       }
41186     }
41187 
41188     // Fold vperm2x128(op(),op()) -> op(vperm2x128(),vperm2x128()).
41189     if (SDValue Res = canonicalizeLaneShuffleWithRepeatedOps(N, DAG, DL))
41190       return Res;
41191 
41192     // Fold vperm2x128 subvector shuffle with an inner concat pattern.
41193     // vperm2x128(concat(X,Y),concat(Z,W)) --> concat X,Y etc.
41194     auto FindSubVector128 = [&](unsigned Idx) {
41195       if (Idx > 3)
41196         return SDValue();
41197       SDValue Src = peekThroughBitcasts(N.getOperand(Idx < 2 ? 0 : 1));
41198       SmallVector<SDValue> SubOps;
41199       if (collectConcatOps(Src.getNode(), SubOps, DAG) && SubOps.size() == 2)
41200         return SubOps[Idx & 1];
41201       unsigned NumElts = Src.getValueType().getVectorNumElements();
41202       if ((Idx & 1) == 1 && Src.getOpcode() == ISD::INSERT_SUBVECTOR &&
41203           Src.getOperand(1).getValueSizeInBits() == 128 &&
41204           Src.getConstantOperandAPInt(2) == (NumElts / 2)) {
41205         return Src.getOperand(1);
41206       }
41207       return SDValue();
41208     };
41209     unsigned Imm = N.getConstantOperandVal(2);
41210     if (SDValue SubLo = FindSubVector128(Imm & 0x0F)) {
41211       if (SDValue SubHi = FindSubVector128((Imm & 0xF0) >> 4)) {
41212         MVT SubVT = VT.getHalfNumVectorElementsVT();
41213         SubLo = DAG.getBitcast(SubVT, SubLo);
41214         SubHi = DAG.getBitcast(SubVT, SubHi);
41215         return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SubLo, SubHi);
41216       }
41217     }
41218     return SDValue();
41219   }
41220   case X86ISD::PSHUFD:
41221   case X86ISD::PSHUFLW:
41222   case X86ISD::PSHUFHW: {
41223     SDValue N0 = N.getOperand(0);
41224     SDValue N1 = N.getOperand(1);
41225     if (N0->hasOneUse()) {
41226       SDValue V = peekThroughOneUseBitcasts(N0);
41227       switch (V.getOpcode()) {
41228       case X86ISD::VSHL:
41229       case X86ISD::VSRL:
41230       case X86ISD::VSRA:
41231       case X86ISD::VSHLI:
41232       case X86ISD::VSRLI:
41233       case X86ISD::VSRAI:
41234       case X86ISD::VROTLI:
41235       case X86ISD::VROTRI: {
41236         MVT InnerVT = V.getSimpleValueType();
41237         if (InnerVT.getScalarSizeInBits() <= VT.getScalarSizeInBits()) {
41238           SDValue Res = DAG.getNode(Opcode, DL, VT,
41239                                     DAG.getBitcast(VT, V.getOperand(0)), N1);
41240           Res = DAG.getBitcast(InnerVT, Res);
41241           Res = DAG.getNode(V.getOpcode(), DL, InnerVT, Res, V.getOperand(1));
41242           return DAG.getBitcast(VT, Res);
41243         }
41244         break;
41245       }
41246       }
41247     }
41248 
41249     Mask = getPSHUFShuffleMask(N);
41250     assert(Mask.size() == 4);
41251     break;
41252   }
41253   case X86ISD::MOVSD:
41254   case X86ISD::MOVSH:
41255   case X86ISD::MOVSS: {
41256     SDValue N0 = N.getOperand(0);
41257     SDValue N1 = N.getOperand(1);
41258 
41259     // Canonicalize scalar FPOps:
41260     // MOVS*(N0, OP(N0, N1)) --> MOVS*(N0, SCALAR_TO_VECTOR(OP(N0[0], N1[0])))
41261     // If commutable, allow OP(N1[0], N0[0]).
41262     unsigned Opcode1 = N1.getOpcode();
41263     if (Opcode1 == ISD::FADD || Opcode1 == ISD::FMUL || Opcode1 == ISD::FSUB ||
41264         Opcode1 == ISD::FDIV) {
41265       SDValue N10 = N1.getOperand(0);
41266       SDValue N11 = N1.getOperand(1);
41267       if (N10 == N0 ||
41268           (N11 == N0 && (Opcode1 == ISD::FADD || Opcode1 == ISD::FMUL))) {
41269         if (N10 != N0)
41270           std::swap(N10, N11);
41271         MVT SVT = VT.getVectorElementType();
41272         SDValue ZeroIdx = DAG.getIntPtrConstant(0, DL);
41273         N10 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SVT, N10, ZeroIdx);
41274         N11 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SVT, N11, ZeroIdx);
41275         SDValue Scl = DAG.getNode(Opcode1, DL, SVT, N10, N11);
41276         SDValue SclVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl);
41277         return DAG.getNode(Opcode, DL, VT, N0, SclVec);
41278       }
41279     }
41280 
41281     return SDValue();
41282   }
41283   case X86ISD::INSERTPS: {
41284     assert(VT == MVT::v4f32 && "INSERTPS ValueType must be MVT::v4f32");
41285     SDValue Op0 = N.getOperand(0);
41286     SDValue Op1 = N.getOperand(1);
41287     unsigned InsertPSMask = N.getConstantOperandVal(2);
41288     unsigned SrcIdx = (InsertPSMask >> 6) & 0x3;
41289     unsigned DstIdx = (InsertPSMask >> 4) & 0x3;
41290     unsigned ZeroMask = InsertPSMask & 0xF;
41291 
41292     // If we zero out all elements from Op0 then we don't need to reference it.
41293     if (((ZeroMask | (1u << DstIdx)) == 0xF) && !Op0.isUndef())
41294       return DAG.getNode(X86ISD::INSERTPS, DL, VT, DAG.getUNDEF(VT), Op1,
41295                          DAG.getTargetConstant(InsertPSMask, DL, MVT::i8));
41296 
41297     // If we zero out the element from Op1 then we don't need to reference it.
41298     if ((ZeroMask & (1u << DstIdx)) && !Op1.isUndef())
41299       return DAG.getNode(X86ISD::INSERTPS, DL, VT, Op0, DAG.getUNDEF(VT),
41300                          DAG.getTargetConstant(InsertPSMask, DL, MVT::i8));
41301 
41302     // Attempt to merge insertps Op1 with an inner target shuffle node.
41303     SmallVector<int, 8> TargetMask1;
41304     SmallVector<SDValue, 2> Ops1;
41305     APInt KnownUndef1, KnownZero1;
41306     if (getTargetShuffleAndZeroables(Op1, TargetMask1, Ops1, KnownUndef1,
41307                                      KnownZero1)) {
41308       if (KnownUndef1[SrcIdx] || KnownZero1[SrcIdx]) {
41309         // Zero/UNDEF insertion - zero out element and remove dependency.
41310         InsertPSMask |= (1u << DstIdx);
41311         return DAG.getNode(X86ISD::INSERTPS, DL, VT, Op0, DAG.getUNDEF(VT),
41312                            DAG.getTargetConstant(InsertPSMask, DL, MVT::i8));
41313       }
41314       // Update insertps mask srcidx and reference the source input directly.
41315       int M = TargetMask1[SrcIdx];
41316       assert(0 <= M && M < 8 && "Shuffle index out of range");
41317       InsertPSMask = (InsertPSMask & 0x3f) | ((M & 0x3) << 6);
41318       Op1 = Ops1[M < 4 ? 0 : 1];
41319       return DAG.getNode(X86ISD::INSERTPS, DL, VT, Op0, Op1,
41320                          DAG.getTargetConstant(InsertPSMask, DL, MVT::i8));
41321     }
41322 
41323     // Attempt to merge insertps Op0 with an inner target shuffle node.
41324     SmallVector<int, 8> TargetMask0;
41325     SmallVector<SDValue, 2> Ops0;
41326     APInt KnownUndef0, KnownZero0;
41327     if (getTargetShuffleAndZeroables(Op0, TargetMask0, Ops0, KnownUndef0,
41328                                      KnownZero0)) {
41329       bool Updated = false;
41330       bool UseInput00 = false;
41331       bool UseInput01 = false;
41332       for (int i = 0; i != 4; ++i) {
41333         if ((InsertPSMask & (1u << i)) || (i == (int)DstIdx)) {
41334           // No change if element is already zero or the inserted element.
41335           continue;
41336         }
41337 
41338         if (KnownUndef0[i] || KnownZero0[i]) {
41339           // If the target mask is undef/zero then we must zero the element.
41340           InsertPSMask |= (1u << i);
41341           Updated = true;
41342           continue;
41343         }
41344 
41345         // The input vector element must be inline.
41346         int M = TargetMask0[i];
41347         if (M != i && M != (i + 4))
41348           return SDValue();
41349 
41350         // Determine which inputs of the target shuffle we're using.
41351         UseInput00 |= (0 <= M && M < 4);
41352         UseInput01 |= (4 <= M);
41353       }
41354 
41355       // If we're not using both inputs of the target shuffle then use the
41356       // referenced input directly.
41357       if (UseInput00 && !UseInput01) {
41358         Updated = true;
41359         Op0 = Ops0[0];
41360       } else if (!UseInput00 && UseInput01) {
41361         Updated = true;
41362         Op0 = Ops0[1];
41363       }
41364 
41365       if (Updated)
41366         return DAG.getNode(X86ISD::INSERTPS, DL, VT, Op0, Op1,
41367                            DAG.getTargetConstant(InsertPSMask, DL, MVT::i8));
41368     }
41369 
41370     // If we're inserting an element from a vbroadcast load, fold the
41371     // load into the X86insertps instruction. We need to convert the scalar
41372     // load to a vector and clear the source lane of the INSERTPS control.
41373     if (Op1.getOpcode() == X86ISD::VBROADCAST_LOAD && Op1.hasOneUse()) {
41374       auto *MemIntr = cast<MemIntrinsicSDNode>(Op1);
41375       if (MemIntr->getMemoryVT().getScalarSizeInBits() == 32) {
41376         SDValue Load = DAG.getLoad(MVT::f32, DL, MemIntr->getChain(),
41377                                    MemIntr->getBasePtr(),
41378                                    MemIntr->getMemOperand());
41379         SDValue Insert = DAG.getNode(X86ISD::INSERTPS, DL, VT, Op0,
41380                            DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT,
41381                                        Load),
41382                            DAG.getTargetConstant(InsertPSMask & 0x3f, DL, MVT::i8));
41383         DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), Load.getValue(1));
41384         return Insert;
41385       }
41386     }
41387 
41388     return SDValue();
41389   }
41390   case X86ISD::VPERMV3: {
41391     // Combine VPERMV3 to widened VPERMV if the two source operands are split
41392     // from the same vector.
41393     SDValue V1 = peekThroughBitcasts(N.getOperand(0));
41394     SDValue V2 = peekThroughBitcasts(N.getOperand(2));
41395     MVT SVT = V1.getSimpleValueType();
41396     if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41397         V1.getConstantOperandVal(1) == 0 &&
41398         V2.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
41399         V2.getConstantOperandVal(1) == SVT.getVectorNumElements() &&
41400         V1.getOperand(0) == V2.getOperand(0)) {
41401       EVT NVT = V1.getOperand(0).getValueType();
41402       if (NVT.is256BitVector() ||
41403           (NVT.is512BitVector() && Subtarget.hasEVEX512())) {
41404         MVT WideVT = MVT::getVectorVT(
41405             VT.getScalarType(), NVT.getSizeInBits() / VT.getScalarSizeInBits());
41406         SDValue Mask = widenSubVector(N.getOperand(1), false, Subtarget, DAG,
41407                                       DL, WideVT.getSizeInBits());
41408         SDValue Perm = DAG.getNode(X86ISD::VPERMV, DL, WideVT, Mask,
41409                                    DAG.getBitcast(WideVT, V1.getOperand(0)));
41410         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Perm,
41411                            DAG.getIntPtrConstant(0, DL));
41412       }
41413     }
41414     return SDValue();
41415   }
41416   default:
41417     return SDValue();
41418   }
41419 
41420   // Nuke no-op shuffles that show up after combining.
41421   if (isNoopShuffleMask(Mask))
41422     return N.getOperand(0);
41423 
41424   // Look for simplifications involving one or two shuffle instructions.
41425   SDValue V = N.getOperand(0);
41426   switch (N.getOpcode()) {
41427   default:
41428     break;
41429   case X86ISD::PSHUFLW:
41430   case X86ISD::PSHUFHW:
41431     assert(VT.getVectorElementType() == MVT::i16 && "Bad word shuffle type!");
41432 
41433     // See if this reduces to a PSHUFD which is no more expensive and can
41434     // combine with more operations. Note that it has to at least flip the
41435     // dwords as otherwise it would have been removed as a no-op.
41436     if (ArrayRef<int>(Mask).equals({2, 3, 0, 1})) {
41437       int DMask[] = {0, 1, 2, 3};
41438       int DOffset = N.getOpcode() == X86ISD::PSHUFLW ? 0 : 2;
41439       DMask[DOffset + 0] = DOffset + 1;
41440       DMask[DOffset + 1] = DOffset + 0;
41441       MVT DVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2);
41442       V = DAG.getBitcast(DVT, V);
41443       V = DAG.getNode(X86ISD::PSHUFD, DL, DVT, V,
41444                       getV4X86ShuffleImm8ForMask(DMask, DL, DAG));
41445       return DAG.getBitcast(VT, V);
41446     }
41447 
41448     // Look for shuffle patterns which can be implemented as a single unpack.
41449     // FIXME: This doesn't handle the location of the PSHUFD generically, and
41450     // only works when we have a PSHUFD followed by two half-shuffles.
41451     if (Mask[0] == Mask[1] && Mask[2] == Mask[3] &&
41452         (V.getOpcode() == X86ISD::PSHUFLW ||
41453          V.getOpcode() == X86ISD::PSHUFHW) &&
41454         V.getOpcode() != N.getOpcode() &&
41455         V.hasOneUse() && V.getOperand(0).hasOneUse()) {
41456       SDValue D = peekThroughOneUseBitcasts(V.getOperand(0));
41457       if (D.getOpcode() == X86ISD::PSHUFD) {
41458         SmallVector<int, 4> VMask = getPSHUFShuffleMask(V);
41459         SmallVector<int, 4> DMask = getPSHUFShuffleMask(D);
41460         int NOffset = N.getOpcode() == X86ISD::PSHUFLW ? 0 : 4;
41461         int VOffset = V.getOpcode() == X86ISD::PSHUFLW ? 0 : 4;
41462         int WordMask[8];
41463         for (int i = 0; i < 4; ++i) {
41464           WordMask[i + NOffset] = Mask[i] + NOffset;
41465           WordMask[i + VOffset] = VMask[i] + VOffset;
41466         }
41467         // Map the word mask through the DWord mask.
41468         int MappedMask[8];
41469         for (int i = 0; i < 8; ++i)
41470           MappedMask[i] = 2 * DMask[WordMask[i] / 2] + WordMask[i] % 2;
41471         if (ArrayRef<int>(MappedMask).equals({0, 0, 1, 1, 2, 2, 3, 3}) ||
41472             ArrayRef<int>(MappedMask).equals({4, 4, 5, 5, 6, 6, 7, 7})) {
41473           // We can replace all three shuffles with an unpack.
41474           V = DAG.getBitcast(VT, D.getOperand(0));
41475           return DAG.getNode(MappedMask[0] == 0 ? X86ISD::UNPCKL
41476                                                 : X86ISD::UNPCKH,
41477                              DL, VT, V, V);
41478         }
41479       }
41480     }
41481 
41482     break;
41483 
41484   case X86ISD::PSHUFD:
41485     if (SDValue NewN = combineRedundantDWordShuffle(N, Mask, DL, DAG))
41486       return NewN;
41487 
41488     break;
41489   }
41490 
41491   return SDValue();
41492 }
41493 
41494 /// Checks if the shuffle mask takes subsequent elements
41495 /// alternately from two vectors.
41496 /// For example <0, 5, 2, 7> or <8, 1, 10, 3, 12, 5, 14, 7> are both correct.
isAddSubOrSubAddMask(ArrayRef<int> Mask,bool & Op0Even)41497 static bool isAddSubOrSubAddMask(ArrayRef<int> Mask, bool &Op0Even) {
41498 
41499   int ParitySrc[2] = {-1, -1};
41500   unsigned Size = Mask.size();
41501   for (unsigned i = 0; i != Size; ++i) {
41502     int M = Mask[i];
41503     if (M < 0)
41504       continue;
41505 
41506     // Make sure we are using the matching element from the input.
41507     if ((M % Size) != i)
41508       return false;
41509 
41510     // Make sure we use the same input for all elements of the same parity.
41511     int Src = M / Size;
41512     if (ParitySrc[i % 2] >= 0 && ParitySrc[i % 2] != Src)
41513       return false;
41514     ParitySrc[i % 2] = Src;
41515   }
41516 
41517   // Make sure each input is used.
41518   if (ParitySrc[0] < 0 || ParitySrc[1] < 0 || ParitySrc[0] == ParitySrc[1])
41519     return false;
41520 
41521   Op0Even = ParitySrc[0] == 0;
41522   return true;
41523 }
41524 
41525 /// Returns true iff the shuffle node \p N can be replaced with ADDSUB(SUBADD)
41526 /// operation. If true is returned then the operands of ADDSUB(SUBADD) operation
41527 /// are written to the parameters \p Opnd0 and \p Opnd1.
41528 ///
41529 /// We combine shuffle to ADDSUB(SUBADD) directly on the abstract vector shuffle nodes
41530 /// so it is easier to generically match. We also insert dummy vector shuffle
41531 /// nodes for the operands which explicitly discard the lanes which are unused
41532 /// by this operation to try to flow through the rest of the combiner
41533 /// the fact that they're unused.
isAddSubOrSubAdd(SDNode * N,const X86Subtarget & Subtarget,SelectionDAG & DAG,SDValue & Opnd0,SDValue & Opnd1,bool & IsSubAdd)41534 static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget,
41535                              SelectionDAG &DAG, SDValue &Opnd0, SDValue &Opnd1,
41536                              bool &IsSubAdd) {
41537 
41538   EVT VT = N->getValueType(0);
41539   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
41540   if (!Subtarget.hasSSE3() || !TLI.isTypeLegal(VT) ||
41541       !VT.getSimpleVT().isFloatingPoint())
41542     return false;
41543 
41544   // We only handle target-independent shuffles.
41545   // FIXME: It would be easy and harmless to use the target shuffle mask
41546   // extraction tool to support more.
41547   if (N->getOpcode() != ISD::VECTOR_SHUFFLE)
41548     return false;
41549 
41550   SDValue V1 = N->getOperand(0);
41551   SDValue V2 = N->getOperand(1);
41552 
41553   // Make sure we have an FADD and an FSUB.
41554   if ((V1.getOpcode() != ISD::FADD && V1.getOpcode() != ISD::FSUB) ||
41555       (V2.getOpcode() != ISD::FADD && V2.getOpcode() != ISD::FSUB) ||
41556       V1.getOpcode() == V2.getOpcode())
41557     return false;
41558 
41559   // If there are other uses of these operations we can't fold them.
41560   if (!V1->hasOneUse() || !V2->hasOneUse())
41561     return false;
41562 
41563   // Ensure that both operations have the same operands. Note that we can
41564   // commute the FADD operands.
41565   SDValue LHS, RHS;
41566   if (V1.getOpcode() == ISD::FSUB) {
41567     LHS = V1->getOperand(0); RHS = V1->getOperand(1);
41568     if ((V2->getOperand(0) != LHS || V2->getOperand(1) != RHS) &&
41569         (V2->getOperand(0) != RHS || V2->getOperand(1) != LHS))
41570       return false;
41571   } else {
41572     assert(V2.getOpcode() == ISD::FSUB && "Unexpected opcode");
41573     LHS = V2->getOperand(0); RHS = V2->getOperand(1);
41574     if ((V1->getOperand(0) != LHS || V1->getOperand(1) != RHS) &&
41575         (V1->getOperand(0) != RHS || V1->getOperand(1) != LHS))
41576       return false;
41577   }
41578 
41579   ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
41580   bool Op0Even;
41581   if (!isAddSubOrSubAddMask(Mask, Op0Even))
41582     return false;
41583 
41584   // It's a subadd if the vector in the even parity is an FADD.
41585   IsSubAdd = Op0Even ? V1->getOpcode() == ISD::FADD
41586                      : V2->getOpcode() == ISD::FADD;
41587 
41588   Opnd0 = LHS;
41589   Opnd1 = RHS;
41590   return true;
41591 }
41592 
41593 /// Combine shuffle of two fma nodes into FMAddSub or FMSubAdd.
combineShuffleToFMAddSub(SDNode * N,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)41594 static SDValue combineShuffleToFMAddSub(SDNode *N, const SDLoc &DL,
41595                                         const X86Subtarget &Subtarget,
41596                                         SelectionDAG &DAG) {
41597   // We only handle target-independent shuffles.
41598   // FIXME: It would be easy and harmless to use the target shuffle mask
41599   // extraction tool to support more.
41600   if (N->getOpcode() != ISD::VECTOR_SHUFFLE)
41601     return SDValue();
41602 
41603   MVT VT = N->getSimpleValueType(0);
41604   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
41605   if (!Subtarget.hasAnyFMA() || !TLI.isTypeLegal(VT))
41606     return SDValue();
41607 
41608   // We're trying to match (shuffle fma(a, b, c), X86Fmsub(a, b, c).
41609   SDValue Op0 = N->getOperand(0);
41610   SDValue Op1 = N->getOperand(1);
41611   SDValue FMAdd = Op0, FMSub = Op1;
41612   if (FMSub.getOpcode() != X86ISD::FMSUB)
41613     std::swap(FMAdd, FMSub);
41614 
41615   if (FMAdd.getOpcode() != ISD::FMA || FMSub.getOpcode() != X86ISD::FMSUB ||
41616       FMAdd.getOperand(0) != FMSub.getOperand(0) || !FMAdd.hasOneUse() ||
41617       FMAdd.getOperand(1) != FMSub.getOperand(1) || !FMSub.hasOneUse() ||
41618       FMAdd.getOperand(2) != FMSub.getOperand(2))
41619     return SDValue();
41620 
41621   // Check for correct shuffle mask.
41622   ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
41623   bool Op0Even;
41624   if (!isAddSubOrSubAddMask(Mask, Op0Even))
41625     return SDValue();
41626 
41627   // FMAddSub takes zeroth operand from FMSub node.
41628   bool IsSubAdd = Op0Even ? Op0 == FMAdd : Op1 == FMAdd;
41629   unsigned Opcode = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB;
41630   return DAG.getNode(Opcode, DL, VT, FMAdd.getOperand(0), FMAdd.getOperand(1),
41631                      FMAdd.getOperand(2));
41632 }
41633 
41634 /// Try to combine a shuffle into a target-specific add-sub or
41635 /// mul-add-sub node.
combineShuffleToAddSubOrFMAddSub(SDNode * N,const SDLoc & DL,const X86Subtarget & Subtarget,SelectionDAG & DAG)41636 static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, const SDLoc &DL,
41637                                                 const X86Subtarget &Subtarget,
41638                                                 SelectionDAG &DAG) {
41639   if (SDValue V = combineShuffleToFMAddSub(N, DL, Subtarget, DAG))
41640     return V;
41641 
41642   SDValue Opnd0, Opnd1;
41643   bool IsSubAdd;
41644   if (!isAddSubOrSubAdd(N, Subtarget, DAG, Opnd0, Opnd1, IsSubAdd))
41645     return SDValue();
41646 
41647   MVT VT = N->getSimpleValueType(0);
41648 
41649   // Try to generate X86ISD::FMADDSUB node here.
41650   SDValue Opnd2;
41651   if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) {
41652     unsigned Opc = IsSubAdd ? X86ISD::FMSUBADD : X86ISD::FMADDSUB;
41653     return DAG.getNode(Opc, DL, VT, Opnd0, Opnd1, Opnd2);
41654   }
41655 
41656   if (IsSubAdd)
41657     return SDValue();
41658 
41659   // Do not generate X86ISD::ADDSUB node for 512-bit types even though
41660   // the ADDSUB idiom has been successfully recognized. There are no known
41661   // X86 targets with 512-bit ADDSUB instructions!
41662   if (VT.is512BitVector())
41663     return SDValue();
41664 
41665   // Do not generate X86ISD::ADDSUB node for FP16's vector types even though
41666   // the ADDSUB idiom has been successfully recognized. There are no known
41667   // X86 targets with FP16 ADDSUB instructions!
41668   if (VT.getVectorElementType() == MVT::f16)
41669     return SDValue();
41670 
41671   return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1);
41672 }
41673 
41674 // We are looking for a shuffle where both sources are concatenated with undef
41675 // and have a width that is half of the output's width. AVX2 has VPERMD/Q, so
41676 // if we can express this as a single-source shuffle, that's preferable.
combineShuffleOfConcatUndef(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)41677 static SDValue combineShuffleOfConcatUndef(SDNode *N, const SDLoc &DL,
41678                                            SelectionDAG &DAG,
41679                                            const X86Subtarget &Subtarget) {
41680   if (!Subtarget.hasAVX2() || !isa<ShuffleVectorSDNode>(N))
41681     return SDValue();
41682 
41683   EVT VT = N->getValueType(0);
41684 
41685   // We only care about shuffles of 128/256-bit vectors of 32/64-bit values.
41686   if (!VT.is128BitVector() && !VT.is256BitVector())
41687     return SDValue();
41688 
41689   if (VT.getVectorElementType() != MVT::i32 &&
41690       VT.getVectorElementType() != MVT::i64 &&
41691       VT.getVectorElementType() != MVT::f32 &&
41692       VT.getVectorElementType() != MVT::f64)
41693     return SDValue();
41694 
41695   SDValue N0 = N->getOperand(0);
41696   SDValue N1 = N->getOperand(1);
41697 
41698   // Check that both sources are concats with undef.
41699   if (N0.getOpcode() != ISD::CONCAT_VECTORS ||
41700       N1.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
41701       N1.getNumOperands() != 2 || !N0.getOperand(1).isUndef() ||
41702       !N1.getOperand(1).isUndef())
41703     return SDValue();
41704 
41705   // Construct the new shuffle mask. Elements from the first source retain their
41706   // index, but elements from the second source no longer need to skip an undef.
41707   SmallVector<int, 8> Mask;
41708   int NumElts = VT.getVectorNumElements();
41709 
41710   auto *SVOp = cast<ShuffleVectorSDNode>(N);
41711   for (int Elt : SVOp->getMask())
41712     Mask.push_back(Elt < NumElts ? Elt : (Elt - NumElts / 2));
41713 
41714   SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, N0.getOperand(0),
41715                                N1.getOperand(0));
41716   return DAG.getVectorShuffle(VT, DL, Concat, DAG.getUNDEF(VT), Mask);
41717 }
41718 
41719 /// If we have a shuffle of AVX/AVX512 (256/512 bit) vectors that only uses the
41720 /// low half of each source vector and does not set any high half elements in
41721 /// the destination vector, narrow the shuffle to half its original size.
narrowShuffle(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)41722 static SDValue narrowShuffle(ShuffleVectorSDNode *Shuf, SelectionDAG &DAG) {
41723   EVT VT = Shuf->getValueType(0);
41724   if (!DAG.getTargetLoweringInfo().isTypeLegal(Shuf->getValueType(0)))
41725     return SDValue();
41726   if (!VT.is256BitVector() && !VT.is512BitVector())
41727     return SDValue();
41728 
41729   // See if we can ignore all of the high elements of the shuffle.
41730   ArrayRef<int> Mask = Shuf->getMask();
41731   if (!isUndefUpperHalf(Mask))
41732     return SDValue();
41733 
41734   // Check if the shuffle mask accesses only the low half of each input vector
41735   // (half-index output is 0 or 2).
41736   int HalfIdx1, HalfIdx2;
41737   SmallVector<int, 8> HalfMask(Mask.size() / 2);
41738   if (!getHalfShuffleMask(Mask, HalfMask, HalfIdx1, HalfIdx2) ||
41739       (HalfIdx1 % 2 == 1) || (HalfIdx2 % 2 == 1))
41740     return SDValue();
41741 
41742   // Create a half-width shuffle to replace the unnecessarily wide shuffle.
41743   // The trick is knowing that all of the insert/extract are actually free
41744   // subregister (zmm<->ymm or ymm<->xmm) ops. That leaves us with a shuffle
41745   // of narrow inputs into a narrow output, and that is always cheaper than
41746   // the wide shuffle that we started with.
41747   return getShuffleHalfVectors(SDLoc(Shuf), Shuf->getOperand(0),
41748                                Shuf->getOperand(1), HalfMask, HalfIdx1,
41749                                HalfIdx2, false, DAG, /*UseConcat*/ true);
41750 }
41751 
combineShuffle(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)41752 static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
41753                               TargetLowering::DAGCombinerInfo &DCI,
41754                               const X86Subtarget &Subtarget) {
41755   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(N))
41756     if (SDValue V = narrowShuffle(Shuf, DAG))
41757       return V;
41758 
41759   // If we have legalized the vector types, look for blends of FADD and FSUB
41760   // nodes that we can fuse into an ADDSUB, FMADDSUB, or FMSUBADD node.
41761   SDLoc dl(N);
41762   EVT VT = N->getValueType(0);
41763   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
41764   if (TLI.isTypeLegal(VT) && !isSoftF16(VT, Subtarget))
41765     if (SDValue AddSub =
41766             combineShuffleToAddSubOrFMAddSub(N, dl, Subtarget, DAG))
41767       return AddSub;
41768 
41769   // Attempt to combine into a vector load/broadcast.
41770   if (SDValue LD = combineToConsecutiveLoads(
41771           VT, SDValue(N, 0), dl, DAG, Subtarget, /*IsAfterLegalize*/ true))
41772     return LD;
41773 
41774   // For AVX2, we sometimes want to combine
41775   // (vector_shuffle <mask> (concat_vectors t1, undef)
41776   //                        (concat_vectors t2, undef))
41777   // Into:
41778   // (vector_shuffle <mask> (concat_vectors t1, t2), undef)
41779   // Since the latter can be efficiently lowered with VPERMD/VPERMQ
41780   if (SDValue ShufConcat = combineShuffleOfConcatUndef(N, dl, DAG, Subtarget))
41781     return ShufConcat;
41782 
41783   if (isTargetShuffle(N->getOpcode())) {
41784     SDValue Op(N, 0);
41785     if (SDValue Shuffle = combineTargetShuffle(Op, dl, DAG, DCI, Subtarget))
41786       return Shuffle;
41787 
41788     // Try recursively combining arbitrary sequences of x86 shuffle
41789     // instructions into higher-order shuffles. We do this after combining
41790     // specific PSHUF instruction sequences into their minimal form so that we
41791     // can evaluate how many specialized shuffle instructions are involved in
41792     // a particular chain.
41793     if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
41794       return Res;
41795 
41796     // Simplify source operands based on shuffle mask.
41797     // TODO - merge this into combineX86ShufflesRecursively.
41798     APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
41799     if (TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DCI))
41800       return SDValue(N, 0);
41801 
41802     // Canonicalize SHUFFLE(UNARYOP(X)) -> UNARYOP(SHUFFLE(X)).
41803     // Canonicalize SHUFFLE(BINOP(X,Y)) -> BINOP(SHUFFLE(X),SHUFFLE(Y)).
41804     // Perform this after other shuffle combines to allow inner shuffles to be
41805     // combined away first.
41806     if (SDValue BinOp = canonicalizeShuffleWithOp(Op, DAG, dl))
41807       return BinOp;
41808   }
41809 
41810   return SDValue();
41811 }
41812 
41813 // Simplify variable target shuffle masks based on the demanded elements.
41814 // TODO: Handle DemandedBits in mask indices as well?
SimplifyDemandedVectorEltsForTargetShuffle(SDValue Op,const APInt & DemandedElts,unsigned MaskIndex,TargetLowering::TargetLoweringOpt & TLO,unsigned Depth) const41815 bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
41816     SDValue Op, const APInt &DemandedElts, unsigned MaskIndex,
41817     TargetLowering::TargetLoweringOpt &TLO, unsigned Depth) const {
41818   // If we're demanding all elements don't bother trying to simplify the mask.
41819   unsigned NumElts = DemandedElts.getBitWidth();
41820   if (DemandedElts.isAllOnes())
41821     return false;
41822 
41823   SDValue Mask = Op.getOperand(MaskIndex);
41824   if (!Mask.hasOneUse())
41825     return false;
41826 
41827   // Attempt to generically simplify the variable shuffle mask.
41828   APInt MaskUndef, MaskZero;
41829   if (SimplifyDemandedVectorElts(Mask, DemandedElts, MaskUndef, MaskZero, TLO,
41830                                  Depth + 1))
41831     return true;
41832 
41833   // Attempt to extract+simplify a (constant pool load) shuffle mask.
41834   // TODO: Support other types from getTargetShuffleMaskIndices?
41835   SDValue BC = peekThroughOneUseBitcasts(Mask);
41836   EVT BCVT = BC.getValueType();
41837   auto *Load = dyn_cast<LoadSDNode>(BC);
41838   if (!Load || !Load->getBasePtr().hasOneUse())
41839     return false;
41840 
41841   const Constant *C = getTargetConstantFromNode(Load);
41842   if (!C)
41843     return false;
41844 
41845   Type *CTy = C->getType();
41846   if (!CTy->isVectorTy() ||
41847       CTy->getPrimitiveSizeInBits() != Mask.getValueSizeInBits())
41848     return false;
41849 
41850   // Handle scaling for i64 elements on 32-bit targets.
41851   unsigned NumCstElts = cast<FixedVectorType>(CTy)->getNumElements();
41852   if (NumCstElts != NumElts && NumCstElts != (NumElts * 2))
41853     return false;
41854   unsigned Scale = NumCstElts / NumElts;
41855 
41856   // Simplify mask if we have an undemanded element that is not undef.
41857   bool Simplified = false;
41858   SmallVector<Constant *, 32> ConstVecOps;
41859   for (unsigned i = 0; i != NumCstElts; ++i) {
41860     Constant *Elt = C->getAggregateElement(i);
41861     if (!DemandedElts[i / Scale] && !isa<UndefValue>(Elt)) {
41862       ConstVecOps.push_back(UndefValue::get(Elt->getType()));
41863       Simplified = true;
41864       continue;
41865     }
41866     ConstVecOps.push_back(Elt);
41867   }
41868   if (!Simplified)
41869     return false;
41870 
41871   // Generate new constant pool entry + legalize immediately for the load.
41872   SDLoc DL(Op);
41873   SDValue CV = TLO.DAG.getConstantPool(ConstantVector::get(ConstVecOps), BCVT);
41874   SDValue LegalCV = LowerConstantPool(CV, TLO.DAG);
41875   SDValue NewMask = TLO.DAG.getLoad(
41876       BCVT, DL, TLO.DAG.getEntryNode(), LegalCV,
41877       MachinePointerInfo::getConstantPool(TLO.DAG.getMachineFunction()),
41878       Load->getAlign());
41879   return TLO.CombineTo(Mask, TLO.DAG.getBitcast(Mask.getValueType(), NewMask));
41880 }
41881 
SimplifyDemandedVectorEltsForTargetNode(SDValue Op,const APInt & DemandedElts,APInt & KnownUndef,APInt & KnownZero,TargetLoweringOpt & TLO,unsigned Depth) const41882 bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
41883     SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero,
41884     TargetLoweringOpt &TLO, unsigned Depth) const {
41885   int NumElts = DemandedElts.getBitWidth();
41886   unsigned Opc = Op.getOpcode();
41887   EVT VT = Op.getValueType();
41888 
41889   // Handle special case opcodes.
41890   switch (Opc) {
41891   case X86ISD::PMULDQ:
41892   case X86ISD::PMULUDQ: {
41893     APInt LHSUndef, LHSZero;
41894     APInt RHSUndef, RHSZero;
41895     SDValue LHS = Op.getOperand(0);
41896     SDValue RHS = Op.getOperand(1);
41897     if (SimplifyDemandedVectorElts(LHS, DemandedElts, LHSUndef, LHSZero, TLO,
41898                                    Depth + 1))
41899       return true;
41900     if (SimplifyDemandedVectorElts(RHS, DemandedElts, RHSUndef, RHSZero, TLO,
41901                                    Depth + 1))
41902       return true;
41903     // Multiply by zero.
41904     KnownZero = LHSZero | RHSZero;
41905     break;
41906   }
41907   case X86ISD::VPMADDUBSW:
41908   case X86ISD::VPMADDWD: {
41909     APInt LHSUndef, LHSZero;
41910     APInt RHSUndef, RHSZero;
41911     SDValue LHS = Op.getOperand(0);
41912     SDValue RHS = Op.getOperand(1);
41913     APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, 2 * NumElts);
41914 
41915     if (SimplifyDemandedVectorElts(LHS, DemandedSrcElts, LHSUndef, LHSZero, TLO,
41916                                    Depth + 1))
41917       return true;
41918     if (SimplifyDemandedVectorElts(RHS, DemandedSrcElts, RHSUndef, RHSZero, TLO,
41919                                    Depth + 1))
41920       return true;
41921 
41922     // TODO: Multiply by zero.
41923 
41924     // If RHS/LHS elements are known zero then we don't need the LHS/RHS equivalent.
41925     APInt DemandedLHSElts = DemandedSrcElts & ~RHSZero;
41926     if (SimplifyDemandedVectorElts(LHS, DemandedLHSElts, LHSUndef, LHSZero, TLO,
41927                                    Depth + 1))
41928       return true;
41929     APInt DemandedRHSElts = DemandedSrcElts & ~LHSZero;
41930     if (SimplifyDemandedVectorElts(RHS, DemandedRHSElts, RHSUndef, RHSZero, TLO,
41931                                    Depth + 1))
41932       return true;
41933     break;
41934   }
41935   case X86ISD::PSADBW: {
41936     SDValue LHS = Op.getOperand(0);
41937     SDValue RHS = Op.getOperand(1);
41938     assert(VT.getScalarType() == MVT::i64 &&
41939            LHS.getValueType() == RHS.getValueType() &&
41940            LHS.getValueType().getScalarType() == MVT::i8 &&
41941            "Unexpected PSADBW types");
41942 
41943     // Aggressively peek through ops to get at the demanded elts.
41944     if (!DemandedElts.isAllOnes()) {
41945       unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
41946       APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
41947       SDValue NewLHS = SimplifyMultipleUseDemandedVectorElts(
41948           LHS, DemandedSrcElts, TLO.DAG, Depth + 1);
41949       SDValue NewRHS = SimplifyMultipleUseDemandedVectorElts(
41950           RHS, DemandedSrcElts, TLO.DAG, Depth + 1);
41951       if (NewLHS || NewRHS) {
41952         NewLHS = NewLHS ? NewLHS : LHS;
41953         NewRHS = NewRHS ? NewRHS : RHS;
41954         return TLO.CombineTo(
41955             Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewLHS, NewRHS));
41956       }
41957     }
41958     break;
41959   }
41960   case X86ISD::VSHL:
41961   case X86ISD::VSRL:
41962   case X86ISD::VSRA: {
41963     // We only need the bottom 64-bits of the (128-bit) shift amount.
41964     SDValue Amt = Op.getOperand(1);
41965     MVT AmtVT = Amt.getSimpleValueType();
41966     assert(AmtVT.is128BitVector() && "Unexpected value type");
41967 
41968     // If we reuse the shift amount just for sse shift amounts then we know that
41969     // only the bottom 64-bits are only ever used.
41970     bool AssumeSingleUse = llvm::all_of(Amt->uses(), [&Amt](SDNode *Use) {
41971       unsigned UseOpc = Use->getOpcode();
41972       return (UseOpc == X86ISD::VSHL || UseOpc == X86ISD::VSRL ||
41973               UseOpc == X86ISD::VSRA) &&
41974              Use->getOperand(0) != Amt;
41975     });
41976 
41977     APInt AmtUndef, AmtZero;
41978     unsigned NumAmtElts = AmtVT.getVectorNumElements();
41979     APInt AmtElts = APInt::getLowBitsSet(NumAmtElts, NumAmtElts / 2);
41980     if (SimplifyDemandedVectorElts(Amt, AmtElts, AmtUndef, AmtZero, TLO,
41981                                    Depth + 1, AssumeSingleUse))
41982       return true;
41983     [[fallthrough]];
41984   }
41985   case X86ISD::VSHLI:
41986   case X86ISD::VSRLI:
41987   case X86ISD::VSRAI: {
41988     SDValue Src = Op.getOperand(0);
41989     APInt SrcUndef;
41990     if (SimplifyDemandedVectorElts(Src, DemandedElts, SrcUndef, KnownZero, TLO,
41991                                    Depth + 1))
41992       return true;
41993 
41994     // Fold shift(0,x) -> 0
41995     if (DemandedElts.isSubsetOf(KnownZero))
41996       return TLO.CombineTo(
41997           Op, getZeroVector(VT.getSimpleVT(), Subtarget, TLO.DAG, SDLoc(Op)));
41998 
41999     // Aggressively peek through ops to get at the demanded elts.
42000     if (!DemandedElts.isAllOnes())
42001       if (SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
42002               Src, DemandedElts, TLO.DAG, Depth + 1))
42003         return TLO.CombineTo(
42004             Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewSrc, Op.getOperand(1)));
42005     break;
42006   }
42007   case X86ISD::VPSHA:
42008   case X86ISD::VPSHL:
42009   case X86ISD::VSHLV:
42010   case X86ISD::VSRLV:
42011   case X86ISD::VSRAV: {
42012     APInt LHSUndef, LHSZero;
42013     APInt RHSUndef, RHSZero;
42014     SDValue LHS = Op.getOperand(0);
42015     SDValue RHS = Op.getOperand(1);
42016     if (SimplifyDemandedVectorElts(LHS, DemandedElts, LHSUndef, LHSZero, TLO,
42017                                    Depth + 1))
42018       return true;
42019 
42020     // Fold shift(0,x) -> 0
42021     if (DemandedElts.isSubsetOf(LHSZero))
42022       return TLO.CombineTo(
42023           Op, getZeroVector(VT.getSimpleVT(), Subtarget, TLO.DAG, SDLoc(Op)));
42024 
42025     if (SimplifyDemandedVectorElts(RHS, DemandedElts, RHSUndef, RHSZero, TLO,
42026                                    Depth + 1))
42027       return true;
42028 
42029     KnownZero = LHSZero;
42030     break;
42031   }
42032   case X86ISD::PCMPEQ:
42033   case X86ISD::PCMPGT: {
42034     APInt LHSUndef, LHSZero;
42035     APInt RHSUndef, RHSZero;
42036     SDValue LHS = Op.getOperand(0);
42037     SDValue RHS = Op.getOperand(1);
42038     if (SimplifyDemandedVectorElts(LHS, DemandedElts, LHSUndef, LHSZero, TLO,
42039                                    Depth + 1))
42040       return true;
42041     if (SimplifyDemandedVectorElts(RHS, DemandedElts, RHSUndef, RHSZero, TLO,
42042                                    Depth + 1))
42043       return true;
42044     break;
42045   }
42046   case X86ISD::KSHIFTL: {
42047     SDValue Src = Op.getOperand(0);
42048     auto *Amt = cast<ConstantSDNode>(Op.getOperand(1));
42049     assert(Amt->getAPIntValue().ult(NumElts) && "Out of range shift amount");
42050     unsigned ShiftAmt = Amt->getZExtValue();
42051 
42052     if (ShiftAmt == 0)
42053       return TLO.CombineTo(Op, Src);
42054 
42055     // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
42056     // single shift.  We can do this if the bottom bits (which are shifted
42057     // out) are never demanded.
42058     if (Src.getOpcode() == X86ISD::KSHIFTR) {
42059       if (!DemandedElts.intersects(APInt::getLowBitsSet(NumElts, ShiftAmt))) {
42060         unsigned C1 = Src.getConstantOperandVal(1);
42061         unsigned NewOpc = X86ISD::KSHIFTL;
42062         int Diff = ShiftAmt - C1;
42063         if (Diff < 0) {
42064           Diff = -Diff;
42065           NewOpc = X86ISD::KSHIFTR;
42066         }
42067 
42068         SDLoc dl(Op);
42069         SDValue NewSA = TLO.DAG.getTargetConstant(Diff, dl, MVT::i8);
42070         return TLO.CombineTo(
42071             Op, TLO.DAG.getNode(NewOpc, dl, VT, Src.getOperand(0), NewSA));
42072       }
42073     }
42074 
42075     APInt DemandedSrc = DemandedElts.lshr(ShiftAmt);
42076     if (SimplifyDemandedVectorElts(Src, DemandedSrc, KnownUndef, KnownZero, TLO,
42077                                    Depth + 1))
42078       return true;
42079 
42080     KnownUndef <<= ShiftAmt;
42081     KnownZero <<= ShiftAmt;
42082     KnownZero.setLowBits(ShiftAmt);
42083     break;
42084   }
42085   case X86ISD::KSHIFTR: {
42086     SDValue Src = Op.getOperand(0);
42087     auto *Amt = cast<ConstantSDNode>(Op.getOperand(1));
42088     assert(Amt->getAPIntValue().ult(NumElts) && "Out of range shift amount");
42089     unsigned ShiftAmt = Amt->getZExtValue();
42090 
42091     if (ShiftAmt == 0)
42092       return TLO.CombineTo(Op, Src);
42093 
42094     // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
42095     // single shift.  We can do this if the top bits (which are shifted
42096     // out) are never demanded.
42097     if (Src.getOpcode() == X86ISD::KSHIFTL) {
42098       if (!DemandedElts.intersects(APInt::getHighBitsSet(NumElts, ShiftAmt))) {
42099         unsigned C1 = Src.getConstantOperandVal(1);
42100         unsigned NewOpc = X86ISD::KSHIFTR;
42101         int Diff = ShiftAmt - C1;
42102         if (Diff < 0) {
42103           Diff = -Diff;
42104           NewOpc = X86ISD::KSHIFTL;
42105         }
42106 
42107         SDLoc dl(Op);
42108         SDValue NewSA = TLO.DAG.getTargetConstant(Diff, dl, MVT::i8);
42109         return TLO.CombineTo(
42110             Op, TLO.DAG.getNode(NewOpc, dl, VT, Src.getOperand(0), NewSA));
42111       }
42112     }
42113 
42114     APInt DemandedSrc = DemandedElts.shl(ShiftAmt);
42115     if (SimplifyDemandedVectorElts(Src, DemandedSrc, KnownUndef, KnownZero, TLO,
42116                                    Depth + 1))
42117       return true;
42118 
42119     KnownUndef.lshrInPlace(ShiftAmt);
42120     KnownZero.lshrInPlace(ShiftAmt);
42121     KnownZero.setHighBits(ShiftAmt);
42122     break;
42123   }
42124   case X86ISD::ANDNP: {
42125     // ANDNP = (~LHS & RHS);
42126     SDValue LHS = Op.getOperand(0);
42127     SDValue RHS = Op.getOperand(1);
42128 
42129     auto GetDemandedMasks = [&](SDValue Op, bool Invert = false) {
42130       APInt UndefElts;
42131       SmallVector<APInt> EltBits;
42132       int NumElts = VT.getVectorNumElements();
42133       int EltSizeInBits = VT.getScalarSizeInBits();
42134       APInt OpBits = APInt::getAllOnes(EltSizeInBits);
42135       APInt OpElts = DemandedElts;
42136       if (getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts,
42137                                         EltBits)) {
42138         OpBits.clearAllBits();
42139         OpElts.clearAllBits();
42140         for (int I = 0; I != NumElts; ++I) {
42141           if (!DemandedElts[I])
42142             continue;
42143           if (UndefElts[I]) {
42144             // We can't assume an undef src element gives an undef dst - the
42145             // other src might be zero.
42146             OpBits.setAllBits();
42147             OpElts.setBit(I);
42148           } else if ((Invert && !EltBits[I].isAllOnes()) ||
42149                      (!Invert && !EltBits[I].isZero())) {
42150             OpBits |= Invert ? ~EltBits[I] : EltBits[I];
42151             OpElts.setBit(I);
42152           }
42153         }
42154       }
42155       return std::make_pair(OpBits, OpElts);
42156     };
42157     APInt BitsLHS, EltsLHS;
42158     APInt BitsRHS, EltsRHS;
42159     std::tie(BitsLHS, EltsLHS) = GetDemandedMasks(RHS);
42160     std::tie(BitsRHS, EltsRHS) = GetDemandedMasks(LHS, true);
42161 
42162     APInt LHSUndef, LHSZero;
42163     APInt RHSUndef, RHSZero;
42164     if (SimplifyDemandedVectorElts(LHS, EltsLHS, LHSUndef, LHSZero, TLO,
42165                                    Depth + 1))
42166       return true;
42167     if (SimplifyDemandedVectorElts(RHS, EltsRHS, RHSUndef, RHSZero, TLO,
42168                                    Depth + 1))
42169       return true;
42170 
42171     if (!DemandedElts.isAllOnes()) {
42172       SDValue NewLHS = SimplifyMultipleUseDemandedBits(LHS, BitsLHS, EltsLHS,
42173                                                        TLO.DAG, Depth + 1);
42174       SDValue NewRHS = SimplifyMultipleUseDemandedBits(RHS, BitsRHS, EltsRHS,
42175                                                        TLO.DAG, Depth + 1);
42176       if (NewLHS || NewRHS) {
42177         NewLHS = NewLHS ? NewLHS : LHS;
42178         NewRHS = NewRHS ? NewRHS : RHS;
42179         return TLO.CombineTo(
42180             Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewLHS, NewRHS));
42181       }
42182     }
42183     break;
42184   }
42185   case X86ISD::CVTSI2P:
42186   case X86ISD::CVTUI2P:
42187   case X86ISD::CVTPH2PS:
42188   case X86ISD::CVTPS2PH: {
42189     SDValue Src = Op.getOperand(0);
42190     EVT SrcVT = Src.getValueType();
42191     APInt SrcUndef, SrcZero;
42192     APInt SrcElts = DemandedElts.zextOrTrunc(SrcVT.getVectorNumElements());
42193     if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO,
42194                                    Depth + 1))
42195       return true;
42196     break;
42197   }
42198   case X86ISD::PACKSS:
42199   case X86ISD::PACKUS: {
42200     SDValue N0 = Op.getOperand(0);
42201     SDValue N1 = Op.getOperand(1);
42202 
42203     APInt DemandedLHS, DemandedRHS;
42204     getPackDemandedElts(VT, DemandedElts, DemandedLHS, DemandedRHS);
42205 
42206     APInt LHSUndef, LHSZero;
42207     if (SimplifyDemandedVectorElts(N0, DemandedLHS, LHSUndef, LHSZero, TLO,
42208                                    Depth + 1))
42209       return true;
42210     APInt RHSUndef, RHSZero;
42211     if (SimplifyDemandedVectorElts(N1, DemandedRHS, RHSUndef, RHSZero, TLO,
42212                                    Depth + 1))
42213       return true;
42214 
42215     // TODO - pass on known zero/undef.
42216 
42217     // Aggressively peek through ops to get at the demanded elts.
42218     // TODO - we should do this for all target/faux shuffles ops.
42219     if (!DemandedElts.isAllOnes()) {
42220       SDValue NewN0 = SimplifyMultipleUseDemandedVectorElts(N0, DemandedLHS,
42221                                                             TLO.DAG, Depth + 1);
42222       SDValue NewN1 = SimplifyMultipleUseDemandedVectorElts(N1, DemandedRHS,
42223                                                             TLO.DAG, Depth + 1);
42224       if (NewN0 || NewN1) {
42225         NewN0 = NewN0 ? NewN0 : N0;
42226         NewN1 = NewN1 ? NewN1 : N1;
42227         return TLO.CombineTo(Op,
42228                              TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewN0, NewN1));
42229       }
42230     }
42231     break;
42232   }
42233   case X86ISD::HADD:
42234   case X86ISD::HSUB:
42235   case X86ISD::FHADD:
42236   case X86ISD::FHSUB: {
42237     SDValue N0 = Op.getOperand(0);
42238     SDValue N1 = Op.getOperand(1);
42239 
42240     APInt DemandedLHS, DemandedRHS;
42241     getHorizDemandedElts(VT, DemandedElts, DemandedLHS, DemandedRHS);
42242 
42243     APInt LHSUndef, LHSZero;
42244     if (SimplifyDemandedVectorElts(N0, DemandedLHS, LHSUndef, LHSZero, TLO,
42245                                    Depth + 1))
42246       return true;
42247     APInt RHSUndef, RHSZero;
42248     if (SimplifyDemandedVectorElts(N1, DemandedRHS, RHSUndef, RHSZero, TLO,
42249                                    Depth + 1))
42250       return true;
42251 
42252     // TODO - pass on known zero/undef.
42253 
42254     // Aggressively peek through ops to get at the demanded elts.
42255     // TODO: Handle repeated operands.
42256     if (N0 != N1 && !DemandedElts.isAllOnes()) {
42257       SDValue NewN0 = SimplifyMultipleUseDemandedVectorElts(N0, DemandedLHS,
42258                                                             TLO.DAG, Depth + 1);
42259       SDValue NewN1 = SimplifyMultipleUseDemandedVectorElts(N1, DemandedRHS,
42260                                                             TLO.DAG, Depth + 1);
42261       if (NewN0 || NewN1) {
42262         NewN0 = NewN0 ? NewN0 : N0;
42263         NewN1 = NewN1 ? NewN1 : N1;
42264         return TLO.CombineTo(Op,
42265                              TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewN0, NewN1));
42266       }
42267     }
42268     break;
42269   }
42270   case X86ISD::VTRUNC:
42271   case X86ISD::VTRUNCS:
42272   case X86ISD::VTRUNCUS: {
42273     SDValue Src = Op.getOperand(0);
42274     MVT SrcVT = Src.getSimpleValueType();
42275     APInt DemandedSrc = DemandedElts.zextOrTrunc(SrcVT.getVectorNumElements());
42276     APInt SrcUndef, SrcZero;
42277     if (SimplifyDemandedVectorElts(Src, DemandedSrc, SrcUndef, SrcZero, TLO,
42278                                    Depth + 1))
42279       return true;
42280     KnownZero = SrcZero.zextOrTrunc(NumElts);
42281     KnownUndef = SrcUndef.zextOrTrunc(NumElts);
42282     break;
42283   }
42284   case X86ISD::BLENDI: {
42285     SmallVector<int, 16> BlendMask;
42286     DecodeBLENDMask(NumElts, Op.getConstantOperandVal(2), BlendMask);
42287     if (SDValue R = combineBlendOfPermutes(
42288             VT.getSimpleVT(), Op.getOperand(0), Op.getOperand(1), BlendMask,
42289             DemandedElts, TLO.DAG, Subtarget, SDLoc(Op)))
42290       return TLO.CombineTo(Op, R);
42291     break;
42292   }
42293   case X86ISD::BLENDV: {
42294     APInt SelUndef, SelZero;
42295     if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, SelUndef,
42296                                    SelZero, TLO, Depth + 1))
42297       return true;
42298 
42299     // TODO: Use SelZero to adjust LHS/RHS DemandedElts.
42300     APInt LHSUndef, LHSZero;
42301     if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, LHSUndef,
42302                                    LHSZero, TLO, Depth + 1))
42303       return true;
42304 
42305     APInt RHSUndef, RHSZero;
42306     if (SimplifyDemandedVectorElts(Op.getOperand(2), DemandedElts, RHSUndef,
42307                                    RHSZero, TLO, Depth + 1))
42308       return true;
42309 
42310     KnownZero = LHSZero & RHSZero;
42311     KnownUndef = LHSUndef & RHSUndef;
42312     break;
42313   }
42314   case X86ISD::VZEXT_MOVL: {
42315     // If upper demanded elements are already zero then we have nothing to do.
42316     SDValue Src = Op.getOperand(0);
42317     APInt DemandedUpperElts = DemandedElts;
42318     DemandedUpperElts.clearLowBits(1);
42319     if (TLO.DAG.MaskedVectorIsZero(Src, DemandedUpperElts, Depth + 1))
42320       return TLO.CombineTo(Op, Src);
42321     break;
42322   }
42323   case X86ISD::VZEXT_LOAD: {
42324     // If upper demanded elements are not demanded then simplify to a
42325     // scalar_to_vector(load()).
42326     MVT SVT = VT.getSimpleVT().getVectorElementType();
42327     if (DemandedElts == 1 && Op.getValue(1).use_empty() && isTypeLegal(SVT)) {
42328       SDLoc DL(Op);
42329       auto *Mem = cast<MemSDNode>(Op);
42330       SDValue Elt = TLO.DAG.getLoad(SVT, DL, Mem->getChain(), Mem->getBasePtr(),
42331                                     Mem->getMemOperand());
42332       SDValue Vec = TLO.DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Elt);
42333       return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Vec));
42334     }
42335     break;
42336   }
42337   case X86ISD::VBROADCAST: {
42338     SDValue Src = Op.getOperand(0);
42339     MVT SrcVT = Src.getSimpleValueType();
42340     if (!SrcVT.isVector())
42341       break;
42342     // Don't bother broadcasting if we just need the 0'th element.
42343     if (DemandedElts == 1) {
42344       if (Src.getValueType() != VT)
42345         Src = widenSubVector(VT.getSimpleVT(), Src, false, Subtarget, TLO.DAG,
42346                              SDLoc(Op));
42347       return TLO.CombineTo(Op, Src);
42348     }
42349     APInt SrcUndef, SrcZero;
42350     APInt SrcElts = APInt::getOneBitSet(SrcVT.getVectorNumElements(), 0);
42351     if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO,
42352                                    Depth + 1))
42353       return true;
42354     // Aggressively peek through src to get at the demanded elt.
42355     // TODO - we should do this for all target/faux shuffles ops.
42356     if (SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
42357             Src, SrcElts, TLO.DAG, Depth + 1))
42358       return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewSrc));
42359     break;
42360   }
42361   case X86ISD::VPERMV:
42362     if (SimplifyDemandedVectorEltsForTargetShuffle(Op, DemandedElts, 0, TLO,
42363                                                    Depth))
42364       return true;
42365     break;
42366   case X86ISD::PSHUFB:
42367   case X86ISD::VPERMV3:
42368   case X86ISD::VPERMILPV:
42369     if (SimplifyDemandedVectorEltsForTargetShuffle(Op, DemandedElts, 1, TLO,
42370                                                    Depth))
42371       return true;
42372     break;
42373   case X86ISD::VPPERM:
42374   case X86ISD::VPERMIL2:
42375     if (SimplifyDemandedVectorEltsForTargetShuffle(Op, DemandedElts, 2, TLO,
42376                                                    Depth))
42377       return true;
42378     break;
42379   }
42380 
42381   // For 256/512-bit ops that are 128/256-bit ops glued together, if we do not
42382   // demand any of the high elements, then narrow the op to 128/256-bits: e.g.
42383   // (op ymm0, ymm1) --> insert undef, (op xmm0, xmm1), 0
42384   if ((VT.is256BitVector() || VT.is512BitVector()) &&
42385       DemandedElts.lshr(NumElts / 2) == 0) {
42386     unsigned SizeInBits = VT.getSizeInBits();
42387     unsigned ExtSizeInBits = SizeInBits / 2;
42388 
42389     // See if 512-bit ops only use the bottom 128-bits.
42390     if (VT.is512BitVector() && DemandedElts.lshr(NumElts / 4) == 0)
42391       ExtSizeInBits = SizeInBits / 4;
42392 
42393     switch (Opc) {
42394       // Scalar broadcast.
42395     case X86ISD::VBROADCAST: {
42396       SDLoc DL(Op);
42397       SDValue Src = Op.getOperand(0);
42398       if (Src.getValueSizeInBits() > ExtSizeInBits)
42399         Src = extractSubVector(Src, 0, TLO.DAG, DL, ExtSizeInBits);
42400       EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
42401                                     ExtSizeInBits / VT.getScalarSizeInBits());
42402       SDValue Bcst = TLO.DAG.getNode(X86ISD::VBROADCAST, DL, BcstVT, Src);
42403       return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0,
42404                                                TLO.DAG, DL, ExtSizeInBits));
42405     }
42406     case X86ISD::VBROADCAST_LOAD: {
42407       SDLoc DL(Op);
42408       auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
42409       EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
42410                                     ExtSizeInBits / VT.getScalarSizeInBits());
42411       SDVTList Tys = TLO.DAG.getVTList(BcstVT, MVT::Other);
42412       SDValue Ops[] = {MemIntr->getOperand(0), MemIntr->getOperand(1)};
42413       SDValue Bcst = TLO.DAG.getMemIntrinsicNode(
42414           X86ISD::VBROADCAST_LOAD, DL, Tys, Ops, MemIntr->getMemoryVT(),
42415           MemIntr->getMemOperand());
42416       TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
42417                                            Bcst.getValue(1));
42418       return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0,
42419                                                TLO.DAG, DL, ExtSizeInBits));
42420     }
42421       // Subvector broadcast.
42422     case X86ISD::SUBV_BROADCAST_LOAD: {
42423       auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
42424       EVT MemVT = MemIntr->getMemoryVT();
42425       if (ExtSizeInBits == MemVT.getStoreSizeInBits()) {
42426         SDLoc DL(Op);
42427         SDValue Ld =
42428             TLO.DAG.getLoad(MemVT, DL, MemIntr->getChain(),
42429                             MemIntr->getBasePtr(), MemIntr->getMemOperand());
42430         TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
42431                                              Ld.getValue(1));
42432         return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Ld, 0,
42433                                                  TLO.DAG, DL, ExtSizeInBits));
42434       } else if ((ExtSizeInBits % MemVT.getStoreSizeInBits()) == 0) {
42435         SDLoc DL(Op);
42436         EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
42437                                       ExtSizeInBits / VT.getScalarSizeInBits());
42438         if (SDValue BcstLd =
42439                 getBROADCAST_LOAD(Opc, DL, BcstVT, MemVT, MemIntr, 0, TLO.DAG))
42440           return TLO.CombineTo(Op,
42441                                insertSubVector(TLO.DAG.getUNDEF(VT), BcstLd, 0,
42442                                                TLO.DAG, DL, ExtSizeInBits));
42443       }
42444       break;
42445     }
42446       // Byte shifts by immediate.
42447     case X86ISD::VSHLDQ:
42448     case X86ISD::VSRLDQ:
42449       // Shift by uniform.
42450     case X86ISD::VSHL:
42451     case X86ISD::VSRL:
42452     case X86ISD::VSRA:
42453       // Shift by immediate.
42454     case X86ISD::VSHLI:
42455     case X86ISD::VSRLI:
42456     case X86ISD::VSRAI: {
42457       SDLoc DL(Op);
42458       SDValue Ext0 =
42459           extractSubVector(Op.getOperand(0), 0, TLO.DAG, DL, ExtSizeInBits);
42460       SDValue ExtOp =
42461           TLO.DAG.getNode(Opc, DL, Ext0.getValueType(), Ext0, Op.getOperand(1));
42462       SDValue UndefVec = TLO.DAG.getUNDEF(VT);
42463       SDValue Insert =
42464           insertSubVector(UndefVec, ExtOp, 0, TLO.DAG, DL, ExtSizeInBits);
42465       return TLO.CombineTo(Op, Insert);
42466     }
42467     case X86ISD::VPERMI: {
42468       // Simplify PERMPD/PERMQ to extract_subvector.
42469       // TODO: This should be done in shuffle combining.
42470       if (VT == MVT::v4f64 || VT == MVT::v4i64) {
42471         SmallVector<int, 4> Mask;
42472         DecodeVPERMMask(NumElts, Op.getConstantOperandVal(1), Mask);
42473         if (isUndefOrEqual(Mask[0], 2) && isUndefOrEqual(Mask[1], 3)) {
42474           SDLoc DL(Op);
42475           SDValue Ext = extractSubVector(Op.getOperand(0), 2, TLO.DAG, DL, 128);
42476           SDValue UndefVec = TLO.DAG.getUNDEF(VT);
42477           SDValue Insert = insertSubVector(UndefVec, Ext, 0, TLO.DAG, DL, 128);
42478           return TLO.CombineTo(Op, Insert);
42479         }
42480       }
42481       break;
42482     }
42483     case X86ISD::VPERM2X128: {
42484       // Simplify VPERM2F128/VPERM2I128 to extract_subvector.
42485       SDLoc DL(Op);
42486       unsigned LoMask = Op.getConstantOperandVal(2) & 0xF;
42487       if (LoMask & 0x8)
42488         return TLO.CombineTo(
42489             Op, getZeroVector(VT.getSimpleVT(), Subtarget, TLO.DAG, DL));
42490       unsigned EltIdx = (LoMask & 0x1) * (NumElts / 2);
42491       unsigned SrcIdx = (LoMask & 0x2) >> 1;
42492       SDValue ExtOp =
42493           extractSubVector(Op.getOperand(SrcIdx), EltIdx, TLO.DAG, DL, 128);
42494       SDValue UndefVec = TLO.DAG.getUNDEF(VT);
42495       SDValue Insert =
42496           insertSubVector(UndefVec, ExtOp, 0, TLO.DAG, DL, ExtSizeInBits);
42497       return TLO.CombineTo(Op, Insert);
42498     }
42499       // Zero upper elements.
42500     case X86ISD::VZEXT_MOVL:
42501       // Target unary shuffles by immediate:
42502     case X86ISD::PSHUFD:
42503     case X86ISD::PSHUFLW:
42504     case X86ISD::PSHUFHW:
42505     case X86ISD::VPERMILPI:
42506       // (Non-Lane Crossing) Target Shuffles.
42507     case X86ISD::VPERMILPV:
42508     case X86ISD::VPERMIL2:
42509     case X86ISD::PSHUFB:
42510     case X86ISD::UNPCKL:
42511     case X86ISD::UNPCKH:
42512     case X86ISD::BLENDI:
42513       // Integer ops.
42514     case X86ISD::PACKSS:
42515     case X86ISD::PACKUS:
42516     case X86ISD::PCMPEQ:
42517     case X86ISD::PCMPGT:
42518     case X86ISD::PMULUDQ:
42519     case X86ISD::PMULDQ:
42520     case X86ISD::VSHLV:
42521     case X86ISD::VSRLV:
42522     case X86ISD::VSRAV:
42523       // Float ops.
42524     case X86ISD::FMAX:
42525     case X86ISD::FMIN:
42526     case X86ISD::FMAXC:
42527     case X86ISD::FMINC:
42528       // Horizontal Ops.
42529     case X86ISD::HADD:
42530     case X86ISD::HSUB:
42531     case X86ISD::FHADD:
42532     case X86ISD::FHSUB: {
42533       SDLoc DL(Op);
42534       SmallVector<SDValue, 4> Ops;
42535       for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
42536         SDValue SrcOp = Op.getOperand(i);
42537         EVT SrcVT = SrcOp.getValueType();
42538         assert((!SrcVT.isVector() || SrcVT.getSizeInBits() == SizeInBits) &&
42539                "Unsupported vector size");
42540         Ops.push_back(SrcVT.isVector() ? extractSubVector(SrcOp, 0, TLO.DAG, DL,
42541                                                           ExtSizeInBits)
42542                                        : SrcOp);
42543       }
42544       MVT ExtVT = VT.getSimpleVT();
42545       ExtVT = MVT::getVectorVT(ExtVT.getScalarType(),
42546                                ExtSizeInBits / ExtVT.getScalarSizeInBits());
42547       SDValue ExtOp = TLO.DAG.getNode(Opc, DL, ExtVT, Ops);
42548       SDValue UndefVec = TLO.DAG.getUNDEF(VT);
42549       SDValue Insert =
42550           insertSubVector(UndefVec, ExtOp, 0, TLO.DAG, DL, ExtSizeInBits);
42551       return TLO.CombineTo(Op, Insert);
42552     }
42553     }
42554   }
42555 
42556   // For splats, unless we *only* demand the 0'th element,
42557   // stop attempts at simplification here, we aren't going to improve things,
42558   // this is better than any potential shuffle.
42559   if (!DemandedElts.isOne() && TLO.DAG.isSplatValue(Op, /*AllowUndefs*/false))
42560     return false;
42561 
42562   // Get target/faux shuffle mask.
42563   APInt OpUndef, OpZero;
42564   SmallVector<int, 64> OpMask;
42565   SmallVector<SDValue, 2> OpInputs;
42566   if (!getTargetShuffleInputs(Op, DemandedElts, OpInputs, OpMask, OpUndef,
42567                               OpZero, TLO.DAG, Depth, false))
42568     return false;
42569 
42570   // Shuffle inputs must be the same size as the result.
42571   if (OpMask.size() != (unsigned)NumElts ||
42572       llvm::any_of(OpInputs, [VT](SDValue V) {
42573         return VT.getSizeInBits() != V.getValueSizeInBits() ||
42574                !V.getValueType().isVector();
42575       }))
42576     return false;
42577 
42578   KnownZero = OpZero;
42579   KnownUndef = OpUndef;
42580 
42581   // Check if shuffle mask can be simplified to undef/zero/identity.
42582   int NumSrcs = OpInputs.size();
42583   for (int i = 0; i != NumElts; ++i)
42584     if (!DemandedElts[i])
42585       OpMask[i] = SM_SentinelUndef;
42586 
42587   if (isUndefInRange(OpMask, 0, NumElts)) {
42588     KnownUndef.setAllBits();
42589     return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
42590   }
42591   if (isUndefOrZeroInRange(OpMask, 0, NumElts)) {
42592     KnownZero.setAllBits();
42593     return TLO.CombineTo(
42594         Op, getZeroVector(VT.getSimpleVT(), Subtarget, TLO.DAG, SDLoc(Op)));
42595   }
42596   for (int Src = 0; Src != NumSrcs; ++Src)
42597     if (isSequentialOrUndefInRange(OpMask, 0, NumElts, Src * NumElts))
42598       return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, OpInputs[Src]));
42599 
42600   // Attempt to simplify inputs.
42601   for (int Src = 0; Src != NumSrcs; ++Src) {
42602     // TODO: Support inputs of different types.
42603     if (OpInputs[Src].getValueType() != VT)
42604       continue;
42605 
42606     int Lo = Src * NumElts;
42607     APInt SrcElts = APInt::getZero(NumElts);
42608     for (int i = 0; i != NumElts; ++i)
42609       if (DemandedElts[i]) {
42610         int M = OpMask[i] - Lo;
42611         if (0 <= M && M < NumElts)
42612           SrcElts.setBit(M);
42613       }
42614 
42615     // TODO - Propagate input undef/zero elts.
42616     APInt SrcUndef, SrcZero;
42617     if (SimplifyDemandedVectorElts(OpInputs[Src], SrcElts, SrcUndef, SrcZero,
42618                                    TLO, Depth + 1))
42619       return true;
42620   }
42621 
42622   // If we don't demand all elements, then attempt to combine to a simpler
42623   // shuffle.
42624   // We need to convert the depth to something combineX86ShufflesRecursively
42625   // can handle - so pretend its Depth == 0 again, and reduce the max depth
42626   // to match. This prevents combineX86ShuffleChain from returning a
42627   // combined shuffle that's the same as the original root, causing an
42628   // infinite loop.
42629   if (!DemandedElts.isAllOnes()) {
42630     assert(Depth < X86::MaxShuffleCombineDepth && "Depth out of range");
42631 
42632     SmallVector<int, 64> DemandedMask(NumElts, SM_SentinelUndef);
42633     for (int i = 0; i != NumElts; ++i)
42634       if (DemandedElts[i])
42635         DemandedMask[i] = i;
42636 
42637     SDValue NewShuffle = combineX86ShufflesRecursively(
42638         {Op}, 0, Op, DemandedMask, {}, 0, X86::MaxShuffleCombineDepth - Depth,
42639         /*HasVarMask*/ false,
42640         /*AllowCrossLaneVarMask*/ true, /*AllowPerLaneVarMask*/ true, TLO.DAG,
42641         Subtarget);
42642     if (NewShuffle)
42643       return TLO.CombineTo(Op, NewShuffle);
42644   }
42645 
42646   return false;
42647 }
42648 
SimplifyDemandedBitsForTargetNode(SDValue Op,const APInt & OriginalDemandedBits,const APInt & OriginalDemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth) const42649 bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
42650     SDValue Op, const APInt &OriginalDemandedBits,
42651     const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
42652     unsigned Depth) const {
42653   EVT VT = Op.getValueType();
42654   unsigned BitWidth = OriginalDemandedBits.getBitWidth();
42655   unsigned Opc = Op.getOpcode();
42656   switch(Opc) {
42657   case X86ISD::VTRUNC: {
42658     KnownBits KnownOp;
42659     SDValue Src = Op.getOperand(0);
42660     MVT SrcVT = Src.getSimpleValueType();
42661 
42662     // Simplify the input, using demanded bit information.
42663     APInt TruncMask = OriginalDemandedBits.zext(SrcVT.getScalarSizeInBits());
42664     APInt DemandedElts = OriginalDemandedElts.trunc(SrcVT.getVectorNumElements());
42665     if (SimplifyDemandedBits(Src, TruncMask, DemandedElts, KnownOp, TLO, Depth + 1))
42666       return true;
42667     break;
42668   }
42669   case X86ISD::PMULDQ:
42670   case X86ISD::PMULUDQ: {
42671     // PMULDQ/PMULUDQ only uses lower 32 bits from each vector element.
42672     KnownBits KnownLHS, KnownRHS;
42673     SDValue LHS = Op.getOperand(0);
42674     SDValue RHS = Op.getOperand(1);
42675 
42676     // Don't mask bits on 32-bit AVX512 targets which might lose a broadcast.
42677     // FIXME: Can we bound this better?
42678     APInt DemandedMask = APInt::getLowBitsSet(64, 32);
42679     APInt DemandedMaskLHS = APInt::getAllOnes(64);
42680     APInt DemandedMaskRHS = APInt::getAllOnes(64);
42681 
42682     bool Is32BitAVX512 = !Subtarget.is64Bit() && Subtarget.hasAVX512();
42683     if (!Is32BitAVX512 || !TLO.DAG.isSplatValue(LHS))
42684       DemandedMaskLHS = DemandedMask;
42685     if (!Is32BitAVX512 || !TLO.DAG.isSplatValue(RHS))
42686       DemandedMaskRHS = DemandedMask;
42687 
42688     if (SimplifyDemandedBits(LHS, DemandedMaskLHS, OriginalDemandedElts,
42689                              KnownLHS, TLO, Depth + 1))
42690       return true;
42691     if (SimplifyDemandedBits(RHS, DemandedMaskRHS, OriginalDemandedElts,
42692                              KnownRHS, TLO, Depth + 1))
42693       return true;
42694 
42695     // PMULUDQ(X,1) -> AND(X,(1<<32)-1) 'getZeroExtendInReg'.
42696     KnownRHS = KnownRHS.trunc(32);
42697     if (Opc == X86ISD::PMULUDQ && KnownRHS.isConstant() &&
42698         KnownRHS.getConstant().isOne()) {
42699       SDLoc DL(Op);
42700       SDValue Mask = TLO.DAG.getConstant(DemandedMask, DL, VT);
42701       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, DL, VT, LHS, Mask));
42702     }
42703 
42704     // Aggressively peek through ops to get at the demanded low bits.
42705     SDValue DemandedLHS = SimplifyMultipleUseDemandedBits(
42706         LHS, DemandedMaskLHS, OriginalDemandedElts, TLO.DAG, Depth + 1);
42707     SDValue DemandedRHS = SimplifyMultipleUseDemandedBits(
42708         RHS, DemandedMaskRHS, OriginalDemandedElts, TLO.DAG, Depth + 1);
42709     if (DemandedLHS || DemandedRHS) {
42710       DemandedLHS = DemandedLHS ? DemandedLHS : LHS;
42711       DemandedRHS = DemandedRHS ? DemandedRHS : RHS;
42712       return TLO.CombineTo(
42713           Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, DemandedLHS, DemandedRHS));
42714     }
42715     break;
42716   }
42717   case X86ISD::ANDNP: {
42718     KnownBits Known2;
42719     SDValue Op0 = Op.getOperand(0);
42720     SDValue Op1 = Op.getOperand(1);
42721 
42722     if (SimplifyDemandedBits(Op1, OriginalDemandedBits, OriginalDemandedElts,
42723                              Known, TLO, Depth + 1))
42724       return true;
42725 
42726     if (SimplifyDemandedBits(Op0, ~Known.Zero & OriginalDemandedBits,
42727                              OriginalDemandedElts, Known2, TLO, Depth + 1))
42728       return true;
42729 
42730     // If the RHS is a constant, see if we can simplify it.
42731     if (ShrinkDemandedConstant(Op, ~Known2.One & OriginalDemandedBits,
42732                                OriginalDemandedElts, TLO))
42733       return true;
42734 
42735     // ANDNP = (~Op0 & Op1);
42736     Known.One &= Known2.Zero;
42737     Known.Zero |= Known2.One;
42738     break;
42739   }
42740   case X86ISD::VSHLI: {
42741     SDValue Op0 = Op.getOperand(0);
42742 
42743     unsigned ShAmt = Op.getConstantOperandVal(1);
42744     if (ShAmt >= BitWidth)
42745       break;
42746 
42747     APInt DemandedMask = OriginalDemandedBits.lshr(ShAmt);
42748 
42749     // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
42750     // single shift.  We can do this if the bottom bits (which are shifted
42751     // out) are never demanded.
42752     if (Op0.getOpcode() == X86ISD::VSRLI &&
42753         OriginalDemandedBits.countr_zero() >= ShAmt) {
42754       unsigned Shift2Amt = Op0.getConstantOperandVal(1);
42755       if (Shift2Amt < BitWidth) {
42756         int Diff = ShAmt - Shift2Amt;
42757         if (Diff == 0)
42758           return TLO.CombineTo(Op, Op0.getOperand(0));
42759 
42760         unsigned NewOpc = Diff < 0 ? X86ISD::VSRLI : X86ISD::VSHLI;
42761         SDValue NewShift = TLO.DAG.getNode(
42762             NewOpc, SDLoc(Op), VT, Op0.getOperand(0),
42763             TLO.DAG.getTargetConstant(std::abs(Diff), SDLoc(Op), MVT::i8));
42764         return TLO.CombineTo(Op, NewShift);
42765       }
42766     }
42767 
42768     // If we are only demanding sign bits then we can use the shift source directly.
42769     unsigned NumSignBits =
42770         TLO.DAG.ComputeNumSignBits(Op0, OriginalDemandedElts, Depth + 1);
42771     unsigned UpperDemandedBits = BitWidth - OriginalDemandedBits.countr_zero();
42772     if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= UpperDemandedBits)
42773       return TLO.CombineTo(Op, Op0);
42774 
42775     if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known,
42776                              TLO, Depth + 1))
42777       return true;
42778 
42779     Known.Zero <<= ShAmt;
42780     Known.One <<= ShAmt;
42781 
42782     // Low bits known zero.
42783     Known.Zero.setLowBits(ShAmt);
42784     return false;
42785   }
42786   case X86ISD::VSRLI: {
42787     unsigned ShAmt = Op.getConstantOperandVal(1);
42788     if (ShAmt >= BitWidth)
42789       break;
42790 
42791     APInt DemandedMask = OriginalDemandedBits << ShAmt;
42792 
42793     if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask,
42794                              OriginalDemandedElts, Known, TLO, Depth + 1))
42795       return true;
42796 
42797     Known.Zero.lshrInPlace(ShAmt);
42798     Known.One.lshrInPlace(ShAmt);
42799 
42800     // High bits known zero.
42801     Known.Zero.setHighBits(ShAmt);
42802     return false;
42803   }
42804   case X86ISD::VSRAI: {
42805     SDValue Op0 = Op.getOperand(0);
42806     SDValue Op1 = Op.getOperand(1);
42807 
42808     unsigned ShAmt = Op1->getAsZExtVal();
42809     if (ShAmt >= BitWidth)
42810       break;
42811 
42812     APInt DemandedMask = OriginalDemandedBits << ShAmt;
42813 
42814     // If we just want the sign bit then we don't need to shift it.
42815     if (OriginalDemandedBits.isSignMask())
42816       return TLO.CombineTo(Op, Op0);
42817 
42818     // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1
42819     if (Op0.getOpcode() == X86ISD::VSHLI &&
42820         Op.getOperand(1) == Op0.getOperand(1)) {
42821       SDValue Op00 = Op0.getOperand(0);
42822       unsigned NumSignBits =
42823           TLO.DAG.ComputeNumSignBits(Op00, OriginalDemandedElts);
42824       if (ShAmt < NumSignBits)
42825         return TLO.CombineTo(Op, Op00);
42826     }
42827 
42828     // If any of the demanded bits are produced by the sign extension, we also
42829     // demand the input sign bit.
42830     if (OriginalDemandedBits.countl_zero() < ShAmt)
42831       DemandedMask.setSignBit();
42832 
42833     if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known,
42834                              TLO, Depth + 1))
42835       return true;
42836 
42837     Known.Zero.lshrInPlace(ShAmt);
42838     Known.One.lshrInPlace(ShAmt);
42839 
42840     // If the input sign bit is known to be zero, or if none of the top bits
42841     // are demanded, turn this into an unsigned shift right.
42842     if (Known.Zero[BitWidth - ShAmt - 1] ||
42843         OriginalDemandedBits.countl_zero() >= ShAmt)
42844       return TLO.CombineTo(
42845           Op, TLO.DAG.getNode(X86ISD::VSRLI, SDLoc(Op), VT, Op0, Op1));
42846 
42847     // High bits are known one.
42848     if (Known.One[BitWidth - ShAmt - 1])
42849       Known.One.setHighBits(ShAmt);
42850     return false;
42851   }
42852   case X86ISD::BLENDV: {
42853     SDValue Sel = Op.getOperand(0);
42854     SDValue LHS = Op.getOperand(1);
42855     SDValue RHS = Op.getOperand(2);
42856 
42857     APInt SignMask = APInt::getSignMask(BitWidth);
42858     SDValue NewSel = SimplifyMultipleUseDemandedBits(
42859         Sel, SignMask, OriginalDemandedElts, TLO.DAG, Depth + 1);
42860     SDValue NewLHS = SimplifyMultipleUseDemandedBits(
42861         LHS, OriginalDemandedBits, OriginalDemandedElts, TLO.DAG, Depth + 1);
42862     SDValue NewRHS = SimplifyMultipleUseDemandedBits(
42863         RHS, OriginalDemandedBits, OriginalDemandedElts, TLO.DAG, Depth + 1);
42864 
42865     if (NewSel || NewLHS || NewRHS) {
42866       NewSel = NewSel ? NewSel : Sel;
42867       NewLHS = NewLHS ? NewLHS : LHS;
42868       NewRHS = NewRHS ? NewRHS : RHS;
42869       return TLO.CombineTo(Op, TLO.DAG.getNode(X86ISD::BLENDV, SDLoc(Op), VT,
42870                                                NewSel, NewLHS, NewRHS));
42871     }
42872     break;
42873   }
42874   case X86ISD::PEXTRB:
42875   case X86ISD::PEXTRW: {
42876     SDValue Vec = Op.getOperand(0);
42877     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(1));
42878     MVT VecVT = Vec.getSimpleValueType();
42879     unsigned NumVecElts = VecVT.getVectorNumElements();
42880 
42881     if (CIdx && CIdx->getAPIntValue().ult(NumVecElts)) {
42882       unsigned Idx = CIdx->getZExtValue();
42883       unsigned VecBitWidth = VecVT.getScalarSizeInBits();
42884 
42885       // If we demand no bits from the vector then we must have demanded
42886       // bits from the implict zext - simplify to zero.
42887       APInt DemandedVecBits = OriginalDemandedBits.trunc(VecBitWidth);
42888       if (DemandedVecBits == 0)
42889         return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
42890 
42891       APInt KnownUndef, KnownZero;
42892       APInt DemandedVecElts = APInt::getOneBitSet(NumVecElts, Idx);
42893       if (SimplifyDemandedVectorElts(Vec, DemandedVecElts, KnownUndef,
42894                                      KnownZero, TLO, Depth + 1))
42895         return true;
42896 
42897       KnownBits KnownVec;
42898       if (SimplifyDemandedBits(Vec, DemandedVecBits, DemandedVecElts,
42899                                KnownVec, TLO, Depth + 1))
42900         return true;
42901 
42902       if (SDValue V = SimplifyMultipleUseDemandedBits(
42903               Vec, DemandedVecBits, DemandedVecElts, TLO.DAG, Depth + 1))
42904         return TLO.CombineTo(
42905             Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, V, Op.getOperand(1)));
42906 
42907       Known = KnownVec.zext(BitWidth);
42908       return false;
42909     }
42910     break;
42911   }
42912   case X86ISD::PINSRB:
42913   case X86ISD::PINSRW: {
42914     SDValue Vec = Op.getOperand(0);
42915     SDValue Scl = Op.getOperand(1);
42916     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
42917     MVT VecVT = Vec.getSimpleValueType();
42918 
42919     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements())) {
42920       unsigned Idx = CIdx->getZExtValue();
42921       if (!OriginalDemandedElts[Idx])
42922         return TLO.CombineTo(Op, Vec);
42923 
42924       KnownBits KnownVec;
42925       APInt DemandedVecElts(OriginalDemandedElts);
42926       DemandedVecElts.clearBit(Idx);
42927       if (SimplifyDemandedBits(Vec, OriginalDemandedBits, DemandedVecElts,
42928                                KnownVec, TLO, Depth + 1))
42929         return true;
42930 
42931       KnownBits KnownScl;
42932       unsigned NumSclBits = Scl.getScalarValueSizeInBits();
42933       APInt DemandedSclBits = OriginalDemandedBits.zext(NumSclBits);
42934       if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
42935         return true;
42936 
42937       KnownScl = KnownScl.trunc(VecVT.getScalarSizeInBits());
42938       Known = KnownVec.intersectWith(KnownScl);
42939       return false;
42940     }
42941     break;
42942   }
42943   case X86ISD::PACKSS:
42944     // PACKSS saturates to MIN/MAX integer values. So if we just want the
42945     // sign bit then we can just ask for the source operands sign bit.
42946     // TODO - add known bits handling.
42947     if (OriginalDemandedBits.isSignMask()) {
42948       APInt DemandedLHS, DemandedRHS;
42949       getPackDemandedElts(VT, OriginalDemandedElts, DemandedLHS, DemandedRHS);
42950 
42951       KnownBits KnownLHS, KnownRHS;
42952       APInt SignMask = APInt::getSignMask(BitWidth * 2);
42953       if (SimplifyDemandedBits(Op.getOperand(0), SignMask, DemandedLHS,
42954                                KnownLHS, TLO, Depth + 1))
42955         return true;
42956       if (SimplifyDemandedBits(Op.getOperand(1), SignMask, DemandedRHS,
42957                                KnownRHS, TLO, Depth + 1))
42958         return true;
42959 
42960       // Attempt to avoid multi-use ops if we don't need anything from them.
42961       SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
42962           Op.getOperand(0), SignMask, DemandedLHS, TLO.DAG, Depth + 1);
42963       SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
42964           Op.getOperand(1), SignMask, DemandedRHS, TLO.DAG, Depth + 1);
42965       if (DemandedOp0 || DemandedOp1) {
42966         SDValue Op0 = DemandedOp0 ? DemandedOp0 : Op.getOperand(0);
42967         SDValue Op1 = DemandedOp1 ? DemandedOp1 : Op.getOperand(1);
42968         return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, Op0, Op1));
42969       }
42970     }
42971     // TODO - add general PACKSS/PACKUS SimplifyDemandedBits support.
42972     break;
42973   case X86ISD::VBROADCAST: {
42974     SDValue Src = Op.getOperand(0);
42975     MVT SrcVT = Src.getSimpleValueType();
42976     APInt DemandedElts = APInt::getOneBitSet(
42977         SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1, 0);
42978     if (SimplifyDemandedBits(Src, OriginalDemandedBits, DemandedElts, Known,
42979                              TLO, Depth + 1))
42980       return true;
42981     // If we don't need the upper bits, attempt to narrow the broadcast source.
42982     // Don't attempt this on AVX512 as it might affect broadcast folding.
42983     // TODO: Should we attempt this for i32/i16 splats? They tend to be slower.
42984     if ((BitWidth == 64) && SrcVT.isScalarInteger() && !Subtarget.hasAVX512() &&
42985         OriginalDemandedBits.countl_zero() >= (BitWidth / 2) &&
42986         Src->hasOneUse()) {
42987       MVT NewSrcVT = MVT::getIntegerVT(BitWidth / 2);
42988       SDValue NewSrc =
42989           TLO.DAG.getNode(ISD::TRUNCATE, SDLoc(Src), NewSrcVT, Src);
42990       MVT NewVT = MVT::getVectorVT(NewSrcVT, VT.getVectorNumElements() * 2);
42991       SDValue NewBcst =
42992           TLO.DAG.getNode(X86ISD::VBROADCAST, SDLoc(Op), NewVT, NewSrc);
42993       return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, NewBcst));
42994     }
42995     break;
42996   }
42997   case X86ISD::PCMPGT:
42998     // icmp sgt(0, R) == ashr(R, BitWidth-1).
42999     // iff we only need the sign bit then we can use R directly.
43000     if (OriginalDemandedBits.isSignMask() &&
43001         ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
43002       return TLO.CombineTo(Op, Op.getOperand(1));
43003     break;
43004   case X86ISD::MOVMSK: {
43005     SDValue Src = Op.getOperand(0);
43006     MVT SrcVT = Src.getSimpleValueType();
43007     unsigned SrcBits = SrcVT.getScalarSizeInBits();
43008     unsigned NumElts = SrcVT.getVectorNumElements();
43009 
43010     // If we don't need the sign bits at all just return zero.
43011     if (OriginalDemandedBits.countr_zero() >= NumElts)
43012       return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
43013 
43014     // See if we only demand bits from the lower 128-bit vector.
43015     if (SrcVT.is256BitVector() &&
43016         OriginalDemandedBits.getActiveBits() <= (NumElts / 2)) {
43017       SDValue NewSrc = extract128BitVector(Src, 0, TLO.DAG, SDLoc(Src));
43018       return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewSrc));
43019     }
43020 
43021     // Only demand the vector elements of the sign bits we need.
43022     APInt KnownUndef, KnownZero;
43023     APInt DemandedElts = OriginalDemandedBits.zextOrTrunc(NumElts);
43024     if (SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef, KnownZero,
43025                                    TLO, Depth + 1))
43026       return true;
43027 
43028     Known.Zero = KnownZero.zext(BitWidth);
43029     Known.Zero.setHighBits(BitWidth - NumElts);
43030 
43031     // MOVMSK only uses the MSB from each vector element.
43032     KnownBits KnownSrc;
43033     APInt DemandedSrcBits = APInt::getSignMask(SrcBits);
43034     if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, KnownSrc, TLO,
43035                              Depth + 1))
43036       return true;
43037 
43038     if (KnownSrc.One[SrcBits - 1])
43039       Known.One.setLowBits(NumElts);
43040     else if (KnownSrc.Zero[SrcBits - 1])
43041       Known.Zero.setLowBits(NumElts);
43042 
43043     // Attempt to avoid multi-use os if we don't need anything from it.
43044     if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
43045             Src, DemandedSrcBits, DemandedElts, TLO.DAG, Depth + 1))
43046       return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, SDLoc(Op), VT, NewSrc));
43047     return false;
43048   }
43049   case X86ISD::TESTP: {
43050     SDValue Op0 = Op.getOperand(0);
43051     SDValue Op1 = Op.getOperand(1);
43052     MVT OpVT = Op0.getSimpleValueType();
43053     assert((OpVT.getVectorElementType() == MVT::f32 ||
43054             OpVT.getVectorElementType() == MVT::f64) &&
43055            "Illegal vector type for X86ISD::TESTP");
43056 
43057     // TESTPS/TESTPD only demands the sign bits of ALL the elements.
43058     KnownBits KnownSrc;
43059     APInt SignMask = APInt::getSignMask(OpVT.getScalarSizeInBits());
43060     bool AssumeSingleUse = (Op0 == Op1) && Op->isOnlyUserOf(Op0.getNode());
43061     return SimplifyDemandedBits(Op0, SignMask, KnownSrc, TLO, Depth + 1,
43062                                 AssumeSingleUse) ||
43063            SimplifyDemandedBits(Op1, SignMask, KnownSrc, TLO, Depth + 1,
43064                                 AssumeSingleUse);
43065   }
43066   case X86ISD::CMOV: {
43067     KnownBits Known2;
43068     if (SimplifyDemandedBits(Op.getOperand(1), OriginalDemandedBits,
43069                              OriginalDemandedElts, Known2, TLO, Depth + 1))
43070       return true;
43071     if (SimplifyDemandedBits(Op.getOperand(0), OriginalDemandedBits,
43072                              OriginalDemandedElts, Known, TLO, Depth + 1))
43073       return true;
43074 
43075     // Only known if known in both the LHS and RHS.
43076     Known = Known.intersectWith(Known2);
43077     break;
43078   }
43079   case X86ISD::BEXTR:
43080   case X86ISD::BEXTRI: {
43081     SDValue Op0 = Op.getOperand(0);
43082     SDValue Op1 = Op.getOperand(1);
43083 
43084     // Only bottom 16-bits of the control bits are required.
43085     if (auto *Cst1 = dyn_cast<ConstantSDNode>(Op1)) {
43086       // NOTE: SimplifyDemandedBits won't do this for constants.
43087       uint64_t Val1 = Cst1->getZExtValue();
43088       uint64_t MaskedVal1 = Val1 & 0xFFFF;
43089       if (Opc == X86ISD::BEXTR && MaskedVal1 != Val1) {
43090         SDLoc DL(Op);
43091         return TLO.CombineTo(
43092             Op, TLO.DAG.getNode(X86ISD::BEXTR, DL, VT, Op0,
43093                                 TLO.DAG.getConstant(MaskedVal1, DL, VT)));
43094       }
43095 
43096       unsigned Shift = Cst1->getAPIntValue().extractBitsAsZExtValue(8, 0);
43097       unsigned Length = Cst1->getAPIntValue().extractBitsAsZExtValue(8, 8);
43098 
43099       // If the length is 0, the result is 0.
43100       if (Length == 0) {
43101         Known.setAllZero();
43102         return false;
43103       }
43104 
43105       if ((Shift + Length) <= BitWidth) {
43106         APInt DemandedMask = APInt::getBitsSet(BitWidth, Shift, Shift + Length);
43107         if (SimplifyDemandedBits(Op0, DemandedMask, Known, TLO, Depth + 1))
43108           return true;
43109 
43110         Known = Known.extractBits(Length, Shift);
43111         Known = Known.zextOrTrunc(BitWidth);
43112         return false;
43113       }
43114     } else {
43115       assert(Opc == X86ISD::BEXTR && "Unexpected opcode!");
43116       KnownBits Known1;
43117       APInt DemandedMask(APInt::getLowBitsSet(BitWidth, 16));
43118       if (SimplifyDemandedBits(Op1, DemandedMask, Known1, TLO, Depth + 1))
43119         return true;
43120 
43121       // If the length is 0, replace with 0.
43122       KnownBits LengthBits = Known1.extractBits(8, 8);
43123       if (LengthBits.isZero())
43124         return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
43125     }
43126 
43127     break;
43128   }
43129   case X86ISD::PDEP: {
43130     SDValue Op0 = Op.getOperand(0);
43131     SDValue Op1 = Op.getOperand(1);
43132 
43133     unsigned DemandedBitsLZ = OriginalDemandedBits.countl_zero();
43134     APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ);
43135 
43136     // If the demanded bits has leading zeroes, we don't demand those from the
43137     // mask.
43138     if (SimplifyDemandedBits(Op1, LoMask, Known, TLO, Depth + 1))
43139       return true;
43140 
43141     // The number of possible 1s in the mask determines the number of LSBs of
43142     // operand 0 used. Undemanded bits from the mask don't matter so filter
43143     // them before counting.
43144     KnownBits Known2;
43145     uint64_t Count = (~Known.Zero & LoMask).popcount();
43146     APInt DemandedMask(APInt::getLowBitsSet(BitWidth, Count));
43147     if (SimplifyDemandedBits(Op0, DemandedMask, Known2, TLO, Depth + 1))
43148       return true;
43149 
43150     // Zeroes are retained from the mask, but not ones.
43151     Known.One.clearAllBits();
43152     // The result will have at least as many trailing zeros as the non-mask
43153     // operand since bits can only map to the same or higher bit position.
43154     Known.Zero.setLowBits(Known2.countMinTrailingZeros());
43155     return false;
43156   }
43157   }
43158 
43159   return TargetLowering::SimplifyDemandedBitsForTargetNode(
43160       Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth);
43161 }
43162 
SimplifyMultipleUseDemandedBitsForTargetNode(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,SelectionDAG & DAG,unsigned Depth) const43163 SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
43164     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
43165     SelectionDAG &DAG, unsigned Depth) const {
43166   int NumElts = DemandedElts.getBitWidth();
43167   unsigned Opc = Op.getOpcode();
43168   EVT VT = Op.getValueType();
43169 
43170   switch (Opc) {
43171   case X86ISD::PINSRB:
43172   case X86ISD::PINSRW: {
43173     // If we don't demand the inserted element, return the base vector.
43174     SDValue Vec = Op.getOperand(0);
43175     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
43176     MVT VecVT = Vec.getSimpleValueType();
43177     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
43178         !DemandedElts[CIdx->getZExtValue()])
43179       return Vec;
43180     break;
43181   }
43182   case X86ISD::VSHLI: {
43183     // If we are only demanding sign bits then we can use the shift source
43184     // directly.
43185     SDValue Op0 = Op.getOperand(0);
43186     unsigned ShAmt = Op.getConstantOperandVal(1);
43187     unsigned BitWidth = DemandedBits.getBitWidth();
43188     unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
43189     unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
43190     if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= UpperDemandedBits)
43191       return Op0;
43192     break;
43193   }
43194   case X86ISD::VSRAI:
43195     // iff we only need the sign bit then we can use the source directly.
43196     // TODO: generalize where we only demand extended signbits.
43197     if (DemandedBits.isSignMask())
43198       return Op.getOperand(0);
43199     break;
43200   case X86ISD::PCMPGT:
43201     // icmp sgt(0, R) == ashr(R, BitWidth-1).
43202     // iff we only need the sign bit then we can use R directly.
43203     if (DemandedBits.isSignMask() &&
43204         ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
43205       return Op.getOperand(1);
43206     break;
43207   case X86ISD::BLENDV: {
43208     // BLENDV: Cond (MSB) ? LHS : RHS
43209     SDValue Cond = Op.getOperand(0);
43210     SDValue LHS = Op.getOperand(1);
43211     SDValue RHS = Op.getOperand(2);
43212 
43213     KnownBits CondKnown = DAG.computeKnownBits(Cond, DemandedElts, Depth + 1);
43214     if (CondKnown.isNegative())
43215       return LHS;
43216     if (CondKnown.isNonNegative())
43217       return RHS;
43218     break;
43219   }
43220   case X86ISD::ANDNP: {
43221     // ANDNP = (~LHS & RHS);
43222     SDValue LHS = Op.getOperand(0);
43223     SDValue RHS = Op.getOperand(1);
43224 
43225     KnownBits LHSKnown = DAG.computeKnownBits(LHS, DemandedElts, Depth + 1);
43226     KnownBits RHSKnown = DAG.computeKnownBits(RHS, DemandedElts, Depth + 1);
43227 
43228     // If all of the demanded bits are known 0 on LHS and known 0 on RHS, then
43229     // the (inverted) LHS bits cannot contribute to the result of the 'andn' in
43230     // this context, so return RHS.
43231     if (DemandedBits.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero))
43232       return RHS;
43233     break;
43234   }
43235   }
43236 
43237   APInt ShuffleUndef, ShuffleZero;
43238   SmallVector<int, 16> ShuffleMask;
43239   SmallVector<SDValue, 2> ShuffleOps;
43240   if (getTargetShuffleInputs(Op, DemandedElts, ShuffleOps, ShuffleMask,
43241                              ShuffleUndef, ShuffleZero, DAG, Depth, false)) {
43242     // If all the demanded elts are from one operand and are inline,
43243     // then we can use the operand directly.
43244     int NumOps = ShuffleOps.size();
43245     if (ShuffleMask.size() == (unsigned)NumElts &&
43246         llvm::all_of(ShuffleOps, [VT](SDValue V) {
43247           return VT.getSizeInBits() == V.getValueSizeInBits();
43248         })) {
43249 
43250       if (DemandedElts.isSubsetOf(ShuffleUndef))
43251         return DAG.getUNDEF(VT);
43252       if (DemandedElts.isSubsetOf(ShuffleUndef | ShuffleZero))
43253         return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, SDLoc(Op));
43254 
43255       // Bitmask that indicates which ops have only been accessed 'inline'.
43256       APInt IdentityOp = APInt::getAllOnes(NumOps);
43257       for (int i = 0; i != NumElts; ++i) {
43258         int M = ShuffleMask[i];
43259         if (!DemandedElts[i] || ShuffleUndef[i])
43260           continue;
43261         int OpIdx = M / NumElts;
43262         int EltIdx = M % NumElts;
43263         if (M < 0 || EltIdx != i) {
43264           IdentityOp.clearAllBits();
43265           break;
43266         }
43267         IdentityOp &= APInt::getOneBitSet(NumOps, OpIdx);
43268         if (IdentityOp == 0)
43269           break;
43270       }
43271       assert((IdentityOp == 0 || IdentityOp.popcount() == 1) &&
43272              "Multiple identity shuffles detected");
43273 
43274       if (IdentityOp != 0)
43275         return DAG.getBitcast(VT, ShuffleOps[IdentityOp.countr_zero()]);
43276     }
43277   }
43278 
43279   return TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
43280       Op, DemandedBits, DemandedElts, DAG, Depth);
43281 }
43282 
isGuaranteedNotToBeUndefOrPoisonForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,bool PoisonOnly,unsigned Depth) const43283 bool X86TargetLowering::isGuaranteedNotToBeUndefOrPoisonForTargetNode(
43284     SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
43285     bool PoisonOnly, unsigned Depth) const {
43286   unsigned NumElts = DemandedElts.getBitWidth();
43287 
43288   switch (Op.getOpcode()) {
43289   case X86ISD::PSHUFD:
43290   case X86ISD::VPERMILPI: {
43291     SmallVector<int, 8> Mask;
43292     SmallVector<SDValue, 2> Ops;
43293     if (getTargetShuffleMask(Op, true, Ops, Mask)) {
43294       SmallVector<APInt, 2> DemandedSrcElts(Ops.size(),
43295                                             APInt::getZero(NumElts));
43296       for (auto M : enumerate(Mask)) {
43297         if (!DemandedElts[M.index()] || M.value() == SM_SentinelZero)
43298           continue;
43299         if (M.value() == SM_SentinelUndef)
43300           return false;
43301         assert(0 <= M.value() && M.value() < (int)(Ops.size() * NumElts) &&
43302                "Shuffle mask index out of range");
43303         DemandedSrcElts[M.value() / NumElts].setBit(M.value() % NumElts);
43304       }
43305       for (auto Op : enumerate(Ops))
43306         if (!DemandedSrcElts[Op.index()].isZero() &&
43307             !DAG.isGuaranteedNotToBeUndefOrPoison(
43308                 Op.value(), DemandedSrcElts[Op.index()], PoisonOnly, Depth + 1))
43309           return false;
43310       return true;
43311     }
43312     break;
43313   }
43314   }
43315   return TargetLowering::isGuaranteedNotToBeUndefOrPoisonForTargetNode(
43316       Op, DemandedElts, DAG, PoisonOnly, Depth);
43317 }
43318 
canCreateUndefOrPoisonForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,bool PoisonOnly,bool ConsiderFlags,unsigned Depth) const43319 bool X86TargetLowering::canCreateUndefOrPoisonForTargetNode(
43320     SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
43321     bool PoisonOnly, bool ConsiderFlags, unsigned Depth) const {
43322 
43323   switch (Op.getOpcode()) {
43324   // SSE vector multiplies are either inbounds or saturate.
43325   case X86ISD::VPMADDUBSW:
43326   case X86ISD::VPMADDWD:
43327   // SSE vector shifts handle out of bounds shift amounts.
43328   case X86ISD::VSHLI:
43329   case X86ISD::VSRLI:
43330   case X86ISD::VSRAI:
43331     return false;
43332   case X86ISD::PSHUFD:
43333   case X86ISD::VPERMILPI:
43334   case X86ISD::UNPCKH:
43335   case X86ISD::UNPCKL:
43336     return false;
43337   }
43338   return TargetLowering::canCreateUndefOrPoisonForTargetNode(
43339       Op, DemandedElts, DAG, PoisonOnly, ConsiderFlags, Depth);
43340 }
43341 
isSplatValueForTargetNode(SDValue Op,const APInt & DemandedElts,APInt & UndefElts,const SelectionDAG & DAG,unsigned Depth) const43342 bool X86TargetLowering::isSplatValueForTargetNode(SDValue Op,
43343                                                   const APInt &DemandedElts,
43344                                                   APInt &UndefElts,
43345                                                   const SelectionDAG &DAG,
43346                                                   unsigned Depth) const {
43347   unsigned NumElts = DemandedElts.getBitWidth();
43348   unsigned Opc = Op.getOpcode();
43349 
43350   switch (Opc) {
43351   case X86ISD::VBROADCAST:
43352   case X86ISD::VBROADCAST_LOAD:
43353     UndefElts = APInt::getZero(NumElts);
43354     return true;
43355   }
43356 
43357   return TargetLowering::isSplatValueForTargetNode(Op, DemandedElts, UndefElts,
43358                                                    DAG, Depth);
43359 }
43360 
43361 // Helper to peek through bitops/trunc/setcc to determine size of source vector.
43362 // Allows combineBitcastvxi1 to determine what size vector generated a <X x i1>.
checkBitcastSrcVectorSize(SDValue Src,unsigned Size,bool AllowTruncate)43363 static bool checkBitcastSrcVectorSize(SDValue Src, unsigned Size,
43364                                       bool AllowTruncate) {
43365   switch (Src.getOpcode()) {
43366   case ISD::TRUNCATE:
43367     if (!AllowTruncate)
43368       return false;
43369     [[fallthrough]];
43370   case ISD::SETCC:
43371     return Src.getOperand(0).getValueSizeInBits() == Size;
43372   case ISD::FREEZE:
43373     return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate);
43374   case ISD::AND:
43375   case ISD::XOR:
43376   case ISD::OR:
43377     return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate) &&
43378            checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate);
43379   case ISD::SELECT:
43380   case ISD::VSELECT:
43381     return Src.getOperand(0).getScalarValueSizeInBits() == 1 &&
43382            checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate) &&
43383            checkBitcastSrcVectorSize(Src.getOperand(2), Size, AllowTruncate);
43384   case ISD::BUILD_VECTOR:
43385     return ISD::isBuildVectorAllZeros(Src.getNode()) ||
43386            ISD::isBuildVectorAllOnes(Src.getNode());
43387   }
43388   return false;
43389 }
43390 
43391 // Helper to flip between AND/OR/XOR opcodes and their X86ISD FP equivalents.
getAltBitOpcode(unsigned Opcode)43392 static unsigned getAltBitOpcode(unsigned Opcode) {
43393   switch(Opcode) {
43394   // clang-format off
43395   case ISD::AND: return X86ISD::FAND;
43396   case ISD::OR: return X86ISD::FOR;
43397   case ISD::XOR: return X86ISD::FXOR;
43398   case X86ISD::ANDNP: return X86ISD::FANDN;
43399   // clang-format on
43400   }
43401   llvm_unreachable("Unknown bitwise opcode");
43402 }
43403 
43404 // Helper to adjust v4i32 MOVMSK expansion to work with SSE1-only targets.
adjustBitcastSrcVectorSSE1(SelectionDAG & DAG,SDValue Src,const SDLoc & DL)43405 static SDValue adjustBitcastSrcVectorSSE1(SelectionDAG &DAG, SDValue Src,
43406                                           const SDLoc &DL) {
43407   EVT SrcVT = Src.getValueType();
43408   if (SrcVT != MVT::v4i1)
43409     return SDValue();
43410 
43411   switch (Src.getOpcode()) {
43412   case ISD::SETCC:
43413     if (Src.getOperand(0).getValueType() == MVT::v4i32 &&
43414         ISD::isBuildVectorAllZeros(Src.getOperand(1).getNode()) &&
43415         cast<CondCodeSDNode>(Src.getOperand(2))->get() == ISD::SETLT) {
43416       SDValue Op0 = Src.getOperand(0);
43417       if (ISD::isNormalLoad(Op0.getNode()))
43418         return DAG.getBitcast(MVT::v4f32, Op0);
43419       if (Op0.getOpcode() == ISD::BITCAST &&
43420           Op0.getOperand(0).getValueType() == MVT::v4f32)
43421         return Op0.getOperand(0);
43422     }
43423     break;
43424   case ISD::AND:
43425   case ISD::XOR:
43426   case ISD::OR: {
43427     SDValue Op0 = adjustBitcastSrcVectorSSE1(DAG, Src.getOperand(0), DL);
43428     SDValue Op1 = adjustBitcastSrcVectorSSE1(DAG, Src.getOperand(1), DL);
43429     if (Op0 && Op1)
43430       return DAG.getNode(getAltBitOpcode(Src.getOpcode()), DL, MVT::v4f32, Op0,
43431                          Op1);
43432     break;
43433   }
43434   }
43435   return SDValue();
43436 }
43437 
43438 // Helper to push sign extension of vXi1 SETCC result through bitops.
signExtendBitcastSrcVector(SelectionDAG & DAG,EVT SExtVT,SDValue Src,const SDLoc & DL)43439 static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
43440                                           SDValue Src, const SDLoc &DL) {
43441   switch (Src.getOpcode()) {
43442   case ISD::SETCC:
43443   case ISD::FREEZE:
43444   case ISD::TRUNCATE:
43445   case ISD::BUILD_VECTOR:
43446     return DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
43447   case ISD::AND:
43448   case ISD::XOR:
43449   case ISD::OR:
43450     return DAG.getNode(
43451         Src.getOpcode(), DL, SExtVT,
43452         signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(0), DL),
43453         signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(1), DL));
43454   case ISD::SELECT:
43455   case ISD::VSELECT:
43456     return DAG.getSelect(
43457         DL, SExtVT, Src.getOperand(0),
43458         signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(1), DL),
43459         signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(2), DL));
43460   }
43461   llvm_unreachable("Unexpected node type for vXi1 sign extension");
43462 }
43463 
43464 // Try to match patterns such as
43465 // (i16 bitcast (v16i1 x))
43466 // ->
43467 // (i16 movmsk (16i8 sext (v16i1 x)))
43468 // before the illegal vector is scalarized on subtargets that don't have legal
43469 // vxi1 types.
combineBitcastvxi1(SelectionDAG & DAG,EVT VT,SDValue Src,const SDLoc & DL,const X86Subtarget & Subtarget)43470 static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
43471                                   const SDLoc &DL,
43472                                   const X86Subtarget &Subtarget) {
43473   EVT SrcVT = Src.getValueType();
43474   if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
43475     return SDValue();
43476 
43477   // Recognize the IR pattern for the movmsk intrinsic under SSE1 before type
43478   // legalization destroys the v4i32 type.
43479   if (Subtarget.hasSSE1() && !Subtarget.hasSSE2()) {
43480     if (SDValue V = adjustBitcastSrcVectorSSE1(DAG, Src, DL)) {
43481       V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32,
43482                       DAG.getBitcast(MVT::v4f32, V));
43483       return DAG.getZExtOrTrunc(V, DL, VT);
43484     }
43485   }
43486 
43487   // If the input is a truncate from v16i8 or v32i8 go ahead and use a
43488   // movmskb even with avx512. This will be better than truncating to vXi1 and
43489   // using a kmov. This can especially help KNL if the input is a v16i8/v32i8
43490   // vpcmpeqb/vpcmpgtb.
43491   bool PreferMovMsk = Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse() &&
43492                       (Src.getOperand(0).getValueType() == MVT::v16i8 ||
43493                        Src.getOperand(0).getValueType() == MVT::v32i8 ||
43494                        Src.getOperand(0).getValueType() == MVT::v64i8);
43495 
43496   // Prefer movmsk for AVX512 for (bitcast (setlt X, 0)) which can be handled
43497   // directly with vpmovmskb/vmovmskps/vmovmskpd.
43498   if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse() &&
43499       cast<CondCodeSDNode>(Src.getOperand(2))->get() == ISD::SETLT &&
43500       ISD::isBuildVectorAllZeros(Src.getOperand(1).getNode())) {
43501     EVT CmpVT = Src.getOperand(0).getValueType();
43502     EVT EltVT = CmpVT.getVectorElementType();
43503     if (CmpVT.getSizeInBits() <= 256 &&
43504         (EltVT == MVT::i8 || EltVT == MVT::i32 || EltVT == MVT::i64))
43505       PreferMovMsk = true;
43506   }
43507 
43508   // With AVX512 vxi1 types are legal and we prefer using k-regs.
43509   // MOVMSK is supported in SSE2 or later.
43510   if (!Subtarget.hasSSE2() || (Subtarget.hasAVX512() && !PreferMovMsk))
43511     return SDValue();
43512 
43513   // If the upper ops of a concatenation are undef, then try to bitcast the
43514   // lower op and extend.
43515   SmallVector<SDValue, 4> SubSrcOps;
43516   if (collectConcatOps(Src.getNode(), SubSrcOps, DAG) &&
43517       SubSrcOps.size() >= 2) {
43518     SDValue LowerOp = SubSrcOps[0];
43519     ArrayRef<SDValue> UpperOps(std::next(SubSrcOps.begin()), SubSrcOps.end());
43520     if (LowerOp.getOpcode() == ISD::SETCC &&
43521         all_of(UpperOps, [](SDValue Op) { return Op.isUndef(); })) {
43522       EVT SubVT = VT.getIntegerVT(
43523           *DAG.getContext(), LowerOp.getValueType().getVectorMinNumElements());
43524       if (SDValue V = combineBitcastvxi1(DAG, SubVT, LowerOp, DL, Subtarget)) {
43525         EVT IntVT = VT.getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
43526         return DAG.getBitcast(VT, DAG.getNode(ISD::ANY_EXTEND, DL, IntVT, V));
43527       }
43528     }
43529   }
43530 
43531   // There are MOVMSK flavors for types v16i8, v32i8, v4f32, v8f32, v4f64 and
43532   // v8f64. So all legal 128-bit and 256-bit vectors are covered except for
43533   // v8i16 and v16i16.
43534   // For these two cases, we can shuffle the upper element bytes to a
43535   // consecutive sequence at the start of the vector and treat the results as
43536   // v16i8 or v32i8, and for v16i8 this is the preferable solution. However,
43537   // for v16i16 this is not the case, because the shuffle is expensive, so we
43538   // avoid sign-extending to this type entirely.
43539   // For example, t0 := (v8i16 sext(v8i1 x)) needs to be shuffled as:
43540   // (v16i8 shuffle <0,2,4,6,8,10,12,14,u,u,...,u> (v16i8 bitcast t0), undef)
43541   MVT SExtVT;
43542   bool PropagateSExt = false;
43543   switch (SrcVT.getSimpleVT().SimpleTy) {
43544   default:
43545     return SDValue();
43546   case MVT::v2i1:
43547     SExtVT = MVT::v2i64;
43548     break;
43549   case MVT::v4i1:
43550     SExtVT = MVT::v4i32;
43551     // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2))
43552     // sign-extend to a 256-bit operation to avoid truncation.
43553     if (Subtarget.hasAVX() &&
43554         checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2())) {
43555       SExtVT = MVT::v4i64;
43556       PropagateSExt = true;
43557     }
43558     break;
43559   case MVT::v8i1:
43560     SExtVT = MVT::v8i16;
43561     // For cases such as (i8 bitcast (v8i1 setcc v8i32 v1, v2)),
43562     // sign-extend to a 256-bit operation to match the compare.
43563     // If the setcc operand is 128-bit, prefer sign-extending to 128-bit over
43564     // 256-bit because the shuffle is cheaper than sign extending the result of
43565     // the compare.
43566     if (Subtarget.hasAVX() && (checkBitcastSrcVectorSize(Src, 256, true) ||
43567                                checkBitcastSrcVectorSize(Src, 512, true))) {
43568       SExtVT = MVT::v8i32;
43569       PropagateSExt = true;
43570     }
43571     break;
43572   case MVT::v16i1:
43573     SExtVT = MVT::v16i8;
43574     // For the case (i16 bitcast (v16i1 setcc v16i16 v1, v2)),
43575     // it is not profitable to sign-extend to 256-bit because this will
43576     // require an extra cross-lane shuffle which is more expensive than
43577     // truncating the result of the compare to 128-bits.
43578     break;
43579   case MVT::v32i1:
43580     SExtVT = MVT::v32i8;
43581     break;
43582   case MVT::v64i1:
43583     // If we have AVX512F, but not AVX512BW and the input is truncated from
43584     // v64i8 checked earlier. Then split the input and make two pmovmskbs.
43585     if (Subtarget.hasAVX512()) {
43586       if (Subtarget.hasBWI())
43587         return SDValue();
43588       SExtVT = MVT::v64i8;
43589       break;
43590     }
43591     // Split if this is a <64 x i8> comparison result.
43592     if (checkBitcastSrcVectorSize(Src, 512, false)) {
43593       SExtVT = MVT::v64i8;
43594       break;
43595     }
43596     return SDValue();
43597   };
43598 
43599   SDValue V = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
43600                             : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
43601 
43602   if (SExtVT == MVT::v16i8 || SExtVT == MVT::v32i8 || SExtVT == MVT::v64i8) {
43603     V = getPMOVMSKB(DL, V, DAG, Subtarget);
43604   } else {
43605     if (SExtVT == MVT::v8i16) {
43606       V = widenSubVector(V, false, Subtarget, DAG, DL, 256);
43607       V = DAG.getNode(ISD::TRUNCATE, DL, MVT::v16i8, V);
43608     }
43609     V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V);
43610   }
43611 
43612   EVT IntVT =
43613       EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
43614   V = DAG.getZExtOrTrunc(V, DL, IntVT);
43615   return DAG.getBitcast(VT, V);
43616 }
43617 
43618 // Convert a vXi1 constant build vector to the same width scalar integer.
combinevXi1ConstantToInteger(SDValue Op,SelectionDAG & DAG)43619 static SDValue combinevXi1ConstantToInteger(SDValue Op, SelectionDAG &DAG) {
43620   EVT SrcVT = Op.getValueType();
43621   assert(SrcVT.getVectorElementType() == MVT::i1 &&
43622          "Expected a vXi1 vector");
43623   assert(ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) &&
43624          "Expected a constant build vector");
43625 
43626   APInt Imm(SrcVT.getVectorNumElements(), 0);
43627   for (unsigned Idx = 0, e = Op.getNumOperands(); Idx < e; ++Idx) {
43628     SDValue In = Op.getOperand(Idx);
43629     if (!In.isUndef() && (In->getAsZExtVal() & 0x1))
43630       Imm.setBit(Idx);
43631   }
43632   EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), Imm.getBitWidth());
43633   return DAG.getConstant(Imm, SDLoc(Op), IntVT);
43634 }
43635 
combineCastedMaskArithmetic(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)43636 static SDValue combineCastedMaskArithmetic(SDNode *N, SelectionDAG &DAG,
43637                                            TargetLowering::DAGCombinerInfo &DCI,
43638                                            const X86Subtarget &Subtarget) {
43639   assert(N->getOpcode() == ISD::BITCAST && "Expected a bitcast");
43640 
43641   if (!DCI.isBeforeLegalizeOps())
43642     return SDValue();
43643 
43644   // Only do this if we have k-registers.
43645   if (!Subtarget.hasAVX512())
43646     return SDValue();
43647 
43648   EVT DstVT = N->getValueType(0);
43649   SDValue Op = N->getOperand(0);
43650   EVT SrcVT = Op.getValueType();
43651 
43652   if (!Op.hasOneUse())
43653     return SDValue();
43654 
43655   // Look for logic ops.
43656   if (Op.getOpcode() != ISD::AND &&
43657       Op.getOpcode() != ISD::OR &&
43658       Op.getOpcode() != ISD::XOR)
43659     return SDValue();
43660 
43661   // Make sure we have a bitcast between mask registers and a scalar type.
43662   if (!(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
43663         DstVT.isScalarInteger()) &&
43664       !(DstVT.isVector() && DstVT.getVectorElementType() == MVT::i1 &&
43665         SrcVT.isScalarInteger()))
43666     return SDValue();
43667 
43668   SDValue LHS = Op.getOperand(0);
43669   SDValue RHS = Op.getOperand(1);
43670 
43671   if (LHS.hasOneUse() && LHS.getOpcode() == ISD::BITCAST &&
43672       LHS.getOperand(0).getValueType() == DstVT)
43673     return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT, LHS.getOperand(0),
43674                        DAG.getBitcast(DstVT, RHS));
43675 
43676   if (RHS.hasOneUse() && RHS.getOpcode() == ISD::BITCAST &&
43677       RHS.getOperand(0).getValueType() == DstVT)
43678     return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT,
43679                        DAG.getBitcast(DstVT, LHS), RHS.getOperand(0));
43680 
43681   // If the RHS is a vXi1 build vector, this is a good reason to flip too.
43682   // Most of these have to move a constant from the scalar domain anyway.
43683   if (ISD::isBuildVectorOfConstantSDNodes(RHS.getNode())) {
43684     RHS = combinevXi1ConstantToInteger(RHS, DAG);
43685     return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT,
43686                        DAG.getBitcast(DstVT, LHS), RHS);
43687   }
43688 
43689   return SDValue();
43690 }
43691 
createMMXBuildVector(BuildVectorSDNode * BV,SelectionDAG & DAG,const X86Subtarget & Subtarget)43692 static SDValue createMMXBuildVector(BuildVectorSDNode *BV, SelectionDAG &DAG,
43693                                     const X86Subtarget &Subtarget) {
43694   SDLoc DL(BV);
43695   unsigned NumElts = BV->getNumOperands();
43696   SDValue Splat = BV->getSplatValue();
43697 
43698   // Build MMX element from integer GPR or SSE float values.
43699   auto CreateMMXElement = [&](SDValue V) {
43700     if (V.isUndef())
43701       return DAG.getUNDEF(MVT::x86mmx);
43702     if (V.getValueType().isFloatingPoint()) {
43703       if (Subtarget.hasSSE1() && !isa<ConstantFPSDNode>(V)) {
43704         V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4f32, V);
43705         V = DAG.getBitcast(MVT::v2i64, V);
43706         return DAG.getNode(X86ISD::MOVDQ2Q, DL, MVT::x86mmx, V);
43707       }
43708       V = DAG.getBitcast(MVT::i32, V);
43709     } else {
43710       V = DAG.getAnyExtOrTrunc(V, DL, MVT::i32);
43711     }
43712     return DAG.getNode(X86ISD::MMX_MOVW2D, DL, MVT::x86mmx, V);
43713   };
43714 
43715   // Convert build vector ops to MMX data in the bottom elements.
43716   SmallVector<SDValue, 8> Ops;
43717 
43718   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
43719 
43720   // Broadcast - use (PUNPCKL+)PSHUFW to broadcast single element.
43721   if (Splat) {
43722     if (Splat.isUndef())
43723       return DAG.getUNDEF(MVT::x86mmx);
43724 
43725     Splat = CreateMMXElement(Splat);
43726 
43727     if (Subtarget.hasSSE1()) {
43728       // Unpack v8i8 to splat i8 elements to lowest 16-bits.
43729       if (NumElts == 8)
43730         Splat = DAG.getNode(
43731             ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx,
43732             DAG.getTargetConstant(Intrinsic::x86_mmx_punpcklbw, DL,
43733                                   TLI.getPointerTy(DAG.getDataLayout())),
43734             Splat, Splat);
43735 
43736       // Use PSHUFW to repeat 16-bit elements.
43737       unsigned ShufMask = (NumElts > 2 ? 0 : 0x44);
43738       return DAG.getNode(
43739           ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx,
43740           DAG.getTargetConstant(Intrinsic::x86_sse_pshuf_w, DL,
43741                                 TLI.getPointerTy(DAG.getDataLayout())),
43742           Splat, DAG.getTargetConstant(ShufMask, DL, MVT::i8));
43743     }
43744     Ops.append(NumElts, Splat);
43745   } else {
43746     for (unsigned i = 0; i != NumElts; ++i)
43747       Ops.push_back(CreateMMXElement(BV->getOperand(i)));
43748   }
43749 
43750   // Use tree of PUNPCKLs to build up general MMX vector.
43751   while (Ops.size() > 1) {
43752     unsigned NumOps = Ops.size();
43753     unsigned IntrinOp =
43754         (NumOps == 2 ? Intrinsic::x86_mmx_punpckldq
43755                      : (NumOps == 4 ? Intrinsic::x86_mmx_punpcklwd
43756                                     : Intrinsic::x86_mmx_punpcklbw));
43757     SDValue Intrin = DAG.getTargetConstant(
43758         IntrinOp, DL, TLI.getPointerTy(DAG.getDataLayout()));
43759     for (unsigned i = 0; i != NumOps; i += 2)
43760       Ops[i / 2] = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::x86mmx, Intrin,
43761                                Ops[i], Ops[i + 1]);
43762     Ops.resize(NumOps / 2);
43763   }
43764 
43765   return Ops[0];
43766 }
43767 
43768 // Recursive function that attempts to find if a bool vector node was originally
43769 // a vector/float/double that got truncated/extended/bitcast to/from a scalar
43770 // integer. If so, replace the scalar ops with bool vector equivalents back down
43771 // the chain.
combineBitcastToBoolVector(EVT VT,SDValue V,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget,unsigned Depth=0)43772 static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, const SDLoc &DL,
43773                                           SelectionDAG &DAG,
43774                                           const X86Subtarget &Subtarget,
43775                                           unsigned Depth = 0) {
43776   if (Depth >= SelectionDAG::MaxRecursionDepth)
43777     return SDValue(); // Limit search depth.
43778 
43779   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
43780   unsigned Opc = V.getOpcode();
43781   switch (Opc) {
43782   case ISD::BITCAST: {
43783     // Bitcast from a vector/float/double, we can cheaply bitcast to VT.
43784     SDValue Src = V.getOperand(0);
43785     EVT SrcVT = Src.getValueType();
43786     if (SrcVT.isVector() || SrcVT.isFloatingPoint())
43787       return DAG.getBitcast(VT, Src);
43788     break;
43789   }
43790   case ISD::Constant: {
43791     auto *C = cast<ConstantSDNode>(V);
43792     if (C->isZero())
43793       return DAG.getConstant(0, DL, VT);
43794     if (C->isAllOnes())
43795       return DAG.getAllOnesConstant(DL, VT);
43796     break;
43797   }
43798   case ISD::TRUNCATE: {
43799     // If we find a suitable source, a truncated scalar becomes a subvector.
43800     SDValue Src = V.getOperand(0);
43801     EVT NewSrcVT =
43802         EVT::getVectorVT(*DAG.getContext(), MVT::i1, Src.getValueSizeInBits());
43803     if (TLI.isTypeLegal(NewSrcVT))
43804       if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
43805                                                   Subtarget, Depth + 1))
43806         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N0,
43807                            DAG.getIntPtrConstant(0, DL));
43808     break;
43809   }
43810   case ISD::ANY_EXTEND:
43811   case ISD::ZERO_EXTEND: {
43812     // If we find a suitable source, an extended scalar becomes a subvector.
43813     SDValue Src = V.getOperand(0);
43814     EVT NewSrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
43815                                     Src.getScalarValueSizeInBits());
43816     if (TLI.isTypeLegal(NewSrcVT))
43817       if (SDValue N0 = combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG,
43818                                                   Subtarget, Depth + 1))
43819         return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
43820                            Opc == ISD::ANY_EXTEND ? DAG.getUNDEF(VT)
43821                                                   : DAG.getConstant(0, DL, VT),
43822                            N0, DAG.getIntPtrConstant(0, DL));
43823     break;
43824   }
43825   case ISD::OR:
43826   case ISD::XOR: {
43827     // If we find suitable sources, we can just move the op to the vector
43828     // domain.
43829     if (SDValue N0 = combineBitcastToBoolVector(VT, V.getOperand(0), DL, DAG,
43830                                                 Subtarget, Depth + 1))
43831       if (SDValue N1 = combineBitcastToBoolVector(VT, V.getOperand(1), DL, DAG,
43832                                                   Subtarget, Depth + 1))
43833         return DAG.getNode(Opc, DL, VT, N0, N1);
43834     break;
43835   }
43836   case ISD::SHL: {
43837     // If we find a suitable source, a SHL becomes a KSHIFTL.
43838     SDValue Src0 = V.getOperand(0);
43839     if ((VT == MVT::v8i1 && !Subtarget.hasDQI()) ||
43840         ((VT == MVT::v32i1 || VT == MVT::v64i1) && !Subtarget.hasBWI()))
43841       break;
43842 
43843     if (auto *Amt = dyn_cast<ConstantSDNode>(V.getOperand(1)))
43844       if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget,
43845                                                   Depth + 1))
43846         return DAG.getNode(
43847             X86ISD::KSHIFTL, DL, VT, N0,
43848             DAG.getTargetConstant(Amt->getZExtValue(), DL, MVT::i8));
43849     break;
43850   }
43851   }
43852 
43853   // Does the inner bitcast already exist?
43854   if (Depth > 0)
43855     if (SDNode *Alt = DAG.getNodeIfExists(ISD::BITCAST, DAG.getVTList(VT), {V}))
43856       return SDValue(Alt, 0);
43857 
43858   return SDValue();
43859 }
43860 
combineBitcast(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)43861 static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
43862                               TargetLowering::DAGCombinerInfo &DCI,
43863                               const X86Subtarget &Subtarget) {
43864   SDValue N0 = N->getOperand(0);
43865   EVT VT = N->getValueType(0);
43866   EVT SrcVT = N0.getValueType();
43867   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
43868 
43869   // Try to match patterns such as
43870   // (i16 bitcast (v16i1 x))
43871   // ->
43872   // (i16 movmsk (16i8 sext (v16i1 x)))
43873   // before the setcc result is scalarized on subtargets that don't have legal
43874   // vxi1 types.
43875   if (DCI.isBeforeLegalize()) {
43876     SDLoc dl(N);
43877     if (SDValue V = combineBitcastvxi1(DAG, VT, N0, dl, Subtarget))
43878       return V;
43879 
43880     // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer
43881     // type, widen both sides to avoid a trip through memory.
43882     if ((VT == MVT::v4i1 || VT == MVT::v2i1) && SrcVT.isScalarInteger() &&
43883         Subtarget.hasAVX512()) {
43884       N0 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i8, N0);
43885       N0 = DAG.getBitcast(MVT::v8i1, N0);
43886       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, N0,
43887                          DAG.getIntPtrConstant(0, dl));
43888     }
43889 
43890     // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer
43891     // type, widen both sides to avoid a trip through memory.
43892     if ((SrcVT == MVT::v4i1 || SrcVT == MVT::v2i1) && VT.isScalarInteger() &&
43893         Subtarget.hasAVX512()) {
43894       // Use zeros for the widening if we already have some zeroes. This can
43895       // allow SimplifyDemandedBits to remove scalar ANDs that may be down
43896       // stream of this.
43897       // FIXME: It might make sense to detect a concat_vectors with a mix of
43898       // zeroes and undef and turn it into insert_subvector for i1 vectors as
43899       // a separate combine. What we can't do is canonicalize the operands of
43900       // such a concat or we'll get into a loop with SimplifyDemandedBits.
43901       if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
43902         SDValue LastOp = N0.getOperand(N0.getNumOperands() - 1);
43903         if (ISD::isBuildVectorAllZeros(LastOp.getNode())) {
43904           SrcVT = LastOp.getValueType();
43905           unsigned NumConcats = 8 / SrcVT.getVectorNumElements();
43906           SmallVector<SDValue, 4> Ops(N0->op_begin(), N0->op_end());
43907           Ops.resize(NumConcats, DAG.getConstant(0, dl, SrcVT));
43908           N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops);
43909           N0 = DAG.getBitcast(MVT::i8, N0);
43910           return DAG.getNode(ISD::TRUNCATE, dl, VT, N0);
43911         }
43912       }
43913 
43914       unsigned NumConcats = 8 / SrcVT.getVectorNumElements();
43915       SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT));
43916       Ops[0] = N0;
43917       N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops);
43918       N0 = DAG.getBitcast(MVT::i8, N0);
43919       return DAG.getNode(ISD::TRUNCATE, dl, VT, N0);
43920     }
43921   } else {
43922     // If we're bitcasting from iX to vXi1, see if the integer originally
43923     // began as a vXi1 and whether we can remove the bitcast entirely.
43924     if (VT.isVector() && VT.getScalarType() == MVT::i1 &&
43925         SrcVT.isScalarInteger() && TLI.isTypeLegal(VT)) {
43926       if (SDValue V =
43927               combineBitcastToBoolVector(VT, N0, SDLoc(N), DAG, Subtarget))
43928         return V;
43929     }
43930   }
43931 
43932   // Look for (i8 (bitcast (v8i1 (extract_subvector (v16i1 X), 0)))) and
43933   // replace with (i8 (trunc (i16 (bitcast (v16i1 X))))). This can occur
43934   // due to insert_subvector legalization on KNL. By promoting the copy to i16
43935   // we can help with known bits propagation from the vXi1 domain to the
43936   // scalar domain.
43937   if (VT == MVT::i8 && SrcVT == MVT::v8i1 && Subtarget.hasAVX512() &&
43938       !Subtarget.hasDQI() && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
43939       N0.getOperand(0).getValueType() == MVT::v16i1 &&
43940       isNullConstant(N0.getOperand(1)))
43941     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT,
43942                        DAG.getBitcast(MVT::i16, N0.getOperand(0)));
43943 
43944   // Canonicalize (bitcast (vbroadcast_load)) so that the output of the bitcast
43945   // and the vbroadcast_load are both integer or both fp. In some cases this
43946   // will remove the bitcast entirely.
43947   if (N0.getOpcode() == X86ISD::VBROADCAST_LOAD && N0.hasOneUse() &&
43948        VT.isFloatingPoint() != SrcVT.isFloatingPoint() && VT.isVector()) {
43949     auto *BCast = cast<MemIntrinsicSDNode>(N0);
43950     unsigned SrcVTSize = SrcVT.getScalarSizeInBits();
43951     unsigned MemSize = BCast->getMemoryVT().getScalarSizeInBits();
43952     // Don't swap i8/i16 since don't have fp types that size.
43953     if (MemSize >= 32) {
43954       MVT MemVT = VT.isFloatingPoint() ? MVT::getFloatingPointVT(MemSize)
43955                                        : MVT::getIntegerVT(MemSize);
43956       MVT LoadVT = VT.isFloatingPoint() ? MVT::getFloatingPointVT(SrcVTSize)
43957                                         : MVT::getIntegerVT(SrcVTSize);
43958       LoadVT = MVT::getVectorVT(LoadVT, SrcVT.getVectorNumElements());
43959 
43960       SDVTList Tys = DAG.getVTList(LoadVT, MVT::Other);
43961       SDValue Ops[] = { BCast->getChain(), BCast->getBasePtr() };
43962       SDValue ResNode =
43963           DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, SDLoc(N), Tys, Ops,
43964                                   MemVT, BCast->getMemOperand());
43965       DAG.ReplaceAllUsesOfValueWith(SDValue(BCast, 1), ResNode.getValue(1));
43966       return DAG.getBitcast(VT, ResNode);
43967     }
43968   }
43969 
43970   // Since MMX types are special and don't usually play with other vector types,
43971   // it's better to handle them early to be sure we emit efficient code by
43972   // avoiding store-load conversions.
43973   if (VT == MVT::x86mmx) {
43974     // Detect MMX constant vectors.
43975     APInt UndefElts;
43976     SmallVector<APInt, 1> EltBits;
43977     if (getTargetConstantBitsFromNode(N0, 64, UndefElts, EltBits,
43978                                       /*AllowWholeUndefs*/ true,
43979                                       /*AllowPartialUndefs*/ true)) {
43980       SDLoc DL(N0);
43981       // Handle zero-extension of i32 with MOVD.
43982       if (EltBits[0].countl_zero() >= 32)
43983         return DAG.getNode(X86ISD::MMX_MOVW2D, DL, VT,
43984                            DAG.getConstant(EltBits[0].trunc(32), DL, MVT::i32));
43985       // Else, bitcast to a double.
43986       // TODO - investigate supporting sext 32-bit immediates on x86_64.
43987       APFloat F64(APFloat::IEEEdouble(), EltBits[0]);
43988       return DAG.getBitcast(VT, DAG.getConstantFP(F64, DL, MVT::f64));
43989     }
43990 
43991     // Detect bitcasts to x86mmx low word.
43992     if (N0.getOpcode() == ISD::BUILD_VECTOR &&
43993         (SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8) &&
43994         N0.getOperand(0).getValueType() == SrcVT.getScalarType()) {
43995       bool LowUndef = true, AllUndefOrZero = true;
43996       for (unsigned i = 1, e = SrcVT.getVectorNumElements(); i != e; ++i) {
43997         SDValue Op = N0.getOperand(i);
43998         LowUndef &= Op.isUndef() || (i >= e/2);
43999         AllUndefOrZero &= isNullConstantOrUndef(Op);
44000       }
44001       if (AllUndefOrZero) {
44002         SDValue N00 = N0.getOperand(0);
44003         SDLoc dl(N00);
44004         N00 = LowUndef ? DAG.getAnyExtOrTrunc(N00, dl, MVT::i32)
44005                        : DAG.getZExtOrTrunc(N00, dl, MVT::i32);
44006         return DAG.getNode(X86ISD::MMX_MOVW2D, dl, VT, N00);
44007       }
44008     }
44009 
44010     // Detect bitcasts of 64-bit build vectors and convert to a
44011     // MMX UNPCK/PSHUFW which takes MMX type inputs with the value in the
44012     // lowest element.
44013     if (N0.getOpcode() == ISD::BUILD_VECTOR &&
44014         (SrcVT == MVT::v2f32 || SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 ||
44015          SrcVT == MVT::v8i8))
44016       return createMMXBuildVector(cast<BuildVectorSDNode>(N0), DAG, Subtarget);
44017 
44018     // Detect bitcasts between element or subvector extraction to x86mmx.
44019     if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
44020          N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) &&
44021         isNullConstant(N0.getOperand(1))) {
44022       SDValue N00 = N0.getOperand(0);
44023       if (N00.getValueType().is128BitVector())
44024         return DAG.getNode(X86ISD::MOVDQ2Q, SDLoc(N00), VT,
44025                            DAG.getBitcast(MVT::v2i64, N00));
44026     }
44027 
44028     // Detect bitcasts from FP_TO_SINT to x86mmx.
44029     if (SrcVT == MVT::v2i32 && N0.getOpcode() == ISD::FP_TO_SINT) {
44030       SDLoc DL(N0);
44031       SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0,
44032                                 DAG.getUNDEF(MVT::v2i32));
44033       return DAG.getNode(X86ISD::MOVDQ2Q, DL, VT,
44034                          DAG.getBitcast(MVT::v2i64, Res));
44035     }
44036   }
44037 
44038   // Try to remove a bitcast of constant vXi1 vector. We have to legalize
44039   // most of these to scalar anyway.
44040   if (Subtarget.hasAVX512() && VT.isScalarInteger() &&
44041       SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
44042       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
44043     return combinevXi1ConstantToInteger(N0, DAG);
44044   }
44045 
44046   if (Subtarget.hasAVX512() && SrcVT.isScalarInteger() && VT.isVector() &&
44047       VT.getVectorElementType() == MVT::i1) {
44048     if (auto *C = dyn_cast<ConstantSDNode>(N0)) {
44049       if (C->isAllOnes())
44050         return DAG.getConstant(1, SDLoc(N0), VT);
44051       if (C->isZero())
44052         return DAG.getConstant(0, SDLoc(N0), VT);
44053     }
44054   }
44055 
44056   // Look for MOVMSK that is maybe truncated and then bitcasted to vXi1.
44057   // Turn it into a sign bit compare that produces a k-register. This avoids
44058   // a trip through a GPR.
44059   if (Subtarget.hasAVX512() && SrcVT.isScalarInteger() &&
44060       VT.isVector() && VT.getVectorElementType() == MVT::i1 &&
44061       isPowerOf2_32(VT.getVectorNumElements())) {
44062     unsigned NumElts = VT.getVectorNumElements();
44063     SDValue Src = N0;
44064 
44065     // Peek through truncate.
44066     if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse())
44067       Src = N0.getOperand(0);
44068 
44069     if (Src.getOpcode() == X86ISD::MOVMSK && Src.hasOneUse()) {
44070       SDValue MovmskIn = Src.getOperand(0);
44071       MVT MovmskVT = MovmskIn.getSimpleValueType();
44072       unsigned MovMskElts = MovmskVT.getVectorNumElements();
44073 
44074       // We allow extra bits of the movmsk to be used since they are known zero.
44075       // We can't convert a VPMOVMSKB without avx512bw.
44076       if (MovMskElts <= NumElts &&
44077           (Subtarget.hasBWI() || MovmskVT.getVectorElementType() != MVT::i8)) {
44078         EVT IntVT = EVT(MovmskVT).changeVectorElementTypeToInteger();
44079         MovmskIn = DAG.getBitcast(IntVT, MovmskIn);
44080         SDLoc dl(N);
44081         MVT CmpVT = MVT::getVectorVT(MVT::i1, MovMskElts);
44082         SDValue Cmp = DAG.getSetCC(dl, CmpVT, MovmskIn,
44083                                    DAG.getConstant(0, dl, IntVT), ISD::SETLT);
44084         if (EVT(CmpVT) == VT)
44085           return Cmp;
44086 
44087         // Pad with zeroes up to original VT to replace the zeroes that were
44088         // being used from the MOVMSK.
44089         unsigned NumConcats = NumElts / MovMskElts;
44090         SmallVector<SDValue, 4> Ops(NumConcats, DAG.getConstant(0, dl, CmpVT));
44091         Ops[0] = Cmp;
44092         return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Ops);
44093       }
44094     }
44095   }
44096 
44097   // Try to remove bitcasts from input and output of mask arithmetic to
44098   // remove GPR<->K-register crossings.
44099   if (SDValue V = combineCastedMaskArithmetic(N, DAG, DCI, Subtarget))
44100     return V;
44101 
44102   // Convert a bitcasted integer logic operation that has one bitcasted
44103   // floating-point operand into a floating-point logic operation. This may
44104   // create a load of a constant, but that is cheaper than materializing the
44105   // constant in an integer register and transferring it to an SSE register or
44106   // transferring the SSE operand to integer register and back.
44107   unsigned FPOpcode;
44108   switch (N0.getOpcode()) {
44109   // clang-format off
44110   case ISD::AND: FPOpcode = X86ISD::FAND; break;
44111   case ISD::OR:  FPOpcode = X86ISD::FOR;  break;
44112   case ISD::XOR: FPOpcode = X86ISD::FXOR; break;
44113   default: return SDValue();
44114   // clang-format on
44115   }
44116 
44117   // Check if we have a bitcast from another integer type as well.
44118   if (!((Subtarget.hasSSE1() && VT == MVT::f32) ||
44119         (Subtarget.hasSSE2() && VT == MVT::f64) ||
44120         (Subtarget.hasFP16() && VT == MVT::f16) ||
44121         (Subtarget.hasSSE2() && VT.isInteger() && VT.isVector() &&
44122          TLI.isTypeLegal(VT))))
44123     return SDValue();
44124 
44125   SDValue LogicOp0 = N0.getOperand(0);
44126   SDValue LogicOp1 = N0.getOperand(1);
44127   SDLoc DL0(N0);
44128 
44129   // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y))
44130   if (N0.hasOneUse() && LogicOp0.getOpcode() == ISD::BITCAST &&
44131       LogicOp0.hasOneUse() && LogicOp0.getOperand(0).hasOneUse() &&
44132       LogicOp0.getOperand(0).getValueType() == VT &&
44133       !isa<ConstantSDNode>(LogicOp0.getOperand(0))) {
44134     SDValue CastedOp1 = DAG.getBitcast(VT, LogicOp1);
44135     unsigned Opcode = VT.isFloatingPoint() ? FPOpcode : N0.getOpcode();
44136     return DAG.getNode(Opcode, DL0, VT, LogicOp0.getOperand(0), CastedOp1);
44137   }
44138   // bitcast(logic(X, bitcast(Y))) --> logic'(bitcast(X), Y)
44139   if (N0.hasOneUse() && LogicOp1.getOpcode() == ISD::BITCAST &&
44140       LogicOp1.hasOneUse() && LogicOp1.getOperand(0).hasOneUse() &&
44141       LogicOp1.getOperand(0).getValueType() == VT &&
44142       !isa<ConstantSDNode>(LogicOp1.getOperand(0))) {
44143     SDValue CastedOp0 = DAG.getBitcast(VT, LogicOp0);
44144     unsigned Opcode = VT.isFloatingPoint() ? FPOpcode : N0.getOpcode();
44145     return DAG.getNode(Opcode, DL0, VT, LogicOp1.getOperand(0), CastedOp0);
44146   }
44147 
44148   return SDValue();
44149 }
44150 
44151 // (mul (zext a), (sext, b))
detectExtMul(SelectionDAG & DAG,const SDValue & Mul,SDValue & Op0,SDValue & Op1)44152 static bool detectExtMul(SelectionDAG &DAG, const SDValue &Mul, SDValue &Op0,
44153                          SDValue &Op1) {
44154   Op0 = Mul.getOperand(0);
44155   Op1 = Mul.getOperand(1);
44156 
44157   // The operand1 should be signed extend
44158   if (Op0.getOpcode() == ISD::SIGN_EXTEND)
44159     std::swap(Op0, Op1);
44160 
44161   auto IsFreeTruncation = [](SDValue &Op) -> bool {
44162     if ((Op.getOpcode() == ISD::ZERO_EXTEND ||
44163          Op.getOpcode() == ISD::SIGN_EXTEND) &&
44164         Op.getOperand(0).getScalarValueSizeInBits() <= 8)
44165       return true;
44166 
44167     auto *BV = dyn_cast<BuildVectorSDNode>(Op);
44168     return (BV && BV->isConstant());
44169   };
44170 
44171   // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned
44172   // value, we need to check Op0 is zero extended value. Op1 should be signed
44173   // value, so we just check the signed bits.
44174   if ((IsFreeTruncation(Op0) &&
44175        DAG.computeKnownBits(Op0).countMaxActiveBits() <= 8) &&
44176       (IsFreeTruncation(Op1) && DAG.ComputeMaxSignificantBits(Op1) <= 8))
44177     return true;
44178 
44179   return false;
44180 }
44181 
44182 // Given a ABS node, detect the following pattern:
44183 // (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))).
44184 // This is useful as it is the input into a SAD pattern.
detectZextAbsDiff(const SDValue & Abs,SDValue & Op0,SDValue & Op1)44185 static bool detectZextAbsDiff(const SDValue &Abs, SDValue &Op0, SDValue &Op1) {
44186   SDValue AbsOp1 = Abs->getOperand(0);
44187   if (AbsOp1.getOpcode() != ISD::SUB)
44188     return false;
44189 
44190   Op0 = AbsOp1.getOperand(0);
44191   Op1 = AbsOp1.getOperand(1);
44192 
44193   // Check if the operands of the sub are zero-extended from vectors of i8.
44194   if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
44195       Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
44196       Op1.getOpcode() != ISD::ZERO_EXTEND ||
44197       Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
44198     return false;
44199 
44200   return true;
44201 }
44202 
createVPDPBUSD(SelectionDAG & DAG,SDValue LHS,SDValue RHS,unsigned & LogBias,const SDLoc & DL,const X86Subtarget & Subtarget)44203 static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
44204                               unsigned &LogBias, const SDLoc &DL,
44205                               const X86Subtarget &Subtarget) {
44206   // Extend or truncate to MVT::i8 first.
44207   MVT Vi8VT =
44208       MVT::getVectorVT(MVT::i8, LHS.getValueType().getVectorElementCount());
44209   LHS = DAG.getZExtOrTrunc(LHS, DL, Vi8VT);
44210   RHS = DAG.getSExtOrTrunc(RHS, DL, Vi8VT);
44211 
44212   // VPDPBUSD(<16 x i32>C, <16 x i8>A, <16 x i8>B). For each dst element
44213   // C[0] = C[0] + A[0]B[0] + A[1]B[1] + A[2]B[2] + A[3]B[3].
44214   // The src A, B element type is i8, but the dst C element type is i32.
44215   // When we calculate the reduce stage, we use src vector type vXi8 for it
44216   // so we need logbias 2 to avoid extra 2 stages.
44217   LogBias = 2;
44218 
44219   unsigned RegSize = std::max(128u, (unsigned)Vi8VT.getSizeInBits());
44220   if (Subtarget.hasVNNI() && !Subtarget.hasVLX())
44221     RegSize = std::max(512u, RegSize);
44222 
44223   // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
44224   // fill in the missing vector elements with 0.
44225   unsigned NumConcat = RegSize / Vi8VT.getSizeInBits();
44226   SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, Vi8VT));
44227   Ops[0] = LHS;
44228   MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
44229   SDValue DpOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
44230   Ops[0] = RHS;
44231   SDValue DpOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
44232 
44233   // Actually build the DotProduct, split as 256/512 bits for
44234   // AVXVNNI/AVX512VNNI.
44235   auto DpBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
44236                        ArrayRef<SDValue> Ops) {
44237     MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
44238     return DAG.getNode(X86ISD::VPDPBUSD, DL, VT, Ops);
44239   };
44240   MVT DpVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
44241   SDValue Zero = DAG.getConstant(0, DL, DpVT);
44242 
44243   return SplitOpsAndApply(DAG, Subtarget, DL, DpVT, {Zero, DpOp0, DpOp1},
44244                           DpBuilder, false);
44245 }
44246 
44247 // Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
44248 // to these zexts.
createPSADBW(SelectionDAG & DAG,const SDValue & Zext0,const SDValue & Zext1,const SDLoc & DL,const X86Subtarget & Subtarget)44249 static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
44250                             const SDValue &Zext1, const SDLoc &DL,
44251                             const X86Subtarget &Subtarget) {
44252   // Find the appropriate width for the PSADBW.
44253   EVT InVT = Zext0.getOperand(0).getValueType();
44254   unsigned RegSize = std::max(128u, (unsigned)InVT.getSizeInBits());
44255 
44256   // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
44257   // fill in the missing vector elements with 0.
44258   unsigned NumConcat = RegSize / InVT.getSizeInBits();
44259   SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
44260   Ops[0] = Zext0.getOperand(0);
44261   MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
44262   SDValue SadOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
44263   Ops[0] = Zext1.getOperand(0);
44264   SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
44265 
44266   // Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
44267   auto PSADBWBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
44268                           ArrayRef<SDValue> Ops) {
44269     MVT VT = MVT::getVectorVT(MVT::i64, Ops[0].getValueSizeInBits() / 64);
44270     return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops);
44271   };
44272   MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
44273   return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, { SadOp0, SadOp1 },
44274                           PSADBWBuilder);
44275 }
44276 
44277 // Attempt to replace an min/max v8i16/v16i8 horizontal reduction with
44278 // PHMINPOSUW.
combineMinMaxReduction(SDNode * Extract,SelectionDAG & DAG,const X86Subtarget & Subtarget)44279 static SDValue combineMinMaxReduction(SDNode *Extract, SelectionDAG &DAG,
44280                                       const X86Subtarget &Subtarget) {
44281   // Bail without SSE41.
44282   if (!Subtarget.hasSSE41())
44283     return SDValue();
44284 
44285   EVT ExtractVT = Extract->getValueType(0);
44286   if (ExtractVT != MVT::i16 && ExtractVT != MVT::i8)
44287     return SDValue();
44288 
44289   // Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns.
44290   ISD::NodeType BinOp;
44291   SDValue Src = DAG.matchBinOpReduction(
44292       Extract, BinOp, {ISD::SMAX, ISD::SMIN, ISD::UMAX, ISD::UMIN}, true);
44293   if (!Src)
44294     return SDValue();
44295 
44296   EVT SrcVT = Src.getValueType();
44297   EVT SrcSVT = SrcVT.getScalarType();
44298   if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0)
44299     return SDValue();
44300 
44301   SDLoc DL(Extract);
44302   SDValue MinPos = Src;
44303 
44304   // First, reduce the source down to 128-bit, applying BinOp to lo/hi.
44305   while (SrcVT.getSizeInBits() > 128) {
44306     SDValue Lo, Hi;
44307     std::tie(Lo, Hi) = splitVector(MinPos, DAG, DL);
44308     SrcVT = Lo.getValueType();
44309     MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
44310   }
44311   assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) ||
44312           (SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) &&
44313          "Unexpected value type");
44314 
44315   // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
44316   // to flip the value accordingly.
44317   SDValue Mask;
44318   unsigned MaskEltsBits = ExtractVT.getSizeInBits();
44319   if (BinOp == ISD::SMAX)
44320     Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
44321   else if (BinOp == ISD::SMIN)
44322     Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
44323   else if (BinOp == ISD::UMAX)
44324     Mask = DAG.getAllOnesConstant(DL, SrcVT);
44325 
44326   if (Mask)
44327     MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);
44328 
44329   // For v16i8 cases we need to perform UMIN on pairs of byte elements,
44330   // shuffling each upper element down and insert zeros. This means that the
44331   // v16i8 UMIN will leave the upper element as zero, performing zero-extension
44332   // ready for the PHMINPOS.
44333   if (ExtractVT == MVT::i8) {
44334     SDValue Upper = DAG.getVectorShuffle(
44335         SrcVT, DL, MinPos, DAG.getConstant(0, DL, MVT::v16i8),
44336         {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
44337     MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper);
44338   }
44339 
44340   // Perform the PHMINPOS on a v8i16 vector,
44341   MinPos = DAG.getBitcast(MVT::v8i16, MinPos);
44342   MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos);
44343   MinPos = DAG.getBitcast(SrcVT, MinPos);
44344 
44345   if (Mask)
44346     MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos);
44347 
44348   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, MinPos,
44349                      DAG.getIntPtrConstant(0, DL));
44350 }
44351 
44352 // Attempt to replace an all_of/any_of/parity style horizontal reduction with a MOVMSK.
combinePredicateReduction(SDNode * Extract,SelectionDAG & DAG,const X86Subtarget & Subtarget)44353 static SDValue combinePredicateReduction(SDNode *Extract, SelectionDAG &DAG,
44354                                          const X86Subtarget &Subtarget) {
44355   // Bail without SSE2.
44356   if (!Subtarget.hasSSE2())
44357     return SDValue();
44358 
44359   EVT ExtractVT = Extract->getValueType(0);
44360   unsigned BitWidth = ExtractVT.getSizeInBits();
44361   if (ExtractVT != MVT::i64 && ExtractVT != MVT::i32 && ExtractVT != MVT::i16 &&
44362       ExtractVT != MVT::i8 && ExtractVT != MVT::i1)
44363     return SDValue();
44364 
44365   // Check for OR(any_of)/AND(all_of)/XOR(parity) horizontal reduction patterns.
44366   ISD::NodeType BinOp;
44367   SDValue Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
44368   if (!Match && ExtractVT == MVT::i1)
44369     Match = DAG.matchBinOpReduction(Extract, BinOp, {ISD::XOR});
44370   if (!Match)
44371     return SDValue();
44372 
44373   // EXTRACT_VECTOR_ELT can require implicit extension of the vector element
44374   // which we can't support here for now.
44375   if (Match.getScalarValueSizeInBits() != BitWidth)
44376     return SDValue();
44377 
44378   SDValue Movmsk;
44379   SDLoc DL(Extract);
44380   EVT MatchVT = Match.getValueType();
44381   unsigned NumElts = MatchVT.getVectorNumElements();
44382   unsigned MaxElts = Subtarget.hasInt256() ? 32 : 16;
44383   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
44384   LLVMContext &Ctx = *DAG.getContext();
44385 
44386   if (ExtractVT == MVT::i1) {
44387     // Special case for (pre-legalization) vXi1 reductions.
44388     if (NumElts > 64 || !isPowerOf2_32(NumElts))
44389       return SDValue();
44390     if (Match.getOpcode() == ISD::SETCC) {
44391       ISD::CondCode CC = cast<CondCodeSDNode>(Match.getOperand(2))->get();
44392       if ((BinOp == ISD::AND && CC == ISD::CondCode::SETEQ) ||
44393           (BinOp == ISD::OR && CC == ISD::CondCode::SETNE)) {
44394         // For all_of(setcc(x,y,eq)) - use (iX)x == (iX)y.
44395         // For any_of(setcc(x,y,ne)) - use (iX)x != (iX)y.
44396         X86::CondCode X86CC;
44397         SDValue LHS = DAG.getFreeze(Match.getOperand(0));
44398         SDValue RHS = DAG.getFreeze(Match.getOperand(1));
44399         APInt Mask = APInt::getAllOnes(LHS.getScalarValueSizeInBits());
44400         if (SDValue V = LowerVectorAllEqual(DL, LHS, RHS, CC, Mask, Subtarget,
44401                                             DAG, X86CC))
44402           return DAG.getNode(ISD::TRUNCATE, DL, ExtractVT,
44403                              getSETCC(X86CC, V, DL, DAG));
44404       }
44405     }
44406     if (TLI.isTypeLegal(MatchVT)) {
44407       // If this is a legal AVX512 predicate type then we can just bitcast.
44408       EVT MovmskVT = EVT::getIntegerVT(Ctx, NumElts);
44409       Movmsk = DAG.getBitcast(MovmskVT, Match);
44410     } else {
44411       // Use combineBitcastvxi1 to create the MOVMSK.
44412       while (NumElts > MaxElts) {
44413         SDValue Lo, Hi;
44414         std::tie(Lo, Hi) = DAG.SplitVector(Match, DL);
44415         Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi);
44416         NumElts /= 2;
44417       }
44418       EVT MovmskVT = EVT::getIntegerVT(Ctx, NumElts);
44419       Movmsk = combineBitcastvxi1(DAG, MovmskVT, Match, DL, Subtarget);
44420     }
44421     if (!Movmsk)
44422       return SDValue();
44423     Movmsk = DAG.getZExtOrTrunc(Movmsk, DL, NumElts > 32 ? MVT::i64 : MVT::i32);
44424   } else {
44425     // FIXME: Better handling of k-registers or 512-bit vectors?
44426     unsigned MatchSizeInBits = Match.getValueSizeInBits();
44427     if (!(MatchSizeInBits == 128 ||
44428           (MatchSizeInBits == 256 && Subtarget.hasAVX())))
44429       return SDValue();
44430 
44431     // Make sure this isn't a vector of 1 element. The perf win from using
44432     // MOVMSK diminishes with less elements in the reduction, but it is
44433     // generally better to get the comparison over to the GPRs as soon as
44434     // possible to reduce the number of vector ops.
44435     if (Match.getValueType().getVectorNumElements() < 2)
44436       return SDValue();
44437 
44438     // Check that we are extracting a reduction of all sign bits.
44439     if (DAG.ComputeNumSignBits(Match) != BitWidth)
44440       return SDValue();
44441 
44442     if (MatchSizeInBits == 256 && BitWidth < 32 && !Subtarget.hasInt256()) {
44443       SDValue Lo, Hi;
44444       std::tie(Lo, Hi) = DAG.SplitVector(Match, DL);
44445       Match = DAG.getNode(BinOp, DL, Lo.getValueType(), Lo, Hi);
44446       MatchSizeInBits = Match.getValueSizeInBits();
44447     }
44448 
44449     // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
44450     MVT MaskSrcVT;
44451     if (64 == BitWidth || 32 == BitWidth)
44452       MaskSrcVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
44453                                    MatchSizeInBits / BitWidth);
44454     else
44455       MaskSrcVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);
44456 
44457     SDValue BitcastLogicOp = DAG.getBitcast(MaskSrcVT, Match);
44458     Movmsk = getPMOVMSKB(DL, BitcastLogicOp, DAG, Subtarget);
44459     NumElts = MaskSrcVT.getVectorNumElements();
44460   }
44461   assert((NumElts <= 32 || NumElts == 64) &&
44462          "Not expecting more than 64 elements");
44463 
44464   MVT CmpVT = NumElts == 64 ? MVT::i64 : MVT::i32;
44465   if (BinOp == ISD::XOR) {
44466     // parity -> (PARITY(MOVMSK X))
44467     SDValue Result = DAG.getNode(ISD::PARITY, DL, CmpVT, Movmsk);
44468     return DAG.getZExtOrTrunc(Result, DL, ExtractVT);
44469   }
44470 
44471   SDValue CmpC;
44472   ISD::CondCode CondCode;
44473   if (BinOp == ISD::OR) {
44474     // any_of -> MOVMSK != 0
44475     CmpC = DAG.getConstant(0, DL, CmpVT);
44476     CondCode = ISD::CondCode::SETNE;
44477   } else {
44478     // all_of -> MOVMSK == ((1 << NumElts) - 1)
44479     CmpC = DAG.getConstant(APInt::getLowBitsSet(CmpVT.getSizeInBits(), NumElts),
44480                            DL, CmpVT);
44481     CondCode = ISD::CondCode::SETEQ;
44482   }
44483 
44484   // The setcc produces an i8 of 0/1, so extend that to the result width and
44485   // negate to get the final 0/-1 mask value.
44486   EVT SetccVT = TLI.getSetCCResultType(DAG.getDataLayout(), Ctx, CmpVT);
44487   SDValue Setcc = DAG.getSetCC(DL, SetccVT, Movmsk, CmpC, CondCode);
44488   SDValue Zext = DAG.getZExtOrTrunc(Setcc, DL, ExtractVT);
44489   return DAG.getNegative(Zext, DL, ExtractVT);
44490 }
44491 
combineVPDPBUSDPattern(SDNode * Extract,SelectionDAG & DAG,const X86Subtarget & Subtarget)44492 static SDValue combineVPDPBUSDPattern(SDNode *Extract, SelectionDAG &DAG,
44493                                       const X86Subtarget &Subtarget) {
44494   if (!Subtarget.hasVNNI() && !Subtarget.hasAVXVNNI())
44495     return SDValue();
44496 
44497   EVT ExtractVT = Extract->getValueType(0);
44498   // Verify the type we're extracting is i32, as the output element type of
44499   // vpdpbusd is i32.
44500   if (ExtractVT != MVT::i32)
44501     return SDValue();
44502 
44503   EVT VT = Extract->getOperand(0).getValueType();
44504   if (!isPowerOf2_32(VT.getVectorNumElements()))
44505     return SDValue();
44506 
44507   // Match shuffle + add pyramid.
44508   ISD::NodeType BinOp;
44509   SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
44510 
44511   // We can't combine to vpdpbusd for zext, because each of the 4 multiplies
44512   // done by vpdpbusd compute a signed 16-bit product that will be sign extended
44513   // before adding into the accumulator.
44514   // TODO:
44515   // We also need to verify that the multiply has at least 2x the number of bits
44516   // of the input. We shouldn't match
44517   // (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))).
44518   // if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND))
44519   //   Root = Root.getOperand(0);
44520 
44521   // If there was a match, we want Root to be a mul.
44522   if (!Root || Root.getOpcode() != ISD::MUL)
44523     return SDValue();
44524 
44525   // Check whether we have an extend and mul pattern
44526   SDValue LHS, RHS;
44527   if (!detectExtMul(DAG, Root, LHS, RHS))
44528     return SDValue();
44529 
44530   // Create the dot product instruction.
44531   SDLoc DL(Extract);
44532   unsigned StageBias;
44533   SDValue DP = createVPDPBUSD(DAG, LHS, RHS, StageBias, DL, Subtarget);
44534 
44535   // If the original vector was wider than 4 elements, sum over the results
44536   // in the DP vector.
44537   unsigned Stages = Log2_32(VT.getVectorNumElements());
44538   EVT DpVT = DP.getValueType();
44539 
44540   if (Stages > StageBias) {
44541     unsigned DpElems = DpVT.getVectorNumElements();
44542 
44543     for (unsigned i = Stages - StageBias; i > 0; --i) {
44544       SmallVector<int, 16> Mask(DpElems, -1);
44545       for (unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j)
44546         Mask[j] = MaskEnd + j;
44547 
44548       SDValue Shuffle =
44549           DAG.getVectorShuffle(DpVT, DL, DP, DAG.getUNDEF(DpVT), Mask);
44550       DP = DAG.getNode(ISD::ADD, DL, DpVT, DP, Shuffle);
44551     }
44552   }
44553 
44554   // Return the lowest ExtractSizeInBits bits.
44555   EVT ResVT =
44556       EVT::getVectorVT(*DAG.getContext(), ExtractVT,
44557                        DpVT.getSizeInBits() / ExtractVT.getSizeInBits());
44558   DP = DAG.getBitcast(ResVT, DP);
44559   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, DP,
44560                      Extract->getOperand(1));
44561 }
44562 
combineBasicSADPattern(SDNode * Extract,SelectionDAG & DAG,const X86Subtarget & Subtarget)44563 static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
44564                                       const X86Subtarget &Subtarget) {
44565   // PSADBW is only supported on SSE2 and up.
44566   if (!Subtarget.hasSSE2())
44567     return SDValue();
44568 
44569   EVT ExtractVT = Extract->getValueType(0);
44570   // Verify the type we're extracting is either i32 or i64.
44571   // FIXME: Could support other types, but this is what we have coverage for.
44572   if (ExtractVT != MVT::i32 && ExtractVT != MVT::i64)
44573     return SDValue();
44574 
44575   EVT VT = Extract->getOperand(0).getValueType();
44576   if (!isPowerOf2_32(VT.getVectorNumElements()))
44577     return SDValue();
44578 
44579   // Match shuffle + add pyramid.
44580   ISD::NodeType BinOp;
44581   SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
44582 
44583   // The operand is expected to be zero extended from i8
44584   // (verified in detectZextAbsDiff).
44585   // In order to convert to i64 and above, additional any/zero/sign
44586   // extend is expected.
44587   // The zero extend from 32 bit has no mathematical effect on the result.
44588   // Also the sign extend is basically zero extend
44589   // (extends the sign bit which is zero).
44590   // So it is correct to skip the sign/zero extend instruction.
44591   if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND ||
44592                Root.getOpcode() == ISD::ZERO_EXTEND ||
44593                Root.getOpcode() == ISD::ANY_EXTEND))
44594     Root = Root.getOperand(0);
44595 
44596   // If there was a match, we want Root to be a select that is the root of an
44597   // abs-diff pattern.
44598   if (!Root || Root.getOpcode() != ISD::ABS)
44599     return SDValue();
44600 
44601   // Check whether we have an abs-diff pattern feeding into the select.
44602   SDValue Zext0, Zext1;
44603   if (!detectZextAbsDiff(Root, Zext0, Zext1))
44604     return SDValue();
44605 
44606   // Create the SAD instruction.
44607   SDLoc DL(Extract);
44608   SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL, Subtarget);
44609 
44610   // If the original vector was wider than 8 elements, sum over the results
44611   // in the SAD vector.
44612   unsigned Stages = Log2_32(VT.getVectorNumElements());
44613   EVT SadVT = SAD.getValueType();
44614   if (Stages > 3) {
44615     unsigned SadElems = SadVT.getVectorNumElements();
44616 
44617     for(unsigned i = Stages - 3; i > 0; --i) {
44618       SmallVector<int, 16> Mask(SadElems, -1);
44619       for(unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j)
44620         Mask[j] = MaskEnd + j;
44621 
44622       SDValue Shuffle =
44623           DAG.getVectorShuffle(SadVT, DL, SAD, DAG.getUNDEF(SadVT), Mask);
44624       SAD = DAG.getNode(ISD::ADD, DL, SadVT, SAD, Shuffle);
44625     }
44626   }
44627 
44628   unsigned ExtractSizeInBits = ExtractVT.getSizeInBits();
44629   // Return the lowest ExtractSizeInBits bits.
44630   EVT ResVT = EVT::getVectorVT(*DAG.getContext(), ExtractVT,
44631                                SadVT.getSizeInBits() / ExtractSizeInBits);
44632   SAD = DAG.getBitcast(ResVT, SAD);
44633   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, SAD,
44634                      Extract->getOperand(1));
44635 }
44636 
44637 // If this extract is from a loaded vector value and will be used as an
44638 // integer, that requires a potentially expensive XMM -> GPR transfer.
44639 // Additionally, if we can convert to a scalar integer load, that will likely
44640 // be folded into a subsequent integer op.
44641 // Note: SrcVec might not have a VecVT type, but it must be the same size.
44642 // Note: Unlike the related fold for this in DAGCombiner, this is not limited
44643 //       to a single-use of the loaded vector. For the reasons above, we
44644 //       expect this to be profitable even if it creates an extra load.
44645 static SDValue
combineExtractFromVectorLoad(SDNode * N,EVT VecVT,SDValue SrcVec,uint64_t Idx,const SDLoc & dl,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)44646 combineExtractFromVectorLoad(SDNode *N, EVT VecVT, SDValue SrcVec, uint64_t Idx,
44647                              const SDLoc &dl, SelectionDAG &DAG,
44648                              TargetLowering::DAGCombinerInfo &DCI) {
44649   assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
44650          "Only EXTRACT_VECTOR_ELT supported so far");
44651 
44652   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
44653   EVT VT = N->getValueType(0);
44654 
44655   bool LikelyUsedAsVector = any_of(N->uses(), [](SDNode *Use) {
44656     return Use->getOpcode() == ISD::STORE ||
44657            Use->getOpcode() == ISD::INSERT_VECTOR_ELT ||
44658            Use->getOpcode() == ISD::SCALAR_TO_VECTOR;
44659   });
44660 
44661   auto *LoadVec = dyn_cast<LoadSDNode>(SrcVec);
44662   if (LoadVec && ISD::isNormalLoad(LoadVec) && VT.isInteger() &&
44663       VecVT.getVectorElementType() == VT &&
44664       VecVT.getSizeInBits() == SrcVec.getValueSizeInBits() &&
44665       DCI.isAfterLegalizeDAG() && !LikelyUsedAsVector && LoadVec->isSimple()) {
44666     SDValue NewPtr = TLI.getVectorElementPointer(
44667         DAG, LoadVec->getBasePtr(), VecVT, DAG.getVectorIdxConstant(Idx, dl));
44668     unsigned PtrOff = VT.getSizeInBits() * Idx / 8;
44669     MachinePointerInfo MPI = LoadVec->getPointerInfo().getWithOffset(PtrOff);
44670     Align Alignment = commonAlignment(LoadVec->getAlign(), PtrOff);
44671     SDValue Load =
44672         DAG.getLoad(VT, dl, LoadVec->getChain(), NewPtr, MPI, Alignment,
44673                     LoadVec->getMemOperand()->getFlags(), LoadVec->getAAInfo());
44674     DAG.makeEquivalentMemoryOrdering(LoadVec, Load);
44675     return Load;
44676   }
44677 
44678   return SDValue();
44679 }
44680 
44681 // Attempt to peek through a target shuffle and extract the scalar from the
44682 // source.
combineExtractWithShuffle(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)44683 static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG,
44684                                          TargetLowering::DAGCombinerInfo &DCI,
44685                                          const X86Subtarget &Subtarget) {
44686   if (DCI.isBeforeLegalizeOps())
44687     return SDValue();
44688 
44689   SDLoc dl(N);
44690   SDValue Src = N->getOperand(0);
44691   SDValue Idx = N->getOperand(1);
44692 
44693   EVT VT = N->getValueType(0);
44694   EVT SrcVT = Src.getValueType();
44695   EVT SrcSVT = SrcVT.getVectorElementType();
44696   unsigned SrcEltBits = SrcSVT.getSizeInBits();
44697   unsigned NumSrcElts = SrcVT.getVectorNumElements();
44698 
44699   // Don't attempt this for boolean mask vectors or unknown extraction indices.
44700   if (SrcSVT == MVT::i1 || !isa<ConstantSDNode>(Idx))
44701     return SDValue();
44702 
44703   const APInt &IdxC = N->getConstantOperandAPInt(1);
44704   if (IdxC.uge(NumSrcElts))
44705     return SDValue();
44706 
44707   SDValue SrcBC = peekThroughBitcasts(Src);
44708 
44709   // Handle extract(bitcast(broadcast(scalar_value))).
44710   if (X86ISD::VBROADCAST == SrcBC.getOpcode()) {
44711     SDValue SrcOp = SrcBC.getOperand(0);
44712     EVT SrcOpVT = SrcOp.getValueType();
44713     if (SrcOpVT.isScalarInteger() && VT.isInteger() &&
44714         (SrcOpVT.getSizeInBits() % SrcEltBits) == 0) {
44715       unsigned Scale = SrcOpVT.getSizeInBits() / SrcEltBits;
44716       unsigned Offset = IdxC.urem(Scale) * SrcEltBits;
44717       // TODO support non-zero offsets.
44718       if (Offset == 0) {
44719         SrcOp = DAG.getZExtOrTrunc(SrcOp, dl, SrcVT.getScalarType());
44720         SrcOp = DAG.getZExtOrTrunc(SrcOp, dl, VT);
44721         return SrcOp;
44722       }
44723     }
44724   }
44725 
44726   // If we're extracting a single element from a broadcast load and there are
44727   // no other users, just create a single load.
44728   if (SrcBC.getOpcode() == X86ISD::VBROADCAST_LOAD && SrcBC.hasOneUse()) {
44729     auto *MemIntr = cast<MemIntrinsicSDNode>(SrcBC);
44730     unsigned SrcBCWidth = SrcBC.getScalarValueSizeInBits();
44731     if (MemIntr->getMemoryVT().getSizeInBits() == SrcBCWidth &&
44732         VT.getSizeInBits() == SrcBCWidth && SrcEltBits == SrcBCWidth) {
44733       SDValue Load = DAG.getLoad(VT, dl, MemIntr->getChain(),
44734                                  MemIntr->getBasePtr(),
44735                                  MemIntr->getPointerInfo(),
44736                                  MemIntr->getOriginalAlign(),
44737                                  MemIntr->getMemOperand()->getFlags());
44738       DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), Load.getValue(1));
44739       return Load;
44740     }
44741   }
44742 
44743   // Handle extract(bitcast(scalar_to_vector(scalar_value))) for integers.
44744   // TODO: Move to DAGCombine?
44745   if (SrcBC.getOpcode() == ISD::SCALAR_TO_VECTOR && VT.isInteger() &&
44746       SrcBC.getValueType().isInteger() &&
44747       (SrcBC.getScalarValueSizeInBits() % SrcEltBits) == 0 &&
44748       SrcBC.getScalarValueSizeInBits() ==
44749           SrcBC.getOperand(0).getValueSizeInBits()) {
44750     unsigned Scale = SrcBC.getScalarValueSizeInBits() / SrcEltBits;
44751     if (IdxC.ult(Scale)) {
44752       unsigned Offset = IdxC.getZExtValue() * SrcVT.getScalarSizeInBits();
44753       SDValue Scl = SrcBC.getOperand(0);
44754       EVT SclVT = Scl.getValueType();
44755       if (Offset) {
44756         Scl = DAG.getNode(ISD::SRL, dl, SclVT, Scl,
44757                           DAG.getShiftAmountConstant(Offset, SclVT, dl));
44758       }
44759       Scl = DAG.getZExtOrTrunc(Scl, dl, SrcVT.getScalarType());
44760       Scl = DAG.getZExtOrTrunc(Scl, dl, VT);
44761       return Scl;
44762     }
44763   }
44764 
44765   // Handle extract(truncate(x)) for 0'th index.
44766   // TODO: Treat this as a faux shuffle?
44767   // TODO: When can we use this for general indices?
44768   if (ISD::TRUNCATE == Src.getOpcode() && IdxC == 0 &&
44769       (SrcVT.getSizeInBits() % 128) == 0) {
44770     Src = extract128BitVector(Src.getOperand(0), 0, DAG, dl);
44771     MVT ExtractVT = MVT::getVectorVT(SrcSVT.getSimpleVT(), 128 / SrcEltBits);
44772     return DAG.getNode(N->getOpcode(), dl, VT, DAG.getBitcast(ExtractVT, Src),
44773                        Idx);
44774   }
44775 
44776   // We can only legally extract other elements from 128-bit vectors and in
44777   // certain circumstances, depending on SSE-level.
44778   // TODO: Investigate float/double extraction if it will be just stored.
44779   auto GetLegalExtract = [&Subtarget, &DAG, &dl](SDValue Vec, EVT VecVT,
44780                                                  unsigned Idx) {
44781     EVT VecSVT = VecVT.getScalarType();
44782     if ((VecVT.is256BitVector() || VecVT.is512BitVector()) &&
44783         (VecSVT == MVT::i8 || VecSVT == MVT::i16 || VecSVT == MVT::i32 ||
44784          VecSVT == MVT::i64)) {
44785       unsigned EltSizeInBits = VecSVT.getSizeInBits();
44786       unsigned NumEltsPerLane = 128 / EltSizeInBits;
44787       unsigned LaneOffset = (Idx & ~(NumEltsPerLane - 1)) * EltSizeInBits;
44788       unsigned LaneIdx = LaneOffset / Vec.getScalarValueSizeInBits();
44789       VecVT = EVT::getVectorVT(*DAG.getContext(), VecSVT, NumEltsPerLane);
44790       Vec = extract128BitVector(Vec, LaneIdx, DAG, dl);
44791       Idx &= (NumEltsPerLane - 1);
44792     }
44793     if ((VecVT == MVT::v4i32 || VecVT == MVT::v2i64) &&
44794         ((Idx == 0 && Subtarget.hasSSE2()) || Subtarget.hasSSE41())) {
44795       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VecVT.getScalarType(),
44796                          DAG.getBitcast(VecVT, Vec),
44797                          DAG.getIntPtrConstant(Idx, dl));
44798     }
44799     if ((VecVT == MVT::v8i16 && Subtarget.hasSSE2()) ||
44800         (VecVT == MVT::v16i8 && Subtarget.hasSSE41())) {
44801       unsigned OpCode = (VecVT == MVT::v8i16 ? X86ISD::PEXTRW : X86ISD::PEXTRB);
44802       return DAG.getNode(OpCode, dl, MVT::i32, DAG.getBitcast(VecVT, Vec),
44803                          DAG.getTargetConstant(Idx, dl, MVT::i8));
44804     }
44805     return SDValue();
44806   };
44807 
44808   // Resolve the target shuffle inputs and mask.
44809   SmallVector<int, 16> Mask;
44810   SmallVector<SDValue, 2> Ops;
44811   if (!getTargetShuffleInputs(SrcBC, Ops, Mask, DAG))
44812     return SDValue();
44813 
44814   // Shuffle inputs must be the same size as the result.
44815   if (llvm::any_of(Ops, [SrcVT](SDValue Op) {
44816         return SrcVT.getSizeInBits() != Op.getValueSizeInBits();
44817       }))
44818     return SDValue();
44819 
44820   // Attempt to narrow/widen the shuffle mask to the correct size.
44821   if (Mask.size() != NumSrcElts) {
44822     if ((NumSrcElts % Mask.size()) == 0) {
44823       SmallVector<int, 16> ScaledMask;
44824       int Scale = NumSrcElts / Mask.size();
44825       narrowShuffleMaskElts(Scale, Mask, ScaledMask);
44826       Mask = std::move(ScaledMask);
44827     } else if ((Mask.size() % NumSrcElts) == 0) {
44828       // Simplify Mask based on demanded element.
44829       int ExtractIdx = (int)IdxC.getZExtValue();
44830       int Scale = Mask.size() / NumSrcElts;
44831       int Lo = Scale * ExtractIdx;
44832       int Hi = Scale * (ExtractIdx + 1);
44833       for (int i = 0, e = (int)Mask.size(); i != e; ++i)
44834         if (i < Lo || Hi <= i)
44835           Mask[i] = SM_SentinelUndef;
44836 
44837       SmallVector<int, 16> WidenedMask;
44838       while (Mask.size() > NumSrcElts &&
44839              canWidenShuffleElements(Mask, WidenedMask))
44840         Mask = std::move(WidenedMask);
44841     }
44842   }
44843 
44844   // If narrowing/widening failed, see if we can extract+zero-extend.
44845   int ExtractIdx;
44846   EVT ExtractVT;
44847   if (Mask.size() == NumSrcElts) {
44848     ExtractIdx = Mask[IdxC.getZExtValue()];
44849     ExtractVT = SrcVT;
44850   } else {
44851     unsigned Scale = Mask.size() / NumSrcElts;
44852     if ((Mask.size() % NumSrcElts) != 0 || SrcVT.isFloatingPoint())
44853       return SDValue();
44854     unsigned ScaledIdx = Scale * IdxC.getZExtValue();
44855     if (!isUndefOrZeroInRange(Mask, ScaledIdx + 1, Scale - 1))
44856       return SDValue();
44857     ExtractIdx = Mask[ScaledIdx];
44858     EVT ExtractSVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltBits / Scale);
44859     ExtractVT = EVT::getVectorVT(*DAG.getContext(), ExtractSVT, Mask.size());
44860     assert(SrcVT.getSizeInBits() == ExtractVT.getSizeInBits() &&
44861            "Failed to widen vector type");
44862   }
44863 
44864   // If the shuffle source element is undef/zero then we can just accept it.
44865   if (ExtractIdx == SM_SentinelUndef)
44866     return DAG.getUNDEF(VT);
44867 
44868   if (ExtractIdx == SM_SentinelZero)
44869     return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, dl, VT)
44870                                 : DAG.getConstant(0, dl, VT);
44871 
44872   SDValue SrcOp = Ops[ExtractIdx / Mask.size()];
44873   ExtractIdx = ExtractIdx % Mask.size();
44874   if (SDValue V = GetLegalExtract(SrcOp, ExtractVT, ExtractIdx))
44875     return DAG.getZExtOrTrunc(V, dl, VT);
44876 
44877   if (N->getOpcode() == ISD::EXTRACT_VECTOR_ELT && ExtractVT == SrcVT)
44878     if (SDValue V = combineExtractFromVectorLoad(
44879             N, SrcVT, peekThroughBitcasts(SrcOp), ExtractIdx, dl, DAG, DCI))
44880       return V;
44881 
44882   return SDValue();
44883 }
44884 
44885 /// Extracting a scalar FP value from vector element 0 is free, so extract each
44886 /// operand first, then perform the math as a scalar op.
scalarizeExtEltFP(SDNode * ExtElt,SelectionDAG & DAG,const X86Subtarget & Subtarget)44887 static SDValue scalarizeExtEltFP(SDNode *ExtElt, SelectionDAG &DAG,
44888                                  const X86Subtarget &Subtarget) {
44889   assert(ExtElt->getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Expected extract");
44890   SDValue Vec = ExtElt->getOperand(0);
44891   SDValue Index = ExtElt->getOperand(1);
44892   EVT VT = ExtElt->getValueType(0);
44893   EVT VecVT = Vec.getValueType();
44894 
44895   // TODO: If this is a unary/expensive/expand op, allow extraction from a
44896   // non-zero element because the shuffle+scalar op will be cheaper?
44897   if (!Vec.hasOneUse() || !isNullConstant(Index) || VecVT.getScalarType() != VT)
44898     return SDValue();
44899 
44900   // Vector FP compares don't fit the pattern of FP math ops (propagate, not
44901   // extract, the condition code), so deal with those as a special-case.
44902   if (Vec.getOpcode() == ISD::SETCC && VT == MVT::i1) {
44903     EVT OpVT = Vec.getOperand(0).getValueType().getScalarType();
44904     if (OpVT != MVT::f32 && OpVT != MVT::f64)
44905       return SDValue();
44906 
44907     // extract (setcc X, Y, CC), 0 --> setcc (extract X, 0), (extract Y, 0), CC
44908     SDLoc DL(ExtElt);
44909     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT,
44910                                Vec.getOperand(0), Index);
44911     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT,
44912                                Vec.getOperand(1), Index);
44913     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1, Vec.getOperand(2));
44914   }
44915 
44916   if (!(VT == MVT::f16 && Subtarget.hasFP16()) && VT != MVT::f32 &&
44917       VT != MVT::f64)
44918     return SDValue();
44919 
44920   // Vector FP selects don't fit the pattern of FP math ops (because the
44921   // condition has a different type and we have to change the opcode), so deal
44922   // with those here.
44923   // FIXME: This is restricted to pre type legalization by ensuring the setcc
44924   // has i1 elements. If we loosen this we need to convert vector bool to a
44925   // scalar bool.
44926   if (Vec.getOpcode() == ISD::VSELECT &&
44927       Vec.getOperand(0).getOpcode() == ISD::SETCC &&
44928       Vec.getOperand(0).getValueType().getScalarType() == MVT::i1 &&
44929       Vec.getOperand(0).getOperand(0).getValueType() == VecVT) {
44930     // ext (sel Cond, X, Y), 0 --> sel (ext Cond, 0), (ext X, 0), (ext Y, 0)
44931     SDLoc DL(ExtElt);
44932     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
44933                                Vec.getOperand(0).getValueType().getScalarType(),
44934                                Vec.getOperand(0), Index);
44935     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
44936                                Vec.getOperand(1), Index);
44937     SDValue Ext2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
44938                                Vec.getOperand(2), Index);
44939     return DAG.getNode(ISD::SELECT, DL, VT, Ext0, Ext1, Ext2);
44940   }
44941 
44942   // TODO: This switch could include FNEG and the x86-specific FP logic ops
44943   // (FAND, FANDN, FOR, FXOR). But that may require enhancements to avoid
44944   // missed load folding and fma+fneg combining.
44945   switch (Vec.getOpcode()) {
44946   case ISD::FMA: // Begin 3 operands
44947   case ISD::FMAD:
44948   case ISD::FADD: // Begin 2 operands
44949   case ISD::FSUB:
44950   case ISD::FMUL:
44951   case ISD::FDIV:
44952   case ISD::FREM:
44953   case ISD::FCOPYSIGN:
44954   case ISD::FMINNUM:
44955   case ISD::FMAXNUM:
44956   case ISD::FMINNUM_IEEE:
44957   case ISD::FMAXNUM_IEEE:
44958   case ISD::FMAXIMUM:
44959   case ISD::FMINIMUM:
44960   case X86ISD::FMAX:
44961   case X86ISD::FMIN:
44962   case ISD::FABS: // Begin 1 operand
44963   case ISD::FSQRT:
44964   case ISD::FRINT:
44965   case ISD::FCEIL:
44966   case ISD::FTRUNC:
44967   case ISD::FNEARBYINT:
44968   case ISD::FROUNDEVEN:
44969   case ISD::FROUND:
44970   case ISD::FFLOOR:
44971   case X86ISD::FRCP:
44972   case X86ISD::FRSQRT: {
44973     // extract (fp X, Y, ...), 0 --> fp (extract X, 0), (extract Y, 0), ...
44974     SDLoc DL(ExtElt);
44975     SmallVector<SDValue, 4> ExtOps;
44976     for (SDValue Op : Vec->ops())
44977       ExtOps.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op, Index));
44978     return DAG.getNode(Vec.getOpcode(), DL, VT, ExtOps);
44979   }
44980   default:
44981     return SDValue();
44982   }
44983   llvm_unreachable("All opcodes should return within switch");
44984 }
44985 
44986 /// Try to convert a vector reduction sequence composed of binops and shuffles
44987 /// into horizontal ops.
combineArithReduction(SDNode * ExtElt,SelectionDAG & DAG,const X86Subtarget & Subtarget)44988 static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
44989                                      const X86Subtarget &Subtarget) {
44990   assert(ExtElt->getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unexpected caller");
44991 
44992   // We need at least SSE2 to anything here.
44993   if (!Subtarget.hasSSE2())
44994     return SDValue();
44995 
44996   ISD::NodeType Opc;
44997   SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc,
44998                                         {ISD::ADD, ISD::MUL, ISD::FADD}, true);
44999   if (!Rdx)
45000     return SDValue();
45001 
45002   SDValue Index = ExtElt->getOperand(1);
45003   assert(isNullConstant(Index) &&
45004          "Reduction doesn't end in an extract from index 0");
45005 
45006   EVT VT = ExtElt->getValueType(0);
45007   EVT VecVT = Rdx.getValueType();
45008   if (VecVT.getScalarType() != VT)
45009     return SDValue();
45010 
45011   SDLoc DL(ExtElt);
45012   unsigned NumElts = VecVT.getVectorNumElements();
45013   unsigned EltSizeInBits = VecVT.getScalarSizeInBits();
45014 
45015   // Extend v4i8/v8i8 vector to v16i8, with undef upper 64-bits.
45016   auto WidenToV16I8 = [&](SDValue V, bool ZeroExtend) {
45017     if (V.getValueType() == MVT::v4i8) {
45018       if (ZeroExtend && Subtarget.hasSSE41()) {
45019         V = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v4i32,
45020                         DAG.getConstant(0, DL, MVT::v4i32),
45021                         DAG.getBitcast(MVT::i32, V),
45022                         DAG.getIntPtrConstant(0, DL));
45023         return DAG.getBitcast(MVT::v16i8, V);
45024       }
45025       V = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i8, V,
45026                       ZeroExtend ? DAG.getConstant(0, DL, MVT::v4i8)
45027                                  : DAG.getUNDEF(MVT::v4i8));
45028     }
45029     return DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V,
45030                        DAG.getUNDEF(MVT::v8i8));
45031   };
45032 
45033   // vXi8 mul reduction - promote to vXi16 mul reduction.
45034   if (Opc == ISD::MUL) {
45035     if (VT != MVT::i8 || NumElts < 4 || !isPowerOf2_32(NumElts))
45036       return SDValue();
45037     if (VecVT.getSizeInBits() >= 128) {
45038       EVT WideVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts / 2);
45039       SDValue Lo = getUnpackl(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
45040       SDValue Hi = getUnpackh(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
45041       Lo = DAG.getBitcast(WideVT, Lo);
45042       Hi = DAG.getBitcast(WideVT, Hi);
45043       Rdx = DAG.getNode(Opc, DL, WideVT, Lo, Hi);
45044       while (Rdx.getValueSizeInBits() > 128) {
45045         std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
45046         Rdx = DAG.getNode(Opc, DL, Lo.getValueType(), Lo, Hi);
45047       }
45048     } else {
45049       Rdx = WidenToV16I8(Rdx, false);
45050       Rdx = getUnpackl(DAG, DL, MVT::v16i8, Rdx, DAG.getUNDEF(MVT::v16i8));
45051       Rdx = DAG.getBitcast(MVT::v8i16, Rdx);
45052     }
45053     if (NumElts >= 8)
45054       Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
45055                         DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
45056                                              {4, 5, 6, 7, -1, -1, -1, -1}));
45057     Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
45058                       DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
45059                                            {2, 3, -1, -1, -1, -1, -1, -1}));
45060     Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
45061                       DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
45062                                            {1, -1, -1, -1, -1, -1, -1, -1}));
45063     Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
45064     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
45065   }
45066 
45067   // vXi8 add reduction - sub 128-bit vector.
45068   if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
45069     Rdx = WidenToV16I8(Rdx, true);
45070     Rdx = DAG.getNode(X86ISD::PSADBW, DL, MVT::v2i64, Rdx,
45071                       DAG.getConstant(0, DL, MVT::v16i8));
45072     Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
45073     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
45074   }
45075 
45076   // Must be a >=128-bit vector with pow2 elements.
45077   if ((VecVT.getSizeInBits() % 128) != 0 || !isPowerOf2_32(NumElts))
45078     return SDValue();
45079 
45080   // vXi8 add reduction - sum lo/hi halves then use PSADBW.
45081   if (VT == MVT::i8) {
45082     while (Rdx.getValueSizeInBits() > 128) {
45083       SDValue Lo, Hi;
45084       std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
45085       VecVT = Lo.getValueType();
45086       Rdx = DAG.getNode(ISD::ADD, DL, VecVT, Lo, Hi);
45087     }
45088     assert(VecVT == MVT::v16i8 && "v16i8 reduction expected");
45089 
45090     SDValue Hi = DAG.getVectorShuffle(
45091         MVT::v16i8, DL, Rdx, Rdx,
45092         {8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1});
45093     Rdx = DAG.getNode(ISD::ADD, DL, MVT::v16i8, Rdx, Hi);
45094     Rdx = DAG.getNode(X86ISD::PSADBW, DL, MVT::v2i64, Rdx,
45095                       getZeroVector(MVT::v16i8, Subtarget, DAG, DL));
45096     Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
45097     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
45098   }
45099 
45100   // See if we can use vXi8 PSADBW add reduction for larger zext types.
45101   // If the source vector values are 0-255, then we can use PSADBW to
45102   // sum+zext v8i8 subvectors to vXi64, then perform the reduction.
45103   // TODO: See if its worth avoiding vXi16/i32 truncations?
45104   if (Opc == ISD::ADD && NumElts >= 4 && EltSizeInBits >= 16 &&
45105       DAG.computeKnownBits(Rdx).getMaxValue().ule(255) &&
45106       (EltSizeInBits == 16 || Rdx.getOpcode() == ISD::ZERO_EXTEND ||
45107        Subtarget.hasAVX512())) {
45108     if (Rdx.getValueType() == MVT::v8i16) {
45109       Rdx = DAG.getNode(X86ISD::PACKUS, DL, MVT::v16i8, Rdx,
45110                         DAG.getUNDEF(MVT::v8i16));
45111     } else {
45112       EVT ByteVT = VecVT.changeVectorElementType(MVT::i8);
45113       Rdx = DAG.getNode(ISD::TRUNCATE, DL, ByteVT, Rdx);
45114       if (ByteVT.getSizeInBits() < 128)
45115         Rdx = WidenToV16I8(Rdx, true);
45116     }
45117 
45118     // Build the PSADBW, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
45119     auto PSADBWBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
45120                             ArrayRef<SDValue> Ops) {
45121       MVT VT = MVT::getVectorVT(MVT::i64, Ops[0].getValueSizeInBits() / 64);
45122       SDValue Zero = DAG.getConstant(0, DL, Ops[0].getValueType());
45123       return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops[0], Zero);
45124     };
45125     MVT SadVT = MVT::getVectorVT(MVT::i64, Rdx.getValueSizeInBits() / 64);
45126     Rdx = SplitOpsAndApply(DAG, Subtarget, DL, SadVT, {Rdx}, PSADBWBuilder);
45127 
45128     // TODO: We could truncate to vXi16/vXi32 before performing the reduction.
45129     while (Rdx.getValueSizeInBits() > 128) {
45130       SDValue Lo, Hi;
45131       std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
45132       VecVT = Lo.getValueType();
45133       Rdx = DAG.getNode(ISD::ADD, DL, VecVT, Lo, Hi);
45134     }
45135     assert(Rdx.getValueType() == MVT::v2i64 && "v2i64 reduction expected");
45136 
45137     if (NumElts > 8) {
45138       SDValue RdxHi = DAG.getVectorShuffle(MVT::v2i64, DL, Rdx, Rdx, {1, -1});
45139       Rdx = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Rdx, RdxHi);
45140     }
45141 
45142     VecVT = MVT::getVectorVT(VT.getSimpleVT(), 128 / VT.getSizeInBits());
45143     Rdx = DAG.getBitcast(VecVT, Rdx);
45144     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
45145   }
45146 
45147   // Only use (F)HADD opcodes if they aren't microcoded or minimizes codesize.
45148   if (!shouldUseHorizontalOp(true, DAG, Subtarget))
45149     return SDValue();
45150 
45151   unsigned HorizOpcode = Opc == ISD::ADD ? X86ISD::HADD : X86ISD::FHADD;
45152 
45153   // 256-bit horizontal instructions operate on 128-bit chunks rather than
45154   // across the whole vector, so we need an extract + hop preliminary stage.
45155   // This is the only step where the operands of the hop are not the same value.
45156   // TODO: We could extend this to handle 512-bit or even longer vectors.
45157   if (((VecVT == MVT::v16i16 || VecVT == MVT::v8i32) && Subtarget.hasSSSE3()) ||
45158       ((VecVT == MVT::v8f32 || VecVT == MVT::v4f64) && Subtarget.hasSSE3())) {
45159     unsigned NumElts = VecVT.getVectorNumElements();
45160     SDValue Hi = extract128BitVector(Rdx, NumElts / 2, DAG, DL);
45161     SDValue Lo = extract128BitVector(Rdx, 0, DAG, DL);
45162     Rdx = DAG.getNode(HorizOpcode, DL, Lo.getValueType(), Hi, Lo);
45163     VecVT = Rdx.getValueType();
45164   }
45165   if (!((VecVT == MVT::v8i16 || VecVT == MVT::v4i32) && Subtarget.hasSSSE3()) &&
45166       !((VecVT == MVT::v4f32 || VecVT == MVT::v2f64) && Subtarget.hasSSE3()))
45167     return SDValue();
45168 
45169   // extract (add (shuf X), X), 0 --> extract (hadd X, X), 0
45170   unsigned ReductionSteps = Log2_32(VecVT.getVectorNumElements());
45171   for (unsigned i = 0; i != ReductionSteps; ++i)
45172     Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Rdx, Rdx);
45173 
45174   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
45175 }
45176 
45177 /// Detect vector gather/scatter index generation and convert it from being a
45178 /// bunch of shuffles and extracts into a somewhat faster sequence.
45179 /// For i686, the best sequence is apparently storing the value and loading
45180 /// scalars back, while for x64 we should use 64-bit extracts and shifts.
combineExtractVectorElt(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)45181 static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
45182                                        TargetLowering::DAGCombinerInfo &DCI,
45183                                        const X86Subtarget &Subtarget) {
45184   if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget))
45185     return NewOp;
45186 
45187   SDValue InputVector = N->getOperand(0);
45188   SDValue EltIdx = N->getOperand(1);
45189   auto *CIdx = dyn_cast<ConstantSDNode>(EltIdx);
45190 
45191   EVT SrcVT = InputVector.getValueType();
45192   EVT VT = N->getValueType(0);
45193   SDLoc dl(InputVector);
45194   bool IsPextr = N->getOpcode() != ISD::EXTRACT_VECTOR_ELT;
45195   unsigned NumSrcElts = SrcVT.getVectorNumElements();
45196   unsigned NumEltBits = VT.getScalarSizeInBits();
45197   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
45198 
45199   if (CIdx && CIdx->getAPIntValue().uge(NumSrcElts))
45200     return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
45201 
45202   // Integer Constant Folding.
45203   if (CIdx && VT.isInteger()) {
45204     APInt UndefVecElts;
45205     SmallVector<APInt, 16> EltBits;
45206     unsigned VecEltBitWidth = SrcVT.getScalarSizeInBits();
45207     if (getTargetConstantBitsFromNode(InputVector, VecEltBitWidth, UndefVecElts,
45208                                       EltBits, /*AllowWholeUndefs*/ true,
45209                                       /*AllowPartialUndefs*/ false)) {
45210       uint64_t Idx = CIdx->getZExtValue();
45211       if (UndefVecElts[Idx])
45212         return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
45213       return DAG.getConstant(EltBits[Idx].zext(NumEltBits), dl, VT);
45214     }
45215 
45216     // Convert extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector()).
45217     // Improves lowering of bool masks on rust which splits them into byte array.
45218     if (InputVector.getOpcode() == ISD::BITCAST && (NumEltBits % 8) == 0) {
45219       SDValue Src = peekThroughBitcasts(InputVector);
45220       if (Src.getValueType().getScalarType() == MVT::i1 &&
45221           TLI.isTypeLegal(Src.getValueType())) {
45222         MVT SubVT = MVT::getVectorVT(MVT::i1, NumEltBits);
45223         SDValue Sub = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Src,
45224             DAG.getIntPtrConstant(CIdx->getZExtValue() * NumEltBits, dl));
45225         return DAG.getBitcast(VT, Sub);
45226       }
45227     }
45228   }
45229 
45230   if (IsPextr) {
45231     if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
45232                                  DCI))
45233       return SDValue(N, 0);
45234 
45235     // PEXTR*(PINSR*(v, s, c), c) -> s (with implicit zext handling).
45236     if ((InputVector.getOpcode() == X86ISD::PINSRB ||
45237          InputVector.getOpcode() == X86ISD::PINSRW) &&
45238         InputVector.getOperand(2) == EltIdx) {
45239       assert(SrcVT == InputVector.getOperand(0).getValueType() &&
45240              "Vector type mismatch");
45241       SDValue Scl = InputVector.getOperand(1);
45242       Scl = DAG.getNode(ISD::TRUNCATE, dl, SrcVT.getScalarType(), Scl);
45243       return DAG.getZExtOrTrunc(Scl, dl, VT);
45244     }
45245 
45246     // TODO - Remove this once we can handle the implicit zero-extension of
45247     // X86ISD::PEXTRW/X86ISD::PEXTRB in combinePredicateReduction and
45248     // combineBasicSADPattern.
45249     return SDValue();
45250   }
45251 
45252   // Detect mmx extraction of all bits as a i64. It works better as a bitcast.
45253   if (VT == MVT::i64 && SrcVT == MVT::v1i64 &&
45254       InputVector.getOpcode() == ISD::BITCAST &&
45255       InputVector.getOperand(0).getValueType() == MVT::x86mmx &&
45256       isNullConstant(EltIdx) && InputVector.hasOneUse())
45257     return DAG.getBitcast(VT, InputVector);
45258 
45259   // Detect mmx to i32 conversion through a v2i32 elt extract.
45260   if (VT == MVT::i32 && SrcVT == MVT::v2i32 &&
45261       InputVector.getOpcode() == ISD::BITCAST &&
45262       InputVector.getOperand(0).getValueType() == MVT::x86mmx &&
45263       isNullConstant(EltIdx) && InputVector.hasOneUse())
45264     return DAG.getNode(X86ISD::MMX_MOVD2W, dl, MVT::i32,
45265                        InputVector.getOperand(0));
45266 
45267   // Check whether this extract is the root of a sum of absolute differences
45268   // pattern. This has to be done here because we really want it to happen
45269   // pre-legalization,
45270   if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
45271     return SAD;
45272 
45273   if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
45274     return VPDPBUSD;
45275 
45276   // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
45277   if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
45278     return Cmp;
45279 
45280   // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW.
45281   if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget))
45282     return MinMax;
45283 
45284   // Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc..
45285   if (SDValue V = combineArithReduction(N, DAG, Subtarget))
45286     return V;
45287 
45288   if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget))
45289     return V;
45290 
45291   if (CIdx)
45292     if (SDValue V = combineExtractFromVectorLoad(
45293             N, InputVector.getValueType(), InputVector, CIdx->getZExtValue(),
45294             dl, DAG, DCI))
45295       return V;
45296 
45297   // Attempt to extract a i1 element by using MOVMSK to extract the signbits
45298   // and then testing the relevant element.
45299   //
45300   // Note that we only combine extracts on the *same* result number, i.e.
45301   //   t0 = merge_values a0, a1, a2, a3
45302   //   i1 = extract_vector_elt t0, Constant:i64<2>
45303   //   i1 = extract_vector_elt t0, Constant:i64<3>
45304   // but not
45305   //   i1 = extract_vector_elt t0:1, Constant:i64<2>
45306   // since the latter would need its own MOVMSK.
45307   if (SrcVT.getScalarType() == MVT::i1) {
45308     bool IsVar = !CIdx;
45309     SmallVector<SDNode *, 16> BoolExtracts;
45310     unsigned ResNo = InputVector.getResNo();
45311     auto IsBoolExtract = [&BoolExtracts, &ResNo, &IsVar](SDNode *Use) {
45312       if (Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
45313           Use->getOperand(0).getResNo() == ResNo &&
45314           Use->getValueType(0) == MVT::i1) {
45315         BoolExtracts.push_back(Use);
45316         IsVar |= !isa<ConstantSDNode>(Use->getOperand(1));
45317         return true;
45318       }
45319       return false;
45320     };
45321     // TODO: Can we drop the oneuse check for constant extracts?
45322     if (all_of(InputVector->uses(), IsBoolExtract) &&
45323         (IsVar || BoolExtracts.size() > 1)) {
45324       EVT BCVT = EVT::getIntegerVT(*DAG.getContext(), NumSrcElts);
45325       if (SDValue BC =
45326               combineBitcastvxi1(DAG, BCVT, InputVector, dl, Subtarget)) {
45327         for (SDNode *Use : BoolExtracts) {
45328           // extractelement vXi1 X, MaskIdx --> ((movmsk X) & Mask) == Mask
45329           // Mask = 1 << MaskIdx
45330           SDValue MaskIdx = DAG.getZExtOrTrunc(Use->getOperand(1), dl, MVT::i8);
45331           SDValue MaskBit = DAG.getConstant(1, dl, BCVT);
45332           SDValue Mask = DAG.getNode(ISD::SHL, dl, BCVT, MaskBit, MaskIdx);
45333           SDValue Res = DAG.getNode(ISD::AND, dl, BCVT, BC, Mask);
45334           Res = DAG.getSetCC(dl, MVT::i1, Res, Mask, ISD::SETEQ);
45335           DCI.CombineTo(Use, Res);
45336         }
45337         return SDValue(N, 0);
45338       }
45339     }
45340   }
45341 
45342   // Attempt to fold extract(trunc(x),c) -> trunc(extract(x,c)).
45343   if (CIdx && InputVector.getOpcode() == ISD::TRUNCATE) {
45344     SDValue TruncSrc = InputVector.getOperand(0);
45345     EVT TruncSVT = TruncSrc.getValueType().getScalarType();
45346     if (DCI.isBeforeLegalize() && TLI.isTypeLegal(TruncSVT)) {
45347       SDValue NewExt =
45348           DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, TruncSVT, TruncSrc, EltIdx);
45349       return DAG.getAnyExtOrTrunc(NewExt, dl, VT);
45350     }
45351   }
45352 
45353   return SDValue();
45354 }
45355 
45356 // Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
45357 // This is more or less the reverse of combineBitcastvxi1.
combineToExtendBoolVectorInReg(unsigned Opcode,const SDLoc & DL,EVT VT,SDValue N0,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)45358 static SDValue combineToExtendBoolVectorInReg(
45359     unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG,
45360     TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) {
45361   if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
45362       Opcode != ISD::ANY_EXTEND)
45363     return SDValue();
45364   if (!DCI.isBeforeLegalizeOps())
45365     return SDValue();
45366   if (!Subtarget.hasSSE2() || Subtarget.hasAVX512())
45367     return SDValue();
45368 
45369   EVT SVT = VT.getScalarType();
45370   EVT InSVT = N0.getValueType().getScalarType();
45371   unsigned EltSizeInBits = SVT.getSizeInBits();
45372 
45373   // Input type must be extending a bool vector (bit-casted from a scalar
45374   // integer) to legal integer types.
45375   if (!VT.isVector())
45376     return SDValue();
45377   if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
45378     return SDValue();
45379   if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
45380     return SDValue();
45381 
45382   SDValue N00 = N0.getOperand(0);
45383   EVT SclVT = N00.getValueType();
45384   if (!SclVT.isScalarInteger())
45385     return SDValue();
45386 
45387   SDValue Vec;
45388   SmallVector<int> ShuffleMask;
45389   unsigned NumElts = VT.getVectorNumElements();
45390   assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");
45391 
45392   // Broadcast the scalar integer to the vector elements.
45393   if (NumElts > EltSizeInBits) {
45394     // If the scalar integer is greater than the vector element size, then we
45395     // must split it down into sub-sections for broadcasting. For example:
45396     //   i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
45397     //   i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
45398     assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
45399     unsigned Scale = NumElts / EltSizeInBits;
45400     EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
45401     Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
45402     Vec = DAG.getBitcast(VT, Vec);
45403 
45404     for (unsigned i = 0; i != Scale; ++i)
45405       ShuffleMask.append(EltSizeInBits, i);
45406     Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
45407   } else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits &&
45408              (SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) {
45409     // If we have register broadcast instructions, use the scalar size as the
45410     // element type for the shuffle. Then cast to the wider element type. The
45411     // widened bits won't be used, and this might allow the use of a broadcast
45412     // load.
45413     assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale");
45414     unsigned Scale = EltSizeInBits / NumElts;
45415     EVT BroadcastVT =
45416         EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale);
45417     Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
45418     ShuffleMask.append(NumElts * Scale, 0);
45419     Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask);
45420     Vec = DAG.getBitcast(VT, Vec);
45421   } else {
45422     // For smaller scalar integers, we can simply any-extend it to the vector
45423     // element size (we don't care about the upper bits) and broadcast it to all
45424     // elements.
45425     SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT);
45426     Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl);
45427     ShuffleMask.append(NumElts, 0);
45428     Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
45429   }
45430 
45431   // Now, mask the relevant bit in each element.
45432   SmallVector<SDValue, 32> Bits;
45433   for (unsigned i = 0; i != NumElts; ++i) {
45434     int BitIdx = (i % EltSizeInBits);
45435     APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
45436     Bits.push_back(DAG.getConstant(Bit, DL, SVT));
45437   }
45438   SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
45439   Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);
45440 
45441   // Compare against the bitmask and extend the result.
45442   EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
45443   Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
45444   Vec = DAG.getSExtOrTrunc(Vec, DL, VT);
45445 
45446   // For SEXT, this is now done, otherwise shift the result down for
45447   // zero-extension.
45448   if (Opcode == ISD::SIGN_EXTEND)
45449     return Vec;
45450   return DAG.getNode(ISD::SRL, DL, VT, Vec,
45451                      DAG.getConstant(EltSizeInBits - 1, DL, VT));
45452 }
45453 
45454 /// If a vector select has an operand that is -1 or 0, try to simplify the
45455 /// select to a bitwise logic operation.
45456 /// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()?
45457 static SDValue
combineVSelectWithAllOnesOrZeros(SDNode * N,SelectionDAG & DAG,const SDLoc & DL,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)45458 combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
45459                                  TargetLowering::DAGCombinerInfo &DCI,
45460                                  const X86Subtarget &Subtarget) {
45461   SDValue Cond = N->getOperand(0);
45462   SDValue LHS = N->getOperand(1);
45463   SDValue RHS = N->getOperand(2);
45464   EVT VT = LHS.getValueType();
45465   EVT CondVT = Cond.getValueType();
45466   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
45467 
45468   if (N->getOpcode() != ISD::VSELECT)
45469     return SDValue();
45470 
45471   assert(CondVT.isVector() && "Vector select expects a vector selector!");
45472 
45473   // TODO: Use isNullOrNullSplat() to distinguish constants with undefs?
45474   // TODO: Can we assert that both operands are not zeros (because that should
45475   //       get simplified at node creation time)?
45476   bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode());
45477   bool FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
45478 
45479   // If both inputs are 0/undef, create a complete zero vector.
45480   // FIXME: As noted above this should be handled by DAGCombiner/getNode.
45481   if (TValIsAllZeros && FValIsAllZeros) {
45482     if (VT.isFloatingPoint())
45483       return DAG.getConstantFP(0.0, DL, VT);
45484     return DAG.getConstant(0, DL, VT);
45485   }
45486 
45487   // To use the condition operand as a bitwise mask, it must have elements that
45488   // are the same size as the select elements. Ie, the condition operand must
45489   // have already been promoted from the IR select condition type <N x i1>.
45490   // Don't check if the types themselves are equal because that excludes
45491   // vector floating-point selects.
45492   if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
45493     return SDValue();
45494 
45495   // Try to invert the condition if true value is not all 1s and false value is
45496   // not all 0s. Only do this if the condition has one use.
45497   bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode());
45498   if (!TValIsAllOnes && !FValIsAllZeros && Cond.hasOneUse() &&
45499       // Check if the selector will be produced by CMPP*/PCMP*.
45500       Cond.getOpcode() == ISD::SETCC &&
45501       // Check if SETCC has already been promoted.
45502       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
45503           CondVT) {
45504     bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode());
45505 
45506     if (TValIsAllZeros || FValIsAllOnes) {
45507       SDValue CC = Cond.getOperand(2);
45508       ISD::CondCode NewCC = ISD::getSetCCInverse(
45509           cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
45510       Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
45511                           NewCC);
45512       std::swap(LHS, RHS);
45513       TValIsAllOnes = FValIsAllOnes;
45514       FValIsAllZeros = TValIsAllZeros;
45515     }
45516   }
45517 
45518   // Cond value must be 'sign splat' to be converted to a logical op.
45519   if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
45520     return SDValue();
45521 
45522   // vselect Cond, 111..., 000... -> Cond
45523   if (TValIsAllOnes && FValIsAllZeros)
45524     return DAG.getBitcast(VT, Cond);
45525 
45526   if (!TLI.isTypeLegal(CondVT))
45527     return SDValue();
45528 
45529   // vselect Cond, 111..., X -> or Cond, X
45530   if (TValIsAllOnes) {
45531     SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
45532     SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, CastRHS);
45533     return DAG.getBitcast(VT, Or);
45534   }
45535 
45536   // vselect Cond, X, 000... -> and Cond, X
45537   if (FValIsAllZeros) {
45538     SDValue CastLHS = DAG.getBitcast(CondVT, LHS);
45539     SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, CastLHS);
45540     return DAG.getBitcast(VT, And);
45541   }
45542 
45543   // vselect Cond, 000..., X -> andn Cond, X
45544   if (TValIsAllZeros) {
45545     SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
45546     SDValue AndN;
45547     // The canonical form differs for i1 vectors - x86andnp is not used
45548     if (CondVT.getScalarType() == MVT::i1)
45549       AndN = DAG.getNode(ISD::AND, DL, CondVT, DAG.getNOT(DL, Cond, CondVT),
45550                          CastRHS);
45551     else
45552       AndN = DAG.getNode(X86ISD::ANDNP, DL, CondVT, Cond, CastRHS);
45553     return DAG.getBitcast(VT, AndN);
45554   }
45555 
45556   return SDValue();
45557 }
45558 
45559 /// If both arms of a vector select are concatenated vectors, split the select,
45560 /// and concatenate the result to eliminate a wide (256-bit) vector instruction:
45561 ///   vselect Cond, (concat T0, T1), (concat F0, F1) -->
45562 ///   concat (vselect (split Cond), T0, F0), (vselect (split Cond), T1, F1)
narrowVectorSelect(SDNode * N,SelectionDAG & DAG,const SDLoc & DL,const X86Subtarget & Subtarget)45563 static SDValue narrowVectorSelect(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
45564                                   const X86Subtarget &Subtarget) {
45565   unsigned Opcode = N->getOpcode();
45566   if (Opcode != X86ISD::BLENDV && Opcode != ISD::VSELECT)
45567     return SDValue();
45568 
45569   // TODO: Split 512-bit vectors too?
45570   EVT VT = N->getValueType(0);
45571   if (!VT.is256BitVector())
45572     return SDValue();
45573 
45574   // TODO: Split as long as any 2 of the 3 operands are concatenated?
45575   SDValue Cond = N->getOperand(0);
45576   SDValue TVal = N->getOperand(1);
45577   SDValue FVal = N->getOperand(2);
45578   if (!TVal.hasOneUse() || !FVal.hasOneUse() ||
45579       !isFreeToSplitVector(TVal.getNode(), DAG) ||
45580       !isFreeToSplitVector(FVal.getNode(), DAG))
45581     return SDValue();
45582 
45583   auto makeBlend = [Opcode](SelectionDAG &DAG, const SDLoc &DL,
45584                             ArrayRef<SDValue> Ops) {
45585     return DAG.getNode(Opcode, DL, Ops[1].getValueType(), Ops);
45586   };
45587   return SplitOpsAndApply(DAG, Subtarget, DL, VT, {Cond, TVal, FVal}, makeBlend,
45588                           /*CheckBWI*/ false);
45589 }
45590 
combineSelectOfTwoConstants(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)45591 static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG,
45592                                            const SDLoc &DL) {
45593   SDValue Cond = N->getOperand(0);
45594   SDValue LHS = N->getOperand(1);
45595   SDValue RHS = N->getOperand(2);
45596 
45597   auto *TrueC = dyn_cast<ConstantSDNode>(LHS);
45598   auto *FalseC = dyn_cast<ConstantSDNode>(RHS);
45599   if (!TrueC || !FalseC)
45600     return SDValue();
45601 
45602   // Don't do this for crazy integer types.
45603   EVT VT = N->getValueType(0);
45604   if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
45605     return SDValue();
45606 
45607   // We're going to use the condition bit in math or logic ops. We could allow
45608   // this with a wider condition value (post-legalization it becomes an i8),
45609   // but if nothing is creating selects that late, it doesn't matter.
45610   if (Cond.getValueType() != MVT::i1)
45611     return SDValue();
45612 
45613   // A power-of-2 multiply is just a shift. LEA also cheaply handles multiply by
45614   // 3, 5, or 9 with i32/i64, so those get transformed too.
45615   // TODO: For constants that overflow or do not differ by power-of-2 or small
45616   // multiplier, convert to 'and' + 'add'.
45617   const APInt &TrueVal = TrueC->getAPIntValue();
45618   const APInt &FalseVal = FalseC->getAPIntValue();
45619 
45620   // We have a more efficient lowering for "(X == 0) ? Y : -1" using SBB.
45621   if ((TrueVal.isAllOnes() || FalseVal.isAllOnes()) &&
45622       Cond.getOpcode() == ISD::SETCC && isNullConstant(Cond.getOperand(1))) {
45623     ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
45624     if (CC == ISD::SETEQ || CC == ISD::SETNE)
45625       return SDValue();
45626   }
45627 
45628   bool OV;
45629   APInt Diff = TrueVal.ssub_ov(FalseVal, OV);
45630   if (OV)
45631     return SDValue();
45632 
45633   APInt AbsDiff = Diff.abs();
45634   if (AbsDiff.isPowerOf2() ||
45635       ((VT == MVT::i32 || VT == MVT::i64) &&
45636        (AbsDiff == 3 || AbsDiff == 5 || AbsDiff == 9))) {
45637 
45638     // We need a positive multiplier constant for shift/LEA codegen. The 'not'
45639     // of the condition can usually be folded into a compare predicate, but even
45640     // without that, the sequence should be cheaper than a CMOV alternative.
45641     if (TrueVal.slt(FalseVal)) {
45642       Cond = DAG.getNOT(DL, Cond, MVT::i1);
45643       std::swap(TrueC, FalseC);
45644     }
45645 
45646     // select Cond, TC, FC --> (zext(Cond) * (TC - FC)) + FC
45647     SDValue R = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
45648 
45649     // Multiply condition by the difference if non-one.
45650     if (!AbsDiff.isOne())
45651       R = DAG.getNode(ISD::MUL, DL, VT, R, DAG.getConstant(AbsDiff, DL, VT));
45652 
45653     // Add the base if non-zero.
45654     if (!FalseC->isZero())
45655       R = DAG.getNode(ISD::ADD, DL, VT, R, SDValue(FalseC, 0));
45656 
45657     return R;
45658   }
45659 
45660   return SDValue();
45661 }
45662 
45663 /// If this is a *dynamic* select (non-constant condition) and we can match
45664 /// this node with one of the variable blend instructions, restructure the
45665 /// condition so that blends can use the high (sign) bit of each element.
45666 /// This function will also call SimplifyDemandedBits on already created
45667 /// BLENDV to perform additional simplifications.
combineVSelectToBLENDV(SDNode * N,SelectionDAG & DAG,const SDLoc & DL,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)45668 static SDValue combineVSelectToBLENDV(SDNode *N, SelectionDAG &DAG,
45669                                       const SDLoc &DL,
45670                                       TargetLowering::DAGCombinerInfo &DCI,
45671                                       const X86Subtarget &Subtarget) {
45672   SDValue Cond = N->getOperand(0);
45673   if ((N->getOpcode() != ISD::VSELECT &&
45674        N->getOpcode() != X86ISD::BLENDV) ||
45675       ISD::isBuildVectorOfConstantSDNodes(Cond.getNode()))
45676     return SDValue();
45677 
45678   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
45679   unsigned BitWidth = Cond.getScalarValueSizeInBits();
45680   EVT VT = N->getValueType(0);
45681 
45682   // We can only handle the cases where VSELECT is directly legal on the
45683   // subtarget. We custom lower VSELECT nodes with constant conditions and
45684   // this makes it hard to see whether a dynamic VSELECT will correctly
45685   // lower, so we both check the operation's status and explicitly handle the
45686   // cases where a *dynamic* blend will fail even though a constant-condition
45687   // blend could be custom lowered.
45688   // FIXME: We should find a better way to handle this class of problems.
45689   // Potentially, we should combine constant-condition vselect nodes
45690   // pre-legalization into shuffles and not mark as many types as custom
45691   // lowered.
45692   if (!TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
45693     return SDValue();
45694   // FIXME: We don't support i16-element blends currently. We could and
45695   // should support them by making *all* the bits in the condition be set
45696   // rather than just the high bit and using an i8-element blend.
45697   if (VT.getVectorElementType() == MVT::i16)
45698     return SDValue();
45699   // Dynamic blending was only available from SSE4.1 onward.
45700   if (VT.is128BitVector() && !Subtarget.hasSSE41())
45701     return SDValue();
45702   // Byte blends are only available in AVX2
45703   if (VT == MVT::v32i8 && !Subtarget.hasAVX2())
45704     return SDValue();
45705   // There are no 512-bit blend instructions that use sign bits.
45706   if (VT.is512BitVector())
45707     return SDValue();
45708 
45709   // Don't optimize before the condition has been transformed to a legal type
45710   // and don't ever optimize vector selects that map to AVX512 mask-registers.
45711   if (BitWidth < 8 || BitWidth > 64)
45712     return SDValue();
45713 
45714   auto OnlyUsedAsSelectCond = [](SDValue Cond) {
45715     for (SDNode::use_iterator UI = Cond->use_begin(), UE = Cond->use_end();
45716          UI != UE; ++UI)
45717       if ((UI->getOpcode() != ISD::VSELECT &&
45718            UI->getOpcode() != X86ISD::BLENDV) ||
45719           UI.getOperandNo() != 0)
45720         return false;
45721 
45722     return true;
45723   };
45724 
45725   APInt DemandedBits(APInt::getSignMask(BitWidth));
45726 
45727   if (OnlyUsedAsSelectCond(Cond)) {
45728     KnownBits Known;
45729     TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
45730                                           !DCI.isBeforeLegalizeOps());
45731     if (!TLI.SimplifyDemandedBits(Cond, DemandedBits, Known, TLO, 0, true))
45732       return SDValue();
45733 
45734     // If we changed the computation somewhere in the DAG, this change will
45735     // affect all users of Cond. Update all the nodes so that we do not use
45736     // the generic VSELECT anymore. Otherwise, we may perform wrong
45737     // optimizations as we messed with the actual expectation for the vector
45738     // boolean values.
45739     for (SDNode *U : Cond->uses()) {
45740       if (U->getOpcode() == X86ISD::BLENDV)
45741         continue;
45742 
45743       SDValue SB = DAG.getNode(X86ISD::BLENDV, SDLoc(U), U->getValueType(0),
45744                                Cond, U->getOperand(1), U->getOperand(2));
45745       DAG.ReplaceAllUsesOfValueWith(SDValue(U, 0), SB);
45746       DCI.AddToWorklist(U);
45747     }
45748     DCI.CommitTargetLoweringOpt(TLO);
45749     return SDValue(N, 0);
45750   }
45751 
45752   // Otherwise we can still at least try to simplify multiple use bits.
45753   if (SDValue V = TLI.SimplifyMultipleUseDemandedBits(Cond, DemandedBits, DAG))
45754     return DAG.getNode(X86ISD::BLENDV, DL, N->getValueType(0), V,
45755                        N->getOperand(1), N->getOperand(2));
45756 
45757   return SDValue();
45758 }
45759 
45760 // Try to match:
45761 //   (or (and (M, (sub 0, X)), (pandn M, X)))
45762 // which is a special case of:
45763 //   (select M, (sub 0, X), X)
45764 // Per:
45765 // http://graphics.stanford.edu/~seander/bithacks.html#ConditionalNegate
45766 // We know that, if fNegate is 0 or 1:
45767 //   (fNegate ? -v : v) == ((v ^ -fNegate) + fNegate)
45768 //
45769 // Here, we have a mask, M (all 1s or 0), and, similarly, we know that:
45770 //   ((M & 1) ? -X : X) == ((X ^ -(M & 1)) + (M & 1))
45771 //   ( M      ? -X : X) == ((X ^   M     ) + (M & 1))
45772 // This lets us transform our vselect to:
45773 //   (add (xor X, M), (and M, 1))
45774 // And further to:
45775 //   (sub (xor X, M), M)
combineLogicBlendIntoConditionalNegate(EVT VT,SDValue Mask,SDValue X,SDValue Y,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)45776 static SDValue combineLogicBlendIntoConditionalNegate(
45777     EVT VT, SDValue Mask, SDValue X, SDValue Y, const SDLoc &DL,
45778     SelectionDAG &DAG, const X86Subtarget &Subtarget) {
45779   EVT MaskVT = Mask.getValueType();
45780   assert(MaskVT.isInteger() &&
45781          DAG.ComputeNumSignBits(Mask) == MaskVT.getScalarSizeInBits() &&
45782          "Mask must be zero/all-bits");
45783 
45784   if (X.getValueType() != MaskVT || Y.getValueType() != MaskVT)
45785     return SDValue();
45786   if (!DAG.getTargetLoweringInfo().isOperationLegal(ISD::SUB, MaskVT))
45787     return SDValue();
45788 
45789   auto IsNegV = [](SDNode *N, SDValue V) {
45790     return N->getOpcode() == ISD::SUB && N->getOperand(1) == V &&
45791            ISD::isBuildVectorAllZeros(N->getOperand(0).getNode());
45792   };
45793 
45794   SDValue V;
45795   if (IsNegV(Y.getNode(), X))
45796     V = X;
45797   else if (IsNegV(X.getNode(), Y))
45798     V = Y;
45799   else
45800     return SDValue();
45801 
45802   SDValue SubOp1 = DAG.getNode(ISD::XOR, DL, MaskVT, V, Mask);
45803   SDValue SubOp2 = Mask;
45804 
45805   // If the negate was on the false side of the select, then
45806   // the operands of the SUB need to be swapped. PR 27251.
45807   // This is because the pattern being matched above is
45808   // (vselect M, (sub (0, X), X)  -> (sub (xor X, M), M)
45809   // but if the pattern matched was
45810   // (vselect M, X, (sub (0, X))), that is really negation of the pattern
45811   // above, -(vselect M, (sub 0, X), X), and therefore the replacement
45812   // pattern also needs to be a negation of the replacement pattern above.
45813   // And -(sub X, Y) is just sub (Y, X), so swapping the operands of the
45814   // sub accomplishes the negation of the replacement pattern.
45815   if (V == Y)
45816     std::swap(SubOp1, SubOp2);
45817 
45818   SDValue Res = DAG.getNode(ISD::SUB, DL, MaskVT, SubOp1, SubOp2);
45819   return DAG.getBitcast(VT, Res);
45820 }
45821 
commuteSelect(SDNode * N,SelectionDAG & DAG,const SDLoc & DL,const X86Subtarget & Subtarget)45822 static SDValue commuteSelect(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
45823                              const X86Subtarget &Subtarget) {
45824   if (!Subtarget.hasAVX512())
45825     return SDValue();
45826   if (N->getOpcode() != ISD::VSELECT)
45827     return SDValue();
45828 
45829   SDValue Cond = N->getOperand(0);
45830   SDValue LHS = N->getOperand(1);
45831   SDValue RHS = N->getOperand(2);
45832 
45833   if (canCombineAsMaskOperation(LHS, Subtarget))
45834     return SDValue();
45835 
45836   if (!canCombineAsMaskOperation(RHS, Subtarget))
45837     return SDValue();
45838 
45839   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse())
45840     return SDValue();
45841 
45842   // Commute LHS and RHS to create opportunity to select mask instruction.
45843   // (vselect M, L, R) -> (vselect ~M, R, L)
45844   ISD::CondCode NewCC =
45845       ISD::getSetCCInverse(cast<CondCodeSDNode>(Cond.getOperand(2))->get(),
45846                            Cond.getOperand(0).getValueType());
45847   Cond = DAG.getSetCC(SDLoc(Cond), Cond.getValueType(), Cond.getOperand(0),
45848                       Cond.getOperand(1), NewCC);
45849   return DAG.getSelect(DL, LHS.getValueType(), Cond, RHS, LHS);
45850 }
45851 
45852 /// Do target-specific dag combines on SELECT and VSELECT nodes.
combineSelect(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)45853 static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
45854                              TargetLowering::DAGCombinerInfo &DCI,
45855                              const X86Subtarget &Subtarget) {
45856   SDLoc DL(N);
45857   SDValue Cond = N->getOperand(0);
45858   SDValue LHS = N->getOperand(1);
45859   SDValue RHS = N->getOperand(2);
45860 
45861   // Try simplification again because we use this function to optimize
45862   // BLENDV nodes that are not handled by the generic combiner.
45863   if (SDValue V = DAG.simplifySelect(Cond, LHS, RHS))
45864     return V;
45865 
45866   // When avx512 is available the lhs operand of select instruction can be
45867   // folded with mask instruction, while the rhs operand can't. Commute the
45868   // lhs and rhs of the select instruction to create the opportunity of
45869   // folding.
45870   if (SDValue V = commuteSelect(N, DAG, DL, Subtarget))
45871     return V;
45872 
45873   EVT VT = LHS.getValueType();
45874   EVT CondVT = Cond.getValueType();
45875   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
45876   bool CondConstantVector = ISD::isBuildVectorOfConstantSDNodes(Cond.getNode());
45877 
45878   // Attempt to combine (select M, (sub 0, X), X) -> (sub (xor X, M), M).
45879   // Limit this to cases of non-constant masks that createShuffleMaskFromVSELECT
45880   // can't catch, plus vXi8 cases where we'd likely end up with BLENDV.
45881   if (CondVT.isVector() && CondVT.isInteger() &&
45882       CondVT.getScalarSizeInBits() == VT.getScalarSizeInBits() &&
45883       (!CondConstantVector || CondVT.getScalarType() == MVT::i8) &&
45884       DAG.ComputeNumSignBits(Cond) == CondVT.getScalarSizeInBits())
45885     if (SDValue V = combineLogicBlendIntoConditionalNegate(VT, Cond, RHS, LHS,
45886                                                            DL, DAG, Subtarget))
45887       return V;
45888 
45889   // Convert vselects with constant condition into shuffles.
45890   if (CondConstantVector && DCI.isBeforeLegalizeOps() &&
45891       (N->getOpcode() == ISD::VSELECT || N->getOpcode() == X86ISD::BLENDV)) {
45892     SmallVector<int, 64> Mask;
45893     if (createShuffleMaskFromVSELECT(Mask, Cond,
45894                                      N->getOpcode() == X86ISD::BLENDV))
45895       return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask);
45896   }
45897 
45898   // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
45899   // by forcing the unselected elements to zero.
45900   // TODO: Can we handle more shuffles with this?
45901   if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() &&
45902       LHS.getOpcode() == X86ISD::PSHUFB && RHS.getOpcode() == X86ISD::PSHUFB &&
45903       LHS.hasOneUse() && RHS.hasOneUse()) {
45904     MVT SimpleVT = VT.getSimpleVT();
45905     SmallVector<SDValue, 1> LHSOps, RHSOps;
45906     SmallVector<int, 64> LHSMask, RHSMask, CondMask;
45907     if (createShuffleMaskFromVSELECT(CondMask, Cond) &&
45908         getTargetShuffleMask(LHS, true, LHSOps, LHSMask) &&
45909         getTargetShuffleMask(RHS, true, RHSOps, RHSMask)) {
45910       int NumElts = VT.getVectorNumElements();
45911       for (int i = 0; i != NumElts; ++i) {
45912         // getConstVector sets negative shuffle mask values as undef, so ensure
45913         // we hardcode SM_SentinelZero values to zero (0x80).
45914         if (CondMask[i] < NumElts) {
45915           LHSMask[i] = isUndefOrZero(LHSMask[i]) ? 0x80 : LHSMask[i];
45916           RHSMask[i] = 0x80;
45917         } else {
45918           LHSMask[i] = 0x80;
45919           RHSMask[i] = isUndefOrZero(RHSMask[i]) ? 0x80 : RHSMask[i];
45920         }
45921       }
45922       LHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, LHS.getOperand(0),
45923                         getConstVector(LHSMask, SimpleVT, DAG, DL, true));
45924       RHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, RHS.getOperand(0),
45925                         getConstVector(RHSMask, SimpleVT, DAG, DL, true));
45926       return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
45927     }
45928   }
45929 
45930   // If we have SSE[12] support, try to form min/max nodes. SSE min/max
45931   // instructions match the semantics of the common C idiom x<y?x:y but not
45932   // x<=y?x:y, because of how they handle negative zero (which can be
45933   // ignored in unsafe-math mode).
45934   // We also try to create v2f32 min/max nodes, which we later widen to v4f32.
45935   if (Cond.getOpcode() == ISD::SETCC && VT.isFloatingPoint() &&
45936       VT != MVT::f80 && VT != MVT::f128 && !isSoftF16(VT, Subtarget) &&
45937       (TLI.isTypeLegal(VT) || VT == MVT::v2f32) &&
45938       (Subtarget.hasSSE2() ||
45939        (Subtarget.hasSSE1() && VT.getScalarType() == MVT::f32))) {
45940     ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
45941 
45942     unsigned Opcode = 0;
45943     // Check for x CC y ? x : y.
45944     if (DAG.isEqualTo(LHS, Cond.getOperand(0)) &&
45945         DAG.isEqualTo(RHS, Cond.getOperand(1))) {
45946       switch (CC) {
45947       default: break;
45948       case ISD::SETULT:
45949         // Converting this to a min would handle NaNs incorrectly, and swapping
45950         // the operands would cause it to handle comparisons between positive
45951         // and negative zero incorrectly.
45952         if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) {
45953           if (!DAG.getTarget().Options.NoSignedZerosFPMath &&
45954               !(DAG.isKnownNeverZeroFloat(LHS) ||
45955                 DAG.isKnownNeverZeroFloat(RHS)))
45956             break;
45957           std::swap(LHS, RHS);
45958         }
45959         Opcode = X86ISD::FMIN;
45960         break;
45961       case ISD::SETOLE:
45962         // Converting this to a min would handle comparisons between positive
45963         // and negative zero incorrectly.
45964         if (!DAG.getTarget().Options.NoSignedZerosFPMath &&
45965             !DAG.isKnownNeverZeroFloat(LHS) && !DAG.isKnownNeverZeroFloat(RHS))
45966           break;
45967         Opcode = X86ISD::FMIN;
45968         break;
45969       case ISD::SETULE:
45970         // Converting this to a min would handle both negative zeros and NaNs
45971         // incorrectly, but we can swap the operands to fix both.
45972         std::swap(LHS, RHS);
45973         [[fallthrough]];
45974       case ISD::SETOLT:
45975       case ISD::SETLT:
45976       case ISD::SETLE:
45977         Opcode = X86ISD::FMIN;
45978         break;
45979 
45980       case ISD::SETOGE:
45981         // Converting this to a max would handle comparisons between positive
45982         // and negative zero incorrectly.
45983         if (!DAG.getTarget().Options.NoSignedZerosFPMath &&
45984             !DAG.isKnownNeverZeroFloat(LHS) && !DAG.isKnownNeverZeroFloat(RHS))
45985           break;
45986         Opcode = X86ISD::FMAX;
45987         break;
45988       case ISD::SETUGT:
45989         // Converting this to a max would handle NaNs incorrectly, and swapping
45990         // the operands would cause it to handle comparisons between positive
45991         // and negative zero incorrectly.
45992         if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS)) {
45993           if (!DAG.getTarget().Options.NoSignedZerosFPMath &&
45994               !(DAG.isKnownNeverZeroFloat(LHS) ||
45995                 DAG.isKnownNeverZeroFloat(RHS)))
45996             break;
45997           std::swap(LHS, RHS);
45998         }
45999         Opcode = X86ISD::FMAX;
46000         break;
46001       case ISD::SETUGE:
46002         // Converting this to a max would handle both negative zeros and NaNs
46003         // incorrectly, but we can swap the operands to fix both.
46004         std::swap(LHS, RHS);
46005         [[fallthrough]];
46006       case ISD::SETOGT:
46007       case ISD::SETGT:
46008       case ISD::SETGE:
46009         Opcode = X86ISD::FMAX;
46010         break;
46011       }
46012     // Check for x CC y ? y : x -- a min/max with reversed arms.
46013     } else if (DAG.isEqualTo(LHS, Cond.getOperand(1)) &&
46014                DAG.isEqualTo(RHS, Cond.getOperand(0))) {
46015       switch (CC) {
46016       default: break;
46017       case ISD::SETOGE:
46018         // Converting this to a min would handle comparisons between positive
46019         // and negative zero incorrectly, and swapping the operands would
46020         // cause it to handle NaNs incorrectly.
46021         if (!DAG.getTarget().Options.NoSignedZerosFPMath &&
46022             !(DAG.isKnownNeverZeroFloat(LHS) ||
46023               DAG.isKnownNeverZeroFloat(RHS))) {
46024           if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS))
46025             break;
46026           std::swap(LHS, RHS);
46027         }
46028         Opcode = X86ISD::FMIN;
46029         break;
46030       case ISD::SETUGT:
46031         // Converting this to a min would handle NaNs incorrectly.
46032         if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS))
46033           break;
46034         Opcode = X86ISD::FMIN;
46035         break;
46036       case ISD::SETUGE:
46037         // Converting this to a min would handle both negative zeros and NaNs
46038         // incorrectly, but we can swap the operands to fix both.
46039         std::swap(LHS, RHS);
46040         [[fallthrough]];
46041       case ISD::SETOGT:
46042       case ISD::SETGT:
46043       case ISD::SETGE:
46044         Opcode = X86ISD::FMIN;
46045         break;
46046 
46047       case ISD::SETULT:
46048         // Converting this to a max would handle NaNs incorrectly.
46049         if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS))
46050           break;
46051         Opcode = X86ISD::FMAX;
46052         break;
46053       case ISD::SETOLE:
46054         // Converting this to a max would handle comparisons between positive
46055         // and negative zero incorrectly, and swapping the operands would
46056         // cause it to handle NaNs incorrectly.
46057         if (!DAG.getTarget().Options.NoSignedZerosFPMath &&
46058             !DAG.isKnownNeverZeroFloat(LHS) &&
46059             !DAG.isKnownNeverZeroFloat(RHS)) {
46060           if (!DAG.isKnownNeverNaN(LHS) || !DAG.isKnownNeverNaN(RHS))
46061             break;
46062           std::swap(LHS, RHS);
46063         }
46064         Opcode = X86ISD::FMAX;
46065         break;
46066       case ISD::SETULE:
46067         // Converting this to a max would handle both negative zeros and NaNs
46068         // incorrectly, but we can swap the operands to fix both.
46069         std::swap(LHS, RHS);
46070         [[fallthrough]];
46071       case ISD::SETOLT:
46072       case ISD::SETLT:
46073       case ISD::SETLE:
46074         Opcode = X86ISD::FMAX;
46075         break;
46076       }
46077     }
46078 
46079     if (Opcode)
46080       return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS);
46081   }
46082 
46083   // Some mask scalar intrinsics rely on checking if only one bit is set
46084   // and implement it in C code like this:
46085   // A[0] = (U & 1) ? A[0] : W[0];
46086   // This creates some redundant instructions that break pattern matching.
46087   // fold (select (setcc (and (X, 1), 0, seteq), Y, Z)) -> select(and(X, 1),Z,Y)
46088   if (Subtarget.hasAVX512() && N->getOpcode() == ISD::SELECT &&
46089       Cond.getOpcode() == ISD::SETCC && (VT == MVT::f32 || VT == MVT::f64)) {
46090     ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
46091     SDValue AndNode = Cond.getOperand(0);
46092     if (AndNode.getOpcode() == ISD::AND && CC == ISD::SETEQ &&
46093         isNullConstant(Cond.getOperand(1)) &&
46094         isOneConstant(AndNode.getOperand(1))) {
46095       // LHS and RHS swapped due to
46096       // setcc outputting 1 when AND resulted in 0 and vice versa.
46097       AndNode = DAG.getZExtOrTrunc(AndNode, DL, MVT::i8);
46098       return DAG.getNode(ISD::SELECT, DL, VT, AndNode, RHS, LHS);
46099     }
46100   }
46101 
46102   // v16i8 (select v16i1, v16i8, v16i8) does not have a proper
46103   // lowering on KNL. In this case we convert it to
46104   // v16i8 (select v16i8, v16i8, v16i8) and use AVX instruction.
46105   // The same situation all vectors of i8 and i16 without BWI.
46106   // Make sure we extend these even before type legalization gets a chance to
46107   // split wide vectors.
46108   // Since SKX these selects have a proper lowering.
46109   if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && CondVT.isVector() &&
46110       CondVT.getVectorElementType() == MVT::i1 &&
46111       (VT.getVectorElementType() == MVT::i8 ||
46112        VT.getVectorElementType() == MVT::i16)) {
46113     Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
46114     return DAG.getNode(N->getOpcode(), DL, VT, Cond, LHS, RHS);
46115   }
46116 
46117   // AVX512 - Extend select with zero to merge with target shuffle.
46118   // select(mask, extract_subvector(shuffle(x)), zero) -->
46119   // extract_subvector(select(insert_subvector(mask), shuffle(x), zero))
46120   // TODO - support non target shuffles as well.
46121   if (Subtarget.hasAVX512() && CondVT.isVector() &&
46122       CondVT.getVectorElementType() == MVT::i1) {
46123     auto SelectableOp = [&TLI](SDValue Op) {
46124       return Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
46125              isTargetShuffle(Op.getOperand(0).getOpcode()) &&
46126              isNullConstant(Op.getOperand(1)) &&
46127              TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
46128              Op.hasOneUse() && Op.getOperand(0).hasOneUse();
46129     };
46130 
46131     bool SelectableLHS = SelectableOp(LHS);
46132     bool SelectableRHS = SelectableOp(RHS);
46133     bool ZeroLHS = ISD::isBuildVectorAllZeros(LHS.getNode());
46134     bool ZeroRHS = ISD::isBuildVectorAllZeros(RHS.getNode());
46135 
46136     if ((SelectableLHS && ZeroRHS) || (SelectableRHS && ZeroLHS)) {
46137       EVT SrcVT = SelectableLHS ? LHS.getOperand(0).getValueType()
46138                                 : RHS.getOperand(0).getValueType();
46139       EVT SrcCondVT = SrcVT.changeVectorElementType(MVT::i1);
46140       LHS = insertSubVector(DAG.getUNDEF(SrcVT), LHS, 0, DAG, DL,
46141                             VT.getSizeInBits());
46142       RHS = insertSubVector(DAG.getUNDEF(SrcVT), RHS, 0, DAG, DL,
46143                             VT.getSizeInBits());
46144       Cond = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, SrcCondVT,
46145                          DAG.getUNDEF(SrcCondVT), Cond,
46146                          DAG.getIntPtrConstant(0, DL));
46147       SDValue Res = DAG.getSelect(DL, SrcVT, Cond, LHS, RHS);
46148       return extractSubVector(Res, 0, DAG, DL, VT.getSizeInBits());
46149     }
46150   }
46151 
46152   if (SDValue V = combineSelectOfTwoConstants(N, DAG, DL))
46153     return V;
46154 
46155   if (N->getOpcode() == ISD::SELECT && Cond.getOpcode() == ISD::SETCC &&
46156       Cond.hasOneUse()) {
46157     EVT CondVT = Cond.getValueType();
46158     SDValue Cond0 = Cond.getOperand(0);
46159     SDValue Cond1 = Cond.getOperand(1);
46160     ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
46161 
46162     // Canonicalize min/max:
46163     // (x > 0) ? x : 0 -> (x >= 0) ? x : 0
46164     // (x < -1) ? x : -1 -> (x <= -1) ? x : -1
46165     // This allows use of COND_S / COND_NS (see TranslateX86CC) which eliminates
46166     // the need for an extra compare against zero. e.g.
46167     // (a - b) > 0 : (a - b) ? 0 -> (a - b) >= 0 : (a - b) ? 0
46168     // subl   %esi, %edi
46169     // testl  %edi, %edi
46170     // movl   $0, %eax
46171     // cmovgl %edi, %eax
46172     // =>
46173     // xorl   %eax, %eax
46174     // subl   %esi, $edi
46175     // cmovsl %eax, %edi
46176     //
46177     // We can also canonicalize
46178     //  (x s> 1) ? x : 1 -> (x s>= 1) ? x : 1 -> (x s> 0) ? x : 1
46179     //  (x u> 1) ? x : 1 -> (x u>= 1) ? x : 1 -> (x != 0) ? x : 1
46180     // This allows the use of a test instruction for the compare.
46181     if (LHS == Cond0 && RHS == Cond1) {
46182       if ((CC == ISD::SETGT && (isNullConstant(RHS) || isOneConstant(RHS))) ||
46183           (CC == ISD::SETLT && isAllOnesConstant(RHS))) {
46184         ISD::CondCode NewCC = CC == ISD::SETGT ? ISD::SETGE : ISD::SETLE;
46185         Cond = DAG.getSetCC(SDLoc(Cond), CondVT, Cond0, Cond1, NewCC);
46186         return DAG.getSelect(DL, VT, Cond, LHS, RHS);
46187       }
46188       if (CC == ISD::SETUGT && isOneConstant(RHS)) {
46189         ISD::CondCode NewCC = ISD::SETUGE;
46190         Cond = DAG.getSetCC(SDLoc(Cond), CondVT, Cond0, Cond1, NewCC);
46191         return DAG.getSelect(DL, VT, Cond, LHS, RHS);
46192       }
46193     }
46194 
46195     // Similar to DAGCombine's select(or(CC0,CC1),X,Y) fold but for legal types.
46196     // fold eq + gt/lt nested selects into ge/le selects
46197     // select (cmpeq Cond0, Cond1), LHS, (select (cmpugt Cond0, Cond1), LHS, Y)
46198     // --> (select (cmpuge Cond0, Cond1), LHS, Y)
46199     // select (cmpslt Cond0, Cond1), LHS, (select (cmpeq Cond0, Cond1), LHS, Y)
46200     // --> (select (cmpsle Cond0, Cond1), LHS, Y)
46201     // .. etc ..
46202     if (RHS.getOpcode() == ISD::SELECT && RHS.getOperand(1) == LHS &&
46203         RHS.getOperand(0).getOpcode() == ISD::SETCC) {
46204       SDValue InnerSetCC = RHS.getOperand(0);
46205       ISD::CondCode InnerCC =
46206           cast<CondCodeSDNode>(InnerSetCC.getOperand(2))->get();
46207       if ((CC == ISD::SETEQ || InnerCC == ISD::SETEQ) &&
46208           Cond0 == InnerSetCC.getOperand(0) &&
46209           Cond1 == InnerSetCC.getOperand(1)) {
46210         ISD::CondCode NewCC;
46211         switch (CC == ISD::SETEQ ? InnerCC : CC) {
46212         // clang-format off
46213         case ISD::SETGT:  NewCC = ISD::SETGE; break;
46214         case ISD::SETLT:  NewCC = ISD::SETLE; break;
46215         case ISD::SETUGT: NewCC = ISD::SETUGE; break;
46216         case ISD::SETULT: NewCC = ISD::SETULE; break;
46217         default: NewCC = ISD::SETCC_INVALID; break;
46218         // clang-format on
46219         }
46220         if (NewCC != ISD::SETCC_INVALID) {
46221           Cond = DAG.getSetCC(DL, CondVT, Cond0, Cond1, NewCC);
46222           return DAG.getSelect(DL, VT, Cond, LHS, RHS.getOperand(2));
46223         }
46224       }
46225     }
46226   }
46227 
46228   // Check if the first operand is all zeros and Cond type is vXi1.
46229   // If this an avx512 target we can improve the use of zero masking by
46230   // swapping the operands and inverting the condition.
46231   if (N->getOpcode() == ISD::VSELECT && Cond.hasOneUse() &&
46232       Subtarget.hasAVX512() && CondVT.getVectorElementType() == MVT::i1 &&
46233       ISD::isBuildVectorAllZeros(LHS.getNode()) &&
46234       !ISD::isBuildVectorAllZeros(RHS.getNode())) {
46235     // Invert the cond to not(cond) : xor(op,allones)=not(op)
46236     SDValue CondNew = DAG.getNOT(DL, Cond, CondVT);
46237     // Vselect cond, op1, op2 = Vselect not(cond), op2, op1
46238     return DAG.getSelect(DL, VT, CondNew, RHS, LHS);
46239   }
46240 
46241   // Attempt to convert a (vXi1 bitcast(iX Cond)) selection mask before it might
46242   // get split by legalization.
46243   if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::BITCAST &&
46244       CondVT.getVectorElementType() == MVT::i1 &&
46245       TLI.isTypeLegal(VT.getScalarType())) {
46246     EVT ExtCondVT = VT.changeVectorElementTypeToInteger();
46247     if (SDValue ExtCond = combineToExtendBoolVectorInReg(
46248             ISD::SIGN_EXTEND, DL, ExtCondVT, Cond, DAG, DCI, Subtarget)) {
46249       ExtCond = DAG.getNode(ISD::TRUNCATE, DL, CondVT, ExtCond);
46250       return DAG.getSelect(DL, VT, ExtCond, LHS, RHS);
46251     }
46252   }
46253 
46254   // Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
46255   // with out-of-bounds clamping.
46256 
46257   // Unlike general shift instructions (SHL/SRL), AVX2's VSHLV/VSRLV handle
46258   // shift amounts exceeding the element bitwidth. VSHLV/VSRLV clamps the amount
46259   // to bitwidth-1 for unsigned shifts, effectively performing a maximum left
46260   // shift of bitwidth-1 positions. and returns zero for unsigned right shifts
46261   // exceeding bitwidth-1.
46262   if (N->getOpcode() == ISD::VSELECT) {
46263     using namespace llvm::SDPatternMatch;
46264     // fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt)
46265     // fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt)
46266     if ((LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
46267         supportedVectorVarShift(VT, Subtarget, LHS.getOpcode()) &&
46268         ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
46269         sd_match(Cond, m_SetCC(m_Specific(LHS.getOperand(1)),
46270                                m_SpecificInt(VT.getScalarSizeInBits()),
46271                                m_SpecificCondCode(ISD::SETULT)))) {
46272       return DAG.getNode(LHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
46273                                                      : X86ISD::VSHLV,
46274                          DL, VT, LHS.getOperand(0), LHS.getOperand(1));
46275     }
46276     // fold select(icmp_uge(amt,BW),0,shl(x,amt)) -> avx2 psllv(x,amt)
46277     // fold select(icmp_uge(amt,BW),0,srl(x,amt)) -> avx2 psrlv(x,amt)
46278     if ((RHS.getOpcode() == ISD::SRL || RHS.getOpcode() == ISD::SHL) &&
46279         supportedVectorVarShift(VT, Subtarget, RHS.getOpcode()) &&
46280         ISD::isConstantSplatVectorAllZeros(LHS.getNode()) &&
46281         sd_match(Cond, m_SetCC(m_Specific(RHS.getOperand(1)),
46282                                m_SpecificInt(VT.getScalarSizeInBits()),
46283                                m_SpecificCondCode(ISD::SETUGE)))) {
46284       return DAG.getNode(RHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
46285                                                      : X86ISD::VSHLV,
46286                          DL, VT, RHS.getOperand(0), RHS.getOperand(1));
46287     }
46288   }
46289 
46290   // Early exit check
46291   if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
46292     return SDValue();
46293 
46294   if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DL, DCI, Subtarget))
46295     return V;
46296 
46297   if (SDValue V = combineVSelectToBLENDV(N, DAG, DL, DCI, Subtarget))
46298     return V;
46299 
46300   if (SDValue V = narrowVectorSelect(N, DAG, DL, Subtarget))
46301     return V;
46302 
46303   // select(~Cond, X, Y) -> select(Cond, Y, X)
46304   if (CondVT.getScalarType() != MVT::i1) {
46305     if (SDValue CondNot = IsNOT(Cond, DAG))
46306       return DAG.getNode(N->getOpcode(), DL, VT,
46307                          DAG.getBitcast(CondVT, CondNot), RHS, LHS);
46308 
46309     // pcmpgt(X, -1) -> pcmpgt(0, X) to help select/blendv just use the
46310     // signbit.
46311     if (Cond.getOpcode() == X86ISD::PCMPGT &&
46312         ISD::isBuildVectorAllOnes(Cond.getOperand(1).getNode()) &&
46313         Cond.hasOneUse()) {
46314       Cond = DAG.getNode(X86ISD::PCMPGT, DL, CondVT,
46315                          DAG.getConstant(0, DL, CondVT), Cond.getOperand(0));
46316       return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS);
46317     }
46318   }
46319 
46320   // Try to optimize vXi1 selects if both operands are either all constants or
46321   // bitcasts from scalar integer type. In that case we can convert the operands
46322   // to integer and use an integer select which will be converted to a CMOV.
46323   // We need to take a little bit of care to avoid creating an i64 type after
46324   // type legalization.
46325   if (N->getOpcode() == ISD::SELECT && VT.isVector() &&
46326       VT.getVectorElementType() == MVT::i1 &&
46327       (DCI.isBeforeLegalize() || (VT != MVT::v64i1 || Subtarget.is64Bit()))) {
46328     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getVectorNumElements());
46329     if (DCI.isBeforeLegalize() || TLI.isTypeLegal(IntVT)) {
46330       bool LHSIsConst = ISD::isBuildVectorOfConstantSDNodes(LHS.getNode());
46331       bool RHSIsConst = ISD::isBuildVectorOfConstantSDNodes(RHS.getNode());
46332 
46333       if ((LHSIsConst || (LHS.getOpcode() == ISD::BITCAST &&
46334                           LHS.getOperand(0).getValueType() == IntVT)) &&
46335           (RHSIsConst || (RHS.getOpcode() == ISD::BITCAST &&
46336                           RHS.getOperand(0).getValueType() == IntVT))) {
46337         if (LHSIsConst)
46338           LHS = combinevXi1ConstantToInteger(LHS, DAG);
46339         else
46340           LHS = LHS.getOperand(0);
46341 
46342         if (RHSIsConst)
46343           RHS = combinevXi1ConstantToInteger(RHS, DAG);
46344         else
46345           RHS = RHS.getOperand(0);
46346 
46347         SDValue Select = DAG.getSelect(DL, IntVT, Cond, LHS, RHS);
46348         return DAG.getBitcast(VT, Select);
46349       }
46350     }
46351   }
46352 
46353   // If this is "((X & C) == 0) ? Y : Z" and C is a constant mask vector of
46354   // single bits, then invert the predicate and swap the select operands.
46355   // This can lower using a vector shift bit-hack rather than mask and compare.
46356   if (DCI.isBeforeLegalize() && !Subtarget.hasAVX512() &&
46357       N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::SETCC &&
46358       Cond.hasOneUse() && CondVT.getVectorElementType() == MVT::i1 &&
46359       Cond.getOperand(0).getOpcode() == ISD::AND &&
46360       isNullOrNullSplat(Cond.getOperand(1)) &&
46361       cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETEQ &&
46362       Cond.getOperand(0).getValueType() == VT) {
46363     // The 'and' mask must be composed of power-of-2 constants.
46364     SDValue And = Cond.getOperand(0);
46365     auto *C = isConstOrConstSplat(And.getOperand(1));
46366     if (C && C->getAPIntValue().isPowerOf2()) {
46367       // vselect (X & C == 0), LHS, RHS --> vselect (X & C != 0), RHS, LHS
46368       SDValue NotCond =
46369           DAG.getSetCC(DL, CondVT, And, Cond.getOperand(1), ISD::SETNE);
46370       return DAG.getSelect(DL, VT, NotCond, RHS, LHS);
46371     }
46372 
46373     // If we have a non-splat but still powers-of-2 mask, AVX1 can use pmulld
46374     // and AVX2 can use vpsllv{dq}. 8-bit lacks a proper shift or multiply.
46375     // 16-bit lacks a proper blendv.
46376     unsigned EltBitWidth = VT.getScalarSizeInBits();
46377     bool CanShiftBlend =
46378         TLI.isTypeLegal(VT) && ((Subtarget.hasAVX() && EltBitWidth == 32) ||
46379                                 (Subtarget.hasAVX2() && EltBitWidth == 64) ||
46380                                 (Subtarget.hasXOP()));
46381     if (CanShiftBlend &&
46382         ISD::matchUnaryPredicate(And.getOperand(1), [](ConstantSDNode *C) {
46383           return C->getAPIntValue().isPowerOf2();
46384         })) {
46385       // Create a left-shift constant to get the mask bits over to the sign-bit.
46386       SDValue Mask = And.getOperand(1);
46387       SmallVector<int, 32> ShlVals;
46388       for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
46389         auto *MaskVal = cast<ConstantSDNode>(Mask.getOperand(i));
46390         ShlVals.push_back(EltBitWidth - 1 -
46391                           MaskVal->getAPIntValue().exactLogBase2());
46392       }
46393       // vsel ((X & C) == 0), LHS, RHS --> vsel ((shl X, C') < 0), RHS, LHS
46394       SDValue ShlAmt = getConstVector(ShlVals, VT.getSimpleVT(), DAG, DL);
46395       SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, And.getOperand(0), ShlAmt);
46396       SDValue NewCond =
46397           DAG.getSetCC(DL, CondVT, Shl, Cond.getOperand(1), ISD::SETLT);
46398       return DAG.getSelect(DL, VT, NewCond, RHS, LHS);
46399     }
46400   }
46401 
46402   return SDValue();
46403 }
46404 
46405 /// Combine:
46406 ///   (brcond/cmov/setcc .., (cmp (atomic_load_add x, 1), 0), COND_S)
46407 /// to:
46408 ///   (brcond/cmov/setcc .., (LADD x, 1), COND_LE)
46409 /// i.e., reusing the EFLAGS produced by the LOCKed instruction.
46410 /// Note that this is only legal for some op/cc combinations.
combineSetCCAtomicArith(SDValue Cmp,X86::CondCode & CC,SelectionDAG & DAG,const X86Subtarget & Subtarget)46411 static SDValue combineSetCCAtomicArith(SDValue Cmp, X86::CondCode &CC,
46412                                        SelectionDAG &DAG,
46413                                        const X86Subtarget &Subtarget) {
46414   // This combine only operates on CMP-like nodes.
46415   if (!(Cmp.getOpcode() == X86ISD::CMP ||
46416         (Cmp.getOpcode() == X86ISD::SUB && !Cmp->hasAnyUseOfValue(0))))
46417     return SDValue();
46418 
46419   // Can't replace the cmp if it has more uses than the one we're looking at.
46420   // FIXME: We would like to be able to handle this, but would need to make sure
46421   // all uses were updated.
46422   if (!Cmp.hasOneUse())
46423     return SDValue();
46424 
46425   // This only applies to variations of the common case:
46426   //   (icmp slt x, 0) -> (icmp sle (add x, 1), 0)
46427   //   (icmp sge x, 0) -> (icmp sgt (add x, 1), 0)
46428   //   (icmp sle x, 0) -> (icmp slt (sub x, 1), 0)
46429   //   (icmp sgt x, 0) -> (icmp sge (sub x, 1), 0)
46430   // Using the proper condcodes (see below), overflow is checked for.
46431 
46432   // FIXME: We can generalize both constraints:
46433   // - XOR/OR/AND (if they were made to survive AtomicExpand)
46434   // - LHS != 1
46435   // if the result is compared.
46436 
46437   SDValue CmpLHS = Cmp.getOperand(0);
46438   SDValue CmpRHS = Cmp.getOperand(1);
46439   EVT CmpVT = CmpLHS.getValueType();
46440 
46441   if (!CmpLHS.hasOneUse())
46442     return SDValue();
46443 
46444   unsigned Opc = CmpLHS.getOpcode();
46445   if (Opc != ISD::ATOMIC_LOAD_ADD && Opc != ISD::ATOMIC_LOAD_SUB)
46446     return SDValue();
46447 
46448   SDValue OpRHS = CmpLHS.getOperand(2);
46449   auto *OpRHSC = dyn_cast<ConstantSDNode>(OpRHS);
46450   if (!OpRHSC)
46451     return SDValue();
46452 
46453   APInt Addend = OpRHSC->getAPIntValue();
46454   if (Opc == ISD::ATOMIC_LOAD_SUB)
46455     Addend = -Addend;
46456 
46457   auto *CmpRHSC = dyn_cast<ConstantSDNode>(CmpRHS);
46458   if (!CmpRHSC)
46459     return SDValue();
46460 
46461   APInt Comparison = CmpRHSC->getAPIntValue();
46462   APInt NegAddend = -Addend;
46463 
46464   // See if we can adjust the CC to make the comparison match the negated
46465   // addend.
46466   if (Comparison != NegAddend) {
46467     APInt IncComparison = Comparison + 1;
46468     if (IncComparison == NegAddend) {
46469       if (CC == X86::COND_A && !Comparison.isMaxValue()) {
46470         Comparison = IncComparison;
46471         CC = X86::COND_AE;
46472       } else if (CC == X86::COND_LE && !Comparison.isMaxSignedValue()) {
46473         Comparison = IncComparison;
46474         CC = X86::COND_L;
46475       }
46476     }
46477     APInt DecComparison = Comparison - 1;
46478     if (DecComparison == NegAddend) {
46479       if (CC == X86::COND_AE && !Comparison.isMinValue()) {
46480         Comparison = DecComparison;
46481         CC = X86::COND_A;
46482       } else if (CC == X86::COND_L && !Comparison.isMinSignedValue()) {
46483         Comparison = DecComparison;
46484         CC = X86::COND_LE;
46485       }
46486     }
46487   }
46488 
46489   // If the addend is the negation of the comparison value, then we can do
46490   // a full comparison by emitting the atomic arithmetic as a locked sub.
46491   if (Comparison == NegAddend) {
46492     // The CC is fine, but we need to rewrite the LHS of the comparison as an
46493     // atomic sub.
46494     auto *AN = cast<AtomicSDNode>(CmpLHS.getNode());
46495     auto AtomicSub = DAG.getAtomic(
46496         ISD::ATOMIC_LOAD_SUB, SDLoc(CmpLHS), CmpVT,
46497         /*Chain*/ CmpLHS.getOperand(0), /*LHS*/ CmpLHS.getOperand(1),
46498         /*RHS*/ DAG.getConstant(NegAddend, SDLoc(CmpRHS), CmpVT),
46499         AN->getMemOperand());
46500     auto LockOp = lowerAtomicArithWithLOCK(AtomicSub, DAG, Subtarget);
46501     DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(0), DAG.getUNDEF(CmpVT));
46502     DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(1), LockOp.getValue(1));
46503     return LockOp;
46504   }
46505 
46506   // We can handle comparisons with zero in a number of cases by manipulating
46507   // the CC used.
46508   if (!Comparison.isZero())
46509     return SDValue();
46510 
46511   if (CC == X86::COND_S && Addend == 1)
46512     CC = X86::COND_LE;
46513   else if (CC == X86::COND_NS && Addend == 1)
46514     CC = X86::COND_G;
46515   else if (CC == X86::COND_G && Addend == -1)
46516     CC = X86::COND_GE;
46517   else if (CC == X86::COND_LE && Addend == -1)
46518     CC = X86::COND_L;
46519   else
46520     return SDValue();
46521 
46522   SDValue LockOp = lowerAtomicArithWithLOCK(CmpLHS, DAG, Subtarget);
46523   DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(0), DAG.getUNDEF(CmpVT));
46524   DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(1), LockOp.getValue(1));
46525   return LockOp;
46526 }
46527 
46528 // Check whether we're just testing the signbit, and whether we can simplify
46529 // this by tracking where the signbit came from.
checkSignTestSetCCCombine(SDValue Cmp,X86::CondCode & CC,SelectionDAG & DAG)46530 static SDValue checkSignTestSetCCCombine(SDValue Cmp, X86::CondCode &CC,
46531                                          SelectionDAG &DAG) {
46532   if (CC != X86::COND_S && CC != X86::COND_NS)
46533     return SDValue();
46534 
46535   if (!Cmp.hasOneUse())
46536     return SDValue();
46537 
46538   SDValue Src;
46539   if (Cmp.getOpcode() == X86ISD::CMP) {
46540     // CMP(X,0) -> signbit test
46541     if (!isNullConstant(Cmp.getOperand(1)))
46542       return SDValue();
46543     Src = Cmp.getOperand(0);
46544     // Peek through a SRA node as we just need the signbit.
46545     // TODO: Remove one use limit once sdiv-fix regressions are fixed.
46546     // TODO: Use SimplifyDemandedBits instead of just SRA?
46547     if (Src.getOpcode() != ISD::SRA || !Src.hasOneUse())
46548       return SDValue();
46549     Src = Src.getOperand(0);
46550   } else if (Cmp.getOpcode() == X86ISD::OR) {
46551     // OR(X,Y) -> see if only one operand contributes to the signbit.
46552     // TODO: XOR(X,Y) -> see if only one operand contributes to the signbit.
46553     if (DAG.SignBitIsZero(Cmp.getOperand(0)))
46554       Src = Cmp.getOperand(1);
46555     else if (DAG.SignBitIsZero(Cmp.getOperand(1)))
46556       Src = Cmp.getOperand(0);
46557     else
46558       return SDValue();
46559   } else {
46560     return SDValue();
46561   }
46562 
46563   // Replace with a TEST on the MSB.
46564   SDLoc DL(Cmp);
46565   MVT SrcVT = Src.getSimpleValueType();
46566   APInt BitMask = APInt::getSignMask(SrcVT.getScalarSizeInBits());
46567 
46568   // If Src came from a SHL (probably from an expanded SIGN_EXTEND_INREG), then
46569   // peek through and adjust the TEST bit.
46570   if (Src.getOpcode() == ISD::SHL) {
46571     if (std::optional<uint64_t> ShiftAmt = DAG.getValidShiftAmount(Src)) {
46572       Src = Src.getOperand(0);
46573       BitMask.lshrInPlace(*ShiftAmt);
46574     }
46575   }
46576 
46577   SDValue Mask = DAG.getNode(ISD::AND, DL, SrcVT, Src,
46578                              DAG.getConstant(BitMask, DL, SrcVT));
46579   CC = CC == X86::COND_S ? X86::COND_NE : X86::COND_E;
46580   return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Mask,
46581                      DAG.getConstant(0, DL, SrcVT));
46582 }
46583 
46584 // Check whether a boolean test is testing a boolean value generated by
46585 // X86ISD::SETCC. If so, return the operand of that SETCC and proper condition
46586 // code.
46587 //
46588 // Simplify the following patterns:
46589 // (Op (CMP (SETCC Cond EFLAGS) 1) EQ) or
46590 // (Op (CMP (SETCC Cond EFLAGS) 0) NEQ)
46591 // to (Op EFLAGS Cond)
46592 //
46593 // (Op (CMP (SETCC Cond EFLAGS) 0) EQ) or
46594 // (Op (CMP (SETCC Cond EFLAGS) 1) NEQ)
46595 // to (Op EFLAGS !Cond)
46596 //
46597 // where Op could be BRCOND or CMOV.
46598 //
checkBoolTestSetCCCombine(SDValue Cmp,X86::CondCode & CC)46599 static SDValue checkBoolTestSetCCCombine(SDValue Cmp, X86::CondCode &CC) {
46600   // This combine only operates on CMP-like nodes.
46601   if (!(Cmp.getOpcode() == X86ISD::CMP ||
46602         (Cmp.getOpcode() == X86ISD::SUB && !Cmp->hasAnyUseOfValue(0))))
46603     return SDValue();
46604 
46605   // Quit if not used as a boolean value.
46606   if (CC != X86::COND_E && CC != X86::COND_NE)
46607     return SDValue();
46608 
46609   // Check CMP operands. One of them should be 0 or 1 and the other should be
46610   // an SetCC or extended from it.
46611   SDValue Op1 = Cmp.getOperand(0);
46612   SDValue Op2 = Cmp.getOperand(1);
46613 
46614   SDValue SetCC;
46615   const ConstantSDNode* C = nullptr;
46616   bool needOppositeCond = (CC == X86::COND_E);
46617   bool checkAgainstTrue = false; // Is it a comparison against 1?
46618 
46619   if ((C = dyn_cast<ConstantSDNode>(Op1)))
46620     SetCC = Op2;
46621   else if ((C = dyn_cast<ConstantSDNode>(Op2)))
46622     SetCC = Op1;
46623   else // Quit if all operands are not constants.
46624     return SDValue();
46625 
46626   if (C->getZExtValue() == 1) {
46627     needOppositeCond = !needOppositeCond;
46628     checkAgainstTrue = true;
46629   } else if (C->getZExtValue() != 0)
46630     // Quit if the constant is neither 0 or 1.
46631     return SDValue();
46632 
46633   bool truncatedToBoolWithAnd = false;
46634   // Skip (zext $x), (trunc $x), or (and $x, 1) node.
46635   while (SetCC.getOpcode() == ISD::ZERO_EXTEND ||
46636          SetCC.getOpcode() == ISD::TRUNCATE ||
46637          SetCC.getOpcode() == ISD::AND) {
46638     if (SetCC.getOpcode() == ISD::AND) {
46639       int OpIdx = -1;
46640       if (isOneConstant(SetCC.getOperand(0)))
46641         OpIdx = 1;
46642       if (isOneConstant(SetCC.getOperand(1)))
46643         OpIdx = 0;
46644       if (OpIdx < 0)
46645         break;
46646       SetCC = SetCC.getOperand(OpIdx);
46647       truncatedToBoolWithAnd = true;
46648     } else
46649       SetCC = SetCC.getOperand(0);
46650   }
46651 
46652   switch (SetCC.getOpcode()) {
46653   case X86ISD::SETCC_CARRY:
46654     // Since SETCC_CARRY gives output based on R = CF ? ~0 : 0, it's unsafe to
46655     // simplify it if the result of SETCC_CARRY is not canonicalized to 0 or 1,
46656     // i.e. it's a comparison against true but the result of SETCC_CARRY is not
46657     // truncated to i1 using 'and'.
46658     if (checkAgainstTrue && !truncatedToBoolWithAnd)
46659       break;
46660     assert(X86::CondCode(SetCC.getConstantOperandVal(0)) == X86::COND_B &&
46661            "Invalid use of SETCC_CARRY!");
46662     [[fallthrough]];
46663   case X86ISD::SETCC:
46664     // Set the condition code or opposite one if necessary.
46665     CC = X86::CondCode(SetCC.getConstantOperandVal(0));
46666     if (needOppositeCond)
46667       CC = X86::GetOppositeBranchCondition(CC);
46668     return SetCC.getOperand(1);
46669   case X86ISD::CMOV: {
46670     // Check whether false/true value has canonical one, i.e. 0 or 1.
46671     ConstantSDNode *FVal = dyn_cast<ConstantSDNode>(SetCC.getOperand(0));
46672     ConstantSDNode *TVal = dyn_cast<ConstantSDNode>(SetCC.getOperand(1));
46673     // Quit if true value is not a constant.
46674     if (!TVal)
46675       return SDValue();
46676     // Quit if false value is not a constant.
46677     if (!FVal) {
46678       SDValue Op = SetCC.getOperand(0);
46679       // Skip 'zext' or 'trunc' node.
46680       if (Op.getOpcode() == ISD::ZERO_EXTEND ||
46681           Op.getOpcode() == ISD::TRUNCATE)
46682         Op = Op.getOperand(0);
46683       // A special case for rdrand/rdseed, where 0 is set if false cond is
46684       // found.
46685       if ((Op.getOpcode() != X86ISD::RDRAND &&
46686            Op.getOpcode() != X86ISD::RDSEED) || Op.getResNo() != 0)
46687         return SDValue();
46688     }
46689     // Quit if false value is not the constant 0 or 1.
46690     bool FValIsFalse = true;
46691     if (FVal && FVal->getZExtValue() != 0) {
46692       if (FVal->getZExtValue() != 1)
46693         return SDValue();
46694       // If FVal is 1, opposite cond is needed.
46695       needOppositeCond = !needOppositeCond;
46696       FValIsFalse = false;
46697     }
46698     // Quit if TVal is not the constant opposite of FVal.
46699     if (FValIsFalse && TVal->getZExtValue() != 1)
46700       return SDValue();
46701     if (!FValIsFalse && TVal->getZExtValue() != 0)
46702       return SDValue();
46703     CC = X86::CondCode(SetCC.getConstantOperandVal(2));
46704     if (needOppositeCond)
46705       CC = X86::GetOppositeBranchCondition(CC);
46706     return SetCC.getOperand(3);
46707   }
46708   }
46709 
46710   return SDValue();
46711 }
46712 
46713 /// Check whether Cond is an AND/OR of SETCCs off of the same EFLAGS.
46714 /// Match:
46715 ///   (X86or (X86setcc) (X86setcc))
46716 ///   (X86cmp (and (X86setcc) (X86setcc)), 0)
checkBoolTestAndOrSetCCCombine(SDValue Cond,X86::CondCode & CC0,X86::CondCode & CC1,SDValue & Flags,bool & isAnd)46717 static bool checkBoolTestAndOrSetCCCombine(SDValue Cond, X86::CondCode &CC0,
46718                                            X86::CondCode &CC1, SDValue &Flags,
46719                                            bool &isAnd) {
46720   if (Cond->getOpcode() == X86ISD::CMP) {
46721     if (!isNullConstant(Cond->getOperand(1)))
46722       return false;
46723 
46724     Cond = Cond->getOperand(0);
46725   }
46726 
46727   isAnd = false;
46728 
46729   SDValue SetCC0, SetCC1;
46730   switch (Cond->getOpcode()) {
46731   default: return false;
46732   case ISD::AND:
46733   case X86ISD::AND:
46734     isAnd = true;
46735     [[fallthrough]];
46736   case ISD::OR:
46737   case X86ISD::OR:
46738     SetCC0 = Cond->getOperand(0);
46739     SetCC1 = Cond->getOperand(1);
46740     break;
46741   };
46742 
46743   // Make sure we have SETCC nodes, using the same flags value.
46744   if (SetCC0.getOpcode() != X86ISD::SETCC ||
46745       SetCC1.getOpcode() != X86ISD::SETCC ||
46746       SetCC0->getOperand(1) != SetCC1->getOperand(1))
46747     return false;
46748 
46749   CC0 = (X86::CondCode)SetCC0->getConstantOperandVal(0);
46750   CC1 = (X86::CondCode)SetCC1->getConstantOperandVal(0);
46751   Flags = SetCC0->getOperand(1);
46752   return true;
46753 }
46754 
46755 // When legalizing carry, we create carries via add X, -1
46756 // If that comes from an actual carry, via setcc, we use the
46757 // carry directly.
combineCarryThroughADD(SDValue EFLAGS,SelectionDAG & DAG)46758 static SDValue combineCarryThroughADD(SDValue EFLAGS, SelectionDAG &DAG) {
46759   if (EFLAGS.getOpcode() == X86ISD::ADD) {
46760     if (isAllOnesConstant(EFLAGS.getOperand(1))) {
46761       bool FoundAndLSB = false;
46762       SDValue Carry = EFLAGS.getOperand(0);
46763       while (Carry.getOpcode() == ISD::TRUNCATE ||
46764              Carry.getOpcode() == ISD::ZERO_EXTEND ||
46765              (Carry.getOpcode() == ISD::AND &&
46766               isOneConstant(Carry.getOperand(1)))) {
46767         FoundAndLSB |= Carry.getOpcode() == ISD::AND;
46768         Carry = Carry.getOperand(0);
46769       }
46770       if (Carry.getOpcode() == X86ISD::SETCC ||
46771           Carry.getOpcode() == X86ISD::SETCC_CARRY) {
46772         // TODO: Merge this code with equivalent in combineAddOrSubToADCOrSBB?
46773         uint64_t CarryCC = Carry.getConstantOperandVal(0);
46774         SDValue CarryOp1 = Carry.getOperand(1);
46775         if (CarryCC == X86::COND_B)
46776           return CarryOp1;
46777         if (CarryCC == X86::COND_A) {
46778           // Try to convert COND_A into COND_B in an attempt to facilitate
46779           // materializing "setb reg".
46780           //
46781           // Do not flip "e > c", where "c" is a constant, because Cmp
46782           // instruction cannot take an immediate as its first operand.
46783           //
46784           if (CarryOp1.getOpcode() == X86ISD::SUB &&
46785               CarryOp1.getNode()->hasOneUse() &&
46786               CarryOp1.getValueType().isInteger() &&
46787               !isa<ConstantSDNode>(CarryOp1.getOperand(1))) {
46788             SDValue SubCommute =
46789                 DAG.getNode(X86ISD::SUB, SDLoc(CarryOp1), CarryOp1->getVTList(),
46790                             CarryOp1.getOperand(1), CarryOp1.getOperand(0));
46791             return SDValue(SubCommute.getNode(), CarryOp1.getResNo());
46792           }
46793         }
46794         // If this is a check of the z flag of an add with 1, switch to the
46795         // C flag.
46796         if (CarryCC == X86::COND_E &&
46797             CarryOp1.getOpcode() == X86ISD::ADD &&
46798             isOneConstant(CarryOp1.getOperand(1)))
46799           return CarryOp1;
46800       } else if (FoundAndLSB) {
46801         SDLoc DL(Carry);
46802         SDValue BitNo = DAG.getConstant(0, DL, Carry.getValueType());
46803         if (Carry.getOpcode() == ISD::SRL) {
46804           BitNo = Carry.getOperand(1);
46805           Carry = Carry.getOperand(0);
46806         }
46807         return getBT(Carry, BitNo, DL, DAG);
46808       }
46809     }
46810   }
46811 
46812   return SDValue();
46813 }
46814 
46815 /// If we are inverting an PTEST/TESTP operand, attempt to adjust the CC
46816 /// to avoid the inversion.
combinePTESTCC(SDValue EFLAGS,X86::CondCode & CC,SelectionDAG & DAG,const X86Subtarget & Subtarget)46817 static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,
46818                               SelectionDAG &DAG,
46819                               const X86Subtarget &Subtarget) {
46820   // TODO: Handle X86ISD::KTEST/X86ISD::KORTEST.
46821   if (EFLAGS.getOpcode() != X86ISD::PTEST &&
46822       EFLAGS.getOpcode() != X86ISD::TESTP)
46823     return SDValue();
46824 
46825   // PTEST/TESTP sets EFLAGS as:
46826   // TESTZ: ZF = (Op0 & Op1) == 0
46827   // TESTC: CF = (~Op0 & Op1) == 0
46828   // TESTNZC: ZF == 0 && CF == 0
46829   MVT VT = EFLAGS.getSimpleValueType();
46830   SDValue Op0 = EFLAGS.getOperand(0);
46831   SDValue Op1 = EFLAGS.getOperand(1);
46832   MVT OpVT = Op0.getSimpleValueType();
46833   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
46834 
46835   // TEST*(~X,Y) == TEST*(X,Y)
46836   if (SDValue NotOp0 = IsNOT(Op0, DAG)) {
46837     X86::CondCode InvCC;
46838     switch (CC) {
46839     case X86::COND_B:
46840       // testc -> testz.
46841       InvCC = X86::COND_E;
46842       break;
46843     case X86::COND_AE:
46844       // !testc -> !testz.
46845       InvCC = X86::COND_NE;
46846       break;
46847     case X86::COND_E:
46848       // testz -> testc.
46849       InvCC = X86::COND_B;
46850       break;
46851     case X86::COND_NE:
46852       // !testz -> !testc.
46853       InvCC = X86::COND_AE;
46854       break;
46855     case X86::COND_A:
46856     case X86::COND_BE:
46857       // testnzc -> testnzc (no change).
46858       InvCC = CC;
46859       break;
46860     default:
46861       InvCC = X86::COND_INVALID;
46862       break;
46863     }
46864 
46865     if (InvCC != X86::COND_INVALID) {
46866       CC = InvCC;
46867       return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT,
46868                          DAG.getBitcast(OpVT, NotOp0), Op1);
46869     }
46870   }
46871 
46872   if (CC == X86::COND_B || CC == X86::COND_AE) {
46873     // TESTC(X,~X) == TESTC(X,-1)
46874     if (SDValue NotOp1 = IsNOT(Op1, DAG)) {
46875       if (peekThroughBitcasts(NotOp1) == peekThroughBitcasts(Op0)) {
46876         SDLoc DL(EFLAGS);
46877         return DAG.getNode(
46878             EFLAGS.getOpcode(), DL, VT, DAG.getBitcast(OpVT, NotOp1),
46879             DAG.getBitcast(OpVT,
46880                            DAG.getAllOnesConstant(DL, NotOp1.getValueType())));
46881       }
46882     }
46883   }
46884 
46885   if (CC == X86::COND_E || CC == X86::COND_NE) {
46886     // TESTZ(X,~Y) == TESTC(Y,X)
46887     if (SDValue NotOp1 = IsNOT(Op1, DAG)) {
46888       CC = (CC == X86::COND_E ? X86::COND_B : X86::COND_AE);
46889       return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT,
46890                          DAG.getBitcast(OpVT, NotOp1), Op0);
46891     }
46892 
46893     if (Op0 == Op1) {
46894       SDValue BC = peekThroughBitcasts(Op0);
46895       EVT BCVT = BC.getValueType();
46896 
46897       // TESTZ(AND(X,Y),AND(X,Y)) == TESTZ(X,Y)
46898       if (BC.getOpcode() == ISD::AND || BC.getOpcode() == X86ISD::FAND) {
46899         return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT,
46900                            DAG.getBitcast(OpVT, BC.getOperand(0)),
46901                            DAG.getBitcast(OpVT, BC.getOperand(1)));
46902       }
46903 
46904       // TESTZ(AND(~X,Y),AND(~X,Y)) == TESTC(X,Y)
46905       if (BC.getOpcode() == X86ISD::ANDNP || BC.getOpcode() == X86ISD::FANDN) {
46906         CC = (CC == X86::COND_E ? X86::COND_B : X86::COND_AE);
46907         return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT,
46908                            DAG.getBitcast(OpVT, BC.getOperand(0)),
46909                            DAG.getBitcast(OpVT, BC.getOperand(1)));
46910       }
46911 
46912       // If every element is an all-sign value, see if we can use TESTP/MOVMSK
46913       // to more efficiently extract the sign bits and compare that.
46914       // TODO: Handle TESTC with comparison inversion.
46915       // TODO: Can we remove SimplifyMultipleUseDemandedBits and rely on
46916       // TESTP/MOVMSK combines to make sure its never worse than PTEST?
46917       if (BCVT.isVector() && TLI.isTypeLegal(BCVT)) {
46918         unsigned EltBits = BCVT.getScalarSizeInBits();
46919         if (DAG.ComputeNumSignBits(BC) == EltBits) {
46920           assert(VT == MVT::i32 && "Expected i32 EFLAGS comparison result");
46921           APInt SignMask = APInt::getSignMask(EltBits);
46922           if (SDValue Res =
46923                   TLI.SimplifyMultipleUseDemandedBits(BC, SignMask, DAG)) {
46924             // For vXi16 cases we need to use pmovmksb and extract every other
46925             // sign bit.
46926             SDLoc DL(EFLAGS);
46927             if ((EltBits == 32 || EltBits == 64) && Subtarget.hasAVX()) {
46928               MVT FloatSVT = MVT::getFloatingPointVT(EltBits);
46929               MVT FloatVT =
46930                   MVT::getVectorVT(FloatSVT, OpVT.getSizeInBits() / EltBits);
46931               Res = DAG.getBitcast(FloatVT, Res);
46932               return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Res, Res);
46933             } else if (EltBits == 16) {
46934               MVT MovmskVT = BCVT.is128BitVector() ? MVT::v16i8 : MVT::v32i8;
46935               Res = DAG.getBitcast(MovmskVT, Res);
46936               Res = getPMOVMSKB(DL, Res, DAG, Subtarget);
46937               Res = DAG.getNode(ISD::AND, DL, MVT::i32, Res,
46938                                 DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
46939             } else {
46940               Res = getPMOVMSKB(DL, Res, DAG, Subtarget);
46941             }
46942             return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Res,
46943                                DAG.getConstant(0, DL, MVT::i32));
46944           }
46945         }
46946       }
46947     }
46948 
46949     // TESTZ(-1,X) == TESTZ(X,X)
46950     if (ISD::isBuildVectorAllOnes(Op0.getNode()))
46951       return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op1, Op1);
46952 
46953     // TESTZ(X,-1) == TESTZ(X,X)
46954     if (ISD::isBuildVectorAllOnes(Op1.getNode()))
46955       return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op0, Op0);
46956 
46957     // TESTZ(OR(LO(X),HI(X)),OR(LO(Y),HI(Y))) -> TESTZ(X,Y)
46958     // TODO: Add COND_NE handling?
46959     if (CC == X86::COND_E && OpVT.is128BitVector() && Subtarget.hasAVX()) {
46960       SDValue Src0 = peekThroughBitcasts(Op0);
46961       SDValue Src1 = peekThroughBitcasts(Op1);
46962       if (Src0.getOpcode() == ISD::OR && Src1.getOpcode() == ISD::OR) {
46963         Src0 = getSplitVectorSrc(peekThroughBitcasts(Src0.getOperand(0)),
46964                                  peekThroughBitcasts(Src0.getOperand(1)), true);
46965         Src1 = getSplitVectorSrc(peekThroughBitcasts(Src1.getOperand(0)),
46966                                  peekThroughBitcasts(Src1.getOperand(1)), true);
46967         if (Src0 && Src1) {
46968           MVT OpVT2 = OpVT.getDoubleNumVectorElementsVT();
46969           return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT,
46970                              DAG.getBitcast(OpVT2, Src0),
46971                              DAG.getBitcast(OpVT2, Src1));
46972         }
46973       }
46974     }
46975   }
46976 
46977   return SDValue();
46978 }
46979 
46980 // Attempt to simplify the MOVMSK input based on the comparison type.
combineSetCCMOVMSK(SDValue EFLAGS,X86::CondCode & CC,SelectionDAG & DAG,const X86Subtarget & Subtarget)46981 static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
46982                                   SelectionDAG &DAG,
46983                                   const X86Subtarget &Subtarget) {
46984   // Handle eq/ne against zero (any_of).
46985   // Handle eq/ne against -1 (all_of).
46986   if (!(CC == X86::COND_E || CC == X86::COND_NE))
46987     return SDValue();
46988   if (EFLAGS.getValueType() != MVT::i32)
46989     return SDValue();
46990   unsigned CmpOpcode = EFLAGS.getOpcode();
46991   if (CmpOpcode != X86ISD::CMP && CmpOpcode != X86ISD::SUB)
46992     return SDValue();
46993   auto *CmpConstant = dyn_cast<ConstantSDNode>(EFLAGS.getOperand(1));
46994   if (!CmpConstant)
46995     return SDValue();
46996   const APInt &CmpVal = CmpConstant->getAPIntValue();
46997 
46998   SDValue CmpOp = EFLAGS.getOperand(0);
46999   unsigned CmpBits = CmpOp.getValueSizeInBits();
47000   assert(CmpBits == CmpVal.getBitWidth() && "Value size mismatch");
47001 
47002   // Peek through any truncate.
47003   if (CmpOp.getOpcode() == ISD::TRUNCATE)
47004     CmpOp = CmpOp.getOperand(0);
47005 
47006   // Bail if we don't find a MOVMSK.
47007   if (CmpOp.getOpcode() != X86ISD::MOVMSK)
47008     return SDValue();
47009 
47010   SDValue Vec = CmpOp.getOperand(0);
47011   MVT VecVT = Vec.getSimpleValueType();
47012   assert((VecVT.is128BitVector() || VecVT.is256BitVector()) &&
47013          "Unexpected MOVMSK operand");
47014   unsigned NumElts = VecVT.getVectorNumElements();
47015   unsigned NumEltBits = VecVT.getScalarSizeInBits();
47016 
47017   bool IsAnyOf = CmpOpcode == X86ISD::CMP && CmpVal.isZero();
47018   bool IsAllOf = (CmpOpcode == X86ISD::SUB || CmpOpcode == X86ISD::CMP) &&
47019                  NumElts <= CmpBits && CmpVal.isMask(NumElts);
47020   if (!IsAnyOf && !IsAllOf)
47021     return SDValue();
47022 
47023   // TODO: Check more combining cases for me.
47024   // Here we check the cmp use number to decide do combining or not.
47025   // Currently we only get 2 tests about combining "MOVMSK(CONCAT(..))"
47026   // and "MOVMSK(PCMPEQ(..))" are fit to use this constraint.
47027   bool IsOneUse = CmpOp.getNode()->hasOneUse();
47028 
47029   // See if we can peek through to a vector with a wider element type, if the
47030   // signbits extend down to all the sub-elements as well.
47031   // Calling MOVMSK with the wider type, avoiding the bitcast, helps expose
47032   // potential SimplifyDemandedBits/Elts cases.
47033   // If we looked through a truncate that discard bits, we can't do this
47034   // transform.
47035   // FIXME: We could do this transform for truncates that discarded bits by
47036   // inserting an AND mask between the new MOVMSK and the CMP.
47037   if (Vec.getOpcode() == ISD::BITCAST && NumElts <= CmpBits) {
47038     SDValue BC = peekThroughBitcasts(Vec);
47039     MVT BCVT = BC.getSimpleValueType();
47040     unsigned BCNumElts = BCVT.getVectorNumElements();
47041     unsigned BCNumEltBits = BCVT.getScalarSizeInBits();
47042     if ((BCNumEltBits == 32 || BCNumEltBits == 64) &&
47043         BCNumEltBits > NumEltBits &&
47044         DAG.ComputeNumSignBits(BC) > (BCNumEltBits - NumEltBits)) {
47045       SDLoc DL(EFLAGS);
47046       APInt CmpMask = APInt::getLowBitsSet(32, IsAnyOf ? 0 : BCNumElts);
47047       return DAG.getNode(X86ISD::CMP, DL, MVT::i32,
47048                          DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, BC),
47049                          DAG.getConstant(CmpMask, DL, MVT::i32));
47050     }
47051   }
47052 
47053   // MOVMSK(CONCAT(X,Y)) == 0 ->  MOVMSK(OR(X,Y)).
47054   // MOVMSK(CONCAT(X,Y)) != 0 ->  MOVMSK(OR(X,Y)).
47055   // MOVMSK(CONCAT(X,Y)) == -1 ->  MOVMSK(AND(X,Y)).
47056   // MOVMSK(CONCAT(X,Y)) != -1 ->  MOVMSK(AND(X,Y)).
47057   if (VecVT.is256BitVector() && NumElts <= CmpBits && IsOneUse) {
47058     SmallVector<SDValue> Ops;
47059     if (collectConcatOps(peekThroughBitcasts(Vec).getNode(), Ops, DAG) &&
47060         Ops.size() == 2) {
47061       SDLoc DL(EFLAGS);
47062       EVT SubVT = Ops[0].getValueType().changeTypeToInteger();
47063       APInt CmpMask = APInt::getLowBitsSet(32, IsAnyOf ? 0 : NumElts / 2);
47064       SDValue V = DAG.getNode(IsAnyOf ? ISD::OR : ISD::AND, DL, SubVT,
47065                               DAG.getBitcast(SubVT, Ops[0]),
47066                               DAG.getBitcast(SubVT, Ops[1]));
47067       V = DAG.getBitcast(VecVT.getHalfNumVectorElementsVT(), V);
47068       return DAG.getNode(X86ISD::CMP, DL, MVT::i32,
47069                          DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V),
47070                          DAG.getConstant(CmpMask, DL, MVT::i32));
47071     }
47072   }
47073 
47074   // MOVMSK(PCMPEQ(X,0)) == -1 -> PTESTZ(X,X).
47075   // MOVMSK(PCMPEQ(X,0)) != -1 -> !PTESTZ(X,X).
47076   // MOVMSK(PCMPEQ(X,Y)) == -1 -> PTESTZ(XOR(X,Y),XOR(X,Y)).
47077   // MOVMSK(PCMPEQ(X,Y)) != -1 -> !PTESTZ(XOR(X,Y),XOR(X,Y)).
47078   if (IsAllOf && Subtarget.hasSSE41() && IsOneUse) {
47079     MVT TestVT = VecVT.is128BitVector() ? MVT::v2i64 : MVT::v4i64;
47080     SDValue BC = peekThroughBitcasts(Vec);
47081     // Ensure MOVMSK was testing every signbit of BC.
47082     if (BC.getValueType().getVectorNumElements() <= NumElts) {
47083       if (BC.getOpcode() == X86ISD::PCMPEQ) {
47084         SDValue V = DAG.getNode(ISD::XOR, SDLoc(BC), BC.getValueType(),
47085                                 BC.getOperand(0), BC.getOperand(1));
47086         V = DAG.getBitcast(TestVT, V);
47087         return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
47088       }
47089       // Check for 256-bit split vector cases.
47090       if (BC.getOpcode() == ISD::AND &&
47091           BC.getOperand(0).getOpcode() == X86ISD::PCMPEQ &&
47092           BC.getOperand(1).getOpcode() == X86ISD::PCMPEQ) {
47093         SDValue LHS = BC.getOperand(0);
47094         SDValue RHS = BC.getOperand(1);
47095         LHS = DAG.getNode(ISD::XOR, SDLoc(LHS), LHS.getValueType(),
47096                           LHS.getOperand(0), LHS.getOperand(1));
47097         RHS = DAG.getNode(ISD::XOR, SDLoc(RHS), RHS.getValueType(),
47098                           RHS.getOperand(0), RHS.getOperand(1));
47099         LHS = DAG.getBitcast(TestVT, LHS);
47100         RHS = DAG.getBitcast(TestVT, RHS);
47101         SDValue V = DAG.getNode(ISD::OR, SDLoc(EFLAGS), TestVT, LHS, RHS);
47102         return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
47103       }
47104     }
47105   }
47106 
47107   // See if we can avoid a PACKSS by calling MOVMSK on the sources.
47108   // For vXi16 cases we can use a v2Xi8 PMOVMSKB. We must mask out
47109   // sign bits prior to the comparison with zero unless we know that
47110   // the vXi16 splats the sign bit down to the lower i8 half.
47111   // TODO: Handle all_of patterns.
47112   if (Vec.getOpcode() == X86ISD::PACKSS && VecVT == MVT::v16i8) {
47113     SDValue VecOp0 = Vec.getOperand(0);
47114     SDValue VecOp1 = Vec.getOperand(1);
47115     bool SignExt0 = DAG.ComputeNumSignBits(VecOp0) > 8;
47116     bool SignExt1 = DAG.ComputeNumSignBits(VecOp1) > 8;
47117     // PMOVMSKB(PACKSSBW(X, undef)) -> PMOVMSKB(BITCAST_v16i8(X)) & 0xAAAA.
47118     if (IsAnyOf && CmpBits == 8 && VecOp1.isUndef()) {
47119       SDLoc DL(EFLAGS);
47120       SDValue Result = DAG.getBitcast(MVT::v16i8, VecOp0);
47121       Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
47122       Result = DAG.getZExtOrTrunc(Result, DL, MVT::i16);
47123       if (!SignExt0) {
47124         Result = DAG.getNode(ISD::AND, DL, MVT::i16, Result,
47125                              DAG.getConstant(0xAAAA, DL, MVT::i16));
47126       }
47127       return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
47128                          DAG.getConstant(0, DL, MVT::i16));
47129     }
47130     // PMOVMSKB(PACKSSBW(LO(X), HI(X)))
47131     // -> PMOVMSKB(BITCAST_v32i8(X)) & 0xAAAAAAAA.
47132     if (CmpBits >= 16 && Subtarget.hasInt256() &&
47133         (IsAnyOf || (SignExt0 && SignExt1))) {
47134       if (SDValue Src = getSplitVectorSrc(VecOp0, VecOp1, true)) {
47135         SDLoc DL(EFLAGS);
47136         SDValue Result = peekThroughBitcasts(Src);
47137         if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ &&
47138             Result.getValueType().getVectorNumElements() <= NumElts) {
47139           SDValue V = DAG.getNode(ISD::XOR, DL, Result.getValueType(),
47140                                   Result.getOperand(0), Result.getOperand(1));
47141           V = DAG.getBitcast(MVT::v4i64, V);
47142           return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
47143         }
47144         Result = DAG.getBitcast(MVT::v32i8, Result);
47145         Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
47146         unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF;
47147         if (!SignExt0 || !SignExt1) {
47148           assert(IsAnyOf &&
47149                  "Only perform v16i16 signmasks for any_of patterns");
47150           Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result,
47151                                DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
47152         }
47153         return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
47154                            DAG.getConstant(CmpMask, DL, MVT::i32));
47155       }
47156     }
47157   }
47158 
47159   // MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
47160   // Since we peek through a bitcast, we need to be careful if the base vector
47161   // type has smaller elements than the MOVMSK type.  In that case, even if
47162   // all the elements are demanded by the shuffle mask, only the "high"
47163   // elements which have highbits that align with highbits in the MOVMSK vec
47164   // elements are actually demanded. A simplification of spurious operations
47165   // on the "low" elements take place during other simplifications.
47166   //
47167   // For example:
47168   // MOVMSK64(BITCAST(SHUF32 X, (1,0,3,2))) even though all the elements are
47169   // demanded, because we are swapping around the result can change.
47170   //
47171   // To address this, we check that we can scale the shuffle mask to MOVMSK
47172   // element width (this will ensure "high" elements match). Its slightly overly
47173   // conservative, but fine for an edge case fold.
47174   SmallVector<int, 32> ShuffleMask;
47175   SmallVector<SDValue, 2> ShuffleInputs;
47176   if (NumElts <= CmpBits &&
47177       getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
47178                              ShuffleMask, DAG) &&
47179       ShuffleInputs.size() == 1 && isCompletePermute(ShuffleMask) &&
47180       ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits() &&
47181       canScaleShuffleElements(ShuffleMask, NumElts)) {
47182     SDLoc DL(EFLAGS);
47183     SDValue Result = DAG.getBitcast(VecVT, ShuffleInputs[0]);
47184     Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
47185     Result =
47186         DAG.getZExtOrTrunc(Result, DL, EFLAGS.getOperand(0).getValueType());
47187     return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result, EFLAGS.getOperand(1));
47188   }
47189 
47190   // MOVMSKPS(V) !=/== 0 -> TESTPS(V,V)
47191   // MOVMSKPD(V) !=/== 0 -> TESTPD(V,V)
47192   // MOVMSKPS(V) !=/== -1 -> TESTPS(V,V)
47193   // MOVMSKPD(V) !=/== -1 -> TESTPD(V,V)
47194   // iff every element is referenced.
47195   if (NumElts <= CmpBits && Subtarget.hasAVX() &&
47196       !Subtarget.preferMovmskOverVTest() && IsOneUse &&
47197       (NumEltBits == 32 || NumEltBits == 64)) {
47198     SDLoc DL(EFLAGS);
47199     MVT FloatSVT = MVT::getFloatingPointVT(NumEltBits);
47200     MVT FloatVT = MVT::getVectorVT(FloatSVT, NumElts);
47201     MVT IntVT = FloatVT.changeVectorElementTypeToInteger();
47202     SDValue LHS = Vec;
47203     SDValue RHS = IsAnyOf ? Vec : DAG.getAllOnesConstant(DL, IntVT);
47204     CC = IsAnyOf ? CC : (CC == X86::COND_E ? X86::COND_B : X86::COND_AE);
47205     return DAG.getNode(X86ISD::TESTP, DL, MVT::i32,
47206                        DAG.getBitcast(FloatVT, LHS),
47207                        DAG.getBitcast(FloatVT, RHS));
47208   }
47209 
47210   return SDValue();
47211 }
47212 
47213 /// Optimize an EFLAGS definition used according to the condition code \p CC
47214 /// into a simpler EFLAGS value, potentially returning a new \p CC and replacing
47215 /// uses of chain values.
combineSetCCEFLAGS(SDValue EFLAGS,X86::CondCode & CC,SelectionDAG & DAG,const X86Subtarget & Subtarget)47216 static SDValue combineSetCCEFLAGS(SDValue EFLAGS, X86::CondCode &CC,
47217                                   SelectionDAG &DAG,
47218                                   const X86Subtarget &Subtarget) {
47219   if (CC == X86::COND_B)
47220     if (SDValue Flags = combineCarryThroughADD(EFLAGS, DAG))
47221       return Flags;
47222 
47223   if (SDValue R = checkSignTestSetCCCombine(EFLAGS, CC, DAG))
47224     return R;
47225 
47226   if (SDValue R = checkBoolTestSetCCCombine(EFLAGS, CC))
47227     return R;
47228 
47229   if (SDValue R = combinePTESTCC(EFLAGS, CC, DAG, Subtarget))
47230     return R;
47231 
47232   if (SDValue R = combineSetCCMOVMSK(EFLAGS, CC, DAG, Subtarget))
47233     return R;
47234 
47235   return combineSetCCAtomicArith(EFLAGS, CC, DAG, Subtarget);
47236 }
47237 
47238 /// Optimize X86ISD::CMOV [LHS, RHS, CONDCODE (e.g. X86::COND_NE), CONDVAL]
combineCMov(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)47239 static SDValue combineCMov(SDNode *N, SelectionDAG &DAG,
47240                            TargetLowering::DAGCombinerInfo &DCI,
47241                            const X86Subtarget &Subtarget) {
47242   SDLoc DL(N);
47243 
47244   SDValue FalseOp = N->getOperand(0);
47245   SDValue TrueOp = N->getOperand(1);
47246   X86::CondCode CC = (X86::CondCode)N->getConstantOperandVal(2);
47247   SDValue Cond = N->getOperand(3);
47248 
47249   // cmov X, X, ?, ? --> X
47250   if (TrueOp == FalseOp)
47251     return TrueOp;
47252 
47253   // Try to simplify the EFLAGS and condition code operands.
47254   // We can't always do this as FCMOV only supports a subset of X86 cond.
47255   if (SDValue Flags = combineSetCCEFLAGS(Cond, CC, DAG, Subtarget)) {
47256     if (!(FalseOp.getValueType() == MVT::f80 ||
47257           (FalseOp.getValueType() == MVT::f64 && !Subtarget.hasSSE2()) ||
47258           (FalseOp.getValueType() == MVT::f32 && !Subtarget.hasSSE1())) ||
47259         !Subtarget.canUseCMOV() || hasFPCMov(CC)) {
47260       SDValue Ops[] = {FalseOp, TrueOp, DAG.getTargetConstant(CC, DL, MVT::i8),
47261                        Flags};
47262       return DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), Ops);
47263     }
47264   }
47265 
47266   // If this is a select between two integer constants, try to do some
47267   // optimizations.  Note that the operands are ordered the opposite of SELECT
47268   // operands.
47269   if (ConstantSDNode *TrueC = dyn_cast<ConstantSDNode>(TrueOp)) {
47270     if (ConstantSDNode *FalseC = dyn_cast<ConstantSDNode>(FalseOp)) {
47271       // Canonicalize the TrueC/FalseC values so that TrueC (the true value) is
47272       // larger than FalseC (the false value).
47273       if (TrueC->getAPIntValue().ult(FalseC->getAPIntValue())) {
47274         CC = X86::GetOppositeBranchCondition(CC);
47275         std::swap(TrueC, FalseC);
47276         std::swap(TrueOp, FalseOp);
47277       }
47278 
47279       // Optimize C ? 8 : 0 -> zext(setcc(C)) << 3.  Likewise for any pow2/0.
47280       // This is efficient for any integer data type (including i8/i16) and
47281       // shift amount.
47282       if (FalseC->getAPIntValue() == 0 && TrueC->getAPIntValue().isPowerOf2()) {
47283         Cond = getSETCC(CC, Cond, DL, DAG);
47284 
47285         // Zero extend the condition if needed.
47286         Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, TrueC->getValueType(0), Cond);
47287 
47288         unsigned ShAmt = TrueC->getAPIntValue().logBase2();
47289         Cond = DAG.getNode(ISD::SHL, DL, Cond.getValueType(), Cond,
47290                            DAG.getConstant(ShAmt, DL, MVT::i8));
47291         return Cond;
47292       }
47293 
47294       // Optimize Cond ? cst+1 : cst -> zext(setcc(C)+cst.  This is efficient
47295       // for any integer data type, including i8/i16.
47296       if (FalseC->getAPIntValue()+1 == TrueC->getAPIntValue()) {
47297         Cond = getSETCC(CC, Cond, DL, DAG);
47298 
47299         // Zero extend the condition if needed.
47300         Cond = DAG.getNode(ISD::ZERO_EXTEND, DL,
47301                            FalseC->getValueType(0), Cond);
47302         Cond = DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond,
47303                            SDValue(FalseC, 0));
47304         return Cond;
47305       }
47306 
47307       // Optimize cases that will turn into an LEA instruction.  This requires
47308       // an i32 or i64 and an efficient multiplier (1, 2, 3, 4, 5, 8, 9).
47309       if (N->getValueType(0) == MVT::i32 || N->getValueType(0) == MVT::i64) {
47310         APInt Diff = TrueC->getAPIntValue() - FalseC->getAPIntValue();
47311         assert(Diff.getBitWidth() == N->getValueType(0).getSizeInBits() &&
47312                "Implicit constant truncation");
47313 
47314         bool isFastMultiplier = false;
47315         if (Diff.ult(10)) {
47316           switch (Diff.getZExtValue()) {
47317           default: break;
47318           case 1:  // result = add base, cond
47319           case 2:  // result = lea base(    , cond*2)
47320           case 3:  // result = lea base(cond, cond*2)
47321           case 4:  // result = lea base(    , cond*4)
47322           case 5:  // result = lea base(cond, cond*4)
47323           case 8:  // result = lea base(    , cond*8)
47324           case 9:  // result = lea base(cond, cond*8)
47325             isFastMultiplier = true;
47326             break;
47327           }
47328         }
47329 
47330         if (isFastMultiplier) {
47331           Cond = getSETCC(CC, Cond, DL ,DAG);
47332           // Zero extend the condition if needed.
47333           Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, FalseC->getValueType(0),
47334                              Cond);
47335           // Scale the condition by the difference.
47336           if (Diff != 1)
47337             Cond = DAG.getNode(ISD::MUL, DL, Cond.getValueType(), Cond,
47338                                DAG.getConstant(Diff, DL, Cond.getValueType()));
47339 
47340           // Add the base if non-zero.
47341           if (FalseC->getAPIntValue() != 0)
47342             Cond = DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond,
47343                                SDValue(FalseC, 0));
47344           return Cond;
47345         }
47346       }
47347     }
47348   }
47349 
47350   // Handle these cases:
47351   //   (select (x != c), e, c) -> select (x != c), e, x),
47352   //   (select (x == c), c, e) -> select (x == c), x, e)
47353   // where the c is an integer constant, and the "select" is the combination
47354   // of CMOV and CMP.
47355   //
47356   // The rationale for this change is that the conditional-move from a constant
47357   // needs two instructions, however, conditional-move from a register needs
47358   // only one instruction.
47359   //
47360   // CAVEAT: By replacing a constant with a symbolic value, it may obscure
47361   //  some instruction-combining opportunities. This opt needs to be
47362   //  postponed as late as possible.
47363   //
47364   if (!DCI.isBeforeLegalize() && !DCI.isBeforeLegalizeOps()) {
47365     // the DCI.xxxx conditions are provided to postpone the optimization as
47366     // late as possible.
47367 
47368     ConstantSDNode *CmpAgainst = nullptr;
47369     if ((Cond.getOpcode() == X86ISD::CMP || Cond.getOpcode() == X86ISD::SUB) &&
47370         (CmpAgainst = dyn_cast<ConstantSDNode>(Cond.getOperand(1))) &&
47371         !isa<ConstantSDNode>(Cond.getOperand(0))) {
47372 
47373       if (CC == X86::COND_NE &&
47374           CmpAgainst == dyn_cast<ConstantSDNode>(FalseOp)) {
47375         CC = X86::GetOppositeBranchCondition(CC);
47376         std::swap(TrueOp, FalseOp);
47377       }
47378 
47379       if (CC == X86::COND_E &&
47380           CmpAgainst == dyn_cast<ConstantSDNode>(TrueOp)) {
47381         SDValue Ops[] = {FalseOp, Cond.getOperand(0),
47382                          DAG.getTargetConstant(CC, DL, MVT::i8), Cond};
47383         return DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), Ops);
47384       }
47385     }
47386   }
47387 
47388   // Transform:
47389   //
47390   //   (cmov 1 T (uge T 2))
47391   //
47392   // to:
47393   //
47394   //   (adc T 0 (sub T 1))
47395   if (CC == X86::COND_AE && isOneConstant(FalseOp) &&
47396       Cond.getOpcode() == X86ISD::SUB && Cond->hasOneUse()) {
47397     SDValue Cond0 = Cond.getOperand(0);
47398     if (Cond0.getOpcode() == ISD::TRUNCATE)
47399       Cond0 = Cond0.getOperand(0);
47400     auto *Sub1C = dyn_cast<ConstantSDNode>(Cond.getOperand(1));
47401     if (Cond0 == TrueOp && Sub1C && Sub1C->getZExtValue() == 2) {
47402       EVT CondVT = Cond->getValueType(0);
47403       EVT OuterVT = N->getValueType(0);
47404       // Subtract 1 and generate a carry.
47405       SDValue NewSub =
47406           DAG.getNode(X86ISD::SUB, DL, Cond->getVTList(), Cond.getOperand(0),
47407                       DAG.getConstant(1, DL, CondVT));
47408       SDValue EFLAGS(NewSub.getNode(), 1);
47409       return DAG.getNode(X86ISD::ADC, DL, DAG.getVTList(OuterVT, MVT::i32),
47410                          TrueOp, DAG.getConstant(0, DL, OuterVT), EFLAGS);
47411     }
47412   }
47413 
47414   // Fold and/or of setcc's to double CMOV:
47415   //   (CMOV F, T, ((cc1 | cc2) != 0)) -> (CMOV (CMOV F, T, cc1), T, cc2)
47416   //   (CMOV F, T, ((cc1 & cc2) != 0)) -> (CMOV (CMOV T, F, !cc1), F, !cc2)
47417   //
47418   // This combine lets us generate:
47419   //   cmovcc1 (jcc1 if we don't have CMOV)
47420   //   cmovcc2 (same)
47421   // instead of:
47422   //   setcc1
47423   //   setcc2
47424   //   and/or
47425   //   cmovne (jne if we don't have CMOV)
47426   // When we can't use the CMOV instruction, it might increase branch
47427   // mispredicts.
47428   // When we can use CMOV, or when there is no mispredict, this improves
47429   // throughput and reduces register pressure.
47430   //
47431   if (CC == X86::COND_NE) {
47432     SDValue Flags;
47433     X86::CondCode CC0, CC1;
47434     bool isAndSetCC;
47435     if (checkBoolTestAndOrSetCCCombine(Cond, CC0, CC1, Flags, isAndSetCC)) {
47436       if (isAndSetCC) {
47437         std::swap(FalseOp, TrueOp);
47438         CC0 = X86::GetOppositeBranchCondition(CC0);
47439         CC1 = X86::GetOppositeBranchCondition(CC1);
47440       }
47441 
47442       SDValue LOps[] = {FalseOp, TrueOp,
47443                         DAG.getTargetConstant(CC0, DL, MVT::i8), Flags};
47444       SDValue LCMOV = DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), LOps);
47445       SDValue Ops[] = {LCMOV, TrueOp, DAG.getTargetConstant(CC1, DL, MVT::i8),
47446                        Flags};
47447       SDValue CMOV = DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), Ops);
47448       return CMOV;
47449     }
47450   }
47451 
47452   // Fold (CMOV C1, (ADD (CTTZ X), C2), (X != 0)) ->
47453   //      (ADD (CMOV C1-C2, (CTTZ X), (X != 0)), C2)
47454   // Or (CMOV (ADD (CTTZ X), C2), C1, (X == 0)) ->
47455   //    (ADD (CMOV (CTTZ X), C1-C2, (X == 0)), C2)
47456   if ((CC == X86::COND_NE || CC == X86::COND_E) &&
47457       Cond.getOpcode() == X86ISD::CMP && isNullConstant(Cond.getOperand(1))) {
47458     SDValue Add = TrueOp;
47459     SDValue Const = FalseOp;
47460     // Canonicalize the condition code for easier matching and output.
47461     if (CC == X86::COND_E)
47462       std::swap(Add, Const);
47463 
47464     // We might have replaced the constant in the cmov with the LHS of the
47465     // compare. If so change it to the RHS of the compare.
47466     if (Const == Cond.getOperand(0))
47467       Const = Cond.getOperand(1);
47468 
47469     // Ok, now make sure that Add is (add (cttz X), C2) and Const is a constant.
47470     if (isa<ConstantSDNode>(Const) && Add.getOpcode() == ISD::ADD &&
47471         Add.hasOneUse() && isa<ConstantSDNode>(Add.getOperand(1)) &&
47472         (Add.getOperand(0).getOpcode() == ISD::CTTZ_ZERO_UNDEF ||
47473          Add.getOperand(0).getOpcode() == ISD::CTTZ) &&
47474         Add.getOperand(0).getOperand(0) == Cond.getOperand(0)) {
47475       EVT VT = N->getValueType(0);
47476       // This should constant fold.
47477       SDValue Diff = DAG.getNode(ISD::SUB, DL, VT, Const, Add.getOperand(1));
47478       SDValue CMov =
47479           DAG.getNode(X86ISD::CMOV, DL, VT, Diff, Add.getOperand(0),
47480                       DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8), Cond);
47481       return DAG.getNode(ISD::ADD, DL, VT, CMov, Add.getOperand(1));
47482     }
47483   }
47484 
47485   return SDValue();
47486 }
47487 
47488 /// Different mul shrinking modes.
47489 enum class ShrinkMode { MULS8, MULU8, MULS16, MULU16 };
47490 
canReduceVMulWidth(SDNode * N,SelectionDAG & DAG,ShrinkMode & Mode)47491 static bool canReduceVMulWidth(SDNode *N, SelectionDAG &DAG, ShrinkMode &Mode) {
47492   EVT VT = N->getOperand(0).getValueType();
47493   if (VT.getScalarSizeInBits() != 32)
47494     return false;
47495 
47496   assert(N->getNumOperands() == 2 && "NumOperands of Mul are 2");
47497   unsigned SignBits[2] = {1, 1};
47498   bool IsPositive[2] = {false, false};
47499   for (unsigned i = 0; i < 2; i++) {
47500     SDValue Opd = N->getOperand(i);
47501 
47502     SignBits[i] = DAG.ComputeNumSignBits(Opd);
47503     IsPositive[i] = DAG.SignBitIsZero(Opd);
47504   }
47505 
47506   bool AllPositive = IsPositive[0] && IsPositive[1];
47507   unsigned MinSignBits = std::min(SignBits[0], SignBits[1]);
47508   // When ranges are from -128 ~ 127, use MULS8 mode.
47509   if (MinSignBits >= 25)
47510     Mode = ShrinkMode::MULS8;
47511   // When ranges are from 0 ~ 255, use MULU8 mode.
47512   else if (AllPositive && MinSignBits >= 24)
47513     Mode = ShrinkMode::MULU8;
47514   // When ranges are from -32768 ~ 32767, use MULS16 mode.
47515   else if (MinSignBits >= 17)
47516     Mode = ShrinkMode::MULS16;
47517   // When ranges are from 0 ~ 65535, use MULU16 mode.
47518   else if (AllPositive && MinSignBits >= 16)
47519     Mode = ShrinkMode::MULU16;
47520   else
47521     return false;
47522   return true;
47523 }
47524 
47525 /// When the operands of vector mul are extended from smaller size values,
47526 /// like i8 and i16, the type of mul may be shrinked to generate more
47527 /// efficient code. Two typical patterns are handled:
47528 /// Pattern1:
47529 ///     %2 = sext/zext <N x i8> %1 to <N x i32>
47530 ///     %4 = sext/zext <N x i8> %3 to <N x i32>
47531 //   or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
47532 ///     %5 = mul <N x i32> %2, %4
47533 ///
47534 /// Pattern2:
47535 ///     %2 = zext/sext <N x i16> %1 to <N x i32>
47536 ///     %4 = zext/sext <N x i16> %3 to <N x i32>
47537 ///  or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
47538 ///     %5 = mul <N x i32> %2, %4
47539 ///
47540 /// There are four mul shrinking modes:
47541 /// If %2 == sext32(trunc8(%2)), i.e., the scalar value range of %2 is
47542 /// -128 to 128, and the scalar value range of %4 is also -128 to 128,
47543 /// generate pmullw+sext32 for it (MULS8 mode).
47544 /// If %2 == zext32(trunc8(%2)), i.e., the scalar value range of %2 is
47545 /// 0 to 255, and the scalar value range of %4 is also 0 to 255,
47546 /// generate pmullw+zext32 for it (MULU8 mode).
47547 /// If %2 == sext32(trunc16(%2)), i.e., the scalar value range of %2 is
47548 /// -32768 to 32767, and the scalar value range of %4 is also -32768 to 32767,
47549 /// generate pmullw+pmulhw for it (MULS16 mode).
47550 /// If %2 == zext32(trunc16(%2)), i.e., the scalar value range of %2 is
47551 /// 0 to 65535, and the scalar value range of %4 is also 0 to 65535,
47552 /// generate pmullw+pmulhuw for it (MULU16 mode).
reduceVMULWidth(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)47553 static SDValue reduceVMULWidth(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
47554                                const X86Subtarget &Subtarget) {
47555   // Check for legality
47556   // pmullw/pmulhw are not supported by SSE.
47557   if (!Subtarget.hasSSE2())
47558     return SDValue();
47559 
47560   // Check for profitability
47561   // pmulld is supported since SSE41. It is better to use pmulld
47562   // instead of pmullw+pmulhw, except for subtargets where pmulld is slower than
47563   // the expansion.
47564   bool OptForMinSize = DAG.getMachineFunction().getFunction().hasMinSize();
47565   if (Subtarget.hasSSE41() && (OptForMinSize || !Subtarget.isPMULLDSlow()))
47566     return SDValue();
47567 
47568   ShrinkMode Mode;
47569   if (!canReduceVMulWidth(N, DAG, Mode))
47570     return SDValue();
47571 
47572   SDValue N0 = N->getOperand(0);
47573   SDValue N1 = N->getOperand(1);
47574   EVT VT = N->getOperand(0).getValueType();
47575   unsigned NumElts = VT.getVectorNumElements();
47576   if ((NumElts % 2) != 0)
47577     return SDValue();
47578 
47579   EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts);
47580 
47581   // Shrink the operands of mul.
47582   SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0);
47583   SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1);
47584 
47585   // Generate the lower part of mul: pmullw. For MULU8/MULS8, only the
47586   // lower part is needed.
47587   SDValue MulLo = DAG.getNode(ISD::MUL, DL, ReducedVT, NewN0, NewN1);
47588   if (Mode == ShrinkMode::MULU8 || Mode == ShrinkMode::MULS8)
47589     return DAG.getNode((Mode == ShrinkMode::MULU8) ? ISD::ZERO_EXTEND
47590                                                    : ISD::SIGN_EXTEND,
47591                        DL, VT, MulLo);
47592 
47593   EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts / 2);
47594   // Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16,
47595   // the higher part is also needed.
47596   SDValue MulHi =
47597       DAG.getNode(Mode == ShrinkMode::MULS16 ? ISD::MULHS : ISD::MULHU, DL,
47598                   ReducedVT, NewN0, NewN1);
47599 
47600   // Repack the lower part and higher part result of mul into a wider
47601   // result.
47602   // Generate shuffle functioning as punpcklwd.
47603   SmallVector<int, 16> ShuffleMask(NumElts);
47604   for (unsigned i = 0, e = NumElts / 2; i < e; i++) {
47605     ShuffleMask[2 * i] = i;
47606     ShuffleMask[2 * i + 1] = i + NumElts;
47607   }
47608   SDValue ResLo =
47609       DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask);
47610   ResLo = DAG.getBitcast(ResVT, ResLo);
47611   // Generate shuffle functioning as punpckhwd.
47612   for (unsigned i = 0, e = NumElts / 2; i < e; i++) {
47613     ShuffleMask[2 * i] = i + NumElts / 2;
47614     ShuffleMask[2 * i + 1] = i + NumElts * 3 / 2;
47615   }
47616   SDValue ResHi =
47617       DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask);
47618   ResHi = DAG.getBitcast(ResVT, ResHi);
47619   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi);
47620 }
47621 
combineMulSpecial(uint64_t MulAmt,SDNode * N,SelectionDAG & DAG,EVT VT,const SDLoc & DL)47622 static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG,
47623                                  EVT VT, const SDLoc &DL) {
47624 
47625   auto combineMulShlAddOrSub = [&](int Mult, int Shift, bool isAdd) {
47626     SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
47627                                  DAG.getConstant(Mult, DL, VT));
47628     Result = DAG.getNode(ISD::SHL, DL, VT, Result,
47629                          DAG.getConstant(Shift, DL, MVT::i8));
47630     Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result,
47631                          N->getOperand(0));
47632     return Result;
47633   };
47634 
47635   auto combineMulMulAddOrSub = [&](int Mul1, int Mul2, bool isAdd) {
47636     SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
47637                                  DAG.getConstant(Mul1, DL, VT));
47638     Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, Result,
47639                          DAG.getConstant(Mul2, DL, VT));
47640     Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result,
47641                          N->getOperand(0));
47642     return Result;
47643   };
47644 
47645   switch (MulAmt) {
47646   default:
47647     break;
47648   case 11:
47649     // mul x, 11 => add ((shl (mul x, 5), 1), x)
47650     return combineMulShlAddOrSub(5, 1, /*isAdd*/ true);
47651   case 21:
47652     // mul x, 21 => add ((shl (mul x, 5), 2), x)
47653     return combineMulShlAddOrSub(5, 2, /*isAdd*/ true);
47654   case 41:
47655     // mul x, 41 => add ((shl (mul x, 5), 3), x)
47656     return combineMulShlAddOrSub(5, 3, /*isAdd*/ true);
47657   case 22:
47658     // mul x, 22 => add (add ((shl (mul x, 5), 2), x), x)
47659     return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0),
47660                        combineMulShlAddOrSub(5, 2, /*isAdd*/ true));
47661   case 19:
47662     // mul x, 19 => add ((shl (mul x, 9), 1), x)
47663     return combineMulShlAddOrSub(9, 1, /*isAdd*/ true);
47664   case 37:
47665     // mul x, 37 => add ((shl (mul x, 9), 2), x)
47666     return combineMulShlAddOrSub(9, 2, /*isAdd*/ true);
47667   case 73:
47668     // mul x, 73 => add ((shl (mul x, 9), 3), x)
47669     return combineMulShlAddOrSub(9, 3, /*isAdd*/ true);
47670   case 13:
47671     // mul x, 13 => add ((shl (mul x, 3), 2), x)
47672     return combineMulShlAddOrSub(3, 2, /*isAdd*/ true);
47673   case 23:
47674     // mul x, 23 => sub ((shl (mul x, 3), 3), x)
47675     return combineMulShlAddOrSub(3, 3, /*isAdd*/ false);
47676   case 26:
47677     // mul x, 26 => add ((mul (mul x, 5), 5), x)
47678     return combineMulMulAddOrSub(5, 5, /*isAdd*/ true);
47679   case 28:
47680     // mul x, 28 => add ((mul (mul x, 9), 3), x)
47681     return combineMulMulAddOrSub(9, 3, /*isAdd*/ true);
47682   case 29:
47683     // mul x, 29 => add (add ((mul (mul x, 9), 3), x), x)
47684     return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0),
47685                        combineMulMulAddOrSub(9, 3, /*isAdd*/ true));
47686   }
47687 
47688   // Another trick. If this is a power 2 + 2/4/8, we can use a shift followed
47689   // by a single LEA.
47690   // First check if this a sum of two power of 2s because that's easy. Then
47691   // count how many zeros are up to the first bit.
47692   // TODO: We can do this even without LEA at a cost of two shifts and an add.
47693   if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
47694     unsigned ScaleShift = llvm::countr_zero(MulAmt);
47695     if (ScaleShift >= 1 && ScaleShift < 4) {
47696       unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
47697       SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
47698                                    DAG.getConstant(ShiftAmt, DL, MVT::i8));
47699       SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
47700                                    DAG.getConstant(ScaleShift, DL, MVT::i8));
47701       return DAG.getNode(ISD::ADD, DL, VT, Shift1, Shift2);
47702     }
47703   }
47704 
47705   return SDValue();
47706 }
47707 
47708 // If the upper 17 bits of either element are zero and the other element are
47709 // zero/sign bits then we can use PMADDWD, which is always at least as quick as
47710 // PMULLD, except on KNL.
combineMulToPMADDWD(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)47711 static SDValue combineMulToPMADDWD(SDNode *N, const SDLoc &DL,
47712                                    SelectionDAG &DAG,
47713                                    const X86Subtarget &Subtarget) {
47714   if (!Subtarget.hasSSE2())
47715     return SDValue();
47716 
47717   if (Subtarget.isPMADDWDSlow())
47718     return SDValue();
47719 
47720   EVT VT = N->getValueType(0);
47721 
47722   // Only support vXi32 vectors.
47723   if (!VT.isVector() || VT.getVectorElementType() != MVT::i32)
47724     return SDValue();
47725 
47726   // Make sure the type is legal or can split/widen to a legal type.
47727   // With AVX512 but without BWI, we would need to split v32i16.
47728   unsigned NumElts = VT.getVectorNumElements();
47729   if (NumElts == 1 || !isPowerOf2_32(NumElts))
47730     return SDValue();
47731 
47732   // With AVX512 but without BWI, we would need to split v32i16.
47733   if (32 <= (2 * NumElts) && Subtarget.hasAVX512() && !Subtarget.hasBWI())
47734     return SDValue();
47735 
47736   SDValue N0 = N->getOperand(0);
47737   SDValue N1 = N->getOperand(1);
47738 
47739   // If we are zero/sign extending two steps without SSE4.1, its better to
47740   // reduce the vmul width instead.
47741   if (!Subtarget.hasSSE41() &&
47742       (((N0.getOpcode() == ISD::ZERO_EXTEND &&
47743          N0.getOperand(0).getScalarValueSizeInBits() <= 8) &&
47744         (N1.getOpcode() == ISD::ZERO_EXTEND &&
47745          N1.getOperand(0).getScalarValueSizeInBits() <= 8)) ||
47746        ((N0.getOpcode() == ISD::SIGN_EXTEND &&
47747          N0.getOperand(0).getScalarValueSizeInBits() <= 8) &&
47748         (N1.getOpcode() == ISD::SIGN_EXTEND &&
47749          N1.getOperand(0).getScalarValueSizeInBits() <= 8))))
47750     return SDValue();
47751 
47752   // If we are sign extending a wide vector without SSE4.1, its better to reduce
47753   // the vmul width instead.
47754   if (!Subtarget.hasSSE41() &&
47755       (N0.getOpcode() == ISD::SIGN_EXTEND &&
47756        N0.getOperand(0).getValueSizeInBits() > 128) &&
47757       (N1.getOpcode() == ISD::SIGN_EXTEND &&
47758        N1.getOperand(0).getValueSizeInBits() > 128))
47759     return SDValue();
47760 
47761   // Sign bits must extend down to the lowest i16.
47762   if (DAG.ComputeMaxSignificantBits(N1) > 16 ||
47763       DAG.ComputeMaxSignificantBits(N0) > 16)
47764     return SDValue();
47765 
47766   // At least one of the elements must be zero in the upper 17 bits, or can be
47767   // safely made zero without altering the final result.
47768   auto GetZeroableOp = [&](SDValue Op) {
47769     APInt Mask17 = APInt::getHighBitsSet(32, 17);
47770     if (DAG.MaskedValueIsZero(Op, Mask17))
47771       return Op;
47772     // Mask off upper 16-bits of sign-extended constants.
47773     if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()))
47774       return DAG.getNode(ISD::AND, DL, VT, Op, DAG.getConstant(0xFFFF, DL, VT));
47775     if (Op.getOpcode() == ISD::SIGN_EXTEND && N->isOnlyUserOf(Op.getNode())) {
47776       SDValue Src = Op.getOperand(0);
47777       // Convert sext(vXi16) to zext(vXi16).
47778       if (Src.getScalarValueSizeInBits() == 16 && VT.getSizeInBits() <= 128)
47779         return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Src);
47780       // Convert sext(vXi8) to zext(vXi16 sext(vXi8)) on pre-SSE41 targets
47781       // which will expand the extension.
47782       if (Src.getScalarValueSizeInBits() < 16 && !Subtarget.hasSSE41()) {
47783         EVT ExtVT = VT.changeVectorElementType(MVT::i16);
47784         Src = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, Src);
47785         return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Src);
47786       }
47787     }
47788     // Convert SIGN_EXTEND_VECTOR_INREG to ZEXT_EXTEND_VECTOR_INREG.
47789     if (Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG &&
47790         N->isOnlyUserOf(Op.getNode())) {
47791       SDValue Src = Op.getOperand(0);
47792       if (Src.getScalarValueSizeInBits() == 16)
47793         return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, Src);
47794     }
47795     // Convert VSRAI(Op, 16) to VSRLI(Op, 16).
47796     if (Op.getOpcode() == X86ISD::VSRAI && Op.getConstantOperandVal(1) == 16 &&
47797         N->isOnlyUserOf(Op.getNode())) {
47798       return DAG.getNode(X86ISD::VSRLI, DL, VT, Op.getOperand(0),
47799                          Op.getOperand(1));
47800     }
47801     return SDValue();
47802   };
47803   SDValue ZeroN0 = GetZeroableOp(N0);
47804   SDValue ZeroN1 = GetZeroableOp(N1);
47805   if (!ZeroN0 && !ZeroN1)
47806     return SDValue();
47807   N0 = ZeroN0 ? ZeroN0 : N0;
47808   N1 = ZeroN1 ? ZeroN1 : N1;
47809 
47810   // Use SplitOpsAndApply to handle AVX splitting.
47811   auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
47812                            ArrayRef<SDValue> Ops) {
47813     MVT ResVT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
47814     MVT OpVT = MVT::getVectorVT(MVT::i16, Ops[0].getValueSizeInBits() / 16);
47815     return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT,
47816                        DAG.getBitcast(OpVT, Ops[0]),
47817                        DAG.getBitcast(OpVT, Ops[1]));
47818   };
47819   return SplitOpsAndApply(DAG, Subtarget, DL, VT, {N0, N1}, PMADDWDBuilder);
47820 }
47821 
combineMulToPMULDQ(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)47822 static SDValue combineMulToPMULDQ(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
47823                                   const X86Subtarget &Subtarget) {
47824   if (!Subtarget.hasSSE2())
47825     return SDValue();
47826 
47827   EVT VT = N->getValueType(0);
47828 
47829   // Only support vXi64 vectors.
47830   if (!VT.isVector() || VT.getVectorElementType() != MVT::i64 ||
47831       VT.getVectorNumElements() < 2 ||
47832       !isPowerOf2_32(VT.getVectorNumElements()))
47833     return SDValue();
47834 
47835   SDValue N0 = N->getOperand(0);
47836   SDValue N1 = N->getOperand(1);
47837 
47838   // MULDQ returns the 64-bit result of the signed multiplication of the lower
47839   // 32-bits. We can lower with this if the sign bits stretch that far.
47840   if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(N0) > 32 &&
47841       DAG.ComputeNumSignBits(N1) > 32) {
47842     auto PMULDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
47843                             ArrayRef<SDValue> Ops) {
47844       return DAG.getNode(X86ISD::PMULDQ, DL, Ops[0].getValueType(), Ops);
47845     };
47846     return SplitOpsAndApply(DAG, Subtarget, DL, VT, {N0, N1}, PMULDQBuilder,
47847                             /*CheckBWI*/ false);
47848   }
47849 
47850   // If the upper bits are zero we can use a single pmuludq.
47851   APInt Mask = APInt::getHighBitsSet(64, 32);
47852   if (DAG.MaskedValueIsZero(N0, Mask) && DAG.MaskedValueIsZero(N1, Mask)) {
47853     auto PMULUDQBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
47854                              ArrayRef<SDValue> Ops) {
47855       return DAG.getNode(X86ISD::PMULUDQ, DL, Ops[0].getValueType(), Ops);
47856     };
47857     return SplitOpsAndApply(DAG, Subtarget, DL, VT, {N0, N1}, PMULUDQBuilder,
47858                             /*CheckBWI*/ false);
47859   }
47860 
47861   return SDValue();
47862 }
47863 
combineMul(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)47864 static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
47865                           TargetLowering::DAGCombinerInfo &DCI,
47866                           const X86Subtarget &Subtarget) {
47867   EVT VT = N->getValueType(0);
47868   SDLoc DL(N);
47869 
47870   if (SDValue V = combineMulToPMADDWD(N, DL, DAG, Subtarget))
47871     return V;
47872 
47873   if (SDValue V = combineMulToPMULDQ(N, DL, DAG, Subtarget))
47874     return V;
47875 
47876   if (DCI.isBeforeLegalize() && VT.isVector())
47877     return reduceVMULWidth(N, DL, DAG, Subtarget);
47878 
47879   // Optimize a single multiply with constant into two operations in order to
47880   // implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA.
47881   if (!MulConstantOptimization)
47882     return SDValue();
47883 
47884   // An imul is usually smaller than the alternative sequence.
47885   if (DAG.getMachineFunction().getFunction().hasMinSize())
47886     return SDValue();
47887 
47888   if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer())
47889     return SDValue();
47890 
47891   if (VT != MVT::i64 && VT != MVT::i32 &&
47892       (!VT.isVector() || !VT.isSimple() || !VT.isInteger()))
47893     return SDValue();
47894 
47895   ConstantSDNode *CNode = isConstOrConstSplat(
47896       N->getOperand(1), /*AllowUndefs*/ true, /*AllowTrunc*/ false);
47897   const APInt *C = nullptr;
47898   if (!CNode) {
47899     if (VT.isVector())
47900       if (auto *RawC = getTargetConstantFromNode(N->getOperand(1)))
47901         if (auto *SplatC = RawC->getSplatValue())
47902           if (auto *SplatCI = dyn_cast<ConstantInt>(SplatC))
47903             C = &(SplatCI->getValue());
47904 
47905     if (!C || C->getBitWidth() != VT.getScalarSizeInBits())
47906       return SDValue();
47907   } else {
47908     C = &(CNode->getAPIntValue());
47909   }
47910 
47911   if (isPowerOf2_64(C->getZExtValue()))
47912     return SDValue();
47913 
47914   int64_t SignMulAmt = C->getSExtValue();
47915   assert(SignMulAmt != INT64_MIN && "Int min should have been handled!");
47916   uint64_t AbsMulAmt = SignMulAmt < 0 ? -SignMulAmt : SignMulAmt;
47917 
47918   SDValue NewMul = SDValue();
47919   if (VT == MVT::i64 || VT == MVT::i32) {
47920     if (AbsMulAmt == 3 || AbsMulAmt == 5 || AbsMulAmt == 9) {
47921       NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
47922                            DAG.getConstant(AbsMulAmt, DL, VT));
47923       if (SignMulAmt < 0)
47924         NewMul = DAG.getNegative(NewMul, DL, VT);
47925 
47926       return NewMul;
47927     }
47928 
47929     uint64_t MulAmt1 = 0;
47930     uint64_t MulAmt2 = 0;
47931     if ((AbsMulAmt % 9) == 0) {
47932       MulAmt1 = 9;
47933       MulAmt2 = AbsMulAmt / 9;
47934     } else if ((AbsMulAmt % 5) == 0) {
47935       MulAmt1 = 5;
47936       MulAmt2 = AbsMulAmt / 5;
47937     } else if ((AbsMulAmt % 3) == 0) {
47938       MulAmt1 = 3;
47939       MulAmt2 = AbsMulAmt / 3;
47940     }
47941 
47942     // For negative multiply amounts, only allow MulAmt2 to be a power of 2.
47943     if (MulAmt2 &&
47944         (isPowerOf2_64(MulAmt2) ||
47945          (SignMulAmt >= 0 && (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9)))) {
47946 
47947       if (isPowerOf2_64(MulAmt2) && !(SignMulAmt >= 0 && N->hasOneUse() &&
47948                                       N->use_begin()->getOpcode() == ISD::ADD))
47949         // If second multiplifer is pow2, issue it first. We want the multiply
47950         // by 3, 5, or 9 to be folded into the addressing mode unless the lone
47951         // use is an add. Only do this for positive multiply amounts since the
47952         // negate would prevent it from being used as an address mode anyway.
47953         std::swap(MulAmt1, MulAmt2);
47954 
47955       if (isPowerOf2_64(MulAmt1))
47956         NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
47957                              DAG.getConstant(Log2_64(MulAmt1), DL, MVT::i8));
47958       else
47959         NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
47960                              DAG.getConstant(MulAmt1, DL, VT));
47961 
47962       if (isPowerOf2_64(MulAmt2))
47963         NewMul = DAG.getNode(ISD::SHL, DL, VT, NewMul,
47964                              DAG.getConstant(Log2_64(MulAmt2), DL, MVT::i8));
47965       else
47966         NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, NewMul,
47967                              DAG.getConstant(MulAmt2, DL, VT));
47968 
47969       // Negate the result.
47970       if (SignMulAmt < 0)
47971         NewMul = DAG.getNegative(NewMul, DL, VT);
47972     } else if (!Subtarget.slowLEA())
47973       NewMul = combineMulSpecial(C->getZExtValue(), N, DAG, VT, DL);
47974   }
47975   if (!NewMul) {
47976     EVT ShiftVT = VT.isVector() ? VT : MVT::i8;
47977     assert(C->getZExtValue() != 0 &&
47978            C->getZExtValue() != maxUIntN(VT.getScalarSizeInBits()) &&
47979            "Both cases that could cause potential overflows should have "
47980            "already been handled.");
47981     if (isPowerOf2_64(AbsMulAmt - 1)) {
47982       // (mul x, 2^N + 1) => (add (shl x, N), x)
47983       NewMul = DAG.getNode(
47984           ISD::ADD, DL, VT, N->getOperand(0),
47985           DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
47986                       DAG.getConstant(Log2_64(AbsMulAmt - 1), DL, ShiftVT)));
47987       if (SignMulAmt < 0)
47988         NewMul = DAG.getNegative(NewMul, DL, VT);
47989     } else if (isPowerOf2_64(AbsMulAmt + 1)) {
47990       // (mul x, 2^N - 1) => (sub (shl x, N), x)
47991       NewMul =
47992           DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
47993                       DAG.getConstant(Log2_64(AbsMulAmt + 1), DL, ShiftVT));
47994       // To negate, reverse the operands of the subtract.
47995       if (SignMulAmt < 0)
47996         NewMul = DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), NewMul);
47997       else
47998         NewMul = DAG.getNode(ISD::SUB, DL, VT, NewMul, N->getOperand(0));
47999     } else if (SignMulAmt >= 0 && isPowerOf2_64(AbsMulAmt - 2) &&
48000                (!VT.isVector() || Subtarget.fastImmVectorShift())) {
48001       // (mul x, 2^N + 2) => (add (shl x, N), (add x, x))
48002       NewMul =
48003           DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
48004                       DAG.getConstant(Log2_64(AbsMulAmt - 2), DL, ShiftVT));
48005       NewMul = DAG.getNode(
48006           ISD::ADD, DL, VT, NewMul,
48007           DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0), N->getOperand(0)));
48008     } else if (SignMulAmt >= 0 && isPowerOf2_64(AbsMulAmt + 2) &&
48009                (!VT.isVector() || Subtarget.fastImmVectorShift())) {
48010       // (mul x, 2^N - 2) => (sub (shl x, N), (add x, x))
48011       NewMul =
48012           DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
48013                       DAG.getConstant(Log2_64(AbsMulAmt + 2), DL, ShiftVT));
48014       NewMul = DAG.getNode(
48015           ISD::SUB, DL, VT, NewMul,
48016           DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0), N->getOperand(0)));
48017     } else if (SignMulAmt >= 0 && VT.isVector() &&
48018                Subtarget.fastImmVectorShift()) {
48019       uint64_t AbsMulAmtLowBit = AbsMulAmt & (-AbsMulAmt);
48020       uint64_t ShiftAmt1;
48021       std::optional<unsigned> Opc;
48022       if (isPowerOf2_64(AbsMulAmt - AbsMulAmtLowBit)) {
48023         ShiftAmt1 = AbsMulAmt - AbsMulAmtLowBit;
48024         Opc = ISD::ADD;
48025       } else if (isPowerOf2_64(AbsMulAmt + AbsMulAmtLowBit)) {
48026         ShiftAmt1 = AbsMulAmt + AbsMulAmtLowBit;
48027         Opc = ISD::SUB;
48028       }
48029 
48030       if (Opc) {
48031         SDValue Shift1 =
48032             DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
48033                         DAG.getConstant(Log2_64(ShiftAmt1), DL, ShiftVT));
48034         SDValue Shift2 =
48035             DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
48036                         DAG.getConstant(Log2_64(AbsMulAmtLowBit), DL, ShiftVT));
48037         NewMul = DAG.getNode(*Opc, DL, VT, Shift1, Shift2);
48038       }
48039     }
48040   }
48041 
48042   return NewMul;
48043 }
48044 
48045 // Try to form a MULHU or MULHS node by looking for
48046 // (srl (mul ext, ext), 16)
48047 // TODO: This is X86 specific because we want to be able to handle wide types
48048 // before type legalization. But we can only do it if the vector will be
48049 // legalized via widening/splitting. Type legalization can't handle promotion
48050 // of a MULHU/MULHS. There isn't a way to convey this to the generic DAG
48051 // combiner.
combineShiftToPMULH(SDNode * N,SelectionDAG & DAG,const SDLoc & DL,const X86Subtarget & Subtarget)48052 static SDValue combineShiftToPMULH(SDNode *N, SelectionDAG &DAG,
48053                                    const SDLoc &DL,
48054                                    const X86Subtarget &Subtarget) {
48055   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
48056            "SRL or SRA node is required here!");
48057 
48058   if (!Subtarget.hasSSE2())
48059     return SDValue();
48060 
48061   // The operation feeding into the shift must be a multiply.
48062   SDValue ShiftOperand = N->getOperand(0);
48063   if (ShiftOperand.getOpcode() != ISD::MUL || !ShiftOperand.hasOneUse())
48064     return SDValue();
48065 
48066   // Input type should be at least vXi32.
48067   EVT VT = N->getValueType(0);
48068   if (!VT.isVector() || VT.getVectorElementType().getSizeInBits() < 32)
48069     return SDValue();
48070 
48071   // Need a shift by 16.
48072   APInt ShiftAmt;
48073   if (!ISD::isConstantSplatVector(N->getOperand(1).getNode(), ShiftAmt) ||
48074       ShiftAmt != 16)
48075     return SDValue();
48076 
48077   SDValue LHS = ShiftOperand.getOperand(0);
48078   SDValue RHS = ShiftOperand.getOperand(1);
48079 
48080   unsigned ExtOpc = LHS.getOpcode();
48081   if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) ||
48082       RHS.getOpcode() != ExtOpc)
48083     return SDValue();
48084 
48085   // Peek through the extends.
48086   LHS = LHS.getOperand(0);
48087   RHS = RHS.getOperand(0);
48088 
48089   // Ensure the input types match.
48090   EVT MulVT = LHS.getValueType();
48091   if (MulVT.getVectorElementType() != MVT::i16 || RHS.getValueType() != MulVT)
48092     return SDValue();
48093 
48094   unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU;
48095   SDValue Mulh = DAG.getNode(Opc, DL, MulVT, LHS, RHS);
48096 
48097   ExtOpc = N->getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
48098   return DAG.getNode(ExtOpc, DL, VT, Mulh);
48099 }
48100 
combineShiftLeft(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)48101 static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
48102                                 const X86Subtarget &Subtarget) {
48103   using namespace llvm::SDPatternMatch;
48104   SDValue N0 = N->getOperand(0);
48105   SDValue N1 = N->getOperand(1);
48106   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
48107   EVT VT = N0.getValueType();
48108   unsigned EltSizeInBits = VT.getScalarSizeInBits();
48109   SDLoc DL(N);
48110 
48111   // Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
48112   // with out-of-bounds clamping.
48113   if (N0.getOpcode() == ISD::VSELECT &&
48114       supportedVectorVarShift(VT, Subtarget, ISD::SHL)) {
48115     SDValue Cond = N0.getOperand(0);
48116     SDValue N00 = N0.getOperand(1);
48117     SDValue N01 = N0.getOperand(2);
48118     // fold shl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psllv(x,amt)
48119     if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
48120         sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
48121                                m_SpecificCondCode(ISD::SETULT)))) {
48122       return DAG.getNode(X86ISD::VSHLV, DL, VT, N00, N1);
48123     }
48124     // fold shl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psllv(x,amt)
48125     if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
48126         sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
48127                                m_SpecificCondCode(ISD::SETUGE)))) {
48128       return DAG.getNode(X86ISD::VSHLV, DL, VT, N01, N1);
48129     }
48130   }
48131 
48132   // fold (shl (and (setcc_c), c1), c2) -> (and setcc_c, (c1 << c2))
48133   // since the result of setcc_c is all zero's or all ones.
48134   if (VT.isInteger() && !VT.isVector() &&
48135       N1C && N0.getOpcode() == ISD::AND &&
48136       N0.getOperand(1).getOpcode() == ISD::Constant) {
48137     SDValue N00 = N0.getOperand(0);
48138     APInt Mask = N0.getConstantOperandAPInt(1);
48139     Mask <<= N1C->getAPIntValue();
48140     bool MaskOK = false;
48141     // We can handle cases concerning bit-widening nodes containing setcc_c if
48142     // we carefully interrogate the mask to make sure we are semantics
48143     // preserving.
48144     // The transform is not safe if the result of C1 << C2 exceeds the bitwidth
48145     // of the underlying setcc_c operation if the setcc_c was zero extended.
48146     // Consider the following example:
48147     //   zext(setcc_c)                 -> i32 0x0000FFFF
48148     //   c1                            -> i32 0x0000FFFF
48149     //   c2                            -> i32 0x00000001
48150     //   (shl (and (setcc_c), c1), c2) -> i32 0x0001FFFE
48151     //   (and setcc_c, (c1 << c2))     -> i32 0x0000FFFE
48152     if (N00.getOpcode() == X86ISD::SETCC_CARRY) {
48153       MaskOK = true;
48154     } else if (N00.getOpcode() == ISD::SIGN_EXTEND &&
48155                N00.getOperand(0).getOpcode() == X86ISD::SETCC_CARRY) {
48156       MaskOK = true;
48157     } else if ((N00.getOpcode() == ISD::ZERO_EXTEND ||
48158                 N00.getOpcode() == ISD::ANY_EXTEND) &&
48159                N00.getOperand(0).getOpcode() == X86ISD::SETCC_CARRY) {
48160       MaskOK = Mask.isIntN(N00.getOperand(0).getValueSizeInBits());
48161     }
48162     if (MaskOK && Mask != 0)
48163       return DAG.getNode(ISD::AND, DL, VT, N00, DAG.getConstant(Mask, DL, VT));
48164   }
48165 
48166   return SDValue();
48167 }
48168 
combineShiftRightArithmetic(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)48169 static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
48170                                            const X86Subtarget &Subtarget) {
48171   using namespace llvm::SDPatternMatch;
48172   SDValue N0 = N->getOperand(0);
48173   SDValue N1 = N->getOperand(1);
48174   EVT VT = N0.getValueType();
48175   unsigned Size = VT.getSizeInBits();
48176   SDLoc DL(N);
48177 
48178   if (SDValue V = combineShiftToPMULH(N, DAG, DL, Subtarget))
48179     return V;
48180 
48181   // fold sra(x,umin(amt,bw-1)) -> avx2 psrav(x,amt)
48182   if (supportedVectorVarShift(VT, Subtarget, ISD::SRA)) {
48183     SDValue ShrAmtVal;
48184     if (sd_match(N1, m_UMin(m_Value(ShrAmtVal),
48185                             m_SpecificInt(VT.getScalarSizeInBits() - 1))))
48186       return DAG.getNode(X86ISD::VSRAV, DL, VT, N0, ShrAmtVal);
48187   }
48188 
48189   // fold (SRA (SHL X, ShlConst), SraConst)
48190   // into (SHL (sext_in_reg X), ShlConst - SraConst)
48191   //   or (sext_in_reg X)
48192   //   or (SRA (sext_in_reg X), SraConst - ShlConst)
48193   // depending on relation between SraConst and ShlConst.
48194   // We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
48195   // us to do the sext_in_reg from corresponding bit.
48196 
48197   // sexts in X86 are MOVs. The MOVs have the same code size
48198   // as above SHIFTs (only SHIFT on 1 has lower code size).
48199   // However the MOVs have 2 advantages to a SHIFT:
48200   // 1. MOVs can write to a register that differs from source
48201   // 2. MOVs accept memory operands
48202 
48203   if (VT.isVector() || N1.getOpcode() != ISD::Constant ||
48204       N0.getOpcode() != ISD::SHL || !N0.hasOneUse() ||
48205       N0.getOperand(1).getOpcode() != ISD::Constant)
48206     return SDValue();
48207 
48208   SDValue N00 = N0.getOperand(0);
48209   SDValue N01 = N0.getOperand(1);
48210   APInt ShlConst = N01->getAsAPIntVal();
48211   APInt SraConst = N1->getAsAPIntVal();
48212   EVT CVT = N1.getValueType();
48213 
48214   if (CVT != N01.getValueType())
48215     return SDValue();
48216   if (SraConst.isNegative())
48217     return SDValue();
48218 
48219   for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
48220     unsigned ShiftSize = SVT.getSizeInBits();
48221     // Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
48222     if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
48223       continue;
48224     SDValue NN =
48225         DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
48226     if (SraConst.eq(ShlConst))
48227       return NN;
48228     if (SraConst.ult(ShlConst))
48229       return DAG.getNode(ISD::SHL, DL, VT, NN,
48230                          DAG.getConstant(ShlConst - SraConst, DL, CVT));
48231     return DAG.getNode(ISD::SRA, DL, VT, NN,
48232                        DAG.getConstant(SraConst - ShlConst, DL, CVT));
48233   }
48234   return SDValue();
48235 }
48236 
combineShiftRightLogical(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)48237 static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
48238                                         TargetLowering::DAGCombinerInfo &DCI,
48239                                         const X86Subtarget &Subtarget) {
48240   using namespace llvm::SDPatternMatch;
48241   SDValue N0 = N->getOperand(0);
48242   SDValue N1 = N->getOperand(1);
48243   EVT VT = N0.getValueType();
48244   unsigned EltSizeInBits = VT.getScalarSizeInBits();
48245   SDLoc DL(N);
48246 
48247   if (SDValue V = combineShiftToPMULH(N, DAG, DL, Subtarget))
48248     return V;
48249 
48250   // Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
48251   // with out-of-bounds clamping.
48252   if (N0.getOpcode() == ISD::VSELECT &&
48253       supportedVectorVarShift(VT, Subtarget, ISD::SRL)) {
48254     SDValue Cond = N0.getOperand(0);
48255     SDValue N00 = N0.getOperand(1);
48256     SDValue N01 = N0.getOperand(2);
48257     // fold srl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psrlv(x,amt)
48258     if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
48259         sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
48260                                m_SpecificCondCode(ISD::SETULT)))) {
48261       return DAG.getNode(X86ISD::VSRLV, DL, VT, N00, N1);
48262     }
48263     // fold srl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psrlv(x,amt)
48264     if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
48265         sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
48266                                m_SpecificCondCode(ISD::SETUGE)))) {
48267       return DAG.getNode(X86ISD::VSRLV, DL, VT, N01, N1);
48268     }
48269   }
48270 
48271   // Only do this on the last DAG combine as it can interfere with other
48272   // combines.
48273   if (!DCI.isAfterLegalizeDAG())
48274     return SDValue();
48275 
48276   // Try to improve a sequence of srl (and X, C1), C2 by inverting the order.
48277   // TODO: This is a generic DAG combine that became an x86-only combine to
48278   // avoid shortcomings in other folds such as bswap, bit-test ('bt'), and
48279   // and-not ('andn').
48280   if (N0.getOpcode() != ISD::AND || !N0.hasOneUse())
48281     return SDValue();
48282 
48283   auto *ShiftC = dyn_cast<ConstantSDNode>(N1);
48284   auto *AndC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
48285   if (!ShiftC || !AndC)
48286     return SDValue();
48287 
48288   // If we can shrink the constant mask below 8-bits or 32-bits, then this
48289   // transform should reduce code size. It may also enable secondary transforms
48290   // from improved known-bits analysis or instruction selection.
48291   APInt MaskVal = AndC->getAPIntValue();
48292 
48293   // If this can be matched by a zero extend, don't optimize.
48294   if (MaskVal.isMask()) {
48295     unsigned TO = MaskVal.countr_one();
48296     if (TO >= 8 && isPowerOf2_32(TO))
48297       return SDValue();
48298   }
48299 
48300   APInt NewMaskVal = MaskVal.lshr(ShiftC->getAPIntValue());
48301   unsigned OldMaskSize = MaskVal.getSignificantBits();
48302   unsigned NewMaskSize = NewMaskVal.getSignificantBits();
48303   if ((OldMaskSize > 8 && NewMaskSize <= 8) ||
48304       (OldMaskSize > 32 && NewMaskSize <= 32)) {
48305     // srl (and X, AndC), ShiftC --> and (srl X, ShiftC), (AndC >> ShiftC)
48306     SDValue NewMask = DAG.getConstant(NewMaskVal, DL, VT);
48307     SDValue NewShift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), N1);
48308     return DAG.getNode(ISD::AND, DL, VT, NewShift, NewMask);
48309   }
48310   return SDValue();
48311 }
48312 
combineHorizOpWithShuffle(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)48313 static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
48314                                          const X86Subtarget &Subtarget) {
48315   unsigned Opcode = N->getOpcode();
48316   assert(isHorizOp(Opcode) && "Unexpected hadd/hsub/pack opcode");
48317 
48318   SDLoc DL(N);
48319   EVT VT = N->getValueType(0);
48320   SDValue N0 = N->getOperand(0);
48321   SDValue N1 = N->getOperand(1);
48322   EVT SrcVT = N0.getValueType();
48323 
48324   SDValue BC0 =
48325       N->isOnlyUserOf(N0.getNode()) ? peekThroughOneUseBitcasts(N0) : N0;
48326   SDValue BC1 =
48327       N->isOnlyUserOf(N1.getNode()) ? peekThroughOneUseBitcasts(N1) : N1;
48328 
48329   // Attempt to fold HOP(LOSUBVECTOR(SHUFFLE(X)),HISUBVECTOR(SHUFFLE(X)))
48330   // to SHUFFLE(HOP(LOSUBVECTOR(X),HISUBVECTOR(X))), this is mainly for
48331   // truncation trees that help us avoid lane crossing shuffles.
48332   // TODO: There's a lot more we can do for PACK/HADD style shuffle combines.
48333   // TODO: We don't handle vXf64 shuffles yet.
48334   if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) {
48335     if (SDValue BCSrc = getSplitVectorSrc(BC0, BC1, false)) {
48336       SmallVector<SDValue> ShuffleOps;
48337       SmallVector<int> ShuffleMask, ScaledMask;
48338       SDValue Vec = peekThroughBitcasts(BCSrc);
48339       if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) {
48340         resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
48341         // To keep the HOP LHS/RHS coherency, we must be able to scale the unary
48342         // shuffle to a v4X64 width - we can probably relax this in the future.
48343         if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
48344             ShuffleOps[0].getValueType().is256BitVector() &&
48345             scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
48346           SDValue Lo, Hi;
48347           MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
48348           std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
48349           Lo = DAG.getBitcast(SrcVT, Lo);
48350           Hi = DAG.getBitcast(SrcVT, Hi);
48351           SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi);
48352           Res = DAG.getBitcast(ShufVT, Res);
48353           Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask);
48354           return DAG.getBitcast(VT, Res);
48355         }
48356       }
48357     }
48358   }
48359 
48360   // Attempt to fold HOP(SHUFFLE(X,Y),SHUFFLE(Z,W)) -> SHUFFLE(HOP()).
48361   if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) {
48362     // If either/both ops are a shuffle that can scale to v2x64,
48363     // then see if we can perform this as a v4x32 post shuffle.
48364     SmallVector<SDValue> Ops0, Ops1;
48365     SmallVector<int> Mask0, Mask1, ScaledMask0, ScaledMask1;
48366     bool IsShuf0 =
48367         getTargetShuffleInputs(BC0, Ops0, Mask0, DAG) && !isAnyZero(Mask0) &&
48368         scaleShuffleElements(Mask0, 2, ScaledMask0) &&
48369         all_of(Ops0, [](SDValue Op) { return Op.getValueSizeInBits() == 128; });
48370     bool IsShuf1 =
48371         getTargetShuffleInputs(BC1, Ops1, Mask1, DAG) && !isAnyZero(Mask1) &&
48372         scaleShuffleElements(Mask1, 2, ScaledMask1) &&
48373         all_of(Ops1, [](SDValue Op) { return Op.getValueSizeInBits() == 128; });
48374     if (IsShuf0 || IsShuf1) {
48375       if (!IsShuf0) {
48376         Ops0.assign({BC0});
48377         ScaledMask0.assign({0, 1});
48378       }
48379       if (!IsShuf1) {
48380         Ops1.assign({BC1});
48381         ScaledMask1.assign({0, 1});
48382       }
48383 
48384       SDValue LHS, RHS;
48385       int PostShuffle[4] = {-1, -1, -1, -1};
48386       auto FindShuffleOpAndIdx = [&](int M, int &Idx, ArrayRef<SDValue> Ops) {
48387         if (M < 0)
48388           return true;
48389         Idx = M % 2;
48390         SDValue Src = Ops[M / 2];
48391         if (!LHS || LHS == Src) {
48392           LHS = Src;
48393           return true;
48394         }
48395         if (!RHS || RHS == Src) {
48396           Idx += 2;
48397           RHS = Src;
48398           return true;
48399         }
48400         return false;
48401       };
48402       if (FindShuffleOpAndIdx(ScaledMask0[0], PostShuffle[0], Ops0) &&
48403           FindShuffleOpAndIdx(ScaledMask0[1], PostShuffle[1], Ops0) &&
48404           FindShuffleOpAndIdx(ScaledMask1[0], PostShuffle[2], Ops1) &&
48405           FindShuffleOpAndIdx(ScaledMask1[1], PostShuffle[3], Ops1)) {
48406         LHS = DAG.getBitcast(SrcVT, LHS);
48407         RHS = DAG.getBitcast(SrcVT, RHS ? RHS : LHS);
48408         MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
48409         SDValue Res = DAG.getNode(Opcode, DL, VT, LHS, RHS);
48410         Res = DAG.getBitcast(ShufVT, Res);
48411         Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, PostShuffle);
48412         return DAG.getBitcast(VT, Res);
48413       }
48414     }
48415   }
48416 
48417   // Attempt to fold HOP(SHUFFLE(X,Y),SHUFFLE(X,Y)) -> SHUFFLE(HOP(X,Y)).
48418   if (VT.is256BitVector() && Subtarget.hasInt256()) {
48419     SmallVector<int> Mask0, Mask1;
48420     SmallVector<SDValue> Ops0, Ops1;
48421     SmallVector<int, 2> ScaledMask0, ScaledMask1;
48422     if (getTargetShuffleInputs(BC0, Ops0, Mask0, DAG) && !isAnyZero(Mask0) &&
48423         getTargetShuffleInputs(BC1, Ops1, Mask1, DAG) && !isAnyZero(Mask1) &&
48424         !Ops0.empty() && !Ops1.empty() &&
48425         all_of(Ops0,
48426                [](SDValue Op) { return Op.getValueType().is256BitVector(); }) &&
48427         all_of(Ops1,
48428                [](SDValue Op) { return Op.getValueType().is256BitVector(); }) &&
48429         scaleShuffleElements(Mask0, 2, ScaledMask0) &&
48430         scaleShuffleElements(Mask1, 2, ScaledMask1)) {
48431       SDValue Op00 = peekThroughBitcasts(Ops0.front());
48432       SDValue Op10 = peekThroughBitcasts(Ops1.front());
48433       SDValue Op01 = peekThroughBitcasts(Ops0.back());
48434       SDValue Op11 = peekThroughBitcasts(Ops1.back());
48435       if ((Op00 == Op11) && (Op01 == Op10)) {
48436         std::swap(Op10, Op11);
48437         ShuffleVectorSDNode::commuteMask(ScaledMask1);
48438       }
48439       if ((Op00 == Op10) && (Op01 == Op11)) {
48440         const int Map[4] = {0, 2, 1, 3};
48441         SmallVector<int, 4> ShuffleMask(
48442             {Map[ScaledMask0[0]], Map[ScaledMask1[0]], Map[ScaledMask0[1]],
48443              Map[ScaledMask1[1]]});
48444         MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f64 : MVT::v4i64;
48445         SDValue Res = DAG.getNode(Opcode, DL, VT, DAG.getBitcast(SrcVT, Op00),
48446                                   DAG.getBitcast(SrcVT, Op01));
48447         Res = DAG.getBitcast(ShufVT, Res);
48448         Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ShuffleMask);
48449         return DAG.getBitcast(VT, Res);
48450       }
48451     }
48452   }
48453 
48454   return SDValue();
48455 }
48456 
combineVectorPack(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)48457 static SDValue combineVectorPack(SDNode *N, SelectionDAG &DAG,
48458                                  TargetLowering::DAGCombinerInfo &DCI,
48459                                  const X86Subtarget &Subtarget) {
48460   unsigned Opcode = N->getOpcode();
48461   assert((X86ISD::PACKSS == Opcode || X86ISD::PACKUS == Opcode) &&
48462          "Unexpected pack opcode");
48463 
48464   EVT VT = N->getValueType(0);
48465   SDValue N0 = N->getOperand(0);
48466   SDValue N1 = N->getOperand(1);
48467   unsigned NumDstElts = VT.getVectorNumElements();
48468   unsigned DstBitsPerElt = VT.getScalarSizeInBits();
48469   unsigned SrcBitsPerElt = 2 * DstBitsPerElt;
48470   assert(N0.getScalarValueSizeInBits() == SrcBitsPerElt &&
48471          N1.getScalarValueSizeInBits() == SrcBitsPerElt &&
48472          "Unexpected PACKSS/PACKUS input type");
48473 
48474   bool IsSigned = (X86ISD::PACKSS == Opcode);
48475 
48476   // Constant Folding.
48477   APInt UndefElts0, UndefElts1;
48478   SmallVector<APInt, 32> EltBits0, EltBits1;
48479   if ((N0.isUndef() || N->isOnlyUserOf(N0.getNode())) &&
48480       (N1.isUndef() || N->isOnlyUserOf(N1.getNode())) &&
48481       getTargetConstantBitsFromNode(N0, SrcBitsPerElt, UndefElts0, EltBits0,
48482                                     /*AllowWholeUndefs*/ true,
48483                                     /*AllowPartialUndefs*/ true) &&
48484       getTargetConstantBitsFromNode(N1, SrcBitsPerElt, UndefElts1, EltBits1,
48485                                     /*AllowWholeUndefs*/ true,
48486                                     /*AllowPartialUndefs*/ true)) {
48487     unsigned NumLanes = VT.getSizeInBits() / 128;
48488     unsigned NumSrcElts = NumDstElts / 2;
48489     unsigned NumDstEltsPerLane = NumDstElts / NumLanes;
48490     unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes;
48491 
48492     APInt Undefs(NumDstElts, 0);
48493     SmallVector<APInt, 32> Bits(NumDstElts, APInt::getZero(DstBitsPerElt));
48494     for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
48495       for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) {
48496         unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane;
48497         auto &UndefElts = (Elt >= NumSrcEltsPerLane ? UndefElts1 : UndefElts0);
48498         auto &EltBits = (Elt >= NumSrcEltsPerLane ? EltBits1 : EltBits0);
48499 
48500         if (UndefElts[SrcIdx]) {
48501           Undefs.setBit(Lane * NumDstEltsPerLane + Elt);
48502           continue;
48503         }
48504 
48505         APInt &Val = EltBits[SrcIdx];
48506         if (IsSigned) {
48507           // PACKSS: Truncate signed value with signed saturation.
48508           // Source values less than dst minint are saturated to minint.
48509           // Source values greater than dst maxint are saturated to maxint.
48510           Val = Val.truncSSat(DstBitsPerElt);
48511         } else {
48512           // PACKUS: Truncate signed value with unsigned saturation.
48513           // Source values less than zero are saturated to zero.
48514           // Source values greater than dst maxuint are saturated to maxuint.
48515           // NOTE: This is different from APInt::truncUSat.
48516           if (Val.isIntN(DstBitsPerElt))
48517             Val = Val.trunc(DstBitsPerElt);
48518           else if (Val.isNegative())
48519             Val = APInt::getZero(DstBitsPerElt);
48520           else
48521             Val = APInt::getAllOnes(DstBitsPerElt);
48522         }
48523         Bits[Lane * NumDstEltsPerLane + Elt] = Val;
48524       }
48525     }
48526 
48527     return getConstVector(Bits, Undefs, VT.getSimpleVT(), DAG, SDLoc(N));
48528   }
48529 
48530   // Try to fold PACK(SHUFFLE(),SHUFFLE()) -> SHUFFLE(PACK()).
48531   if (SDValue V = combineHorizOpWithShuffle(N, DAG, Subtarget))
48532     return V;
48533 
48534   // Try to fold PACKSS(NOT(X),NOT(Y)) -> NOT(PACKSS(X,Y)).
48535   // Currently limit this to allsignbits cases only.
48536   if (IsSigned &&
48537       (N0.isUndef() || DAG.ComputeNumSignBits(N0) == SrcBitsPerElt) &&
48538       (N1.isUndef() || DAG.ComputeNumSignBits(N1) == SrcBitsPerElt)) {
48539     SDValue Not0 = N0.isUndef() ? N0 : IsNOT(N0, DAG);
48540     SDValue Not1 = N1.isUndef() ? N1 : IsNOT(N1, DAG);
48541     if (Not0 && Not1) {
48542       SDLoc DL(N);
48543       MVT SrcVT = N0.getSimpleValueType();
48544       SDValue Pack =
48545           DAG.getNode(X86ISD::PACKSS, DL, VT, DAG.getBitcast(SrcVT, Not0),
48546                       DAG.getBitcast(SrcVT, Not1));
48547       return DAG.getNOT(DL, Pack, VT);
48548     }
48549   }
48550 
48551   // Try to combine a PACKUSWB/PACKSSWB implemented truncate with a regular
48552   // truncate to create a larger truncate.
48553   if (Subtarget.hasAVX512() &&
48554       N0.getOpcode() == ISD::TRUNCATE && N1.isUndef() && VT == MVT::v16i8 &&
48555       N0.getOperand(0).getValueType() == MVT::v8i32) {
48556     if ((IsSigned && DAG.ComputeNumSignBits(N0) > 8) ||
48557         (!IsSigned &&
48558          DAG.MaskedValueIsZero(N0, APInt::getHighBitsSet(16, 8)))) {
48559       if (Subtarget.hasVLX())
48560         return DAG.getNode(X86ISD::VTRUNC, SDLoc(N), VT, N0.getOperand(0));
48561 
48562       // Widen input to v16i32 so we can truncate that.
48563       SDLoc dl(N);
48564       SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i32,
48565                                    N0.getOperand(0), DAG.getUNDEF(MVT::v8i32));
48566       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Concat);
48567     }
48568   }
48569 
48570   // Try to fold PACK(EXTEND(X),EXTEND(Y)) -> CONCAT(X,Y) subvectors.
48571   if (VT.is128BitVector()) {
48572     unsigned ExtOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
48573     SDValue Src0, Src1;
48574     if (N0.getOpcode() == ExtOpc &&
48575         N0.getOperand(0).getValueType().is64BitVector() &&
48576         N0.getOperand(0).getScalarValueSizeInBits() == DstBitsPerElt) {
48577       Src0 = N0.getOperand(0);
48578     }
48579     if (N1.getOpcode() == ExtOpc &&
48580         N1.getOperand(0).getValueType().is64BitVector() &&
48581         N1.getOperand(0).getScalarValueSizeInBits() == DstBitsPerElt) {
48582       Src1 = N1.getOperand(0);
48583     }
48584     if ((Src0 || N0.isUndef()) && (Src1 || N1.isUndef())) {
48585       assert((Src0 || Src1) && "Found PACK(UNDEF,UNDEF)");
48586       Src0 = Src0 ? Src0 : DAG.getUNDEF(Src1.getValueType());
48587       Src1 = Src1 ? Src1 : DAG.getUNDEF(Src0.getValueType());
48588       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Src0, Src1);
48589     }
48590 
48591     // Try again with pack(*_extend_vector_inreg, undef).
48592     unsigned VecInRegOpc = IsSigned ? ISD::SIGN_EXTEND_VECTOR_INREG
48593                                     : ISD::ZERO_EXTEND_VECTOR_INREG;
48594     if (N0.getOpcode() == VecInRegOpc && N1.isUndef() &&
48595         N0.getOperand(0).getScalarValueSizeInBits() < DstBitsPerElt)
48596       return getEXTEND_VECTOR_INREG(ExtOpc, SDLoc(N), VT, N0.getOperand(0),
48597                                     DAG);
48598   }
48599 
48600   // Attempt to combine as shuffle.
48601   SDValue Op(N, 0);
48602   if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
48603     return Res;
48604 
48605   return SDValue();
48606 }
48607 
combineVectorHADDSUB(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)48608 static SDValue combineVectorHADDSUB(SDNode *N, SelectionDAG &DAG,
48609                                     TargetLowering::DAGCombinerInfo &DCI,
48610                                     const X86Subtarget &Subtarget) {
48611   assert((X86ISD::HADD == N->getOpcode() || X86ISD::FHADD == N->getOpcode() ||
48612           X86ISD::HSUB == N->getOpcode() || X86ISD::FHSUB == N->getOpcode()) &&
48613          "Unexpected horizontal add/sub opcode");
48614 
48615   if (!shouldUseHorizontalOp(true, DAG, Subtarget)) {
48616     MVT VT = N->getSimpleValueType(0);
48617     SDValue LHS = N->getOperand(0);
48618     SDValue RHS = N->getOperand(1);
48619 
48620     // HOP(HOP'(X,X),HOP'(Y,Y)) -> HOP(PERMUTE(HOP'(X,Y)),PERMUTE(HOP'(X,Y)).
48621     if (LHS != RHS && LHS.getOpcode() == N->getOpcode() &&
48622         LHS.getOpcode() == RHS.getOpcode() &&
48623         LHS.getValueType() == RHS.getValueType() &&
48624         N->isOnlyUserOf(LHS.getNode()) && N->isOnlyUserOf(RHS.getNode())) {
48625       SDValue LHS0 = LHS.getOperand(0);
48626       SDValue LHS1 = LHS.getOperand(1);
48627       SDValue RHS0 = RHS.getOperand(0);
48628       SDValue RHS1 = RHS.getOperand(1);
48629       if ((LHS0 == LHS1 || LHS0.isUndef() || LHS1.isUndef()) &&
48630           (RHS0 == RHS1 || RHS0.isUndef() || RHS1.isUndef())) {
48631         SDLoc DL(N);
48632         SDValue Res = DAG.getNode(LHS.getOpcode(), DL, LHS.getValueType(),
48633                                   LHS0.isUndef() ? LHS1 : LHS0,
48634                                   RHS0.isUndef() ? RHS1 : RHS0);
48635         MVT ShufVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32);
48636         Res = DAG.getBitcast(ShufVT, Res);
48637         SDValue NewLHS =
48638             DAG.getNode(X86ISD::PSHUFD, DL, ShufVT, Res,
48639                         getV4X86ShuffleImm8ForMask({0, 1, 0, 1}, DL, DAG));
48640         SDValue NewRHS =
48641             DAG.getNode(X86ISD::PSHUFD, DL, ShufVT, Res,
48642                         getV4X86ShuffleImm8ForMask({2, 3, 2, 3}, DL, DAG));
48643         return DAG.getNode(N->getOpcode(), DL, VT, DAG.getBitcast(VT, NewLHS),
48644                            DAG.getBitcast(VT, NewRHS));
48645       }
48646     }
48647   }
48648 
48649   // Try to fold HOP(SHUFFLE(),SHUFFLE()) -> SHUFFLE(HOP()).
48650   if (SDValue V = combineHorizOpWithShuffle(N, DAG, Subtarget))
48651     return V;
48652 
48653   return SDValue();
48654 }
48655 
combineVectorShiftVar(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)48656 static SDValue combineVectorShiftVar(SDNode *N, SelectionDAG &DAG,
48657                                      TargetLowering::DAGCombinerInfo &DCI,
48658                                      const X86Subtarget &Subtarget) {
48659   assert((X86ISD::VSHL == N->getOpcode() || X86ISD::VSRA == N->getOpcode() ||
48660           X86ISD::VSRL == N->getOpcode()) &&
48661          "Unexpected shift opcode");
48662   EVT VT = N->getValueType(0);
48663   SDValue N0 = N->getOperand(0);
48664   SDValue N1 = N->getOperand(1);
48665 
48666   // Shift zero -> zero.
48667   if (ISD::isBuildVectorAllZeros(N0.getNode()))
48668     return DAG.getConstant(0, SDLoc(N), VT);
48669 
48670   // Detect constant shift amounts.
48671   APInt UndefElts;
48672   SmallVector<APInt, 32> EltBits;
48673   if (getTargetConstantBitsFromNode(N1, 64, UndefElts, EltBits,
48674                                     /*AllowWholeUndefs*/ true,
48675                                     /*AllowPartialUndefs*/ false)) {
48676     unsigned X86Opc = getTargetVShiftUniformOpcode(N->getOpcode(), false);
48677     return getTargetVShiftByConstNode(X86Opc, SDLoc(N), VT.getSimpleVT(), N0,
48678                                       EltBits[0].getZExtValue(), DAG);
48679   }
48680 
48681   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
48682   APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
48683   if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, DCI))
48684     return SDValue(N, 0);
48685 
48686   return SDValue();
48687 }
48688 
combineVectorShiftImm(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)48689 static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG,
48690                                      TargetLowering::DAGCombinerInfo &DCI,
48691                                      const X86Subtarget &Subtarget) {
48692   unsigned Opcode = N->getOpcode();
48693   assert((X86ISD::VSHLI == Opcode || X86ISD::VSRAI == Opcode ||
48694           X86ISD::VSRLI == Opcode) &&
48695          "Unexpected shift opcode");
48696   bool LogicalShift = X86ISD::VSHLI == Opcode || X86ISD::VSRLI == Opcode;
48697   EVT VT = N->getValueType(0);
48698   SDValue N0 = N->getOperand(0);
48699   SDValue N1 = N->getOperand(1);
48700   unsigned NumBitsPerElt = VT.getScalarSizeInBits();
48701   assert(VT == N0.getValueType() && (NumBitsPerElt % 8) == 0 &&
48702          "Unexpected value type");
48703   assert(N1.getValueType() == MVT::i8 && "Unexpected shift amount type");
48704 
48705   // (shift undef, X) -> 0
48706   if (N0.isUndef())
48707     return DAG.getConstant(0, SDLoc(N), VT);
48708 
48709   // Out of range logical bit shifts are guaranteed to be zero.
48710   // Out of range arithmetic bit shifts splat the sign bit.
48711   unsigned ShiftVal = N->getConstantOperandVal(1);
48712   if (ShiftVal >= NumBitsPerElt) {
48713     if (LogicalShift)
48714       return DAG.getConstant(0, SDLoc(N), VT);
48715     ShiftVal = NumBitsPerElt - 1;
48716   }
48717 
48718   // (shift X, 0) -> X
48719   if (!ShiftVal)
48720     return N0;
48721 
48722   // (shift 0, C) -> 0
48723   if (ISD::isBuildVectorAllZeros(N0.getNode()))
48724     // N0 is all zeros or undef. We guarantee that the bits shifted into the
48725     // result are all zeros, not undef.
48726     return DAG.getConstant(0, SDLoc(N), VT);
48727 
48728   // (VSRAI -1, C) -> -1
48729   if (!LogicalShift && ISD::isBuildVectorAllOnes(N0.getNode()))
48730     // N0 is all ones or undef. We guarantee that the bits shifted into the
48731     // result are all ones, not undef.
48732     return DAG.getConstant(-1, SDLoc(N), VT);
48733 
48734   auto MergeShifts = [&](SDValue X, uint64_t Amt0, uint64_t Amt1) {
48735     unsigned NewShiftVal = Amt0 + Amt1;
48736     if (NewShiftVal >= NumBitsPerElt) {
48737       // Out of range logical bit shifts are guaranteed to be zero.
48738       // Out of range arithmetic bit shifts splat the sign bit.
48739       if (LogicalShift)
48740         return DAG.getConstant(0, SDLoc(N), VT);
48741       NewShiftVal = NumBitsPerElt - 1;
48742     }
48743     return DAG.getNode(Opcode, SDLoc(N), VT, N0.getOperand(0),
48744                        DAG.getTargetConstant(NewShiftVal, SDLoc(N), MVT::i8));
48745   };
48746 
48747   // (shift (shift X, C2), C1) -> (shift X, (C1 + C2))
48748   if (Opcode == N0.getOpcode())
48749     return MergeShifts(N0.getOperand(0), ShiftVal, N0.getConstantOperandVal(1));
48750 
48751   // (shl (add X, X), C) -> (shl X, (C + 1))
48752   if (Opcode == X86ISD::VSHLI && N0.getOpcode() == ISD::ADD &&
48753       N0.getOperand(0) == N0.getOperand(1))
48754     return MergeShifts(N0.getOperand(0), ShiftVal, 1);
48755 
48756   // We can decode 'whole byte' logical bit shifts as shuffles.
48757   if (LogicalShift && (ShiftVal % 8) == 0) {
48758     SDValue Op(N, 0);
48759     if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
48760       return Res;
48761   }
48762 
48763   // Attempt to detect an expanded vXi64 SIGN_EXTEND_INREG vXi1 pattern, and
48764   // convert to a splatted v2Xi32 SIGN_EXTEND_INREG pattern:
48765   // psrad(pshufd(psllq(X,63),1,1,3,3),31) ->
48766   // pshufd(psrad(pslld(X,31),31),0,0,2,2).
48767   if (Opcode == X86ISD::VSRAI && NumBitsPerElt == 32 && ShiftVal == 31 &&
48768       N0.getOpcode() == X86ISD::PSHUFD &&
48769       N0.getConstantOperandVal(1) == getV4X86ShuffleImm({1, 1, 3, 3}) &&
48770       N0->hasOneUse()) {
48771     SDValue BC = peekThroughOneUseBitcasts(N0.getOperand(0));
48772     if (BC.getOpcode() == X86ISD::VSHLI &&
48773         BC.getScalarValueSizeInBits() == 64 &&
48774         BC.getConstantOperandVal(1) == 63) {
48775       SDLoc DL(N);
48776       SDValue Src = BC.getOperand(0);
48777       Src = DAG.getBitcast(VT, Src);
48778       Src = DAG.getNode(X86ISD::PSHUFD, DL, VT, Src,
48779                         getV4X86ShuffleImm8ForMask({0, 0, 2, 2}, DL, DAG));
48780       Src = DAG.getNode(X86ISD::VSHLI, DL, VT, Src, N1);
48781       Src = DAG.getNode(X86ISD::VSRAI, DL, VT, Src, N1);
48782       return Src;
48783     }
48784   }
48785 
48786   auto TryConstantFold = [&](SDValue V) {
48787     APInt UndefElts;
48788     SmallVector<APInt, 32> EltBits;
48789     if (!getTargetConstantBitsFromNode(V, NumBitsPerElt, UndefElts, EltBits,
48790                                        /*AllowWholeUndefs*/ true,
48791                                        /*AllowPartialUndefs*/ true))
48792       return SDValue();
48793     assert(EltBits.size() == VT.getVectorNumElements() &&
48794            "Unexpected shift value type");
48795     // Undef elements need to fold to 0. It's possible SimplifyDemandedBits
48796     // created an undef input due to no input bits being demanded, but user
48797     // still expects 0 in other bits.
48798     for (unsigned i = 0, e = EltBits.size(); i != e; ++i) {
48799       APInt &Elt = EltBits[i];
48800       if (UndefElts[i])
48801         Elt = 0;
48802       else if (X86ISD::VSHLI == Opcode)
48803         Elt <<= ShiftVal;
48804       else if (X86ISD::VSRAI == Opcode)
48805         Elt.ashrInPlace(ShiftVal);
48806       else
48807         Elt.lshrInPlace(ShiftVal);
48808     }
48809     // Reset undef elements since they were zeroed above.
48810     UndefElts = 0;
48811     return getConstVector(EltBits, UndefElts, VT.getSimpleVT(), DAG, SDLoc(N));
48812   };
48813 
48814   // Constant Folding.
48815   if (N->isOnlyUserOf(N0.getNode())) {
48816     if (SDValue C = TryConstantFold(N0))
48817       return C;
48818 
48819     // Fold (shift (logic X, C2), C1) -> (logic (shift X, C1), (shift C2, C1))
48820     // Don't break NOT patterns.
48821     SDValue BC = peekThroughOneUseBitcasts(N0);
48822     if (ISD::isBitwiseLogicOp(BC.getOpcode()) &&
48823         BC->isOnlyUserOf(BC.getOperand(1).getNode()) &&
48824         !ISD::isBuildVectorAllOnes(BC.getOperand(1).getNode())) {
48825       if (SDValue RHS = TryConstantFold(BC.getOperand(1))) {
48826         SDLoc DL(N);
48827         SDValue LHS = DAG.getNode(Opcode, DL, VT,
48828                                   DAG.getBitcast(VT, BC.getOperand(0)), N1);
48829         return DAG.getNode(BC.getOpcode(), DL, VT, LHS, RHS);
48830       }
48831     }
48832   }
48833 
48834   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
48835   if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumBitsPerElt),
48836                                DCI))
48837     return SDValue(N, 0);
48838 
48839   return SDValue();
48840 }
48841 
combineVectorInsert(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)48842 static SDValue combineVectorInsert(SDNode *N, SelectionDAG &DAG,
48843                                    TargetLowering::DAGCombinerInfo &DCI,
48844                                    const X86Subtarget &Subtarget) {
48845   EVT VT = N->getValueType(0);
48846   unsigned Opcode = N->getOpcode();
48847   assert(((Opcode == X86ISD::PINSRB && VT == MVT::v16i8) ||
48848           (Opcode == X86ISD::PINSRW && VT == MVT::v8i16) ||
48849           Opcode == ISD::INSERT_VECTOR_ELT) &&
48850          "Unexpected vector insertion");
48851 
48852   SDValue Vec = N->getOperand(0);
48853   SDValue Scl = N->getOperand(1);
48854   SDValue Idx = N->getOperand(2);
48855 
48856   // Fold insert_vector_elt(undef, elt, 0) --> scalar_to_vector(elt).
48857   if (Opcode == ISD::INSERT_VECTOR_ELT && Vec.isUndef() && isNullConstant(Idx))
48858     return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Scl);
48859 
48860   if (Opcode == X86ISD::PINSRB || Opcode == X86ISD::PINSRW) {
48861     unsigned NumBitsPerElt = VT.getScalarSizeInBits();
48862     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
48863     if (TLI.SimplifyDemandedBits(SDValue(N, 0),
48864                                  APInt::getAllOnes(NumBitsPerElt), DCI))
48865       return SDValue(N, 0);
48866   }
48867 
48868   // Attempt to combine insertion patterns to a shuffle.
48869   if (VT.isSimple() && DCI.isAfterLegalizeDAG()) {
48870     SDValue Op(N, 0);
48871     if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
48872       return Res;
48873   }
48874 
48875   return SDValue();
48876 }
48877 
48878 /// Recognize the distinctive (AND (setcc ...) (setcc ..)) where both setccs
48879 /// reference the same FP CMP, and rewrite for CMPEQSS and friends. Likewise for
48880 /// OR -> CMPNEQSS.
combineCompareEqual(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)48881 static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG,
48882                                    TargetLowering::DAGCombinerInfo &DCI,
48883                                    const X86Subtarget &Subtarget) {
48884   unsigned opcode;
48885 
48886   // SSE1 supports CMP{eq|ne}SS, and SSE2 added CMP{eq|ne}SD, but
48887   // we're requiring SSE2 for both.
48888   if (Subtarget.hasSSE2() && isAndOrOfSetCCs(SDValue(N, 0U), opcode)) {
48889     SDValue N0 = N->getOperand(0);
48890     SDValue N1 = N->getOperand(1);
48891     SDValue CMP0 = N0.getOperand(1);
48892     SDValue CMP1 = N1.getOperand(1);
48893     SDLoc DL(N);
48894 
48895     // The SETCCs should both refer to the same CMP.
48896     if (CMP0.getOpcode() != X86ISD::FCMP || CMP0 != CMP1)
48897       return SDValue();
48898 
48899     SDValue CMP00 = CMP0->getOperand(0);
48900     SDValue CMP01 = CMP0->getOperand(1);
48901     EVT     VT    = CMP00.getValueType();
48902 
48903     if (VT == MVT::f32 || VT == MVT::f64 ||
48904         (VT == MVT::f16 && Subtarget.hasFP16())) {
48905       bool ExpectingFlags = false;
48906       // Check for any users that want flags:
48907       for (const SDNode *U : N->uses()) {
48908         if (ExpectingFlags)
48909           break;
48910 
48911         switch (U->getOpcode()) {
48912         default:
48913         case ISD::BR_CC:
48914         case ISD::BRCOND:
48915         case ISD::SELECT:
48916           ExpectingFlags = true;
48917           break;
48918         case ISD::CopyToReg:
48919         case ISD::SIGN_EXTEND:
48920         case ISD::ZERO_EXTEND:
48921         case ISD::ANY_EXTEND:
48922           break;
48923         }
48924       }
48925 
48926       if (!ExpectingFlags) {
48927         enum X86::CondCode cc0 = (enum X86::CondCode)N0.getConstantOperandVal(0);
48928         enum X86::CondCode cc1 = (enum X86::CondCode)N1.getConstantOperandVal(0);
48929 
48930         if (cc1 == X86::COND_E || cc1 == X86::COND_NE) {
48931           X86::CondCode tmp = cc0;
48932           cc0 = cc1;
48933           cc1 = tmp;
48934         }
48935 
48936         if ((cc0 == X86::COND_E  && cc1 == X86::COND_NP) ||
48937             (cc0 == X86::COND_NE && cc1 == X86::COND_P)) {
48938           // FIXME: need symbolic constants for these magic numbers.
48939           // See X86ATTInstPrinter.cpp:printSSECC().
48940           unsigned x86cc = (cc0 == X86::COND_E) ? 0 : 4;
48941           if (Subtarget.hasAVX512()) {
48942             SDValue FSetCC =
48943                 DAG.getNode(X86ISD::FSETCCM, DL, MVT::v1i1, CMP00, CMP01,
48944                             DAG.getTargetConstant(x86cc, DL, MVT::i8));
48945             // Need to fill with zeros to ensure the bitcast will produce zeroes
48946             // for the upper bits. An EXTRACT_ELEMENT here wouldn't guarantee that.
48947             SDValue Ins = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MVT::v16i1,
48948                                       DAG.getConstant(0, DL, MVT::v16i1),
48949                                       FSetCC, DAG.getIntPtrConstant(0, DL));
48950             return DAG.getZExtOrTrunc(DAG.getBitcast(MVT::i16, Ins), DL,
48951                                       N->getSimpleValueType(0));
48952           }
48953           SDValue OnesOrZeroesF =
48954               DAG.getNode(X86ISD::FSETCC, DL, CMP00.getValueType(), CMP00,
48955                           CMP01, DAG.getTargetConstant(x86cc, DL, MVT::i8));
48956 
48957           bool is64BitFP = (CMP00.getValueType() == MVT::f64);
48958           MVT IntVT = is64BitFP ? MVT::i64 : MVT::i32;
48959 
48960           if (is64BitFP && !Subtarget.is64Bit()) {
48961             // On a 32-bit target, we cannot bitcast the 64-bit float to a
48962             // 64-bit integer, since that's not a legal type. Since
48963             // OnesOrZeroesF is all ones or all zeroes, we don't need all the
48964             // bits, but can do this little dance to extract the lowest 32 bits
48965             // and work with those going forward.
48966             SDValue Vector64 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f64,
48967                                            OnesOrZeroesF);
48968             SDValue Vector32 = DAG.getBitcast(MVT::v4f32, Vector64);
48969             OnesOrZeroesF = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32,
48970                                         Vector32, DAG.getIntPtrConstant(0, DL));
48971             IntVT = MVT::i32;
48972           }
48973 
48974           SDValue OnesOrZeroesI = DAG.getBitcast(IntVT, OnesOrZeroesF);
48975           SDValue ANDed = DAG.getNode(ISD::AND, DL, IntVT, OnesOrZeroesI,
48976                                       DAG.getConstant(1, DL, IntVT));
48977           SDValue OneBitOfTruth = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
48978                                               ANDed);
48979           return OneBitOfTruth;
48980         }
48981       }
48982     }
48983   }
48984   return SDValue();
48985 }
48986 
48987 /// Try to fold: (and (xor X, -1), Y) -> (andnp X, Y).
combineAndNotIntoANDNP(SDNode * N,SelectionDAG & DAG)48988 static SDValue combineAndNotIntoANDNP(SDNode *N, SelectionDAG &DAG) {
48989   assert(N->getOpcode() == ISD::AND && "Unexpected opcode combine into ANDNP");
48990 
48991   MVT VT = N->getSimpleValueType(0);
48992   if (!VT.is128BitVector() && !VT.is256BitVector() && !VT.is512BitVector())
48993     return SDValue();
48994 
48995   SDValue X, Y;
48996   SDValue N0 = N->getOperand(0);
48997   SDValue N1 = N->getOperand(1);
48998 
48999   if (SDValue Not = IsNOT(N0, DAG)) {
49000     X = Not;
49001     Y = N1;
49002   } else if (SDValue Not = IsNOT(N1, DAG)) {
49003     X = Not;
49004     Y = N0;
49005   } else
49006     return SDValue();
49007 
49008   X = DAG.getBitcast(VT, X);
49009   Y = DAG.getBitcast(VT, Y);
49010   return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
49011 }
49012 
49013 /// Try to fold:
49014 ///   and (vector_shuffle<Z,...,Z>
49015 ///            (insert_vector_elt undef, (xor X, -1), Z), undef), Y
49016 ///   ->
49017 ///   andnp (vector_shuffle<Z,...,Z>
49018 ///              (insert_vector_elt undef, X, Z), undef), Y
combineAndShuffleNot(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)49019 static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG,
49020                                     const X86Subtarget &Subtarget) {
49021   assert(N->getOpcode() == ISD::AND && "Unexpected opcode combine into ANDNP");
49022 
49023   EVT VT = N->getValueType(0);
49024   // Do not split 256 and 512 bit vectors with SSE2 as they overwrite original
49025   // value and require extra moves.
49026   if (!((VT.is128BitVector() && Subtarget.hasSSE2()) ||
49027         ((VT.is256BitVector() || VT.is512BitVector()) && Subtarget.hasAVX())))
49028     return SDValue();
49029 
49030   auto GetNot = [&DAG](SDValue V) {
49031     auto *SVN = dyn_cast<ShuffleVectorSDNode>(peekThroughOneUseBitcasts(V));
49032     // TODO: SVN->hasOneUse() is a strong condition. It can be relaxed if all
49033     // end-users are ISD::AND including cases
49034     // (and(extract_vector_element(SVN), Y)).
49035     if (!SVN || !SVN->hasOneUse() || !SVN->isSplat() ||
49036         !SVN->getOperand(1).isUndef()) {
49037       return SDValue();
49038     }
49039     SDValue IVEN = SVN->getOperand(0);
49040     if (IVEN.getOpcode() != ISD::INSERT_VECTOR_ELT ||
49041         !IVEN.getOperand(0).isUndef() || !IVEN.hasOneUse())
49042       return SDValue();
49043     if (!isa<ConstantSDNode>(IVEN.getOperand(2)) ||
49044         IVEN.getConstantOperandAPInt(2) != SVN->getSplatIndex())
49045       return SDValue();
49046     SDValue Src = IVEN.getOperand(1);
49047     if (SDValue Not = IsNOT(Src, DAG)) {
49048       SDValue NotSrc = DAG.getBitcast(Src.getValueType(), Not);
49049       SDValue NotIVEN =
49050           DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(IVEN), IVEN.getValueType(),
49051                       IVEN.getOperand(0), NotSrc, IVEN.getOperand(2));
49052       return DAG.getVectorShuffle(SVN->getValueType(0), SDLoc(SVN), NotIVEN,
49053                                   SVN->getOperand(1), SVN->getMask());
49054     }
49055     return SDValue();
49056   };
49057 
49058   SDValue X, Y;
49059   SDValue N0 = N->getOperand(0);
49060   SDValue N1 = N->getOperand(1);
49061   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
49062 
49063   if (SDValue Not = GetNot(N0)) {
49064     X = Not;
49065     Y = N1;
49066   } else if (SDValue Not = GetNot(N1)) {
49067     X = Not;
49068     Y = N0;
49069   } else
49070     return SDValue();
49071 
49072   X = DAG.getBitcast(VT, X);
49073   Y = DAG.getBitcast(VT, Y);
49074   SDLoc DL(N);
49075 
49076   // We do not split for SSE at all, but we need to split vectors for AVX1 and
49077   // AVX2.
49078   if (!Subtarget.useAVX512Regs() && VT.is512BitVector() &&
49079       TLI.isTypeLegal(VT.getHalfNumVectorElementsVT(*DAG.getContext()))) {
49080     SDValue LoX, HiX;
49081     std::tie(LoX, HiX) = splitVector(X, DAG, DL);
49082     SDValue LoY, HiY;
49083     std::tie(LoY, HiY) = splitVector(Y, DAG, DL);
49084     EVT SplitVT = LoX.getValueType();
49085     SDValue LoV = DAG.getNode(X86ISD::ANDNP, DL, SplitVT, {LoX, LoY});
49086     SDValue HiV = DAG.getNode(X86ISD::ANDNP, DL, SplitVT, {HiX, HiY});
49087     return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, {LoV, HiV});
49088   }
49089 
49090   if (TLI.isTypeLegal(VT))
49091     return DAG.getNode(X86ISD::ANDNP, DL, VT, {X, Y});
49092 
49093   return SDValue();
49094 }
49095 
49096 // Try to widen AND, OR and XOR nodes to VT in order to remove casts around
49097 // logical operations, like in the example below.
49098 //   or (and (truncate x, truncate y)),
49099 //      (xor (truncate z, build_vector (constants)))
49100 // Given a target type \p VT, we generate
49101 //   or (and x, y), (xor z, zext(build_vector (constants)))
49102 // given x, y and z are of type \p VT. We can do so, if operands are either
49103 // truncates from VT types, the second operand is a vector of constants or can
49104 // be recursively promoted.
PromoteMaskArithmetic(SDValue N,const SDLoc & DL,EVT VT,SelectionDAG & DAG,unsigned Depth)49105 static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL, EVT VT,
49106                                      SelectionDAG &DAG, unsigned Depth) {
49107   // Limit recursion to avoid excessive compile times.
49108   if (Depth >= SelectionDAG::MaxRecursionDepth)
49109     return SDValue();
49110 
49111   if (!ISD::isBitwiseLogicOp(N.getOpcode()))
49112     return SDValue();
49113 
49114   SDValue N0 = N.getOperand(0);
49115   SDValue N1 = N.getOperand(1);
49116 
49117   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
49118   if (!TLI.isOperationLegalOrPromote(N.getOpcode(), VT))
49119     return SDValue();
49120 
49121   if (SDValue NN0 = PromoteMaskArithmetic(N0, DL, VT, DAG, Depth + 1))
49122     N0 = NN0;
49123   else {
49124     // The left side has to be a trunc.
49125     if (N0.getOpcode() != ISD::TRUNCATE)
49126       return SDValue();
49127 
49128     // The type of the truncated inputs.
49129     if (N0.getOperand(0).getValueType() != VT)
49130       return SDValue();
49131 
49132     N0 = N0.getOperand(0);
49133   }
49134 
49135   if (SDValue NN1 = PromoteMaskArithmetic(N1, DL, VT, DAG, Depth + 1))
49136     N1 = NN1;
49137   else {
49138     // The right side has to be a 'trunc' or a (foldable) constant.
49139     bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
49140                     N1.getOperand(0).getValueType() == VT;
49141     if (RHSTrunc)
49142       N1 = N1.getOperand(0);
49143     else if (SDValue Cst =
49144                  DAG.FoldConstantArithmetic(ISD::ZERO_EXTEND, DL, VT, {N1}))
49145       N1 = Cst;
49146     else
49147       return SDValue();
49148   }
49149 
49150   return DAG.getNode(N.getOpcode(), DL, VT, N0, N1);
49151 }
49152 
49153 // On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
49154 // register. In most cases we actually compare or select YMM-sized registers
49155 // and mixing the two types creates horrible code. This method optimizes
49156 // some of the transition sequences.
49157 // Even with AVX-512 this is still useful for removing casts around logical
49158 // operations on vXi1 mask types.
PromoteMaskArithmetic(SDValue N,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)49159 static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL,
49160                                      SelectionDAG &DAG,
49161                                      const X86Subtarget &Subtarget) {
49162   EVT VT = N.getValueType();
49163   assert(VT.isVector() && "Expected vector type");
49164   assert((N.getOpcode() == ISD::ANY_EXTEND ||
49165           N.getOpcode() == ISD::ZERO_EXTEND ||
49166           N.getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
49167 
49168   SDValue Narrow = N.getOperand(0);
49169   EVT NarrowVT = Narrow.getValueType();
49170 
49171   // Generate the wide operation.
49172   SDValue Op = PromoteMaskArithmetic(Narrow, DL, VT, DAG, 0);
49173   if (!Op)
49174     return SDValue();
49175   switch (N.getOpcode()) {
49176   default: llvm_unreachable("Unexpected opcode");
49177   case ISD::ANY_EXTEND:
49178     return Op;
49179   case ISD::ZERO_EXTEND:
49180     return DAG.getZeroExtendInReg(Op, DL, NarrowVT);
49181   case ISD::SIGN_EXTEND:
49182     return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT,
49183                        Op, DAG.getValueType(NarrowVT));
49184   }
49185 }
49186 
convertIntLogicToFPLogicOpcode(unsigned Opcode)49187 static unsigned convertIntLogicToFPLogicOpcode(unsigned Opcode) {
49188   unsigned FPOpcode;
49189   switch (Opcode) {
49190   // clang-format off
49191   default: llvm_unreachable("Unexpected input node for FP logic conversion");
49192   case ISD::AND: FPOpcode = X86ISD::FAND; break;
49193   case ISD::OR:  FPOpcode = X86ISD::FOR;  break;
49194   case ISD::XOR: FPOpcode = X86ISD::FXOR; break;
49195   // clang-format on
49196   }
49197   return FPOpcode;
49198 }
49199 
49200 /// If both input operands of a logic op are being cast from floating-point
49201 /// types or FP compares, try to convert this into a floating-point logic node
49202 /// to avoid unnecessary moves from SSE to integer registers.
convertIntLogicToFPLogic(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)49203 static SDValue convertIntLogicToFPLogic(SDNode *N, SelectionDAG &DAG,
49204                                         TargetLowering::DAGCombinerInfo &DCI,
49205                                         const X86Subtarget &Subtarget) {
49206   EVT VT = N->getValueType(0);
49207   SDValue N0 = N->getOperand(0);
49208   SDValue N1 = N->getOperand(1);
49209   SDLoc DL(N);
49210 
49211   if (!((N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) ||
49212         (N0.getOpcode() == ISD::SETCC && N1.getOpcode() == ISD::SETCC)))
49213     return SDValue();
49214 
49215   SDValue N00 = N0.getOperand(0);
49216   SDValue N10 = N1.getOperand(0);
49217   EVT N00Type = N00.getValueType();
49218   EVT N10Type = N10.getValueType();
49219 
49220   // Ensure that both types are the same and are legal scalar fp types.
49221   if (N00Type != N10Type || !((Subtarget.hasSSE1() && N00Type == MVT::f32) ||
49222                               (Subtarget.hasSSE2() && N00Type == MVT::f64) ||
49223                               (Subtarget.hasFP16() && N00Type == MVT::f16)))
49224     return SDValue();
49225 
49226   if (N0.getOpcode() == ISD::BITCAST && !DCI.isBeforeLegalizeOps()) {
49227     unsigned FPOpcode = convertIntLogicToFPLogicOpcode(N->getOpcode());
49228     SDValue FPLogic = DAG.getNode(FPOpcode, DL, N00Type, N00, N10);
49229     return DAG.getBitcast(VT, FPLogic);
49230   }
49231 
49232   if (VT != MVT::i1 || N0.getOpcode() != ISD::SETCC || !N0.hasOneUse() ||
49233       !N1.hasOneUse())
49234     return SDValue();
49235 
49236   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0.getOperand(2))->get();
49237   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1.getOperand(2))->get();
49238 
49239   // The vector ISA for FP predicates is incomplete before AVX, so converting
49240   // COMIS* to CMPS* may not be a win before AVX.
49241   if (!Subtarget.hasAVX() &&
49242       !(cheapX86FSETCC_SSE(CC0) && cheapX86FSETCC_SSE(CC1)))
49243     return SDValue();
49244 
49245   // Convert scalar FP compares and logic to vector compares (COMIS* to CMPS*)
49246   // and vector logic:
49247   // logic (setcc N00, N01), (setcc N10, N11) -->
49248   // extelt (logic (setcc (s2v N00), (s2v N01)), setcc (s2v N10), (s2v N11))), 0
49249   unsigned NumElts = 128 / N00Type.getSizeInBits();
49250   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), N00Type, NumElts);
49251   EVT BoolVecVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
49252   SDValue ZeroIndex = DAG.getVectorIdxConstant(0, DL);
49253   SDValue N01 = N0.getOperand(1);
49254   SDValue N11 = N1.getOperand(1);
49255   SDValue Vec00 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, N00);
49256   SDValue Vec01 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, N01);
49257   SDValue Vec10 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, N10);
49258   SDValue Vec11 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, N11);
49259   SDValue Setcc0 = DAG.getSetCC(DL, BoolVecVT, Vec00, Vec01, CC0);
49260   SDValue Setcc1 = DAG.getSetCC(DL, BoolVecVT, Vec10, Vec11, CC1);
49261   SDValue Logic = DAG.getNode(N->getOpcode(), DL, BoolVecVT, Setcc0, Setcc1);
49262   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Logic, ZeroIndex);
49263 }
49264 
49265 // Attempt to fold BITOP(MOVMSK(X),MOVMSK(Y)) -> MOVMSK(BITOP(X,Y))
49266 // to reduce XMM->GPR traffic.
combineBitOpWithMOVMSK(SDNode * N,SelectionDAG & DAG)49267 static SDValue combineBitOpWithMOVMSK(SDNode *N, SelectionDAG &DAG) {
49268   unsigned Opc = N->getOpcode();
49269   assert((Opc == ISD::OR || Opc == ISD::AND || Opc == ISD::XOR) &&
49270          "Unexpected bit opcode");
49271 
49272   SDValue N0 = N->getOperand(0);
49273   SDValue N1 = N->getOperand(1);
49274 
49275   // Both operands must be single use MOVMSK.
49276   if (N0.getOpcode() != X86ISD::MOVMSK || !N0.hasOneUse() ||
49277       N1.getOpcode() != X86ISD::MOVMSK || !N1.hasOneUse())
49278     return SDValue();
49279 
49280   SDValue Vec0 = N0.getOperand(0);
49281   SDValue Vec1 = N1.getOperand(0);
49282   EVT VecVT0 = Vec0.getValueType();
49283   EVT VecVT1 = Vec1.getValueType();
49284 
49285   // Both MOVMSK operands must be from vectors of the same size and same element
49286   // size, but its OK for a fp/int diff.
49287   if (VecVT0.getSizeInBits() != VecVT1.getSizeInBits() ||
49288       VecVT0.getScalarSizeInBits() != VecVT1.getScalarSizeInBits())
49289     return SDValue();
49290 
49291   SDLoc DL(N);
49292   unsigned VecOpc =
49293       VecVT0.isFloatingPoint() ? convertIntLogicToFPLogicOpcode(Opc) : Opc;
49294   SDValue Result =
49295       DAG.getNode(VecOpc, DL, VecVT0, Vec0, DAG.getBitcast(VecVT0, Vec1));
49296   return DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
49297 }
49298 
49299 // Attempt to fold BITOP(SHIFT(X,Z),SHIFT(Y,Z)) -> SHIFT(BITOP(X,Y),Z).
49300 // NOTE: This is a very limited case of what SimplifyUsingDistributiveLaws
49301 // handles in InstCombine.
combineBitOpWithShift(SDNode * N,SelectionDAG & DAG)49302 static SDValue combineBitOpWithShift(SDNode *N, SelectionDAG &DAG) {
49303   unsigned Opc = N->getOpcode();
49304   assert((Opc == ISD::OR || Opc == ISD::AND || Opc == ISD::XOR) &&
49305          "Unexpected bit opcode");
49306 
49307   SDValue N0 = N->getOperand(0);
49308   SDValue N1 = N->getOperand(1);
49309   EVT VT = N->getValueType(0);
49310 
49311   // Both operands must be single use.
49312   if (!N0.hasOneUse() || !N1.hasOneUse())
49313     return SDValue();
49314 
49315   // Search for matching shifts.
49316   SDValue BC0 = peekThroughOneUseBitcasts(N0);
49317   SDValue BC1 = peekThroughOneUseBitcasts(N1);
49318 
49319   unsigned BCOpc = BC0.getOpcode();
49320   EVT BCVT = BC0.getValueType();
49321   if (BCOpc != BC1->getOpcode() || BCVT != BC1.getValueType())
49322     return SDValue();
49323 
49324   switch (BCOpc) {
49325   case X86ISD::VSHLI:
49326   case X86ISD::VSRLI:
49327   case X86ISD::VSRAI: {
49328     if (BC0.getOperand(1) != BC1.getOperand(1))
49329       return SDValue();
49330 
49331     SDLoc DL(N);
49332     SDValue BitOp =
49333         DAG.getNode(Opc, DL, BCVT, BC0.getOperand(0), BC1.getOperand(0));
49334     SDValue Shift = DAG.getNode(BCOpc, DL, BCVT, BitOp, BC0.getOperand(1));
49335     return DAG.getBitcast(VT, Shift);
49336   }
49337   }
49338 
49339   return SDValue();
49340 }
49341 
49342 // Attempt to fold:
49343 // BITOP(PACKSS(X,Z),PACKSS(Y,W)) --> PACKSS(BITOP(X,Y),BITOP(Z,W)).
49344 // TODO: Handle PACKUS handling.
combineBitOpWithPACK(SDNode * N,SelectionDAG & DAG)49345 static SDValue combineBitOpWithPACK(SDNode *N, SelectionDAG &DAG) {
49346   unsigned Opc = N->getOpcode();
49347   assert((Opc == ISD::OR || Opc == ISD::AND || Opc == ISD::XOR) &&
49348          "Unexpected bit opcode");
49349 
49350   SDValue N0 = N->getOperand(0);
49351   SDValue N1 = N->getOperand(1);
49352   EVT VT = N->getValueType(0);
49353 
49354   // Both operands must be single use.
49355   if (!N0.hasOneUse() || !N1.hasOneUse())
49356     return SDValue();
49357 
49358   // Search for matching packs.
49359   N0 = peekThroughOneUseBitcasts(N0);
49360   N1 = peekThroughOneUseBitcasts(N1);
49361 
49362   if (N0.getOpcode() != X86ISD::PACKSS || N1.getOpcode() != X86ISD::PACKSS)
49363     return SDValue();
49364 
49365   MVT DstVT = N0.getSimpleValueType();
49366   if (DstVT != N1.getSimpleValueType())
49367     return SDValue();
49368 
49369   MVT SrcVT = N0.getOperand(0).getSimpleValueType();
49370   unsigned NumSrcBits = SrcVT.getScalarSizeInBits();
49371 
49372   // Limit to allsignbits packing.
49373   if (DAG.ComputeNumSignBits(N0.getOperand(0)) != NumSrcBits ||
49374       DAG.ComputeNumSignBits(N0.getOperand(1)) != NumSrcBits ||
49375       DAG.ComputeNumSignBits(N1.getOperand(0)) != NumSrcBits ||
49376       DAG.ComputeNumSignBits(N1.getOperand(1)) != NumSrcBits)
49377     return SDValue();
49378 
49379   SDLoc DL(N);
49380   SDValue LHS = DAG.getNode(Opc, DL, SrcVT, N0.getOperand(0), N1.getOperand(0));
49381   SDValue RHS = DAG.getNode(Opc, DL, SrcVT, N0.getOperand(1), N1.getOperand(1));
49382   return DAG.getBitcast(VT, DAG.getNode(X86ISD::PACKSS, DL, DstVT, LHS, RHS));
49383 }
49384 
49385 /// If this is a zero/all-bits result that is bitwise-anded with a low bits
49386 /// mask. (Mask == 1 for the x86 lowering of a SETCC + ZEXT), replace the 'and'
49387 /// with a shift-right to eliminate loading the vector constant mask value.
combineAndMaskToShift(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)49388 static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG,
49389                                      const X86Subtarget &Subtarget) {
49390   SDValue Op0 = peekThroughBitcasts(N->getOperand(0));
49391   SDValue Op1 = peekThroughBitcasts(N->getOperand(1));
49392   EVT VT = Op0.getValueType();
49393   if (VT != Op1.getValueType() || !VT.isSimple() || !VT.isInteger())
49394     return SDValue();
49395 
49396   // Try to convert an "is positive" signbit masking operation into arithmetic
49397   // shift and "andn". This saves a materialization of a -1 vector constant.
49398   // The "is negative" variant should be handled more generally because it only
49399   // requires "and" rather than "andn":
49400   // and (pcmpgt X, -1), Y --> pandn (vsrai X, BitWidth - 1), Y
49401   //
49402   // This is limited to the original type to avoid producing even more bitcasts.
49403   // If the bitcasts can't be eliminated, then it is unlikely that this fold
49404   // will be profitable.
49405   if (N->getValueType(0) == VT &&
49406       supportedVectorShiftWithImm(VT, Subtarget, ISD::SRA)) {
49407     SDValue X, Y;
49408     if (Op1.getOpcode() == X86ISD::PCMPGT &&
49409         isAllOnesOrAllOnesSplat(Op1.getOperand(1)) && Op1.hasOneUse()) {
49410       X = Op1.getOperand(0);
49411       Y = Op0;
49412     } else if (Op0.getOpcode() == X86ISD::PCMPGT &&
49413                isAllOnesOrAllOnesSplat(Op0.getOperand(1)) && Op0.hasOneUse()) {
49414       X = Op0.getOperand(0);
49415       Y = Op1;
49416     }
49417     if (X && Y) {
49418       SDLoc DL(N);
49419       SDValue Sra =
49420           getTargetVShiftByConstNode(X86ISD::VSRAI, DL, VT.getSimpleVT(), X,
49421                                      VT.getScalarSizeInBits() - 1, DAG);
49422       return DAG.getNode(X86ISD::ANDNP, DL, VT, Sra, Y);
49423     }
49424   }
49425 
49426   APInt SplatVal;
49427   if (!X86::isConstantSplat(Op1, SplatVal, false) || !SplatVal.isMask())
49428     return SDValue();
49429 
49430   // Don't prevent creation of ANDN.
49431   if (isBitwiseNot(Op0))
49432     return SDValue();
49433 
49434   if (!supportedVectorShiftWithImm(VT, Subtarget, ISD::SRL))
49435     return SDValue();
49436 
49437   unsigned EltBitWidth = VT.getScalarSizeInBits();
49438   if (EltBitWidth != DAG.ComputeNumSignBits(Op0))
49439     return SDValue();
49440 
49441   SDLoc DL(N);
49442   unsigned ShiftVal = SplatVal.countr_one();
49443   SDValue ShAmt = DAG.getTargetConstant(EltBitWidth - ShiftVal, DL, MVT::i8);
49444   SDValue Shift = DAG.getNode(X86ISD::VSRLI, DL, VT, Op0, ShAmt);
49445   return DAG.getBitcast(N->getValueType(0), Shift);
49446 }
49447 
49448 // Get the index node from the lowered DAG of a GEP IR instruction with one
49449 // indexing dimension.
getIndexFromUnindexedLoad(LoadSDNode * Ld)49450 static SDValue getIndexFromUnindexedLoad(LoadSDNode *Ld) {
49451   if (Ld->isIndexed())
49452     return SDValue();
49453 
49454   SDValue Base = Ld->getBasePtr();
49455 
49456   if (Base.getOpcode() != ISD::ADD)
49457     return SDValue();
49458 
49459   SDValue ShiftedIndex = Base.getOperand(0);
49460 
49461   if (ShiftedIndex.getOpcode() != ISD::SHL)
49462     return SDValue();
49463 
49464   return ShiftedIndex.getOperand(0);
49465 
49466 }
49467 
hasBZHI(const X86Subtarget & Subtarget,MVT VT)49468 static bool hasBZHI(const X86Subtarget &Subtarget, MVT VT) {
49469   return Subtarget.hasBMI2() &&
49470          (VT == MVT::i32 || (VT == MVT::i64 && Subtarget.is64Bit()));
49471 }
49472 
49473 // This function recognizes cases where X86 bzhi instruction can replace and
49474 // 'and-load' sequence.
49475 // In case of loading integer value from an array of constants which is defined
49476 // as follows:
49477 //
49478 //   int array[SIZE] = {0x0, 0x1, 0x3, 0x7, 0xF ..., 2^(SIZE-1) - 1}
49479 //
49480 // then applying a bitwise and on the result with another input.
49481 // It's equivalent to performing bzhi (zero high bits) on the input, with the
49482 // same index of the load.
combineAndLoadToBZHI(SDNode * Node,SelectionDAG & DAG,const X86Subtarget & Subtarget)49483 static SDValue combineAndLoadToBZHI(SDNode *Node, SelectionDAG &DAG,
49484                                     const X86Subtarget &Subtarget) {
49485   MVT VT = Node->getSimpleValueType(0);
49486   SDLoc dl(Node);
49487 
49488   // Check if subtarget has BZHI instruction for the node's type
49489   if (!hasBZHI(Subtarget, VT))
49490     return SDValue();
49491 
49492   // Try matching the pattern for both operands.
49493   for (unsigned i = 0; i < 2; i++) {
49494     SDValue N = Node->getOperand(i);
49495     LoadSDNode *Ld = dyn_cast<LoadSDNode>(N.getNode());
49496 
49497      // continue if the operand is not a load instruction
49498     if (!Ld)
49499       return SDValue();
49500 
49501     const Value *MemOp = Ld->getMemOperand()->getValue();
49502 
49503     if (!MemOp)
49504       return SDValue();
49505 
49506     if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(MemOp)) {
49507       if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) {
49508         if (GV->isConstant() && GV->hasDefinitiveInitializer()) {
49509 
49510           Constant *Init = GV->getInitializer();
49511           Type *Ty = Init->getType();
49512           if (!isa<ConstantDataArray>(Init) ||
49513               !Ty->getArrayElementType()->isIntegerTy() ||
49514               Ty->getArrayElementType()->getScalarSizeInBits() !=
49515                   VT.getSizeInBits() ||
49516               Ty->getArrayNumElements() >
49517                   Ty->getArrayElementType()->getScalarSizeInBits())
49518             continue;
49519 
49520           // Check if the array's constant elements are suitable to our case.
49521           uint64_t ArrayElementCount = Init->getType()->getArrayNumElements();
49522           bool ConstantsMatch = true;
49523           for (uint64_t j = 0; j < ArrayElementCount; j++) {
49524             auto *Elem = cast<ConstantInt>(Init->getAggregateElement(j));
49525             if (Elem->getZExtValue() != (((uint64_t)1 << j) - 1)) {
49526               ConstantsMatch = false;
49527               break;
49528             }
49529           }
49530           if (!ConstantsMatch)
49531             continue;
49532 
49533           // Do the transformation (For 32-bit type):
49534           // -> (and (load arr[idx]), inp)
49535           // <- (and (srl 0xFFFFFFFF, (sub 32, idx)))
49536           //    that will be replaced with one bzhi instruction.
49537           SDValue Inp = (i == 0) ? Node->getOperand(1) : Node->getOperand(0);
49538           SDValue SizeC = DAG.getConstant(VT.getSizeInBits(), dl, MVT::i32);
49539 
49540           // Get the Node which indexes into the array.
49541           SDValue Index = getIndexFromUnindexedLoad(Ld);
49542           if (!Index)
49543             return SDValue();
49544           Index = DAG.getZExtOrTrunc(Index, dl, MVT::i32);
49545 
49546           SDValue Sub = DAG.getNode(ISD::SUB, dl, MVT::i32, SizeC, Index);
49547           Sub = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Sub);
49548 
49549           SDValue AllOnes = DAG.getAllOnesConstant(dl, VT);
49550           SDValue LShr = DAG.getNode(ISD::SRL, dl, VT, AllOnes, Sub);
49551 
49552           return DAG.getNode(ISD::AND, dl, VT, Inp, LShr);
49553         }
49554       }
49555     }
49556   }
49557   return SDValue();
49558 }
49559 
49560 // Look for (and (bitcast (vXi1 (concat_vectors (vYi1 setcc), undef,))), C)
49561 // Where C is a mask containing the same number of bits as the setcc and
49562 // where the setcc will freely 0 upper bits of k-register. We can replace the
49563 // undef in the concat with 0s and remove the AND. This mainly helps with
49564 // v2i1/v4i1 setcc being casted to scalar.
combineScalarAndWithMaskSetcc(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)49565 static SDValue combineScalarAndWithMaskSetcc(SDNode *N, SelectionDAG &DAG,
49566                                              const X86Subtarget &Subtarget) {
49567   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
49568 
49569   EVT VT = N->getValueType(0);
49570 
49571   // Make sure this is an AND with constant. We will check the value of the
49572   // constant later.
49573   auto *C1 = dyn_cast<ConstantSDNode>(N->getOperand(1));
49574   if (!C1)
49575     return SDValue();
49576 
49577   // This is implied by the ConstantSDNode.
49578   assert(!VT.isVector() && "Expected scalar VT!");
49579 
49580   SDValue Src = N->getOperand(0);
49581   if (!Src.hasOneUse())
49582     return SDValue();
49583 
49584   // (Optionally) peek through any_extend().
49585   if (Src.getOpcode() == ISD::ANY_EXTEND) {
49586     if (!Src.getOperand(0).hasOneUse())
49587       return SDValue();
49588     Src = Src.getOperand(0);
49589   }
49590 
49591   if (Src.getOpcode() != ISD::BITCAST || !Src.getOperand(0).hasOneUse())
49592     return SDValue();
49593 
49594   Src = Src.getOperand(0);
49595   EVT SrcVT = Src.getValueType();
49596 
49597   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
49598   if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::i1 ||
49599       !TLI.isTypeLegal(SrcVT))
49600     return SDValue();
49601 
49602   if (Src.getOpcode() != ISD::CONCAT_VECTORS)
49603     return SDValue();
49604 
49605   // We only care about the first subvector of the concat, we expect the
49606   // other subvectors to be ignored due to the AND if we make the change.
49607   SDValue SubVec = Src.getOperand(0);
49608   EVT SubVecVT = SubVec.getValueType();
49609 
49610   // The RHS of the AND should be a mask with as many bits as SubVec.
49611   if (!TLI.isTypeLegal(SubVecVT) ||
49612       !C1->getAPIntValue().isMask(SubVecVT.getVectorNumElements()))
49613     return SDValue();
49614 
49615   // First subvector should be a setcc with a legal result type or a
49616   // AND containing at least one setcc with a legal result type.
49617   auto IsLegalSetCC = [&](SDValue V) {
49618     if (V.getOpcode() != ISD::SETCC)
49619       return false;
49620     EVT SetccVT = V.getOperand(0).getValueType();
49621     if (!TLI.isTypeLegal(SetccVT) ||
49622         !(Subtarget.hasVLX() || SetccVT.is512BitVector()))
49623       return false;
49624     if (!(Subtarget.hasBWI() || SetccVT.getScalarSizeInBits() >= 32))
49625       return false;
49626     return true;
49627   };
49628   if (!(IsLegalSetCC(SubVec) || (SubVec.getOpcode() == ISD::AND &&
49629                                  (IsLegalSetCC(SubVec.getOperand(0)) ||
49630                                   IsLegalSetCC(SubVec.getOperand(1))))))
49631     return SDValue();
49632 
49633   // We passed all the checks. Rebuild the concat_vectors with zeroes
49634   // and cast it back to VT.
49635   SDLoc dl(N);
49636   SmallVector<SDValue, 4> Ops(Src.getNumOperands(),
49637                               DAG.getConstant(0, dl, SubVecVT));
49638   Ops[0] = SubVec;
49639   SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, SrcVT,
49640                                Ops);
49641   EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getSizeInBits());
49642   return DAG.getZExtOrTrunc(DAG.getBitcast(IntVT, Concat), dl, VT);
49643 }
49644 
getBMIMatchingOp(unsigned Opc,SelectionDAG & DAG,SDValue OpMustEq,SDValue Op,unsigned Depth)49645 static SDValue getBMIMatchingOp(unsigned Opc, SelectionDAG &DAG,
49646                                 SDValue OpMustEq, SDValue Op, unsigned Depth) {
49647   // We don't want to go crazy with the recursion here. This isn't a super
49648   // important optimization.
49649   static constexpr unsigned kMaxDepth = 2;
49650 
49651   // Only do this re-ordering if op has one use.
49652   if (!Op.hasOneUse())
49653     return SDValue();
49654 
49655   SDLoc DL(Op);
49656   // If we hit another assosiative op, recurse further.
49657   if (Op.getOpcode() == Opc) {
49658     // Done recursing.
49659     if (Depth++ >= kMaxDepth)
49660       return SDValue();
49661 
49662     for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx)
49663       if (SDValue R =
49664               getBMIMatchingOp(Opc, DAG, OpMustEq, Op.getOperand(OpIdx), Depth))
49665         return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), R,
49666                            Op.getOperand(1 - OpIdx));
49667 
49668   } else if (Op.getOpcode() == ISD::SUB) {
49669     if (Opc == ISD::AND) {
49670       // BLSI: (and x, (sub 0, x))
49671       if (isNullConstant(Op.getOperand(0)) && Op.getOperand(1) == OpMustEq)
49672         return DAG.getNode(Opc, DL, Op.getValueType(), OpMustEq, Op);
49673     }
49674     // Opc must be ISD::AND or ISD::XOR
49675     // BLSR: (and x, (sub x, 1))
49676     // BLSMSK: (xor x, (sub x, 1))
49677     if (isOneConstant(Op.getOperand(1)) && Op.getOperand(0) == OpMustEq)
49678       return DAG.getNode(Opc, DL, Op.getValueType(), OpMustEq, Op);
49679 
49680   } else if (Op.getOpcode() == ISD::ADD) {
49681     // Opc must be ISD::AND or ISD::XOR
49682     // BLSR: (and x, (add x, -1))
49683     // BLSMSK: (xor x, (add x, -1))
49684     if (isAllOnesConstant(Op.getOperand(1)) && Op.getOperand(0) == OpMustEq)
49685       return DAG.getNode(Opc, DL, Op.getValueType(), OpMustEq, Op);
49686   }
49687   return SDValue();
49688 }
49689 
combineBMILogicOp(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)49690 static SDValue combineBMILogicOp(SDNode *N, SelectionDAG &DAG,
49691                                  const X86Subtarget &Subtarget) {
49692   EVT VT = N->getValueType(0);
49693   // Make sure this node is a candidate for BMI instructions.
49694   if (!Subtarget.hasBMI() || !VT.isScalarInteger() ||
49695       (VT != MVT::i32 && VT != MVT::i64))
49696     return SDValue();
49697 
49698   assert(N->getOpcode() == ISD::AND || N->getOpcode() == ISD::XOR);
49699 
49700   // Try and match LHS and RHS.
49701   for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx)
49702     if (SDValue OpMatch =
49703             getBMIMatchingOp(N->getOpcode(), DAG, N->getOperand(OpIdx),
49704                              N->getOperand(1 - OpIdx), 0))
49705       return OpMatch;
49706   return SDValue();
49707 }
49708 
combineX86SubCmpForFlags(SDNode * N,SDValue Flag,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & ST)49709 static SDValue combineX86SubCmpForFlags(SDNode *N, SDValue Flag,
49710                                         SelectionDAG &DAG,
49711                                         TargetLowering::DAGCombinerInfo &DCI,
49712                                         const X86Subtarget &ST) {
49713   // cmp(setcc(cc, X), 0)
49714   // brcond ne
49715   //  ->
49716   //    X
49717   //    brcond cc
49718 
49719   // sub(setcc(cc, X), 1)
49720   // brcond ne
49721   //  ->
49722   //    X
49723   //    brcond ~cc
49724   //
49725   // if only flag has users
49726 
49727   SDValue SetCC = N->getOperand(0);
49728 
49729   if (SetCC.getOpcode() != X86ISD::SETCC || !Flag.hasOneUse())
49730     return SDValue();
49731 
49732   // Check the only user of flag is `brcond ne`.
49733   SDNode *BrCond = *Flag->uses().begin();
49734   if (BrCond->getOpcode() != X86ISD::BRCOND)
49735     return SDValue();
49736   unsigned CondNo = 2;
49737   if (static_cast<X86::CondCode>(BrCond->getConstantOperandVal(CondNo)) !=
49738       X86::COND_NE)
49739     return SDValue();
49740 
49741   SDValue X = SetCC.getOperand(1);
49742   // sub has two results while X only have one. DAG combine assumes the value
49743   // type matches.
49744   if (N->getOpcode() == X86ISD::SUB)
49745     X = DAG.getMergeValues({N->getOperand(0), X}, SDLoc(N));
49746 
49747   SDValue CCN = SetCC.getOperand(0);
49748   X86::CondCode CC =
49749       static_cast<X86::CondCode>(CCN->getAsAPIntVal().getSExtValue());
49750   X86::CondCode OppositeCC = X86::GetOppositeBranchCondition(CC);
49751   // Update CC for the consumer of the flag.
49752   // The old CC is `ne`. Hence, when comparing the result with 0, we are
49753   // checking if the second condition evaluates to true. When comparing the
49754   // result with 1, we are checking uf the second condition evaluates to false.
49755   SmallVector<SDValue> Ops(BrCond->op_values());
49756   if (isNullConstant(N->getOperand(1)))
49757     Ops[CondNo] = CCN;
49758   else if (isOneConstant(N->getOperand(1)))
49759     Ops[CondNo] = DAG.getTargetConstant(OppositeCC, SDLoc(BrCond), MVT::i8);
49760   else
49761     llvm_unreachable("expect constant 0 or 1");
49762 
49763   SDValue NewBrCond =
49764       DAG.getNode(X86ISD::BRCOND, SDLoc(BrCond), BrCond->getValueType(0), Ops);
49765   // Avoid self-assign error b/c CC1 can be `e/ne`.
49766   if (BrCond != NewBrCond.getNode())
49767     DCI.CombineTo(BrCond, NewBrCond);
49768   return X;
49769 }
49770 
combineAndOrForCcmpCtest(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & ST)49771 static SDValue combineAndOrForCcmpCtest(SDNode *N, SelectionDAG &DAG,
49772                                         TargetLowering::DAGCombinerInfo &DCI,
49773                                         const X86Subtarget &ST) {
49774   // and/or(setcc(cc0, flag0), setcc(cc1, sub (X, Y)))
49775   //  ->
49776   //    setcc(cc1, ccmp(X, Y, ~cflags/cflags, cc0/~cc0, flag0))
49777 
49778   // and/or(setcc(cc0, flag0), setcc(cc1, cmp (X, 0)))
49779   //  ->
49780   //    setcc(cc1, ctest(X, X, ~cflags/cflags, cc0/~cc0, flag0))
49781   //
49782   // where cflags is determined by cc1.
49783 
49784   if (!ST.hasCCMP())
49785     return SDValue();
49786 
49787   SDValue SetCC0 = N->getOperand(0);
49788   SDValue SetCC1 = N->getOperand(1);
49789   if (SetCC0.getOpcode() != X86ISD::SETCC ||
49790       SetCC1.getOpcode() != X86ISD::SETCC)
49791     return SDValue();
49792 
49793   auto GetCombineToOpc = [&](SDValue V) -> unsigned {
49794     SDValue Op = V.getOperand(1);
49795     unsigned Opc = Op.getOpcode();
49796     if (Opc == X86ISD::SUB)
49797       return X86ISD::CCMP;
49798     if (Opc == X86ISD::CMP && isNullConstant(Op.getOperand(1)))
49799       return X86ISD::CTEST;
49800     return 0U;
49801   };
49802 
49803   unsigned NewOpc = 0;
49804 
49805   // AND/OR is commutable. Canonicalize the operands to make SETCC with SUB/CMP
49806   // appear on the right.
49807   if (!(NewOpc = GetCombineToOpc(SetCC1))) {
49808     std::swap(SetCC0, SetCC1);
49809     if (!(NewOpc = GetCombineToOpc(SetCC1)))
49810       return SDValue();
49811   }
49812 
49813   X86::CondCode CC0 =
49814       static_cast<X86::CondCode>(SetCC0.getConstantOperandVal(0));
49815   // CCMP/CTEST is not conditional when the source condition is COND_P/COND_NP.
49816   if (CC0 == X86::COND_P || CC0 == X86::COND_NP)
49817     return SDValue();
49818 
49819   bool IsOR = N->getOpcode() == ISD::OR;
49820 
49821   // CMP/TEST is executed and updates the EFLAGS normally only when SrcCC
49822   // evaluates to true. So we need to inverse CC0 as SrcCC when the logic
49823   // operator is OR. Similar for CC1.
49824   SDValue SrcCC =
49825       IsOR ? DAG.getTargetConstant(X86::GetOppositeBranchCondition(CC0),
49826                                    SDLoc(SetCC0.getOperand(0)), MVT::i8)
49827            : SetCC0.getOperand(0);
49828   SDValue CC1N = SetCC1.getOperand(0);
49829   X86::CondCode CC1 =
49830       static_cast<X86::CondCode>(CC1N->getAsAPIntVal().getSExtValue());
49831   X86::CondCode OppositeCC1 = X86::GetOppositeBranchCondition(CC1);
49832   X86::CondCode CFlagsCC = IsOR ? CC1 : OppositeCC1;
49833   SDLoc DL(N);
49834   SDValue CFlags = DAG.getTargetConstant(
49835       X86::getCCMPCondFlagsFromCondCode(CFlagsCC), DL, MVT::i8);
49836   SDValue Sub = SetCC1.getOperand(1);
49837 
49838   // Replace any uses of the old flag produced by SUB/CMP with the new one
49839   // produced by CCMP/CTEST.
49840   SDValue CCMP = (NewOpc == X86ISD::CCMP)
49841                      ? DAG.getNode(X86ISD::CCMP, DL, MVT::i32,
49842                                    {Sub.getOperand(0), Sub.getOperand(1),
49843                                     CFlags, SrcCC, SetCC0.getOperand(1)})
49844                      : DAG.getNode(X86ISD::CTEST, DL, MVT::i32,
49845                                    {Sub.getOperand(0), Sub.getOperand(0),
49846                                     CFlags, SrcCC, SetCC0.getOperand(1)});
49847 
49848   return DAG.getNode(X86ISD::SETCC, DL, MVT::i8, {CC1N, CCMP});
49849 }
49850 
combineAnd(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)49851 static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
49852                           TargetLowering::DAGCombinerInfo &DCI,
49853                           const X86Subtarget &Subtarget) {
49854   SDValue N0 = N->getOperand(0);
49855   SDValue N1 = N->getOperand(1);
49856   EVT VT = N->getValueType(0);
49857   SDLoc dl(N);
49858   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
49859 
49860   // If this is SSE1 only convert to FAND to avoid scalarization.
49861   if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32) {
49862     return DAG.getBitcast(MVT::v4i32,
49863                           DAG.getNode(X86ISD::FAND, dl, MVT::v4f32,
49864                                       DAG.getBitcast(MVT::v4f32, N0),
49865                                       DAG.getBitcast(MVT::v4f32, N1)));
49866   }
49867 
49868   // Use a 32-bit and+zext if upper bits known zero.
49869   if (VT == MVT::i64 && Subtarget.is64Bit() && !isa<ConstantSDNode>(N1)) {
49870     APInt HiMask = APInt::getHighBitsSet(64, 32);
49871     if (DAG.MaskedValueIsZero(N1, HiMask) ||
49872         DAG.MaskedValueIsZero(N0, HiMask)) {
49873       SDValue LHS = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, N0);
49874       SDValue RHS = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, N1);
49875       return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64,
49876                          DAG.getNode(ISD::AND, dl, MVT::i32, LHS, RHS));
49877     }
49878   }
49879 
49880   // Match all-of bool scalar reductions into a bitcast/movmsk + cmp.
49881   // TODO: Support multiple SrcOps.
49882   if (VT == MVT::i1) {
49883     SmallVector<SDValue, 2> SrcOps;
49884     SmallVector<APInt, 2> SrcPartials;
49885     if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps, &SrcPartials) &&
49886         SrcOps.size() == 1) {
49887       unsigned NumElts = SrcOps[0].getValueType().getVectorNumElements();
49888       EVT MaskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts);
49889       SDValue Mask = combineBitcastvxi1(DAG, MaskVT, SrcOps[0], dl, Subtarget);
49890       if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
49891         Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
49892       if (Mask) {
49893         assert(SrcPartials[0].getBitWidth() == NumElts &&
49894                "Unexpected partial reduction mask");
49895         SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
49896         Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
49897         return DAG.getSetCC(dl, MVT::i1, Mask, PartialBits, ISD::SETEQ);
49898       }
49899     }
49900   }
49901 
49902   // InstCombine converts:
49903   //    `(-x << C0) & C1`
49904   // to
49905   //    `(x * (Pow2_Ceil(C1) - (1 << C0))) & C1`
49906   // This saves an IR instruction but on x86 the neg/shift version is preferable
49907   // so undo the transform.
49908 
49909   if (N0.getOpcode() == ISD::MUL && N0.hasOneUse()) {
49910     // TODO: We don't actually need a splat for this, we just need the checks to
49911     // hold for each element.
49912     ConstantSDNode *N1C = isConstOrConstSplat(N1, /*AllowUndefs*/ true,
49913                                               /*AllowTruncation*/ false);
49914     ConstantSDNode *N01C =
49915         isConstOrConstSplat(N0.getOperand(1), /*AllowUndefs*/ true,
49916                             /*AllowTruncation*/ false);
49917     if (N1C && N01C) {
49918       const APInt &MulC = N01C->getAPIntValue();
49919       const APInt &AndC = N1C->getAPIntValue();
49920       APInt MulCLowBit = MulC & (-MulC);
49921       if (MulC.uge(AndC) && !MulC.isPowerOf2() &&
49922           (MulCLowBit + MulC).isPowerOf2()) {
49923         SDValue Neg = DAG.getNegative(N0.getOperand(0), dl, VT);
49924         int32_t MulCLowBitLog = MulCLowBit.exactLogBase2();
49925         assert(MulCLowBitLog != -1 &&
49926                "Isolated lowbit is somehow not a power of 2!");
49927         SDValue Shift = DAG.getNode(ISD::SHL, dl, VT, Neg,
49928                                     DAG.getConstant(MulCLowBitLog, dl, VT));
49929         return DAG.getNode(ISD::AND, dl, VT, Shift, N1);
49930       }
49931     }
49932   }
49933 
49934   if (SDValue SetCC = combineAndOrForCcmpCtest(N, DAG, DCI, Subtarget))
49935     return SetCC;
49936 
49937   if (SDValue V = combineScalarAndWithMaskSetcc(N, DAG, Subtarget))
49938     return V;
49939 
49940   if (SDValue R = combineBitOpWithMOVMSK(N, DAG))
49941     return R;
49942 
49943   if (SDValue R = combineBitOpWithShift(N, DAG))
49944     return R;
49945 
49946   if (SDValue R = combineBitOpWithPACK(N, DAG))
49947     return R;
49948 
49949   if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, DCI, Subtarget))
49950     return FPLogic;
49951 
49952   if (SDValue R = combineAndShuffleNot(N, DAG, Subtarget))
49953     return R;
49954 
49955   if (DCI.isBeforeLegalizeOps())
49956     return SDValue();
49957 
49958   if (SDValue R = combineCompareEqual(N, DAG, DCI, Subtarget))
49959     return R;
49960 
49961   if (SDValue R = combineAndNotIntoANDNP(N, DAG))
49962     return R;
49963 
49964   if (SDValue ShiftRight = combineAndMaskToShift(N, DAG, Subtarget))
49965     return ShiftRight;
49966 
49967   if (SDValue R = combineAndLoadToBZHI(N, DAG, Subtarget))
49968     return R;
49969 
49970   // fold (and (mul x, c1), c2) -> (mul x, (and c1, c2))
49971   // iff c2 is all/no bits mask - i.e. a select-with-zero mask.
49972   // TODO: Handle PMULDQ/PMULUDQ/VPMADDWD/VPMADDUBSW?
49973   if (VT.isVector() && getTargetConstantFromNode(N1)) {
49974     unsigned Opc0 = N0.getOpcode();
49975     if ((Opc0 == ISD::MUL || Opc0 == ISD::MULHU || Opc0 == ISD::MULHS) &&
49976         getTargetConstantFromNode(N0.getOperand(1)) &&
49977         DAG.ComputeNumSignBits(N1) == VT.getScalarSizeInBits() &&
49978         N0->hasOneUse() && N0.getOperand(1)->hasOneUse()) {
49979       SDValue MaskMul = DAG.getNode(ISD::AND, dl, VT, N0.getOperand(1), N1);
49980       return DAG.getNode(Opc0, dl, VT, N0.getOperand(0), MaskMul);
49981     }
49982   }
49983 
49984   // Fold AND(SRL(X,Y),1) -> SETCC(BT(X,Y), COND_B) iff Y is not a constant
49985   // avoids slow variable shift (moving shift amount to ECX etc.)
49986   if (isOneConstant(N1) && N0->hasOneUse()) {
49987     SDValue Src = N0;
49988     while ((Src.getOpcode() == ISD::ZERO_EXTEND ||
49989             Src.getOpcode() == ISD::TRUNCATE) &&
49990            Src.getOperand(0)->hasOneUse())
49991       Src = Src.getOperand(0);
49992     bool ContainsNOT = false;
49993     X86::CondCode X86CC = X86::COND_B;
49994     // Peek through AND(NOT(SRL(X,Y)),1).
49995     if (isBitwiseNot(Src)) {
49996       Src = Src.getOperand(0);
49997       X86CC = X86::COND_AE;
49998       ContainsNOT = true;
49999     }
50000     if (Src.getOpcode() == ISD::SRL &&
50001         !isa<ConstantSDNode>(Src.getOperand(1))) {
50002       SDValue BitNo = Src.getOperand(1);
50003       Src = Src.getOperand(0);
50004       // Peek through AND(SRL(NOT(X),Y),1).
50005       if (isBitwiseNot(Src)) {
50006         Src = Src.getOperand(0);
50007         X86CC = X86CC == X86::COND_AE ? X86::COND_B : X86::COND_AE;
50008         ContainsNOT = true;
50009       }
50010       // If we have BMI2 then SHRX should be faster for i32/i64 cases.
50011       if (!(Subtarget.hasBMI2() && !ContainsNOT && VT.getSizeInBits() >= 32))
50012         if (SDValue BT = getBT(Src, BitNo, dl, DAG))
50013           return DAG.getZExtOrTrunc(getSETCC(X86CC, BT, dl, DAG), dl, VT);
50014     }
50015   }
50016 
50017   if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) {
50018     // Attempt to recursively combine a bitmask AND with shuffles.
50019     SDValue Op(N, 0);
50020     if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
50021       return Res;
50022 
50023     // If either operand is a constant mask, then only the elements that aren't
50024     // zero are actually demanded by the other operand.
50025     auto GetDemandedMasks = [&](SDValue Op) {
50026       APInt UndefElts;
50027       SmallVector<APInt> EltBits;
50028       int NumElts = VT.getVectorNumElements();
50029       int EltSizeInBits = VT.getScalarSizeInBits();
50030       APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
50031       APInt DemandedElts = APInt::getAllOnes(NumElts);
50032       if (getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts,
50033                                         EltBits)) {
50034         DemandedBits.clearAllBits();
50035         DemandedElts.clearAllBits();
50036         for (int I = 0; I != NumElts; ++I) {
50037           if (UndefElts[I]) {
50038             // We can't assume an undef src element gives an undef dst - the
50039             // other src might be zero.
50040             DemandedBits.setAllBits();
50041             DemandedElts.setBit(I);
50042           } else if (!EltBits[I].isZero()) {
50043             DemandedBits |= EltBits[I];
50044             DemandedElts.setBit(I);
50045           }
50046         }
50047       }
50048       return std::make_pair(DemandedBits, DemandedElts);
50049     };
50050     APInt Bits0, Elts0;
50051     APInt Bits1, Elts1;
50052     std::tie(Bits0, Elts0) = GetDemandedMasks(N1);
50053     std::tie(Bits1, Elts1) = GetDemandedMasks(N0);
50054 
50055     if (TLI.SimplifyDemandedVectorElts(N0, Elts0, DCI) ||
50056         TLI.SimplifyDemandedVectorElts(N1, Elts1, DCI) ||
50057         TLI.SimplifyDemandedBits(N0, Bits0, Elts0, DCI) ||
50058         TLI.SimplifyDemandedBits(N1, Bits1, Elts1, DCI)) {
50059       if (N->getOpcode() != ISD::DELETED_NODE)
50060         DCI.AddToWorklist(N);
50061       return SDValue(N, 0);
50062     }
50063 
50064     SDValue NewN0 = TLI.SimplifyMultipleUseDemandedBits(N0, Bits0, Elts0, DAG);
50065     SDValue NewN1 = TLI.SimplifyMultipleUseDemandedBits(N1, Bits1, Elts1, DAG);
50066     if (NewN0 || NewN1)
50067       return DAG.getNode(ISD::AND, dl, VT, NewN0 ? NewN0 : N0,
50068                          NewN1 ? NewN1 : N1);
50069   }
50070 
50071   // Attempt to combine a scalar bitmask AND with an extracted shuffle.
50072   if ((VT.getScalarSizeInBits() % 8) == 0 &&
50073       N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
50074       isa<ConstantSDNode>(N0.getOperand(1)) && N0->hasOneUse()) {
50075     SDValue BitMask = N1;
50076     SDValue SrcVec = N0.getOperand(0);
50077     EVT SrcVecVT = SrcVec.getValueType();
50078 
50079     // Check that the constant bitmask masks whole bytes.
50080     APInt UndefElts;
50081     SmallVector<APInt, 64> EltBits;
50082     if (VT == SrcVecVT.getScalarType() && N0->isOnlyUserOf(SrcVec.getNode()) &&
50083         getTargetConstantBitsFromNode(BitMask, 8, UndefElts, EltBits) &&
50084         llvm::all_of(EltBits, [](const APInt &M) {
50085           return M.isZero() || M.isAllOnes();
50086         })) {
50087       unsigned NumElts = SrcVecVT.getVectorNumElements();
50088       unsigned Scale = SrcVecVT.getScalarSizeInBits() / 8;
50089       unsigned Idx = N0.getConstantOperandVal(1);
50090 
50091       // Create a root shuffle mask from the byte mask and the extracted index.
50092       SmallVector<int, 16> ShuffleMask(NumElts * Scale, SM_SentinelUndef);
50093       for (unsigned i = 0; i != Scale; ++i) {
50094         if (UndefElts[i])
50095           continue;
50096         int VecIdx = Scale * Idx + i;
50097         ShuffleMask[VecIdx] = EltBits[i].isZero() ? SM_SentinelZero : VecIdx;
50098       }
50099 
50100       if (SDValue Shuffle = combineX86ShufflesRecursively(
50101               {SrcVec}, 0, SrcVec, ShuffleMask, {}, /*Depth*/ 1,
50102               X86::MaxShuffleCombineDepth,
50103               /*HasVarMask*/ false, /*AllowVarCrossLaneMask*/ true,
50104               /*AllowVarPerLaneMask*/ true, DAG, Subtarget))
50105         return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Shuffle,
50106                            N0.getOperand(1));
50107     }
50108   }
50109 
50110   if (SDValue R = combineBMILogicOp(N, DAG, Subtarget))
50111     return R;
50112 
50113   return SDValue();
50114 }
50115 
50116 // Canonicalize OR(AND(X,C),AND(Y,~C)) -> OR(AND(X,C),ANDNP(C,Y))
canonicalizeBitSelect(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)50117 static SDValue canonicalizeBitSelect(SDNode *N, SelectionDAG &DAG,
50118                                      const X86Subtarget &Subtarget) {
50119   assert(N->getOpcode() == ISD::OR && "Unexpected Opcode");
50120 
50121   MVT VT = N->getSimpleValueType(0);
50122   unsigned EltSizeInBits = VT.getScalarSizeInBits();
50123   if (!VT.isVector() || (EltSizeInBits % 8) != 0)
50124     return SDValue();
50125 
50126   SDValue N0 = peekThroughBitcasts(N->getOperand(0));
50127   SDValue N1 = peekThroughBitcasts(N->getOperand(1));
50128   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
50129     return SDValue();
50130 
50131   // On XOP we'll lower to PCMOV so accept one use. With AVX512, we can use
50132   // VPTERNLOG. Otherwise only do this if either mask has multiple uses already.
50133   if (!(Subtarget.hasXOP() || useVPTERNLOG(Subtarget, VT) ||
50134         !N0.getOperand(1).hasOneUse() || !N1.getOperand(1).hasOneUse()))
50135     return SDValue();
50136 
50137   // Attempt to extract constant byte masks.
50138   APInt UndefElts0, UndefElts1;
50139   SmallVector<APInt, 32> EltBits0, EltBits1;
50140   if (!getTargetConstantBitsFromNode(N0.getOperand(1), 8, UndefElts0, EltBits0,
50141                                      /*AllowWholeUndefs*/ false,
50142                                      /*AllowPartialUndefs*/ false))
50143     return SDValue();
50144   if (!getTargetConstantBitsFromNode(N1.getOperand(1), 8, UndefElts1, EltBits1,
50145                                      /*AllowWholeUndefs*/ false,
50146                                      /*AllowPartialUndefs*/ false))
50147     return SDValue();
50148 
50149   for (unsigned i = 0, e = EltBits0.size(); i != e; ++i) {
50150     // TODO - add UNDEF elts support.
50151     if (UndefElts0[i] || UndefElts1[i])
50152       return SDValue();
50153     if (EltBits0[i] != ~EltBits1[i])
50154       return SDValue();
50155   }
50156 
50157   SDLoc DL(N);
50158 
50159   if (useVPTERNLOG(Subtarget, VT)) {
50160     // Emit a VPTERNLOG node directly - 0xCA is the imm code for A?B:C.
50161     // VPTERNLOG is only available as vXi32/64-bit types.
50162     MVT OpSVT = EltSizeInBits <= 32 ? MVT::i32 : MVT::i64;
50163     MVT OpVT =
50164         MVT::getVectorVT(OpSVT, VT.getSizeInBits() / OpSVT.getSizeInBits());
50165     SDValue A = DAG.getBitcast(OpVT, N0.getOperand(1));
50166     SDValue B = DAG.getBitcast(OpVT, N0.getOperand(0));
50167     SDValue C = DAG.getBitcast(OpVT, N1.getOperand(0));
50168     SDValue Imm = DAG.getTargetConstant(0xCA, DL, MVT::i8);
50169     SDValue Res = getAVX512Node(X86ISD::VPTERNLOG, DL, OpVT, {A, B, C, Imm},
50170                                 DAG, Subtarget);
50171     return DAG.getBitcast(VT, Res);
50172   }
50173 
50174   SDValue X = N->getOperand(0);
50175   SDValue Y =
50176       DAG.getNode(X86ISD::ANDNP, DL, VT, DAG.getBitcast(VT, N0.getOperand(1)),
50177                   DAG.getBitcast(VT, N1.getOperand(0)));
50178   return DAG.getNode(ISD::OR, DL, VT, X, Y);
50179 }
50180 
50181 // Try to match OR(AND(~MASK,X),AND(MASK,Y)) logic pattern.
matchLogicBlend(SDNode * N,SDValue & X,SDValue & Y,SDValue & Mask)50182 static bool matchLogicBlend(SDNode *N, SDValue &X, SDValue &Y, SDValue &Mask) {
50183   if (N->getOpcode() != ISD::OR)
50184     return false;
50185 
50186   SDValue N0 = N->getOperand(0);
50187   SDValue N1 = N->getOperand(1);
50188 
50189   // Canonicalize AND to LHS.
50190   if (N1.getOpcode() == ISD::AND)
50191     std::swap(N0, N1);
50192 
50193   // Attempt to match OR(AND(M,Y),ANDNP(M,X)).
50194   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != X86ISD::ANDNP)
50195     return false;
50196 
50197   Mask = N1.getOperand(0);
50198   X = N1.getOperand(1);
50199 
50200   // Check to see if the mask appeared in both the AND and ANDNP.
50201   if (N0.getOperand(0) == Mask)
50202     Y = N0.getOperand(1);
50203   else if (N0.getOperand(1) == Mask)
50204     Y = N0.getOperand(0);
50205   else
50206     return false;
50207 
50208   // TODO: Attempt to match against AND(XOR(-1,M),Y) as well, waiting for
50209   // ANDNP combine allows other combines to happen that prevent matching.
50210   return true;
50211 }
50212 
50213 // Try to fold:
50214 //   (or (and (m, y), (pandn m, x)))
50215 // into:
50216 //   (vselect m, x, y)
50217 // As a special case, try to fold:
50218 //   (or (and (m, (sub 0, x)), (pandn m, x)))
50219 // into:
50220 //   (sub (xor X, M), M)
combineLogicBlendIntoPBLENDV(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)50221 static SDValue combineLogicBlendIntoPBLENDV(SDNode *N, SelectionDAG &DAG,
50222                                             const X86Subtarget &Subtarget) {
50223   assert(N->getOpcode() == ISD::OR && "Unexpected Opcode");
50224 
50225   EVT VT = N->getValueType(0);
50226   if (!((VT.is128BitVector() && Subtarget.hasSSE2()) ||
50227         (VT.is256BitVector() && Subtarget.hasInt256())))
50228     return SDValue();
50229 
50230   SDValue X, Y, Mask;
50231   if (!matchLogicBlend(N, X, Y, Mask))
50232     return SDValue();
50233 
50234   // Validate that X, Y, and Mask are bitcasts, and see through them.
50235   Mask = peekThroughBitcasts(Mask);
50236   X = peekThroughBitcasts(X);
50237   Y = peekThroughBitcasts(Y);
50238 
50239   EVT MaskVT = Mask.getValueType();
50240   unsigned EltBits = MaskVT.getScalarSizeInBits();
50241 
50242   // TODO: Attempt to handle floating point cases as well?
50243   if (!MaskVT.isInteger() || DAG.ComputeNumSignBits(Mask) != EltBits)
50244     return SDValue();
50245 
50246   SDLoc DL(N);
50247 
50248   // Attempt to combine to conditional negate: (sub (xor X, M), M)
50249   if (SDValue Res = combineLogicBlendIntoConditionalNegate(VT, Mask, X, Y, DL,
50250                                                            DAG, Subtarget))
50251     return Res;
50252 
50253   // PBLENDVB is only available on SSE 4.1.
50254   if (!Subtarget.hasSSE41())
50255     return SDValue();
50256 
50257   // If we have VPTERNLOG we should prefer that since PBLENDVB is multiple uops.
50258   if (Subtarget.hasVLX())
50259     return SDValue();
50260 
50261   MVT BlendVT = VT.is256BitVector() ? MVT::v32i8 : MVT::v16i8;
50262 
50263   X = DAG.getBitcast(BlendVT, X);
50264   Y = DAG.getBitcast(BlendVT, Y);
50265   Mask = DAG.getBitcast(BlendVT, Mask);
50266   Mask = DAG.getSelect(DL, BlendVT, Mask, Y, X);
50267   return DAG.getBitcast(VT, Mask);
50268 }
50269 
50270 // Helper function for combineOrCmpEqZeroToCtlzSrl
50271 // Transforms:
50272 //   seteq(cmp x, 0)
50273 //   into:
50274 //   srl(ctlz x), log2(bitsize(x))
50275 // Input pattern is checked by caller.
lowerX86CmpEqZeroToCtlzSrl(SDValue Op,SelectionDAG & DAG)50276 static SDValue lowerX86CmpEqZeroToCtlzSrl(SDValue Op, SelectionDAG &DAG) {
50277   SDValue Cmp = Op.getOperand(1);
50278   EVT VT = Cmp.getOperand(0).getValueType();
50279   unsigned Log2b = Log2_32(VT.getSizeInBits());
50280   SDLoc dl(Op);
50281   SDValue Clz = DAG.getNode(ISD::CTLZ, dl, VT, Cmp->getOperand(0));
50282   // The result of the shift is true or false, and on X86, the 32-bit
50283   // encoding of shr and lzcnt is more desirable.
50284   SDValue Trunc = DAG.getZExtOrTrunc(Clz, dl, MVT::i32);
50285   SDValue Scc = DAG.getNode(ISD::SRL, dl, MVT::i32, Trunc,
50286                             DAG.getConstant(Log2b, dl, MVT::i8));
50287   return Scc;
50288 }
50289 
50290 // Try to transform:
50291 //   zext(or(setcc(eq, (cmp x, 0)), setcc(eq, (cmp y, 0))))
50292 //   into:
50293 //   srl(or(ctlz(x), ctlz(y)), log2(bitsize(x))
50294 // Will also attempt to match more generic cases, eg:
50295 //   zext(or(or(setcc(eq, cmp 0), setcc(eq, cmp 0)), setcc(eq, cmp 0)))
50296 // Only applies if the target supports the FastLZCNT feature.
combineOrCmpEqZeroToCtlzSrl(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)50297 static SDValue combineOrCmpEqZeroToCtlzSrl(SDNode *N, SelectionDAG &DAG,
50298                                            TargetLowering::DAGCombinerInfo &DCI,
50299                                            const X86Subtarget &Subtarget) {
50300   if (DCI.isBeforeLegalize() || !Subtarget.getTargetLowering()->isCtlzFast())
50301     return SDValue();
50302 
50303   auto isORCandidate = [](SDValue N) {
50304     return (N->getOpcode() == ISD::OR && N->hasOneUse());
50305   };
50306 
50307   // Check the zero extend is extending to 32-bit or more. The code generated by
50308   // srl(ctlz) for 16-bit or less variants of the pattern would require extra
50309   // instructions to clear the upper bits.
50310   if (!N->hasOneUse() || !N->getSimpleValueType(0).bitsGE(MVT::i32) ||
50311       !isORCandidate(N->getOperand(0)))
50312     return SDValue();
50313 
50314   // Check the node matches: setcc(eq, cmp 0)
50315   auto isSetCCCandidate = [](SDValue N) {
50316     return N->getOpcode() == X86ISD::SETCC && N->hasOneUse() &&
50317            X86::CondCode(N->getConstantOperandVal(0)) == X86::COND_E &&
50318            N->getOperand(1).getOpcode() == X86ISD::CMP &&
50319            isNullConstant(N->getOperand(1).getOperand(1)) &&
50320            N->getOperand(1).getValueType().bitsGE(MVT::i32);
50321   };
50322 
50323   SDNode *OR = N->getOperand(0).getNode();
50324   SDValue LHS = OR->getOperand(0);
50325   SDValue RHS = OR->getOperand(1);
50326 
50327   // Save nodes matching or(or, setcc(eq, cmp 0)).
50328   SmallVector<SDNode *, 2> ORNodes;
50329   while (((isORCandidate(LHS) && isSetCCCandidate(RHS)) ||
50330           (isORCandidate(RHS) && isSetCCCandidate(LHS)))) {
50331     ORNodes.push_back(OR);
50332     OR = (LHS->getOpcode() == ISD::OR) ? LHS.getNode() : RHS.getNode();
50333     LHS = OR->getOperand(0);
50334     RHS = OR->getOperand(1);
50335   }
50336 
50337   // The last OR node should match or(setcc(eq, cmp 0), setcc(eq, cmp 0)).
50338   if (!(isSetCCCandidate(LHS) && isSetCCCandidate(RHS)) ||
50339       !isORCandidate(SDValue(OR, 0)))
50340     return SDValue();
50341 
50342   // We have a or(setcc(eq, cmp 0), setcc(eq, cmp 0)) pattern, try to lower it
50343   // to
50344   // or(srl(ctlz),srl(ctlz)).
50345   // The dag combiner can then fold it into:
50346   // srl(or(ctlz, ctlz)).
50347   SDValue NewLHS = lowerX86CmpEqZeroToCtlzSrl(LHS, DAG);
50348   SDValue Ret, NewRHS;
50349   if (NewLHS && (NewRHS = lowerX86CmpEqZeroToCtlzSrl(RHS, DAG)))
50350     Ret = DAG.getNode(ISD::OR, SDLoc(OR), MVT::i32, NewLHS, NewRHS);
50351 
50352   if (!Ret)
50353     return SDValue();
50354 
50355   // Try to lower nodes matching the or(or, setcc(eq, cmp 0)) pattern.
50356   while (!ORNodes.empty()) {
50357     OR = ORNodes.pop_back_val();
50358     LHS = OR->getOperand(0);
50359     RHS = OR->getOperand(1);
50360     // Swap rhs with lhs to match or(setcc(eq, cmp, 0), or).
50361     if (RHS->getOpcode() == ISD::OR)
50362       std::swap(LHS, RHS);
50363     NewRHS = lowerX86CmpEqZeroToCtlzSrl(RHS, DAG);
50364     if (!NewRHS)
50365       return SDValue();
50366     Ret = DAG.getNode(ISD::OR, SDLoc(OR), MVT::i32, Ret, NewRHS);
50367   }
50368 
50369   return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), Ret);
50370 }
50371 
foldMaskedMergeImpl(SDValue And0_L,SDValue And0_R,SDValue And1_L,SDValue And1_R,const SDLoc & DL,SelectionDAG & DAG)50372 static SDValue foldMaskedMergeImpl(SDValue And0_L, SDValue And0_R,
50373                                    SDValue And1_L, SDValue And1_R,
50374                                    const SDLoc &DL, SelectionDAG &DAG) {
50375   if (!isBitwiseNot(And0_L, true) || !And0_L->hasOneUse())
50376     return SDValue();
50377   SDValue NotOp = And0_L->getOperand(0);
50378   if (NotOp == And1_R)
50379     std::swap(And1_R, And1_L);
50380   if (NotOp != And1_L)
50381     return SDValue();
50382 
50383   // (~(NotOp) & And0_R) | (NotOp & And1_R)
50384   // --> ((And0_R ^ And1_R) & NotOp) ^ And1_R
50385   EVT VT = And1_L->getValueType(0);
50386   SDValue Freeze_And0_R = DAG.getNode(ISD::FREEZE, SDLoc(), VT, And0_R);
50387   SDValue Xor0 = DAG.getNode(ISD::XOR, DL, VT, And1_R, Freeze_And0_R);
50388   SDValue And = DAG.getNode(ISD::AND, DL, VT, Xor0, NotOp);
50389   SDValue Xor1 = DAG.getNode(ISD::XOR, DL, VT, And, Freeze_And0_R);
50390   return Xor1;
50391 }
50392 
50393 /// Fold "masked merge" expressions like `(m & x) | (~m & y)` into the
50394 /// equivalent `((x ^ y) & m) ^ y)` pattern.
50395 /// This is typically a better representation for  targets without a fused
50396 /// "and-not" operation. This function is intended to be called from a
50397 /// `TargetLowering::PerformDAGCombine` callback on `ISD::OR` nodes.
foldMaskedMerge(SDNode * Node,SelectionDAG & DAG)50398 static SDValue foldMaskedMerge(SDNode *Node, SelectionDAG &DAG) {
50399   // Note that masked-merge variants using XOR or ADD expressions are
50400   // normalized to OR by InstCombine so we only check for OR.
50401   assert(Node->getOpcode() == ISD::OR && "Must be called with ISD::OR node");
50402   SDValue N0 = Node->getOperand(0);
50403   if (N0->getOpcode() != ISD::AND || !N0->hasOneUse())
50404     return SDValue();
50405   SDValue N1 = Node->getOperand(1);
50406   if (N1->getOpcode() != ISD::AND || !N1->hasOneUse())
50407     return SDValue();
50408 
50409   SDLoc DL(Node);
50410   SDValue N00 = N0->getOperand(0);
50411   SDValue N01 = N0->getOperand(1);
50412   SDValue N10 = N1->getOperand(0);
50413   SDValue N11 = N1->getOperand(1);
50414   if (SDValue Result = foldMaskedMergeImpl(N00, N01, N10, N11, DL, DAG))
50415     return Result;
50416   if (SDValue Result = foldMaskedMergeImpl(N01, N00, N10, N11, DL, DAG))
50417     return Result;
50418   if (SDValue Result = foldMaskedMergeImpl(N10, N11, N00, N01, DL, DAG))
50419     return Result;
50420   if (SDValue Result = foldMaskedMergeImpl(N11, N10, N00, N01, DL, DAG))
50421     return Result;
50422   return SDValue();
50423 }
50424 
50425 /// If this is an add or subtract where one operand is produced by a cmp+setcc,
50426 /// then try to convert it to an ADC or SBB. This replaces TEST+SET+{ADD/SUB}
50427 /// with CMP+{ADC, SBB}.
50428 /// Also try (ADD/SUB)+(AND(SRL,1)) bit extraction pattern with BT+{ADC, SBB}.
combineAddOrSubToADCOrSBB(bool IsSub,const SDLoc & DL,EVT VT,SDValue X,SDValue Y,SelectionDAG & DAG,bool ZeroSecondOpOnly=false)50429 static SDValue combineAddOrSubToADCOrSBB(bool IsSub, const SDLoc &DL, EVT VT,
50430                                          SDValue X, SDValue Y,
50431                                          SelectionDAG &DAG,
50432                                          bool ZeroSecondOpOnly = false) {
50433   if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
50434     return SDValue();
50435 
50436   // Look through a one-use zext.
50437   if (Y.getOpcode() == ISD::ZERO_EXTEND && Y.hasOneUse())
50438     Y = Y.getOperand(0);
50439 
50440   X86::CondCode CC;
50441   SDValue EFLAGS;
50442   if (Y.getOpcode() == X86ISD::SETCC && Y.hasOneUse()) {
50443     CC = (X86::CondCode)Y.getConstantOperandVal(0);
50444     EFLAGS = Y.getOperand(1);
50445   } else if (Y.getOpcode() == ISD::AND && isOneConstant(Y.getOperand(1)) &&
50446              Y.hasOneUse()) {
50447     EFLAGS = LowerAndToBT(Y, ISD::SETNE, DL, DAG, CC);
50448   }
50449 
50450   if (!EFLAGS)
50451     return SDValue();
50452 
50453   // If X is -1 or 0, then we have an opportunity to avoid constants required in
50454   // the general case below.
50455   auto *ConstantX = dyn_cast<ConstantSDNode>(X);
50456   if (ConstantX && !ZeroSecondOpOnly) {
50457     if ((!IsSub && CC == X86::COND_AE && ConstantX->isAllOnes()) ||
50458         (IsSub && CC == X86::COND_B && ConstantX->isZero())) {
50459       // This is a complicated way to get -1 or 0 from the carry flag:
50460       // -1 + SETAE --> -1 + (!CF) --> CF ? -1 : 0 --> SBB %eax, %eax
50461       //  0 - SETB  -->  0 -  (CF) --> CF ? -1 : 0 --> SBB %eax, %eax
50462       return DAG.getNode(X86ISD::SETCC_CARRY, DL, VT,
50463                          DAG.getTargetConstant(X86::COND_B, DL, MVT::i8),
50464                          EFLAGS);
50465     }
50466 
50467     if ((!IsSub && CC == X86::COND_BE && ConstantX->isAllOnes()) ||
50468         (IsSub && CC == X86::COND_A && ConstantX->isZero())) {
50469       if (EFLAGS.getOpcode() == X86ISD::SUB && EFLAGS.hasOneUse() &&
50470           EFLAGS.getValueType().isInteger() &&
50471           !isa<ConstantSDNode>(EFLAGS.getOperand(1))) {
50472         // Swap the operands of a SUB, and we have the same pattern as above.
50473         // -1 + SETBE (SUB A, B) --> -1 + SETAE (SUB B, A) --> SUB + SBB
50474         //  0 - SETA  (SUB A, B) -->  0 - SETB  (SUB B, A) --> SUB + SBB
50475         SDValue NewSub = DAG.getNode(
50476             X86ISD::SUB, SDLoc(EFLAGS), EFLAGS.getNode()->getVTList(),
50477             EFLAGS.getOperand(1), EFLAGS.getOperand(0));
50478         SDValue NewEFLAGS = SDValue(NewSub.getNode(), EFLAGS.getResNo());
50479         return DAG.getNode(X86ISD::SETCC_CARRY, DL, VT,
50480                            DAG.getTargetConstant(X86::COND_B, DL, MVT::i8),
50481                            NewEFLAGS);
50482       }
50483     }
50484   }
50485 
50486   if (CC == X86::COND_B) {
50487     // X + SETB Z --> adc X, 0
50488     // X - SETB Z --> sbb X, 0
50489     return DAG.getNode(IsSub ? X86ISD::SBB : X86ISD::ADC, DL,
50490                        DAG.getVTList(VT, MVT::i32), X,
50491                        DAG.getConstant(0, DL, VT), EFLAGS);
50492   }
50493 
50494   if (ZeroSecondOpOnly)
50495     return SDValue();
50496 
50497   if (CC == X86::COND_A) {
50498     // Try to convert COND_A into COND_B in an attempt to facilitate
50499     // materializing "setb reg".
50500     //
50501     // Do not flip "e > c", where "c" is a constant, because Cmp instruction
50502     // cannot take an immediate as its first operand.
50503     //
50504     if (EFLAGS.getOpcode() == X86ISD::SUB && EFLAGS.getNode()->hasOneUse() &&
50505         EFLAGS.getValueType().isInteger() &&
50506         !isa<ConstantSDNode>(EFLAGS.getOperand(1))) {
50507       SDValue NewSub =
50508           DAG.getNode(X86ISD::SUB, SDLoc(EFLAGS), EFLAGS.getNode()->getVTList(),
50509                       EFLAGS.getOperand(1), EFLAGS.getOperand(0));
50510       SDValue NewEFLAGS = NewSub.getValue(EFLAGS.getResNo());
50511       return DAG.getNode(IsSub ? X86ISD::SBB : X86ISD::ADC, DL,
50512                          DAG.getVTList(VT, MVT::i32), X,
50513                          DAG.getConstant(0, DL, VT), NewEFLAGS);
50514     }
50515   }
50516 
50517   if (CC == X86::COND_AE) {
50518     // X + SETAE --> sbb X, -1
50519     // X - SETAE --> adc X, -1
50520     return DAG.getNode(IsSub ? X86ISD::ADC : X86ISD::SBB, DL,
50521                        DAG.getVTList(VT, MVT::i32), X,
50522                        DAG.getConstant(-1, DL, VT), EFLAGS);
50523   }
50524 
50525   if (CC == X86::COND_BE) {
50526     // X + SETBE --> sbb X, -1
50527     // X - SETBE --> adc X, -1
50528     // Try to convert COND_BE into COND_AE in an attempt to facilitate
50529     // materializing "setae reg".
50530     //
50531     // Do not flip "e <= c", where "c" is a constant, because Cmp instruction
50532     // cannot take an immediate as its first operand.
50533     //
50534     if (EFLAGS.getOpcode() == X86ISD::SUB && EFLAGS.getNode()->hasOneUse() &&
50535         EFLAGS.getValueType().isInteger() &&
50536         !isa<ConstantSDNode>(EFLAGS.getOperand(1))) {
50537       SDValue NewSub =
50538           DAG.getNode(X86ISD::SUB, SDLoc(EFLAGS), EFLAGS.getNode()->getVTList(),
50539                       EFLAGS.getOperand(1), EFLAGS.getOperand(0));
50540       SDValue NewEFLAGS = NewSub.getValue(EFLAGS.getResNo());
50541       return DAG.getNode(IsSub ? X86ISD::ADC : X86ISD::SBB, DL,
50542                          DAG.getVTList(VT, MVT::i32), X,
50543                          DAG.getConstant(-1, DL, VT), NewEFLAGS);
50544     }
50545   }
50546 
50547   if (CC != X86::COND_E && CC != X86::COND_NE)
50548     return SDValue();
50549 
50550   if (EFLAGS.getOpcode() != X86ISD::CMP || !EFLAGS.hasOneUse() ||
50551       !X86::isZeroNode(EFLAGS.getOperand(1)) ||
50552       !EFLAGS.getOperand(0).getValueType().isInteger())
50553     return SDValue();
50554 
50555   SDValue Z = EFLAGS.getOperand(0);
50556   EVT ZVT = Z.getValueType();
50557 
50558   // If X is -1 or 0, then we have an opportunity to avoid constants required in
50559   // the general case below.
50560   if (ConstantX) {
50561     // 'neg' sets the carry flag when Z != 0, so create 0 or -1 using 'sbb' with
50562     // fake operands:
50563     //  0 - (Z != 0) --> sbb %eax, %eax, (neg Z)
50564     // -1 + (Z == 0) --> sbb %eax, %eax, (neg Z)
50565     if ((IsSub && CC == X86::COND_NE && ConstantX->isZero()) ||
50566         (!IsSub && CC == X86::COND_E && ConstantX->isAllOnes())) {
50567       SDValue Zero = DAG.getConstant(0, DL, ZVT);
50568       SDVTList X86SubVTs = DAG.getVTList(ZVT, MVT::i32);
50569       SDValue Neg = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, Z);
50570       return DAG.getNode(X86ISD::SETCC_CARRY, DL, VT,
50571                          DAG.getTargetConstant(X86::COND_B, DL, MVT::i8),
50572                          SDValue(Neg.getNode(), 1));
50573     }
50574 
50575     // cmp with 1 sets the carry flag when Z == 0, so create 0 or -1 using 'sbb'
50576     // with fake operands:
50577     //  0 - (Z == 0) --> sbb %eax, %eax, (cmp Z, 1)
50578     // -1 + (Z != 0) --> sbb %eax, %eax, (cmp Z, 1)
50579     if ((IsSub && CC == X86::COND_E && ConstantX->isZero()) ||
50580         (!IsSub && CC == X86::COND_NE && ConstantX->isAllOnes())) {
50581       SDValue One = DAG.getConstant(1, DL, ZVT);
50582       SDVTList X86SubVTs = DAG.getVTList(ZVT, MVT::i32);
50583       SDValue Cmp1 = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Z, One);
50584       return DAG.getNode(X86ISD::SETCC_CARRY, DL, VT,
50585                          DAG.getTargetConstant(X86::COND_B, DL, MVT::i8),
50586                          Cmp1.getValue(1));
50587     }
50588   }
50589 
50590   // (cmp Z, 1) sets the carry flag if Z is 0.
50591   SDValue One = DAG.getConstant(1, DL, ZVT);
50592   SDVTList X86SubVTs = DAG.getVTList(ZVT, MVT::i32);
50593   SDValue Cmp1 = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Z, One);
50594 
50595   // Add the flags type for ADC/SBB nodes.
50596   SDVTList VTs = DAG.getVTList(VT, MVT::i32);
50597 
50598   // X - (Z != 0) --> sub X, (zext(setne Z, 0)) --> adc X, -1, (cmp Z, 1)
50599   // X + (Z != 0) --> add X, (zext(setne Z, 0)) --> sbb X, -1, (cmp Z, 1)
50600   if (CC == X86::COND_NE)
50601     return DAG.getNode(IsSub ? X86ISD::ADC : X86ISD::SBB, DL, VTs, X,
50602                        DAG.getConstant(-1ULL, DL, VT), Cmp1.getValue(1));
50603 
50604   // X - (Z == 0) --> sub X, (zext(sete  Z, 0)) --> sbb X, 0, (cmp Z, 1)
50605   // X + (Z == 0) --> add X, (zext(sete  Z, 0)) --> adc X, 0, (cmp Z, 1)
50606   return DAG.getNode(IsSub ? X86ISD::SBB : X86ISD::ADC, DL, VTs, X,
50607                      DAG.getConstant(0, DL, VT), Cmp1.getValue(1));
50608 }
50609 
50610 /// If this is an add or subtract where one operand is produced by a cmp+setcc,
50611 /// then try to convert it to an ADC or SBB. This replaces TEST+SET+{ADD/SUB}
50612 /// with CMP+{ADC, SBB}.
combineAddOrSubToADCOrSBB(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)50613 static SDValue combineAddOrSubToADCOrSBB(SDNode *N, const SDLoc &DL,
50614                                          SelectionDAG &DAG) {
50615   bool IsSub = N->getOpcode() == ISD::SUB;
50616   SDValue X = N->getOperand(0);
50617   SDValue Y = N->getOperand(1);
50618   EVT VT = N->getValueType(0);
50619 
50620   if (SDValue ADCOrSBB = combineAddOrSubToADCOrSBB(IsSub, DL, VT, X, Y, DAG))
50621     return ADCOrSBB;
50622 
50623   // Commute and try again (negate the result for subtracts).
50624   if (SDValue ADCOrSBB = combineAddOrSubToADCOrSBB(IsSub, DL, VT, Y, X, DAG)) {
50625     if (IsSub)
50626       ADCOrSBB = DAG.getNegative(ADCOrSBB, DL, VT);
50627     return ADCOrSBB;
50628   }
50629 
50630   return SDValue();
50631 }
50632 
combineOrXorWithSETCC(SDNode * N,SDValue N0,SDValue N1,SelectionDAG & DAG)50633 static SDValue combineOrXorWithSETCC(SDNode *N, SDValue N0, SDValue N1,
50634                                      SelectionDAG &DAG) {
50635   assert((N->getOpcode() == ISD::XOR || N->getOpcode() == ISD::OR) &&
50636          "Unexpected opcode");
50637 
50638   // Delegate to combineAddOrSubToADCOrSBB if we have:
50639   //
50640   //   (xor/or (zero_extend (setcc)) imm)
50641   //
50642   // where imm is odd if and only if we have xor, in which case the XOR/OR are
50643   // equivalent to a SUB/ADD, respectively.
50644   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
50645       N0.getOperand(0).getOpcode() == X86ISD::SETCC && N0.hasOneUse()) {
50646     if (auto *N1C = dyn_cast<ConstantSDNode>(N1)) {
50647       bool IsSub = N->getOpcode() == ISD::XOR;
50648       bool N1COdd = N1C->getZExtValue() & 1;
50649       if (IsSub ? N1COdd : !N1COdd) {
50650         SDLoc DL(N);
50651         EVT VT = N->getValueType(0);
50652         if (SDValue R = combineAddOrSubToADCOrSBB(IsSub, DL, VT, N1, N0, DAG))
50653           return R;
50654       }
50655     }
50656   }
50657 
50658   // not(pcmpeq(and(X,CstPow2),0)) -> pcmpeq(and(X,CstPow2),CstPow2)
50659   if (N->getOpcode() == ISD::XOR && N0.getOpcode() == X86ISD::PCMPEQ &&
50660       N0.getOperand(0).getOpcode() == ISD::AND &&
50661       ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode()) &&
50662       ISD::isBuildVectorAllOnes(N1.getNode())) {
50663     MVT VT = N->getSimpleValueType(0);
50664     APInt UndefElts;
50665     SmallVector<APInt> EltBits;
50666     if (getTargetConstantBitsFromNode(N0.getOperand(0).getOperand(1),
50667                                       VT.getScalarSizeInBits(), UndefElts,
50668                                       EltBits)) {
50669       bool IsPow2OrUndef = true;
50670       for (unsigned I = 0, E = EltBits.size(); I != E; ++I)
50671         IsPow2OrUndef &= UndefElts[I] || EltBits[I].isPowerOf2();
50672 
50673       if (IsPow2OrUndef)
50674         return DAG.getNode(X86ISD::PCMPEQ, SDLoc(N), VT, N0.getOperand(0),
50675                            N0.getOperand(0).getOperand(1));
50676     }
50677   }
50678 
50679   return SDValue();
50680 }
50681 
combineOr(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)50682 static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
50683                          TargetLowering::DAGCombinerInfo &DCI,
50684                          const X86Subtarget &Subtarget) {
50685   SDValue N0 = N->getOperand(0);
50686   SDValue N1 = N->getOperand(1);
50687   EVT VT = N->getValueType(0);
50688   SDLoc dl(N);
50689   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
50690 
50691   // If this is SSE1 only convert to FOR to avoid scalarization.
50692   if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32) {
50693     return DAG.getBitcast(MVT::v4i32,
50694                           DAG.getNode(X86ISD::FOR, dl, MVT::v4f32,
50695                                       DAG.getBitcast(MVT::v4f32, N0),
50696                                       DAG.getBitcast(MVT::v4f32, N1)));
50697   }
50698 
50699   // Match any-of bool scalar reductions into a bitcast/movmsk + cmp.
50700   // TODO: Support multiple SrcOps.
50701   if (VT == MVT::i1) {
50702     SmallVector<SDValue, 2> SrcOps;
50703     SmallVector<APInt, 2> SrcPartials;
50704     if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps, &SrcPartials) &&
50705         SrcOps.size() == 1) {
50706       unsigned NumElts = SrcOps[0].getValueType().getVectorNumElements();
50707       EVT MaskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts);
50708       SDValue Mask = combineBitcastvxi1(DAG, MaskVT, SrcOps[0], dl, Subtarget);
50709       if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
50710         Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
50711       if (Mask) {
50712         assert(SrcPartials[0].getBitWidth() == NumElts &&
50713                "Unexpected partial reduction mask");
50714         SDValue ZeroBits = DAG.getConstant(0, dl, MaskVT);
50715         SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
50716         Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
50717         return DAG.getSetCC(dl, MVT::i1, Mask, ZeroBits, ISD::SETNE);
50718       }
50719     }
50720   }
50721 
50722   if (SDValue SetCC = combineAndOrForCcmpCtest(N, DAG, DCI, Subtarget))
50723     return SetCC;
50724 
50725   if (SDValue R = combineBitOpWithMOVMSK(N, DAG))
50726     return R;
50727 
50728   if (SDValue R = combineBitOpWithShift(N, DAG))
50729     return R;
50730 
50731   if (SDValue R = combineBitOpWithPACK(N, DAG))
50732     return R;
50733 
50734   if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, DCI, Subtarget))
50735     return FPLogic;
50736 
50737   if (DCI.isBeforeLegalizeOps())
50738     return SDValue();
50739 
50740   if (SDValue R = combineCompareEqual(N, DAG, DCI, Subtarget))
50741     return R;
50742 
50743   if (SDValue R = canonicalizeBitSelect(N, DAG, Subtarget))
50744     return R;
50745 
50746   if (SDValue R = combineLogicBlendIntoPBLENDV(N, DAG, Subtarget))
50747     return R;
50748 
50749   // (0 - SetCC) | C -> (zext (not SetCC)) * (C + 1) - 1 if we can get a LEA out of it.
50750   if ((VT == MVT::i32 || VT == MVT::i64) &&
50751       N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
50752       isNullConstant(N0.getOperand(0))) {
50753     SDValue Cond = N0.getOperand(1);
50754     if (Cond.getOpcode() == ISD::ZERO_EXTEND && Cond.hasOneUse())
50755       Cond = Cond.getOperand(0);
50756 
50757     if (Cond.getOpcode() == X86ISD::SETCC && Cond.hasOneUse()) {
50758       if (auto *CN = dyn_cast<ConstantSDNode>(N1)) {
50759         uint64_t Val = CN->getZExtValue();
50760         if (Val == 1 || Val == 2 || Val == 3 || Val == 4 || Val == 7 || Val == 8) {
50761           X86::CondCode CCode = (X86::CondCode)Cond.getConstantOperandVal(0);
50762           CCode = X86::GetOppositeBranchCondition(CCode);
50763           SDValue NotCond = getSETCC(CCode, Cond.getOperand(1), SDLoc(Cond), DAG);
50764 
50765           SDValue R = DAG.getZExtOrTrunc(NotCond, dl, VT);
50766           R = DAG.getNode(ISD::MUL, dl, VT, R, DAG.getConstant(Val + 1, dl, VT));
50767           R = DAG.getNode(ISD::SUB, dl, VT, R, DAG.getConstant(1, dl, VT));
50768           return R;
50769         }
50770       }
50771     }
50772   }
50773 
50774   // Combine OR(X,KSHIFTL(Y,Elts/2)) -> CONCAT_VECTORS(X,Y) == KUNPCK(X,Y).
50775   // Combine OR(KSHIFTL(X,Elts/2),Y) -> CONCAT_VECTORS(Y,X) == KUNPCK(Y,X).
50776   // iff the upper elements of the non-shifted arg are zero.
50777   // KUNPCK require 16+ bool vector elements.
50778   if (N0.getOpcode() == X86ISD::KSHIFTL || N1.getOpcode() == X86ISD::KSHIFTL) {
50779     unsigned NumElts = VT.getVectorNumElements();
50780     unsigned HalfElts = NumElts / 2;
50781     APInt UpperElts = APInt::getHighBitsSet(NumElts, HalfElts);
50782     if (NumElts >= 16 && N1.getOpcode() == X86ISD::KSHIFTL &&
50783         N1.getConstantOperandAPInt(1) == HalfElts &&
50784         DAG.MaskedVectorIsZero(N0, UpperElts)) {
50785       return DAG.getNode(
50786           ISD::CONCAT_VECTORS, dl, VT,
50787           extractSubVector(N0, 0, DAG, dl, HalfElts),
50788           extractSubVector(N1.getOperand(0), 0, DAG, dl, HalfElts));
50789     }
50790     if (NumElts >= 16 && N0.getOpcode() == X86ISD::KSHIFTL &&
50791         N0.getConstantOperandAPInt(1) == HalfElts &&
50792         DAG.MaskedVectorIsZero(N1, UpperElts)) {
50793       return DAG.getNode(
50794           ISD::CONCAT_VECTORS, dl, VT,
50795           extractSubVector(N1, 0, DAG, dl, HalfElts),
50796           extractSubVector(N0.getOperand(0), 0, DAG, dl, HalfElts));
50797     }
50798   }
50799 
50800   if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) {
50801     // Attempt to recursively combine an OR of shuffles.
50802     SDValue Op(N, 0);
50803     if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
50804       return Res;
50805 
50806     // If either operand is a constant mask, then only the elements that aren't
50807     // allones are actually demanded by the other operand.
50808     auto SimplifyUndemandedElts = [&](SDValue Op, SDValue OtherOp) {
50809       APInt UndefElts;
50810       SmallVector<APInt> EltBits;
50811       int NumElts = VT.getVectorNumElements();
50812       int EltSizeInBits = VT.getScalarSizeInBits();
50813       if (!getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts, EltBits))
50814         return false;
50815 
50816       APInt DemandedElts = APInt::getZero(NumElts);
50817       for (int I = 0; I != NumElts; ++I)
50818         if (!EltBits[I].isAllOnes())
50819           DemandedElts.setBit(I);
50820 
50821       return TLI.SimplifyDemandedVectorElts(OtherOp, DemandedElts, DCI);
50822     };
50823     if (SimplifyUndemandedElts(N0, N1) || SimplifyUndemandedElts(N1, N0)) {
50824       if (N->getOpcode() != ISD::DELETED_NODE)
50825         DCI.AddToWorklist(N);
50826       return SDValue(N, 0);
50827     }
50828   }
50829 
50830   // We should fold "masked merge" patterns when `andn` is not available.
50831   if (!Subtarget.hasBMI() && VT.isScalarInteger() && VT != MVT::i1)
50832     if (SDValue R = foldMaskedMerge(N, DAG))
50833       return R;
50834 
50835   if (SDValue R = combineOrXorWithSETCC(N, N0, N1, DAG))
50836     return R;
50837 
50838   return SDValue();
50839 }
50840 
50841 /// Try to turn tests against the signbit in the form of:
50842 ///   XOR(TRUNCATE(SRL(X, size(X)-1)), 1)
50843 /// into:
50844 ///   SETGT(X, -1)
foldXorTruncShiftIntoCmp(SDNode * N,SelectionDAG & DAG)50845 static SDValue foldXorTruncShiftIntoCmp(SDNode *N, SelectionDAG &DAG) {
50846   // This is only worth doing if the output type is i8 or i1.
50847   EVT ResultType = N->getValueType(0);
50848   if (ResultType != MVT::i8 && ResultType != MVT::i1)
50849     return SDValue();
50850 
50851   SDValue N0 = N->getOperand(0);
50852   SDValue N1 = N->getOperand(1);
50853 
50854   // We should be performing an xor against a truncated shift.
50855   if (N0.getOpcode() != ISD::TRUNCATE || !N0.hasOneUse())
50856     return SDValue();
50857 
50858   // Make sure we are performing an xor against one.
50859   if (!isOneConstant(N1))
50860     return SDValue();
50861 
50862   // SetCC on x86 zero extends so only act on this if it's a logical shift.
50863   SDValue Shift = N0.getOperand(0);
50864   if (Shift.getOpcode() != ISD::SRL || !Shift.hasOneUse())
50865     return SDValue();
50866 
50867   // Make sure we are truncating from one of i16, i32 or i64.
50868   EVT ShiftTy = Shift.getValueType();
50869   if (ShiftTy != MVT::i16 && ShiftTy != MVT::i32 && ShiftTy != MVT::i64)
50870     return SDValue();
50871 
50872   // Make sure the shift amount extracts the sign bit.
50873   if (!isa<ConstantSDNode>(Shift.getOperand(1)) ||
50874       Shift.getConstantOperandAPInt(1) != (ShiftTy.getSizeInBits() - 1))
50875     return SDValue();
50876 
50877   // Create a greater-than comparison against -1.
50878   // N.B. Using SETGE against 0 works but we want a canonical looking
50879   // comparison, using SETGT matches up with what TranslateX86CC.
50880   SDLoc DL(N);
50881   SDValue ShiftOp = Shift.getOperand(0);
50882   EVT ShiftOpTy = ShiftOp.getValueType();
50883   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
50884   EVT SetCCResultType = TLI.getSetCCResultType(DAG.getDataLayout(),
50885                                                *DAG.getContext(), ResultType);
50886   SDValue Cond = DAG.getSetCC(DL, SetCCResultType, ShiftOp,
50887                               DAG.getConstant(-1, DL, ShiftOpTy), ISD::SETGT);
50888   if (SetCCResultType != ResultType)
50889     Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, ResultType, Cond);
50890   return Cond;
50891 }
50892 
50893 /// Turn vector tests of the signbit in the form of:
50894 ///   xor (sra X, elt_size(X)-1), -1
50895 /// into:
50896 ///   pcmpgt X, -1
50897 ///
50898 /// This should be called before type legalization because the pattern may not
50899 /// persist after that.
foldVectorXorShiftIntoCmp(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)50900 static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
50901                                          const X86Subtarget &Subtarget) {
50902   EVT VT = N->getValueType(0);
50903   if (!VT.isSimple())
50904     return SDValue();
50905 
50906   switch (VT.getSimpleVT().SimpleTy) {
50907   // clang-format off
50908   default: return SDValue();
50909   case MVT::v16i8:
50910   case MVT::v8i16:
50911   case MVT::v4i32:
50912   case MVT::v2i64: if (!Subtarget.hasSSE2()) return SDValue(); break;
50913   case MVT::v32i8:
50914   case MVT::v16i16:
50915   case MVT::v8i32:
50916   case MVT::v4i64: if (!Subtarget.hasAVX2()) return SDValue(); break;
50917     // clang-format on
50918   }
50919 
50920   // There must be a shift right algebraic before the xor, and the xor must be a
50921   // 'not' operation.
50922   SDValue Shift = N->getOperand(0);
50923   SDValue Ones = N->getOperand(1);
50924   if (Shift.getOpcode() != ISD::SRA || !Shift.hasOneUse() ||
50925       !ISD::isBuildVectorAllOnes(Ones.getNode()))
50926     return SDValue();
50927 
50928   // The shift should be smearing the sign bit across each vector element.
50929   auto *ShiftAmt =
50930       isConstOrConstSplat(Shift.getOperand(1), /*AllowUndefs*/ true);
50931   if (!ShiftAmt ||
50932       ShiftAmt->getAPIntValue() != (Shift.getScalarValueSizeInBits() - 1))
50933     return SDValue();
50934 
50935   // Create a greater-than comparison against -1. We don't use the more obvious
50936   // greater-than-or-equal-to-zero because SSE/AVX don't have that instruction.
50937   return DAG.getSetCC(SDLoc(N), VT, Shift.getOperand(0), Ones, ISD::SETGT);
50938 }
50939 
50940 /// Detect patterns of truncation with unsigned saturation:
50941 ///
50942 /// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
50943 ///   Return the source value x to be truncated or SDValue() if the pattern was
50944 ///   not matched.
50945 ///
50946 /// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
50947 ///   where C1 >= 0 and C2 is unsigned max of destination type.
50948 ///
50949 ///    (truncate (smax (smin (x, C2), C1)) to dest_type)
50950 ///   where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
50951 ///
50952 ///   These two patterns are equivalent to:
50953 ///   (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
50954 ///   So return the smax(x, C1) value to be truncated or SDValue() if the
50955 ///   pattern was not matched.
detectUSatPattern(SDValue In,EVT VT,SelectionDAG & DAG,const SDLoc & DL)50956 static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
50957                                  const SDLoc &DL) {
50958   EVT InVT = In.getValueType();
50959 
50960   // Saturation with truncation. We truncate from InVT to VT.
50961   assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
50962          "Unexpected types for truncate operation");
50963 
50964   // Match min/max and return limit value as a parameter.
50965   auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
50966     if (V.getOpcode() == Opcode &&
50967         ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
50968       return V.getOperand(0);
50969     return SDValue();
50970   };
50971 
50972   APInt C1, C2;
50973   if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
50974     // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
50975     // the element size of the destination type.
50976     if (C2.isMask(VT.getScalarSizeInBits()))
50977       return UMin;
50978 
50979   if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
50980     if (MatchMinMax(SMin, ISD::SMAX, C1))
50981       if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
50982         return SMin;
50983 
50984   if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
50985     if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
50986       if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
50987           C2.uge(C1)) {
50988         return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
50989       }
50990 
50991   return SDValue();
50992 }
50993 
50994 /// Detect patterns of truncation with signed saturation:
50995 /// (truncate (smin ((smax (x, signed_min_of_dest_type)),
50996 ///                  signed_max_of_dest_type)) to dest_type)
50997 /// or:
50998 /// (truncate (smax ((smin (x, signed_max_of_dest_type)),
50999 ///                  signed_min_of_dest_type)) to dest_type).
51000 /// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
51001 /// Return the source value to be truncated or SDValue() if the pattern was not
51002 /// matched.
detectSSatPattern(SDValue In,EVT VT,bool MatchPackUS=false)51003 static SDValue detectSSatPattern(SDValue In, EVT VT, bool MatchPackUS = false) {
51004   unsigned NumDstBits = VT.getScalarSizeInBits();
51005   unsigned NumSrcBits = In.getScalarValueSizeInBits();
51006   assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
51007 
51008   auto MatchMinMax = [](SDValue V, unsigned Opcode,
51009                         const APInt &Limit) -> SDValue {
51010     APInt C;
51011     if (V.getOpcode() == Opcode &&
51012         ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
51013       return V.getOperand(0);
51014     return SDValue();
51015   };
51016 
51017   APInt SignedMax, SignedMin;
51018   if (MatchPackUS) {
51019     SignedMax = APInt::getAllOnes(NumDstBits).zext(NumSrcBits);
51020     SignedMin = APInt(NumSrcBits, 0);
51021   } else {
51022     SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
51023     SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
51024   }
51025 
51026   if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax))
51027     if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin))
51028       return SMax;
51029 
51030   if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin))
51031     if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax))
51032       return SMin;
51033 
51034   return SDValue();
51035 }
51036 
combineTruncateWithSat(SDValue In,EVT VT,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)51037 static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL,
51038                                       SelectionDAG &DAG,
51039                                       const X86Subtarget &Subtarget) {
51040   if (!Subtarget.hasSSE2() || !VT.isVector())
51041     return SDValue();
51042 
51043   EVT SVT = VT.getVectorElementType();
51044   EVT InVT = In.getValueType();
51045   EVT InSVT = InVT.getVectorElementType();
51046 
51047   // If we're clamping a signed 32-bit vector to 0-255 and the 32-bit vector is
51048   // split across two registers. We can use a packusdw+perm to clamp to 0-65535
51049   // and concatenate at the same time. Then we can use a final vpmovuswb to
51050   // clip to 0-255.
51051   if (Subtarget.hasBWI() && !Subtarget.useAVX512Regs() &&
51052       InVT == MVT::v16i32 && VT == MVT::v16i8) {
51053     if (SDValue USatVal = detectSSatPattern(In, VT, true)) {
51054       // Emit a VPACKUSDW+VPERMQ followed by a VPMOVUSWB.
51055       SDValue Mid = truncateVectorWithPACK(X86ISD::PACKUS, MVT::v16i16, USatVal,
51056                                            DL, DAG, Subtarget);
51057       assert(Mid && "Failed to pack!");
51058       return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, Mid);
51059     }
51060   }
51061 
51062   // vXi32 truncate instructions are available with AVX512F.
51063   // vXi16 truncate instructions are only available with AVX512BW.
51064   // For 256-bit or smaller vectors, we require VLX.
51065   // FIXME: We could widen truncates to 512 to remove the VLX restriction.
51066   // If the result type is 256-bits or larger and we have disable 512-bit
51067   // registers, we should go ahead and use the pack instructions if possible.
51068   bool PreferAVX512 = ((Subtarget.hasAVX512() && InSVT == MVT::i32) ||
51069                        (Subtarget.hasBWI() && InSVT == MVT::i16)) &&
51070                       (InVT.getSizeInBits() > 128) &&
51071                       (Subtarget.hasVLX() || InVT.getSizeInBits() > 256) &&
51072                       !(!Subtarget.useAVX512Regs() && VT.getSizeInBits() >= 256);
51073 
51074   if (!PreferAVX512 && VT.getVectorNumElements() > 1 &&
51075       isPowerOf2_32(VT.getVectorNumElements()) &&
51076       (SVT == MVT::i8 || SVT == MVT::i16) &&
51077       (InSVT == MVT::i16 || InSVT == MVT::i32)) {
51078     if (SDValue USatVal = detectSSatPattern(In, VT, true)) {
51079       // vXi32 -> vXi8 must be performed as PACKUSWB(PACKSSDW,PACKSSDW).
51080       if (SVT == MVT::i8 && InSVT == MVT::i32) {
51081         EVT MidVT = VT.changeVectorElementType(MVT::i16);
51082         SDValue Mid = truncateVectorWithPACK(X86ISD::PACKSS, MidVT, USatVal, DL,
51083                                              DAG, Subtarget);
51084         assert(Mid && "Failed to pack!");
51085         SDValue V = truncateVectorWithPACK(X86ISD::PACKUS, VT, Mid, DL, DAG,
51086                                            Subtarget);
51087         assert(V && "Failed to pack!");
51088         return V;
51089       } else if (SVT == MVT::i8 || Subtarget.hasSSE41())
51090         return truncateVectorWithPACK(X86ISD::PACKUS, VT, USatVal, DL, DAG,
51091                                       Subtarget);
51092     }
51093     if (SDValue SSatVal = detectSSatPattern(In, VT))
51094       return truncateVectorWithPACK(X86ISD::PACKSS, VT, SSatVal, DL, DAG,
51095                                     Subtarget);
51096   }
51097 
51098   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
51099   if (TLI.isTypeLegal(InVT) && InVT.isVector() && SVT != MVT::i1 &&
51100       Subtarget.hasAVX512() && (InSVT != MVT::i16 || Subtarget.hasBWI()) &&
51101       (SVT == MVT::i32 || SVT == MVT::i16 || SVT == MVT::i8)) {
51102     unsigned TruncOpc = 0;
51103     SDValue SatVal;
51104     if (SDValue SSatVal = detectSSatPattern(In, VT)) {
51105       SatVal = SSatVal;
51106       TruncOpc = X86ISD::VTRUNCS;
51107     } else if (SDValue USatVal = detectUSatPattern(In, VT, DAG, DL)) {
51108       SatVal = USatVal;
51109       TruncOpc = X86ISD::VTRUNCUS;
51110     }
51111     if (SatVal) {
51112       unsigned ResElts = VT.getVectorNumElements();
51113       // If the input type is less than 512 bits and we don't have VLX, we need
51114       // to widen to 512 bits.
51115       if (!Subtarget.hasVLX() && !InVT.is512BitVector()) {
51116         unsigned NumConcats = 512 / InVT.getSizeInBits();
51117         ResElts *= NumConcats;
51118         SmallVector<SDValue, 4> ConcatOps(NumConcats, DAG.getUNDEF(InVT));
51119         ConcatOps[0] = SatVal;
51120         InVT = EVT::getVectorVT(*DAG.getContext(), InSVT,
51121                                 NumConcats * InVT.getVectorNumElements());
51122         SatVal = DAG.getNode(ISD::CONCAT_VECTORS, DL, InVT, ConcatOps);
51123       }
51124       // Widen the result if its narrower than 128 bits.
51125       if (ResElts * SVT.getSizeInBits() < 128)
51126         ResElts = 128 / SVT.getSizeInBits();
51127       EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), SVT, ResElts);
51128       SDValue Res = DAG.getNode(TruncOpc, DL, TruncVT, SatVal);
51129       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
51130                          DAG.getIntPtrConstant(0, DL));
51131     }
51132   }
51133 
51134   return SDValue();
51135 }
51136 
combineConstantPoolLoads(SDNode * N,const SDLoc & dl,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)51137 static SDValue combineConstantPoolLoads(SDNode *N, const SDLoc &dl,
51138                                         SelectionDAG &DAG,
51139                                         TargetLowering::DAGCombinerInfo &DCI,
51140                                         const X86Subtarget &Subtarget) {
51141   auto *Ld = cast<LoadSDNode>(N);
51142   EVT RegVT = Ld->getValueType(0);
51143   SDValue Ptr = Ld->getBasePtr();
51144   SDValue Chain = Ld->getChain();
51145   ISD::LoadExtType Ext = Ld->getExtensionType();
51146 
51147   if (Ext != ISD::NON_EXTLOAD || !Subtarget.hasAVX() || !Ld->isSimple())
51148     return SDValue();
51149 
51150   if (!(RegVT.is128BitVector() || RegVT.is256BitVector()))
51151     return SDValue();
51152 
51153   const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
51154   if (!LdC)
51155     return SDValue();
51156 
51157   auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
51158                          ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
51159     for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
51160       if (Undefs[I])
51161         continue;
51162       if (UserUndefs[I] || Bits[I] != UserBits[I])
51163         return false;
51164     }
51165     return true;
51166   };
51167 
51168   // Look through all other loads/broadcasts in the chain for another constant
51169   // pool entry.
51170   for (SDNode *User : Chain->uses()) {
51171     auto *UserLd = dyn_cast<MemSDNode>(User);
51172     if (User != N && UserLd &&
51173         (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
51174          User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
51175          ISD::isNormalLoad(User)) &&
51176         UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
51177         User->getValueSizeInBits(0).getFixedValue() >
51178             RegVT.getFixedSizeInBits()) {
51179       EVT UserVT = User->getValueType(0);
51180       SDValue UserPtr = UserLd->getBasePtr();
51181       const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
51182 
51183       // See if we are loading a constant that matches in the lower
51184       // bits of a longer constant (but from a different constant pool ptr).
51185       if (UserC && UserPtr != Ptr) {
51186         unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
51187         unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
51188         if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
51189           APInt Undefs, UserUndefs;
51190           SmallVector<APInt> Bits, UserBits;
51191           unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
51192                                       UserVT.getScalarSizeInBits());
51193           if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
51194                                             Bits) &&
51195               getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
51196                                             UserUndefs, UserBits)) {
51197             if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
51198               SDValue Extract = extractSubVector(
51199                   SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
51200               Extract = DAG.getBitcast(RegVT, Extract);
51201               return DCI.CombineTo(N, Extract, SDValue(User, 1));
51202             }
51203           }
51204         }
51205       }
51206     }
51207   }
51208 
51209   return SDValue();
51210 }
51211 
combineLoad(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)51212 static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
51213                            TargetLowering::DAGCombinerInfo &DCI,
51214                            const X86Subtarget &Subtarget) {
51215   auto *Ld = cast<LoadSDNode>(N);
51216   EVT RegVT = Ld->getValueType(0);
51217   EVT MemVT = Ld->getMemoryVT();
51218   SDLoc dl(Ld);
51219   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
51220 
51221   // For chips with slow 32-byte unaligned loads, break the 32-byte operation
51222   // into two 16-byte operations. Also split non-temporal aligned loads on
51223   // pre-AVX2 targets as 32-byte loads will lower to regular temporal loads.
51224   ISD::LoadExtType Ext = Ld->getExtensionType();
51225   unsigned Fast;
51226   if (RegVT.is256BitVector() && !DCI.isBeforeLegalizeOps() &&
51227       Ext == ISD::NON_EXTLOAD &&
51228       ((Ld->isNonTemporal() && !Subtarget.hasInt256() &&
51229         Ld->getAlign() >= Align(16)) ||
51230        (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), RegVT,
51231                                *Ld->getMemOperand(), &Fast) &&
51232         !Fast))) {
51233     unsigned NumElems = RegVT.getVectorNumElements();
51234     if (NumElems < 2)
51235       return SDValue();
51236 
51237     unsigned HalfOffset = 16;
51238     SDValue Ptr1 = Ld->getBasePtr();
51239     SDValue Ptr2 =
51240         DAG.getMemBasePlusOffset(Ptr1, TypeSize::getFixed(HalfOffset), dl);
51241     EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(),
51242                                   NumElems / 2);
51243     SDValue Load1 =
51244         DAG.getLoad(HalfVT, dl, Ld->getChain(), Ptr1, Ld->getPointerInfo(),
51245                     Ld->getOriginalAlign(),
51246                     Ld->getMemOperand()->getFlags());
51247     SDValue Load2 = DAG.getLoad(HalfVT, dl, Ld->getChain(), Ptr2,
51248                                 Ld->getPointerInfo().getWithOffset(HalfOffset),
51249                                 Ld->getOriginalAlign(),
51250                                 Ld->getMemOperand()->getFlags());
51251     SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
51252                              Load1.getValue(1), Load2.getValue(1));
51253 
51254     SDValue NewVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, RegVT, Load1, Load2);
51255     return DCI.CombineTo(N, NewVec, TF, true);
51256   }
51257 
51258   // Bool vector load - attempt to cast to an integer, as we have good
51259   // (vXiY *ext(vXi1 bitcast(iX))) handling.
51260   if (Ext == ISD::NON_EXTLOAD && !Subtarget.hasAVX512() && RegVT.isVector() &&
51261       RegVT.getScalarType() == MVT::i1 && DCI.isBeforeLegalize()) {
51262     unsigned NumElts = RegVT.getVectorNumElements();
51263     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), NumElts);
51264     if (TLI.isTypeLegal(IntVT)) {
51265       SDValue IntLoad = DAG.getLoad(IntVT, dl, Ld->getChain(), Ld->getBasePtr(),
51266                                     Ld->getPointerInfo(),
51267                                     Ld->getOriginalAlign(),
51268                                     Ld->getMemOperand()->getFlags());
51269       SDValue BoolVec = DAG.getBitcast(RegVT, IntLoad);
51270       return DCI.CombineTo(N, BoolVec, IntLoad.getValue(1), true);
51271     }
51272   }
51273 
51274   // If we also broadcast this vector to a wider type, then just extract the
51275   // lowest subvector.
51276   if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
51277       (RegVT.is128BitVector() || RegVT.is256BitVector())) {
51278     SDValue Ptr = Ld->getBasePtr();
51279     SDValue Chain = Ld->getChain();
51280     for (SDNode *User : Chain->uses()) {
51281       auto *UserLd = dyn_cast<MemSDNode>(User);
51282       if (User != N && UserLd &&
51283           User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
51284           UserLd->getChain() == Chain && UserLd->getBasePtr() == Ptr &&
51285           UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits() &&
51286           !User->hasAnyUseOfValue(1) &&
51287           User->getValueSizeInBits(0).getFixedValue() >
51288               RegVT.getFixedSizeInBits()) {
51289         SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, dl,
51290                                            RegVT.getSizeInBits());
51291         Extract = DAG.getBitcast(RegVT, Extract);
51292         return DCI.CombineTo(N, Extract, SDValue(User, 1));
51293       }
51294     }
51295   }
51296 
51297   if (SDValue V = combineConstantPoolLoads(Ld, dl, DAG, DCI, Subtarget))
51298     return V;
51299 
51300   // Cast ptr32 and ptr64 pointers to the default address space before a load.
51301   unsigned AddrSpace = Ld->getAddressSpace();
51302   if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||
51303       AddrSpace == X86AS::PTR32_UPTR) {
51304     MVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
51305     if (PtrVT != Ld->getBasePtr().getSimpleValueType()) {
51306       SDValue Cast =
51307           DAG.getAddrSpaceCast(dl, PtrVT, Ld->getBasePtr(), AddrSpace, 0);
51308       return DAG.getExtLoad(Ext, dl, RegVT, Ld->getChain(), Cast,
51309                             Ld->getPointerInfo(), MemVT, Ld->getOriginalAlign(),
51310                             Ld->getMemOperand()->getFlags());
51311     }
51312   }
51313 
51314   return SDValue();
51315 }
51316 
51317 /// If V is a build vector of boolean constants and exactly one of those
51318 /// constants is true, return the operand index of that true element.
51319 /// Otherwise, return -1.
getOneTrueElt(SDValue V)51320 static int getOneTrueElt(SDValue V) {
51321   // This needs to be a build vector of booleans.
51322   // TODO: Checking for the i1 type matches the IR definition for the mask,
51323   // but the mask check could be loosened to i8 or other types. That might
51324   // also require checking more than 'allOnesValue'; eg, the x86 HW
51325   // instructions only require that the MSB is set for each mask element.
51326   // The ISD::MSTORE comments/definition do not specify how the mask operand
51327   // is formatted.
51328   auto *BV = dyn_cast<BuildVectorSDNode>(V);
51329   if (!BV || BV->getValueType(0).getVectorElementType() != MVT::i1)
51330     return -1;
51331 
51332   int TrueIndex = -1;
51333   unsigned NumElts = BV->getValueType(0).getVectorNumElements();
51334   for (unsigned i = 0; i < NumElts; ++i) {
51335     const SDValue &Op = BV->getOperand(i);
51336     if (Op.isUndef())
51337       continue;
51338     auto *ConstNode = dyn_cast<ConstantSDNode>(Op);
51339     if (!ConstNode)
51340       return -1;
51341     if (ConstNode->getAPIntValue().countr_one() >= 1) {
51342       // If we already found a one, this is too many.
51343       if (TrueIndex >= 0)
51344         return -1;
51345       TrueIndex = i;
51346     }
51347   }
51348   return TrueIndex;
51349 }
51350 
51351 /// Given a masked memory load/store operation, return true if it has one mask
51352 /// bit set. If it has one mask bit set, then also return the memory address of
51353 /// the scalar element to load/store, the vector index to insert/extract that
51354 /// scalar element, and the alignment for the scalar memory access.
getParamsForOneTrueMaskedElt(MaskedLoadStoreSDNode * MaskedOp,SelectionDAG & DAG,SDValue & Addr,SDValue & Index,Align & Alignment,unsigned & Offset)51355 static bool getParamsForOneTrueMaskedElt(MaskedLoadStoreSDNode *MaskedOp,
51356                                          SelectionDAG &DAG, SDValue &Addr,
51357                                          SDValue &Index, Align &Alignment,
51358                                          unsigned &Offset) {
51359   int TrueMaskElt = getOneTrueElt(MaskedOp->getMask());
51360   if (TrueMaskElt < 0)
51361     return false;
51362 
51363   // Get the address of the one scalar element that is specified by the mask
51364   // using the appropriate offset from the base pointer.
51365   EVT EltVT = MaskedOp->getMemoryVT().getVectorElementType();
51366   Offset = 0;
51367   Addr = MaskedOp->getBasePtr();
51368   if (TrueMaskElt != 0) {
51369     Offset = TrueMaskElt * EltVT.getStoreSize();
51370     Addr = DAG.getMemBasePlusOffset(Addr, TypeSize::getFixed(Offset),
51371                                     SDLoc(MaskedOp));
51372   }
51373 
51374   Index = DAG.getIntPtrConstant(TrueMaskElt, SDLoc(MaskedOp));
51375   Alignment = commonAlignment(MaskedOp->getOriginalAlign(),
51376                               EltVT.getStoreSize());
51377   return true;
51378 }
51379 
51380 /// If exactly one element of the mask is set for a non-extending masked load,
51381 /// it is a scalar load and vector insert.
51382 /// Note: It is expected that the degenerate cases of an all-zeros or all-ones
51383 /// mask have already been optimized in IR, so we don't bother with those here.
51384 static SDValue
reduceMaskedLoadToScalarLoad(MaskedLoadSDNode * ML,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)51385 reduceMaskedLoadToScalarLoad(MaskedLoadSDNode *ML, SelectionDAG &DAG,
51386                              TargetLowering::DAGCombinerInfo &DCI,
51387                              const X86Subtarget &Subtarget) {
51388   assert(ML->isUnindexed() && "Unexpected indexed masked load!");
51389   // TODO: This is not x86-specific, so it could be lifted to DAGCombiner.
51390   // However, some target hooks may need to be added to know when the transform
51391   // is profitable. Endianness would also have to be considered.
51392 
51393   SDValue Addr, VecIndex;
51394   Align Alignment;
51395   unsigned Offset;
51396   if (!getParamsForOneTrueMaskedElt(ML, DAG, Addr, VecIndex, Alignment, Offset))
51397     return SDValue();
51398 
51399   // Load the one scalar element that is specified by the mask using the
51400   // appropriate offset from the base pointer.
51401   SDLoc DL(ML);
51402   EVT VT = ML->getValueType(0);
51403   EVT EltVT = VT.getVectorElementType();
51404 
51405   EVT CastVT = VT;
51406   if (EltVT == MVT::i64 && !Subtarget.is64Bit()) {
51407     EltVT = MVT::f64;
51408     CastVT = VT.changeVectorElementType(EltVT);
51409   }
51410 
51411   SDValue Load =
51412       DAG.getLoad(EltVT, DL, ML->getChain(), Addr,
51413                   ML->getPointerInfo().getWithOffset(Offset),
51414                   Alignment, ML->getMemOperand()->getFlags());
51415 
51416   SDValue PassThru = DAG.getBitcast(CastVT, ML->getPassThru());
51417 
51418   // Insert the loaded element into the appropriate place in the vector.
51419   SDValue Insert =
51420       DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, CastVT, PassThru, Load, VecIndex);
51421   Insert = DAG.getBitcast(VT, Insert);
51422   return DCI.CombineTo(ML, Insert, Load.getValue(1), true);
51423 }
51424 
51425 static SDValue
combineMaskedLoadConstantMask(MaskedLoadSDNode * ML,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)51426 combineMaskedLoadConstantMask(MaskedLoadSDNode *ML, SelectionDAG &DAG,
51427                               TargetLowering::DAGCombinerInfo &DCI) {
51428   assert(ML->isUnindexed() && "Unexpected indexed masked load!");
51429   if (!ISD::isBuildVectorOfConstantSDNodes(ML->getMask().getNode()))
51430     return SDValue();
51431 
51432   SDLoc DL(ML);
51433   EVT VT = ML->getValueType(0);
51434 
51435   // If we are loading the first and last elements of a vector, it is safe and
51436   // always faster to load the whole vector. Replace the masked load with a
51437   // vector load and select.
51438   unsigned NumElts = VT.getVectorNumElements();
51439   BuildVectorSDNode *MaskBV = cast<BuildVectorSDNode>(ML->getMask());
51440   bool LoadFirstElt = !isNullConstant(MaskBV->getOperand(0));
51441   bool LoadLastElt = !isNullConstant(MaskBV->getOperand(NumElts - 1));
51442   if (LoadFirstElt && LoadLastElt) {
51443     SDValue VecLd = DAG.getLoad(VT, DL, ML->getChain(), ML->getBasePtr(),
51444                                 ML->getMemOperand());
51445     SDValue Blend = DAG.getSelect(DL, VT, ML->getMask(), VecLd,
51446                                   ML->getPassThru());
51447     return DCI.CombineTo(ML, Blend, VecLd.getValue(1), true);
51448   }
51449 
51450   // Convert a masked load with a constant mask into a masked load and a select.
51451   // This allows the select operation to use a faster kind of select instruction
51452   // (for example, vblendvps -> vblendps).
51453 
51454   // Don't try this if the pass-through operand is already undefined. That would
51455   // cause an infinite loop because that's what we're about to create.
51456   if (ML->getPassThru().isUndef())
51457     return SDValue();
51458 
51459   if (ISD::isBuildVectorAllZeros(ML->getPassThru().getNode()))
51460     return SDValue();
51461 
51462   // The new masked load has an undef pass-through operand. The select uses the
51463   // original pass-through operand.
51464   SDValue NewML = DAG.getMaskedLoad(
51465       VT, DL, ML->getChain(), ML->getBasePtr(), ML->getOffset(), ML->getMask(),
51466       DAG.getUNDEF(VT), ML->getMemoryVT(), ML->getMemOperand(),
51467       ML->getAddressingMode(), ML->getExtensionType());
51468   SDValue Blend = DAG.getSelect(DL, VT, ML->getMask(), NewML,
51469                                 ML->getPassThru());
51470 
51471   return DCI.CombineTo(ML, Blend, NewML.getValue(1), true);
51472 }
51473 
combineMaskedLoad(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)51474 static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
51475                                  TargetLowering::DAGCombinerInfo &DCI,
51476                                  const X86Subtarget &Subtarget) {
51477   auto *Mld = cast<MaskedLoadSDNode>(N);
51478 
51479   // TODO: Expanding load with constant mask may be optimized as well.
51480   if (Mld->isExpandingLoad())
51481     return SDValue();
51482 
51483   if (Mld->getExtensionType() == ISD::NON_EXTLOAD) {
51484     if (SDValue ScalarLoad =
51485             reduceMaskedLoadToScalarLoad(Mld, DAG, DCI, Subtarget))
51486       return ScalarLoad;
51487 
51488     // TODO: Do some AVX512 subsets benefit from this transform?
51489     if (!Subtarget.hasAVX512())
51490       if (SDValue Blend = combineMaskedLoadConstantMask(Mld, DAG, DCI))
51491         return Blend;
51492   }
51493 
51494   // If the mask value has been legalized to a non-boolean vector, try to
51495   // simplify ops leading up to it. We only demand the MSB of each lane.
51496   SDValue Mask = Mld->getMask();
51497   if (Mask.getScalarValueSizeInBits() != 1) {
51498     EVT VT = Mld->getValueType(0);
51499     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
51500     APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
51501     if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
51502       if (N->getOpcode() != ISD::DELETED_NODE)
51503         DCI.AddToWorklist(N);
51504       return SDValue(N, 0);
51505     }
51506     if (SDValue NewMask =
51507             TLI.SimplifyMultipleUseDemandedBits(Mask, DemandedBits, DAG))
51508       return DAG.getMaskedLoad(
51509           VT, SDLoc(N), Mld->getChain(), Mld->getBasePtr(), Mld->getOffset(),
51510           NewMask, Mld->getPassThru(), Mld->getMemoryVT(), Mld->getMemOperand(),
51511           Mld->getAddressingMode(), Mld->getExtensionType());
51512   }
51513 
51514   return SDValue();
51515 }
51516 
51517 /// If exactly one element of the mask is set for a non-truncating masked store,
51518 /// it is a vector extract and scalar store.
51519 /// Note: It is expected that the degenerate cases of an all-zeros or all-ones
51520 /// mask have already been optimized in IR, so we don't bother with those here.
reduceMaskedStoreToScalarStore(MaskedStoreSDNode * MS,SelectionDAG & DAG,const X86Subtarget & Subtarget)51521 static SDValue reduceMaskedStoreToScalarStore(MaskedStoreSDNode *MS,
51522                                               SelectionDAG &DAG,
51523                                               const X86Subtarget &Subtarget) {
51524   // TODO: This is not x86-specific, so it could be lifted to DAGCombiner.
51525   // However, some target hooks may need to be added to know when the transform
51526   // is profitable. Endianness would also have to be considered.
51527 
51528   SDValue Addr, VecIndex;
51529   Align Alignment;
51530   unsigned Offset;
51531   if (!getParamsForOneTrueMaskedElt(MS, DAG, Addr, VecIndex, Alignment, Offset))
51532     return SDValue();
51533 
51534   // Extract the one scalar element that is actually being stored.
51535   SDLoc DL(MS);
51536   SDValue Value = MS->getValue();
51537   EVT VT = Value.getValueType();
51538   EVT EltVT = VT.getVectorElementType();
51539   if (EltVT == MVT::i64 && !Subtarget.is64Bit()) {
51540     EltVT = MVT::f64;
51541     EVT CastVT = VT.changeVectorElementType(EltVT);
51542     Value = DAG.getBitcast(CastVT, Value);
51543   }
51544   SDValue Extract =
51545       DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Value, VecIndex);
51546 
51547   // Store that element at the appropriate offset from the base pointer.
51548   return DAG.getStore(MS->getChain(), DL, Extract, Addr,
51549                       MS->getPointerInfo().getWithOffset(Offset),
51550                       Alignment, MS->getMemOperand()->getFlags());
51551 }
51552 
combineMaskedStore(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)51553 static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
51554                                   TargetLowering::DAGCombinerInfo &DCI,
51555                                   const X86Subtarget &Subtarget) {
51556   MaskedStoreSDNode *Mst = cast<MaskedStoreSDNode>(N);
51557   if (Mst->isCompressingStore())
51558     return SDValue();
51559 
51560   EVT VT = Mst->getValue().getValueType();
51561   SDLoc dl(Mst);
51562   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
51563 
51564   if (Mst->isTruncatingStore())
51565     return SDValue();
51566 
51567   if (SDValue ScalarStore = reduceMaskedStoreToScalarStore(Mst, DAG, Subtarget))
51568     return ScalarStore;
51569 
51570   // If the mask value has been legalized to a non-boolean vector, try to
51571   // simplify ops leading up to it. We only demand the MSB of each lane.
51572   SDValue Mask = Mst->getMask();
51573   if (Mask.getScalarValueSizeInBits() != 1) {
51574     APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
51575     if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
51576       if (N->getOpcode() != ISD::DELETED_NODE)
51577         DCI.AddToWorklist(N);
51578       return SDValue(N, 0);
51579     }
51580     if (SDValue NewMask =
51581             TLI.SimplifyMultipleUseDemandedBits(Mask, DemandedBits, DAG))
51582       return DAG.getMaskedStore(Mst->getChain(), SDLoc(N), Mst->getValue(),
51583                                 Mst->getBasePtr(), Mst->getOffset(), NewMask,
51584                                 Mst->getMemoryVT(), Mst->getMemOperand(),
51585                                 Mst->getAddressingMode());
51586   }
51587 
51588   SDValue Value = Mst->getValue();
51589   if (Value.getOpcode() == ISD::TRUNCATE && Value.getNode()->hasOneUse() &&
51590       TLI.isTruncStoreLegal(Value.getOperand(0).getValueType(),
51591                             Mst->getMemoryVT())) {
51592     return DAG.getMaskedStore(Mst->getChain(), SDLoc(N), Value.getOperand(0),
51593                               Mst->getBasePtr(), Mst->getOffset(), Mask,
51594                               Mst->getMemoryVT(), Mst->getMemOperand(),
51595                               Mst->getAddressingMode(), true);
51596   }
51597 
51598   return SDValue();
51599 }
51600 
combineStore(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)51601 static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
51602                             TargetLowering::DAGCombinerInfo &DCI,
51603                             const X86Subtarget &Subtarget) {
51604   StoreSDNode *St = cast<StoreSDNode>(N);
51605   EVT StVT = St->getMemoryVT();
51606   SDLoc dl(St);
51607   SDValue StoredVal = St->getValue();
51608   EVT VT = StoredVal.getValueType();
51609   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
51610 
51611   // Convert a store of vXi1 into a store of iX and a bitcast.
51612   if (!Subtarget.hasAVX512() && VT == StVT && VT.isVector() &&
51613       VT.getVectorElementType() == MVT::i1) {
51614 
51615     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), VT.getVectorNumElements());
51616     StoredVal = DAG.getBitcast(NewVT, StoredVal);
51617 
51618     return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
51619                         St->getPointerInfo(), St->getOriginalAlign(),
51620                         St->getMemOperand()->getFlags());
51621   }
51622 
51623   // If this is a store of a scalar_to_vector to v1i1, just use a scalar store.
51624   // This will avoid a copy to k-register.
51625   if (VT == MVT::v1i1 && VT == StVT && Subtarget.hasAVX512() &&
51626       StoredVal.getOpcode() == ISD::SCALAR_TO_VECTOR &&
51627       StoredVal.getOperand(0).getValueType() == MVT::i8) {
51628     SDValue Val = StoredVal.getOperand(0);
51629     // We must store zeros to the unused bits.
51630     Val = DAG.getZeroExtendInReg(Val, dl, MVT::i1);
51631     return DAG.getStore(St->getChain(), dl, Val,
51632                         St->getBasePtr(), St->getPointerInfo(),
51633                         St->getOriginalAlign(),
51634                         St->getMemOperand()->getFlags());
51635   }
51636 
51637   // Widen v2i1/v4i1 stores to v8i1.
51638   if ((VT == MVT::v1i1 || VT == MVT::v2i1 || VT == MVT::v4i1) && VT == StVT &&
51639       Subtarget.hasAVX512()) {
51640     unsigned NumConcats = 8 / VT.getVectorNumElements();
51641     // We must store zeros to the unused bits.
51642     SmallVector<SDValue, 4> Ops(NumConcats, DAG.getConstant(0, dl, VT));
51643     Ops[0] = StoredVal;
51644     StoredVal = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops);
51645     return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
51646                         St->getPointerInfo(), St->getOriginalAlign(),
51647                         St->getMemOperand()->getFlags());
51648   }
51649 
51650   // Turn vXi1 stores of constants into a scalar store.
51651   if ((VT == MVT::v8i1 || VT == MVT::v16i1 || VT == MVT::v32i1 ||
51652        VT == MVT::v64i1) && VT == StVT && TLI.isTypeLegal(VT) &&
51653       ISD::isBuildVectorOfConstantSDNodes(StoredVal.getNode())) {
51654     // If its a v64i1 store without 64-bit support, we need two stores.
51655     if (!DCI.isBeforeLegalize() && VT == MVT::v64i1 && !Subtarget.is64Bit()) {
51656       SDValue Lo = DAG.getBuildVector(MVT::v32i1, dl,
51657                                       StoredVal->ops().slice(0, 32));
51658       Lo = combinevXi1ConstantToInteger(Lo, DAG);
51659       SDValue Hi = DAG.getBuildVector(MVT::v32i1, dl,
51660                                       StoredVal->ops().slice(32, 32));
51661       Hi = combinevXi1ConstantToInteger(Hi, DAG);
51662 
51663       SDValue Ptr0 = St->getBasePtr();
51664       SDValue Ptr1 = DAG.getMemBasePlusOffset(Ptr0, TypeSize::getFixed(4), dl);
51665 
51666       SDValue Ch0 =
51667           DAG.getStore(St->getChain(), dl, Lo, Ptr0, St->getPointerInfo(),
51668                        St->getOriginalAlign(),
51669                        St->getMemOperand()->getFlags());
51670       SDValue Ch1 =
51671           DAG.getStore(St->getChain(), dl, Hi, Ptr1,
51672                        St->getPointerInfo().getWithOffset(4),
51673                        St->getOriginalAlign(),
51674                        St->getMemOperand()->getFlags());
51675       return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Ch0, Ch1);
51676     }
51677 
51678     StoredVal = combinevXi1ConstantToInteger(StoredVal, DAG);
51679     return DAG.getStore(St->getChain(), dl, StoredVal, St->getBasePtr(),
51680                         St->getPointerInfo(), St->getOriginalAlign(),
51681                         St->getMemOperand()->getFlags());
51682   }
51683 
51684   // If we are saving a 32-byte vector and 32-byte stores are slow, such as on
51685   // Sandy Bridge, perform two 16-byte stores.
51686   unsigned Fast;
51687   if (VT.is256BitVector() && StVT == VT &&
51688       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
51689                              *St->getMemOperand(), &Fast) &&
51690       !Fast) {
51691     unsigned NumElems = VT.getVectorNumElements();
51692     if (NumElems < 2)
51693       return SDValue();
51694 
51695     return splitVectorStore(St, DAG);
51696   }
51697 
51698   // Split under-aligned vector non-temporal stores.
51699   if (St->isNonTemporal() && StVT == VT &&
51700       St->getAlign().value() < VT.getStoreSize()) {
51701     // ZMM/YMM nt-stores - either it can be stored as a series of shorter
51702     // vectors or the legalizer can scalarize it to use MOVNTI.
51703     if (VT.is256BitVector() || VT.is512BitVector()) {
51704       unsigned NumElems = VT.getVectorNumElements();
51705       if (NumElems < 2)
51706         return SDValue();
51707       return splitVectorStore(St, DAG);
51708     }
51709 
51710     // XMM nt-stores - scalarize this to f64 nt-stores on SSE4A, else i32/i64
51711     // to use MOVNTI.
51712     if (VT.is128BitVector() && Subtarget.hasSSE2()) {
51713       MVT NTVT = Subtarget.hasSSE4A()
51714                      ? MVT::v2f64
51715                      : (TLI.isTypeLegal(MVT::i64) ? MVT::v2i64 : MVT::v4i32);
51716       return scalarizeVectorStore(St, NTVT, DAG);
51717     }
51718   }
51719 
51720   // Try to optimize v16i16->v16i8 truncating stores when BWI is not
51721   // supported, but avx512f is by extending to v16i32 and truncating.
51722   if (!St->isTruncatingStore() && VT == MVT::v16i8 && !Subtarget.hasBWI() &&
51723       St->getValue().getOpcode() == ISD::TRUNCATE &&
51724       St->getValue().getOperand(0).getValueType() == MVT::v16i16 &&
51725       TLI.isTruncStoreLegal(MVT::v16i32, MVT::v16i8) &&
51726       St->getValue().hasOneUse() && !DCI.isBeforeLegalizeOps()) {
51727     SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::v16i32,
51728                               St->getValue().getOperand(0));
51729     return DAG.getTruncStore(St->getChain(), dl, Ext, St->getBasePtr(),
51730                              MVT::v16i8, St->getMemOperand());
51731   }
51732 
51733   // Try to fold a VTRUNCUS or VTRUNCS into a truncating store.
51734   if (!St->isTruncatingStore() &&
51735       (StoredVal.getOpcode() == X86ISD::VTRUNCUS ||
51736        StoredVal.getOpcode() == X86ISD::VTRUNCS) &&
51737       StoredVal.hasOneUse() &&
51738       TLI.isTruncStoreLegal(StoredVal.getOperand(0).getValueType(), VT)) {
51739     bool IsSigned = StoredVal.getOpcode() == X86ISD::VTRUNCS;
51740     return EmitTruncSStore(IsSigned, St->getChain(),
51741                            dl, StoredVal.getOperand(0), St->getBasePtr(),
51742                            VT, St->getMemOperand(), DAG);
51743   }
51744 
51745   // Try to fold a extract_element(VTRUNC) pattern into a truncating store.
51746   if (!St->isTruncatingStore()) {
51747     auto IsExtractedElement = [](SDValue V) {
51748       if (V.getOpcode() == ISD::TRUNCATE && V.hasOneUse())
51749         V = V.getOperand(0);
51750       unsigned Opc = V.getOpcode();
51751       if ((Opc == ISD::EXTRACT_VECTOR_ELT || Opc == X86ISD::PEXTRW) &&
51752           isNullConstant(V.getOperand(1)) && V.hasOneUse() &&
51753           V.getOperand(0).hasOneUse())
51754         return V.getOperand(0);
51755       return SDValue();
51756     };
51757     if (SDValue Extract = IsExtractedElement(StoredVal)) {
51758       SDValue Trunc = peekThroughOneUseBitcasts(Extract);
51759       if (Trunc.getOpcode() == X86ISD::VTRUNC) {
51760         SDValue Src = Trunc.getOperand(0);
51761         MVT DstVT = Trunc.getSimpleValueType();
51762         MVT SrcVT = Src.getSimpleValueType();
51763         unsigned NumSrcElts = SrcVT.getVectorNumElements();
51764         unsigned NumTruncBits = DstVT.getScalarSizeInBits() * NumSrcElts;
51765         MVT TruncVT = MVT::getVectorVT(DstVT.getScalarType(), NumSrcElts);
51766         if (NumTruncBits == VT.getSizeInBits() &&
51767             TLI.isTruncStoreLegal(SrcVT, TruncVT)) {
51768           return DAG.getTruncStore(St->getChain(), dl, Src, St->getBasePtr(),
51769                                    TruncVT, St->getMemOperand());
51770         }
51771       }
51772     }
51773   }
51774 
51775   // Optimize trunc store (of multiple scalars) to shuffle and store.
51776   // First, pack all of the elements in one place. Next, store to memory
51777   // in fewer chunks.
51778   if (St->isTruncatingStore() && VT.isVector()) {
51779     if (TLI.isTruncStoreLegal(VT, StVT)) {
51780       if (SDValue Val = detectSSatPattern(St->getValue(), St->getMemoryVT()))
51781         return EmitTruncSStore(true /* Signed saturation */, St->getChain(),
51782                                dl, Val, St->getBasePtr(),
51783                                St->getMemoryVT(), St->getMemOperand(), DAG);
51784       if (SDValue Val = detectUSatPattern(St->getValue(), St->getMemoryVT(),
51785                                           DAG, dl))
51786         return EmitTruncSStore(false /* Unsigned saturation */, St->getChain(),
51787                                dl, Val, St->getBasePtr(),
51788                                St->getMemoryVT(), St->getMemOperand(), DAG);
51789     }
51790 
51791     return SDValue();
51792   }
51793 
51794   // Cast ptr32 and ptr64 pointers to the default address space before a store.
51795   unsigned AddrSpace = St->getAddressSpace();
51796   if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||
51797       AddrSpace == X86AS::PTR32_UPTR) {
51798     MVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
51799     if (PtrVT != St->getBasePtr().getSimpleValueType()) {
51800       SDValue Cast =
51801           DAG.getAddrSpaceCast(dl, PtrVT, St->getBasePtr(), AddrSpace, 0);
51802       return DAG.getTruncStore(
51803           St->getChain(), dl, StoredVal, Cast, St->getPointerInfo(), StVT,
51804           St->getOriginalAlign(), St->getMemOperand()->getFlags(),
51805           St->getAAInfo());
51806     }
51807   }
51808 
51809   // Turn load->store of MMX types into GPR load/stores.  This avoids clobbering
51810   // the FP state in cases where an emms may be missing.
51811   // A preferable solution to the general problem is to figure out the right
51812   // places to insert EMMS.  This qualifies as a quick hack.
51813 
51814   // Similarly, turn load->store of i64 into double load/stores in 32-bit mode.
51815   if (VT.getSizeInBits() != 64)
51816     return SDValue();
51817 
51818   const Function &F = DAG.getMachineFunction().getFunction();
51819   bool NoImplicitFloatOps = F.hasFnAttribute(Attribute::NoImplicitFloat);
51820   bool F64IsLegal =
51821       !Subtarget.useSoftFloat() && !NoImplicitFloatOps && Subtarget.hasSSE2();
51822 
51823   if (!F64IsLegal || Subtarget.is64Bit())
51824     return SDValue();
51825 
51826   if (VT == MVT::i64 && isa<LoadSDNode>(St->getValue()) &&
51827       cast<LoadSDNode>(St->getValue())->isSimple() &&
51828       St->getChain().hasOneUse() && St->isSimple()) {
51829     auto *Ld = cast<LoadSDNode>(St->getValue());
51830 
51831     if (!ISD::isNormalLoad(Ld))
51832       return SDValue();
51833 
51834     // Avoid the transformation if there are multiple uses of the loaded value.
51835     if (!Ld->hasNUsesOfValue(1, 0))
51836       return SDValue();
51837 
51838     SDLoc LdDL(Ld);
51839     SDLoc StDL(N);
51840     // Lower to a single movq load/store pair.
51841     SDValue NewLd = DAG.getLoad(MVT::f64, LdDL, Ld->getChain(),
51842                                 Ld->getBasePtr(), Ld->getMemOperand());
51843 
51844     // Make sure new load is placed in same chain order.
51845     DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
51846     return DAG.getStore(St->getChain(), StDL, NewLd, St->getBasePtr(),
51847                         St->getMemOperand());
51848   }
51849 
51850   // This is similar to the above case, but here we handle a scalar 64-bit
51851   // integer store that is extracted from a vector on a 32-bit target.
51852   // If we have SSE2, then we can treat it like a floating-point double
51853   // to get past legalization. The execution dependencies fixup pass will
51854   // choose the optimal machine instruction for the store if this really is
51855   // an integer or v2f32 rather than an f64.
51856   if (VT == MVT::i64 &&
51857       St->getOperand(1).getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
51858     SDValue OldExtract = St->getOperand(1);
51859     SDValue ExtOp0 = OldExtract.getOperand(0);
51860     unsigned VecSize = ExtOp0.getValueSizeInBits();
51861     EVT VecVT = EVT::getVectorVT(*DAG.getContext(), MVT::f64, VecSize / 64);
51862     SDValue BitCast = DAG.getBitcast(VecVT, ExtOp0);
51863     SDValue NewExtract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f64,
51864                                      BitCast, OldExtract.getOperand(1));
51865     return DAG.getStore(St->getChain(), dl, NewExtract, St->getBasePtr(),
51866                         St->getPointerInfo(), St->getOriginalAlign(),
51867                         St->getMemOperand()->getFlags());
51868   }
51869 
51870   return SDValue();
51871 }
51872 
combineVEXTRACT_STORE(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)51873 static SDValue combineVEXTRACT_STORE(SDNode *N, SelectionDAG &DAG,
51874                                      TargetLowering::DAGCombinerInfo &DCI,
51875                                      const X86Subtarget &Subtarget) {
51876   auto *St = cast<MemIntrinsicSDNode>(N);
51877 
51878   SDValue StoredVal = N->getOperand(1);
51879   MVT VT = StoredVal.getSimpleValueType();
51880   EVT MemVT = St->getMemoryVT();
51881 
51882   // Figure out which elements we demand.
51883   unsigned StElts = MemVT.getSizeInBits() / VT.getScalarSizeInBits();
51884   APInt DemandedElts = APInt::getLowBitsSet(VT.getVectorNumElements(), StElts);
51885 
51886   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
51887   if (TLI.SimplifyDemandedVectorElts(StoredVal, DemandedElts, DCI)) {
51888     if (N->getOpcode() != ISD::DELETED_NODE)
51889       DCI.AddToWorklist(N);
51890     return SDValue(N, 0);
51891   }
51892 
51893   return SDValue();
51894 }
51895 
51896 /// Return 'true' if this vector operation is "horizontal"
51897 /// and return the operands for the horizontal operation in LHS and RHS.  A
51898 /// horizontal operation performs the binary operation on successive elements
51899 /// of its first operand, then on successive elements of its second operand,
51900 /// returning the resulting values in a vector.  For example, if
51901 ///   A = < float a0, float a1, float a2, float a3 >
51902 /// and
51903 ///   B = < float b0, float b1, float b2, float b3 >
51904 /// then the result of doing a horizontal operation on A and B is
51905 ///   A horizontal-op B = < a0 op a1, a2 op a3, b0 op b1, b2 op b3 >.
51906 /// In short, LHS and RHS are inspected to see if LHS op RHS is of the form
51907 /// A horizontal-op B, for some already available A and B, and if so then LHS is
51908 /// set to A, RHS to B, and the routine returns 'true'.
isHorizontalBinOp(unsigned HOpcode,SDValue & LHS,SDValue & RHS,SelectionDAG & DAG,const X86Subtarget & Subtarget,bool IsCommutative,SmallVectorImpl<int> & PostShuffleMask,bool ForceHorizOp)51909 static bool isHorizontalBinOp(unsigned HOpcode, SDValue &LHS, SDValue &RHS,
51910                               SelectionDAG &DAG, const X86Subtarget &Subtarget,
51911                               bool IsCommutative,
51912                               SmallVectorImpl<int> &PostShuffleMask,
51913                               bool ForceHorizOp) {
51914   // If either operand is undef, bail out. The binop should be simplified.
51915   if (LHS.isUndef() || RHS.isUndef())
51916     return false;
51917 
51918   // Look for the following pattern:
51919   //   A = < float a0, float a1, float a2, float a3 >
51920   //   B = < float b0, float b1, float b2, float b3 >
51921   // and
51922   //   LHS = VECTOR_SHUFFLE A, B, <0, 2, 4, 6>
51923   //   RHS = VECTOR_SHUFFLE A, B, <1, 3, 5, 7>
51924   // then LHS op RHS = < a0 op a1, a2 op a3, b0 op b1, b2 op b3 >
51925   // which is A horizontal-op B.
51926 
51927   MVT VT = LHS.getSimpleValueType();
51928   assert((VT.is128BitVector() || VT.is256BitVector()) &&
51929          "Unsupported vector type for horizontal add/sub");
51930   unsigned NumElts = VT.getVectorNumElements();
51931 
51932   auto GetShuffle = [&](SDValue Op, SDValue &N0, SDValue &N1,
51933                         SmallVectorImpl<int> &ShuffleMask) {
51934     bool UseSubVector = false;
51935     if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
51936         Op.getOperand(0).getValueType().is256BitVector() &&
51937         llvm::isNullConstant(Op.getOperand(1))) {
51938       Op = Op.getOperand(0);
51939       UseSubVector = true;
51940     }
51941     SmallVector<SDValue, 2> SrcOps;
51942     SmallVector<int, 16> SrcMask, ScaledMask;
51943     SDValue BC = peekThroughBitcasts(Op);
51944     if (getTargetShuffleInputs(BC, SrcOps, SrcMask, DAG) &&
51945         !isAnyZero(SrcMask) && all_of(SrcOps, [BC](SDValue Op) {
51946           return Op.getValueSizeInBits() == BC.getValueSizeInBits();
51947         })) {
51948       resolveTargetShuffleInputsAndMask(SrcOps, SrcMask);
51949       if (!UseSubVector && SrcOps.size() <= 2 &&
51950           scaleShuffleElements(SrcMask, NumElts, ScaledMask)) {
51951         N0 = !SrcOps.empty() ? SrcOps[0] : SDValue();
51952         N1 = SrcOps.size() > 1 ? SrcOps[1] : SDValue();
51953         ShuffleMask.assign(ScaledMask.begin(), ScaledMask.end());
51954       }
51955       if (UseSubVector && SrcOps.size() == 1 &&
51956           scaleShuffleElements(SrcMask, 2 * NumElts, ScaledMask)) {
51957         std::tie(N0, N1) = DAG.SplitVector(SrcOps[0], SDLoc(Op));
51958         ArrayRef<int> Mask = ArrayRef<int>(ScaledMask).slice(0, NumElts);
51959         ShuffleMask.assign(Mask.begin(), Mask.end());
51960       }
51961     }
51962   };
51963 
51964   // View LHS in the form
51965   //   LHS = VECTOR_SHUFFLE A, B, LMask
51966   // If LHS is not a shuffle, then pretend it is the identity shuffle:
51967   //   LHS = VECTOR_SHUFFLE LHS, undef, <0, 1, ..., N-1>
51968   // NOTE: A default initialized SDValue represents an UNDEF of type VT.
51969   SDValue A, B;
51970   SmallVector<int, 16> LMask;
51971   GetShuffle(LHS, A, B, LMask);
51972 
51973   // Likewise, view RHS in the form
51974   //   RHS = VECTOR_SHUFFLE C, D, RMask
51975   SDValue C, D;
51976   SmallVector<int, 16> RMask;
51977   GetShuffle(RHS, C, D, RMask);
51978 
51979   // At least one of the operands should be a vector shuffle.
51980   unsigned NumShuffles = (LMask.empty() ? 0 : 1) + (RMask.empty() ? 0 : 1);
51981   if (NumShuffles == 0)
51982     return false;
51983 
51984   if (LMask.empty()) {
51985     A = LHS;
51986     for (unsigned i = 0; i != NumElts; ++i)
51987       LMask.push_back(i);
51988   }
51989 
51990   if (RMask.empty()) {
51991     C = RHS;
51992     for (unsigned i = 0; i != NumElts; ++i)
51993       RMask.push_back(i);
51994   }
51995 
51996   // If we have an unary mask, ensure the other op is set to null.
51997   if (isUndefOrInRange(LMask, 0, NumElts))
51998     B = SDValue();
51999   else if (isUndefOrInRange(LMask, NumElts, NumElts * 2))
52000     A = SDValue();
52001 
52002   if (isUndefOrInRange(RMask, 0, NumElts))
52003     D = SDValue();
52004   else if (isUndefOrInRange(RMask, NumElts, NumElts * 2))
52005     C = SDValue();
52006 
52007   // If A and B occur in reverse order in RHS, then canonicalize by commuting
52008   // RHS operands and shuffle mask.
52009   if (A != C) {
52010     std::swap(C, D);
52011     ShuffleVectorSDNode::commuteMask(RMask);
52012   }
52013   // Check that the shuffles are both shuffling the same vectors.
52014   if (!(A == C && B == D))
52015     return false;
52016 
52017   PostShuffleMask.clear();
52018   PostShuffleMask.append(NumElts, SM_SentinelUndef);
52019 
52020   // LHS and RHS are now:
52021   //   LHS = shuffle A, B, LMask
52022   //   RHS = shuffle A, B, RMask
52023   // Check that the masks correspond to performing a horizontal operation.
52024   // AVX defines horizontal add/sub to operate independently on 128-bit lanes,
52025   // so we just repeat the inner loop if this is a 256-bit op.
52026   unsigned Num128BitChunks = VT.getSizeInBits() / 128;
52027   unsigned NumEltsPer128BitChunk = NumElts / Num128BitChunks;
52028   unsigned NumEltsPer64BitChunk = NumEltsPer128BitChunk / 2;
52029   assert((NumEltsPer128BitChunk % 2 == 0) &&
52030          "Vector type should have an even number of elements in each lane");
52031   for (unsigned j = 0; j != NumElts; j += NumEltsPer128BitChunk) {
52032     for (unsigned i = 0; i != NumEltsPer128BitChunk; ++i) {
52033       // Ignore undefined components.
52034       int LIdx = LMask[i + j], RIdx = RMask[i + j];
52035       if (LIdx < 0 || RIdx < 0 ||
52036           (!A.getNode() && (LIdx < (int)NumElts || RIdx < (int)NumElts)) ||
52037           (!B.getNode() && (LIdx >= (int)NumElts || RIdx >= (int)NumElts)))
52038         continue;
52039 
52040       // Check that successive odd/even elements are being operated on. If not,
52041       // this is not a horizontal operation.
52042       if (!((RIdx & 1) == 1 && (LIdx + 1) == RIdx) &&
52043           !((LIdx & 1) == 1 && (RIdx + 1) == LIdx && IsCommutative))
52044         return false;
52045 
52046       // Compute the post-shuffle mask index based on where the element
52047       // is stored in the HOP result, and where it needs to be moved to.
52048       int Base = LIdx & ~1u;
52049       int Index = ((Base % NumEltsPer128BitChunk) / 2) +
52050                   ((Base % NumElts) & ~(NumEltsPer128BitChunk - 1));
52051 
52052       // The  low half of the 128-bit result must choose from A.
52053       // The high half of the 128-bit result must choose from B,
52054       // unless B is undef. In that case, we are always choosing from A.
52055       if ((B && Base >= (int)NumElts) || (!B && i >= NumEltsPer64BitChunk))
52056         Index += NumEltsPer64BitChunk;
52057       PostShuffleMask[i + j] = Index;
52058     }
52059   }
52060 
52061   SDValue NewLHS = A.getNode() ? A : B; // If A is 'UNDEF', use B for it.
52062   SDValue NewRHS = B.getNode() ? B : A; // If B is 'UNDEF', use A for it.
52063 
52064   bool IsIdentityPostShuffle =
52065       isSequentialOrUndefInRange(PostShuffleMask, 0, NumElts, 0);
52066   if (IsIdentityPostShuffle)
52067     PostShuffleMask.clear();
52068 
52069   // Avoid 128-bit multi lane shuffles if pre-AVX2 and FP (integer will split).
52070   if (!IsIdentityPostShuffle && !Subtarget.hasAVX2() && VT.isFloatingPoint() &&
52071       isMultiLaneShuffleMask(128, VT.getScalarSizeInBits(), PostShuffleMask))
52072     return false;
52073 
52074   // If the source nodes are already used in HorizOps then always accept this.
52075   // Shuffle folding should merge these back together.
52076   auto FoundHorizUser = [&](SDNode *User) {
52077     return User->getOpcode() == HOpcode && User->getValueType(0) == VT;
52078   };
52079   ForceHorizOp =
52080       ForceHorizOp || (llvm::any_of(NewLHS->uses(), FoundHorizUser) &&
52081                        llvm::any_of(NewRHS->uses(), FoundHorizUser));
52082 
52083   // Assume a SingleSource HOP if we only shuffle one input and don't need to
52084   // shuffle the result.
52085   if (!ForceHorizOp &&
52086       !shouldUseHorizontalOp(NewLHS == NewRHS &&
52087                                  (NumShuffles < 2 || !IsIdentityPostShuffle),
52088                              DAG, Subtarget))
52089     return false;
52090 
52091   LHS = DAG.getBitcast(VT, NewLHS);
52092   RHS = DAG.getBitcast(VT, NewRHS);
52093   return true;
52094 }
52095 
52096 // Try to synthesize horizontal (f)hadd/hsub from (f)adds/subs of shuffles.
combineToHorizontalAddSub(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)52097 static SDValue combineToHorizontalAddSub(SDNode *N, SelectionDAG &DAG,
52098                                          const X86Subtarget &Subtarget) {
52099   EVT VT = N->getValueType(0);
52100   unsigned Opcode = N->getOpcode();
52101   bool IsAdd = (Opcode == ISD::FADD) || (Opcode == ISD::ADD);
52102   SmallVector<int, 8> PostShuffleMask;
52103 
52104   auto MergableHorizOp = [N](unsigned HorizOpcode) {
52105     return N->hasOneUse() &&
52106            N->use_begin()->getOpcode() == ISD::VECTOR_SHUFFLE &&
52107            (N->use_begin()->getOperand(0).getOpcode() == HorizOpcode ||
52108             N->use_begin()->getOperand(1).getOpcode() == HorizOpcode);
52109   };
52110 
52111   switch (Opcode) {
52112   case ISD::FADD:
52113   case ISD::FSUB:
52114     if ((Subtarget.hasSSE3() && (VT == MVT::v4f32 || VT == MVT::v2f64)) ||
52115         (Subtarget.hasAVX() && (VT == MVT::v8f32 || VT == MVT::v4f64))) {
52116       SDValue LHS = N->getOperand(0);
52117       SDValue RHS = N->getOperand(1);
52118       auto HorizOpcode = IsAdd ? X86ISD::FHADD : X86ISD::FHSUB;
52119       if (isHorizontalBinOp(HorizOpcode, LHS, RHS, DAG, Subtarget, IsAdd,
52120                             PostShuffleMask, MergableHorizOp(HorizOpcode))) {
52121         SDValue HorizBinOp = DAG.getNode(HorizOpcode, SDLoc(N), VT, LHS, RHS);
52122         if (!PostShuffleMask.empty())
52123           HorizBinOp = DAG.getVectorShuffle(VT, SDLoc(HorizBinOp), HorizBinOp,
52124                                             DAG.getUNDEF(VT), PostShuffleMask);
52125         return HorizBinOp;
52126       }
52127     }
52128     break;
52129   case ISD::ADD:
52130   case ISD::SUB:
52131     if (Subtarget.hasSSSE3() && (VT == MVT::v8i16 || VT == MVT::v4i32 ||
52132                                  VT == MVT::v16i16 || VT == MVT::v8i32)) {
52133       SDValue LHS = N->getOperand(0);
52134       SDValue RHS = N->getOperand(1);
52135       auto HorizOpcode = IsAdd ? X86ISD::HADD : X86ISD::HSUB;
52136       if (isHorizontalBinOp(HorizOpcode, LHS, RHS, DAG, Subtarget, IsAdd,
52137                             PostShuffleMask, MergableHorizOp(HorizOpcode))) {
52138         auto HOpBuilder = [HorizOpcode](SelectionDAG &DAG, const SDLoc &DL,
52139                                         ArrayRef<SDValue> Ops) {
52140           return DAG.getNode(HorizOpcode, DL, Ops[0].getValueType(), Ops);
52141         };
52142         SDValue HorizBinOp = SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT,
52143                                               {LHS, RHS}, HOpBuilder);
52144         if (!PostShuffleMask.empty())
52145           HorizBinOp = DAG.getVectorShuffle(VT, SDLoc(HorizBinOp), HorizBinOp,
52146                                             DAG.getUNDEF(VT), PostShuffleMask);
52147         return HorizBinOp;
52148       }
52149     }
52150     break;
52151   }
52152 
52153   return SDValue();
52154 }
52155 
52156 //  Try to combine the following nodes
52157 //  t29: i64 = X86ISD::Wrapper TargetConstantPool:i64
52158 //    <i32 -2147483648[float -0.000000e+00]> 0
52159 //  t27: v16i32[v16f32],ch = X86ISD::VBROADCAST_LOAD
52160 //    <(load 4 from constant-pool)> t0, t29
52161 //  [t30: v16i32 = bitcast t27]
52162 //  t6: v16i32 = xor t7, t27[t30]
52163 //  t11: v16f32 = bitcast t6
52164 //  t21: v16f32 = X86ISD::VFMULC[X86ISD::VCFMULC] t11, t8
52165 //  into X86ISD::VFCMULC[X86ISD::VFMULC] if possible:
52166 //  t22: v16f32 = bitcast t7
52167 //  t23: v16f32 = X86ISD::VFCMULC[X86ISD::VFMULC] t8, t22
52168 //  t24: v32f16 = bitcast t23
combineFMulcFCMulc(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)52169 static SDValue combineFMulcFCMulc(SDNode *N, SelectionDAG &DAG,
52170                                   const X86Subtarget &Subtarget) {
52171   EVT VT = N->getValueType(0);
52172   SDValue LHS = N->getOperand(0);
52173   SDValue RHS = N->getOperand(1);
52174   int CombineOpcode =
52175       N->getOpcode() == X86ISD::VFCMULC ? X86ISD::VFMULC : X86ISD::VFCMULC;
52176   auto combineConjugation = [&](SDValue &r) {
52177     if (LHS->getOpcode() == ISD::BITCAST && RHS.hasOneUse()) {
52178       SDValue XOR = LHS.getOperand(0);
52179       if (XOR->getOpcode() == ISD::XOR && XOR.hasOneUse()) {
52180         KnownBits XORRHS = DAG.computeKnownBits(XOR.getOperand(1));
52181         if (XORRHS.isConstant()) {
52182           APInt ConjugationInt32 = APInt(32, 0x80000000, true);
52183           APInt ConjugationInt64 = APInt(64, 0x8000000080000000ULL, true);
52184           if ((XORRHS.getBitWidth() == 32 &&
52185                XORRHS.getConstant() == ConjugationInt32) ||
52186               (XORRHS.getBitWidth() == 64 &&
52187                XORRHS.getConstant() == ConjugationInt64)) {
52188             SelectionDAG::FlagInserter FlagsInserter(DAG, N);
52189             SDValue I2F = DAG.getBitcast(VT, LHS.getOperand(0).getOperand(0));
52190             SDValue FCMulC = DAG.getNode(CombineOpcode, SDLoc(N), VT, RHS, I2F);
52191             r = DAG.getBitcast(VT, FCMulC);
52192             return true;
52193           }
52194         }
52195       }
52196     }
52197     return false;
52198   };
52199   SDValue Res;
52200   if (combineConjugation(Res))
52201     return Res;
52202   std::swap(LHS, RHS);
52203   if (combineConjugation(Res))
52204     return Res;
52205   return Res;
52206 }
52207 
52208 //  Try to combine the following nodes:
52209 //  FADD(A, FMA(B, C, 0)) and FADD(A, FMUL(B, C)) to FMA(B, C, A)
combineFaddCFmul(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)52210 static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
52211                                 const X86Subtarget &Subtarget) {
52212   auto AllowContract = [&DAG](const SDNodeFlags &Flags) {
52213     return DAG.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast ||
52214            Flags.hasAllowContract();
52215   };
52216 
52217   auto HasNoSignedZero = [&DAG](const SDNodeFlags &Flags) {
52218     return DAG.getTarget().Options.NoSignedZerosFPMath ||
52219            Flags.hasNoSignedZeros();
52220   };
52221   auto IsVectorAllNegativeZero = [&DAG](SDValue Op) {
52222     APInt AI = APInt(32, 0x80008000, true);
52223     KnownBits Bits = DAG.computeKnownBits(Op);
52224     return Bits.getBitWidth() == 32 && Bits.isConstant() &&
52225            Bits.getConstant() == AI;
52226   };
52227 
52228   if (N->getOpcode() != ISD::FADD || !Subtarget.hasFP16() ||
52229       !AllowContract(N->getFlags()))
52230     return SDValue();
52231 
52232   EVT VT = N->getValueType(0);
52233   if (VT != MVT::v8f16 && VT != MVT::v16f16 && VT != MVT::v32f16)
52234     return SDValue();
52235 
52236   SDValue LHS = N->getOperand(0);
52237   SDValue RHS = N->getOperand(1);
52238   bool IsConj;
52239   SDValue FAddOp1, MulOp0, MulOp1;
52240   auto GetCFmulFrom = [&MulOp0, &MulOp1, &IsConj, &AllowContract,
52241                        &IsVectorAllNegativeZero,
52242                        &HasNoSignedZero](SDValue N) -> bool {
52243     if (!N.hasOneUse() || N.getOpcode() != ISD::BITCAST)
52244       return false;
52245     SDValue Op0 = N.getOperand(0);
52246     unsigned Opcode = Op0.getOpcode();
52247     if (Op0.hasOneUse() && AllowContract(Op0->getFlags())) {
52248       if ((Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC)) {
52249         MulOp0 = Op0.getOperand(0);
52250         MulOp1 = Op0.getOperand(1);
52251         IsConj = Opcode == X86ISD::VFCMULC;
52252         return true;
52253       }
52254       if ((Opcode == X86ISD::VFMADDC || Opcode == X86ISD::VFCMADDC) &&
52255           ((ISD::isBuildVectorAllZeros(Op0->getOperand(2).getNode()) &&
52256             HasNoSignedZero(Op0->getFlags())) ||
52257            IsVectorAllNegativeZero(Op0->getOperand(2)))) {
52258         MulOp0 = Op0.getOperand(0);
52259         MulOp1 = Op0.getOperand(1);
52260         IsConj = Opcode == X86ISD::VFCMADDC;
52261         return true;
52262       }
52263     }
52264     return false;
52265   };
52266 
52267   if (GetCFmulFrom(LHS))
52268     FAddOp1 = RHS;
52269   else if (GetCFmulFrom(RHS))
52270     FAddOp1 = LHS;
52271   else
52272     return SDValue();
52273 
52274   MVT CVT = MVT::getVectorVT(MVT::f32, VT.getVectorNumElements() / 2);
52275   FAddOp1 = DAG.getBitcast(CVT, FAddOp1);
52276   unsigned NewOp = IsConj ? X86ISD::VFCMADDC : X86ISD::VFMADDC;
52277   // FIXME: How do we handle when fast math flags of FADD are different from
52278   // CFMUL's?
52279   SDValue CFmul =
52280       DAG.getNode(NewOp, SDLoc(N), CVT, MulOp0, MulOp1, FAddOp1, N->getFlags());
52281   return DAG.getBitcast(VT, CFmul);
52282 }
52283 
52284 /// Do target-specific dag combines on floating-point adds/subs.
combineFaddFsub(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)52285 static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
52286                                const X86Subtarget &Subtarget) {
52287   if (SDValue HOp = combineToHorizontalAddSub(N, DAG, Subtarget))
52288     return HOp;
52289 
52290   if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget))
52291     return COp;
52292 
52293   return SDValue();
52294 }
52295 
combineLRINT_LLRINT(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)52296 static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
52297                                    const X86Subtarget &Subtarget) {
52298   EVT VT = N->getValueType(0);
52299   SDValue Src = N->getOperand(0);
52300   EVT SrcVT = Src.getValueType();
52301   SDLoc DL(N);
52302 
52303   if (!Subtarget.hasDQI() || !Subtarget.hasVLX() || VT != MVT::v2i64 ||
52304       SrcVT != MVT::v2f32)
52305     return SDValue();
52306 
52307   return DAG.getNode(X86ISD::CVTP2SI, DL, VT,
52308                      DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, Src,
52309                                  DAG.getUNDEF(SrcVT)));
52310 }
52311 
52312 /// Attempt to pre-truncate inputs to arithmetic ops if it will simplify
52313 /// the codegen.
52314 /// e.g. TRUNC( BINOP( X, Y ) ) --> BINOP( TRUNC( X ), TRUNC( Y ) )
52315 /// TODO: This overlaps with the generic combiner's visitTRUNCATE. Remove
52316 ///       anything that is guaranteed to be transformed by DAGCombiner.
combineTruncatedArithmetic(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget,const SDLoc & DL)52317 static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
52318                                           const X86Subtarget &Subtarget,
52319                                           const SDLoc &DL) {
52320   assert(N->getOpcode() == ISD::TRUNCATE && "Wrong opcode");
52321   SDValue Src = N->getOperand(0);
52322   unsigned SrcOpcode = Src.getOpcode();
52323   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
52324 
52325   EVT VT = N->getValueType(0);
52326   EVT SrcVT = Src.getValueType();
52327 
52328   auto IsFreeTruncation = [VT](SDValue Op) {
52329     unsigned TruncSizeInBits = VT.getScalarSizeInBits();
52330 
52331     // See if this has been extended from a smaller/equal size to
52332     // the truncation size, allowing a truncation to combine with the extend.
52333     unsigned Opcode = Op.getOpcode();
52334     if ((Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND ||
52335          Opcode == ISD::ZERO_EXTEND) &&
52336         Op.getOperand(0).getScalarValueSizeInBits() <= TruncSizeInBits)
52337       return true;
52338 
52339     // See if this is a single use constant which can be constant folded.
52340     // NOTE: We don't peek throught bitcasts here because there is currently
52341     // no support for constant folding truncate+bitcast+vector_of_constants. So
52342     // we'll just send up with a truncate on both operands which will
52343     // get turned back into (truncate (binop)) causing an infinite loop.
52344     return ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
52345   };
52346 
52347   auto TruncateArithmetic = [&](SDValue N0, SDValue N1) {
52348     SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, VT, N0);
52349     SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
52350     return DAG.getNode(SrcOpcode, DL, VT, Trunc0, Trunc1);
52351   };
52352 
52353   // Don't combine if the operation has other uses.
52354   if (!Src.hasOneUse())
52355     return SDValue();
52356 
52357   // Only support vector truncation for now.
52358   // TODO: i64 scalar math would benefit as well.
52359   if (!VT.isVector())
52360     return SDValue();
52361 
52362   // In most cases its only worth pre-truncating if we're only facing the cost
52363   // of one truncation.
52364   // i.e. if one of the inputs will constant fold or the input is repeated.
52365   switch (SrcOpcode) {
52366   case ISD::MUL:
52367     // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its
52368     // better to truncate if we have the chance.
52369     if (SrcVT.getScalarType() == MVT::i64 &&
52370         TLI.isOperationLegal(SrcOpcode, VT) &&
52371         !TLI.isOperationLegal(SrcOpcode, SrcVT))
52372       return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1));
52373     [[fallthrough]];
52374   case ISD::AND:
52375   case ISD::XOR:
52376   case ISD::OR:
52377   case ISD::ADD:
52378   case ISD::SUB: {
52379     SDValue Op0 = Src.getOperand(0);
52380     SDValue Op1 = Src.getOperand(1);
52381     if (TLI.isOperationLegal(SrcOpcode, VT) &&
52382         (Op0 == Op1 || IsFreeTruncation(Op0) || IsFreeTruncation(Op1)))
52383       return TruncateArithmetic(Op0, Op1);
52384     break;
52385   }
52386   }
52387 
52388   return SDValue();
52389 }
52390 
52391 // Try to form a MULHU or MULHS node by looking for
52392 // (trunc (srl (mul ext, ext), 16))
52393 // TODO: This is X86 specific because we want to be able to handle wide types
52394 // before type legalization. But we can only do it if the vector will be
52395 // legalized via widening/splitting. Type legalization can't handle promotion
52396 // of a MULHU/MULHS. There isn't a way to convey this to the generic DAG
52397 // combiner.
combinePMULH(SDValue Src,EVT VT,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)52398 static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
52399                             SelectionDAG &DAG, const X86Subtarget &Subtarget) {
52400   // First instruction should be a right shift of a multiply.
52401   if (Src.getOpcode() != ISD::SRL ||
52402       Src.getOperand(0).getOpcode() != ISD::MUL)
52403     return SDValue();
52404 
52405   if (!Subtarget.hasSSE2())
52406     return SDValue();
52407 
52408   // Only handle vXi16 types that are at least 128-bits unless they will be
52409   // widened.
52410   if (!VT.isVector() || VT.getVectorElementType() != MVT::i16)
52411     return SDValue();
52412 
52413   // Input type should be at least vXi32.
52414   EVT InVT = Src.getValueType();
52415   if (InVT.getVectorElementType().getSizeInBits() < 32)
52416     return SDValue();
52417 
52418   // Need a shift by 16.
52419   APInt ShiftAmt;
52420   if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) ||
52421       ShiftAmt != 16)
52422     return SDValue();
52423 
52424   SDValue LHS = Src.getOperand(0).getOperand(0);
52425   SDValue RHS = Src.getOperand(0).getOperand(1);
52426 
52427   // Count leading sign/zero bits on both inputs - if there are enough then
52428   // truncation back to vXi16 will be cheap - either as a pack/shuffle
52429   // sequence or using AVX512 truncations. If the inputs are sext/zext then the
52430   // truncations may actually be free by peeking through to the ext source.
52431   auto IsSext = [&DAG](SDValue V) {
52432     return DAG.ComputeMaxSignificantBits(V) <= 16;
52433   };
52434   auto IsZext = [&DAG](SDValue V) {
52435     return DAG.computeKnownBits(V).countMaxActiveBits() <= 16;
52436   };
52437 
52438   bool IsSigned = IsSext(LHS) && IsSext(RHS);
52439   bool IsUnsigned = IsZext(LHS) && IsZext(RHS);
52440   if (!IsSigned && !IsUnsigned)
52441     return SDValue();
52442 
52443   // Check if both inputs are extensions, which will be removed by truncation.
52444   bool IsTruncateFree = (LHS.getOpcode() == ISD::SIGN_EXTEND ||
52445                          LHS.getOpcode() == ISD::ZERO_EXTEND) &&
52446                         (RHS.getOpcode() == ISD::SIGN_EXTEND ||
52447                          RHS.getOpcode() == ISD::ZERO_EXTEND) &&
52448                         LHS.getOperand(0).getScalarValueSizeInBits() <= 16 &&
52449                         RHS.getOperand(0).getScalarValueSizeInBits() <= 16;
52450 
52451   // For AVX2+ targets, with the upper bits known zero, we can perform MULHU on
52452   // the (bitcasted) inputs directly, and then cheaply pack/truncate the result
52453   // (upper elts will be zero). Don't attempt this with just AVX512F as MULHU
52454   // will have to split anyway.
52455   unsigned InSizeInBits = InVT.getSizeInBits();
52456   if (IsUnsigned && !IsTruncateFree && Subtarget.hasInt256() &&
52457       !(Subtarget.hasAVX512() && !Subtarget.hasBWI() && VT.is256BitVector()) &&
52458       (InSizeInBits % 16) == 0) {
52459     EVT BCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
52460                                 InVT.getSizeInBits() / 16);
52461     SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
52462                               DAG.getBitcast(BCVT, RHS));
52463     return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
52464   }
52465 
52466   // Truncate back to source type.
52467   LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS);
52468   RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
52469 
52470   unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
52471   return DAG.getNode(Opc, DL, VT, LHS, RHS);
52472 }
52473 
52474 // Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
52475 // from one vector with signed bytes from another vector, adds together
52476 // adjacent pairs of 16-bit products, and saturates the result before
52477 // truncating to 16-bits.
52478 //
52479 // Which looks something like this:
52480 // (i16 (ssat (add (mul (zext (even elts (i8 A))), (sext (even elts (i8 B)))),
52481 //                 (mul (zext (odd elts (i8 A)), (sext (odd elts (i8 B))))))))
detectPMADDUBSW(SDValue In,EVT VT,SelectionDAG & DAG,const X86Subtarget & Subtarget,const SDLoc & DL)52482 static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
52483                                const X86Subtarget &Subtarget,
52484                                const SDLoc &DL) {
52485   if (!VT.isVector() || !Subtarget.hasSSSE3())
52486     return SDValue();
52487 
52488   unsigned NumElems = VT.getVectorNumElements();
52489   EVT ScalarVT = VT.getVectorElementType();
52490   if (ScalarVT != MVT::i16 || NumElems < 8 || !isPowerOf2_32(NumElems))
52491     return SDValue();
52492 
52493   SDValue SSatVal = detectSSatPattern(In, VT);
52494   if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
52495     return SDValue();
52496 
52497   // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
52498   // of multiplies from even/odd elements.
52499   SDValue N0 = SSatVal.getOperand(0);
52500   SDValue N1 = SSatVal.getOperand(1);
52501 
52502   if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
52503     return SDValue();
52504 
52505   SDValue N00 = N0.getOperand(0);
52506   SDValue N01 = N0.getOperand(1);
52507   SDValue N10 = N1.getOperand(0);
52508   SDValue N11 = N1.getOperand(1);
52509 
52510   // TODO: Handle constant vectors and use knownbits/computenumsignbits?
52511   // Canonicalize zero_extend to LHS.
52512   if (N01.getOpcode() == ISD::ZERO_EXTEND)
52513     std::swap(N00, N01);
52514   if (N11.getOpcode() == ISD::ZERO_EXTEND)
52515     std::swap(N10, N11);
52516 
52517   // Ensure we have a zero_extend and a sign_extend.
52518   if (N00.getOpcode() != ISD::ZERO_EXTEND ||
52519       N01.getOpcode() != ISD::SIGN_EXTEND ||
52520       N10.getOpcode() != ISD::ZERO_EXTEND ||
52521       N11.getOpcode() != ISD::SIGN_EXTEND)
52522     return SDValue();
52523 
52524   // Peek through the extends.
52525   N00 = N00.getOperand(0);
52526   N01 = N01.getOperand(0);
52527   N10 = N10.getOperand(0);
52528   N11 = N11.getOperand(0);
52529 
52530   // Ensure the extend is from vXi8.
52531   if (N00.getValueType().getVectorElementType() != MVT::i8 ||
52532       N01.getValueType().getVectorElementType() != MVT::i8 ||
52533       N10.getValueType().getVectorElementType() != MVT::i8 ||
52534       N11.getValueType().getVectorElementType() != MVT::i8)
52535     return SDValue();
52536 
52537   // All inputs should be build_vectors.
52538   if (N00.getOpcode() != ISD::BUILD_VECTOR ||
52539       N01.getOpcode() != ISD::BUILD_VECTOR ||
52540       N10.getOpcode() != ISD::BUILD_VECTOR ||
52541       N11.getOpcode() != ISD::BUILD_VECTOR)
52542     return SDValue();
52543 
52544   // N00/N10 are zero extended. N01/N11 are sign extended.
52545 
52546   // For each element, we need to ensure we have an odd element from one vector
52547   // multiplied by the odd element of another vector and the even element from
52548   // one of the same vectors being multiplied by the even element from the
52549   // other vector. So we need to make sure for each element i, this operator
52550   // is being performed:
52551   //  A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
52552   SDValue ZExtIn, SExtIn;
52553   for (unsigned i = 0; i != NumElems; ++i) {
52554     SDValue N00Elt = N00.getOperand(i);
52555     SDValue N01Elt = N01.getOperand(i);
52556     SDValue N10Elt = N10.getOperand(i);
52557     SDValue N11Elt = N11.getOperand(i);
52558     // TODO: Be more tolerant to undefs.
52559     if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
52560         N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
52561         N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
52562         N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
52563       return SDValue();
52564     auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
52565     auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
52566     auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
52567     auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
52568     if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
52569       return SDValue();
52570     unsigned IdxN00 = ConstN00Elt->getZExtValue();
52571     unsigned IdxN01 = ConstN01Elt->getZExtValue();
52572     unsigned IdxN10 = ConstN10Elt->getZExtValue();
52573     unsigned IdxN11 = ConstN11Elt->getZExtValue();
52574     // Add is commutative so indices can be reordered.
52575     if (IdxN00 > IdxN10) {
52576       std::swap(IdxN00, IdxN10);
52577       std::swap(IdxN01, IdxN11);
52578     }
52579     // N0 indices be the even element. N1 indices must be the next odd element.
52580     if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
52581         IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
52582       return SDValue();
52583     SDValue N00In = N00Elt.getOperand(0);
52584     SDValue N01In = N01Elt.getOperand(0);
52585     SDValue N10In = N10Elt.getOperand(0);
52586     SDValue N11In = N11Elt.getOperand(0);
52587     // First time we find an input capture it.
52588     if (!ZExtIn) {
52589       ZExtIn = N00In;
52590       SExtIn = N01In;
52591     }
52592     if (ZExtIn != N00In || SExtIn != N01In ||
52593         ZExtIn != N10In || SExtIn != N11In)
52594       return SDValue();
52595   }
52596 
52597   auto ExtractVec = [&DAG, &DL, NumElems](SDValue &Ext) {
52598     EVT ExtVT = Ext.getValueType();
52599     if (ExtVT.getVectorNumElements() != NumElems * 2) {
52600       MVT NVT = MVT::getVectorVT(MVT::i8, NumElems * 2);
52601       Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, Ext,
52602                         DAG.getIntPtrConstant(0, DL));
52603     }
52604   };
52605   ExtractVec(ZExtIn);
52606   ExtractVec(SExtIn);
52607 
52608   auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
52609                          ArrayRef<SDValue> Ops) {
52610     // Shrink by adding truncate nodes and let DAGCombine fold with the
52611     // sources.
52612     EVT InVT = Ops[0].getValueType();
52613     assert(InVT.getScalarType() == MVT::i8 &&
52614            "Unexpected scalar element type");
52615     assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
52616     EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
52617                                  InVT.getVectorNumElements() / 2);
52618     return DAG.getNode(X86ISD::VPMADDUBSW, DL, ResVT, Ops[0], Ops[1]);
52619   };
52620   return SplitOpsAndApply(DAG, Subtarget, DL, VT, { ZExtIn, SExtIn },
52621                           PMADDBuilder);
52622 }
52623 
combineTruncate(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)52624 static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
52625                                const X86Subtarget &Subtarget) {
52626   EVT VT = N->getValueType(0);
52627   SDValue Src = N->getOperand(0);
52628   SDLoc DL(N);
52629 
52630   // Attempt to pre-truncate inputs to arithmetic ops instead.
52631   if (SDValue V = combineTruncatedArithmetic(N, DAG, Subtarget, DL))
52632     return V;
52633 
52634   // Try to detect PMADD
52635   if (SDValue PMAdd = detectPMADDUBSW(Src, VT, DAG, Subtarget, DL))
52636     return PMAdd;
52637 
52638   // Try to combine truncation with signed/unsigned saturation.
52639   if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
52640     return Val;
52641 
52642   // Try to combine PMULHUW/PMULHW for vXi16.
52643   if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget))
52644     return V;
52645 
52646   // The bitcast source is a direct mmx result.
52647   // Detect bitcasts between i32 to x86mmx
52648   if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {
52649     SDValue BCSrc = Src.getOperand(0);
52650     if (BCSrc.getValueType() == MVT::x86mmx)
52651       return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc);
52652   }
52653 
52654   // Try to combine (trunc (vNi64 (lrint x))) to (vNi32 (lrint x)).
52655   if (Src.getOpcode() == ISD::LRINT && VT.getScalarType() == MVT::i32 &&
52656       Src.hasOneUse())
52657     return DAG.getNode(ISD::LRINT, DL, VT, Src.getOperand(0));
52658 
52659   return SDValue();
52660 }
52661 
combineVTRUNC(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)52662 static SDValue combineVTRUNC(SDNode *N, SelectionDAG &DAG,
52663                              TargetLowering::DAGCombinerInfo &DCI) {
52664   EVT VT = N->getValueType(0);
52665   SDValue In = N->getOperand(0);
52666   SDLoc DL(N);
52667 
52668   if (SDValue SSatVal = detectSSatPattern(In, VT))
52669     return DAG.getNode(X86ISD::VTRUNCS, DL, VT, SSatVal);
52670   if (SDValue USatVal = detectUSatPattern(In, VT, DAG, DL))
52671     return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, USatVal);
52672 
52673   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
52674   APInt DemandedMask(APInt::getAllOnes(VT.getScalarSizeInBits()));
52675   if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI))
52676     return SDValue(N, 0);
52677 
52678   return SDValue();
52679 }
52680 
52681 /// Returns the negated value if the node \p N flips sign of FP value.
52682 ///
52683 /// FP-negation node may have different forms: FNEG(x), FXOR (x, 0x80000000)
52684 /// or FSUB(0, x)
52685 /// AVX512F does not have FXOR, so FNEG is lowered as
52686 /// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))).
52687 /// In this case we go though all bitcasts.
52688 /// This also recognizes splat of a negated value and returns the splat of that
52689 /// value.
isFNEG(SelectionDAG & DAG,SDNode * N,unsigned Depth=0)52690 static SDValue isFNEG(SelectionDAG &DAG, SDNode *N, unsigned Depth = 0) {
52691   if (N->getOpcode() == ISD::FNEG)
52692     return N->getOperand(0);
52693 
52694   // Don't recurse exponentially.
52695   if (Depth > SelectionDAG::MaxRecursionDepth)
52696     return SDValue();
52697 
52698   unsigned ScalarSize = N->getValueType(0).getScalarSizeInBits();
52699 
52700   SDValue Op = peekThroughBitcasts(SDValue(N, 0));
52701   EVT VT = Op->getValueType(0);
52702 
52703   // Make sure the element size doesn't change.
52704   if (VT.getScalarSizeInBits() != ScalarSize)
52705     return SDValue();
52706 
52707   unsigned Opc = Op.getOpcode();
52708   switch (Opc) {
52709   case ISD::VECTOR_SHUFFLE: {
52710     // For a VECTOR_SHUFFLE(VEC1, VEC2), if the VEC2 is undef, then the negate
52711     // of this is VECTOR_SHUFFLE(-VEC1, UNDEF).  The mask can be anything here.
52712     if (!Op.getOperand(1).isUndef())
52713       return SDValue();
52714     if (SDValue NegOp0 = isFNEG(DAG, Op.getOperand(0).getNode(), Depth + 1))
52715       if (NegOp0.getValueType() == VT) // FIXME: Can we do better?
52716         return DAG.getVectorShuffle(VT, SDLoc(Op), NegOp0, DAG.getUNDEF(VT),
52717                                     cast<ShuffleVectorSDNode>(Op)->getMask());
52718     break;
52719   }
52720   case ISD::INSERT_VECTOR_ELT: {
52721     // Negate of INSERT_VECTOR_ELT(UNDEF, V, INDEX) is INSERT_VECTOR_ELT(UNDEF,
52722     // -V, INDEX).
52723     SDValue InsVector = Op.getOperand(0);
52724     SDValue InsVal = Op.getOperand(1);
52725     if (!InsVector.isUndef())
52726       return SDValue();
52727     if (SDValue NegInsVal = isFNEG(DAG, InsVal.getNode(), Depth + 1))
52728       if (NegInsVal.getValueType() == VT.getVectorElementType()) // FIXME
52729         return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Op), VT, InsVector,
52730                            NegInsVal, Op.getOperand(2));
52731     break;
52732   }
52733   case ISD::FSUB:
52734   case ISD::XOR:
52735   case X86ISD::FXOR: {
52736     SDValue Op1 = Op.getOperand(1);
52737     SDValue Op0 = Op.getOperand(0);
52738 
52739     // For XOR and FXOR, we want to check if constant
52740     // bits of Op1 are sign bit masks. For FSUB, we
52741     // have to check if constant bits of Op0 are sign
52742     // bit masks and hence we swap the operands.
52743     if (Opc == ISD::FSUB)
52744       std::swap(Op0, Op1);
52745 
52746     APInt UndefElts;
52747     SmallVector<APInt, 16> EltBits;
52748     // Extract constant bits and see if they are all
52749     // sign bit masks. Ignore the undef elements.
52750     if (getTargetConstantBitsFromNode(Op1, ScalarSize, UndefElts, EltBits,
52751                                       /* AllowWholeUndefs */ true,
52752                                       /* AllowPartialUndefs */ false)) {
52753       for (unsigned I = 0, E = EltBits.size(); I < E; I++)
52754         if (!UndefElts[I] && !EltBits[I].isSignMask())
52755           return SDValue();
52756 
52757       // Only allow bitcast from correctly-sized constant.
52758       Op0 = peekThroughBitcasts(Op0);
52759       if (Op0.getScalarValueSizeInBits() == ScalarSize)
52760         return Op0;
52761     }
52762     break;
52763   } // case
52764   } // switch
52765 
52766   return SDValue();
52767 }
52768 
negateFMAOpcode(unsigned Opcode,bool NegMul,bool NegAcc,bool NegRes)52769 static unsigned negateFMAOpcode(unsigned Opcode, bool NegMul, bool NegAcc,
52770                                 bool NegRes) {
52771   if (NegMul) {
52772     switch (Opcode) {
52773     // clang-format off
52774     default: llvm_unreachable("Unexpected opcode");
52775     case ISD::FMA:              Opcode = X86ISD::FNMADD;        break;
52776     case ISD::STRICT_FMA:       Opcode = X86ISD::STRICT_FNMADD; break;
52777     case X86ISD::FMADD_RND:     Opcode = X86ISD::FNMADD_RND;    break;
52778     case X86ISD::FMSUB:         Opcode = X86ISD::FNMSUB;        break;
52779     case X86ISD::STRICT_FMSUB:  Opcode = X86ISD::STRICT_FNMSUB; break;
52780     case X86ISD::FMSUB_RND:     Opcode = X86ISD::FNMSUB_RND;    break;
52781     case X86ISD::FNMADD:        Opcode = ISD::FMA;              break;
52782     case X86ISD::STRICT_FNMADD: Opcode = ISD::STRICT_FMA;       break;
52783     case X86ISD::FNMADD_RND:    Opcode = X86ISD::FMADD_RND;     break;
52784     case X86ISD::FNMSUB:        Opcode = X86ISD::FMSUB;         break;
52785     case X86ISD::STRICT_FNMSUB: Opcode = X86ISD::STRICT_FMSUB;  break;
52786     case X86ISD::FNMSUB_RND:    Opcode = X86ISD::FMSUB_RND;     break;
52787     // clang-format on
52788     }
52789   }
52790 
52791   if (NegAcc) {
52792     switch (Opcode) {
52793     // clang-format off
52794     default: llvm_unreachable("Unexpected opcode");
52795     case ISD::FMA:              Opcode = X86ISD::FMSUB;         break;
52796     case ISD::STRICT_FMA:       Opcode = X86ISD::STRICT_FMSUB;  break;
52797     case X86ISD::FMADD_RND:     Opcode = X86ISD::FMSUB_RND;     break;
52798     case X86ISD::FMSUB:         Opcode = ISD::FMA;              break;
52799     case X86ISD::STRICT_FMSUB:  Opcode = ISD::STRICT_FMA;       break;
52800     case X86ISD::FMSUB_RND:     Opcode = X86ISD::FMADD_RND;     break;
52801     case X86ISD::FNMADD:        Opcode = X86ISD::FNMSUB;        break;
52802     case X86ISD::STRICT_FNMADD: Opcode = X86ISD::STRICT_FNMSUB; break;
52803     case X86ISD::FNMADD_RND:    Opcode = X86ISD::FNMSUB_RND;    break;
52804     case X86ISD::FNMSUB:        Opcode = X86ISD::FNMADD;        break;
52805     case X86ISD::STRICT_FNMSUB: Opcode = X86ISD::STRICT_FNMADD; break;
52806     case X86ISD::FNMSUB_RND:    Opcode = X86ISD::FNMADD_RND;    break;
52807     case X86ISD::FMADDSUB:      Opcode = X86ISD::FMSUBADD;      break;
52808     case X86ISD::FMADDSUB_RND:  Opcode = X86ISD::FMSUBADD_RND;  break;
52809     case X86ISD::FMSUBADD:      Opcode = X86ISD::FMADDSUB;      break;
52810     case X86ISD::FMSUBADD_RND:  Opcode = X86ISD::FMADDSUB_RND;  break;
52811     // clang-format on
52812     }
52813   }
52814 
52815   if (NegRes) {
52816     switch (Opcode) {
52817     // For accuracy reason, we never combine fneg and fma under strict FP.
52818     // clang-format off
52819     default: llvm_unreachable("Unexpected opcode");
52820     case ISD::FMA:             Opcode = X86ISD::FNMSUB;       break;
52821     case X86ISD::FMADD_RND:    Opcode = X86ISD::FNMSUB_RND;   break;
52822     case X86ISD::FMSUB:        Opcode = X86ISD::FNMADD;       break;
52823     case X86ISD::FMSUB_RND:    Opcode = X86ISD::FNMADD_RND;   break;
52824     case X86ISD::FNMADD:       Opcode = X86ISD::FMSUB;        break;
52825     case X86ISD::FNMADD_RND:   Opcode = X86ISD::FMSUB_RND;    break;
52826     case X86ISD::FNMSUB:       Opcode = ISD::FMA;             break;
52827     case X86ISD::FNMSUB_RND:   Opcode = X86ISD::FMADD_RND;    break;
52828     // clang-format on
52829     }
52830   }
52831 
52832   return Opcode;
52833 }
52834 
52835 /// Do target-specific dag combines on floating point negations.
combineFneg(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)52836 static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
52837                            TargetLowering::DAGCombinerInfo &DCI,
52838                            const X86Subtarget &Subtarget) {
52839   EVT OrigVT = N->getValueType(0);
52840   SDValue Arg = isFNEG(DAG, N);
52841   if (!Arg)
52842     return SDValue();
52843 
52844   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
52845   EVT VT = Arg.getValueType();
52846   EVT SVT = VT.getScalarType();
52847   SDLoc DL(N);
52848 
52849   // Let legalize expand this if it isn't a legal type yet.
52850   if (!TLI.isTypeLegal(VT))
52851     return SDValue();
52852 
52853   // If we're negating a FMUL node on a target with FMA, then we can avoid the
52854   // use of a constant by performing (-0 - A*B) instead.
52855   // FIXME: Check rounding control flags as well once it becomes available.
52856   if (Arg.getOpcode() == ISD::FMUL && (SVT == MVT::f32 || SVT == MVT::f64) &&
52857       Arg->getFlags().hasNoSignedZeros() && Subtarget.hasAnyFMA()) {
52858     SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
52859     SDValue NewNode = DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
52860                                   Arg.getOperand(1), Zero);
52861     return DAG.getBitcast(OrigVT, NewNode);
52862   }
52863 
52864   bool CodeSize = DAG.getMachineFunction().getFunction().hasOptSize();
52865   bool LegalOperations = !DCI.isBeforeLegalizeOps();
52866   if (SDValue NegArg =
52867           TLI.getNegatedExpression(Arg, DAG, LegalOperations, CodeSize))
52868     return DAG.getBitcast(OrigVT, NegArg);
52869 
52870   return SDValue();
52871 }
52872 
getNegatedExpression(SDValue Op,SelectionDAG & DAG,bool LegalOperations,bool ForCodeSize,NegatibleCost & Cost,unsigned Depth) const52873 SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
52874                                                 bool LegalOperations,
52875                                                 bool ForCodeSize,
52876                                                 NegatibleCost &Cost,
52877                                                 unsigned Depth) const {
52878   // fneg patterns are removable even if they have multiple uses.
52879   if (SDValue Arg = isFNEG(DAG, Op.getNode(), Depth)) {
52880     Cost = NegatibleCost::Cheaper;
52881     return DAG.getBitcast(Op.getValueType(), Arg);
52882   }
52883 
52884   EVT VT = Op.getValueType();
52885   EVT SVT = VT.getScalarType();
52886   unsigned Opc = Op.getOpcode();
52887   SDNodeFlags Flags = Op.getNode()->getFlags();
52888   switch (Opc) {
52889   case ISD::FMA:
52890   case X86ISD::FMSUB:
52891   case X86ISD::FNMADD:
52892   case X86ISD::FNMSUB:
52893   case X86ISD::FMADD_RND:
52894   case X86ISD::FMSUB_RND:
52895   case X86ISD::FNMADD_RND:
52896   case X86ISD::FNMSUB_RND: {
52897     if (!Op.hasOneUse() || !Subtarget.hasAnyFMA() || !isTypeLegal(VT) ||
52898         !(SVT == MVT::f32 || SVT == MVT::f64) ||
52899         !isOperationLegal(ISD::FMA, VT))
52900       break;
52901 
52902     // Don't fold (fneg (fma (fneg x), y, (fneg z))) to (fma x, y, z)
52903     // if it may have signed zeros.
52904     if (!Flags.hasNoSignedZeros())
52905       break;
52906 
52907     // This is always negatible for free but we might be able to remove some
52908     // extra operand negations as well.
52909     SmallVector<SDValue, 4> NewOps(Op.getNumOperands(), SDValue());
52910     for (int i = 0; i != 3; ++i)
52911       NewOps[i] = getCheaperNegatedExpression(
52912           Op.getOperand(i), DAG, LegalOperations, ForCodeSize, Depth + 1);
52913 
52914     bool NegA = !!NewOps[0];
52915     bool NegB = !!NewOps[1];
52916     bool NegC = !!NewOps[2];
52917     unsigned NewOpc = negateFMAOpcode(Opc, NegA != NegB, NegC, true);
52918 
52919     Cost = (NegA || NegB || NegC) ? NegatibleCost::Cheaper
52920                                   : NegatibleCost::Neutral;
52921 
52922     // Fill in the non-negated ops with the original values.
52923     for (int i = 0, e = Op.getNumOperands(); i != e; ++i)
52924       if (!NewOps[i])
52925         NewOps[i] = Op.getOperand(i);
52926     return DAG.getNode(NewOpc, SDLoc(Op), VT, NewOps);
52927   }
52928   case X86ISD::FRCP:
52929     if (SDValue NegOp0 =
52930             getNegatedExpression(Op.getOperand(0), DAG, LegalOperations,
52931                                  ForCodeSize, Cost, Depth + 1))
52932       return DAG.getNode(Opc, SDLoc(Op), VT, NegOp0);
52933     break;
52934   }
52935 
52936   return TargetLowering::getNegatedExpression(Op, DAG, LegalOperations,
52937                                               ForCodeSize, Cost, Depth);
52938 }
52939 
lowerX86FPLogicOp(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)52940 static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG,
52941                                  const X86Subtarget &Subtarget) {
52942   MVT VT = N->getSimpleValueType(0);
52943   // If we have integer vector types available, use the integer opcodes.
52944   if (!VT.isVector() || !Subtarget.hasSSE2())
52945     return SDValue();
52946 
52947   SDLoc dl(N);
52948 
52949   unsigned IntBits = VT.getScalarSizeInBits();
52950   MVT IntSVT = MVT::getIntegerVT(IntBits);
52951   MVT IntVT = MVT::getVectorVT(IntSVT, VT.getSizeInBits() / IntBits);
52952 
52953   SDValue Op0 = DAG.getBitcast(IntVT, N->getOperand(0));
52954   SDValue Op1 = DAG.getBitcast(IntVT, N->getOperand(1));
52955   unsigned IntOpcode;
52956   switch (N->getOpcode()) {
52957   // clang-format off
52958   default: llvm_unreachable("Unexpected FP logic op");
52959   case X86ISD::FOR:   IntOpcode = ISD::OR; break;
52960   case X86ISD::FXOR:  IntOpcode = ISD::XOR; break;
52961   case X86ISD::FAND:  IntOpcode = ISD::AND; break;
52962   case X86ISD::FANDN: IntOpcode = X86ISD::ANDNP; break;
52963   // clang-format on
52964   }
52965   SDValue IntOp = DAG.getNode(IntOpcode, dl, IntVT, Op0, Op1);
52966   return DAG.getBitcast(VT, IntOp);
52967 }
52968 
52969 
52970 /// Fold a xor(setcc cond, val), 1 --> setcc (inverted(cond), val)
foldXor1SetCC(SDNode * N,SelectionDAG & DAG)52971 static SDValue foldXor1SetCC(SDNode *N, SelectionDAG &DAG) {
52972   if (N->getOpcode() != ISD::XOR)
52973     return SDValue();
52974 
52975   SDValue LHS = N->getOperand(0);
52976   if (!isOneConstant(N->getOperand(1)) || LHS->getOpcode() != X86ISD::SETCC)
52977     return SDValue();
52978 
52979   X86::CondCode NewCC = X86::GetOppositeBranchCondition(
52980       X86::CondCode(LHS->getConstantOperandVal(0)));
52981   SDLoc DL(N);
52982   return getSETCC(NewCC, LHS->getOperand(1), DL, DAG);
52983 }
52984 
combineXorSubCTLZ(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)52985 static SDValue combineXorSubCTLZ(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
52986                                  const X86Subtarget &Subtarget) {
52987   assert((N->getOpcode() == ISD::XOR || N->getOpcode() == ISD::SUB) &&
52988          "Invalid opcode for combing with CTLZ");
52989   if (Subtarget.hasFastLZCNT())
52990     return SDValue();
52991 
52992   EVT VT = N->getValueType(0);
52993   if (VT != MVT::i8 && VT != MVT::i16 && VT != MVT::i32 &&
52994       (VT != MVT::i64 || !Subtarget.is64Bit()))
52995     return SDValue();
52996 
52997   SDValue N0 = N->getOperand(0);
52998   SDValue N1 = N->getOperand(1);
52999 
53000   if (N0.getOpcode() != ISD::CTLZ_ZERO_UNDEF &&
53001       N1.getOpcode() != ISD::CTLZ_ZERO_UNDEF)
53002     return SDValue();
53003 
53004   SDValue OpCTLZ;
53005   SDValue OpSizeTM1;
53006 
53007   if (N1.getOpcode() == ISD::CTLZ_ZERO_UNDEF) {
53008     OpCTLZ = N1;
53009     OpSizeTM1 = N0;
53010   } else if (N->getOpcode() == ISD::SUB) {
53011     return SDValue();
53012   } else {
53013     OpCTLZ = N0;
53014     OpSizeTM1 = N1;
53015   }
53016 
53017   if (!OpCTLZ.hasOneUse())
53018     return SDValue();
53019   auto *C = dyn_cast<ConstantSDNode>(OpSizeTM1);
53020   if (!C)
53021     return SDValue();
53022 
53023   if (C->getZExtValue() != uint64_t(OpCTLZ.getValueSizeInBits() - 1))
53024     return SDValue();
53025   EVT OpVT = VT;
53026   SDValue Op = OpCTLZ.getOperand(0);
53027   if (VT == MVT::i8) {
53028     // Zero extend to i32 since there is not an i8 bsr.
53029     OpVT = MVT::i32;
53030     Op = DAG.getNode(ISD::ZERO_EXTEND, DL, OpVT, Op);
53031   }
53032 
53033   SDVTList VTs = DAG.getVTList(OpVT, MVT::i32);
53034   Op = DAG.getNode(X86ISD::BSR, DL, VTs, Op);
53035   if (VT == MVT::i8)
53036     Op = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Op);
53037 
53038   return Op;
53039 }
53040 
combineXor(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)53041 static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
53042                           TargetLowering::DAGCombinerInfo &DCI,
53043                           const X86Subtarget &Subtarget) {
53044   SDValue N0 = N->getOperand(0);
53045   SDValue N1 = N->getOperand(1);
53046   EVT VT = N->getValueType(0);
53047   SDLoc DL(N);
53048 
53049   // If this is SSE1 only convert to FXOR to avoid scalarization.
53050   if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32) {
53051     return DAG.getBitcast(MVT::v4i32,
53052                           DAG.getNode(X86ISD::FXOR, DL, MVT::v4f32,
53053                                       DAG.getBitcast(MVT::v4f32, N0),
53054                                       DAG.getBitcast(MVT::v4f32, N1)));
53055   }
53056 
53057   if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
53058     return Cmp;
53059 
53060   if (SDValue R = combineBitOpWithMOVMSK(N, DAG))
53061     return R;
53062 
53063   if (SDValue R = combineBitOpWithShift(N, DAG))
53064     return R;
53065 
53066   if (SDValue R = combineBitOpWithPACK(N, DAG))
53067     return R;
53068 
53069   if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, DCI, Subtarget))
53070     return FPLogic;
53071 
53072   if (SDValue R = combineXorSubCTLZ(N, DL, DAG, Subtarget))
53073     return R;
53074 
53075   if (DCI.isBeforeLegalizeOps())
53076     return SDValue();
53077 
53078   if (SDValue SetCC = foldXor1SetCC(N, DAG))
53079     return SetCC;
53080 
53081   if (SDValue R = combineOrXorWithSETCC(N, N0, N1, DAG))
53082     return R;
53083 
53084   if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG))
53085     return RV;
53086 
53087   // Fold not(iX bitcast(vXi1)) -> (iX bitcast(not(vec))) for legal boolvecs.
53088   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53089   if (llvm::isAllOnesConstant(N1) && N0.getOpcode() == ISD::BITCAST &&
53090       N0.getOperand(0).getValueType().isVector() &&
53091       N0.getOperand(0).getValueType().getVectorElementType() == MVT::i1 &&
53092       TLI.isTypeLegal(N0.getOperand(0).getValueType()) && N0.hasOneUse()) {
53093     return DAG.getBitcast(
53094         VT, DAG.getNOT(DL, N0.getOperand(0), N0.getOperand(0).getValueType()));
53095   }
53096 
53097   // Handle AVX512 mask widening.
53098   // Fold not(insert_subvector(undef,sub)) -> insert_subvector(undef,not(sub))
53099   if (ISD::isBuildVectorAllOnes(N1.getNode()) && VT.isVector() &&
53100       VT.getVectorElementType() == MVT::i1 &&
53101       N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.getOperand(0).isUndef() &&
53102       TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
53103     return DAG.getNode(
53104         ISD::INSERT_SUBVECTOR, DL, VT, N0.getOperand(0),
53105         DAG.getNOT(DL, N0.getOperand(1), N0.getOperand(1).getValueType()),
53106         N0.getOperand(2));
53107   }
53108 
53109   // Fold xor(zext(xor(x,c1)),c2) -> xor(zext(x),xor(zext(c1),c2))
53110   // Fold xor(truncate(xor(x,c1)),c2) -> xor(truncate(x),xor(truncate(c1),c2))
53111   // TODO: Under what circumstances could this be performed in DAGCombine?
53112   if ((N0.getOpcode() == ISD::TRUNCATE || N0.getOpcode() == ISD::ZERO_EXTEND) &&
53113       N0.getOperand(0).getOpcode() == N->getOpcode()) {
53114     SDValue TruncExtSrc = N0.getOperand(0);
53115     auto *N1C = dyn_cast<ConstantSDNode>(N1);
53116     auto *N001C = dyn_cast<ConstantSDNode>(TruncExtSrc.getOperand(1));
53117     if (N1C && !N1C->isOpaque() && N001C && !N001C->isOpaque()) {
53118       SDValue LHS = DAG.getZExtOrTrunc(TruncExtSrc.getOperand(0), DL, VT);
53119       SDValue RHS = DAG.getZExtOrTrunc(TruncExtSrc.getOperand(1), DL, VT);
53120       return DAG.getNode(ISD::XOR, DL, VT, LHS,
53121                          DAG.getNode(ISD::XOR, DL, VT, RHS, N1));
53122     }
53123   }
53124 
53125   if (SDValue R = combineBMILogicOp(N, DAG, Subtarget))
53126     return R;
53127 
53128   return combineFneg(N, DAG, DCI, Subtarget);
53129 }
53130 
combineBITREVERSE(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)53131 static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
53132                                  TargetLowering::DAGCombinerInfo &DCI,
53133                                  const X86Subtarget &Subtarget) {
53134   SDValue N0 = N->getOperand(0);
53135   EVT VT = N->getValueType(0);
53136 
53137   // Convert a (iX bitreverse(bitcast(vXi1 X))) -> (iX bitcast(shuffle(X)))
53138   if (VT.isInteger() && N0.getOpcode() == ISD::BITCAST && N0.hasOneUse()) {
53139     SDValue Src = N0.getOperand(0);
53140     EVT SrcVT = Src.getValueType();
53141     if (SrcVT.isVector() && SrcVT.getScalarType() == MVT::i1 &&
53142         (DCI.isBeforeLegalize() ||
53143          DAG.getTargetLoweringInfo().isTypeLegal(SrcVT)) &&
53144         Subtarget.hasSSSE3()) {
53145       unsigned NumElts = SrcVT.getVectorNumElements();
53146       SmallVector<int, 32> ReverseMask(NumElts);
53147       for (unsigned I = 0; I != NumElts; ++I)
53148         ReverseMask[I] = (NumElts - 1) - I;
53149       SDValue Rev =
53150           DAG.getVectorShuffle(SrcVT, SDLoc(N), Src, Src, ReverseMask);
53151       return DAG.getBitcast(VT, Rev);
53152     }
53153   }
53154 
53155   return SDValue();
53156 }
53157 
53158 // Various combines to try to convert to avgceilu.
combineAVG(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)53159 static SDValue combineAVG(SDNode *N, SelectionDAG &DAG,
53160                           TargetLowering::DAGCombinerInfo &DCI,
53161                           const X86Subtarget &Subtarget) {
53162   unsigned Opcode = N->getOpcode();
53163   SDValue N0 = N->getOperand(0);
53164   SDValue N1 = N->getOperand(1);
53165   EVT VT = N->getValueType(0);
53166   EVT SVT = VT.getScalarType();
53167   SDLoc DL(N);
53168 
53169   // avgceils(x,y) -> flipsign(avgceilu(flipsign(x),flipsign(y)))
53170   // Only useful on vXi8 which doesn't have good SRA handling.
53171   if (Opcode == ISD::AVGCEILS && VT.isVector() && SVT == MVT::i8) {
53172     APInt SignBit = APInt::getSignMask(VT.getScalarSizeInBits());
53173     SDValue SignMask = DAG.getConstant(SignBit, DL, VT);
53174     N0 = DAG.getNode(ISD::XOR, DL, VT, N0, SignMask);
53175     N1 = DAG.getNode(ISD::XOR, DL, VT, N1, SignMask);
53176     return DAG.getNode(ISD::XOR, DL, VT,
53177                        DAG.getNode(ISD::AVGCEILU, DL, VT, N0, N1), SignMask);
53178   }
53179 
53180   return SDValue();
53181 }
53182 
combineBEXTR(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)53183 static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG,
53184                             TargetLowering::DAGCombinerInfo &DCI,
53185                             const X86Subtarget &Subtarget) {
53186   EVT VT = N->getValueType(0);
53187   unsigned NumBits = VT.getSizeInBits();
53188 
53189   // TODO - Constant Folding.
53190 
53191   // Simplify the inputs.
53192   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53193   APInt DemandedMask(APInt::getAllOnes(NumBits));
53194   if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI))
53195     return SDValue(N, 0);
53196 
53197   return SDValue();
53198 }
53199 
isNullFPScalarOrVectorConst(SDValue V)53200 static bool isNullFPScalarOrVectorConst(SDValue V) {
53201   return isNullFPConstant(V) || ISD::isBuildVectorAllZeros(V.getNode());
53202 }
53203 
53204 /// If a value is a scalar FP zero or a vector FP zero (potentially including
53205 /// undefined elements), return a zero constant that may be used to fold away
53206 /// that value. In the case of a vector, the returned constant will not contain
53207 /// undefined elements even if the input parameter does. This makes it suitable
53208 /// to be used as a replacement operand with operations (eg, bitwise-and) where
53209 /// an undef should not propagate.
getNullFPConstForNullVal(SDValue V,SelectionDAG & DAG,const X86Subtarget & Subtarget)53210 static SDValue getNullFPConstForNullVal(SDValue V, SelectionDAG &DAG,
53211                                         const X86Subtarget &Subtarget) {
53212   if (!isNullFPScalarOrVectorConst(V))
53213     return SDValue();
53214 
53215   if (V.getValueType().isVector())
53216     return getZeroVector(V.getSimpleValueType(), Subtarget, DAG, SDLoc(V));
53217 
53218   return V;
53219 }
53220 
combineFAndFNotToFAndn(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)53221 static SDValue combineFAndFNotToFAndn(SDNode *N, SelectionDAG &DAG,
53222                                       const X86Subtarget &Subtarget) {
53223   SDValue N0 = N->getOperand(0);
53224   SDValue N1 = N->getOperand(1);
53225   EVT VT = N->getValueType(0);
53226   SDLoc DL(N);
53227 
53228   // Vector types are handled in combineANDXORWithAllOnesIntoANDNP().
53229   if (!((VT == MVT::f32 && Subtarget.hasSSE1()) ||
53230         (VT == MVT::f64 && Subtarget.hasSSE2()) ||
53231         (VT == MVT::v4f32 && Subtarget.hasSSE1() && !Subtarget.hasSSE2())))
53232     return SDValue();
53233 
53234   auto isAllOnesConstantFP = [](SDValue V) {
53235     if (V.getSimpleValueType().isVector())
53236       return ISD::isBuildVectorAllOnes(V.getNode());
53237     auto *C = dyn_cast<ConstantFPSDNode>(V);
53238     return C && C->getConstantFPValue()->isAllOnesValue();
53239   };
53240 
53241   // fand (fxor X, -1), Y --> fandn X, Y
53242   if (N0.getOpcode() == X86ISD::FXOR && isAllOnesConstantFP(N0.getOperand(1)))
53243     return DAG.getNode(X86ISD::FANDN, DL, VT, N0.getOperand(0), N1);
53244 
53245   // fand X, (fxor Y, -1) --> fandn Y, X
53246   if (N1.getOpcode() == X86ISD::FXOR && isAllOnesConstantFP(N1.getOperand(1)))
53247     return DAG.getNode(X86ISD::FANDN, DL, VT, N1.getOperand(0), N0);
53248 
53249   return SDValue();
53250 }
53251 
53252 /// Do target-specific dag combines on X86ISD::FAND nodes.
combineFAnd(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)53253 static SDValue combineFAnd(SDNode *N, SelectionDAG &DAG,
53254                            const X86Subtarget &Subtarget) {
53255   // FAND(0.0, x) -> 0.0
53256   if (SDValue V = getNullFPConstForNullVal(N->getOperand(0), DAG, Subtarget))
53257     return V;
53258 
53259   // FAND(x, 0.0) -> 0.0
53260   if (SDValue V = getNullFPConstForNullVal(N->getOperand(1), DAG, Subtarget))
53261     return V;
53262 
53263   if (SDValue V = combineFAndFNotToFAndn(N, DAG, Subtarget))
53264     return V;
53265 
53266   return lowerX86FPLogicOp(N, DAG, Subtarget);
53267 }
53268 
53269 /// Do target-specific dag combines on X86ISD::FANDN nodes.
combineFAndn(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)53270 static SDValue combineFAndn(SDNode *N, SelectionDAG &DAG,
53271                             const X86Subtarget &Subtarget) {
53272   // FANDN(0.0, x) -> x
53273   if (isNullFPScalarOrVectorConst(N->getOperand(0)))
53274     return N->getOperand(1);
53275 
53276   // FANDN(x, 0.0) -> 0.0
53277   if (SDValue V = getNullFPConstForNullVal(N->getOperand(1), DAG, Subtarget))
53278     return V;
53279 
53280   return lowerX86FPLogicOp(N, DAG, Subtarget);
53281 }
53282 
53283 /// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes.
combineFOr(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)53284 static SDValue combineFOr(SDNode *N, SelectionDAG &DAG,
53285                           TargetLowering::DAGCombinerInfo &DCI,
53286                           const X86Subtarget &Subtarget) {
53287   assert(N->getOpcode() == X86ISD::FOR || N->getOpcode() == X86ISD::FXOR);
53288 
53289   // F[X]OR(0.0, x) -> x
53290   if (isNullFPScalarOrVectorConst(N->getOperand(0)))
53291     return N->getOperand(1);
53292 
53293   // F[X]OR(x, 0.0) -> x
53294   if (isNullFPScalarOrVectorConst(N->getOperand(1)))
53295     return N->getOperand(0);
53296 
53297   if (SDValue NewVal = combineFneg(N, DAG, DCI, Subtarget))
53298     return NewVal;
53299 
53300   return lowerX86FPLogicOp(N, DAG, Subtarget);
53301 }
53302 
53303 /// Do target-specific dag combines on X86ISD::FMIN and X86ISD::FMAX nodes.
combineFMinFMax(SDNode * N,SelectionDAG & DAG)53304 static SDValue combineFMinFMax(SDNode *N, SelectionDAG &DAG) {
53305   assert(N->getOpcode() == X86ISD::FMIN || N->getOpcode() == X86ISD::FMAX);
53306 
53307   // FMIN/FMAX are commutative if no NaNs and no negative zeros are allowed.
53308   if (!DAG.getTarget().Options.NoNaNsFPMath ||
53309       !DAG.getTarget().Options.NoSignedZerosFPMath)
53310     return SDValue();
53311 
53312   // If we run in unsafe-math mode, then convert the FMAX and FMIN nodes
53313   // into FMINC and FMAXC, which are Commutative operations.
53314   unsigned NewOp = 0;
53315   switch (N->getOpcode()) {
53316     default: llvm_unreachable("unknown opcode");
53317     case X86ISD::FMIN:  NewOp = X86ISD::FMINC; break;
53318     case X86ISD::FMAX:  NewOp = X86ISD::FMAXC; break;
53319   }
53320 
53321   return DAG.getNode(NewOp, SDLoc(N), N->getValueType(0),
53322                      N->getOperand(0), N->getOperand(1));
53323 }
53324 
combineFMinNumFMaxNum(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)53325 static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
53326                                      const X86Subtarget &Subtarget) {
53327   EVT VT = N->getValueType(0);
53328   if (Subtarget.useSoftFloat() || isSoftF16(VT, Subtarget))
53329     return SDValue();
53330 
53331   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53332 
53333   if (!((Subtarget.hasSSE1() && VT == MVT::f32) ||
53334         (Subtarget.hasSSE2() && VT == MVT::f64) ||
53335         (Subtarget.hasFP16() && VT == MVT::f16) ||
53336         (VT.isVector() && TLI.isTypeLegal(VT))))
53337     return SDValue();
53338 
53339   SDValue Op0 = N->getOperand(0);
53340   SDValue Op1 = N->getOperand(1);
53341   SDLoc DL(N);
53342   auto MinMaxOp = N->getOpcode() == ISD::FMAXNUM ? X86ISD::FMAX : X86ISD::FMIN;
53343 
53344   // If we don't have to respect NaN inputs, this is a direct translation to x86
53345   // min/max instructions.
53346   if (DAG.getTarget().Options.NoNaNsFPMath || N->getFlags().hasNoNaNs())
53347     return DAG.getNode(MinMaxOp, DL, VT, Op0, Op1, N->getFlags());
53348 
53349   // If one of the operands is known non-NaN use the native min/max instructions
53350   // with the non-NaN input as second operand.
53351   if (DAG.isKnownNeverNaN(Op1))
53352     return DAG.getNode(MinMaxOp, DL, VT, Op0, Op1, N->getFlags());
53353   if (DAG.isKnownNeverNaN(Op0))
53354     return DAG.getNode(MinMaxOp, DL, VT, Op1, Op0, N->getFlags());
53355 
53356   // If we have to respect NaN inputs, this takes at least 3 instructions.
53357   // Favor a library call when operating on a scalar and minimizing code size.
53358   if (!VT.isVector() && DAG.getMachineFunction().getFunction().hasMinSize())
53359     return SDValue();
53360 
53361   EVT SetCCType = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
53362                                          VT);
53363 
53364   // There are 4 possibilities involving NaN inputs, and these are the required
53365   // outputs:
53366   //                   Op1
53367   //               Num     NaN
53368   //            ----------------
53369   //       Num  |  Max  |  Op0 |
53370   // Op0        ----------------
53371   //       NaN  |  Op1  |  NaN |
53372   //            ----------------
53373   //
53374   // The SSE FP max/min instructions were not designed for this case, but rather
53375   // to implement:
53376   //   Min = Op1 < Op0 ? Op1 : Op0
53377   //   Max = Op1 > Op0 ? Op1 : Op0
53378   //
53379   // So they always return Op0 if either input is a NaN. However, we can still
53380   // use those instructions for fmaxnum by selecting away a NaN input.
53381 
53382   // If either operand is NaN, the 2nd source operand (Op0) is passed through.
53383   SDValue MinOrMax = DAG.getNode(MinMaxOp, DL, VT, Op1, Op0);
53384   SDValue IsOp0Nan = DAG.getSetCC(DL, SetCCType, Op0, Op0, ISD::SETUO);
53385 
53386   // If Op0 is a NaN, select Op1. Otherwise, select the max. If both operands
53387   // are NaN, the NaN value of Op1 is the result.
53388   return DAG.getSelect(DL, VT, IsOp0Nan, Op1, MinOrMax);
53389 }
53390 
combineX86INT_TO_FP(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)53391 static SDValue combineX86INT_TO_FP(SDNode *N, SelectionDAG &DAG,
53392                                    TargetLowering::DAGCombinerInfo &DCI) {
53393   EVT VT = N->getValueType(0);
53394   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53395 
53396   APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
53397   if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, DCI))
53398     return SDValue(N, 0);
53399 
53400   // Convert a full vector load into vzload when not all bits are needed.
53401   SDValue In = N->getOperand(0);
53402   MVT InVT = In.getSimpleValueType();
53403   if (VT.getVectorNumElements() < InVT.getVectorNumElements() &&
53404       ISD::isNormalLoad(In.getNode()) && In.hasOneUse()) {
53405     assert(InVT.is128BitVector() && "Expected 128-bit input vector");
53406     LoadSDNode *LN = cast<LoadSDNode>(N->getOperand(0));
53407     unsigned NumBits = InVT.getScalarSizeInBits() * VT.getVectorNumElements();
53408     MVT MemVT = MVT::getIntegerVT(NumBits);
53409     MVT LoadVT = MVT::getVectorVT(MemVT, 128 / NumBits);
53410     if (SDValue VZLoad = narrowLoadToVZLoad(LN, MemVT, LoadVT, DAG)) {
53411       SDLoc dl(N);
53412       SDValue Convert = DAG.getNode(N->getOpcode(), dl, VT,
53413                                     DAG.getBitcast(InVT, VZLoad));
53414       DCI.CombineTo(N, Convert);
53415       DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
53416       DCI.recursivelyDeleteUnusedNodes(LN);
53417       return SDValue(N, 0);
53418     }
53419   }
53420 
53421   return SDValue();
53422 }
53423 
combineCVTP2I_CVTTP2I(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)53424 static SDValue combineCVTP2I_CVTTP2I(SDNode *N, SelectionDAG &DAG,
53425                                      TargetLowering::DAGCombinerInfo &DCI) {
53426   bool IsStrict = N->isTargetStrictFPOpcode();
53427   EVT VT = N->getValueType(0);
53428 
53429   // Convert a full vector load into vzload when not all bits are needed.
53430   SDValue In = N->getOperand(IsStrict ? 1 : 0);
53431   MVT InVT = In.getSimpleValueType();
53432   if (VT.getVectorNumElements() < InVT.getVectorNumElements() &&
53433       ISD::isNormalLoad(In.getNode()) && In.hasOneUse()) {
53434     assert(InVT.is128BitVector() && "Expected 128-bit input vector");
53435     LoadSDNode *LN = cast<LoadSDNode>(In);
53436     unsigned NumBits = InVT.getScalarSizeInBits() * VT.getVectorNumElements();
53437     MVT MemVT = MVT::getFloatingPointVT(NumBits);
53438     MVT LoadVT = MVT::getVectorVT(MemVT, 128 / NumBits);
53439     if (SDValue VZLoad = narrowLoadToVZLoad(LN, MemVT, LoadVT, DAG)) {
53440       SDLoc dl(N);
53441       if (IsStrict) {
53442         SDValue Convert =
53443             DAG.getNode(N->getOpcode(), dl, {VT, MVT::Other},
53444                         {N->getOperand(0), DAG.getBitcast(InVT, VZLoad)});
53445         DCI.CombineTo(N, Convert, Convert.getValue(1));
53446       } else {
53447         SDValue Convert =
53448             DAG.getNode(N->getOpcode(), dl, VT, DAG.getBitcast(InVT, VZLoad));
53449         DCI.CombineTo(N, Convert);
53450       }
53451       DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
53452       DCI.recursivelyDeleteUnusedNodes(LN);
53453       return SDValue(N, 0);
53454     }
53455   }
53456 
53457   return SDValue();
53458 }
53459 
53460 /// Do target-specific dag combines on X86ISD::ANDNP nodes.
combineAndnp(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)53461 static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG,
53462                             TargetLowering::DAGCombinerInfo &DCI,
53463                             const X86Subtarget &Subtarget) {
53464   SDValue N0 = N->getOperand(0);
53465   SDValue N1 = N->getOperand(1);
53466   MVT VT = N->getSimpleValueType(0);
53467   int NumElts = VT.getVectorNumElements();
53468   unsigned EltSizeInBits = VT.getScalarSizeInBits();
53469   SDLoc DL(N);
53470 
53471   // ANDNP(undef, x) -> 0
53472   // ANDNP(x, undef) -> 0
53473   if (N0.isUndef() || N1.isUndef())
53474     return DAG.getConstant(0, DL, VT);
53475 
53476   // ANDNP(0, x) -> x
53477   if (ISD::isBuildVectorAllZeros(N0.getNode()))
53478     return N1;
53479 
53480   // ANDNP(x, 0) -> 0
53481   if (ISD::isBuildVectorAllZeros(N1.getNode()))
53482     return DAG.getConstant(0, DL, VT);
53483 
53484   // ANDNP(x, -1) -> NOT(x) -> XOR(x, -1)
53485   if (ISD::isBuildVectorAllOnes(N1.getNode()))
53486     return DAG.getNOT(DL, N0, VT);
53487 
53488   // Turn ANDNP back to AND if input is inverted.
53489   if (SDValue Not = IsNOT(N0, DAG))
53490     return DAG.getNode(ISD::AND, DL, VT, DAG.getBitcast(VT, Not), N1);
53491 
53492   // Fold for better commutativity:
53493   // ANDNP(x,NOT(y)) -> AND(NOT(x),NOT(y)) -> NOT(OR(X,Y)).
53494   if (N1->hasOneUse())
53495     if (SDValue Not = IsNOT(N1, DAG))
53496       return DAG.getNOT(
53497           DL, DAG.getNode(ISD::OR, DL, VT, N0, DAG.getBitcast(VT, Not)), VT);
53498 
53499   // Constant Folding
53500   APInt Undefs0, Undefs1;
53501   SmallVector<APInt> EltBits0, EltBits1;
53502   if (getTargetConstantBitsFromNode(N0, EltSizeInBits, Undefs0, EltBits0,
53503                                     /*AllowWholeUndefs*/ true,
53504                                     /*AllowPartialUndefs*/ true)) {
53505     if (getTargetConstantBitsFromNode(N1, EltSizeInBits, Undefs1, EltBits1,
53506                                       /*AllowWholeUndefs*/ true,
53507                                       /*AllowPartialUndefs*/ true)) {
53508       SmallVector<APInt> ResultBits;
53509       for (int I = 0; I != NumElts; ++I)
53510         ResultBits.push_back(~EltBits0[I] & EltBits1[I]);
53511       return getConstVector(ResultBits, VT, DAG, DL);
53512     }
53513 
53514     // Constant fold NOT(N0) to allow us to use AND.
53515     // Ensure this is only performed if we can confirm that the bitcasted source
53516     // has oneuse to prevent an infinite loop with canonicalizeBitSelect.
53517     if (N0->hasOneUse()) {
53518       SDValue BC0 = peekThroughOneUseBitcasts(N0);
53519       if (BC0.getOpcode() != ISD::BITCAST) {
53520         for (APInt &Elt : EltBits0)
53521           Elt = ~Elt;
53522         SDValue Not = getConstVector(EltBits0, VT, DAG, DL);
53523         return DAG.getNode(ISD::AND, DL, VT, Not, N1);
53524       }
53525     }
53526   }
53527 
53528   // Attempt to recursively combine a bitmask ANDNP with shuffles.
53529   if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) {
53530     SDValue Op(N, 0);
53531     if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
53532       return Res;
53533 
53534     // If either operand is a constant mask, then only the elements that aren't
53535     // zero are actually demanded by the other operand.
53536     auto GetDemandedMasks = [&](SDValue Op, bool Invert = false) {
53537       APInt UndefElts;
53538       SmallVector<APInt> EltBits;
53539       APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
53540       APInt DemandedElts = APInt::getAllOnes(NumElts);
53541       if (getTargetConstantBitsFromNode(Op, EltSizeInBits, UndefElts,
53542                                         EltBits)) {
53543         DemandedBits.clearAllBits();
53544         DemandedElts.clearAllBits();
53545         for (int I = 0; I != NumElts; ++I) {
53546           if (UndefElts[I]) {
53547             // We can't assume an undef src element gives an undef dst - the
53548             // other src might be zero.
53549             DemandedBits.setAllBits();
53550             DemandedElts.setBit(I);
53551           } else if ((Invert && !EltBits[I].isAllOnes()) ||
53552                      (!Invert && !EltBits[I].isZero())) {
53553             DemandedBits |= Invert ? ~EltBits[I] : EltBits[I];
53554             DemandedElts.setBit(I);
53555           }
53556         }
53557       }
53558       return std::make_pair(DemandedBits, DemandedElts);
53559     };
53560     APInt Bits0, Elts0;
53561     APInt Bits1, Elts1;
53562     std::tie(Bits0, Elts0) = GetDemandedMasks(N1);
53563     std::tie(Bits1, Elts1) = GetDemandedMasks(N0, true);
53564 
53565     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53566     if (TLI.SimplifyDemandedVectorElts(N0, Elts0, DCI) ||
53567         TLI.SimplifyDemandedVectorElts(N1, Elts1, DCI) ||
53568         TLI.SimplifyDemandedBits(N0, Bits0, Elts0, DCI) ||
53569         TLI.SimplifyDemandedBits(N1, Bits1, Elts1, DCI)) {
53570       if (N->getOpcode() != ISD::DELETED_NODE)
53571         DCI.AddToWorklist(N);
53572       return SDValue(N, 0);
53573     }
53574   }
53575 
53576   return SDValue();
53577 }
53578 
combineBT(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)53579 static SDValue combineBT(SDNode *N, SelectionDAG &DAG,
53580                          TargetLowering::DAGCombinerInfo &DCI) {
53581   SDValue N1 = N->getOperand(1);
53582 
53583   // BT ignores high bits in the bit index operand.
53584   unsigned BitWidth = N1.getValueSizeInBits();
53585   APInt DemandedMask = APInt::getLowBitsSet(BitWidth, Log2_32(BitWidth));
53586   if (DAG.getTargetLoweringInfo().SimplifyDemandedBits(N1, DemandedMask, DCI)) {
53587     if (N->getOpcode() != ISD::DELETED_NODE)
53588       DCI.AddToWorklist(N);
53589     return SDValue(N, 0);
53590   }
53591 
53592   return SDValue();
53593 }
53594 
combineCVTPH2PS(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)53595 static SDValue combineCVTPH2PS(SDNode *N, SelectionDAG &DAG,
53596                                TargetLowering::DAGCombinerInfo &DCI) {
53597   bool IsStrict = N->getOpcode() == X86ISD::STRICT_CVTPH2PS;
53598   SDValue Src = N->getOperand(IsStrict ? 1 : 0);
53599 
53600   if (N->getValueType(0) == MVT::v4f32 && Src.getValueType() == MVT::v8i16) {
53601     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53602     APInt DemandedElts = APInt::getLowBitsSet(8, 4);
53603     if (TLI.SimplifyDemandedVectorElts(Src, DemandedElts, DCI)) {
53604       if (N->getOpcode() != ISD::DELETED_NODE)
53605         DCI.AddToWorklist(N);
53606       return SDValue(N, 0);
53607     }
53608 
53609     // Convert a full vector load into vzload when not all bits are needed.
53610     if (ISD::isNormalLoad(Src.getNode()) && Src.hasOneUse()) {
53611       LoadSDNode *LN = cast<LoadSDNode>(N->getOperand(IsStrict ? 1 : 0));
53612       if (SDValue VZLoad = narrowLoadToVZLoad(LN, MVT::i64, MVT::v2i64, DAG)) {
53613         SDLoc dl(N);
53614         if (IsStrict) {
53615           SDValue Convert = DAG.getNode(
53616               N->getOpcode(), dl, {MVT::v4f32, MVT::Other},
53617               {N->getOperand(0), DAG.getBitcast(MVT::v8i16, VZLoad)});
53618           DCI.CombineTo(N, Convert, Convert.getValue(1));
53619         } else {
53620           SDValue Convert = DAG.getNode(N->getOpcode(), dl, MVT::v4f32,
53621                                         DAG.getBitcast(MVT::v8i16, VZLoad));
53622           DCI.CombineTo(N, Convert);
53623         }
53624 
53625         DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), VZLoad.getValue(1));
53626         DCI.recursivelyDeleteUnusedNodes(LN);
53627         return SDValue(N, 0);
53628       }
53629     }
53630   }
53631 
53632   return SDValue();
53633 }
53634 
53635 // Try to combine sext_in_reg of a cmov of constants by extending the constants.
combineSextInRegCmov(SDNode * N,SelectionDAG & DAG)53636 static SDValue combineSextInRegCmov(SDNode *N, SelectionDAG &DAG) {
53637   assert(N->getOpcode() == ISD::SIGN_EXTEND_INREG);
53638 
53639   EVT DstVT = N->getValueType(0);
53640 
53641   SDValue N0 = N->getOperand(0);
53642   SDValue N1 = N->getOperand(1);
53643   EVT ExtraVT = cast<VTSDNode>(N1)->getVT();
53644 
53645   if (ExtraVT != MVT::i8 && ExtraVT != MVT::i16)
53646     return SDValue();
53647 
53648   // Look through single use any_extends / truncs.
53649   SDValue IntermediateBitwidthOp;
53650   if ((N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::TRUNCATE) &&
53651       N0.hasOneUse()) {
53652     IntermediateBitwidthOp = N0;
53653     N0 = N0.getOperand(0);
53654   }
53655 
53656   // See if we have a single use cmov.
53657   if (N0.getOpcode() != X86ISD::CMOV || !N0.hasOneUse())
53658     return SDValue();
53659 
53660   SDValue CMovOp0 = N0.getOperand(0);
53661   SDValue CMovOp1 = N0.getOperand(1);
53662 
53663   // Make sure both operands are constants.
53664   if (!isa<ConstantSDNode>(CMovOp0.getNode()) ||
53665       !isa<ConstantSDNode>(CMovOp1.getNode()))
53666     return SDValue();
53667 
53668   SDLoc DL(N);
53669 
53670   // If we looked through an any_extend/trunc above, add one to the constants.
53671   if (IntermediateBitwidthOp) {
53672     unsigned IntermediateOpc = IntermediateBitwidthOp.getOpcode();
53673     CMovOp0 = DAG.getNode(IntermediateOpc, DL, DstVT, CMovOp0);
53674     CMovOp1 = DAG.getNode(IntermediateOpc, DL, DstVT, CMovOp1);
53675   }
53676 
53677   CMovOp0 = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, DstVT, CMovOp0, N1);
53678   CMovOp1 = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, DstVT, CMovOp1, N1);
53679 
53680   EVT CMovVT = DstVT;
53681   // We do not want i16 CMOV's. Promote to i32 and truncate afterwards.
53682   if (DstVT == MVT::i16) {
53683     CMovVT = MVT::i32;
53684     CMovOp0 = DAG.getNode(ISD::ZERO_EXTEND, DL, CMovVT, CMovOp0);
53685     CMovOp1 = DAG.getNode(ISD::ZERO_EXTEND, DL, CMovVT, CMovOp1);
53686   }
53687 
53688   SDValue CMov = DAG.getNode(X86ISD::CMOV, DL, CMovVT, CMovOp0, CMovOp1,
53689                              N0.getOperand(2), N0.getOperand(3));
53690 
53691   if (CMovVT != DstVT)
53692     CMov = DAG.getNode(ISD::TRUNCATE, DL, DstVT, CMov);
53693 
53694   return CMov;
53695 }
53696 
combineSignExtendInReg(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)53697 static SDValue combineSignExtendInReg(SDNode *N, SelectionDAG &DAG,
53698                                       const X86Subtarget &Subtarget) {
53699   assert(N->getOpcode() == ISD::SIGN_EXTEND_INREG);
53700 
53701   if (SDValue V = combineSextInRegCmov(N, DAG))
53702     return V;
53703 
53704   EVT VT = N->getValueType(0);
53705   SDValue N0 = N->getOperand(0);
53706   SDValue N1 = N->getOperand(1);
53707   EVT ExtraVT = cast<VTSDNode>(N1)->getVT();
53708   SDLoc dl(N);
53709 
53710   // The SIGN_EXTEND_INREG to v4i64 is expensive operation on the
53711   // both SSE and AVX2 since there is no sign-extended shift right
53712   // operation on a vector with 64-bit elements.
53713   //(sext_in_reg (v4i64 anyext (v4i32 x )), ExtraVT) ->
53714   // (v4i64 sext (v4i32 sext_in_reg (v4i32 x , ExtraVT)))
53715   if (VT == MVT::v4i64 && (N0.getOpcode() == ISD::ANY_EXTEND ||
53716                            N0.getOpcode() == ISD::SIGN_EXTEND)) {
53717     SDValue N00 = N0.getOperand(0);
53718 
53719     // EXTLOAD has a better solution on AVX2,
53720     // it may be replaced with X86ISD::VSEXT node.
53721     if (N00.getOpcode() == ISD::LOAD && Subtarget.hasInt256())
53722       if (!ISD::isNormalLoad(N00.getNode()))
53723         return SDValue();
53724 
53725     // Attempt to promote any comparison mask ops before moving the
53726     // SIGN_EXTEND_INREG in the way.
53727     if (SDValue Promote = PromoteMaskArithmetic(N0, dl, DAG, Subtarget))
53728       return DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, VT, Promote, N1);
53729 
53730     if (N00.getValueType() == MVT::v4i32 && ExtraVT.getSizeInBits() < 128) {
53731       SDValue Tmp =
53732           DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, MVT::v4i32, N00, N1);
53733       return DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i64, Tmp);
53734     }
53735   }
53736   return SDValue();
53737 }
53738 
53739 /// sext(add_nsw(x, C)) --> add(sext(x), C_sext)
53740 /// zext(add_nuw(x, C)) --> add(zext(x), C_zext)
53741 /// Promoting a sign/zero extension ahead of a no overflow 'add' exposes
53742 /// opportunities to combine math ops, use an LEA, or use a complex addressing
53743 /// mode. This can eliminate extend, add, and shift instructions.
promoteExtBeforeAdd(SDNode * Ext,SelectionDAG & DAG,const X86Subtarget & Subtarget)53744 static SDValue promoteExtBeforeAdd(SDNode *Ext, SelectionDAG &DAG,
53745                                    const X86Subtarget &Subtarget) {
53746   if (Ext->getOpcode() != ISD::SIGN_EXTEND &&
53747       Ext->getOpcode() != ISD::ZERO_EXTEND)
53748     return SDValue();
53749 
53750   // TODO: This should be valid for other integer types.
53751   EVT VT = Ext->getValueType(0);
53752   if (VT != MVT::i64)
53753     return SDValue();
53754 
53755   SDValue Add = Ext->getOperand(0);
53756   if (Add.getOpcode() != ISD::ADD)
53757     return SDValue();
53758 
53759   SDValue AddOp0 = Add.getOperand(0);
53760   SDValue AddOp1 = Add.getOperand(1);
53761   bool Sext = Ext->getOpcode() == ISD::SIGN_EXTEND;
53762   bool NSW = Add->getFlags().hasNoSignedWrap();
53763   bool NUW = Add->getFlags().hasNoUnsignedWrap();
53764   NSW = NSW || (Sext && DAG.willNotOverflowAdd(true, AddOp0, AddOp1));
53765   NUW = NUW || (!Sext && DAG.willNotOverflowAdd(false, AddOp0, AddOp1));
53766 
53767   // We need an 'add nsw' feeding into the 'sext' or 'add nuw' feeding
53768   // into the 'zext'
53769   if ((Sext && !NSW) || (!Sext && !NUW))
53770     return SDValue();
53771 
53772   // Having a constant operand to the 'add' ensures that we are not increasing
53773   // the instruction count because the constant is extended for free below.
53774   // A constant operand can also become the displacement field of an LEA.
53775   auto *AddOp1C = dyn_cast<ConstantSDNode>(AddOp1);
53776   if (!AddOp1C)
53777     return SDValue();
53778 
53779   // Don't make the 'add' bigger if there's no hope of combining it with some
53780   // other 'add' or 'shl' instruction.
53781   // TODO: It may be profitable to generate simpler LEA instructions in place
53782   // of single 'add' instructions, but the cost model for selecting an LEA
53783   // currently has a high threshold.
53784   bool HasLEAPotential = false;
53785   for (auto *User : Ext->uses()) {
53786     if (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SHL) {
53787       HasLEAPotential = true;
53788       break;
53789     }
53790   }
53791   if (!HasLEAPotential)
53792     return SDValue();
53793 
53794   // Everything looks good, so pull the '{s|z}ext' ahead of the 'add'.
53795   int64_t AddC = Sext ? AddOp1C->getSExtValue() : AddOp1C->getZExtValue();
53796   SDValue NewExt = DAG.getNode(Ext->getOpcode(), SDLoc(Ext), VT, AddOp0);
53797   SDValue NewConstant = DAG.getConstant(AddC, SDLoc(Add), VT);
53798 
53799   // The wider add is guaranteed to not wrap because both operands are
53800   // sign-extended.
53801   SDNodeFlags Flags;
53802   Flags.setNoSignedWrap(NSW);
53803   Flags.setNoUnsignedWrap(NUW);
53804   return DAG.getNode(ISD::ADD, SDLoc(Add), VT, NewExt, NewConstant, Flags);
53805 }
53806 
53807 // If we face {ANY,SIGN,ZERO}_EXTEND that is applied to a CMOV with constant
53808 // operands and the result of CMOV is not used anywhere else - promote CMOV
53809 // itself instead of promoting its result. This could be beneficial, because:
53810 //     1) X86TargetLowering::EmitLoweredSelect later can do merging of two
53811 //        (or more) pseudo-CMOVs only when they go one-after-another and
53812 //        getting rid of result extension code after CMOV will help that.
53813 //     2) Promotion of constant CMOV arguments is free, hence the
53814 //        {ANY,SIGN,ZERO}_EXTEND will just be deleted.
53815 //     3) 16-bit CMOV encoding is 4 bytes, 32-bit CMOV is 3-byte, so this
53816 //        promotion is also good in terms of code-size.
53817 //        (64-bit CMOV is 4-bytes, that's why we don't do 32-bit => 64-bit
53818 //         promotion).
combineToExtendCMOV(SDNode * Extend,SelectionDAG & DAG)53819 static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) {
53820   SDValue CMovN = Extend->getOperand(0);
53821   if (CMovN.getOpcode() != X86ISD::CMOV || !CMovN.hasOneUse())
53822     return SDValue();
53823 
53824   EVT TargetVT = Extend->getValueType(0);
53825   unsigned ExtendOpcode = Extend->getOpcode();
53826   SDLoc DL(Extend);
53827 
53828   EVT VT = CMovN.getValueType();
53829   SDValue CMovOp0 = CMovN.getOperand(0);
53830   SDValue CMovOp1 = CMovN.getOperand(1);
53831 
53832   if (!isa<ConstantSDNode>(CMovOp0.getNode()) ||
53833       !isa<ConstantSDNode>(CMovOp1.getNode()))
53834     return SDValue();
53835 
53836   // Only extend to i32 or i64.
53837   if (TargetVT != MVT::i32 && TargetVT != MVT::i64)
53838     return SDValue();
53839 
53840   // Only extend from i16 unless its a sign_extend from i32. Zext/aext from i32
53841   // are free.
53842   if (VT != MVT::i16 && !(ExtendOpcode == ISD::SIGN_EXTEND && VT == MVT::i32))
53843     return SDValue();
53844 
53845   // If this a zero extend to i64, we should only extend to i32 and use a free
53846   // zero extend to finish.
53847   EVT ExtendVT = TargetVT;
53848   if (TargetVT == MVT::i64 && ExtendOpcode != ISD::SIGN_EXTEND)
53849     ExtendVT = MVT::i32;
53850 
53851   CMovOp0 = DAG.getNode(ExtendOpcode, DL, ExtendVT, CMovOp0);
53852   CMovOp1 = DAG.getNode(ExtendOpcode, DL, ExtendVT, CMovOp1);
53853 
53854   SDValue Res = DAG.getNode(X86ISD::CMOV, DL, ExtendVT, CMovOp0, CMovOp1,
53855                             CMovN.getOperand(2), CMovN.getOperand(3));
53856 
53857   // Finish extending if needed.
53858   if (ExtendVT != TargetVT)
53859     Res = DAG.getNode(ExtendOpcode, DL, TargetVT, Res);
53860 
53861   return Res;
53862 }
53863 
53864 // Attempt to combine a (sext/zext (setcc)) to a setcc with a xmm/ymm/zmm
53865 // result type.
combineExtSetcc(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)53866 static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG,
53867                                const X86Subtarget &Subtarget) {
53868   SDValue N0 = N->getOperand(0);
53869   EVT VT = N->getValueType(0);
53870   SDLoc dl(N);
53871 
53872   // Only do this combine with AVX512 for vector extends.
53873   if (!Subtarget.hasAVX512() || !VT.isVector() || N0.getOpcode() != ISD::SETCC)
53874     return SDValue();
53875 
53876   // Only combine legal element types.
53877   EVT SVT = VT.getVectorElementType();
53878   if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32 &&
53879       SVT != MVT::i64 && SVT != MVT::f32 && SVT != MVT::f64)
53880     return SDValue();
53881 
53882   // We don't have CMPP Instruction for vxf16
53883   if (N0.getOperand(0).getValueType().getVectorElementType() == MVT::f16)
53884     return SDValue();
53885   // We can only do this if the vector size in 256 bits or less.
53886   unsigned Size = VT.getSizeInBits();
53887   if (Size > 256 && Subtarget.useAVX512Regs())
53888     return SDValue();
53889 
53890   // Don't fold if the condition code can't be handled by PCMPEQ/PCMPGT since
53891   // that's the only integer compares with we have.
53892   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
53893   if (ISD::isUnsignedIntSetCC(CC))
53894     return SDValue();
53895 
53896   // Only do this combine if the extension will be fully consumed by the setcc.
53897   EVT N00VT = N0.getOperand(0).getValueType();
53898   EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
53899   if (Size != MatchingVecType.getSizeInBits())
53900     return SDValue();
53901 
53902   SDValue Res = DAG.getSetCC(dl, VT, N0.getOperand(0), N0.getOperand(1), CC);
53903 
53904   if (N->getOpcode() == ISD::ZERO_EXTEND)
53905     Res = DAG.getZeroExtendInReg(Res, dl, N0.getValueType());
53906 
53907   return Res;
53908 }
53909 
combineSext(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)53910 static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
53911                            TargetLowering::DAGCombinerInfo &DCI,
53912                            const X86Subtarget &Subtarget) {
53913   SDValue N0 = N->getOperand(0);
53914   EVT VT = N->getValueType(0);
53915   SDLoc DL(N);
53916 
53917   // (i32 (sext (i8 (x86isd::setcc_carry)))) -> (i32 (x86isd::setcc_carry))
53918   if (!DCI.isBeforeLegalizeOps() &&
53919       N0.getOpcode() == X86ISD::SETCC_CARRY) {
53920     SDValue Setcc = DAG.getNode(X86ISD::SETCC_CARRY, DL, VT, N0->getOperand(0),
53921                                  N0->getOperand(1));
53922     bool ReplaceOtherUses = !N0.hasOneUse();
53923     DCI.CombineTo(N, Setcc);
53924     // Replace other uses with a truncate of the widened setcc_carry.
53925     if (ReplaceOtherUses) {
53926       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
53927                                   N0.getValueType(), Setcc);
53928       DCI.CombineTo(N0.getNode(), Trunc);
53929     }
53930 
53931     return SDValue(N, 0);
53932   }
53933 
53934   if (SDValue NewCMov = combineToExtendCMOV(N, DAG))
53935     return NewCMov;
53936 
53937   if (!DCI.isBeforeLegalizeOps())
53938     return SDValue();
53939 
53940   if (SDValue V = combineExtSetcc(N, DAG, Subtarget))
53941     return V;
53942 
53943   if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), DL, VT, N0,
53944                                                  DAG, DCI, Subtarget))
53945     return V;
53946 
53947   if (VT.isVector()) {
53948     if (SDValue R = PromoteMaskArithmetic(SDValue(N, 0), DL, DAG, Subtarget))
53949       return R;
53950 
53951     if (N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
53952       return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0));
53953   }
53954 
53955   if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget))
53956     return NewAdd;
53957 
53958   return SDValue();
53959 }
53960 
53961 // Inverting a constant vector is profitable if it can be eliminated and the
53962 // inverted vector is already present in DAG. Otherwise, it will be loaded
53963 // anyway.
53964 //
53965 // We determine which of the values can be completely eliminated and invert it.
53966 // If both are eliminable, select a vector with the first negative element.
getInvertedVectorForFMA(SDValue V,SelectionDAG & DAG)53967 static SDValue getInvertedVectorForFMA(SDValue V, SelectionDAG &DAG) {
53968   assert(ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()) &&
53969          "ConstantFP build vector expected");
53970   // Check if we can eliminate V. We assume if a value is only used in FMAs, we
53971   // can eliminate it. Since this function is invoked for each FMA with this
53972   // vector.
53973   auto IsNotFMA = [](SDNode *Use) {
53974     return Use->getOpcode() != ISD::FMA && Use->getOpcode() != ISD::STRICT_FMA;
53975   };
53976   if (llvm::any_of(V->uses(), IsNotFMA))
53977     return SDValue();
53978 
53979   SmallVector<SDValue, 8> Ops;
53980   EVT VT = V.getValueType();
53981   EVT EltVT = VT.getVectorElementType();
53982   for (const SDValue &Op : V->op_values()) {
53983     if (auto *Cst = dyn_cast<ConstantFPSDNode>(Op)) {
53984       Ops.push_back(DAG.getConstantFP(-Cst->getValueAPF(), SDLoc(Op), EltVT));
53985     } else {
53986       assert(Op.isUndef());
53987       Ops.push_back(DAG.getUNDEF(EltVT));
53988     }
53989   }
53990 
53991   SDNode *NV = DAG.getNodeIfExists(ISD::BUILD_VECTOR, DAG.getVTList(VT), Ops);
53992   if (!NV)
53993     return SDValue();
53994 
53995   // If an inverted version cannot be eliminated, choose it instead of the
53996   // original version.
53997   if (llvm::any_of(NV->uses(), IsNotFMA))
53998     return SDValue(NV, 0);
53999 
54000   // If the inverted version also can be eliminated, we have to consistently
54001   // prefer one of the values. We prefer a constant with a negative value on
54002   // the first place.
54003   // N.B. We need to skip undefs that may precede a value.
54004   for (const SDValue &Op : V->op_values()) {
54005     if (auto *Cst = dyn_cast<ConstantFPSDNode>(Op)) {
54006       if (Cst->isNegative())
54007         return SDValue();
54008       break;
54009     }
54010   }
54011   return SDValue(NV, 0);
54012 }
54013 
combineFMA(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)54014 static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
54015                           TargetLowering::DAGCombinerInfo &DCI,
54016                           const X86Subtarget &Subtarget) {
54017   SDLoc dl(N);
54018   EVT VT = N->getValueType(0);
54019   bool IsStrict = N->isStrictFPOpcode() || N->isTargetStrictFPOpcode();
54020 
54021   // Let legalize expand this if it isn't a legal type yet.
54022   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54023   if (!TLI.isTypeLegal(VT))
54024     return SDValue();
54025 
54026   SDValue A = N->getOperand(IsStrict ? 1 : 0);
54027   SDValue B = N->getOperand(IsStrict ? 2 : 1);
54028   SDValue C = N->getOperand(IsStrict ? 3 : 2);
54029 
54030   // If the operation allows fast-math and the target does not support FMA,
54031   // split this into mul+add to avoid libcall(s).
54032   SDNodeFlags Flags = N->getFlags();
54033   if (!IsStrict && Flags.hasAllowReassociation() &&
54034       TLI.isOperationExpand(ISD::FMA, VT)) {
54035     SDValue Fmul = DAG.getNode(ISD::FMUL, dl, VT, A, B, Flags);
54036     return DAG.getNode(ISD::FADD, dl, VT, Fmul, C, Flags);
54037   }
54038 
54039   EVT ScalarVT = VT.getScalarType();
54040   if (((ScalarVT != MVT::f32 && ScalarVT != MVT::f64) ||
54041        !Subtarget.hasAnyFMA()) &&
54042       !(ScalarVT == MVT::f16 && Subtarget.hasFP16()))
54043     return SDValue();
54044 
54045   auto invertIfNegative = [&DAG, &TLI, &DCI](SDValue &V) {
54046     bool CodeSize = DAG.getMachineFunction().getFunction().hasOptSize();
54047     bool LegalOperations = !DCI.isBeforeLegalizeOps();
54048     if (SDValue NegV = TLI.getCheaperNegatedExpression(V, DAG, LegalOperations,
54049                                                        CodeSize)) {
54050       V = NegV;
54051       return true;
54052     }
54053     // Look through extract_vector_elts. If it comes from an FNEG, create a
54054     // new extract from the FNEG input.
54055     if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
54056         isNullConstant(V.getOperand(1))) {
54057       SDValue Vec = V.getOperand(0);
54058       if (SDValue NegV = TLI.getCheaperNegatedExpression(
54059               Vec, DAG, LegalOperations, CodeSize)) {
54060         V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(),
54061                         NegV, V.getOperand(1));
54062         return true;
54063       }
54064     }
54065     // Lookup if there is an inverted version of constant vector V in DAG.
54066     if (ISD::isBuildVectorOfConstantFPSDNodes(V.getNode())) {
54067       if (SDValue NegV = getInvertedVectorForFMA(V, DAG)) {
54068         V = NegV;
54069         return true;
54070       }
54071     }
54072     return false;
54073   };
54074 
54075   // Do not convert the passthru input of scalar intrinsics.
54076   // FIXME: We could allow negations of the lower element only.
54077   bool NegA = invertIfNegative(A);
54078   bool NegB = invertIfNegative(B);
54079   bool NegC = invertIfNegative(C);
54080 
54081   if (!NegA && !NegB && !NegC)
54082     return SDValue();
54083 
54084   unsigned NewOpcode =
54085       negateFMAOpcode(N->getOpcode(), NegA != NegB, NegC, false);
54086 
54087   // Propagate fast-math-flags to new FMA node.
54088   SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
54089   if (IsStrict) {
54090     assert(N->getNumOperands() == 4 && "Shouldn't be greater than 4");
54091     return DAG.getNode(NewOpcode, dl, {VT, MVT::Other},
54092                        {N->getOperand(0), A, B, C});
54093   } else {
54094     if (N->getNumOperands() == 4)
54095       return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3));
54096     return DAG.getNode(NewOpcode, dl, VT, A, B, C);
54097   }
54098 }
54099 
54100 // Combine FMADDSUB(A, B, FNEG(C)) -> FMSUBADD(A, B, C)
54101 // Combine FMSUBADD(A, B, FNEG(C)) -> FMADDSUB(A, B, C)
combineFMADDSUB(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)54102 static SDValue combineFMADDSUB(SDNode *N, SelectionDAG &DAG,
54103                                TargetLowering::DAGCombinerInfo &DCI) {
54104   SDLoc dl(N);
54105   EVT VT = N->getValueType(0);
54106   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54107   bool CodeSize = DAG.getMachineFunction().getFunction().hasOptSize();
54108   bool LegalOperations = !DCI.isBeforeLegalizeOps();
54109 
54110   SDValue N2 = N->getOperand(2);
54111 
54112   SDValue NegN2 =
54113       TLI.getCheaperNegatedExpression(N2, DAG, LegalOperations, CodeSize);
54114   if (!NegN2)
54115     return SDValue();
54116   unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), false, true, false);
54117 
54118   if (N->getNumOperands() == 4)
54119     return DAG.getNode(NewOpcode, dl, VT, N->getOperand(0), N->getOperand(1),
54120                        NegN2, N->getOperand(3));
54121   return DAG.getNode(NewOpcode, dl, VT, N->getOperand(0), N->getOperand(1),
54122                      NegN2);
54123 }
54124 
combineZext(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)54125 static SDValue combineZext(SDNode *N, SelectionDAG &DAG,
54126                            TargetLowering::DAGCombinerInfo &DCI,
54127                            const X86Subtarget &Subtarget) {
54128   SDLoc dl(N);
54129   SDValue N0 = N->getOperand(0);
54130   EVT VT = N->getValueType(0);
54131 
54132   // (i32 (aext (i8 (x86isd::setcc_carry)))) -> (i32 (x86isd::setcc_carry))
54133   // FIXME: Is this needed? We don't seem to have any tests for it.
54134   if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ANY_EXTEND &&
54135       N0.getOpcode() == X86ISD::SETCC_CARRY) {
54136     SDValue Setcc = DAG.getNode(X86ISD::SETCC_CARRY, dl, VT, N0->getOperand(0),
54137                                  N0->getOperand(1));
54138     bool ReplaceOtherUses = !N0.hasOneUse();
54139     DCI.CombineTo(N, Setcc);
54140     // Replace other uses with a truncate of the widened setcc_carry.
54141     if (ReplaceOtherUses) {
54142       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
54143                                   N0.getValueType(), Setcc);
54144       DCI.CombineTo(N0.getNode(), Trunc);
54145     }
54146 
54147     return SDValue(N, 0);
54148   }
54149 
54150   if (SDValue NewCMov = combineToExtendCMOV(N, DAG))
54151     return NewCMov;
54152 
54153   if (DCI.isBeforeLegalizeOps())
54154     if (SDValue V = combineExtSetcc(N, DAG, Subtarget))
54155       return V;
54156 
54157   if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), dl, VT, N0,
54158                                                  DAG, DCI, Subtarget))
54159     return V;
54160 
54161   if (VT.isVector())
54162     if (SDValue R = PromoteMaskArithmetic(SDValue(N, 0), dl, DAG, Subtarget))
54163       return R;
54164 
54165   if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget))
54166     return NewAdd;
54167 
54168   if (SDValue R = combineOrCmpEqZeroToCtlzSrl(N, DAG, DCI, Subtarget))
54169     return R;
54170 
54171   // TODO: Combine with any target/faux shuffle.
54172   if (N0.getOpcode() == X86ISD::PACKUS && N0.getValueSizeInBits() == 128 &&
54173       VT.getScalarSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits()) {
54174     SDValue N00 = N0.getOperand(0);
54175     SDValue N01 = N0.getOperand(1);
54176     unsigned NumSrcEltBits = N00.getScalarValueSizeInBits();
54177     APInt ZeroMask = APInt::getHighBitsSet(NumSrcEltBits, NumSrcEltBits / 2);
54178     if ((N00.isUndef() || DAG.MaskedValueIsZero(N00, ZeroMask)) &&
54179         (N01.isUndef() || DAG.MaskedValueIsZero(N01, ZeroMask))) {
54180       return concatSubVectors(N00, N01, DAG, dl);
54181     }
54182   }
54183 
54184   return SDValue();
54185 }
54186 
54187 /// If we have AVX512, but not BWI and this is a vXi16/vXi8 setcc, just
54188 /// pre-promote its result type since vXi1 vectors don't get promoted
54189 /// during type legalization.
truncateAVX512SetCCNoBWI(EVT VT,EVT OpVT,SDValue LHS,SDValue RHS,ISD::CondCode CC,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)54190 static SDValue truncateAVX512SetCCNoBWI(EVT VT, EVT OpVT, SDValue LHS,
54191                                         SDValue RHS, ISD::CondCode CC,
54192                                         const SDLoc &DL, SelectionDAG &DAG,
54193                                         const X86Subtarget &Subtarget) {
54194   if (Subtarget.hasAVX512() && !Subtarget.hasBWI() && VT.isVector() &&
54195       VT.getVectorElementType() == MVT::i1 &&
54196       (OpVT.getVectorElementType() == MVT::i8 ||
54197        OpVT.getVectorElementType() == MVT::i16)) {
54198     SDValue Setcc = DAG.getSetCC(DL, OpVT, LHS, RHS, CC);
54199     return DAG.getNode(ISD::TRUNCATE, DL, VT, Setcc);
54200   }
54201   return SDValue();
54202 }
54203 
combineSetCC(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)54204 static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
54205                             TargetLowering::DAGCombinerInfo &DCI,
54206                             const X86Subtarget &Subtarget) {
54207   const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
54208   const SDValue LHS = N->getOperand(0);
54209   const SDValue RHS = N->getOperand(1);
54210   EVT VT = N->getValueType(0);
54211   EVT OpVT = LHS.getValueType();
54212   SDLoc DL(N);
54213 
54214   if (CC == ISD::SETNE || CC == ISD::SETEQ) {
54215     if (SDValue V = combineVectorSizedSetCCEquality(VT, LHS, RHS, CC, DL, DAG,
54216                                                     Subtarget))
54217       return V;
54218 
54219     if (VT == MVT::i1) {
54220       X86::CondCode X86CC;
54221       if (SDValue V =
54222               MatchVectorAllEqualTest(LHS, RHS, CC, DL, Subtarget, DAG, X86CC))
54223         return DAG.getNode(ISD::TRUNCATE, DL, VT, getSETCC(X86CC, V, DL, DAG));
54224     }
54225 
54226     if (OpVT.isScalarInteger()) {
54227       // cmpeq(or(X,Y),X) --> cmpeq(and(~X,Y),0)
54228       // cmpne(or(X,Y),X) --> cmpne(and(~X,Y),0)
54229       auto MatchOrCmpEq = [&](SDValue N0, SDValue N1) {
54230         if (N0.getOpcode() == ISD::OR && N0->hasOneUse()) {
54231           if (N0.getOperand(0) == N1)
54232             return DAG.getNode(ISD::AND, DL, OpVT, DAG.getNOT(DL, N1, OpVT),
54233                                N0.getOperand(1));
54234           if (N0.getOperand(1) == N1)
54235             return DAG.getNode(ISD::AND, DL, OpVT, DAG.getNOT(DL, N1, OpVT),
54236                                N0.getOperand(0));
54237         }
54238         return SDValue();
54239       };
54240       if (SDValue AndN = MatchOrCmpEq(LHS, RHS))
54241         return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC);
54242       if (SDValue AndN = MatchOrCmpEq(RHS, LHS))
54243         return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC);
54244 
54245       // cmpeq(and(X,Y),Y) --> cmpeq(and(~X,Y),0)
54246       // cmpne(and(X,Y),Y) --> cmpne(and(~X,Y),0)
54247       auto MatchAndCmpEq = [&](SDValue N0, SDValue N1) {
54248         if (N0.getOpcode() == ISD::AND && N0->hasOneUse()) {
54249           if (N0.getOperand(0) == N1)
54250             return DAG.getNode(ISD::AND, DL, OpVT, N1,
54251                                DAG.getNOT(DL, N0.getOperand(1), OpVT));
54252           if (N0.getOperand(1) == N1)
54253             return DAG.getNode(ISD::AND, DL, OpVT, N1,
54254                                DAG.getNOT(DL, N0.getOperand(0), OpVT));
54255         }
54256         return SDValue();
54257       };
54258       if (SDValue AndN = MatchAndCmpEq(LHS, RHS))
54259         return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC);
54260       if (SDValue AndN = MatchAndCmpEq(RHS, LHS))
54261         return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC);
54262 
54263       // cmpeq(trunc(x),C) --> cmpeq(x,C)
54264       // cmpne(trunc(x),C) --> cmpne(x,C)
54265       // iff x upper bits are zero.
54266       if (LHS.getOpcode() == ISD::TRUNCATE &&
54267           LHS.getOperand(0).getScalarValueSizeInBits() >= 32 &&
54268           isa<ConstantSDNode>(RHS) && !DCI.isBeforeLegalize()) {
54269         EVT SrcVT = LHS.getOperand(0).getValueType();
54270         APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
54271                                                 OpVT.getScalarSizeInBits());
54272         const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54273         if (DAG.MaskedValueIsZero(LHS.getOperand(0), UpperBits) &&
54274             TLI.isTypeLegal(LHS.getOperand(0).getValueType()))
54275           return DAG.getSetCC(DL, VT, LHS.getOperand(0),
54276                               DAG.getZExtOrTrunc(RHS, DL, SrcVT), CC);
54277       }
54278 
54279       // With C as a power of 2 and C != 0 and C != INT_MIN:
54280       //    icmp eq Abs(X) C ->
54281       //        (icmp eq A, C) | (icmp eq A, -C)
54282       //    icmp ne Abs(X) C ->
54283       //        (icmp ne A, C) & (icmp ne A, -C)
54284       // Both of these patterns can be better optimized in
54285       // DAGCombiner::foldAndOrOfSETCC. Note this only applies for scalar
54286       // integers which is checked above.
54287       if (LHS.getOpcode() == ISD::ABS && LHS.hasOneUse()) {
54288         if (auto *C = dyn_cast<ConstantSDNode>(RHS)) {
54289           const APInt &CInt = C->getAPIntValue();
54290           // We can better optimize this case in DAGCombiner::foldAndOrOfSETCC.
54291           if (CInt.isPowerOf2() && !CInt.isMinSignedValue()) {
54292             SDValue BaseOp = LHS.getOperand(0);
54293             SDValue SETCC0 = DAG.getSetCC(DL, VT, BaseOp, RHS, CC);
54294             SDValue SETCC1 = DAG.getSetCC(
54295                 DL, VT, BaseOp, DAG.getConstant(-CInt, DL, OpVT), CC);
54296             return DAG.getNode(CC == ISD::SETEQ ? ISD::OR : ISD::AND, DL, VT,
54297                                SETCC0, SETCC1);
54298           }
54299         }
54300       }
54301     }
54302   }
54303 
54304   if (VT.isVector() && VT.getVectorElementType() == MVT::i1 &&
54305       (CC == ISD::SETNE || CC == ISD::SETEQ || ISD::isSignedIntSetCC(CC))) {
54306     // Using temporaries to avoid messing up operand ordering for later
54307     // transformations if this doesn't work.
54308     SDValue Op0 = LHS;
54309     SDValue Op1 = RHS;
54310     ISD::CondCode TmpCC = CC;
54311     // Put build_vector on the right.
54312     if (Op0.getOpcode() == ISD::BUILD_VECTOR) {
54313       std::swap(Op0, Op1);
54314       TmpCC = ISD::getSetCCSwappedOperands(TmpCC);
54315     }
54316 
54317     bool IsSEXT0 =
54318         (Op0.getOpcode() == ISD::SIGN_EXTEND) &&
54319         (Op0.getOperand(0).getValueType().getVectorElementType() == MVT::i1);
54320     bool IsVZero1 = ISD::isBuildVectorAllZeros(Op1.getNode());
54321 
54322     if (IsSEXT0 && IsVZero1) {
54323       assert(VT == Op0.getOperand(0).getValueType() &&
54324              "Unexpected operand type");
54325       if (TmpCC == ISD::SETGT)
54326         return DAG.getConstant(0, DL, VT);
54327       if (TmpCC == ISD::SETLE)
54328         return DAG.getConstant(1, DL, VT);
54329       if (TmpCC == ISD::SETEQ || TmpCC == ISD::SETGE)
54330         return DAG.getNOT(DL, Op0.getOperand(0), VT);
54331 
54332       assert((TmpCC == ISD::SETNE || TmpCC == ISD::SETLT) &&
54333              "Unexpected condition code!");
54334       return Op0.getOperand(0);
54335     }
54336   }
54337 
54338   // Try and make unsigned vector comparison signed. On pre AVX512 targets there
54339   // only are unsigned comparisons (`PCMPGT`) and on AVX512 its often better to
54340   // use `PCMPGT` if the result is mean to stay in a vector (and if its going to
54341   // a mask, there are signed AVX512 comparisons).
54342   if (VT.isVector() && OpVT.isVector() && OpVT.isInteger()) {
54343     bool CanMakeSigned = false;
54344     if (ISD::isUnsignedIntSetCC(CC)) {
54345       KnownBits CmpKnown =
54346           DAG.computeKnownBits(LHS).intersectWith(DAG.computeKnownBits(RHS));
54347       // If we know LHS/RHS share the same sign bit at each element we can
54348       // make this signed.
54349       // NOTE: `computeKnownBits` on a vector type aggregates common bits
54350       // across all lanes. So a pattern where the sign varies from lane to
54351       // lane, but at each lane Sign(LHS) is known to equal Sign(RHS), will be
54352       // missed. We could get around this by demanding each lane
54353       // independently, but this isn't the most important optimization and
54354       // that may eat into compile time.
54355       CanMakeSigned =
54356           CmpKnown.Zero.isSignBitSet() || CmpKnown.One.isSignBitSet();
54357     }
54358     if (CanMakeSigned || ISD::isSignedIntSetCC(CC)) {
54359       SDValue LHSOut = LHS;
54360       SDValue RHSOut = RHS;
54361       ISD::CondCode NewCC = CC;
54362       switch (CC) {
54363       case ISD::SETGE:
54364       case ISD::SETUGE:
54365         if (SDValue NewLHS = incDecVectorConstant(LHS, DAG, /*IsInc*/ true,
54366                                                   /*NSW*/ true))
54367           LHSOut = NewLHS;
54368         else if (SDValue NewRHS = incDecVectorConstant(
54369                      RHS, DAG, /*IsInc*/ false, /*NSW*/ true))
54370           RHSOut = NewRHS;
54371         else
54372           break;
54373 
54374         [[fallthrough]];
54375       case ISD::SETUGT:
54376         NewCC = ISD::SETGT;
54377         break;
54378 
54379       case ISD::SETLE:
54380       case ISD::SETULE:
54381         if (SDValue NewLHS = incDecVectorConstant(LHS, DAG, /*IsInc*/ false,
54382                                                   /*NSW*/ true))
54383           LHSOut = NewLHS;
54384         else if (SDValue NewRHS = incDecVectorConstant(RHS, DAG, /*IsInc*/ true,
54385                                                        /*NSW*/ true))
54386           RHSOut = NewRHS;
54387         else
54388           break;
54389 
54390         [[fallthrough]];
54391       case ISD::SETULT:
54392         // Will be swapped to SETGT in LowerVSETCC*.
54393         NewCC = ISD::SETLT;
54394         break;
54395       default:
54396         break;
54397       }
54398       if (NewCC != CC) {
54399         if (SDValue R = truncateAVX512SetCCNoBWI(VT, OpVT, LHSOut, RHSOut,
54400                                                  NewCC, DL, DAG, Subtarget))
54401           return R;
54402         return DAG.getSetCC(DL, VT, LHSOut, RHSOut, NewCC);
54403       }
54404     }
54405   }
54406 
54407   if (SDValue R =
54408           truncateAVX512SetCCNoBWI(VT, OpVT, LHS, RHS, CC, DL, DAG, Subtarget))
54409     return R;
54410 
54411   // In the middle end transforms:
54412   //    `(or (icmp eq X, C), (icmp eq X, C+1))`
54413   //        -> `(icmp ult (add x, -C), 2)`
54414   // Likewise inverted cases with `ugt`.
54415   //
54416   // Since x86, pre avx512, doesn't have unsigned vector compares, this results
54417   // in worse codegen. So, undo the middle-end transform and go back to `(or
54418   // (icmp eq), (icmp eq))` form.
54419   // Also skip AVX1 with ymm vectors, as the umin approach combines better than
54420   // the xmm approach.
54421   //
54422   // NB: We don't handle the similiar simplication of `(and (icmp ne), (icmp
54423   // ne))` as it doesn't end up instruction positive.
54424   // TODO: We might want to do this for avx512 as well if we `sext` the result.
54425   if (VT.isVector() && OpVT.isVector() && OpVT.isInteger() &&
54426       ISD::isUnsignedIntSetCC(CC) && LHS.getOpcode() == ISD::ADD &&
54427       !Subtarget.hasAVX512() &&
54428       (OpVT.getSizeInBits() <= 128 || !Subtarget.hasAVX() ||
54429        Subtarget.hasAVX2()) &&
54430       LHS.hasOneUse()) {
54431 
54432     APInt CmpC;
54433     SDValue AddC = LHS.getOperand(1);
54434     if (ISD::isConstantSplatVector(RHS.getNode(), CmpC) &&
54435         DAG.isConstantIntBuildVectorOrConstantInt(AddC)) {
54436       // See which form we have depending on the constant/condition.
54437       SDValue C0 = SDValue();
54438       SDValue C1 = SDValue();
54439 
54440       // If we had `(add x, -1)` and can lower with `umin`, don't transform as
54441       // we will end up generating an additional constant. Keeping in the
54442       // current form has a slight latency cost, but it probably worth saving a
54443       // constant.
54444       if (ISD::isConstantSplatVectorAllOnes(AddC.getNode()) &&
54445           DAG.getTargetLoweringInfo().isOperationLegal(ISD::UMIN, OpVT)) {
54446         // Pass
54447       }
54448       // Normal Cases
54449       else if ((CC == ISD::SETULT && CmpC == 2) ||
54450                (CC == ISD::SETULE && CmpC == 1)) {
54451         // These will constant fold.
54452         C0 = DAG.getNegative(AddC, DL, OpVT);
54453         C1 = DAG.getNode(ISD::SUB, DL, OpVT, C0,
54454                          DAG.getAllOnesConstant(DL, OpVT));
54455       }
54456       // Inverted Cases
54457       else if ((CC == ISD::SETUGT && (-CmpC) == 3) ||
54458                (CC == ISD::SETUGE && (-CmpC) == 2)) {
54459         // These will constant fold.
54460         C0 = DAG.getNOT(DL, AddC, OpVT);
54461         C1 = DAG.getNode(ISD::ADD, DL, OpVT, C0,
54462                          DAG.getAllOnesConstant(DL, OpVT));
54463       }
54464       if (C0 && C1) {
54465         SDValue NewLHS =
54466             DAG.getSetCC(DL, VT, LHS.getOperand(0), C0, ISD::SETEQ);
54467         SDValue NewRHS =
54468             DAG.getSetCC(DL, VT, LHS.getOperand(0), C1, ISD::SETEQ);
54469         return DAG.getNode(ISD::OR, DL, VT, NewLHS, NewRHS);
54470       }
54471     }
54472   }
54473 
54474   // For an SSE1-only target, lower a comparison of v4f32 to X86ISD::CMPP early
54475   // to avoid scalarization via legalization because v4i32 is not a legal type.
54476   if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32 &&
54477       LHS.getValueType() == MVT::v4f32)
54478     return LowerVSETCC(SDValue(N, 0), Subtarget, DAG);
54479 
54480   // X pred 0.0 --> X pred -X
54481   // If the negation of X already exists, use it in the comparison. This removes
54482   // the need to materialize 0.0 and allows matching to SSE's MIN/MAX
54483   // instructions in patterns with a 'select' node.
54484   if (isNullFPScalarOrVectorConst(RHS)) {
54485     SDVTList FNegVT = DAG.getVTList(OpVT);
54486     if (SDNode *FNeg = DAG.getNodeIfExists(ISD::FNEG, FNegVT, {LHS}))
54487       return DAG.getSetCC(DL, VT, LHS, SDValue(FNeg, 0), CC);
54488   }
54489 
54490   return SDValue();
54491 }
54492 
combineMOVMSK(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)54493 static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
54494                              TargetLowering::DAGCombinerInfo &DCI,
54495                              const X86Subtarget &Subtarget) {
54496   SDValue Src = N->getOperand(0);
54497   MVT SrcVT = Src.getSimpleValueType();
54498   MVT VT = N->getSimpleValueType(0);
54499   unsigned NumBits = VT.getScalarSizeInBits();
54500   unsigned NumElts = SrcVT.getVectorNumElements();
54501   unsigned NumBitsPerElt = SrcVT.getScalarSizeInBits();
54502   assert(VT == MVT::i32 && NumElts <= NumBits && "Unexpected MOVMSK types");
54503 
54504   // Perform constant folding.
54505   APInt UndefElts;
54506   SmallVector<APInt, 32> EltBits;
54507   if (getTargetConstantBitsFromNode(Src, NumBitsPerElt, UndefElts, EltBits,
54508                                     /*AllowWholeUndefs*/ true,
54509                                     /*AllowPartialUndefs*/ true)) {
54510     APInt Imm(32, 0);
54511     for (unsigned Idx = 0; Idx != NumElts; ++Idx)
54512       if (!UndefElts[Idx] && EltBits[Idx].isNegative())
54513         Imm.setBit(Idx);
54514 
54515     return DAG.getConstant(Imm, SDLoc(N), VT);
54516   }
54517 
54518   // Look through int->fp bitcasts that don't change the element width.
54519   unsigned EltWidth = SrcVT.getScalarSizeInBits();
54520   if (Subtarget.hasSSE2() && Src.getOpcode() == ISD::BITCAST &&
54521       Src.getOperand(0).getScalarValueSizeInBits() == EltWidth)
54522     return DAG.getNode(X86ISD::MOVMSK, SDLoc(N), VT, Src.getOperand(0));
54523 
54524   // Fold movmsk(not(x)) -> not(movmsk(x)) to improve folding of movmsk results
54525   // with scalar comparisons.
54526   if (SDValue NotSrc = IsNOT(Src, DAG)) {
54527     SDLoc DL(N);
54528     APInt NotMask = APInt::getLowBitsSet(NumBits, NumElts);
54529     NotSrc = DAG.getBitcast(SrcVT, NotSrc);
54530     return DAG.getNode(ISD::XOR, DL, VT,
54531                        DAG.getNode(X86ISD::MOVMSK, DL, VT, NotSrc),
54532                        DAG.getConstant(NotMask, DL, VT));
54533   }
54534 
54535   // Fold movmsk(icmp_sgt(x,-1)) -> not(movmsk(x)) to improve folding of movmsk
54536   // results with scalar comparisons.
54537   if (Src.getOpcode() == X86ISD::PCMPGT &&
54538       ISD::isBuildVectorAllOnes(Src.getOperand(1).getNode())) {
54539     SDLoc DL(N);
54540     APInt NotMask = APInt::getLowBitsSet(NumBits, NumElts);
54541     return DAG.getNode(ISD::XOR, DL, VT,
54542                        DAG.getNode(X86ISD::MOVMSK, DL, VT, Src.getOperand(0)),
54543                        DAG.getConstant(NotMask, DL, VT));
54544   }
54545 
54546   // Fold movmsk(icmp_eq(and(x,c1),c1)) -> movmsk(shl(x,c2))
54547   // Fold movmsk(icmp_eq(and(x,c1),0)) -> movmsk(not(shl(x,c2)))
54548   // iff pow2splat(c1).
54549   // Use KnownBits to determine if only a single bit is non-zero
54550   // in each element (pow2 or zero), and shift that bit to the msb.
54551   if (Src.getOpcode() == X86ISD::PCMPEQ) {
54552     KnownBits KnownLHS = DAG.computeKnownBits(Src.getOperand(0));
54553     KnownBits KnownRHS = DAG.computeKnownBits(Src.getOperand(1));
54554     unsigned ShiftAmt = KnownLHS.countMinLeadingZeros();
54555     if (KnownLHS.countMaxPopulation() == 1 &&
54556         (KnownRHS.isZero() || (KnownRHS.countMaxPopulation() == 1 &&
54557                                ShiftAmt == KnownRHS.countMinLeadingZeros()))) {
54558       SDLoc DL(N);
54559       MVT ShiftVT = SrcVT;
54560       SDValue ShiftLHS = Src.getOperand(0);
54561       SDValue ShiftRHS = Src.getOperand(1);
54562       if (ShiftVT.getScalarType() == MVT::i8) {
54563         // vXi8 shifts - we only care about the signbit so can use PSLLW.
54564         ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
54565         ShiftLHS = DAG.getBitcast(ShiftVT, ShiftLHS);
54566         ShiftRHS = DAG.getBitcast(ShiftVT, ShiftRHS);
54567       }
54568       ShiftLHS = getTargetVShiftByConstNode(X86ISD::VSHLI, DL, ShiftVT,
54569                                             ShiftLHS, ShiftAmt, DAG);
54570       ShiftRHS = getTargetVShiftByConstNode(X86ISD::VSHLI, DL, ShiftVT,
54571                                             ShiftRHS, ShiftAmt, DAG);
54572       ShiftLHS = DAG.getBitcast(SrcVT, ShiftLHS);
54573       ShiftRHS = DAG.getBitcast(SrcVT, ShiftRHS);
54574       SDValue Res = DAG.getNode(ISD::XOR, DL, SrcVT, ShiftLHS, ShiftRHS);
54575       return DAG.getNode(X86ISD::MOVMSK, DL, VT, DAG.getNOT(DL, Res, SrcVT));
54576     }
54577   }
54578 
54579   // Fold movmsk(logic(X,C)) -> logic(movmsk(X),C)
54580   if (N->isOnlyUserOf(Src.getNode())) {
54581     SDValue SrcBC = peekThroughOneUseBitcasts(Src);
54582     if (ISD::isBitwiseLogicOp(SrcBC.getOpcode())) {
54583       APInt UndefElts;
54584       SmallVector<APInt, 32> EltBits;
54585       if (getTargetConstantBitsFromNode(SrcBC.getOperand(1), NumBitsPerElt,
54586                                         UndefElts, EltBits)) {
54587         APInt Mask = APInt::getZero(NumBits);
54588         for (unsigned Idx = 0; Idx != NumElts; ++Idx) {
54589           if (!UndefElts[Idx] && EltBits[Idx].isNegative())
54590             Mask.setBit(Idx);
54591         }
54592         SDLoc DL(N);
54593         SDValue NewSrc = DAG.getBitcast(SrcVT, SrcBC.getOperand(0));
54594         SDValue NewMovMsk = DAG.getNode(X86ISD::MOVMSK, DL, VT, NewSrc);
54595         return DAG.getNode(SrcBC.getOpcode(), DL, VT, NewMovMsk,
54596                            DAG.getConstant(Mask, DL, VT));
54597       }
54598     }
54599   }
54600 
54601   // Simplify the inputs.
54602   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54603   APInt DemandedMask(APInt::getAllOnes(NumBits));
54604   if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI))
54605     return SDValue(N, 0);
54606 
54607   return SDValue();
54608 }
54609 
combineTESTP(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)54610 static SDValue combineTESTP(SDNode *N, SelectionDAG &DAG,
54611                             TargetLowering::DAGCombinerInfo &DCI,
54612                             const X86Subtarget &Subtarget) {
54613   MVT VT = N->getSimpleValueType(0);
54614   unsigned NumBits = VT.getScalarSizeInBits();
54615 
54616   // Simplify the inputs.
54617   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54618   APInt DemandedMask(APInt::getAllOnes(NumBits));
54619   if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI))
54620     return SDValue(N, 0);
54621 
54622   return SDValue();
54623 }
54624 
combineX86GatherScatter(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)54625 static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
54626                                        TargetLowering::DAGCombinerInfo &DCI) {
54627   auto *MemOp = cast<X86MaskedGatherScatterSDNode>(N);
54628   SDValue Mask = MemOp->getMask();
54629 
54630   // With vector masks we only demand the upper bit of the mask.
54631   if (Mask.getScalarValueSizeInBits() != 1) {
54632     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54633     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
54634     if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) {
54635       if (N->getOpcode() != ISD::DELETED_NODE)
54636         DCI.AddToWorklist(N);
54637       return SDValue(N, 0);
54638     }
54639   }
54640 
54641   return SDValue();
54642 }
54643 
rebuildGatherScatter(MaskedGatherScatterSDNode * GorS,SDValue Index,SDValue Base,SDValue Scale,SelectionDAG & DAG)54644 static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
54645                                     SDValue Index, SDValue Base, SDValue Scale,
54646                                     SelectionDAG &DAG) {
54647   SDLoc DL(GorS);
54648 
54649   if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
54650     SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
54651                       Gather->getMask(), Base, Index, Scale } ;
54652     return DAG.getMaskedGather(Gather->getVTList(),
54653                                Gather->getMemoryVT(), DL, Ops,
54654                                Gather->getMemOperand(),
54655                                Gather->getIndexType(),
54656                                Gather->getExtensionType());
54657   }
54658   auto *Scatter = cast<MaskedScatterSDNode>(GorS);
54659   SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
54660                     Scatter->getMask(), Base, Index, Scale };
54661   return DAG.getMaskedScatter(Scatter->getVTList(),
54662                               Scatter->getMemoryVT(), DL,
54663                               Ops, Scatter->getMemOperand(),
54664                               Scatter->getIndexType(),
54665                               Scatter->isTruncatingStore());
54666 }
54667 
combineGatherScatter(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)54668 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
54669                                     TargetLowering::DAGCombinerInfo &DCI) {
54670   SDLoc DL(N);
54671   auto *GorS = cast<MaskedGatherScatterSDNode>(N);
54672   SDValue Index = GorS->getIndex();
54673   SDValue Base = GorS->getBasePtr();
54674   SDValue Scale = GorS->getScale();
54675   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54676 
54677   if (DCI.isBeforeLegalize()) {
54678     unsigned IndexWidth = Index.getScalarValueSizeInBits();
54679 
54680     // Shrink constant indices if they are larger than 32-bits.
54681     // Only do this before legalize types since v2i64 could become v2i32.
54682     // FIXME: We could check that the type is legal if we're after legalize
54683     // types, but then we would need to construct test cases where that happens.
54684     // FIXME: We could support more than just constant vectors, but we need to
54685     // careful with costing. A truncate that can be optimized out would be fine.
54686     // Otherwise we might only want to create a truncate if it avoids a split.
54687     if (auto *BV = dyn_cast<BuildVectorSDNode>(Index)) {
54688       if (BV->isConstant() && IndexWidth > 32 &&
54689           DAG.ComputeNumSignBits(Index) > (IndexWidth - 32)) {
54690         EVT NewVT = Index.getValueType().changeVectorElementType(MVT::i32);
54691         Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
54692         return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
54693       }
54694     }
54695 
54696     // Shrink any sign/zero extends from 32 or smaller to larger than 32 if
54697     // there are sufficient sign bits. Only do this before legalize types to
54698     // avoid creating illegal types in truncate.
54699     if ((Index.getOpcode() == ISD::SIGN_EXTEND ||
54700          Index.getOpcode() == ISD::ZERO_EXTEND) &&
54701         IndexWidth > 32 &&
54702         Index.getOperand(0).getScalarValueSizeInBits() <= 32 &&
54703         DAG.ComputeNumSignBits(Index) > (IndexWidth - 32)) {
54704       EVT NewVT = Index.getValueType().changeVectorElementType(MVT::i32);
54705       Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
54706       return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
54707     }
54708   }
54709 
54710   EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
54711   // Try to move splat constant adders from the index operand to the base
54712   // pointer operand. Taking care to multiply by the scale. We can only do
54713   // this when index element type is the same as the pointer type.
54714   // Otherwise we need to be sure the math doesn't wrap before the scale.
54715   if (Index.getOpcode() == ISD::ADD &&
54716       Index.getValueType().getVectorElementType() == PtrVT &&
54717       isa<ConstantSDNode>(Scale)) {
54718     uint64_t ScaleAmt = Scale->getAsZExtVal();
54719     if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
54720       BitVector UndefElts;
54721       if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
54722         // FIXME: Allow non-constant?
54723         if (UndefElts.none()) {
54724           // Apply the scale.
54725           APInt Adder = C->getAPIntValue() * ScaleAmt;
54726           // Add it to the existing base.
54727           Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
54728                              DAG.getConstant(Adder, DL, PtrVT));
54729           Index = Index.getOperand(0);
54730           return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
54731         }
54732       }
54733 
54734       // It's also possible base is just a constant. In that case, just
54735       // replace it with 0 and move the displacement into the index.
54736       if (BV->isConstant() && isa<ConstantSDNode>(Base) &&
54737           isOneConstant(Scale)) {
54738         SDValue Splat = DAG.getSplatBuildVector(Index.getValueType(), DL, Base);
54739         // Combine the constant build_vector and the constant base.
54740         Splat = DAG.getNode(ISD::ADD, DL, Index.getValueType(),
54741                             Index.getOperand(1), Splat);
54742         // Add to the LHS of the original Index add.
54743         Index = DAG.getNode(ISD::ADD, DL, Index.getValueType(),
54744                             Index.getOperand(0), Splat);
54745         Base = DAG.getConstant(0, DL, Base.getValueType());
54746         return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
54747       }
54748     }
54749   }
54750 
54751   if (DCI.isBeforeLegalizeOps()) {
54752     unsigned IndexWidth = Index.getScalarValueSizeInBits();
54753 
54754     // Make sure the index is either i32 or i64
54755     if (IndexWidth != 32 && IndexWidth != 64) {
54756       MVT EltVT = IndexWidth > 32 ? MVT::i64 : MVT::i32;
54757       EVT IndexVT = Index.getValueType().changeVectorElementType(EltVT);
54758       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
54759       return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
54760     }
54761   }
54762 
54763   // With vector masks we only demand the upper bit of the mask.
54764   SDValue Mask = GorS->getMask();
54765   if (Mask.getScalarValueSizeInBits() != 1) {
54766     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
54767     if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) {
54768       if (N->getOpcode() != ISD::DELETED_NODE)
54769         DCI.AddToWorklist(N);
54770       return SDValue(N, 0);
54771     }
54772   }
54773 
54774   return SDValue();
54775 }
54776 
54777 // Optimize  RES = X86ISD::SETCC CONDCODE, EFLAG_INPUT
combineX86SetCC(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)54778 static SDValue combineX86SetCC(SDNode *N, SelectionDAG &DAG,
54779                                const X86Subtarget &Subtarget) {
54780   SDLoc DL(N);
54781   X86::CondCode CC = X86::CondCode(N->getConstantOperandVal(0));
54782   SDValue EFLAGS = N->getOperand(1);
54783 
54784   // Try to simplify the EFLAGS and condition code operands.
54785   if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget))
54786     return getSETCC(CC, Flags, DL, DAG);
54787 
54788   return SDValue();
54789 }
54790 
54791 /// Optimize branch condition evaluation.
combineBrCond(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)54792 static SDValue combineBrCond(SDNode *N, SelectionDAG &DAG,
54793                              const X86Subtarget &Subtarget) {
54794   SDLoc DL(N);
54795   SDValue EFLAGS = N->getOperand(3);
54796   X86::CondCode CC = X86::CondCode(N->getConstantOperandVal(2));
54797 
54798   // Try to simplify the EFLAGS and condition code operands.
54799   // Make sure to not keep references to operands, as combineSetCCEFLAGS can
54800   // RAUW them under us.
54801   if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget)) {
54802     SDValue Cond = DAG.getTargetConstant(CC, DL, MVT::i8);
54803     return DAG.getNode(X86ISD::BRCOND, DL, N->getVTList(), N->getOperand(0),
54804                        N->getOperand(1), Cond, Flags);
54805   }
54806 
54807   return SDValue();
54808 }
54809 
54810 // TODO: Could we move this to DAGCombine?
combineVectorCompareAndMaskUnaryOp(SDNode * N,SelectionDAG & DAG)54811 static SDValue combineVectorCompareAndMaskUnaryOp(SDNode *N,
54812                                                   SelectionDAG &DAG) {
54813   // Take advantage of vector comparisons (etc.) producing 0 or -1 in each lane
54814   // to optimize away operation when it's from a constant.
54815   //
54816   // The general transformation is:
54817   //    UNARYOP(AND(VECTOR_CMP(x,y), constant)) -->
54818   //       AND(VECTOR_CMP(x,y), constant2)
54819   //    constant2 = UNARYOP(constant)
54820 
54821   // Early exit if this isn't a vector operation, the operand of the
54822   // unary operation isn't a bitwise AND, or if the sizes of the operations
54823   // aren't the same.
54824   EVT VT = N->getValueType(0);
54825   bool IsStrict = N->isStrictFPOpcode();
54826   unsigned NumEltBits = VT.getScalarSizeInBits();
54827   SDValue Op0 = N->getOperand(IsStrict ? 1 : 0);
54828   if (!VT.isVector() || Op0.getOpcode() != ISD::AND ||
54829       DAG.ComputeNumSignBits(Op0.getOperand(0)) != NumEltBits ||
54830       VT.getSizeInBits() != Op0.getValueSizeInBits())
54831     return SDValue();
54832 
54833   // Now check that the other operand of the AND is a constant. We could
54834   // make the transformation for non-constant splats as well, but it's unclear
54835   // that would be a benefit as it would not eliminate any operations, just
54836   // perform one more step in scalar code before moving to the vector unit.
54837   if (auto *BV = dyn_cast<BuildVectorSDNode>(Op0.getOperand(1))) {
54838     // Bail out if the vector isn't a constant.
54839     if (!BV->isConstant())
54840       return SDValue();
54841 
54842     // Everything checks out. Build up the new and improved node.
54843     SDLoc DL(N);
54844     EVT IntVT = BV->getValueType(0);
54845     // Create a new constant of the appropriate type for the transformed
54846     // DAG.
54847     SDValue SourceConst;
54848     if (IsStrict)
54849       SourceConst = DAG.getNode(N->getOpcode(), DL, {VT, MVT::Other},
54850                                 {N->getOperand(0), SDValue(BV, 0)});
54851     else
54852       SourceConst = DAG.getNode(N->getOpcode(), DL, VT, SDValue(BV, 0));
54853     // The AND node needs bitcasts to/from an integer vector type around it.
54854     SDValue MaskConst = DAG.getBitcast(IntVT, SourceConst);
54855     SDValue NewAnd = DAG.getNode(ISD::AND, DL, IntVT, Op0->getOperand(0),
54856                                  MaskConst);
54857     SDValue Res = DAG.getBitcast(VT, NewAnd);
54858     if (IsStrict)
54859       return DAG.getMergeValues({Res, SourceConst.getValue(1)}, DL);
54860     return Res;
54861   }
54862 
54863   return SDValue();
54864 }
54865 
54866 /// If we are converting a value to floating-point, try to replace scalar
54867 /// truncate of an extracted vector element with a bitcast. This tries to keep
54868 /// the sequence on XMM registers rather than moving between vector and GPRs.
combineToFPTruncExtElt(SDNode * N,SelectionDAG & DAG)54869 static SDValue combineToFPTruncExtElt(SDNode *N, SelectionDAG &DAG) {
54870   // TODO: This is currently only used by combineSIntToFP, but it is generalized
54871   //       to allow being called by any similar cast opcode.
54872   // TODO: Consider merging this into lowering: vectorizeExtractedCast().
54873   SDValue Trunc = N->getOperand(0);
54874   if (!Trunc.hasOneUse() || Trunc.getOpcode() != ISD::TRUNCATE)
54875     return SDValue();
54876 
54877   SDValue ExtElt = Trunc.getOperand(0);
54878   if (!ExtElt.hasOneUse() || ExtElt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
54879       !isNullConstant(ExtElt.getOperand(1)))
54880     return SDValue();
54881 
54882   EVT TruncVT = Trunc.getValueType();
54883   EVT SrcVT = ExtElt.getValueType();
54884   unsigned DestWidth = TruncVT.getSizeInBits();
54885   unsigned SrcWidth = SrcVT.getSizeInBits();
54886   if (SrcWidth % DestWidth != 0)
54887     return SDValue();
54888 
54889   // inttofp (trunc (extelt X, 0)) --> inttofp (extelt (bitcast X), 0)
54890   EVT SrcVecVT = ExtElt.getOperand(0).getValueType();
54891   unsigned VecWidth = SrcVecVT.getSizeInBits();
54892   unsigned NumElts = VecWidth / DestWidth;
54893   EVT BitcastVT = EVT::getVectorVT(*DAG.getContext(), TruncVT, NumElts);
54894   SDValue BitcastVec = DAG.getBitcast(BitcastVT, ExtElt.getOperand(0));
54895   SDLoc DL(N);
54896   SDValue NewExtElt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TruncVT,
54897                                   BitcastVec, ExtElt.getOperand(1));
54898   return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), NewExtElt);
54899 }
54900 
combineUIntToFP(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)54901 static SDValue combineUIntToFP(SDNode *N, SelectionDAG &DAG,
54902                                const X86Subtarget &Subtarget) {
54903   bool IsStrict = N->isStrictFPOpcode();
54904   SDValue Op0 = N->getOperand(IsStrict ? 1 : 0);
54905   EVT VT = N->getValueType(0);
54906   EVT InVT = Op0.getValueType();
54907 
54908   // Using i16 as an intermediate type is a bad idea, unless we have HW support
54909   // for it. Therefore for type sizes equal or smaller than 32 just go with i32.
54910   // if hasFP16 support:
54911   //   UINT_TO_FP(vXi1~15)  -> SINT_TO_FP(ZEXT(vXi1~15  to vXi16))
54912   //   UINT_TO_FP(vXi17~31) -> SINT_TO_FP(ZEXT(vXi17~31 to vXi32))
54913   // else
54914   //   UINT_TO_FP(vXi1~31) -> SINT_TO_FP(ZEXT(vXi1~31 to vXi32))
54915   // UINT_TO_FP(vXi33~63) -> SINT_TO_FP(ZEXT(vXi33~63 to vXi64))
54916   if (InVT.isVector() && VT.getVectorElementType() == MVT::f16) {
54917     unsigned ScalarSize = InVT.getScalarSizeInBits();
54918     if ((ScalarSize == 16 && Subtarget.hasFP16()) || ScalarSize == 32 ||
54919         ScalarSize >= 64)
54920       return SDValue();
54921     SDLoc dl(N);
54922     EVT DstVT =
54923         EVT::getVectorVT(*DAG.getContext(),
54924                          (Subtarget.hasFP16() && ScalarSize < 16) ? MVT::i16
54925                          : ScalarSize < 32                        ? MVT::i32
54926                                                                   : MVT::i64,
54927                          InVT.getVectorNumElements());
54928     SDValue P = DAG.getNode(ISD::ZERO_EXTEND, dl, DstVT, Op0);
54929     if (IsStrict)
54930       return DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {VT, MVT::Other},
54931                          {N->getOperand(0), P});
54932     return DAG.getNode(ISD::SINT_TO_FP, dl, VT, P);
54933   }
54934 
54935   // UINT_TO_FP(vXi1) -> SINT_TO_FP(ZEXT(vXi1 to vXi32))
54936   // UINT_TO_FP(vXi8) -> SINT_TO_FP(ZEXT(vXi8 to vXi32))
54937   // UINT_TO_FP(vXi16) -> SINT_TO_FP(ZEXT(vXi16 to vXi32))
54938   if (InVT.isVector() && InVT.getScalarSizeInBits() < 32 &&
54939       VT.getScalarType() != MVT::f16) {
54940     SDLoc dl(N);
54941     EVT DstVT = InVT.changeVectorElementType(MVT::i32);
54942     SDValue P = DAG.getNode(ISD::ZERO_EXTEND, dl, DstVT, Op0);
54943 
54944     // UINT_TO_FP isn't legal without AVX512 so use SINT_TO_FP.
54945     if (IsStrict)
54946       return DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {VT, MVT::Other},
54947                          {N->getOperand(0), P});
54948     return DAG.getNode(ISD::SINT_TO_FP, dl, VT, P);
54949   }
54950 
54951   // Since UINT_TO_FP is legal (it's marked custom), dag combiner won't
54952   // optimize it to a SINT_TO_FP when the sign bit is known zero. Perform
54953   // the optimization here.
54954   SDNodeFlags Flags = N->getFlags();
54955   if (Flags.hasNonNeg() || DAG.SignBitIsZero(Op0)) {
54956     if (IsStrict)
54957       return DAG.getNode(ISD::STRICT_SINT_TO_FP, SDLoc(N), {VT, MVT::Other},
54958                          {N->getOperand(0), Op0});
54959     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, Op0);
54960   }
54961 
54962   return SDValue();
54963 }
54964 
combineSIntToFP(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)54965 static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG,
54966                                TargetLowering::DAGCombinerInfo &DCI,
54967                                const X86Subtarget &Subtarget) {
54968   // First try to optimize away the conversion entirely when it's
54969   // conditionally from a constant. Vectors only.
54970   bool IsStrict = N->isStrictFPOpcode();
54971   if (SDValue Res = combineVectorCompareAndMaskUnaryOp(N, DAG))
54972     return Res;
54973 
54974   // Now move on to more general possibilities.
54975   SDValue Op0 = N->getOperand(IsStrict ? 1 : 0);
54976   EVT VT = N->getValueType(0);
54977   EVT InVT = Op0.getValueType();
54978 
54979   // Using i16 as an intermediate type is a bad idea, unless we have HW support
54980   // for it. Therefore for type sizes equal or smaller than 32 just go with i32.
54981   // if hasFP16 support:
54982   //   SINT_TO_FP(vXi1~15)  -> SINT_TO_FP(SEXT(vXi1~15  to vXi16))
54983   //   SINT_TO_FP(vXi17~31) -> SINT_TO_FP(SEXT(vXi17~31 to vXi32))
54984   // else
54985   //   SINT_TO_FP(vXi1~31) -> SINT_TO_FP(ZEXT(vXi1~31 to vXi32))
54986   // SINT_TO_FP(vXi33~63) -> SINT_TO_FP(SEXT(vXi33~63 to vXi64))
54987   if (InVT.isVector() && VT.getVectorElementType() == MVT::f16) {
54988     unsigned ScalarSize = InVT.getScalarSizeInBits();
54989     if ((ScalarSize == 16 && Subtarget.hasFP16()) || ScalarSize == 32 ||
54990         ScalarSize >= 64)
54991       return SDValue();
54992     SDLoc dl(N);
54993     EVT DstVT =
54994         EVT::getVectorVT(*DAG.getContext(),
54995                          (Subtarget.hasFP16() && ScalarSize < 16) ? MVT::i16
54996                          : ScalarSize < 32                        ? MVT::i32
54997                                                                   : MVT::i64,
54998                          InVT.getVectorNumElements());
54999     SDValue P = DAG.getNode(ISD::SIGN_EXTEND, dl, DstVT, Op0);
55000     if (IsStrict)
55001       return DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {VT, MVT::Other},
55002                          {N->getOperand(0), P});
55003     return DAG.getNode(ISD::SINT_TO_FP, dl, VT, P);
55004   }
55005 
55006   // SINT_TO_FP(vXi1) -> SINT_TO_FP(SEXT(vXi1 to vXi32))
55007   // SINT_TO_FP(vXi8) -> SINT_TO_FP(SEXT(vXi8 to vXi32))
55008   // SINT_TO_FP(vXi16) -> SINT_TO_FP(SEXT(vXi16 to vXi32))
55009   if (InVT.isVector() && InVT.getScalarSizeInBits() < 32 &&
55010       VT.getScalarType() != MVT::f16) {
55011     SDLoc dl(N);
55012     EVT DstVT = InVT.changeVectorElementType(MVT::i32);
55013     SDValue P = DAG.getNode(ISD::SIGN_EXTEND, dl, DstVT, Op0);
55014     if (IsStrict)
55015       return DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {VT, MVT::Other},
55016                          {N->getOperand(0), P});
55017     return DAG.getNode(ISD::SINT_TO_FP, dl, VT, P);
55018   }
55019 
55020   // Without AVX512DQ we only support i64 to float scalar conversion. For both
55021   // vectors and scalars, see if we know that the upper bits are all the sign
55022   // bit, in which case we can truncate the input to i32 and convert from that.
55023   if (InVT.getScalarSizeInBits() > 32 && !Subtarget.hasDQI()) {
55024     unsigned BitWidth = InVT.getScalarSizeInBits();
55025     unsigned NumSignBits = DAG.ComputeNumSignBits(Op0);
55026     if (NumSignBits >= (BitWidth - 31)) {
55027       EVT TruncVT = MVT::i32;
55028       if (InVT.isVector())
55029         TruncVT = InVT.changeVectorElementType(TruncVT);
55030       SDLoc dl(N);
55031       if (DCI.isBeforeLegalize() || TruncVT != MVT::v2i32) {
55032         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Op0);
55033         if (IsStrict)
55034           return DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, {VT, MVT::Other},
55035                              {N->getOperand(0), Trunc});
55036         return DAG.getNode(ISD::SINT_TO_FP, dl, VT, Trunc);
55037       }
55038       // If we're after legalize and the type is v2i32 we need to shuffle and
55039       // use CVTSI2P.
55040       assert(InVT == MVT::v2i64 && "Unexpected VT!");
55041       SDValue Cast = DAG.getBitcast(MVT::v4i32, Op0);
55042       SDValue Shuf = DAG.getVectorShuffle(MVT::v4i32, dl, Cast, Cast,
55043                                           { 0, 2, -1, -1 });
55044       if (IsStrict)
55045         return DAG.getNode(X86ISD::STRICT_CVTSI2P, dl, {VT, MVT::Other},
55046                            {N->getOperand(0), Shuf});
55047       return DAG.getNode(X86ISD::CVTSI2P, dl, VT, Shuf);
55048     }
55049   }
55050 
55051   // Transform (SINT_TO_FP (i64 ...)) into an x87 operation if we have
55052   // a 32-bit target where SSE doesn't support i64->FP operations.
55053   if (!Subtarget.useSoftFloat() && Subtarget.hasX87() &&
55054       Op0.getOpcode() == ISD::LOAD) {
55055     LoadSDNode *Ld = cast<LoadSDNode>(Op0.getNode());
55056 
55057     // This transformation is not supported if the result type is f16 or f128.
55058     if (VT == MVT::f16 || VT == MVT::f128)
55059       return SDValue();
55060 
55061     // If we have AVX512DQ we can use packed conversion instructions unless
55062     // the VT is f80.
55063     if (Subtarget.hasDQI() && VT != MVT::f80)
55064       return SDValue();
55065 
55066     if (Ld->isSimple() && !VT.isVector() && ISD::isNormalLoad(Op0.getNode()) &&
55067         Op0.hasOneUse() && !Subtarget.is64Bit() && InVT == MVT::i64) {
55068       std::pair<SDValue, SDValue> Tmp =
55069           Subtarget.getTargetLowering()->BuildFILD(
55070               VT, InVT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(),
55071               Ld->getPointerInfo(), Ld->getOriginalAlign(), DAG);
55072       DAG.ReplaceAllUsesOfValueWith(Op0.getValue(1), Tmp.second);
55073       return Tmp.first;
55074     }
55075   }
55076 
55077   if (IsStrict)
55078     return SDValue();
55079 
55080   if (SDValue V = combineToFPTruncExtElt(N, DAG))
55081     return V;
55082 
55083   return SDValue();
55084 }
55085 
needCarryOrOverflowFlag(SDValue Flags)55086 static bool needCarryOrOverflowFlag(SDValue Flags) {
55087   assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!");
55088 
55089   for (const SDNode *User : Flags->uses()) {
55090     X86::CondCode CC;
55091     switch (User->getOpcode()) {
55092     default:
55093       // Be conservative.
55094       return true;
55095     case X86ISD::SETCC:
55096     case X86ISD::SETCC_CARRY:
55097       CC = (X86::CondCode)User->getConstantOperandVal(0);
55098       break;
55099     case X86ISD::BRCOND:
55100     case X86ISD::CMOV:
55101       CC = (X86::CondCode)User->getConstantOperandVal(2);
55102       break;
55103     }
55104 
55105     switch (CC) {
55106     // clang-format off
55107     default: break;
55108     case X86::COND_A: case X86::COND_AE:
55109     case X86::COND_B: case X86::COND_BE:
55110     case X86::COND_O: case X86::COND_NO:
55111     case X86::COND_G: case X86::COND_GE:
55112     case X86::COND_L: case X86::COND_LE:
55113       return true;
55114     // clang-format on
55115     }
55116   }
55117 
55118   return false;
55119 }
55120 
onlyZeroFlagUsed(SDValue Flags)55121 static bool onlyZeroFlagUsed(SDValue Flags) {
55122   assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!");
55123 
55124   for (const SDNode *User : Flags->uses()) {
55125     unsigned CCOpNo;
55126     switch (User->getOpcode()) {
55127     default:
55128       // Be conservative.
55129       return false;
55130     case X86ISD::SETCC:
55131     case X86ISD::SETCC_CARRY:
55132       CCOpNo = 0;
55133       break;
55134     case X86ISD::BRCOND:
55135     case X86ISD::CMOV:
55136       CCOpNo = 2;
55137       break;
55138     }
55139 
55140     X86::CondCode CC = (X86::CondCode)User->getConstantOperandVal(CCOpNo);
55141     if (CC != X86::COND_E && CC != X86::COND_NE)
55142       return false;
55143   }
55144 
55145   return true;
55146 }
55147 
combineCMP(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)55148 static SDValue combineCMP(SDNode *N, SelectionDAG &DAG,
55149                           TargetLowering::DAGCombinerInfo &DCI,
55150                           const X86Subtarget &Subtarget) {
55151   // Only handle test patterns.
55152   if (!isNullConstant(N->getOperand(1)))
55153     return SDValue();
55154 
55155   // If we have a CMP of a truncated binop, see if we can make a smaller binop
55156   // and use its flags directly.
55157   // TODO: Maybe we should try promoting compares that only use the zero flag
55158   // first if we can prove the upper bits with computeKnownBits?
55159   SDLoc dl(N);
55160   SDValue Op = N->getOperand(0);
55161   EVT VT = Op.getValueType();
55162   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
55163 
55164   if (SDValue CMP =
55165           combineX86SubCmpForFlags(N, SDValue(N, 0), DAG, DCI, Subtarget))
55166     return CMP;
55167 
55168   // If we have a constant logical shift that's only used in a comparison
55169   // against zero turn it into an equivalent AND. This allows turning it into
55170   // a TEST instruction later.
55171   if ((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) &&
55172       Op.hasOneUse() && isa<ConstantSDNode>(Op.getOperand(1)) &&
55173       onlyZeroFlagUsed(SDValue(N, 0))) {
55174     unsigned BitWidth = VT.getSizeInBits();
55175     const APInt &ShAmt = Op.getConstantOperandAPInt(1);
55176     if (ShAmt.ult(BitWidth)) { // Avoid undefined shifts.
55177       unsigned MaskBits = BitWidth - ShAmt.getZExtValue();
55178       APInt Mask = Op.getOpcode() == ISD::SRL
55179                        ? APInt::getHighBitsSet(BitWidth, MaskBits)
55180                        : APInt::getLowBitsSet(BitWidth, MaskBits);
55181       if (Mask.isSignedIntN(32)) {
55182         Op = DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0),
55183                          DAG.getConstant(Mask, dl, VT));
55184         return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
55185                            DAG.getConstant(0, dl, VT));
55186       }
55187     }
55188   }
55189 
55190   // If we're extracting from a avx512 bool vector and comparing against zero,
55191   // then try to just bitcast the vector to an integer to use TEST/BT directly.
55192   // (and (extract_elt (kshiftr vXi1, C), 0), 1) -> (and (bc vXi1), 1<<C)
55193   if (Op.getOpcode() == ISD::AND && isOneConstant(Op.getOperand(1)) &&
55194       Op.hasOneUse() && onlyZeroFlagUsed(SDValue(N, 0))) {
55195     SDValue Src = Op.getOperand(0);
55196     if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
55197         isNullConstant(Src.getOperand(1)) &&
55198         Src.getOperand(0).getValueType().getScalarType() == MVT::i1) {
55199       SDValue BoolVec = Src.getOperand(0);
55200       unsigned ShAmt = 0;
55201       if (BoolVec.getOpcode() == X86ISD::KSHIFTR) {
55202         ShAmt = BoolVec.getConstantOperandVal(1);
55203         BoolVec = BoolVec.getOperand(0);
55204       }
55205       BoolVec = widenMaskVector(BoolVec, false, Subtarget, DAG, dl);
55206       EVT VecVT = BoolVec.getValueType();
55207       unsigned BitWidth = VecVT.getVectorNumElements();
55208       EVT BCVT = EVT::getIntegerVT(*DAG.getContext(), BitWidth);
55209       if (TLI.isTypeLegal(VecVT) && TLI.isTypeLegal(BCVT)) {
55210         APInt Mask = APInt::getOneBitSet(BitWidth, ShAmt);
55211         Op = DAG.getBitcast(BCVT, BoolVec);
55212         Op = DAG.getNode(ISD::AND, dl, BCVT, Op,
55213                          DAG.getConstant(Mask, dl, BCVT));
55214         return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
55215                            DAG.getConstant(0, dl, BCVT));
55216       }
55217     }
55218   }
55219 
55220   // Peek through any zero-extend if we're only testing for a zero result.
55221   if (Op.getOpcode() == ISD::ZERO_EXTEND && onlyZeroFlagUsed(SDValue(N, 0))) {
55222     SDValue Src = Op.getOperand(0);
55223     EVT SrcVT = Src.getValueType();
55224     if (SrcVT.getScalarSizeInBits() >= 8 && TLI.isTypeLegal(SrcVT))
55225       return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Src,
55226                          DAG.getConstant(0, dl, SrcVT));
55227   }
55228 
55229   // Look for a truncate.
55230   if (Op.getOpcode() != ISD::TRUNCATE)
55231     return SDValue();
55232 
55233   SDValue Trunc = Op;
55234   Op = Op.getOperand(0);
55235 
55236   // See if we can compare with zero against the truncation source,
55237   // which should help using the Z flag from many ops. Only do this for
55238   // i32 truncated op to prevent partial-reg compares of promoted ops.
55239   EVT OpVT = Op.getValueType();
55240   APInt UpperBits =
55241       APInt::getBitsSetFrom(OpVT.getSizeInBits(), VT.getSizeInBits());
55242   if (OpVT == MVT::i32 && DAG.MaskedValueIsZero(Op, UpperBits) &&
55243       onlyZeroFlagUsed(SDValue(N, 0))) {
55244     return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
55245                        DAG.getConstant(0, dl, OpVT));
55246   }
55247 
55248   // After this the truncate and arithmetic op must have a single use.
55249   if (!Trunc.hasOneUse() || !Op.hasOneUse())
55250       return SDValue();
55251 
55252   unsigned NewOpc;
55253   switch (Op.getOpcode()) {
55254   default: return SDValue();
55255   case ISD::AND:
55256     // Skip and with constant. We have special handling for and with immediate
55257     // during isel to generate test instructions.
55258     if (isa<ConstantSDNode>(Op.getOperand(1)))
55259       return SDValue();
55260     NewOpc = X86ISD::AND;
55261     break;
55262   case ISD::OR:  NewOpc = X86ISD::OR;  break;
55263   case ISD::XOR: NewOpc = X86ISD::XOR; break;
55264   case ISD::ADD:
55265     // If the carry or overflow flag is used, we can't truncate.
55266     if (needCarryOrOverflowFlag(SDValue(N, 0)))
55267       return SDValue();
55268     NewOpc = X86ISD::ADD;
55269     break;
55270   case ISD::SUB:
55271     // If the carry or overflow flag is used, we can't truncate.
55272     if (needCarryOrOverflowFlag(SDValue(N, 0)))
55273       return SDValue();
55274     NewOpc = X86ISD::SUB;
55275     break;
55276   }
55277 
55278   // We found an op we can narrow. Truncate its inputs.
55279   SDValue Op0 = DAG.getNode(ISD::TRUNCATE, dl, VT, Op.getOperand(0));
55280   SDValue Op1 = DAG.getNode(ISD::TRUNCATE, dl, VT, Op.getOperand(1));
55281 
55282   // Use a X86 specific opcode to avoid DAG combine messing with it.
55283   SDVTList VTs = DAG.getVTList(VT, MVT::i32);
55284   Op = DAG.getNode(NewOpc, dl, VTs, Op0, Op1);
55285 
55286   // For AND, keep a CMP so that we can match the test pattern.
55287   if (NewOpc == X86ISD::AND)
55288     return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
55289                        DAG.getConstant(0, dl, VT));
55290 
55291   // Return the flags.
55292   return Op.getValue(1);
55293 }
55294 
combineX86AddSub(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & ST)55295 static SDValue combineX86AddSub(SDNode *N, SelectionDAG &DAG,
55296                                 TargetLowering::DAGCombinerInfo &DCI,
55297                                 const X86Subtarget &ST) {
55298   assert((X86ISD::ADD == N->getOpcode() || X86ISD::SUB == N->getOpcode()) &&
55299          "Expected X86ISD::ADD or X86ISD::SUB");
55300 
55301   SDLoc DL(N);
55302   SDValue LHS = N->getOperand(0);
55303   SDValue RHS = N->getOperand(1);
55304   MVT VT = LHS.getSimpleValueType();
55305   bool IsSub = X86ISD::SUB == N->getOpcode();
55306   unsigned GenericOpc = IsSub ? ISD::SUB : ISD::ADD;
55307 
55308   if (IsSub && isOneConstant(N->getOperand(1)) && !N->hasAnyUseOfValue(0))
55309     if (SDValue CMP = combineX86SubCmpForFlags(N, SDValue(N, 1), DAG, DCI, ST))
55310       return CMP;
55311 
55312   // If we don't use the flag result, simplify back to a generic ADD/SUB.
55313   if (!N->hasAnyUseOfValue(1)) {
55314     SDValue Res = DAG.getNode(GenericOpc, DL, VT, LHS, RHS);
55315     return DAG.getMergeValues({Res, DAG.getConstant(0, DL, MVT::i32)}, DL);
55316   }
55317 
55318   // Fold any similar generic ADD/SUB opcodes to reuse this node.
55319   auto MatchGeneric = [&](SDValue N0, SDValue N1, bool Negate) {
55320     SDValue Ops[] = {N0, N1};
55321     SDVTList VTs = DAG.getVTList(N->getValueType(0));
55322     if (SDNode *GenericAddSub = DAG.getNodeIfExists(GenericOpc, VTs, Ops)) {
55323       SDValue Op(N, 0);
55324       if (Negate)
55325         Op = DAG.getNegative(Op, DL, VT);
55326       DCI.CombineTo(GenericAddSub, Op);
55327     }
55328   };
55329   MatchGeneric(LHS, RHS, false);
55330   MatchGeneric(RHS, LHS, X86ISD::SUB == N->getOpcode());
55331 
55332   // TODO: Can we drop the ZeroSecondOpOnly limit? This is to guarantee that the
55333   // EFLAGS result doesn't change.
55334   return combineAddOrSubToADCOrSBB(IsSub, DL, VT, LHS, RHS, DAG,
55335                                    /*ZeroSecondOpOnly*/ true);
55336 }
55337 
combineSBB(SDNode * N,SelectionDAG & DAG)55338 static SDValue combineSBB(SDNode *N, SelectionDAG &DAG) {
55339   SDValue LHS = N->getOperand(0);
55340   SDValue RHS = N->getOperand(1);
55341   SDValue BorrowIn = N->getOperand(2);
55342 
55343   if (SDValue Flags = combineCarryThroughADD(BorrowIn, DAG)) {
55344     MVT VT = N->getSimpleValueType(0);
55345     SDVTList VTs = DAG.getVTList(VT, MVT::i32);
55346     return DAG.getNode(X86ISD::SBB, SDLoc(N), VTs, LHS, RHS, Flags);
55347   }
55348 
55349   // Fold SBB(SUB(X,Y),0,Carry) -> SBB(X,Y,Carry)
55350   // iff the flag result is dead.
55351   if (LHS.getOpcode() == ISD::SUB && isNullConstant(RHS) &&
55352       !N->hasAnyUseOfValue(1))
55353     return DAG.getNode(X86ISD::SBB, SDLoc(N), N->getVTList(), LHS.getOperand(0),
55354                        LHS.getOperand(1), BorrowIn);
55355 
55356   return SDValue();
55357 }
55358 
55359 // Optimize RES, EFLAGS = X86ISD::ADC LHS, RHS, EFLAGS
combineADC(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)55360 static SDValue combineADC(SDNode *N, SelectionDAG &DAG,
55361                           TargetLowering::DAGCombinerInfo &DCI) {
55362   SDValue LHS = N->getOperand(0);
55363   SDValue RHS = N->getOperand(1);
55364   SDValue CarryIn = N->getOperand(2);
55365   auto *LHSC = dyn_cast<ConstantSDNode>(LHS);
55366   auto *RHSC = dyn_cast<ConstantSDNode>(RHS);
55367 
55368   // Canonicalize constant to RHS.
55369   if (LHSC && !RHSC)
55370     return DAG.getNode(X86ISD::ADC, SDLoc(N), N->getVTList(), RHS, LHS,
55371                        CarryIn);
55372 
55373   // If the LHS and RHS of the ADC node are zero, then it can't overflow and
55374   // the result is either zero or one (depending on the input carry bit).
55375   // Strength reduce this down to a "set on carry" aka SETCC_CARRY&1.
55376   if (LHSC && RHSC && LHSC->isZero() && RHSC->isZero() &&
55377       // We don't have a good way to replace an EFLAGS use, so only do this when
55378       // dead right now.
55379       SDValue(N, 1).use_empty()) {
55380     SDLoc DL(N);
55381     EVT VT = N->getValueType(0);
55382     SDValue CarryOut = DAG.getConstant(0, DL, N->getValueType(1));
55383     SDValue Res1 = DAG.getNode(
55384         ISD::AND, DL, VT,
55385         DAG.getNode(X86ISD::SETCC_CARRY, DL, VT,
55386                     DAG.getTargetConstant(X86::COND_B, DL, MVT::i8), CarryIn),
55387         DAG.getConstant(1, DL, VT));
55388     return DCI.CombineTo(N, Res1, CarryOut);
55389   }
55390 
55391   // Fold ADC(C1,C2,Carry) -> ADC(0,C1+C2,Carry)
55392   // iff the flag result is dead.
55393   // TODO: Allow flag result if C1+C2 doesn't signed/unsigned overflow.
55394   if (LHSC && RHSC && !LHSC->isZero() && !N->hasAnyUseOfValue(1)) {
55395     SDLoc DL(N);
55396     APInt Sum = LHSC->getAPIntValue() + RHSC->getAPIntValue();
55397     return DAG.getNode(X86ISD::ADC, DL, N->getVTList(),
55398                        DAG.getConstant(0, DL, LHS.getValueType()),
55399                        DAG.getConstant(Sum, DL, LHS.getValueType()), CarryIn);
55400   }
55401 
55402   if (SDValue Flags = combineCarryThroughADD(CarryIn, DAG)) {
55403     MVT VT = N->getSimpleValueType(0);
55404     SDVTList VTs = DAG.getVTList(VT, MVT::i32);
55405     return DAG.getNode(X86ISD::ADC, SDLoc(N), VTs, LHS, RHS, Flags);
55406   }
55407 
55408   // Fold ADC(ADD(X,Y),0,Carry) -> ADC(X,Y,Carry)
55409   // iff the flag result is dead.
55410   if (LHS.getOpcode() == ISD::ADD && RHSC && RHSC->isZero() &&
55411       !N->hasAnyUseOfValue(1))
55412     return DAG.getNode(X86ISD::ADC, SDLoc(N), N->getVTList(), LHS.getOperand(0),
55413                        LHS.getOperand(1), CarryIn);
55414 
55415   return SDValue();
55416 }
55417 
matchPMADDWD(SelectionDAG & DAG,SDValue Op0,SDValue Op1,const SDLoc & DL,EVT VT,const X86Subtarget & Subtarget)55418 static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
55419                             const SDLoc &DL, EVT VT,
55420                             const X86Subtarget &Subtarget) {
55421   // Example of pattern we try to detect:
55422   // t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1))))
55423   //(add (build_vector (extract_elt t, 0),
55424   //                   (extract_elt t, 2),
55425   //                   (extract_elt t, 4),
55426   //                   (extract_elt t, 6)),
55427   //     (build_vector (extract_elt t, 1),
55428   //                   (extract_elt t, 3),
55429   //                   (extract_elt t, 5),
55430   //                   (extract_elt t, 7)))
55431 
55432   if (!Subtarget.hasSSE2())
55433     return SDValue();
55434 
55435   if (Op0.getOpcode() != ISD::BUILD_VECTOR ||
55436       Op1.getOpcode() != ISD::BUILD_VECTOR)
55437     return SDValue();
55438 
55439   if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
55440       VT.getVectorNumElements() < 4 ||
55441       !isPowerOf2_32(VT.getVectorNumElements()))
55442     return SDValue();
55443 
55444   // Check if one of Op0,Op1 is of the form:
55445   // (build_vector (extract_elt Mul, 0),
55446   //               (extract_elt Mul, 2),
55447   //               (extract_elt Mul, 4),
55448   //                   ...
55449   // the other is of the form:
55450   // (build_vector (extract_elt Mul, 1),
55451   //               (extract_elt Mul, 3),
55452   //               (extract_elt Mul, 5),
55453   //                   ...
55454   // and identify Mul.
55455   SDValue Mul;
55456   for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; i += 2) {
55457     SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
55458             Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1);
55459     // TODO: Be more tolerant to undefs.
55460     if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
55461         Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
55462         Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
55463         Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
55464       return SDValue();
55465     auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1));
55466     auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1));
55467     auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1));
55468     auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1));
55469     if (!Const0L || !Const1L || !Const0H || !Const1H)
55470       return SDValue();
55471     unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(),
55472              Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue();
55473     // Commutativity of mul allows factors of a product to reorder.
55474     if (Idx0L > Idx1L)
55475       std::swap(Idx0L, Idx1L);
55476     if (Idx0H > Idx1H)
55477       std::swap(Idx0H, Idx1H);
55478     // Commutativity of add allows pairs of factors to reorder.
55479     if (Idx0L > Idx0H) {
55480       std::swap(Idx0L, Idx0H);
55481       std::swap(Idx1L, Idx1H);
55482     }
55483     if (Idx0L != 2 * i || Idx1L != 2 * i + 1 || Idx0H != 2 * i + 2 ||
55484         Idx1H != 2 * i + 3)
55485       return SDValue();
55486     if (!Mul) {
55487       // First time an extract_elt's source vector is visited. Must be a MUL
55488       // with 2X number of vector elements than the BUILD_VECTOR.
55489       // Both extracts must be from same MUL.
55490       Mul = Op0L->getOperand(0);
55491       if (Mul->getOpcode() != ISD::MUL ||
55492           Mul.getValueType().getVectorNumElements() != 2 * e)
55493         return SDValue();
55494     }
55495     // Check that the extract is from the same MUL previously seen.
55496     if (Mul != Op0L->getOperand(0) || Mul != Op1L->getOperand(0) ||
55497         Mul != Op0H->getOperand(0) || Mul != Op1H->getOperand(0))
55498       return SDValue();
55499   }
55500 
55501   // Check if the Mul source can be safely shrunk.
55502   ShrinkMode Mode;
55503   if (!canReduceVMulWidth(Mul.getNode(), DAG, Mode) ||
55504       Mode == ShrinkMode::MULU16)
55505     return SDValue();
55506 
55507   EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
55508                                  VT.getVectorNumElements() * 2);
55509   SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(0));
55510   SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(1));
55511 
55512   auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
55513                          ArrayRef<SDValue> Ops) {
55514     EVT InVT = Ops[0].getValueType();
55515     assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
55516     EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
55517                                  InVT.getVectorNumElements() / 2);
55518     return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
55519   };
55520   return SplitOpsAndApply(DAG, Subtarget, DL, VT, { N0, N1 }, PMADDBuilder);
55521 }
55522 
55523 // Attempt to turn this pattern into PMADDWD.
55524 // (add (mul (sext (build_vector)), (sext (build_vector))),
55525 //      (mul (sext (build_vector)), (sext (build_vector)))
matchPMADDWD_2(SelectionDAG & DAG,SDValue N0,SDValue N1,const SDLoc & DL,EVT VT,const X86Subtarget & Subtarget)55526 static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
55527                               const SDLoc &DL, EVT VT,
55528                               const X86Subtarget &Subtarget) {
55529   if (!Subtarget.hasSSE2())
55530     return SDValue();
55531 
55532   if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
55533     return SDValue();
55534 
55535   if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
55536       VT.getVectorNumElements() < 4 ||
55537       !isPowerOf2_32(VT.getVectorNumElements()))
55538     return SDValue();
55539 
55540   SDValue N00 = N0.getOperand(0);
55541   SDValue N01 = N0.getOperand(1);
55542   SDValue N10 = N1.getOperand(0);
55543   SDValue N11 = N1.getOperand(1);
55544 
55545   // All inputs need to be sign extends.
55546   // TODO: Support ZERO_EXTEND from known positive?
55547   if (N00.getOpcode() != ISD::SIGN_EXTEND ||
55548       N01.getOpcode() != ISD::SIGN_EXTEND ||
55549       N10.getOpcode() != ISD::SIGN_EXTEND ||
55550       N11.getOpcode() != ISD::SIGN_EXTEND)
55551     return SDValue();
55552 
55553   // Peek through the extends.
55554   N00 = N00.getOperand(0);
55555   N01 = N01.getOperand(0);
55556   N10 = N10.getOperand(0);
55557   N11 = N11.getOperand(0);
55558 
55559   // Must be extending from vXi16.
55560   EVT InVT = N00.getValueType();
55561   if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
55562       N10.getValueType() != InVT || N11.getValueType() != InVT)
55563     return SDValue();
55564 
55565   // All inputs should be build_vectors.
55566   if (N00.getOpcode() != ISD::BUILD_VECTOR ||
55567       N01.getOpcode() != ISD::BUILD_VECTOR ||
55568       N10.getOpcode() != ISD::BUILD_VECTOR ||
55569       N11.getOpcode() != ISD::BUILD_VECTOR)
55570     return SDValue();
55571 
55572   // For each element, we need to ensure we have an odd element from one vector
55573   // multiplied by the odd element of another vector and the even element from
55574   // one of the same vectors being multiplied by the even element from the
55575   // other vector. So we need to make sure for each element i, this operator
55576   // is being performed:
55577   //  A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
55578   SDValue In0, In1;
55579   for (unsigned i = 0; i != N00.getNumOperands(); ++i) {
55580     SDValue N00Elt = N00.getOperand(i);
55581     SDValue N01Elt = N01.getOperand(i);
55582     SDValue N10Elt = N10.getOperand(i);
55583     SDValue N11Elt = N11.getOperand(i);
55584     // TODO: Be more tolerant to undefs.
55585     if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
55586         N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
55587         N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
55588         N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
55589       return SDValue();
55590     auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
55591     auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
55592     auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
55593     auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
55594     if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
55595       return SDValue();
55596     unsigned IdxN00 = ConstN00Elt->getZExtValue();
55597     unsigned IdxN01 = ConstN01Elt->getZExtValue();
55598     unsigned IdxN10 = ConstN10Elt->getZExtValue();
55599     unsigned IdxN11 = ConstN11Elt->getZExtValue();
55600     // Add is commutative so indices can be reordered.
55601     if (IdxN00 > IdxN10) {
55602       std::swap(IdxN00, IdxN10);
55603       std::swap(IdxN01, IdxN11);
55604     }
55605     // N0 indices be the even element. N1 indices must be the next odd element.
55606     if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
55607         IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
55608       return SDValue();
55609     SDValue N00In = N00Elt.getOperand(0);
55610     SDValue N01In = N01Elt.getOperand(0);
55611     SDValue N10In = N10Elt.getOperand(0);
55612     SDValue N11In = N11Elt.getOperand(0);
55613 
55614     // First time we find an input capture it.
55615     if (!In0) {
55616       In0 = N00In;
55617       In1 = N01In;
55618 
55619       // The input vectors must be at least as wide as the output.
55620       // If they are larger than the output, we extract subvector below.
55621       if (In0.getValueSizeInBits() < VT.getSizeInBits() ||
55622           In1.getValueSizeInBits() < VT.getSizeInBits())
55623         return SDValue();
55624     }
55625     // Mul is commutative so the input vectors can be in any order.
55626     // Canonicalize to make the compares easier.
55627     if (In0 != N00In)
55628       std::swap(N00In, N01In);
55629     if (In0 != N10In)
55630       std::swap(N10In, N11In);
55631     if (In0 != N00In || In1 != N01In || In0 != N10In || In1 != N11In)
55632       return SDValue();
55633   }
55634 
55635   auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
55636                          ArrayRef<SDValue> Ops) {
55637     EVT OpVT = Ops[0].getValueType();
55638     assert(OpVT.getScalarType() == MVT::i16 &&
55639            "Unexpected scalar element type");
55640     assert(OpVT == Ops[1].getValueType() && "Operands' types mismatch");
55641     EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
55642                                  OpVT.getVectorNumElements() / 2);
55643     return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
55644   };
55645 
55646   // If the output is narrower than an input, extract the low part of the input
55647   // vector.
55648   EVT OutVT16 = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
55649                                VT.getVectorNumElements() * 2);
55650   if (OutVT16.bitsLT(In0.getValueType())) {
55651     In0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT16, In0,
55652                       DAG.getIntPtrConstant(0, DL));
55653   }
55654   if (OutVT16.bitsLT(In1.getValueType())) {
55655     In1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT16, In1,
55656                       DAG.getIntPtrConstant(0, DL));
55657   }
55658   return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 },
55659                           PMADDBuilder);
55660 }
55661 
55662 // ADD(VPMADDWD(X,Y),VPMADDWD(Z,W)) -> VPMADDWD(SHUFFLE(X,Z), SHUFFLE(Y,W))
55663 // If upper element in each pair of both VPMADDWD are zero then we can merge
55664 // the operand elements and use the implicit add of VPMADDWD.
55665 // TODO: Add support for VPMADDUBSW (which isn't commutable).
combineAddOfPMADDWD(SelectionDAG & DAG,SDValue N0,SDValue N1,const SDLoc & DL,EVT VT)55666 static SDValue combineAddOfPMADDWD(SelectionDAG &DAG, SDValue N0, SDValue N1,
55667                                    const SDLoc &DL, EVT VT) {
55668   if (N0.getOpcode() != N1.getOpcode() || N0.getOpcode() != X86ISD::VPMADDWD)
55669     return SDValue();
55670 
55671   // TODO: Add 256/512-bit support once VPMADDWD combines with shuffles.
55672   if (VT.getSizeInBits() > 128)
55673     return SDValue();
55674 
55675   unsigned NumElts = VT.getVectorNumElements();
55676   MVT OpVT = N0.getOperand(0).getSimpleValueType();
55677   APInt DemandedBits = APInt::getAllOnes(OpVT.getScalarSizeInBits());
55678   APInt DemandedHiElts = APInt::getSplat(2 * NumElts, APInt(2, 2));
55679 
55680   bool Op0HiZero =
55681       DAG.MaskedValueIsZero(N0.getOperand(0), DemandedBits, DemandedHiElts) ||
55682       DAG.MaskedValueIsZero(N0.getOperand(1), DemandedBits, DemandedHiElts);
55683   bool Op1HiZero =
55684       DAG.MaskedValueIsZero(N1.getOperand(0), DemandedBits, DemandedHiElts) ||
55685       DAG.MaskedValueIsZero(N1.getOperand(1), DemandedBits, DemandedHiElts);
55686 
55687   // TODO: Check for zero lower elements once we have actual codegen that
55688   // creates them.
55689   if (!Op0HiZero || !Op1HiZero)
55690     return SDValue();
55691 
55692   // Create a shuffle mask packing the lower elements from each VPMADDWD.
55693   SmallVector<int> Mask;
55694   for (int i = 0; i != (int)NumElts; ++i) {
55695     Mask.push_back(2 * i);
55696     Mask.push_back(2 * (i + NumElts));
55697   }
55698 
55699   SDValue LHS =
55700       DAG.getVectorShuffle(OpVT, DL, N0.getOperand(0), N1.getOperand(0), Mask);
55701   SDValue RHS =
55702       DAG.getVectorShuffle(OpVT, DL, N0.getOperand(1), N1.getOperand(1), Mask);
55703   return DAG.getNode(X86ISD::VPMADDWD, DL, VT, LHS, RHS);
55704 }
55705 
55706 /// CMOV of constants requires materializing constant operands in registers.
55707 /// Try to fold those constants into an 'add' instruction to reduce instruction
55708 /// count. We do this with CMOV rather the generic 'select' because there are
55709 /// earlier folds that may be used to turn select-of-constants into logic hacks.
pushAddIntoCmovOfConsts(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const X86Subtarget & Subtarget)55710 static SDValue pushAddIntoCmovOfConsts(SDNode *N, const SDLoc &DL,
55711                                        SelectionDAG &DAG,
55712                                        const X86Subtarget &Subtarget) {
55713   // If an operand is zero, add-of-0 gets simplified away, so that's clearly
55714   // better because we eliminate 1-2 instructions. This transform is still
55715   // an improvement without zero operands because we trade 2 move constants and
55716   // 1 add for 2 adds (LEA) as long as the constants can be represented as
55717   // immediate asm operands (fit in 32-bits).
55718   auto isSuitableCmov = [](SDValue V) {
55719     if (V.getOpcode() != X86ISD::CMOV || !V.hasOneUse())
55720       return false;
55721     if (!isa<ConstantSDNode>(V.getOperand(0)) ||
55722         !isa<ConstantSDNode>(V.getOperand(1)))
55723       return false;
55724     return isNullConstant(V.getOperand(0)) || isNullConstant(V.getOperand(1)) ||
55725            (V.getConstantOperandAPInt(0).isSignedIntN(32) &&
55726             V.getConstantOperandAPInt(1).isSignedIntN(32));
55727   };
55728 
55729   // Match an appropriate CMOV as the first operand of the add.
55730   SDValue Cmov = N->getOperand(0);
55731   SDValue OtherOp = N->getOperand(1);
55732   if (!isSuitableCmov(Cmov))
55733     std::swap(Cmov, OtherOp);
55734   if (!isSuitableCmov(Cmov))
55735     return SDValue();
55736 
55737   // Don't remove a load folding opportunity for the add. That would neutralize
55738   // any improvements from removing constant materializations.
55739   if (X86::mayFoldLoad(OtherOp, Subtarget))
55740     return SDValue();
55741 
55742   EVT VT = N->getValueType(0);
55743   SDValue FalseOp = Cmov.getOperand(0);
55744   SDValue TrueOp = Cmov.getOperand(1);
55745 
55746   // We will push the add through the select, but we can potentially do better
55747   // if we know there is another add in the sequence and this is pointer math.
55748   // In that case, we can absorb an add into the trailing memory op and avoid
55749   // a 3-operand LEA which is likely slower than a 2-operand LEA.
55750   // TODO: If target has "slow3OpsLEA", do this even without the trailing memop?
55751   if (OtherOp.getOpcode() == ISD::ADD && OtherOp.hasOneUse() &&
55752       !isa<ConstantSDNode>(OtherOp.getOperand(0)) &&
55753       all_of(N->uses(), [&](SDNode *Use) {
55754         auto *MemNode = dyn_cast<MemSDNode>(Use);
55755         return MemNode && MemNode->getBasePtr().getNode() == N;
55756       })) {
55757     // add (cmov C1, C2), add (X, Y) --> add (cmov (add X, C1), (add X, C2)), Y
55758     // TODO: We are arbitrarily choosing op0 as the 1st piece of the sum, but
55759     //       it is possible that choosing op1 might be better.
55760     SDValue X = OtherOp.getOperand(0), Y = OtherOp.getOperand(1);
55761     FalseOp = DAG.getNode(ISD::ADD, DL, VT, X, FalseOp);
55762     TrueOp = DAG.getNode(ISD::ADD, DL, VT, X, TrueOp);
55763     Cmov = DAG.getNode(X86ISD::CMOV, DL, VT, FalseOp, TrueOp,
55764                        Cmov.getOperand(2), Cmov.getOperand(3));
55765     return DAG.getNode(ISD::ADD, DL, VT, Cmov, Y);
55766   }
55767 
55768   // add (cmov C1, C2), OtherOp --> cmov (add OtherOp, C1), (add OtherOp, C2)
55769   FalseOp = DAG.getNode(ISD::ADD, DL, VT, OtherOp, FalseOp);
55770   TrueOp = DAG.getNode(ISD::ADD, DL, VT, OtherOp, TrueOp);
55771   return DAG.getNode(X86ISD::CMOV, DL, VT, FalseOp, TrueOp, Cmov.getOperand(2),
55772                      Cmov.getOperand(3));
55773 }
55774 
combineAdd(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)55775 static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
55776                           TargetLowering::DAGCombinerInfo &DCI,
55777                           const X86Subtarget &Subtarget) {
55778   EVT VT = N->getValueType(0);
55779   SDValue Op0 = N->getOperand(0);
55780   SDValue Op1 = N->getOperand(1);
55781   SDLoc DL(N);
55782 
55783   if (SDValue Select = pushAddIntoCmovOfConsts(N, DL, DAG, Subtarget))
55784     return Select;
55785 
55786   if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, DL, VT, Subtarget))
55787     return MAdd;
55788   if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, DL, VT, Subtarget))
55789     return MAdd;
55790   if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT))
55791     return MAdd;
55792 
55793   // Try to synthesize horizontal adds from adds of shuffles.
55794   if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget))
55795     return V;
55796 
55797   // add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
55798   // iff X and Y won't overflow.
55799   if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
55800       ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
55801       ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
55802     if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
55803       MVT OpVT = Op0.getOperand(1).getSimpleValueType();
55804       SDValue Sum =
55805           DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
55806       return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
55807                          getZeroVector(OpVT, Subtarget, DAG, DL));
55808     }
55809   }
55810 
55811   // If vectors of i1 are legal, turn (add (zext (vXi1 X)), Y) into
55812   // (sub Y, (sext (vXi1 X))).
55813   // FIXME: We have the (sub Y, (zext (vXi1 X))) -> (add (sext (vXi1 X)), Y) in
55814   // generic DAG combine without a legal type check, but adding this there
55815   // caused regressions.
55816   if (VT.isVector()) {
55817     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
55818     if (Op0.getOpcode() == ISD::ZERO_EXTEND &&
55819         Op0.getOperand(0).getValueType().getVectorElementType() == MVT::i1 &&
55820         TLI.isTypeLegal(Op0.getOperand(0).getValueType())) {
55821       SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op0.getOperand(0));
55822       return DAG.getNode(ISD::SUB, DL, VT, Op1, SExt);
55823     }
55824 
55825     if (Op1.getOpcode() == ISD::ZERO_EXTEND &&
55826         Op1.getOperand(0).getValueType().getVectorElementType() == MVT::i1 &&
55827         TLI.isTypeLegal(Op1.getOperand(0).getValueType())) {
55828       SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op1.getOperand(0));
55829       return DAG.getNode(ISD::SUB, DL, VT, Op0, SExt);
55830     }
55831   }
55832 
55833   // Fold ADD(ADC(Y,0,W),X) -> ADC(X,Y,W)
55834   if (Op0.getOpcode() == X86ISD::ADC && Op0->hasOneUse() &&
55835       X86::isZeroNode(Op0.getOperand(1))) {
55836     assert(!Op0->hasAnyUseOfValue(1) && "Overflow bit in use");
55837     return DAG.getNode(X86ISD::ADC, SDLoc(Op0), Op0->getVTList(), Op1,
55838                        Op0.getOperand(0), Op0.getOperand(2));
55839   }
55840 
55841   return combineAddOrSubToADCOrSBB(N, DL, DAG);
55842 }
55843 
55844 // Try to fold (sub Y, cmovns X, -X) -> (add Y, cmovns -X, X) if the cmov
55845 // condition comes from the subtract node that produced -X. This matches the
55846 // cmov expansion for absolute value. By swapping the operands we convert abs
55847 // to nabs.
combineSubABS(SDNode * N,SelectionDAG & DAG)55848 static SDValue combineSubABS(SDNode *N, SelectionDAG &DAG) {
55849   SDValue N0 = N->getOperand(0);
55850   SDValue N1 = N->getOperand(1);
55851 
55852   if (N1.getOpcode() != X86ISD::CMOV || !N1.hasOneUse())
55853     return SDValue();
55854 
55855   X86::CondCode CC = (X86::CondCode)N1.getConstantOperandVal(2);
55856   if (CC != X86::COND_S && CC != X86::COND_NS)
55857     return SDValue();
55858 
55859   // Condition should come from a negate operation.
55860   SDValue Cond = N1.getOperand(3);
55861   if (Cond.getOpcode() != X86ISD::SUB || !isNullConstant(Cond.getOperand(0)))
55862     return SDValue();
55863   assert(Cond.getResNo() == 1 && "Unexpected result number");
55864 
55865   // Get the X and -X from the negate.
55866   SDValue NegX = Cond.getValue(0);
55867   SDValue X = Cond.getOperand(1);
55868 
55869   SDValue FalseOp = N1.getOperand(0);
55870   SDValue TrueOp = N1.getOperand(1);
55871 
55872   // Cmov operands should be X and NegX. Order doesn't matter.
55873   if (!(TrueOp == X && FalseOp == NegX) && !(TrueOp == NegX && FalseOp == X))
55874     return SDValue();
55875 
55876   // Build a new CMOV with the operands swapped.
55877   SDLoc DL(N);
55878   MVT VT = N->getSimpleValueType(0);
55879   SDValue Cmov = DAG.getNode(X86ISD::CMOV, DL, VT, TrueOp, FalseOp,
55880                              N1.getOperand(2), Cond);
55881   // Convert sub to add.
55882   return DAG.getNode(ISD::ADD, DL, VT, N0, Cmov);
55883 }
55884 
combineSubSetcc(SDNode * N,SelectionDAG & DAG)55885 static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
55886   SDValue Op0 = N->getOperand(0);
55887   SDValue Op1 = N->getOperand(1);
55888 
55889   // (sub C (zero_extend (setcc)))
55890   // =>
55891   // (add (zero_extend (setcc inverted) C-1))   if C is a nonzero immediate
55892   // Don't disturb (sub 0 setcc), which is easily done with neg.
55893   EVT VT = N->getValueType(0);
55894   auto *Op0C = dyn_cast<ConstantSDNode>(Op0);
55895   if (Op1.getOpcode() == ISD::ZERO_EXTEND && Op1.hasOneUse() && Op0C &&
55896       !Op0C->isZero() && Op1.getOperand(0).getOpcode() == X86ISD::SETCC &&
55897       Op1.getOperand(0).hasOneUse()) {
55898     SDValue SetCC = Op1.getOperand(0);
55899     X86::CondCode CC = (X86::CondCode)SetCC.getConstantOperandVal(0);
55900     X86::CondCode NewCC = X86::GetOppositeBranchCondition(CC);
55901     APInt NewImm = Op0C->getAPIntValue() - 1;
55902     SDLoc DL(Op1);
55903     SDValue NewSetCC = getSETCC(NewCC, SetCC.getOperand(1), DL, DAG);
55904     NewSetCC = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NewSetCC);
55905     return DAG.getNode(X86ISD::ADD, DL, DAG.getVTList(VT, VT), NewSetCC,
55906                        DAG.getConstant(NewImm, DL, VT));
55907   }
55908 
55909   return SDValue();
55910 }
55911 
combineX86CloadCstore(SDNode * N,SelectionDAG & DAG)55912 static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
55913   // res, flags2 = sub 0, (setcc cc, flag)
55914   // cload/cstore ..., cond_ne, flag2
55915   // ->
55916   // cload/cstore cc, flag
55917   if (N->getConstantOperandVal(3) != X86::COND_NE)
55918     return SDValue();
55919 
55920   SDValue Sub = N->getOperand(4);
55921   if (Sub.getOpcode() != X86ISD::SUB)
55922     return SDValue();
55923 
55924   SDValue SetCC = Sub.getOperand(1);
55925 
55926   if (!X86::isZeroNode(Sub.getOperand(0)) || SetCC.getOpcode() != X86ISD::SETCC)
55927     return SDValue();
55928 
55929   SmallVector<SDValue, 5> Ops(N->op_values());
55930   Ops[3] = SetCC.getOperand(0);
55931   Ops[4] = SetCC.getOperand(1);
55932 
55933   return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
55934                                  cast<MemSDNode>(N)->getMemoryVT(),
55935                                  cast<MemSDNode>(N)->getMemOperand());
55936 }
55937 
combineSub(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)55938 static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
55939                           TargetLowering::DAGCombinerInfo &DCI,
55940                           const X86Subtarget &Subtarget) {
55941   SDValue Op0 = N->getOperand(0);
55942   SDValue Op1 = N->getOperand(1);
55943   SDLoc DL(N);
55944 
55945   // TODO: Add NoOpaque handling to isConstantIntBuildVectorOrConstantInt.
55946   auto IsNonOpaqueConstant = [&](SDValue Op) {
55947     if (SDNode *C = DAG.isConstantIntBuildVectorOrConstantInt(Op)) {
55948       if (auto *Cst = dyn_cast<ConstantSDNode>(C))
55949         return !Cst->isOpaque();
55950       return true;
55951     }
55952     return false;
55953   };
55954 
55955   // X86 can't encode an immediate LHS of a sub. See if we can push the
55956   // negation into a preceding instruction. If the RHS of the sub is a XOR with
55957   // one use and a constant, invert the immediate, saving one register.
55958   // However, ignore cases where C1 is 0, as those will become a NEG.
55959   // sub(C1, xor(X, C2)) -> add(xor(X, ~C2), C1+1)
55960   if (Op1.getOpcode() == ISD::XOR && IsNonOpaqueConstant(Op0) &&
55961       !isNullConstant(Op0) && IsNonOpaqueConstant(Op1.getOperand(1)) &&
55962       Op1->hasOneUse()) {
55963     EVT VT = Op0.getValueType();
55964     SDValue NewXor = DAG.getNode(ISD::XOR, SDLoc(Op1), VT, Op1.getOperand(0),
55965                                  DAG.getNOT(SDLoc(Op1), Op1.getOperand(1), VT));
55966     SDValue NewAdd =
55967         DAG.getNode(ISD::ADD, DL, VT, Op0, DAG.getConstant(1, DL, VT));
55968     return DAG.getNode(ISD::ADD, DL, VT, NewXor, NewAdd);
55969   }
55970 
55971   if (SDValue V = combineSubABS(N, DAG))
55972     return V;
55973 
55974   // Try to synthesize horizontal subs from subs of shuffles.
55975   if (SDValue V = combineToHorizontalAddSub(N, DAG, Subtarget))
55976     return V;
55977 
55978   // Fold SUB(X,ADC(Y,0,W)) -> SBB(X,Y,W)
55979   if (Op1.getOpcode() == X86ISD::ADC && Op1->hasOneUse() &&
55980       X86::isZeroNode(Op1.getOperand(1))) {
55981     assert(!Op1->hasAnyUseOfValue(1) && "Overflow bit in use");
55982     return DAG.getNode(X86ISD::SBB, SDLoc(Op1), Op1->getVTList(), Op0,
55983                        Op1.getOperand(0), Op1.getOperand(2));
55984   }
55985 
55986   // Fold SUB(X,SBB(Y,Z,W)) -> SUB(ADC(X,Z,W),Y)
55987   // Don't fold to ADC(0,0,W)/SETCC_CARRY pattern which will prevent more folds.
55988   if (Op1.getOpcode() == X86ISD::SBB && Op1->hasOneUse() &&
55989       !(X86::isZeroNode(Op0) && X86::isZeroNode(Op1.getOperand(1)))) {
55990     assert(!Op1->hasAnyUseOfValue(1) && "Overflow bit in use");
55991     SDValue ADC = DAG.getNode(X86ISD::ADC, SDLoc(Op1), Op1->getVTList(), Op0,
55992                               Op1.getOperand(1), Op1.getOperand(2));
55993     return DAG.getNode(ISD::SUB, DL, Op0.getValueType(), ADC.getValue(0),
55994                        Op1.getOperand(0));
55995   }
55996 
55997   if (SDValue V = combineXorSubCTLZ(N, DL, DAG, Subtarget))
55998     return V;
55999 
56000   if (SDValue V = combineAddOrSubToADCOrSBB(N, DL, DAG))
56001     return V;
56002 
56003   return combineSubSetcc(N, DAG);
56004 }
56005 
combineVectorCompare(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)56006 static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
56007                                     const X86Subtarget &Subtarget) {
56008   unsigned Opcode = N->getOpcode();
56009   assert((Opcode == X86ISD::PCMPEQ || Opcode == X86ISD::PCMPGT) &&
56010          "Unknown PCMP opcode");
56011 
56012   SDValue LHS = N->getOperand(0);
56013   SDValue RHS = N->getOperand(1);
56014   MVT VT = N->getSimpleValueType(0);
56015   unsigned EltBits = VT.getScalarSizeInBits();
56016   unsigned NumElts = VT.getVectorNumElements();
56017   SDLoc DL(N);
56018 
56019   if (LHS == RHS)
56020     return (Opcode == X86ISD::PCMPEQ) ? DAG.getAllOnesConstant(DL, VT)
56021                                       : DAG.getConstant(0, DL, VT);
56022 
56023   // Constant Folding.
56024   // PCMPEQ(X,UNDEF) -> UNDEF
56025   // PCMPGT(X,UNDEF) -> 0
56026   // PCMPGT(UNDEF,X) -> 0
56027   APInt LHSUndefs, RHSUndefs;
56028   SmallVector<APInt> LHSBits, RHSBits;
56029   if (getTargetConstantBitsFromNode(LHS, EltBits, LHSUndefs, LHSBits) &&
56030       getTargetConstantBitsFromNode(RHS, EltBits, RHSUndefs, RHSBits)) {
56031     APInt Ones = APInt::getAllOnes(EltBits);
56032     APInt Zero = APInt::getZero(EltBits);
56033     SmallVector<APInt> Results(NumElts);
56034     for (unsigned I = 0; I != NumElts; ++I) {
56035       if (Opcode == X86ISD::PCMPEQ) {
56036         Results[I] = (LHSBits[I] == RHSBits[I]) ? Ones : Zero;
56037       } else {
56038         bool AnyUndef = LHSUndefs[I] || RHSUndefs[I];
56039         Results[I] = (!AnyUndef && LHSBits[I].sgt(RHSBits[I])) ? Ones : Zero;
56040       }
56041     }
56042     if (Opcode == X86ISD::PCMPEQ)
56043       return getConstVector(Results, LHSUndefs | RHSUndefs, VT, DAG, DL);
56044     return getConstVector(Results, VT, DAG, DL);
56045   }
56046 
56047   return SDValue();
56048 }
56049 
56050 // Helper to determine if we can convert an integer comparison to a float
56051 // comparison byt casting the operands.
56052 static std::optional<unsigned>
CastIntSETCCtoFP(MVT VT,ISD::CondCode CC,unsigned NumSignificantBitsLHS,unsigned NumSignificantBitsRHS)56053 CastIntSETCCtoFP(MVT VT, ISD::CondCode CC, unsigned NumSignificantBitsLHS,
56054                  unsigned NumSignificantBitsRHS) {
56055   MVT SVT = VT.getScalarType();
56056   assert(SVT == MVT::f32 && "Only tested for float so far");
56057   const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(SVT);
56058   assert((CC == ISD::SETEQ || CC == ISD::SETGT) &&
56059          "Only PCMPEQ/PCMPGT currently supported");
56060 
56061   // TODO: Handle bitcastable integers.
56062 
56063   // For cvt + signed compare we need lhs and rhs to be exactly representable as
56064   // a fp value.
56065   unsigned FPPrec = APFloat::semanticsPrecision(Sem);
56066   if (FPPrec >= NumSignificantBitsLHS && FPPrec >= NumSignificantBitsRHS)
56067     return ISD::SINT_TO_FP;
56068 
56069   return std::nullopt;
56070 }
56071 
56072 /// Helper that combines an array of subvector ops as if they were the operands
56073 /// of a ISD::CONCAT_VECTORS node, but may have come from another source (e.g.
56074 /// ISD::INSERT_SUBVECTOR). The ops are assumed to be of the same type.
combineConcatVectorOps(const SDLoc & DL,MVT VT,ArrayRef<SDValue> Ops,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)56075 static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
56076                                       ArrayRef<SDValue> Ops, SelectionDAG &DAG,
56077                                       TargetLowering::DAGCombinerInfo &DCI,
56078                                       const X86Subtarget &Subtarget) {
56079   assert(Subtarget.hasAVX() && "AVX assumed for concat_vectors");
56080   unsigned EltSizeInBits = VT.getScalarSizeInBits();
56081 
56082   if (llvm::all_of(Ops, [](SDValue Op) { return Op.isUndef(); }))
56083     return DAG.getUNDEF(VT);
56084 
56085   if (llvm::all_of(Ops, [](SDValue Op) {
56086         return ISD::isBuildVectorAllZeros(Op.getNode());
56087       }))
56088     return getZeroVector(VT, Subtarget, DAG, DL);
56089 
56090   SDValue Op0 = Ops[0];
56091   bool IsSplat = llvm::all_equal(Ops);
56092   unsigned NumOps = Ops.size();
56093   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
56094   LLVMContext &Ctx = *DAG.getContext();
56095 
56096   // Repeated subvectors.
56097   if (IsSplat &&
56098       (VT.is256BitVector() || (VT.is512BitVector() && Subtarget.hasAVX512()))) {
56099     // If this broadcast is inserted into both halves, use a larger broadcast.
56100     if (Op0.getOpcode() == X86ISD::VBROADCAST)
56101       return DAG.getNode(Op0.getOpcode(), DL, VT, Op0.getOperand(0));
56102 
56103     // concat_vectors(movddup(x),movddup(x)) -> broadcast(x)
56104     if (Op0.getOpcode() == X86ISD::MOVDDUP && VT == MVT::v4f64 &&
56105         (Subtarget.hasAVX2() ||
56106          X86::mayFoldLoadIntoBroadcastFromMem(Op0.getOperand(0),
56107                                               VT.getScalarType(), Subtarget)))
56108       return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
56109                          DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f64,
56110                                      Op0.getOperand(0),
56111                                      DAG.getIntPtrConstant(0, DL)));
56112 
56113     // concat_vectors(scalar_to_vector(x),scalar_to_vector(x)) -> broadcast(x)
56114     if (Op0.getOpcode() == ISD::SCALAR_TO_VECTOR &&
56115         (Subtarget.hasAVX2() ||
56116          (EltSizeInBits >= 32 &&
56117           X86::mayFoldLoad(Op0.getOperand(0), Subtarget))) &&
56118         Op0.getOperand(0).getValueType() == VT.getScalarType())
56119       return DAG.getNode(X86ISD::VBROADCAST, DL, VT, Op0.getOperand(0));
56120 
56121     // concat_vectors(extract_subvector(broadcast(x)),
56122     //                extract_subvector(broadcast(x))) -> broadcast(x)
56123     // concat_vectors(extract_subvector(subv_broadcast(x)),
56124     //                extract_subvector(subv_broadcast(x))) -> subv_broadcast(x)
56125     if (Op0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
56126         Op0.getOperand(0).getValueType() == VT) {
56127       SDValue SrcVec = Op0.getOperand(0);
56128       if (SrcVec.getOpcode() == X86ISD::VBROADCAST ||
56129           SrcVec.getOpcode() == X86ISD::VBROADCAST_LOAD)
56130         return Op0.getOperand(0);
56131       if (SrcVec.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
56132           Op0.getValueType() == cast<MemSDNode>(SrcVec)->getMemoryVT())
56133         return Op0.getOperand(0);
56134     }
56135 
56136     // concat_vectors(permq(x),permq(x)) -> permq(concat_vectors(x,x))
56137     if (Op0.getOpcode() == X86ISD::VPERMI && Subtarget.useAVX512Regs() &&
56138         !X86::mayFoldLoad(Op0.getOperand(0), Subtarget))
56139       return DAG.getNode(Op0.getOpcode(), DL, VT,
56140                          DAG.getNode(ISD::CONCAT_VECTORS, DL, VT,
56141                                      Op0.getOperand(0), Op0.getOperand(0)),
56142                          Op0.getOperand(1));
56143   }
56144 
56145   // concat(extract_subvector(v0,c0), extract_subvector(v1,c1)) -> vperm2x128.
56146   // Only concat of subvector high halves which vperm2x128 is best at.
56147   // TODO: This should go in combineX86ShufflesRecursively eventually.
56148   if (VT.is256BitVector() && NumOps == 2) {
56149     SDValue Src0 = peekThroughBitcasts(Ops[0]);
56150     SDValue Src1 = peekThroughBitcasts(Ops[1]);
56151     if (Src0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
56152         Src1.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
56153       EVT SrcVT0 = Src0.getOperand(0).getValueType();
56154       EVT SrcVT1 = Src1.getOperand(0).getValueType();
56155       unsigned NumSrcElts0 = SrcVT0.getVectorNumElements();
56156       unsigned NumSrcElts1 = SrcVT1.getVectorNumElements();
56157       if (SrcVT0.is256BitVector() && SrcVT1.is256BitVector() &&
56158           Src0.getConstantOperandAPInt(1) == (NumSrcElts0 / 2) &&
56159           Src1.getConstantOperandAPInt(1) == (NumSrcElts1 / 2)) {
56160         return DAG.getNode(X86ISD::VPERM2X128, DL, VT,
56161                            DAG.getBitcast(VT, Src0.getOperand(0)),
56162                            DAG.getBitcast(VT, Src1.getOperand(0)),
56163                            DAG.getTargetConstant(0x31, DL, MVT::i8));
56164       }
56165     }
56166   }
56167 
56168   // Repeated opcode.
56169   // TODO - combineX86ShufflesRecursively should handle shuffle concatenation
56170   // but it currently struggles with different vector widths.
56171   if (llvm::all_of(Ops, [Op0](SDValue Op) {
56172         return Op.getOpcode() == Op0.getOpcode() && Op.hasOneUse();
56173       })) {
56174     auto ConcatSubOperand = [&](EVT VT, ArrayRef<SDValue> SubOps, unsigned I) {
56175       SmallVector<SDValue> Subs;
56176       for (SDValue SubOp : SubOps)
56177         Subs.push_back(SubOp.getOperand(I));
56178       // Attempt to peek through bitcasts and concat the original subvectors.
56179       EVT SubVT = peekThroughBitcasts(Subs[0]).getValueType();
56180       if (SubVT.isSimple() && SubVT.isVector()) {
56181         EVT ConcatVT =
56182             EVT::getVectorVT(*DAG.getContext(), SubVT.getScalarType(),
56183                              SubVT.getVectorElementCount() * Subs.size());
56184         for (SDValue &Sub : Subs)
56185           Sub = DAG.getBitcast(SubVT, Sub);
56186         return DAG.getBitcast(
56187             VT, DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, Subs));
56188       }
56189       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
56190     };
56191     auto IsConcatFree = [](MVT VT, ArrayRef<SDValue> SubOps, unsigned Op) {
56192       bool AllConstants = true;
56193       bool AllSubVectors = true;
56194       for (unsigned I = 0, E = SubOps.size(); I != E; ++I) {
56195         SDValue Sub = SubOps[I].getOperand(Op);
56196         unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
56197         SDValue BC = peekThroughBitcasts(Sub);
56198         AllConstants &= ISD::isBuildVectorOfConstantSDNodes(BC.getNode()) ||
56199                         ISD::isBuildVectorOfConstantFPSDNodes(BC.getNode());
56200         AllSubVectors &= Sub.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
56201                          Sub.getOperand(0).getValueType() == VT &&
56202                          Sub.getConstantOperandAPInt(1) == (I * NumSubElts);
56203       }
56204       return AllConstants || AllSubVectors;
56205     };
56206 
56207     switch (Op0.getOpcode()) {
56208     case X86ISD::VBROADCAST: {
56209       if (!IsSplat && llvm::all_of(Ops, [](SDValue Op) {
56210             return Op.getOperand(0).getValueType().is128BitVector();
56211           })) {
56212         if (VT == MVT::v4f64 || VT == MVT::v4i64)
56213           return DAG.getNode(X86ISD::UNPCKL, DL, VT,
56214                              ConcatSubOperand(VT, Ops, 0),
56215                              ConcatSubOperand(VT, Ops, 0));
56216         // TODO: Add pseudo v8i32 PSHUFD handling to AVX1Only targets.
56217         if (VT == MVT::v8f32 || (VT == MVT::v8i32 && Subtarget.hasInt256()))
56218           return DAG.getNode(VT == MVT::v8f32 ? X86ISD::VPERMILPI
56219                                               : X86ISD::PSHUFD,
56220                              DL, VT, ConcatSubOperand(VT, Ops, 0),
56221                              getV4X86ShuffleImm8ForMask({0, 0, 0, 0}, DL, DAG));
56222       }
56223       break;
56224     }
56225     case X86ISD::MOVDDUP:
56226     case X86ISD::MOVSHDUP:
56227     case X86ISD::MOVSLDUP: {
56228       if (!IsSplat)
56229         return DAG.getNode(Op0.getOpcode(), DL, VT,
56230                            ConcatSubOperand(VT, Ops, 0));
56231       break;
56232     }
56233     case X86ISD::SHUFP: {
56234       // Add SHUFPD support if/when necessary.
56235       if (!IsSplat && VT.getScalarType() == MVT::f32 &&
56236           llvm::all_of(Ops, [Op0](SDValue Op) {
56237             return Op.getOperand(2) == Op0.getOperand(2);
56238           })) {
56239         return DAG.getNode(Op0.getOpcode(), DL, VT,
56240                            ConcatSubOperand(VT, Ops, 0),
56241                            ConcatSubOperand(VT, Ops, 1), Op0.getOperand(2));
56242       }
56243       break;
56244     }
56245     case X86ISD::UNPCKH:
56246     case X86ISD::UNPCKL: {
56247       // Don't concatenate build_vector patterns.
56248       if (!IsSplat && EltSizeInBits >= 32 &&
56249           ((VT.is256BitVector() && Subtarget.hasInt256()) ||
56250            (VT.is512BitVector() && Subtarget.useAVX512Regs())) &&
56251           none_of(Ops, [](SDValue Op) {
56252             return peekThroughBitcasts(Op.getOperand(0)).getOpcode() ==
56253                        ISD::SCALAR_TO_VECTOR ||
56254                    peekThroughBitcasts(Op.getOperand(1)).getOpcode() ==
56255                        ISD::SCALAR_TO_VECTOR;
56256           })) {
56257         return DAG.getNode(Op0.getOpcode(), DL, VT,
56258                            ConcatSubOperand(VT, Ops, 0),
56259                            ConcatSubOperand(VT, Ops, 1));
56260       }
56261       break;
56262     }
56263     case X86ISD::PSHUFHW:
56264     case X86ISD::PSHUFLW:
56265     case X86ISD::PSHUFD:
56266       if (!IsSplat && NumOps == 2 && VT.is256BitVector() &&
56267           Subtarget.hasInt256() && Op0.getOperand(1) == Ops[1].getOperand(1)) {
56268         return DAG.getNode(Op0.getOpcode(), DL, VT,
56269                            ConcatSubOperand(VT, Ops, 0), Op0.getOperand(1));
56270       }
56271       [[fallthrough]];
56272     case X86ISD::VPERMILPI:
56273       if (!IsSplat && EltSizeInBits == 32 &&
56274           (VT.is256BitVector() ||
56275            (VT.is512BitVector() && Subtarget.useAVX512Regs())) &&
56276           all_of(Ops, [&Op0](SDValue Op) {
56277             return Op0.getOperand(1) == Op.getOperand(1);
56278           })) {
56279         MVT FloatVT = VT.changeVectorElementType(MVT::f32);
56280         SDValue Res = DAG.getBitcast(FloatVT, ConcatSubOperand(VT, Ops, 0));
56281         Res =
56282             DAG.getNode(X86ISD::VPERMILPI, DL, FloatVT, Res, Op0.getOperand(1));
56283         return DAG.getBitcast(VT, Res);
56284       }
56285       if (!IsSplat && NumOps == 2 && VT == MVT::v4f64) {
56286         uint64_t Idx0 = Ops[0].getConstantOperandVal(1);
56287         uint64_t Idx1 = Ops[1].getConstantOperandVal(1);
56288         uint64_t Idx = ((Idx1 & 3) << 2) | (Idx0 & 3);
56289         return DAG.getNode(Op0.getOpcode(), DL, VT,
56290                            ConcatSubOperand(VT, Ops, 0),
56291                            DAG.getTargetConstant(Idx, DL, MVT::i8));
56292       }
56293       break;
56294     case X86ISD::PSHUFB:
56295     case X86ISD::PSADBW:
56296     case X86ISD::VPMADDUBSW:
56297     case X86ISD::VPMADDWD:
56298       if (!IsSplat && ((VT.is256BitVector() && Subtarget.hasInt256()) ||
56299                        (VT.is512BitVector() && Subtarget.useBWIRegs()))) {
56300         MVT SrcVT = Op0.getOperand(0).getSimpleValueType();
56301         SrcVT = MVT::getVectorVT(SrcVT.getScalarType(),
56302                                  NumOps * SrcVT.getVectorNumElements());
56303         return DAG.getNode(Op0.getOpcode(), DL, VT,
56304                            ConcatSubOperand(SrcVT, Ops, 0),
56305                            ConcatSubOperand(SrcVT, Ops, 1));
56306       }
56307       break;
56308     case X86ISD::VPERMV:
56309       if (!IsSplat && NumOps == 2 &&
56310           (VT.is512BitVector() && Subtarget.useAVX512Regs())) {
56311         MVT OpVT = Op0.getSimpleValueType();
56312         int NumSrcElts = OpVT.getVectorNumElements();
56313         SmallVector<int, 64> ConcatMask;
56314         for (unsigned i = 0; i != NumOps; ++i) {
56315           SmallVector<int, 64> SubMask;
56316           SmallVector<SDValue, 2> SubOps;
56317           if (!getTargetShuffleMask(Ops[i], false, SubOps, SubMask))
56318             break;
56319           for (int M : SubMask) {
56320             if (0 <= M)
56321               M += i * NumSrcElts;
56322             ConcatMask.push_back(M);
56323           }
56324         }
56325         if (ConcatMask.size() == (NumOps * NumSrcElts)) {
56326           SDValue Src = concatSubVectors(Ops[0].getOperand(1),
56327                                          Ops[1].getOperand(1), DAG, DL);
56328           MVT IntMaskSVT = MVT::getIntegerVT(EltSizeInBits);
56329           MVT IntMaskVT = MVT::getVectorVT(IntMaskSVT, NumOps * NumSrcElts);
56330           SDValue Mask = getConstVector(ConcatMask, IntMaskVT, DAG, DL, true);
56331           return DAG.getNode(X86ISD::VPERMV, DL, VT, Mask, Src);
56332         }
56333       }
56334       break;
56335     case X86ISD::VPERMV3:
56336       if (!IsSplat && NumOps == 2 && VT.is512BitVector()) {
56337         MVT OpVT = Op0.getSimpleValueType();
56338         int NumSrcElts = OpVT.getVectorNumElements();
56339         SmallVector<int, 64> ConcatMask;
56340         for (unsigned i = 0; i != NumOps; ++i) {
56341           SmallVector<int, 64> SubMask;
56342           SmallVector<SDValue, 2> SubOps;
56343           if (!getTargetShuffleMask(Ops[i], false, SubOps, SubMask))
56344             break;
56345           for (int M : SubMask) {
56346             if (0 <= M) {
56347               M += M < NumSrcElts ? 0 : NumSrcElts;
56348               M += i * NumSrcElts;
56349             }
56350             ConcatMask.push_back(M);
56351           }
56352         }
56353         if (ConcatMask.size() == (NumOps * NumSrcElts)) {
56354           SDValue Src0 = concatSubVectors(Ops[0].getOperand(0),
56355                                           Ops[1].getOperand(0), DAG, DL);
56356           SDValue Src1 = concatSubVectors(Ops[0].getOperand(2),
56357                                           Ops[1].getOperand(2), DAG, DL);
56358           MVT IntMaskSVT = MVT::getIntegerVT(EltSizeInBits);
56359           MVT IntMaskVT = MVT::getVectorVT(IntMaskSVT, NumOps * NumSrcElts);
56360           SDValue Mask = getConstVector(ConcatMask, IntMaskVT, DAG, DL, true);
56361           return DAG.getNode(X86ISD::VPERMV3, DL, VT, Src0, Mask, Src1);
56362         }
56363       }
56364       break;
56365     case X86ISD::VPERM2X128: {
56366       if (!IsSplat && VT.is512BitVector() && Subtarget.useAVX512Regs()) {
56367         assert(NumOps == 2 && "Bad concat_vectors operands");
56368         unsigned Imm0 = Ops[0].getConstantOperandVal(2);
56369         unsigned Imm1 = Ops[1].getConstantOperandVal(2);
56370         // TODO: Handle zero'd subvectors.
56371         if ((Imm0 & 0x88) == 0 && (Imm1 & 0x88) == 0) {
56372           int Mask[4] = {(int)(Imm0 & 0x03), (int)((Imm0 >> 4) & 0x3), (int)(Imm1 & 0x03),
56373                          (int)((Imm1 >> 4) & 0x3)};
56374           MVT ShuffleVT = VT.isFloatingPoint() ? MVT::v8f64 : MVT::v8i64;
56375           SDValue LHS = concatSubVectors(Ops[0].getOperand(0),
56376                                          Ops[0].getOperand(1), DAG, DL);
56377           SDValue RHS = concatSubVectors(Ops[1].getOperand(0),
56378                                          Ops[1].getOperand(1), DAG, DL);
56379           SDValue Res = DAG.getNode(X86ISD::SHUF128, DL, ShuffleVT,
56380                                     DAG.getBitcast(ShuffleVT, LHS),
56381                                     DAG.getBitcast(ShuffleVT, RHS),
56382                                     getV4X86ShuffleImm8ForMask(Mask, DL, DAG));
56383           return DAG.getBitcast(VT, Res);
56384         }
56385       }
56386       break;
56387     }
56388     case X86ISD::SHUF128: {
56389       if (!IsSplat && NumOps == 2 && VT.is512BitVector()) {
56390         unsigned Imm0 = Ops[0].getConstantOperandVal(2);
56391         unsigned Imm1 = Ops[1].getConstantOperandVal(2);
56392         unsigned Imm = ((Imm0 & 1) << 0) | ((Imm0 & 2) << 1) | 0x08 |
56393                        ((Imm1 & 1) << 4) | ((Imm1 & 2) << 5) | 0x80;
56394         SDValue LHS = concatSubVectors(Ops[0].getOperand(0),
56395                                        Ops[0].getOperand(1), DAG, DL);
56396         SDValue RHS = concatSubVectors(Ops[1].getOperand(0),
56397                                        Ops[1].getOperand(1), DAG, DL);
56398         return DAG.getNode(X86ISD::SHUF128, DL, VT, LHS, RHS,
56399                            DAG.getTargetConstant(Imm, DL, MVT::i8));
56400       }
56401       break;
56402     }
56403     case ISD::TRUNCATE:
56404       if (!IsSplat && NumOps == 2 && VT.is256BitVector()) {
56405         EVT SrcVT = Ops[0].getOperand(0).getValueType();
56406         if (SrcVT.is256BitVector() && SrcVT.isSimple() &&
56407             SrcVT == Ops[1].getOperand(0).getValueType() &&
56408             Subtarget.useAVX512Regs() &&
56409             Subtarget.getPreferVectorWidth() >= 512 &&
56410             (SrcVT.getScalarSizeInBits() > 16 || Subtarget.useBWIRegs())) {
56411           EVT NewSrcVT = SrcVT.getDoubleNumVectorElementsVT(Ctx);
56412           return DAG.getNode(ISD::TRUNCATE, DL, VT,
56413                              ConcatSubOperand(NewSrcVT, Ops, 0));
56414         }
56415       }
56416       break;
56417     case X86ISD::VSHLI:
56418     case X86ISD::VSRLI:
56419       // Special case: SHL/SRL AVX1 V4i64 by 32-bits can lower as a shuffle.
56420       // TODO: Move this to LowerShiftByScalarImmediate?
56421       if (VT == MVT::v4i64 && !Subtarget.hasInt256() &&
56422           llvm::all_of(Ops, [](SDValue Op) {
56423             return Op.getConstantOperandAPInt(1) == 32;
56424           })) {
56425         SDValue Res = DAG.getBitcast(MVT::v8i32, ConcatSubOperand(VT, Ops, 0));
56426         SDValue Zero = getZeroVector(MVT::v8i32, Subtarget, DAG, DL);
56427         if (Op0.getOpcode() == X86ISD::VSHLI) {
56428           Res = DAG.getVectorShuffle(MVT::v8i32, DL, Res, Zero,
56429                                      {8, 0, 8, 2, 8, 4, 8, 6});
56430         } else {
56431           Res = DAG.getVectorShuffle(MVT::v8i32, DL, Res, Zero,
56432                                      {1, 8, 3, 8, 5, 8, 7, 8});
56433         }
56434         return DAG.getBitcast(VT, Res);
56435       }
56436       [[fallthrough]];
56437     case X86ISD::VSRAI:
56438     case X86ISD::VSHL:
56439     case X86ISD::VSRL:
56440     case X86ISD::VSRA:
56441       if (((VT.is256BitVector() && Subtarget.hasInt256()) ||
56442            (VT.is512BitVector() && Subtarget.useAVX512Regs() &&
56443             (EltSizeInBits >= 32 || Subtarget.useBWIRegs()))) &&
56444           llvm::all_of(Ops, [Op0](SDValue Op) {
56445             return Op0.getOperand(1) == Op.getOperand(1);
56446           })) {
56447         return DAG.getNode(Op0.getOpcode(), DL, VT,
56448                            ConcatSubOperand(VT, Ops, 0), Op0.getOperand(1));
56449       }
56450       break;
56451     case X86ISD::VPERMI:
56452     case X86ISD::VROTLI:
56453     case X86ISD::VROTRI:
56454       if (VT.is512BitVector() && Subtarget.useAVX512Regs() &&
56455           llvm::all_of(Ops, [Op0](SDValue Op) {
56456             return Op0.getOperand(1) == Op.getOperand(1);
56457           })) {
56458         return DAG.getNode(Op0.getOpcode(), DL, VT,
56459                            ConcatSubOperand(VT, Ops, 0), Op0.getOperand(1));
56460       }
56461       break;
56462     case ISD::AND:
56463     case ISD::OR:
56464     case ISD::XOR:
56465     case X86ISD::ANDNP:
56466       if (!IsSplat && ((VT.is256BitVector() && Subtarget.hasInt256()) ||
56467                        (VT.is512BitVector() && Subtarget.useAVX512Regs()))) {
56468         return DAG.getNode(Op0.getOpcode(), DL, VT,
56469                            ConcatSubOperand(VT, Ops, 0),
56470                            ConcatSubOperand(VT, Ops, 1));
56471       }
56472       break;
56473     case X86ISD::PCMPEQ:
56474     case X86ISD::PCMPGT:
56475       if (!IsSplat && VT.is256BitVector() &&
56476           (Subtarget.hasInt256() || VT == MVT::v8i32) &&
56477           (IsConcatFree(VT, Ops, 0) || IsConcatFree(VT, Ops, 1))) {
56478         if (Subtarget.hasInt256())
56479           return DAG.getNode(Op0.getOpcode(), DL, VT,
56480                              ConcatSubOperand(VT, Ops, 0),
56481                              ConcatSubOperand(VT, Ops, 1));
56482 
56483         // Without AVX2, see if we can cast the values to v8f32 and use fcmp.
56484         // TODO: Handle v4f64 as well?
56485         unsigned MaxSigBitsLHS = 0, MaxSigBitsRHS = 0;
56486         for (unsigned I = 0; I != NumOps; ++I) {
56487           MaxSigBitsLHS =
56488               std::max(MaxSigBitsLHS,
56489                        DAG.ComputeMaxSignificantBits(Ops[I].getOperand(0)));
56490           MaxSigBitsRHS =
56491               std::max(MaxSigBitsRHS,
56492                        DAG.ComputeMaxSignificantBits(Ops[I].getOperand(1)));
56493           if (MaxSigBitsLHS == EltSizeInBits && MaxSigBitsRHS == EltSizeInBits)
56494             break;
56495         }
56496 
56497         ISD::CondCode ICC =
56498             Op0.getOpcode() == X86ISD::PCMPEQ ? ISD::SETEQ : ISD::SETGT;
56499         ISD::CondCode FCC =
56500             Op0.getOpcode() == X86ISD::PCMPEQ ? ISD::SETOEQ : ISD::SETOGT;
56501 
56502         MVT FpSVT = MVT::getFloatingPointVT(EltSizeInBits);
56503         MVT FpVT = VT.changeVectorElementType(FpSVT);
56504 
56505         if (std::optional<unsigned> CastOpc =
56506                 CastIntSETCCtoFP(FpVT, ICC, MaxSigBitsLHS, MaxSigBitsRHS)) {
56507           SDValue LHS = ConcatSubOperand(VT, Ops, 0);
56508           SDValue RHS = ConcatSubOperand(VT, Ops, 1);
56509           LHS = DAG.getNode(*CastOpc, DL, FpVT, LHS);
56510           RHS = DAG.getNode(*CastOpc, DL, FpVT, RHS);
56511 
56512           bool IsAlwaysSignaling;
56513           unsigned FSETCC =
56514               translateX86FSETCC(FCC, LHS, RHS, IsAlwaysSignaling);
56515           return DAG.getBitcast(
56516               VT, DAG.getNode(X86ISD::CMPP, DL, FpVT, LHS, RHS,
56517                               DAG.getTargetConstant(FSETCC, DL, MVT::i8)));
56518         }
56519       }
56520       break;
56521     case ISD::CTPOP:
56522     case ISD::CTTZ:
56523     case ISD::CTLZ:
56524     case ISD::CTTZ_ZERO_UNDEF:
56525     case ISD::CTLZ_ZERO_UNDEF:
56526       if (!IsSplat && ((VT.is256BitVector() && Subtarget.hasInt256()) ||
56527                        (VT.is512BitVector() && Subtarget.useBWIRegs()))) {
56528         return DAG.getNode(Op0.getOpcode(), DL, VT,
56529                            ConcatSubOperand(VT, Ops, 0));
56530       }
56531       break;
56532     case X86ISD::GF2P8AFFINEQB:
56533       if (!IsSplat &&
56534           (VT.is256BitVector() ||
56535            (VT.is512BitVector() && Subtarget.useAVX512Regs())) &&
56536           llvm::all_of(Ops, [Op0](SDValue Op) {
56537             return Op0.getOperand(2) == Op.getOperand(2);
56538           })) {
56539         return DAG.getNode(Op0.getOpcode(), DL, VT,
56540                            ConcatSubOperand(VT, Ops, 0),
56541                            ConcatSubOperand(VT, Ops, 1), Op0.getOperand(2));
56542       }
56543       break;
56544     case ISD::ADD:
56545     case ISD::SUB:
56546     case ISD::MUL:
56547       if (!IsSplat && ((VT.is256BitVector() && Subtarget.hasInt256()) ||
56548                        (VT.is512BitVector() && Subtarget.useAVX512Regs() &&
56549                         (EltSizeInBits >= 32 || Subtarget.useBWIRegs())))) {
56550         return DAG.getNode(Op0.getOpcode(), DL, VT,
56551                            ConcatSubOperand(VT, Ops, 0),
56552                            ConcatSubOperand(VT, Ops, 1));
56553       }
56554       break;
56555     // Due to VADD, VSUB, VMUL can executed on more ports than VINSERT and
56556     // their latency are short, so here we don't replace them unless we won't
56557     // introduce extra VINSERT.
56558     case ISD::FADD:
56559     case ISD::FSUB:
56560     case ISD::FMUL:
56561       if (!IsSplat && (IsConcatFree(VT, Ops, 0) || IsConcatFree(VT, Ops, 1)) &&
56562           (VT.is256BitVector() ||
56563            (VT.is512BitVector() && Subtarget.useAVX512Regs()))) {
56564         return DAG.getNode(Op0.getOpcode(), DL, VT,
56565                            ConcatSubOperand(VT, Ops, 0),
56566                            ConcatSubOperand(VT, Ops, 1));
56567       }
56568       break;
56569     case ISD::FDIV:
56570       if (!IsSplat && (VT.is256BitVector() ||
56571                        (VT.is512BitVector() && Subtarget.useAVX512Regs()))) {
56572         return DAG.getNode(Op0.getOpcode(), DL, VT,
56573                            ConcatSubOperand(VT, Ops, 0),
56574                            ConcatSubOperand(VT, Ops, 1));
56575       }
56576       break;
56577     case X86ISD::HADD:
56578     case X86ISD::HSUB:
56579     case X86ISD::FHADD:
56580     case X86ISD::FHSUB:
56581       if (!IsSplat && VT.is256BitVector() &&
56582           (VT.isFloatingPoint() || Subtarget.hasInt256())) {
56583         return DAG.getNode(Op0.getOpcode(), DL, VT,
56584                            ConcatSubOperand(VT, Ops, 0),
56585                            ConcatSubOperand(VT, Ops, 1));
56586       }
56587       break;
56588     case X86ISD::PACKSS:
56589     case X86ISD::PACKUS:
56590       if (!IsSplat && ((VT.is256BitVector() && Subtarget.hasInt256()) ||
56591                        (VT.is512BitVector() && Subtarget.useBWIRegs()))) {
56592         MVT SrcVT = Op0.getOperand(0).getSimpleValueType();
56593         SrcVT = MVT::getVectorVT(SrcVT.getScalarType(),
56594                                  NumOps * SrcVT.getVectorNumElements());
56595         return DAG.getNode(Op0.getOpcode(), DL, VT,
56596                            ConcatSubOperand(SrcVT, Ops, 0),
56597                            ConcatSubOperand(SrcVT, Ops, 1));
56598       }
56599       break;
56600     case X86ISD::PALIGNR:
56601       if (!IsSplat &&
56602           ((VT.is256BitVector() && Subtarget.hasInt256()) ||
56603            (VT.is512BitVector() && Subtarget.useBWIRegs())) &&
56604           llvm::all_of(Ops, [Op0](SDValue Op) {
56605             return Op0.getOperand(2) == Op.getOperand(2);
56606           })) {
56607         return DAG.getNode(Op0.getOpcode(), DL, VT,
56608                            ConcatSubOperand(VT, Ops, 0),
56609                            ConcatSubOperand(VT, Ops, 1), Op0.getOperand(2));
56610       }
56611       break;
56612     case X86ISD::BLENDI:
56613       if (NumOps == 2 && VT.is512BitVector() && Subtarget.useBWIRegs()) {
56614         uint64_t Mask0 = Ops[0].getConstantOperandVal(2);
56615         uint64_t Mask1 = Ops[1].getConstantOperandVal(2);
56616         // MVT::v16i16 has repeated blend mask.
56617         if (Op0.getSimpleValueType() == MVT::v16i16) {
56618           Mask0 = (Mask0 << 8) | Mask0;
56619           Mask1 = (Mask1 << 8) | Mask1;
56620         }
56621         uint64_t Mask = (Mask1 << (VT.getVectorNumElements() / 2)) | Mask0;
56622         MVT MaskSVT = MVT::getIntegerVT(VT.getVectorNumElements());
56623         MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
56624         SDValue Sel =
56625             DAG.getBitcast(MaskVT, DAG.getConstant(Mask, DL, MaskSVT));
56626         return DAG.getSelect(DL, VT, Sel, ConcatSubOperand(VT, Ops, 1),
56627                              ConcatSubOperand(VT, Ops, 0));
56628       }
56629       break;
56630     case ISD::VSELECT:
56631       if (!IsSplat && Subtarget.hasAVX512() &&
56632           (VT.is256BitVector() ||
56633            (VT.is512BitVector() && Subtarget.useAVX512Regs())) &&
56634           (EltSizeInBits >= 32 || Subtarget.hasBWI())) {
56635         EVT SelVT = Ops[0].getOperand(0).getValueType();
56636         if (SelVT.getVectorElementType() == MVT::i1) {
56637           SelVT = EVT::getVectorVT(Ctx, MVT::i1,
56638                                    NumOps * SelVT.getVectorNumElements());
56639           if (TLI.isTypeLegal(SelVT))
56640             return DAG.getNode(Op0.getOpcode(), DL, VT,
56641                                ConcatSubOperand(SelVT.getSimpleVT(), Ops, 0),
56642                                ConcatSubOperand(VT, Ops, 1),
56643                                ConcatSubOperand(VT, Ops, 2));
56644         }
56645       }
56646       [[fallthrough]];
56647     case X86ISD::BLENDV:
56648       if (!IsSplat && VT.is256BitVector() && NumOps == 2 &&
56649           (EltSizeInBits >= 32 || Subtarget.hasInt256()) &&
56650           IsConcatFree(VT, Ops, 1) && IsConcatFree(VT, Ops, 2)) {
56651         EVT SelVT = Ops[0].getOperand(0).getValueType();
56652         SelVT = SelVT.getDoubleNumVectorElementsVT(Ctx);
56653         if (TLI.isTypeLegal(SelVT))
56654           return DAG.getNode(Op0.getOpcode(), DL, VT,
56655                              ConcatSubOperand(SelVT.getSimpleVT(), Ops, 0),
56656                              ConcatSubOperand(VT, Ops, 1),
56657                              ConcatSubOperand(VT, Ops, 2));
56658       }
56659       break;
56660     }
56661   }
56662 
56663   // Fold subvector loads into one.
56664   // If needed, look through bitcasts to get to the load.
56665   if (auto *FirstLd = dyn_cast<LoadSDNode>(peekThroughBitcasts(Op0))) {
56666     unsigned Fast;
56667     const X86TargetLowering *TLI = Subtarget.getTargetLowering();
56668     if (TLI->allowsMemoryAccess(Ctx, DAG.getDataLayout(), VT,
56669                                 *FirstLd->getMemOperand(), &Fast) &&
56670         Fast) {
56671       if (SDValue Ld =
56672               EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
56673         return Ld;
56674     }
56675   }
56676 
56677   // Attempt to fold target constant loads.
56678   if (all_of(Ops, [](SDValue Op) { return getTargetConstantFromNode(Op); })) {
56679     SmallVector<APInt> EltBits;
56680     APInt UndefElts = APInt::getZero(VT.getVectorNumElements());
56681     for (unsigned I = 0; I != NumOps; ++I) {
56682       APInt OpUndefElts;
56683       SmallVector<APInt> OpEltBits;
56684       if (!getTargetConstantBitsFromNode(Ops[I], EltSizeInBits, OpUndefElts,
56685                                          OpEltBits, /*AllowWholeUndefs*/ true,
56686                                          /*AllowPartialUndefs*/ false))
56687         break;
56688       EltBits.append(OpEltBits);
56689       UndefElts.insertBits(OpUndefElts, I * OpUndefElts.getBitWidth());
56690     }
56691     if (EltBits.size() == VT.getVectorNumElements()) {
56692       Constant *C = getConstantVector(VT, EltBits, UndefElts, Ctx);
56693       MVT PVT = TLI.getPointerTy(DAG.getDataLayout());
56694       SDValue CV = DAG.getConstantPool(C, PVT);
56695       MachineFunction &MF = DAG.getMachineFunction();
56696       MachinePointerInfo MPI = MachinePointerInfo::getConstantPool(MF);
56697       SDValue Ld = DAG.getLoad(VT, DL, DAG.getEntryNode(), CV, MPI);
56698       SDValue Sub = extractSubVector(Ld, 0, DAG, DL, Op0.getValueSizeInBits());
56699       DAG.ReplaceAllUsesOfValueWith(Op0, Sub);
56700       return Ld;
56701     }
56702   }
56703 
56704   // If this simple subvector or scalar/subvector broadcast_load is inserted
56705   // into both halves, use a larger broadcast_load. Update other uses to use
56706   // an extracted subvector.
56707   if (IsSplat &&
56708       (VT.is256BitVector() || (VT.is512BitVector() && Subtarget.hasAVX512()))) {
56709     if (ISD::isNormalLoad(Op0.getNode()) ||
56710         Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
56711         Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) {
56712       auto *Mem = cast<MemSDNode>(Op0);
56713       unsigned Opc = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD
56714                          ? X86ISD::VBROADCAST_LOAD
56715                          : X86ISD::SUBV_BROADCAST_LOAD;
56716       if (SDValue BcastLd =
56717               getBROADCAST_LOAD(Opc, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) {
56718         SDValue BcastSrc =
56719             extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits());
56720         DAG.ReplaceAllUsesOfValueWith(Op0, BcastSrc);
56721         return BcastLd;
56722       }
56723     }
56724   }
56725 
56726   // If we're splatting a 128-bit subvector to 512-bits, use SHUF128 directly.
56727   if (IsSplat && NumOps == 4 && VT.is512BitVector() &&
56728       Subtarget.useAVX512Regs()) {
56729     MVT ShuffleVT = VT.isFloatingPoint() ? MVT::v8f64 : MVT::v8i64;
56730     SDValue Res = widenSubVector(Op0, false, Subtarget, DAG, DL, 512);
56731     Res = DAG.getBitcast(ShuffleVT, Res);
56732     Res = DAG.getNode(X86ISD::SHUF128, DL, ShuffleVT, Res, Res,
56733                       getV4X86ShuffleImm8ForMask({0, 0, 0, 0}, DL, DAG));
56734     return DAG.getBitcast(VT, Res);
56735   }
56736 
56737   return SDValue();
56738 }
56739 
combineCONCAT_VECTORS(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)56740 static SDValue combineCONCAT_VECTORS(SDNode *N, SelectionDAG &DAG,
56741                                      TargetLowering::DAGCombinerInfo &DCI,
56742                                      const X86Subtarget &Subtarget) {
56743   EVT VT = N->getValueType(0);
56744   EVT SrcVT = N->getOperand(0).getValueType();
56745   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
56746   SmallVector<SDValue, 4> Ops(N->op_begin(), N->op_end());
56747 
56748   if (VT.getVectorElementType() == MVT::i1) {
56749     // Attempt to constant fold.
56750     unsigned SubSizeInBits = SrcVT.getSizeInBits();
56751     APInt Constant = APInt::getZero(VT.getSizeInBits());
56752     for (unsigned I = 0, E = Ops.size(); I != E; ++I) {
56753       auto *C = dyn_cast<ConstantSDNode>(peekThroughBitcasts(Ops[I]));
56754       if (!C) break;
56755       Constant.insertBits(C->getAPIntValue(), I * SubSizeInBits);
56756       if (I == (E - 1)) {
56757         EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
56758         if (TLI.isTypeLegal(IntVT))
56759           return DAG.getBitcast(VT, DAG.getConstant(Constant, SDLoc(N), IntVT));
56760       }
56761     }
56762 
56763     // Don't do anything else for i1 vectors.
56764     return SDValue();
56765   }
56766 
56767   if (Subtarget.hasAVX() && TLI.isTypeLegal(VT) && TLI.isTypeLegal(SrcVT)) {
56768     if (SDValue R = combineConcatVectorOps(SDLoc(N), VT.getSimpleVT(), Ops, DAG,
56769                                            DCI, Subtarget))
56770       return R;
56771   }
56772 
56773   return SDValue();
56774 }
56775 
combineINSERT_SUBVECTOR(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)56776 static SDValue combineINSERT_SUBVECTOR(SDNode *N, SelectionDAG &DAG,
56777                                        TargetLowering::DAGCombinerInfo &DCI,
56778                                        const X86Subtarget &Subtarget) {
56779   if (DCI.isBeforeLegalizeOps())
56780     return SDValue();
56781 
56782   MVT OpVT = N->getSimpleValueType(0);
56783 
56784   bool IsI1Vector = OpVT.getVectorElementType() == MVT::i1;
56785 
56786   SDLoc dl(N);
56787   SDValue Vec = N->getOperand(0);
56788   SDValue SubVec = N->getOperand(1);
56789 
56790   uint64_t IdxVal = N->getConstantOperandVal(2);
56791   MVT SubVecVT = SubVec.getSimpleValueType();
56792 
56793   if (Vec.isUndef() && SubVec.isUndef())
56794     return DAG.getUNDEF(OpVT);
56795 
56796   // Inserting undefs/zeros into zeros/undefs is a zero vector.
56797   if ((Vec.isUndef() || ISD::isBuildVectorAllZeros(Vec.getNode())) &&
56798       (SubVec.isUndef() || ISD::isBuildVectorAllZeros(SubVec.getNode())))
56799     return getZeroVector(OpVT, Subtarget, DAG, dl);
56800 
56801   if (ISD::isBuildVectorAllZeros(Vec.getNode())) {
56802     // If we're inserting into a zero vector and then into a larger zero vector,
56803     // just insert into the larger zero vector directly.
56804     if (SubVec.getOpcode() == ISD::INSERT_SUBVECTOR &&
56805         ISD::isBuildVectorAllZeros(SubVec.getOperand(0).getNode())) {
56806       uint64_t Idx2Val = SubVec.getConstantOperandVal(2);
56807       return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT,
56808                          getZeroVector(OpVT, Subtarget, DAG, dl),
56809                          SubVec.getOperand(1),
56810                          DAG.getIntPtrConstant(IdxVal + Idx2Val, dl));
56811     }
56812 
56813     // If we're inserting into a zero vector and our input was extracted from an
56814     // insert into a zero vector of the same type and the extraction was at
56815     // least as large as the original insertion. Just insert the original
56816     // subvector into a zero vector.
56817     if (SubVec.getOpcode() == ISD::EXTRACT_SUBVECTOR && IdxVal == 0 &&
56818         isNullConstant(SubVec.getOperand(1)) &&
56819         SubVec.getOperand(0).getOpcode() == ISD::INSERT_SUBVECTOR) {
56820       SDValue Ins = SubVec.getOperand(0);
56821       if (isNullConstant(Ins.getOperand(2)) &&
56822           ISD::isBuildVectorAllZeros(Ins.getOperand(0).getNode()) &&
56823           Ins.getOperand(1).getValueSizeInBits().getFixedValue() <=
56824               SubVecVT.getFixedSizeInBits())
56825           return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT,
56826                              getZeroVector(OpVT, Subtarget, DAG, dl),
56827                              Ins.getOperand(1), N->getOperand(2));
56828     }
56829   }
56830 
56831   // Stop here if this is an i1 vector.
56832   if (IsI1Vector)
56833     return SDValue();
56834 
56835   // Eliminate an intermediate vector widening:
56836   // insert_subvector X, (insert_subvector undef, Y, 0), Idx -->
56837   // insert_subvector X, Y, Idx
56838   // TODO: This is a more general version of a DAGCombiner fold, can we move it
56839   // there?
56840   if (SubVec.getOpcode() == ISD::INSERT_SUBVECTOR &&
56841       SubVec.getOperand(0).isUndef() && isNullConstant(SubVec.getOperand(2)))
56842     return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Vec,
56843                        SubVec.getOperand(1), N->getOperand(2));
56844 
56845   // If this is an insert of an extract, combine to a shuffle. Don't do this
56846   // if the insert or extract can be represented with a subregister operation.
56847   if (SubVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
56848       SubVec.getOperand(0).getSimpleValueType() == OpVT &&
56849       (IdxVal != 0 ||
56850        !(Vec.isUndef() || ISD::isBuildVectorAllZeros(Vec.getNode())))) {
56851     int ExtIdxVal = SubVec.getConstantOperandVal(1);
56852     if (ExtIdxVal != 0) {
56853       int VecNumElts = OpVT.getVectorNumElements();
56854       int SubVecNumElts = SubVecVT.getVectorNumElements();
56855       SmallVector<int, 64> Mask(VecNumElts);
56856       // First create an identity shuffle mask.
56857       for (int i = 0; i != VecNumElts; ++i)
56858         Mask[i] = i;
56859       // Now insert the extracted portion.
56860       for (int i = 0; i != SubVecNumElts; ++i)
56861         Mask[i + IdxVal] = i + ExtIdxVal + VecNumElts;
56862 
56863       return DAG.getVectorShuffle(OpVT, dl, Vec, SubVec.getOperand(0), Mask);
56864     }
56865   }
56866 
56867   // Match concat_vector style patterns.
56868   SmallVector<SDValue, 2> SubVectorOps;
56869   if (collectConcatOps(N, SubVectorOps, DAG)) {
56870     if (SDValue Fold =
56871             combineConcatVectorOps(dl, OpVT, SubVectorOps, DAG, DCI, Subtarget))
56872       return Fold;
56873 
56874     // If we're inserting all zeros into the upper half, change this to
56875     // a concat with zero. We will match this to a move
56876     // with implicit upper bit zeroing during isel.
56877     // We do this here because we don't want combineConcatVectorOps to
56878     // create INSERT_SUBVECTOR from CONCAT_VECTORS.
56879     if (SubVectorOps.size() == 2 &&
56880         ISD::isBuildVectorAllZeros(SubVectorOps[1].getNode()))
56881       return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT,
56882                          getZeroVector(OpVT, Subtarget, DAG, dl),
56883                          SubVectorOps[0], DAG.getIntPtrConstant(0, dl));
56884 
56885     // Attempt to recursively combine to a shuffle.
56886     if (all_of(SubVectorOps, [](SDValue SubOp) {
56887           return isTargetShuffle(SubOp.getOpcode());
56888         })) {
56889       SDValue Op(N, 0);
56890       if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
56891         return Res;
56892     }
56893   }
56894 
56895   // If this is a broadcast insert into an upper undef, use a larger broadcast.
56896   if (Vec.isUndef() && IdxVal != 0 && SubVec.getOpcode() == X86ISD::VBROADCAST)
56897     return DAG.getNode(X86ISD::VBROADCAST, dl, OpVT, SubVec.getOperand(0));
56898 
56899   // If this is a broadcast load inserted into an upper undef, use a larger
56900   // broadcast load.
56901   if (Vec.isUndef() && IdxVal != 0 && SubVec.hasOneUse() &&
56902       SubVec.getOpcode() == X86ISD::VBROADCAST_LOAD) {
56903     auto *MemIntr = cast<MemIntrinsicSDNode>(SubVec);
56904     SDVTList Tys = DAG.getVTList(OpVT, MVT::Other);
56905     SDValue Ops[] = { MemIntr->getChain(), MemIntr->getBasePtr() };
56906     SDValue BcastLd =
56907         DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops,
56908                                 MemIntr->getMemoryVT(),
56909                                 MemIntr->getMemOperand());
56910     DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
56911     return BcastLd;
56912   }
56913 
56914   // If we're splatting the lower half subvector of a full vector load into the
56915   // upper half, attempt to create a subvector broadcast.
56916   if (IdxVal == (OpVT.getVectorNumElements() / 2) && SubVec.hasOneUse() &&
56917       Vec.getValueSizeInBits() == (2 * SubVec.getValueSizeInBits())) {
56918     auto *VecLd = dyn_cast<LoadSDNode>(Vec);
56919     auto *SubLd = dyn_cast<LoadSDNode>(SubVec);
56920     if (VecLd && SubLd &&
56921         DAG.areNonVolatileConsecutiveLoads(SubLd, VecLd,
56922                                            SubVec.getValueSizeInBits() / 8, 0))
56923       return getBROADCAST_LOAD(X86ISD::SUBV_BROADCAST_LOAD, dl, OpVT, SubVecVT,
56924                                SubLd, 0, DAG);
56925   }
56926 
56927   return SDValue();
56928 }
56929 
56930 /// If we are extracting a subvector of a vector select and the select condition
56931 /// is composed of concatenated vectors, try to narrow the select width. This
56932 /// is a common pattern for AVX1 integer code because 256-bit selects may be
56933 /// legal, but there is almost no integer math/logic available for 256-bit.
56934 /// This function should only be called with legal types (otherwise, the calls
56935 /// to get simple value types will assert).
narrowExtractedVectorSelect(SDNode * Ext,const SDLoc & DL,SelectionDAG & DAG)56936 static SDValue narrowExtractedVectorSelect(SDNode *Ext, const SDLoc &DL,
56937                                            SelectionDAG &DAG) {
56938   SDValue Sel = Ext->getOperand(0);
56939   if (Sel.getOpcode() != ISD::VSELECT ||
56940       !isFreeToSplitVector(Sel.getOperand(0).getNode(), DAG))
56941     return SDValue();
56942 
56943   // Note: We assume simple value types because this should only be called with
56944   //       legal operations/types.
56945   // TODO: This can be extended to handle extraction to 256-bits.
56946   MVT VT = Ext->getSimpleValueType(0);
56947   if (!VT.is128BitVector())
56948     return SDValue();
56949 
56950   MVT SelCondVT = Sel.getOperand(0).getSimpleValueType();
56951   if (!SelCondVT.is256BitVector() && !SelCondVT.is512BitVector())
56952     return SDValue();
56953 
56954   MVT WideVT = Ext->getOperand(0).getSimpleValueType();
56955   MVT SelVT = Sel.getSimpleValueType();
56956   assert((SelVT.is256BitVector() || SelVT.is512BitVector()) &&
56957          "Unexpected vector type with legal operations");
56958 
56959   unsigned SelElts = SelVT.getVectorNumElements();
56960   unsigned CastedElts = WideVT.getVectorNumElements();
56961   unsigned ExtIdx = Ext->getConstantOperandVal(1);
56962   if (SelElts % CastedElts == 0) {
56963     // The select has the same or more (narrower) elements than the extract
56964     // operand. The extraction index gets scaled by that factor.
56965     ExtIdx *= (SelElts / CastedElts);
56966   } else if (CastedElts % SelElts == 0) {
56967     // The select has less (wider) elements than the extract operand. Make sure
56968     // that the extraction index can be divided evenly.
56969     unsigned IndexDivisor = CastedElts / SelElts;
56970     if (ExtIdx % IndexDivisor != 0)
56971       return SDValue();
56972     ExtIdx /= IndexDivisor;
56973   } else {
56974     llvm_unreachable("Element count of simple vector types are not divisible?");
56975   }
56976 
56977   unsigned NarrowingFactor = WideVT.getSizeInBits() / VT.getSizeInBits();
56978   unsigned NarrowElts = SelElts / NarrowingFactor;
56979   MVT NarrowSelVT = MVT::getVectorVT(SelVT.getVectorElementType(), NarrowElts);
56980   SDValue ExtCond = extract128BitVector(Sel.getOperand(0), ExtIdx, DAG, DL);
56981   SDValue ExtT = extract128BitVector(Sel.getOperand(1), ExtIdx, DAG, DL);
56982   SDValue ExtF = extract128BitVector(Sel.getOperand(2), ExtIdx, DAG, DL);
56983   SDValue NarrowSel = DAG.getSelect(DL, NarrowSelVT, ExtCond, ExtT, ExtF);
56984   return DAG.getBitcast(VT, NarrowSel);
56985 }
56986 
combineEXTRACT_SUBVECTOR(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)56987 static SDValue combineEXTRACT_SUBVECTOR(SDNode *N, SelectionDAG &DAG,
56988                                         TargetLowering::DAGCombinerInfo &DCI,
56989                                         const X86Subtarget &Subtarget) {
56990   // For AVX1 only, if we are extracting from a 256-bit and+not (which will
56991   // eventually get combined/lowered into ANDNP) with a concatenated operand,
56992   // split the 'and' into 128-bit ops to avoid the concatenate and extract.
56993   // We let generic combining take over from there to simplify the
56994   // insert/extract and 'not'.
56995   // This pattern emerges during AVX1 legalization. We handle it before lowering
56996   // to avoid complications like splitting constant vector loads.
56997 
56998   // Capture the original wide type in the likely case that we need to bitcast
56999   // back to this type.
57000   if (!N->getValueType(0).isSimple())
57001     return SDValue();
57002 
57003   MVT VT = N->getSimpleValueType(0);
57004   SDValue InVec = N->getOperand(0);
57005   unsigned IdxVal = N->getConstantOperandVal(1);
57006   SDValue InVecBC = peekThroughBitcasts(InVec);
57007   EVT InVecVT = InVec.getValueType();
57008   unsigned SizeInBits = VT.getSizeInBits();
57009   unsigned InSizeInBits = InVecVT.getSizeInBits();
57010   unsigned NumSubElts = VT.getVectorNumElements();
57011   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
57012   SDLoc DL(N);
57013 
57014   if (Subtarget.hasAVX() && !Subtarget.hasAVX2() &&
57015       TLI.isTypeLegal(InVecVT) &&
57016       InSizeInBits == 256 && InVecBC.getOpcode() == ISD::AND) {
57017     auto isConcatenatedNot = [](SDValue V) {
57018       V = peekThroughBitcasts(V);
57019       if (!isBitwiseNot(V))
57020         return false;
57021       SDValue NotOp = V->getOperand(0);
57022       return peekThroughBitcasts(NotOp).getOpcode() == ISD::CONCAT_VECTORS;
57023     };
57024     if (isConcatenatedNot(InVecBC.getOperand(0)) ||
57025         isConcatenatedNot(InVecBC.getOperand(1))) {
57026       // extract (and v4i64 X, (not (concat Y1, Y2))), n -> andnp v2i64 X(n), Y1
57027       SDValue Concat = splitVectorIntBinary(InVecBC, DAG, SDLoc(InVecBC));
57028       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT,
57029                          DAG.getBitcast(InVecVT, Concat), N->getOperand(1));
57030     }
57031   }
57032 
57033   if (DCI.isBeforeLegalizeOps())
57034     return SDValue();
57035 
57036   if (SDValue V = narrowExtractedVectorSelect(N, DL, DAG))
57037     return V;
57038 
57039   if (ISD::isBuildVectorAllZeros(InVec.getNode()))
57040     return getZeroVector(VT, Subtarget, DAG, DL);
57041 
57042   if (ISD::isBuildVectorAllOnes(InVec.getNode())) {
57043     if (VT.getScalarType() == MVT::i1)
57044       return DAG.getConstant(1, DL, VT);
57045     return getOnesVector(VT, DAG, DL);
57046   }
57047 
57048   if (InVec.getOpcode() == ISD::BUILD_VECTOR)
57049     return DAG.getBuildVector(VT, DL, InVec->ops().slice(IdxVal, NumSubElts));
57050 
57051   // If we are extracting from an insert into a larger vector, replace with a
57052   // smaller insert if we don't access less than the original subvector. Don't
57053   // do this for i1 vectors.
57054   // TODO: Relax the matching indices requirement?
57055   if (VT.getVectorElementType() != MVT::i1 &&
57056       InVec.getOpcode() == ISD::INSERT_SUBVECTOR && InVec.hasOneUse() &&
57057       IdxVal == InVec.getConstantOperandVal(2) &&
57058       InVec.getOperand(1).getValueSizeInBits() <= SizeInBits) {
57059     SDValue NewExt = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT,
57060                                  InVec.getOperand(0), N->getOperand(1));
57061     unsigned NewIdxVal = InVec.getConstantOperandVal(2) - IdxVal;
57062     return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, NewExt,
57063                        InVec.getOperand(1),
57064                        DAG.getVectorIdxConstant(NewIdxVal, DL));
57065   }
57066 
57067   // If we're extracting an upper subvector from a broadcast we should just
57068   // extract the lowest subvector instead which should allow
57069   // SimplifyDemandedVectorElts do more simplifications.
57070   if (IdxVal != 0 && (InVec.getOpcode() == X86ISD::VBROADCAST ||
57071                       InVec.getOpcode() == X86ISD::VBROADCAST_LOAD ||
57072                       DAG.isSplatValue(InVec, /*AllowUndefs*/ false)))
57073     return extractSubVector(InVec, 0, DAG, DL, SizeInBits);
57074 
57075   // If we're extracting a broadcasted subvector, just use the lowest subvector.
57076   if (IdxVal != 0 && InVec.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
57077       cast<MemIntrinsicSDNode>(InVec)->getMemoryVT() == VT)
57078     return extractSubVector(InVec, 0, DAG, DL, SizeInBits);
57079 
57080   // Attempt to extract from the source of a shuffle vector.
57081   if ((InSizeInBits % SizeInBits) == 0 && (IdxVal % NumSubElts) == 0) {
57082     SmallVector<int, 32> ShuffleMask;
57083     SmallVector<int, 32> ScaledMask;
57084     SmallVector<SDValue, 2> ShuffleInputs;
57085     unsigned NumSubVecs = InSizeInBits / SizeInBits;
57086     // Decode the shuffle mask and scale it so its shuffling subvectors.
57087     if (getTargetShuffleInputs(InVecBC, ShuffleInputs, ShuffleMask, DAG) &&
57088         scaleShuffleElements(ShuffleMask, NumSubVecs, ScaledMask)) {
57089       unsigned SubVecIdx = IdxVal / NumSubElts;
57090       if (ScaledMask[SubVecIdx] == SM_SentinelUndef)
57091         return DAG.getUNDEF(VT);
57092       if (ScaledMask[SubVecIdx] == SM_SentinelZero)
57093         return getZeroVector(VT, Subtarget, DAG, DL);
57094       SDValue Src = ShuffleInputs[ScaledMask[SubVecIdx] / NumSubVecs];
57095       if (Src.getValueSizeInBits() == InSizeInBits) {
57096         unsigned SrcSubVecIdx = ScaledMask[SubVecIdx] % NumSubVecs;
57097         unsigned SrcEltIdx = SrcSubVecIdx * NumSubElts;
57098         return extractSubVector(DAG.getBitcast(InVecVT, Src), SrcEltIdx, DAG,
57099                                 DL, SizeInBits);
57100       }
57101     }
57102   }
57103 
57104   auto IsExtractFree = [](SDValue V) {
57105     V = peekThroughBitcasts(V);
57106     if (ISD::isBuildVectorOfConstantSDNodes(V.getNode()))
57107       return true;
57108     if (ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()))
57109       return true;
57110     return V.isUndef();
57111   };
57112 
57113   // If we're extracting the lowest subvector and we're the only user,
57114   // we may be able to perform this with a smaller vector width.
57115   unsigned InOpcode = InVec.getOpcode();
57116   if (InVec.hasOneUse()) {
57117     if (IdxVal == 0 && VT == MVT::v2f64 && InVecVT == MVT::v4f64) {
57118       // v2f64 CVTDQ2PD(v4i32).
57119       if (InOpcode == ISD::SINT_TO_FP &&
57120           InVec.getOperand(0).getValueType() == MVT::v4i32) {
57121         return DAG.getNode(X86ISD::CVTSI2P, DL, VT, InVec.getOperand(0));
57122       }
57123       // v2f64 CVTUDQ2PD(v4i32).
57124       if (InOpcode == ISD::UINT_TO_FP && Subtarget.hasVLX() &&
57125           InVec.getOperand(0).getValueType() == MVT::v4i32) {
57126         return DAG.getNode(X86ISD::CVTUI2P, DL, VT, InVec.getOperand(0));
57127       }
57128       // v2f64 CVTPS2PD(v4f32).
57129       if (InOpcode == ISD::FP_EXTEND &&
57130           InVec.getOperand(0).getValueType() == MVT::v4f32) {
57131         return DAG.getNode(X86ISD::VFPEXT, DL, VT, InVec.getOperand(0));
57132       }
57133     }
57134     // v4i32 CVTPS2DQ(v4f32).
57135     if (InOpcode == ISD::FP_TO_SINT && VT == MVT::v4i32) {
57136       SDValue Src = InVec.getOperand(0);
57137       if (Src.getValueType().getScalarType() == MVT::f32)
57138         return DAG.getNode(InOpcode, DL, VT,
57139                            extractSubVector(Src, IdxVal, DAG, DL, SizeInBits));
57140     }
57141     if (IdxVal == 0 &&
57142         (ISD::isExtOpcode(InOpcode) || ISD::isExtVecInRegOpcode(InOpcode)) &&
57143         (SizeInBits == 128 || SizeInBits == 256) &&
57144         InVec.getOperand(0).getValueSizeInBits() >= SizeInBits) {
57145       SDValue Ext = InVec.getOperand(0);
57146       if (Ext.getValueSizeInBits() > SizeInBits)
57147         Ext = extractSubVector(Ext, 0, DAG, DL, SizeInBits);
57148       unsigned ExtOp = DAG.getOpcode_EXTEND_VECTOR_INREG(InOpcode);
57149       return DAG.getNode(ExtOp, DL, VT, Ext);
57150     }
57151     if (IdxVal == 0 && InOpcode == ISD::VSELECT &&
57152         InVec.getOperand(0).getValueType().is256BitVector() &&
57153         InVec.getOperand(1).getValueType().is256BitVector() &&
57154         InVec.getOperand(2).getValueType().is256BitVector()) {
57155       SDValue Ext0 = extractSubVector(InVec.getOperand(0), 0, DAG, DL, 128);
57156       SDValue Ext1 = extractSubVector(InVec.getOperand(1), 0, DAG, DL, 128);
57157       SDValue Ext2 = extractSubVector(InVec.getOperand(2), 0, DAG, DL, 128);
57158       return DAG.getNode(InOpcode, DL, VT, Ext0, Ext1, Ext2);
57159     }
57160     if (IdxVal == 0 && InOpcode == ISD::TRUNCATE && Subtarget.hasVLX() &&
57161         (SizeInBits == 128 || SizeInBits == 256)) {
57162       SDValue InVecSrc = InVec.getOperand(0);
57163       unsigned Scale = InVecSrc.getValueSizeInBits() / InSizeInBits;
57164       SDValue Ext = extractSubVector(InVecSrc, 0, DAG, DL, Scale * SizeInBits);
57165       return DAG.getNode(InOpcode, DL, VT, Ext);
57166     }
57167     if ((InOpcode == X86ISD::CMPP || InOpcode == X86ISD::PCMPEQ ||
57168          InOpcode == X86ISD::PCMPGT) &&
57169         (IsExtractFree(InVec.getOperand(0)) ||
57170          IsExtractFree(InVec.getOperand(1))) &&
57171         SizeInBits == 128) {
57172       SDValue Ext0 =
57173           extractSubVector(InVec.getOperand(0), IdxVal, DAG, DL, SizeInBits);
57174       SDValue Ext1 =
57175           extractSubVector(InVec.getOperand(1), IdxVal, DAG, DL, SizeInBits);
57176       if (InOpcode == X86ISD::CMPP)
57177         return DAG.getNode(InOpcode, DL, VT, Ext0, Ext1, InVec.getOperand(2));
57178       return DAG.getNode(InOpcode, DL, VT, Ext0, Ext1);
57179     }
57180     if (InOpcode == X86ISD::MOVDDUP &&
57181         (SizeInBits == 128 || SizeInBits == 256)) {
57182       SDValue Ext0 =
57183           extractSubVector(InVec.getOperand(0), IdxVal, DAG, DL, SizeInBits);
57184       return DAG.getNode(InOpcode, DL, VT, Ext0);
57185     }
57186   }
57187 
57188   // Always split vXi64 logical shifts where we're extracting the upper 32-bits
57189   // as this is very likely to fold into a shuffle/truncation.
57190   if ((InOpcode == X86ISD::VSHLI || InOpcode == X86ISD::VSRLI) &&
57191       InVecVT.getScalarSizeInBits() == 64 &&
57192       InVec.getConstantOperandAPInt(1) == 32) {
57193     SDValue Ext =
57194         extractSubVector(InVec.getOperand(0), IdxVal, DAG, DL, SizeInBits);
57195     return DAG.getNode(InOpcode, DL, VT, Ext, InVec.getOperand(1));
57196   }
57197 
57198   return SDValue();
57199 }
57200 
combineScalarToVector(SDNode * N,SelectionDAG & DAG)57201 static SDValue combineScalarToVector(SDNode *N, SelectionDAG &DAG) {
57202   EVT VT = N->getValueType(0);
57203   SDValue Src = N->getOperand(0);
57204   SDLoc DL(N);
57205 
57206   // If this is a scalar to vector to v1i1 from an AND with 1, bypass the and.
57207   // This occurs frequently in our masked scalar intrinsic code and our
57208   // floating point select lowering with AVX512.
57209   // TODO: SimplifyDemandedBits instead?
57210   if (VT == MVT::v1i1 && Src.getOpcode() == ISD::AND && Src.hasOneUse() &&
57211       isOneConstant(Src.getOperand(1)))
57212     return DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v1i1, Src.getOperand(0));
57213 
57214   // Combine scalar_to_vector of an extract_vector_elt into an extract_subvec.
57215   if (VT == MVT::v1i1 && Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
57216       Src.hasOneUse() && Src.getOperand(0).getValueType().isVector() &&
57217       Src.getOperand(0).getValueType().getVectorElementType() == MVT::i1 &&
57218       isNullConstant(Src.getOperand(1)))
57219     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Src.getOperand(0),
57220                        Src.getOperand(1));
57221 
57222   // Reduce v2i64 to v4i32 if we don't need the upper bits or are known zero.
57223   // TODO: Move to DAGCombine/SimplifyDemandedBits?
57224   if ((VT == MVT::v2i64 || VT == MVT::v2f64) && Src.hasOneUse()) {
57225     auto IsExt64 = [&DAG](SDValue Op, bool IsZeroExt) {
57226       if (Op.getValueType() != MVT::i64)
57227         return SDValue();
57228       unsigned Opc = IsZeroExt ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND;
57229       if (Op.getOpcode() == Opc &&
57230           Op.getOperand(0).getScalarValueSizeInBits() <= 32)
57231         return Op.getOperand(0);
57232       unsigned Ext = IsZeroExt ? ISD::ZEXTLOAD : ISD::EXTLOAD;
57233       if (auto *Ld = dyn_cast<LoadSDNode>(Op))
57234         if (Ld->getExtensionType() == Ext &&
57235             Ld->getMemoryVT().getScalarSizeInBits() <= 32)
57236           return Op;
57237       if (IsZeroExt) {
57238         KnownBits Known = DAG.computeKnownBits(Op);
57239         if (!Known.isConstant() && Known.countMinLeadingZeros() >= 32)
57240           return Op;
57241       }
57242       return SDValue();
57243     };
57244 
57245     if (SDValue AnyExt = IsExt64(peekThroughOneUseBitcasts(Src), false))
57246       return DAG.getBitcast(
57247           VT, DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i32,
57248                           DAG.getAnyExtOrTrunc(AnyExt, DL, MVT::i32)));
57249 
57250     if (SDValue ZeroExt = IsExt64(peekThroughOneUseBitcasts(Src), true))
57251       return DAG.getBitcast(
57252           VT,
57253           DAG.getNode(X86ISD::VZEXT_MOVL, DL, MVT::v4i32,
57254                       DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4i32,
57255                                   DAG.getZExtOrTrunc(ZeroExt, DL, MVT::i32))));
57256   }
57257 
57258   // Combine (v2i64 (scalar_to_vector (i64 (bitconvert (mmx))))) to MOVQ2DQ.
57259   if (VT == MVT::v2i64 && Src.getOpcode() == ISD::BITCAST &&
57260       Src.getOperand(0).getValueType() == MVT::x86mmx)
57261     return DAG.getNode(X86ISD::MOVQ2DQ, DL, VT, Src.getOperand(0));
57262 
57263   // See if we're broadcasting the scalar value, in which case just reuse that.
57264   // Ensure the same SDValue from the SDNode use is being used.
57265   if (VT.getScalarType() == Src.getValueType())
57266     for (SDNode *User : Src->uses())
57267       if (User->getOpcode() == X86ISD::VBROADCAST &&
57268           Src == User->getOperand(0)) {
57269         unsigned SizeInBits = VT.getFixedSizeInBits();
57270         unsigned BroadcastSizeInBits =
57271             User->getValueSizeInBits(0).getFixedValue();
57272         if (BroadcastSizeInBits == SizeInBits)
57273           return SDValue(User, 0);
57274         if (BroadcastSizeInBits > SizeInBits)
57275           return extractSubVector(SDValue(User, 0), 0, DAG, DL, SizeInBits);
57276         // TODO: Handle BroadcastSizeInBits < SizeInBits when we have test
57277         // coverage.
57278       }
57279 
57280   return SDValue();
57281 }
57282 
57283 // Simplify PMULDQ and PMULUDQ operations.
combinePMULDQ(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)57284 static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
57285                              TargetLowering::DAGCombinerInfo &DCI,
57286                              const X86Subtarget &Subtarget) {
57287   SDValue LHS = N->getOperand(0);
57288   SDValue RHS = N->getOperand(1);
57289 
57290   // Canonicalize constant to RHS.
57291   if (DAG.isConstantIntBuildVectorOrConstantInt(LHS) &&
57292       !DAG.isConstantIntBuildVectorOrConstantInt(RHS))
57293     return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), RHS, LHS);
57294 
57295   // Multiply by zero.
57296   // Don't return RHS as it may contain UNDEFs.
57297   if (ISD::isBuildVectorAllZeros(RHS.getNode()))
57298     return DAG.getConstant(0, SDLoc(N), N->getValueType(0));
57299 
57300   // PMULDQ/PMULUDQ only uses lower 32 bits from each vector element.
57301   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
57302   if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(64), DCI))
57303     return SDValue(N, 0);
57304 
57305   // If the input is an extend_invec and the SimplifyDemandedBits call didn't
57306   // convert it to any_extend_invec, due to the LegalOperations check, do the
57307   // conversion directly to a vector shuffle manually. This exposes combine
57308   // opportunities missed by combineEXTEND_VECTOR_INREG not calling
57309   // combineX86ShufflesRecursively on SSE4.1 targets.
57310   // FIXME: This is basically a hack around several other issues related to
57311   // ANY_EXTEND_VECTOR_INREG.
57312   if (N->getValueType(0) == MVT::v2i64 && LHS.hasOneUse() &&
57313       (LHS.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
57314        LHS.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG) &&
57315       LHS.getOperand(0).getValueType() == MVT::v4i32) {
57316     SDLoc dl(N);
57317     LHS = DAG.getVectorShuffle(MVT::v4i32, dl, LHS.getOperand(0),
57318                                LHS.getOperand(0), { 0, -1, 1, -1 });
57319     LHS = DAG.getBitcast(MVT::v2i64, LHS);
57320     return DAG.getNode(N->getOpcode(), dl, MVT::v2i64, LHS, RHS);
57321   }
57322   if (N->getValueType(0) == MVT::v2i64 && RHS.hasOneUse() &&
57323       (RHS.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
57324        RHS.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG) &&
57325       RHS.getOperand(0).getValueType() == MVT::v4i32) {
57326     SDLoc dl(N);
57327     RHS = DAG.getVectorShuffle(MVT::v4i32, dl, RHS.getOperand(0),
57328                                RHS.getOperand(0), { 0, -1, 1, -1 });
57329     RHS = DAG.getBitcast(MVT::v2i64, RHS);
57330     return DAG.getNode(N->getOpcode(), dl, MVT::v2i64, LHS, RHS);
57331   }
57332 
57333   return SDValue();
57334 }
57335 
57336 // Simplify VPMADDUBSW/VPMADDWD operations.
combineVPMADD(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)57337 static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
57338                              TargetLowering::DAGCombinerInfo &DCI) {
57339   MVT VT = N->getSimpleValueType(0);
57340   SDValue LHS = N->getOperand(0);
57341   SDValue RHS = N->getOperand(1);
57342   unsigned Opc = N->getOpcode();
57343   bool IsPMADDWD = Opc == X86ISD::VPMADDWD;
57344   assert((Opc == X86ISD::VPMADDWD || Opc == X86ISD::VPMADDUBSW) &&
57345          "Unexpected PMADD opcode");
57346 
57347   // Multiply by zero.
57348   // Don't return LHS/RHS as it may contain UNDEFs.
57349   if (ISD::isBuildVectorAllZeros(LHS.getNode()) ||
57350       ISD::isBuildVectorAllZeros(RHS.getNode()))
57351     return DAG.getConstant(0, SDLoc(N), VT);
57352 
57353   // Constant folding.
57354   APInt LHSUndefs, RHSUndefs;
57355   SmallVector<APInt> LHSBits, RHSBits;
57356   unsigned SrcEltBits = LHS.getScalarValueSizeInBits();
57357   unsigned DstEltBits = VT.getScalarSizeInBits();
57358   if (getTargetConstantBitsFromNode(LHS, SrcEltBits, LHSUndefs, LHSBits) &&
57359       getTargetConstantBitsFromNode(RHS, SrcEltBits, RHSUndefs, RHSBits)) {
57360     SmallVector<APInt> Result;
57361     for (unsigned I = 0, E = LHSBits.size(); I != E; I += 2) {
57362       APInt LHSLo = LHSBits[I + 0], LHSHi = LHSBits[I + 1];
57363       APInt RHSLo = RHSBits[I + 0], RHSHi = RHSBits[I + 1];
57364       LHSLo = IsPMADDWD ? LHSLo.sext(DstEltBits) : LHSLo.zext(DstEltBits);
57365       LHSHi = IsPMADDWD ? LHSHi.sext(DstEltBits) : LHSHi.zext(DstEltBits);
57366       APInt Lo = LHSLo * RHSLo.sext(DstEltBits);
57367       APInt Hi = LHSHi * RHSHi.sext(DstEltBits);
57368       APInt Res = IsPMADDWD ? (Lo + Hi) : Lo.sadd_sat(Hi);
57369       Result.push_back(Res);
57370     }
57371     return getConstVector(Result, VT, DAG, SDLoc(N));
57372   }
57373 
57374   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
57375   APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
57376   if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, DCI))
57377     return SDValue(N, 0);
57378 
57379   return SDValue();
57380 }
57381 
combineEXTEND_VECTOR_INREG(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)57382 static SDValue combineEXTEND_VECTOR_INREG(SDNode *N, SelectionDAG &DAG,
57383                                           TargetLowering::DAGCombinerInfo &DCI,
57384                                           const X86Subtarget &Subtarget) {
57385   EVT VT = N->getValueType(0);
57386   SDValue In = N->getOperand(0);
57387   unsigned Opcode = N->getOpcode();
57388   unsigned InOpcode = In.getOpcode();
57389   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
57390   SDLoc DL(N);
57391 
57392   // Try to merge vector loads and extend_inreg to an extload.
57393   if (!DCI.isBeforeLegalizeOps() && ISD::isNormalLoad(In.getNode()) &&
57394       In.hasOneUse()) {
57395     auto *Ld = cast<LoadSDNode>(In);
57396     if (Ld->isSimple()) {
57397       MVT SVT = In.getSimpleValueType().getVectorElementType();
57398       ISD::LoadExtType Ext = Opcode == ISD::SIGN_EXTEND_VECTOR_INREG
57399                                  ? ISD::SEXTLOAD
57400                                  : ISD::ZEXTLOAD;
57401       EVT MemVT = VT.changeVectorElementType(SVT);
57402       if (TLI.isLoadExtLegal(Ext, VT, MemVT)) {
57403         SDValue Load = DAG.getExtLoad(
57404             Ext, DL, VT, Ld->getChain(), Ld->getBasePtr(), Ld->getPointerInfo(),
57405             MemVT, Ld->getOriginalAlign(), Ld->getMemOperand()->getFlags());
57406         DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1));
57407         return Load;
57408       }
57409     }
57410   }
57411 
57412   // Fold EXTEND_VECTOR_INREG(EXTEND_VECTOR_INREG(X)) -> EXTEND_VECTOR_INREG(X).
57413   if (Opcode == InOpcode)
57414     return DAG.getNode(Opcode, DL, VT, In.getOperand(0));
57415 
57416   // Fold EXTEND_VECTOR_INREG(EXTRACT_SUBVECTOR(EXTEND(X),0))
57417   // -> EXTEND_VECTOR_INREG(X).
57418   // TODO: Handle non-zero subvector indices.
57419   if (InOpcode == ISD::EXTRACT_SUBVECTOR && In.getConstantOperandVal(1) == 0 &&
57420       In.getOperand(0).getOpcode() == DAG.getOpcode_EXTEND(Opcode) &&
57421       In.getOperand(0).getOperand(0).getValueSizeInBits() ==
57422           In.getValueSizeInBits())
57423     return DAG.getNode(Opcode, DL, VT, In.getOperand(0).getOperand(0));
57424 
57425   // Fold EXTEND_VECTOR_INREG(BUILD_VECTOR(X,Y,?,?)) -> BUILD_VECTOR(X,0,Y,0).
57426   // TODO: Move to DAGCombine?
57427   if (!DCI.isBeforeLegalizeOps() && Opcode == ISD::ZERO_EXTEND_VECTOR_INREG &&
57428       In.getOpcode() == ISD::BUILD_VECTOR && In.hasOneUse() &&
57429       In.getValueSizeInBits() == VT.getSizeInBits()) {
57430     unsigned NumElts = VT.getVectorNumElements();
57431     unsigned Scale = VT.getScalarSizeInBits() / In.getScalarValueSizeInBits();
57432     EVT EltVT = In.getOperand(0).getValueType();
57433     SmallVector<SDValue> Elts(Scale * NumElts, DAG.getConstant(0, DL, EltVT));
57434     for (unsigned I = 0; I != NumElts; ++I)
57435       Elts[I * Scale] = In.getOperand(I);
57436     return DAG.getBitcast(VT, DAG.getBuildVector(In.getValueType(), DL, Elts));
57437   }
57438 
57439   // Attempt to combine as a shuffle on SSE41+ targets.
57440   if (Subtarget.hasSSE41()) {
57441     SDValue Op(N, 0);
57442     if (TLI.isTypeLegal(VT) && TLI.isTypeLegal(In.getValueType()))
57443       if (SDValue Res = combineX86ShufflesRecursively(Op, DAG, Subtarget))
57444         return Res;
57445   }
57446 
57447   return SDValue();
57448 }
57449 
combineKSHIFT(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)57450 static SDValue combineKSHIFT(SDNode *N, SelectionDAG &DAG,
57451                              TargetLowering::DAGCombinerInfo &DCI) {
57452   EVT VT = N->getValueType(0);
57453 
57454   if (ISD::isBuildVectorAllZeros(N->getOperand(0).getNode()))
57455     return DAG.getConstant(0, SDLoc(N), VT);
57456 
57457   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
57458   APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
57459   if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, DCI))
57460     return SDValue(N, 0);
57461 
57462   return SDValue();
57463 }
57464 
57465 // Optimize (fp16_to_fp (fp_to_fp16 X)) to VCVTPS2PH followed by VCVTPH2PS.
57466 // Done as a combine because the lowering for fp16_to_fp and fp_to_fp16 produce
57467 // extra instructions between the conversion due to going to scalar and back.
combineFP16_TO_FP(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)57468 static SDValue combineFP16_TO_FP(SDNode *N, SelectionDAG &DAG,
57469                                  const X86Subtarget &Subtarget) {
57470   if (Subtarget.useSoftFloat() || !Subtarget.hasF16C())
57471     return SDValue();
57472 
57473   if (N->getOperand(0).getOpcode() != ISD::FP_TO_FP16)
57474     return SDValue();
57475 
57476   if (N->getValueType(0) != MVT::f32 ||
57477       N->getOperand(0).getOperand(0).getValueType() != MVT::f32)
57478     return SDValue();
57479 
57480   SDLoc dl(N);
57481   SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4f32,
57482                             N->getOperand(0).getOperand(0));
57483   Res = DAG.getNode(X86ISD::CVTPS2PH, dl, MVT::v8i16, Res,
57484                     DAG.getTargetConstant(4, dl, MVT::i32));
57485   Res = DAG.getNode(X86ISD::CVTPH2PS, dl, MVT::v4f32, Res);
57486   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::f32, Res,
57487                      DAG.getIntPtrConstant(0, dl));
57488 }
57489 
combineFP_EXTEND(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const X86Subtarget & Subtarget)57490 static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
57491                                 TargetLowering::DAGCombinerInfo &DCI,
57492                                 const X86Subtarget &Subtarget) {
57493   EVT VT = N->getValueType(0);
57494   bool IsStrict = N->isStrictFPOpcode();
57495   SDValue Src = N->getOperand(IsStrict ? 1 : 0);
57496   EVT SrcVT = Src.getValueType();
57497 
57498   SDLoc dl(N);
57499   if (SrcVT.getScalarType() == MVT::bf16) {
57500     if (DCI.isAfterLegalizeDAG() && Src.getOpcode() == ISD::FP_ROUND &&
57501         !IsStrict && Src.getOperand(0).getValueType() == VT)
57502       return Src.getOperand(0);
57503 
57504     if (!SrcVT.isVector())
57505       return SDValue();
57506 
57507     assert(!IsStrict && "Strict FP doesn't support BF16");
57508     if (VT.getVectorElementType() == MVT::f64) {
57509       EVT TmpVT = VT.changeVectorElementType(MVT::f32);
57510       return DAG.getNode(ISD::FP_EXTEND, dl, VT,
57511                          DAG.getNode(ISD::FP_EXTEND, dl, TmpVT, Src));
57512     }
57513     assert(VT.getVectorElementType() == MVT::f32 && "Unexpected fpext");
57514     EVT NVT = SrcVT.changeVectorElementType(MVT::i32);
57515     Src = DAG.getBitcast(SrcVT.changeTypeToInteger(), Src);
57516     Src = DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, Src);
57517     Src = DAG.getNode(ISD::SHL, dl, NVT, Src, DAG.getConstant(16, dl, NVT));
57518     return DAG.getBitcast(VT, Src);
57519   }
57520 
57521   if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
57522     return SDValue();
57523 
57524   if (Subtarget.hasFP16())
57525     return SDValue();
57526 
57527   if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::f16)
57528     return SDValue();
57529 
57530   if (VT.getVectorElementType() != MVT::f32 &&
57531       VT.getVectorElementType() != MVT::f64)
57532     return SDValue();
57533 
57534   unsigned NumElts = VT.getVectorNumElements();
57535   if (NumElts == 1 || !isPowerOf2_32(NumElts))
57536     return SDValue();
57537 
57538   // Convert the input to vXi16.
57539   EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
57540   Src = DAG.getBitcast(IntVT, Src);
57541 
57542   // Widen to at least 8 input elements.
57543   if (NumElts < 8) {
57544     unsigned NumConcats = 8 / NumElts;
57545     SDValue Fill = NumElts == 4 ? DAG.getUNDEF(IntVT)
57546                                 : DAG.getConstant(0, dl, IntVT);
57547     SmallVector<SDValue, 4> Ops(NumConcats, Fill);
57548     Ops[0] = Src;
57549     Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i16, Ops);
57550   }
57551 
57552   // Destination is vXf32 with at least 4 elements.
57553   EVT CvtVT = EVT::getVectorVT(*DAG.getContext(), MVT::f32,
57554                                std::max(4U, NumElts));
57555   SDValue Cvt, Chain;
57556   if (IsStrict) {
57557     Cvt = DAG.getNode(X86ISD::STRICT_CVTPH2PS, dl, {CvtVT, MVT::Other},
57558                       {N->getOperand(0), Src});
57559     Chain = Cvt.getValue(1);
57560   } else {
57561     Cvt = DAG.getNode(X86ISD::CVTPH2PS, dl, CvtVT, Src);
57562   }
57563 
57564   if (NumElts < 4) {
57565     assert(NumElts == 2 && "Unexpected size");
57566     Cvt = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2f32, Cvt,
57567                       DAG.getIntPtrConstant(0, dl));
57568   }
57569 
57570   if (IsStrict) {
57571     // Extend to the original VT if necessary.
57572     if (Cvt.getValueType() != VT) {
57573       Cvt = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {VT, MVT::Other},
57574                         {Chain, Cvt});
57575       Chain = Cvt.getValue(1);
57576     }
57577     return DAG.getMergeValues({Cvt, Chain}, dl);
57578   }
57579 
57580   // Extend to the original VT if necessary.
57581   return DAG.getNode(ISD::FP_EXTEND, dl, VT, Cvt);
57582 }
57583 
57584 // Try to find a larger VBROADCAST_LOAD/SUBV_BROADCAST_LOAD that we can extract
57585 // from. Limit this to cases where the loads have the same input chain and the
57586 // output chains are unused. This avoids any memory ordering issues.
combineBROADCAST_LOAD(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)57587 static SDValue combineBROADCAST_LOAD(SDNode *N, SelectionDAG &DAG,
57588                                      TargetLowering::DAGCombinerInfo &DCI) {
57589   assert((N->getOpcode() == X86ISD::VBROADCAST_LOAD ||
57590           N->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) &&
57591          "Unknown broadcast load type");
57592 
57593   // Only do this if the chain result is unused.
57594   if (N->hasAnyUseOfValue(1))
57595     return SDValue();
57596 
57597   auto *MemIntrin = cast<MemIntrinsicSDNode>(N);
57598 
57599   SDValue Ptr = MemIntrin->getBasePtr();
57600   SDValue Chain = MemIntrin->getChain();
57601   EVT VT = N->getSimpleValueType(0);
57602   EVT MemVT = MemIntrin->getMemoryVT();
57603 
57604   // Look at other users of our base pointer and try to find a wider broadcast.
57605   // The input chain and the size of the memory VT must match.
57606   for (SDNode *User : Ptr->uses())
57607     if (User != N && User->getOpcode() == N->getOpcode() &&
57608         cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
57609         cast<MemIntrinsicSDNode>(User)->getChain() == Chain &&
57610         cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
57611             MemVT.getSizeInBits() &&
57612         !User->hasAnyUseOfValue(1) &&
57613         User->getValueSizeInBits(0).getFixedValue() > VT.getFixedSizeInBits()) {
57614       SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
57615                                          VT.getSizeInBits());
57616       Extract = DAG.getBitcast(VT, Extract);
57617       return DCI.CombineTo(N, Extract, SDValue(User, 1));
57618     }
57619 
57620   return SDValue();
57621 }
57622 
combineFP_ROUND(SDNode * N,SelectionDAG & DAG,const X86Subtarget & Subtarget)57623 static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG,
57624                                const X86Subtarget &Subtarget) {
57625   if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
57626     return SDValue();
57627 
57628   bool IsStrict = N->isStrictFPOpcode();
57629   EVT VT = N->getValueType(0);
57630   SDValue Src = N->getOperand(IsStrict ? 1 : 0);
57631   EVT SrcVT = Src.getValueType();
57632 
57633   if (!VT.isVector() || VT.getVectorElementType() != MVT::f16 ||
57634       SrcVT.getVectorElementType() != MVT::f32)
57635     return SDValue();
57636 
57637   SDLoc dl(N);
57638 
57639   SDValue Cvt, Chain;
57640   unsigned NumElts = VT.getVectorNumElements();
57641   if (Subtarget.hasFP16()) {
57642     // Combine (v8f16 fp_round(concat_vectors(v4f32 (xint_to_fp v4i64),
57643     //                                        v4f32 (xint_to_fp v4i64))))
57644     // into (v8f16 vector_shuffle(v8f16 (CVTXI2P v4i64),
57645     //                            v8f16 (CVTXI2P v4i64)))
57646     if (NumElts == 8 && Src.getOpcode() == ISD::CONCAT_VECTORS &&
57647         Src.getNumOperands() == 2) {
57648       SDValue Cvt0, Cvt1;
57649       SDValue Op0 = Src.getOperand(0);
57650       SDValue Op1 = Src.getOperand(1);
57651       bool IsOp0Strict = Op0->isStrictFPOpcode();
57652       if (Op0.getOpcode() != Op1.getOpcode() ||
57653           Op0.getOperand(IsOp0Strict ? 1 : 0).getValueType() != MVT::v4i64 ||
57654           Op1.getOperand(IsOp0Strict ? 1 : 0).getValueType() != MVT::v4i64) {
57655         return SDValue();
57656       }
57657       int Mask[8] = {0, 1, 2, 3, 8, 9, 10, 11};
57658       if (IsStrict) {
57659         assert(IsOp0Strict && "Op0 must be strict node");
57660         unsigned Opc = Op0.getOpcode() == ISD::STRICT_SINT_TO_FP
57661                            ? X86ISD::STRICT_CVTSI2P
57662                            : X86ISD::STRICT_CVTUI2P;
57663         Cvt0 = DAG.getNode(Opc, dl, {MVT::v8f16, MVT::Other},
57664                            {Op0.getOperand(0), Op0.getOperand(1)});
57665         Cvt1 = DAG.getNode(Opc, dl, {MVT::v8f16, MVT::Other},
57666                            {Op1.getOperand(0), Op1.getOperand(1)});
57667         Cvt = DAG.getVectorShuffle(MVT::v8f16, dl, Cvt0, Cvt1, Mask);
57668         return DAG.getMergeValues({Cvt, Cvt0.getValue(1)}, dl);
57669       }
57670       unsigned Opc = Op0.getOpcode() == ISD::SINT_TO_FP ? X86ISD::CVTSI2P
57671                                                         : X86ISD::CVTUI2P;
57672       Cvt0 = DAG.getNode(Opc, dl, MVT::v8f16, Op0.getOperand(0));
57673       Cvt1 = DAG.getNode(Opc, dl, MVT::v8f16, Op1.getOperand(0));
57674       return Cvt = DAG.getVectorShuffle(MVT::v8f16, dl, Cvt0, Cvt1, Mask);
57675     }
57676     return SDValue();
57677   }
57678 
57679   if (NumElts == 1 || !isPowerOf2_32(NumElts))
57680     return SDValue();
57681 
57682   // Widen to at least 4 input elements.
57683   if (NumElts < 4)
57684     Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src,
57685                       DAG.getConstantFP(0.0, dl, SrcVT));
57686 
57687   // Destination is v8i16 with at least 8 elements.
57688   EVT CvtVT =
57689       EVT::getVectorVT(*DAG.getContext(), MVT::i16, std::max(8U, NumElts));
57690   SDValue Rnd = DAG.getTargetConstant(4, dl, MVT::i32);
57691   if (IsStrict) {
57692     Cvt = DAG.getNode(X86ISD::STRICT_CVTPS2PH, dl, {CvtVT, MVT::Other},
57693                       {N->getOperand(0), Src, Rnd});
57694     Chain = Cvt.getValue(1);
57695   } else {
57696     Cvt = DAG.getNode(X86ISD::CVTPS2PH, dl, CvtVT, Src, Rnd);
57697   }
57698 
57699   // Extract down to real number of elements.
57700   if (NumElts < 8) {
57701     EVT IntVT = VT.changeVectorElementTypeToInteger();
57702     Cvt = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, IntVT, Cvt,
57703                       DAG.getIntPtrConstant(0, dl));
57704   }
57705 
57706   Cvt = DAG.getBitcast(VT, Cvt);
57707 
57708   if (IsStrict)
57709     return DAG.getMergeValues({Cvt, Chain}, dl);
57710 
57711   return Cvt;
57712 }
57713 
combineMOVDQ2Q(SDNode * N,SelectionDAG & DAG)57714 static SDValue combineMOVDQ2Q(SDNode *N, SelectionDAG &DAG) {
57715   SDValue Src = N->getOperand(0);
57716 
57717   // Turn MOVDQ2Q+simple_load into an mmx load.
57718   if (ISD::isNormalLoad(Src.getNode()) && Src.hasOneUse()) {
57719     LoadSDNode *LN = cast<LoadSDNode>(Src.getNode());
57720 
57721     if (LN->isSimple()) {
57722       SDValue NewLd = DAG.getLoad(MVT::x86mmx, SDLoc(N), LN->getChain(),
57723                                   LN->getBasePtr(),
57724                                   LN->getPointerInfo(),
57725                                   LN->getOriginalAlign(),
57726                                   LN->getMemOperand()->getFlags());
57727       DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), NewLd.getValue(1));
57728       return NewLd;
57729     }
57730   }
57731 
57732   return SDValue();
57733 }
57734 
combinePDEP(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI)57735 static SDValue combinePDEP(SDNode *N, SelectionDAG &DAG,
57736                            TargetLowering::DAGCombinerInfo &DCI) {
57737   unsigned NumBits = N->getSimpleValueType(0).getSizeInBits();
57738   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
57739   if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumBits), DCI))
57740     return SDValue(N, 0);
57741 
57742   return SDValue();
57743 }
57744 
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const57745 SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
57746                                              DAGCombinerInfo &DCI) const {
57747   SelectionDAG &DAG = DCI.DAG;
57748   switch (N->getOpcode()) {
57749   // clang-format off
57750   default: break;
57751   case ISD::SCALAR_TO_VECTOR:
57752     return combineScalarToVector(N, DAG);
57753   case ISD::EXTRACT_VECTOR_ELT:
57754   case X86ISD::PEXTRW:
57755   case X86ISD::PEXTRB:
57756     return combineExtractVectorElt(N, DAG, DCI, Subtarget);
57757   case ISD::CONCAT_VECTORS:
57758     return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget);
57759   case ISD::INSERT_SUBVECTOR:
57760     return combineINSERT_SUBVECTOR(N, DAG, DCI, Subtarget);
57761   case ISD::EXTRACT_SUBVECTOR:
57762     return combineEXTRACT_SUBVECTOR(N, DAG, DCI, Subtarget);
57763   case ISD::VSELECT:
57764   case ISD::SELECT:
57765   case X86ISD::BLENDV:      return combineSelect(N, DAG, DCI, Subtarget);
57766   case ISD::BITCAST:        return combineBitcast(N, DAG, DCI, Subtarget);
57767   case X86ISD::CMOV:        return combineCMov(N, DAG, DCI, Subtarget);
57768   case X86ISD::CMP:         return combineCMP(N, DAG, DCI, Subtarget);
57769   case ISD::ADD:            return combineAdd(N, DAG, DCI, Subtarget);
57770   case ISD::SUB:            return combineSub(N, DAG, DCI, Subtarget);
57771   case X86ISD::ADD:
57772   case X86ISD::SUB:         return combineX86AddSub(N, DAG, DCI, Subtarget);
57773   case X86ISD::CLOAD:
57774   case X86ISD::CSTORE:      return combineX86CloadCstore(N, DAG);
57775   case X86ISD::SBB:         return combineSBB(N, DAG);
57776   case X86ISD::ADC:         return combineADC(N, DAG, DCI);
57777   case ISD::MUL:            return combineMul(N, DAG, DCI, Subtarget);
57778   case ISD::SHL:            return combineShiftLeft(N, DAG, Subtarget);
57779   case ISD::SRA:            return combineShiftRightArithmetic(N, DAG, Subtarget);
57780   case ISD::SRL:            return combineShiftRightLogical(N, DAG, DCI, Subtarget);
57781   case ISD::AND:            return combineAnd(N, DAG, DCI, Subtarget);
57782   case ISD::OR:             return combineOr(N, DAG, DCI, Subtarget);
57783   case ISD::XOR:            return combineXor(N, DAG, DCI, Subtarget);
57784   case ISD::BITREVERSE:     return combineBITREVERSE(N, DAG, DCI, Subtarget);
57785   case ISD::AVGCEILS:
57786   case ISD::AVGCEILU:
57787   case ISD::AVGFLOORS:
57788   case ISD::AVGFLOORU:      return combineAVG(N, DAG, DCI, Subtarget);
57789   case X86ISD::BEXTR:
57790   case X86ISD::BEXTRI:      return combineBEXTR(N, DAG, DCI, Subtarget);
57791   case ISD::LOAD:           return combineLoad(N, DAG, DCI, Subtarget);
57792   case ISD::MLOAD:          return combineMaskedLoad(N, DAG, DCI, Subtarget);
57793   case ISD::STORE:          return combineStore(N, DAG, DCI, Subtarget);
57794   case ISD::MSTORE:         return combineMaskedStore(N, DAG, DCI, Subtarget);
57795   case X86ISD::VEXTRACT_STORE:
57796     return combineVEXTRACT_STORE(N, DAG, DCI, Subtarget);
57797   case ISD::SINT_TO_FP:
57798   case ISD::STRICT_SINT_TO_FP:
57799     return combineSIntToFP(N, DAG, DCI, Subtarget);
57800   case ISD::UINT_TO_FP:
57801   case ISD::STRICT_UINT_TO_FP:
57802     return combineUIntToFP(N, DAG, Subtarget);
57803   case ISD::LRINT:
57804   case ISD::LLRINT:         return combineLRINT_LLRINT(N, DAG, Subtarget);
57805   case ISD::FADD:
57806   case ISD::FSUB:           return combineFaddFsub(N, DAG, Subtarget);
57807   case X86ISD::VFCMULC:
57808   case X86ISD::VFMULC:      return combineFMulcFCMulc(N, DAG, Subtarget);
57809   case ISD::FNEG:           return combineFneg(N, DAG, DCI, Subtarget);
57810   case ISD::TRUNCATE:       return combineTruncate(N, DAG, Subtarget);
57811   case X86ISD::VTRUNC:      return combineVTRUNC(N, DAG, DCI);
57812   case X86ISD::ANDNP:       return combineAndnp(N, DAG, DCI, Subtarget);
57813   case X86ISD::FAND:        return combineFAnd(N, DAG, Subtarget);
57814   case X86ISD::FANDN:       return combineFAndn(N, DAG, Subtarget);
57815   case X86ISD::FXOR:
57816   case X86ISD::FOR:         return combineFOr(N, DAG, DCI, Subtarget);
57817   case X86ISD::FMIN:
57818   case X86ISD::FMAX:        return combineFMinFMax(N, DAG);
57819   case ISD::FMINNUM:
57820   case ISD::FMAXNUM:        return combineFMinNumFMaxNum(N, DAG, Subtarget);
57821   case X86ISD::CVTSI2P:
57822   case X86ISD::CVTUI2P:     return combineX86INT_TO_FP(N, DAG, DCI);
57823   case X86ISD::CVTP2SI:
57824   case X86ISD::CVTP2UI:
57825   case X86ISD::STRICT_CVTTP2SI:
57826   case X86ISD::CVTTP2SI:
57827   case X86ISD::STRICT_CVTTP2UI:
57828   case X86ISD::CVTTP2UI:
57829                             return combineCVTP2I_CVTTP2I(N, DAG, DCI);
57830   case X86ISD::STRICT_CVTPH2PS:
57831   case X86ISD::CVTPH2PS:    return combineCVTPH2PS(N, DAG, DCI);
57832   case X86ISD::BT:          return combineBT(N, DAG, DCI);
57833   case ISD::ANY_EXTEND:
57834   case ISD::ZERO_EXTEND:    return combineZext(N, DAG, DCI, Subtarget);
57835   case ISD::SIGN_EXTEND:    return combineSext(N, DAG, DCI, Subtarget);
57836   case ISD::SIGN_EXTEND_INREG: return combineSignExtendInReg(N, DAG, Subtarget);
57837   case ISD::ANY_EXTEND_VECTOR_INREG:
57838   case ISD::SIGN_EXTEND_VECTOR_INREG:
57839   case ISD::ZERO_EXTEND_VECTOR_INREG:
57840     return combineEXTEND_VECTOR_INREG(N, DAG, DCI, Subtarget);
57841   case ISD::SETCC:          return combineSetCC(N, DAG, DCI, Subtarget);
57842   case X86ISD::SETCC:       return combineX86SetCC(N, DAG, Subtarget);
57843   case X86ISD::BRCOND:      return combineBrCond(N, DAG, Subtarget);
57844   case X86ISD::PACKSS:
57845   case X86ISD::PACKUS:      return combineVectorPack(N, DAG, DCI, Subtarget);
57846   case X86ISD::HADD:
57847   case X86ISD::HSUB:
57848   case X86ISD::FHADD:
57849   case X86ISD::FHSUB:       return combineVectorHADDSUB(N, DAG, DCI, Subtarget);
57850   case X86ISD::VSHL:
57851   case X86ISD::VSRA:
57852   case X86ISD::VSRL:
57853     return combineVectorShiftVar(N, DAG, DCI, Subtarget);
57854   case X86ISD::VSHLI:
57855   case X86ISD::VSRAI:
57856   case X86ISD::VSRLI:
57857     return combineVectorShiftImm(N, DAG, DCI, Subtarget);
57858   case ISD::INSERT_VECTOR_ELT:
57859   case X86ISD::PINSRB:
57860   case X86ISD::PINSRW:      return combineVectorInsert(N, DAG, DCI, Subtarget);
57861   case X86ISD::SHUFP:       // Handle all target specific shuffles
57862   case X86ISD::INSERTPS:
57863   case X86ISD::EXTRQI:
57864   case X86ISD::INSERTQI:
57865   case X86ISD::VALIGN:
57866   case X86ISD::PALIGNR:
57867   case X86ISD::VSHLDQ:
57868   case X86ISD::VSRLDQ:
57869   case X86ISD::BLENDI:
57870   case X86ISD::UNPCKH:
57871   case X86ISD::UNPCKL:
57872   case X86ISD::MOVHLPS:
57873   case X86ISD::MOVLHPS:
57874   case X86ISD::PSHUFB:
57875   case X86ISD::PSHUFD:
57876   case X86ISD::PSHUFHW:
57877   case X86ISD::PSHUFLW:
57878   case X86ISD::MOVSHDUP:
57879   case X86ISD::MOVSLDUP:
57880   case X86ISD::MOVDDUP:
57881   case X86ISD::MOVSS:
57882   case X86ISD::MOVSD:
57883   case X86ISD::MOVSH:
57884   case X86ISD::VBROADCAST:
57885   case X86ISD::VPPERM:
57886   case X86ISD::VPERMI:
57887   case X86ISD::VPERMV:
57888   case X86ISD::VPERMV3:
57889   case X86ISD::VPERMIL2:
57890   case X86ISD::VPERMILPI:
57891   case X86ISD::VPERMILPV:
57892   case X86ISD::VPERM2X128:
57893   case X86ISD::SHUF128:
57894   case X86ISD::VZEXT_MOVL:
57895   case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget);
57896   case X86ISD::FMADD_RND:
57897   case X86ISD::FMSUB:
57898   case X86ISD::STRICT_FMSUB:
57899   case X86ISD::FMSUB_RND:
57900   case X86ISD::FNMADD:
57901   case X86ISD::STRICT_FNMADD:
57902   case X86ISD::FNMADD_RND:
57903   case X86ISD::FNMSUB:
57904   case X86ISD::STRICT_FNMSUB:
57905   case X86ISD::FNMSUB_RND:
57906   case ISD::FMA:
57907   case ISD::STRICT_FMA:     return combineFMA(N, DAG, DCI, Subtarget);
57908   case X86ISD::FMADDSUB_RND:
57909   case X86ISD::FMSUBADD_RND:
57910   case X86ISD::FMADDSUB:
57911   case X86ISD::FMSUBADD:    return combineFMADDSUB(N, DAG, DCI);
57912   case X86ISD::MOVMSK:      return combineMOVMSK(N, DAG, DCI, Subtarget);
57913   case X86ISD::TESTP:       return combineTESTP(N, DAG, DCI, Subtarget);
57914   case X86ISD::MGATHER:
57915   case X86ISD::MSCATTER:    return combineX86GatherScatter(N, DAG, DCI);
57916   case ISD::MGATHER:
57917   case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
57918   case X86ISD::PCMPEQ:
57919   case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
57920   case X86ISD::PMULDQ:
57921   case X86ISD::PMULUDQ:     return combinePMULDQ(N, DAG, DCI, Subtarget);
57922   case X86ISD::VPMADDUBSW:
57923   case X86ISD::VPMADDWD:    return combineVPMADD(N, DAG, DCI);
57924   case X86ISD::KSHIFTL:
57925   case X86ISD::KSHIFTR:     return combineKSHIFT(N, DAG, DCI);
57926   case ISD::FP16_TO_FP:     return combineFP16_TO_FP(N, DAG, Subtarget);
57927   case ISD::STRICT_FP_EXTEND:
57928   case ISD::FP_EXTEND:      return combineFP_EXTEND(N, DAG, DCI, Subtarget);
57929   case ISD::STRICT_FP_ROUND:
57930   case ISD::FP_ROUND:       return combineFP_ROUND(N, DAG, Subtarget);
57931   case X86ISD::VBROADCAST_LOAD:
57932   case X86ISD::SUBV_BROADCAST_LOAD: return combineBROADCAST_LOAD(N, DAG, DCI);
57933   case X86ISD::MOVDQ2Q:     return combineMOVDQ2Q(N, DAG);
57934   case X86ISD::PDEP:        return combinePDEP(N, DAG, DCI);
57935   // clang-format on
57936   }
57937 
57938   return SDValue();
57939 }
57940 
preferABDSToABSWithNSW(EVT VT) const57941 bool X86TargetLowering::preferABDSToABSWithNSW(EVT VT) const {
57942   return false;
57943 }
57944 
57945 // Prefer (non-AVX512) vector TRUNCATE(SIGN_EXTEND_INREG(X)) to use of PACKSS.
preferSextInRegOfTruncate(EVT TruncVT,EVT VT,EVT ExtVT) const57946 bool X86TargetLowering::preferSextInRegOfTruncate(EVT TruncVT, EVT VT,
57947                                                   EVT ExtVT) const {
57948   return Subtarget.hasAVX512() || !VT.isVector();
57949 }
57950 
isTypeDesirableForOp(unsigned Opc,EVT VT) const57951 bool X86TargetLowering::isTypeDesirableForOp(unsigned Opc, EVT VT) const {
57952   if (!isTypeLegal(VT))
57953     return false;
57954 
57955   // There are no vXi8 shifts.
57956   if (Opc == ISD::SHL && VT.isVector() && VT.getVectorElementType() == MVT::i8)
57957     return false;
57958 
57959   // TODO: Almost no 8-bit ops are desirable because they have no actual
57960   //       size/speed advantages vs. 32-bit ops, but they do have a major
57961   //       potential disadvantage by causing partial register stalls.
57962   //
57963   // 8-bit multiply/shl is probably not cheaper than 32-bit multiply/shl, and
57964   // we have specializations to turn 32-bit multiply/shl into LEA or other ops.
57965   // Also, see the comment in "IsDesirableToPromoteOp" - where we additionally
57966   // check for a constant operand to the multiply.
57967   if ((Opc == ISD::MUL || Opc == ISD::SHL) && VT == MVT::i8)
57968     return false;
57969 
57970   // i16 instruction encodings are longer and some i16 instructions are slow,
57971   // so those are not desirable.
57972   if (VT == MVT::i16) {
57973     switch (Opc) {
57974     default:
57975       break;
57976     case ISD::LOAD:
57977     case ISD::SIGN_EXTEND:
57978     case ISD::ZERO_EXTEND:
57979     case ISD::ANY_EXTEND:
57980     case ISD::MUL:
57981       return false;
57982     case ISD::SHL:
57983     case ISD::SRA:
57984     case ISD::SRL:
57985     case ISD::SUB:
57986     case ISD::ADD:
57987     case ISD::AND:
57988     case ISD::OR:
57989     case ISD::XOR:
57990       // NDD instruction never has "partial register write" issue b/c it has
57991       // destination register's upper bits [63:OSIZE]) zeroed even when
57992       // OSIZE=8/16.
57993       return Subtarget.hasNDD();
57994     }
57995   }
57996 
57997   // Any legal type not explicitly accounted for above here is desirable.
57998   return true;
57999 }
58000 
expandIndirectJTBranch(const SDLoc & dl,SDValue Value,SDValue Addr,int JTI,SelectionDAG & DAG) const58001 SDValue X86TargetLowering::expandIndirectJTBranch(const SDLoc &dl,
58002                                                   SDValue Value, SDValue Addr,
58003                                                   int JTI,
58004                                                   SelectionDAG &DAG) const {
58005   const Module *M = DAG.getMachineFunction().getFunction().getParent();
58006   Metadata *IsCFProtectionSupported = M->getModuleFlag("cf-protection-branch");
58007   if (IsCFProtectionSupported) {
58008     // In case control-flow branch protection is enabled, we need to add
58009     // notrack prefix to the indirect branch.
58010     // In order to do that we create NT_BRIND SDNode.
58011     // Upon ISEL, the pattern will convert it to jmp with NoTrack prefix.
58012     SDValue JTInfo = DAG.getJumpTableDebugInfo(JTI, Value, dl);
58013     return DAG.getNode(X86ISD::NT_BRIND, dl, MVT::Other, JTInfo, Addr);
58014   }
58015 
58016   return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, JTI, DAG);
58017 }
58018 
58019 TargetLowering::AndOrSETCCFoldKind
isDesirableToCombineLogicOpOfSETCC(const SDNode * LogicOp,const SDNode * SETCC0,const SDNode * SETCC1) const58020 X86TargetLowering::isDesirableToCombineLogicOpOfSETCC(
58021     const SDNode *LogicOp, const SDNode *SETCC0, const SDNode *SETCC1) const {
58022   using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
58023   EVT VT = LogicOp->getValueType(0);
58024   EVT OpVT = SETCC0->getOperand(0).getValueType();
58025   if (!VT.isInteger())
58026     return AndOrSETCCFoldKind::None;
58027 
58028   if (VT.isVector())
58029     return AndOrSETCCFoldKind(AndOrSETCCFoldKind::NotAnd |
58030                               (isOperationLegal(ISD::ABS, OpVT)
58031                                    ? AndOrSETCCFoldKind::ABS
58032                                    : AndOrSETCCFoldKind::None));
58033 
58034   // Don't use `NotAnd` as even though `not` is generally shorter code size than
58035   // `add`, `add` can lower to LEA which can save moves / spills. Any case where
58036   // `NotAnd` applies, `AddAnd` does as well.
58037   // TODO: Currently we lower (icmp eq/ne (and ~X, Y), 0) -> `test (not X), Y`,
58038   // if we change that to `andn Y, X` it may be worth prefering `NotAnd` here.
58039   return AndOrSETCCFoldKind::AddAnd;
58040 }
58041 
IsDesirableToPromoteOp(SDValue Op,EVT & PVT) const58042 bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const {
58043   EVT VT = Op.getValueType();
58044   bool Is8BitMulByConstant = VT == MVT::i8 && Op.getOpcode() == ISD::MUL &&
58045                              isa<ConstantSDNode>(Op.getOperand(1));
58046 
58047   // i16 is legal, but undesirable since i16 instruction encodings are longer
58048   // and some i16 instructions are slow.
58049   // 8-bit multiply-by-constant can usually be expanded to something cheaper
58050   // using LEA and/or other ALU ops.
58051   if (VT != MVT::i16 && !Is8BitMulByConstant)
58052     return false;
58053 
58054   auto IsFoldableRMW = [](SDValue Load, SDValue Op) {
58055     if (!Op.hasOneUse())
58056       return false;
58057     SDNode *User = *Op->use_begin();
58058     if (!ISD::isNormalStore(User))
58059       return false;
58060     auto *Ld = cast<LoadSDNode>(Load);
58061     auto *St = cast<StoreSDNode>(User);
58062     return Ld->getBasePtr() == St->getBasePtr();
58063   };
58064 
58065   auto IsFoldableAtomicRMW = [](SDValue Load, SDValue Op) {
58066     if (!Load.hasOneUse() || Load.getOpcode() != ISD::ATOMIC_LOAD)
58067       return false;
58068     if (!Op.hasOneUse())
58069       return false;
58070     SDNode *User = *Op->use_begin();
58071     if (User->getOpcode() != ISD::ATOMIC_STORE)
58072       return false;
58073     auto *Ld = cast<AtomicSDNode>(Load);
58074     auto *St = cast<AtomicSDNode>(User);
58075     return Ld->getBasePtr() == St->getBasePtr();
58076   };
58077 
58078   bool Commute = false;
58079   switch (Op.getOpcode()) {
58080   default: return false;
58081   case ISD::SIGN_EXTEND:
58082   case ISD::ZERO_EXTEND:
58083   case ISD::ANY_EXTEND:
58084     break;
58085   case ISD::SHL:
58086   case ISD::SRA:
58087   case ISD::SRL: {
58088     SDValue N0 = Op.getOperand(0);
58089     // Look out for (store (shl (load), x)).
58090     if (X86::mayFoldLoad(N0, Subtarget) && IsFoldableRMW(N0, Op))
58091       return false;
58092     break;
58093   }
58094   case ISD::ADD:
58095   case ISD::MUL:
58096   case ISD::AND:
58097   case ISD::OR:
58098   case ISD::XOR:
58099     Commute = true;
58100     [[fallthrough]];
58101   case ISD::SUB: {
58102     SDValue N0 = Op.getOperand(0);
58103     SDValue N1 = Op.getOperand(1);
58104     // Avoid disabling potential load folding opportunities.
58105     if (X86::mayFoldLoad(N1, Subtarget) &&
58106         (!Commute || !isa<ConstantSDNode>(N0) ||
58107          (Op.getOpcode() != ISD::MUL && IsFoldableRMW(N1, Op))))
58108       return false;
58109     if (X86::mayFoldLoad(N0, Subtarget) &&
58110         ((Commute && !isa<ConstantSDNode>(N1)) ||
58111          (Op.getOpcode() != ISD::MUL && IsFoldableRMW(N0, Op))))
58112       return false;
58113     if (IsFoldableAtomicRMW(N0, Op) ||
58114         (Commute && IsFoldableAtomicRMW(N1, Op)))
58115       return false;
58116   }
58117   }
58118 
58119   PVT = MVT::i32;
58120   return true;
58121 }
58122 
58123 //===----------------------------------------------------------------------===//
58124 //                           X86 Inline Assembly Support
58125 //===----------------------------------------------------------------------===//
58126 
58127 // Helper to match a string separated by whitespace.
matchAsm(StringRef S,ArrayRef<const char * > Pieces)58128 static bool matchAsm(StringRef S, ArrayRef<const char *> Pieces) {
58129   S = S.substr(S.find_first_not_of(" \t")); // Skip leading whitespace.
58130 
58131   for (StringRef Piece : Pieces) {
58132     if (!S.starts_with(Piece)) // Check if the piece matches.
58133       return false;
58134 
58135     S = S.substr(Piece.size());
58136     StringRef::size_type Pos = S.find_first_not_of(" \t");
58137     if (Pos == 0) // We matched a prefix.
58138       return false;
58139 
58140     S = S.substr(Pos);
58141   }
58142 
58143   return S.empty();
58144 }
58145 
clobbersFlagRegisters(const SmallVector<StringRef,4> & AsmPieces)58146 static bool clobbersFlagRegisters(const SmallVector<StringRef, 4> &AsmPieces) {
58147 
58148   if (AsmPieces.size() == 3 || AsmPieces.size() == 4) {
58149     if (llvm::is_contained(AsmPieces, "~{cc}") &&
58150         llvm::is_contained(AsmPieces, "~{flags}") &&
58151         llvm::is_contained(AsmPieces, "~{fpsr}")) {
58152 
58153       if (AsmPieces.size() == 3)
58154         return true;
58155       else if (llvm::is_contained(AsmPieces, "~{dirflag}"))
58156         return true;
58157     }
58158   }
58159   return false;
58160 }
58161 
ExpandInlineAsm(CallInst * CI) const58162 bool X86TargetLowering::ExpandInlineAsm(CallInst *CI) const {
58163   InlineAsm *IA = cast<InlineAsm>(CI->getCalledOperand());
58164 
58165   const std::string &AsmStr = IA->getAsmString();
58166 
58167   IntegerType *Ty = dyn_cast<IntegerType>(CI->getType());
58168   if (!Ty || Ty->getBitWidth() % 16 != 0)
58169     return false;
58170 
58171   // TODO: should remove alternatives from the asmstring: "foo {a|b}" -> "foo a"
58172   SmallVector<StringRef, 4> AsmPieces;
58173   SplitString(AsmStr, AsmPieces, ";\n");
58174 
58175   switch (AsmPieces.size()) {
58176   default: return false;
58177   case 1:
58178     // FIXME: this should verify that we are targeting a 486 or better.  If not,
58179     // we will turn this bswap into something that will be lowered to logical
58180     // ops instead of emitting the bswap asm.  For now, we don't support 486 or
58181     // lower so don't worry about this.
58182     // bswap $0
58183     if (matchAsm(AsmPieces[0], {"bswap", "$0"}) ||
58184         matchAsm(AsmPieces[0], {"bswapl", "$0"}) ||
58185         matchAsm(AsmPieces[0], {"bswapq", "$0"}) ||
58186         matchAsm(AsmPieces[0], {"bswap", "${0:q}"}) ||
58187         matchAsm(AsmPieces[0], {"bswapl", "${0:q}"}) ||
58188         matchAsm(AsmPieces[0], {"bswapq", "${0:q}"})) {
58189       // No need to check constraints, nothing other than the equivalent of
58190       // "=r,0" would be valid here.
58191       return IntrinsicLowering::LowerToByteSwap(CI);
58192     }
58193 
58194     // rorw $$8, ${0:w}  -->  llvm.bswap.i16
58195     if (CI->getType()->isIntegerTy(16) &&
58196         IA->getConstraintString().compare(0, 5, "=r,0,") == 0 &&
58197         (matchAsm(AsmPieces[0], {"rorw", "$$8,", "${0:w}"}) ||
58198          matchAsm(AsmPieces[0], {"rolw", "$$8,", "${0:w}"}))) {
58199       AsmPieces.clear();
58200       StringRef ConstraintsStr = IA->getConstraintString();
58201       SplitString(StringRef(ConstraintsStr).substr(5), AsmPieces, ",");
58202       array_pod_sort(AsmPieces.begin(), AsmPieces.end());
58203       if (clobbersFlagRegisters(AsmPieces))
58204         return IntrinsicLowering::LowerToByteSwap(CI);
58205     }
58206     break;
58207   case 3:
58208     if (CI->getType()->isIntegerTy(32) &&
58209         IA->getConstraintString().compare(0, 5, "=r,0,") == 0 &&
58210         matchAsm(AsmPieces[0], {"rorw", "$$8,", "${0:w}"}) &&
58211         matchAsm(AsmPieces[1], {"rorl", "$$16,", "$0"}) &&
58212         matchAsm(AsmPieces[2], {"rorw", "$$8,", "${0:w}"})) {
58213       AsmPieces.clear();
58214       StringRef ConstraintsStr = IA->getConstraintString();
58215       SplitString(StringRef(ConstraintsStr).substr(5), AsmPieces, ",");
58216       array_pod_sort(AsmPieces.begin(), AsmPieces.end());
58217       if (clobbersFlagRegisters(AsmPieces))
58218         return IntrinsicLowering::LowerToByteSwap(CI);
58219     }
58220 
58221     if (CI->getType()->isIntegerTy(64)) {
58222       InlineAsm::ConstraintInfoVector Constraints = IA->ParseConstraints();
58223       if (Constraints.size() >= 2 &&
58224           Constraints[0].Codes.size() == 1 && Constraints[0].Codes[0] == "A" &&
58225           Constraints[1].Codes.size() == 1 && Constraints[1].Codes[0] == "0") {
58226         // bswap %eax / bswap %edx / xchgl %eax, %edx  -> llvm.bswap.i64
58227         if (matchAsm(AsmPieces[0], {"bswap", "%eax"}) &&
58228             matchAsm(AsmPieces[1], {"bswap", "%edx"}) &&
58229             matchAsm(AsmPieces[2], {"xchgl", "%eax,", "%edx"}))
58230           return IntrinsicLowering::LowerToByteSwap(CI);
58231       }
58232     }
58233     break;
58234   }
58235   return false;
58236 }
58237 
parseConstraintCode(llvm::StringRef Constraint)58238 static X86::CondCode parseConstraintCode(llvm::StringRef Constraint) {
58239   X86::CondCode Cond = StringSwitch<X86::CondCode>(Constraint)
58240                            .Case("{@cca}", X86::COND_A)
58241                            .Case("{@ccae}", X86::COND_AE)
58242                            .Case("{@ccb}", X86::COND_B)
58243                            .Case("{@ccbe}", X86::COND_BE)
58244                            .Case("{@ccc}", X86::COND_B)
58245                            .Case("{@cce}", X86::COND_E)
58246                            .Case("{@ccz}", X86::COND_E)
58247                            .Case("{@ccg}", X86::COND_G)
58248                            .Case("{@ccge}", X86::COND_GE)
58249                            .Case("{@ccl}", X86::COND_L)
58250                            .Case("{@ccle}", X86::COND_LE)
58251                            .Case("{@ccna}", X86::COND_BE)
58252                            .Case("{@ccnae}", X86::COND_B)
58253                            .Case("{@ccnb}", X86::COND_AE)
58254                            .Case("{@ccnbe}", X86::COND_A)
58255                            .Case("{@ccnc}", X86::COND_AE)
58256                            .Case("{@ccne}", X86::COND_NE)
58257                            .Case("{@ccnz}", X86::COND_NE)
58258                            .Case("{@ccng}", X86::COND_LE)
58259                            .Case("{@ccnge}", X86::COND_L)
58260                            .Case("{@ccnl}", X86::COND_GE)
58261                            .Case("{@ccnle}", X86::COND_G)
58262                            .Case("{@ccno}", X86::COND_NO)
58263                            .Case("{@ccnp}", X86::COND_NP)
58264                            .Case("{@ccns}", X86::COND_NS)
58265                            .Case("{@cco}", X86::COND_O)
58266                            .Case("{@ccp}", X86::COND_P)
58267                            .Case("{@ccs}", X86::COND_S)
58268                            .Default(X86::COND_INVALID);
58269   return Cond;
58270 }
58271 
58272 /// Given a constraint letter, return the type of constraint for this target.
58273 X86TargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const58274 X86TargetLowering::getConstraintType(StringRef Constraint) const {
58275   if (Constraint.size() == 1) {
58276     switch (Constraint[0]) {
58277     case 'R':
58278     case 'q':
58279     case 'Q':
58280     case 'f':
58281     case 't':
58282     case 'u':
58283     case 'y':
58284     case 'x':
58285     case 'v':
58286     case 'l':
58287     case 'k': // AVX512 masking registers.
58288       return C_RegisterClass;
58289     case 'a':
58290     case 'b':
58291     case 'c':
58292     case 'd':
58293     case 'S':
58294     case 'D':
58295     case 'A':
58296       return C_Register;
58297     case 'I':
58298     case 'J':
58299     case 'K':
58300     case 'N':
58301     case 'G':
58302     case 'L':
58303     case 'M':
58304       return C_Immediate;
58305     case 'C':
58306     case 'e':
58307     case 'Z':
58308       return C_Other;
58309     default:
58310       break;
58311     }
58312   }
58313   else if (Constraint.size() == 2) {
58314     switch (Constraint[0]) {
58315     default:
58316       break;
58317     case 'W':
58318       if (Constraint[1] != 's')
58319         break;
58320       return C_Other;
58321     case 'Y':
58322       switch (Constraint[1]) {
58323       default:
58324         break;
58325       case 'z':
58326         return C_Register;
58327       case 'i':
58328       case 'm':
58329       case 'k':
58330       case 't':
58331       case '2':
58332         return C_RegisterClass;
58333       }
58334       break;
58335     case 'j':
58336       switch (Constraint[1]) {
58337       default:
58338         break;
58339       case 'r':
58340       case 'R':
58341         return C_RegisterClass;
58342       }
58343     }
58344   } else if (parseConstraintCode(Constraint) != X86::COND_INVALID)
58345     return C_Other;
58346   return TargetLowering::getConstraintType(Constraint);
58347 }
58348 
58349 /// Examine constraint type and operand type and determine a weight value.
58350 /// This object must already have been set up with the operand type
58351 /// and the current alternative constraint selected.
58352 TargetLowering::ConstraintWeight
getSingleConstraintMatchWeight(AsmOperandInfo & Info,const char * Constraint) const58353 X86TargetLowering::getSingleConstraintMatchWeight(
58354     AsmOperandInfo &Info, const char *Constraint) const {
58355   ConstraintWeight Wt = CW_Invalid;
58356   Value *CallOperandVal = Info.CallOperandVal;
58357   // If we don't have a value, we can't do a match,
58358   // but allow it at the lowest weight.
58359   if (!CallOperandVal)
58360     return CW_Default;
58361   Type *Ty = CallOperandVal->getType();
58362   // Look at the constraint type.
58363   switch (*Constraint) {
58364   default:
58365     Wt = TargetLowering::getSingleConstraintMatchWeight(Info, Constraint);
58366     [[fallthrough]];
58367   case 'R':
58368   case 'q':
58369   case 'Q':
58370   case 'a':
58371   case 'b':
58372   case 'c':
58373   case 'd':
58374   case 'S':
58375   case 'D':
58376   case 'A':
58377     if (CallOperandVal->getType()->isIntegerTy())
58378       Wt = CW_SpecificReg;
58379     break;
58380   case 'f':
58381   case 't':
58382   case 'u':
58383     if (Ty->isFloatingPointTy())
58384       Wt = CW_SpecificReg;
58385     break;
58386   case 'y':
58387     if (Ty->isX86_MMXTy() && Subtarget.hasMMX())
58388       Wt = CW_SpecificReg;
58389     break;
58390   case 'Y':
58391     if (StringRef(Constraint).size() != 2)
58392       break;
58393     switch (Constraint[1]) {
58394     default:
58395       return CW_Invalid;
58396     // XMM0
58397     case 'z':
58398       if (((Ty->getPrimitiveSizeInBits() == 128) && Subtarget.hasSSE1()) ||
58399           ((Ty->getPrimitiveSizeInBits() == 256) && Subtarget.hasAVX()) ||
58400           ((Ty->getPrimitiveSizeInBits() == 512) && Subtarget.hasAVX512()))
58401         return CW_SpecificReg;
58402       return CW_Invalid;
58403     // Conditional OpMask regs (AVX512)
58404     case 'k':
58405       if ((Ty->getPrimitiveSizeInBits() == 64) && Subtarget.hasAVX512())
58406         return CW_Register;
58407       return CW_Invalid;
58408     // Any MMX reg
58409     case 'm':
58410       if (Ty->isX86_MMXTy() && Subtarget.hasMMX())
58411         return Wt;
58412       return CW_Invalid;
58413     // Any SSE reg when ISA >= SSE2, same as 'x'
58414     case 'i':
58415     case 't':
58416     case '2':
58417       if (!Subtarget.hasSSE2())
58418         return CW_Invalid;
58419       break;
58420     }
58421     break;
58422   case 'j':
58423     if (StringRef(Constraint).size() != 2)
58424       break;
58425     switch (Constraint[1]) {
58426     default:
58427       return CW_Invalid;
58428     case 'r':
58429     case 'R':
58430       if (CallOperandVal->getType()->isIntegerTy())
58431         Wt = CW_SpecificReg;
58432       break;
58433     }
58434     break;
58435   case 'v':
58436     if ((Ty->getPrimitiveSizeInBits() == 512) && Subtarget.hasAVX512())
58437       Wt = CW_Register;
58438     [[fallthrough]];
58439   case 'x':
58440     if (((Ty->getPrimitiveSizeInBits() == 128) && Subtarget.hasSSE1()) ||
58441         ((Ty->getPrimitiveSizeInBits() == 256) && Subtarget.hasAVX()))
58442       Wt = CW_Register;
58443     break;
58444   case 'k':
58445     // Enable conditional vector operations using %k<#> registers.
58446     if ((Ty->getPrimitiveSizeInBits() == 64) && Subtarget.hasAVX512())
58447       Wt = CW_Register;
58448     break;
58449   case 'I':
58450     if (auto *C = dyn_cast<ConstantInt>(Info.CallOperandVal))
58451       if (C->getZExtValue() <= 31)
58452         Wt = CW_Constant;
58453     break;
58454   case 'J':
58455     if (auto *C = dyn_cast<ConstantInt>(CallOperandVal))
58456       if (C->getZExtValue() <= 63)
58457         Wt = CW_Constant;
58458     break;
58459   case 'K':
58460     if (auto *C = dyn_cast<ConstantInt>(CallOperandVal))
58461       if ((C->getSExtValue() >= -0x80) && (C->getSExtValue() <= 0x7f))
58462         Wt = CW_Constant;
58463     break;
58464   case 'L':
58465     if (auto *C = dyn_cast<ConstantInt>(CallOperandVal))
58466       if ((C->getZExtValue() == 0xff) || (C->getZExtValue() == 0xffff))
58467         Wt = CW_Constant;
58468     break;
58469   case 'M':
58470     if (auto *C = dyn_cast<ConstantInt>(CallOperandVal))
58471       if (C->getZExtValue() <= 3)
58472         Wt = CW_Constant;
58473     break;
58474   case 'N':
58475     if (auto *C = dyn_cast<ConstantInt>(CallOperandVal))
58476       if (C->getZExtValue() <= 0xff)
58477         Wt = CW_Constant;
58478     break;
58479   case 'G':
58480   case 'C':
58481     if (isa<ConstantFP>(CallOperandVal))
58482       Wt = CW_Constant;
58483     break;
58484   case 'e':
58485     if (auto *C = dyn_cast<ConstantInt>(CallOperandVal))
58486       if ((C->getSExtValue() >= -0x80000000LL) &&
58487           (C->getSExtValue() <= 0x7fffffffLL))
58488         Wt = CW_Constant;
58489     break;
58490   case 'Z':
58491     if (auto *C = dyn_cast<ConstantInt>(CallOperandVal))
58492       if (C->getZExtValue() <= 0xffffffff)
58493         Wt = CW_Constant;
58494     break;
58495   }
58496   return Wt;
58497 }
58498 
58499 /// Try to replace an X constraint, which matches anything, with another that
58500 /// has more specific requirements based on the type of the corresponding
58501 /// operand.
58502 const char *X86TargetLowering::
LowerXConstraint(EVT ConstraintVT) const58503 LowerXConstraint(EVT ConstraintVT) const {
58504   // FP X constraints get lowered to SSE1/2 registers if available, otherwise
58505   // 'f' like normal targets.
58506   if (ConstraintVT.isFloatingPoint()) {
58507     if (Subtarget.hasSSE1())
58508       return "x";
58509   }
58510 
58511   return TargetLowering::LowerXConstraint(ConstraintVT);
58512 }
58513 
58514 // Lower @cc targets via setcc.
LowerAsmOutputForConstraint(SDValue & Chain,SDValue & Glue,const SDLoc & DL,const AsmOperandInfo & OpInfo,SelectionDAG & DAG) const58515 SDValue X86TargetLowering::LowerAsmOutputForConstraint(
58516     SDValue &Chain, SDValue &Glue, const SDLoc &DL,
58517     const AsmOperandInfo &OpInfo, SelectionDAG &DAG) const {
58518   X86::CondCode Cond = parseConstraintCode(OpInfo.ConstraintCode);
58519   if (Cond == X86::COND_INVALID)
58520     return SDValue();
58521   // Check that return type is valid.
58522   if (OpInfo.ConstraintVT.isVector() || !OpInfo.ConstraintVT.isInteger() ||
58523       OpInfo.ConstraintVT.getSizeInBits() < 8)
58524     report_fatal_error("Glue output operand is of invalid type");
58525 
58526   // Get EFLAGS register. Only update chain when copyfrom is glued.
58527   if (Glue.getNode()) {
58528     Glue = DAG.getCopyFromReg(Chain, DL, X86::EFLAGS, MVT::i32, Glue);
58529     Chain = Glue.getValue(1);
58530   } else
58531     Glue = DAG.getCopyFromReg(Chain, DL, X86::EFLAGS, MVT::i32);
58532   // Extract CC code.
58533   SDValue CC = getSETCC(Cond, Glue, DL, DAG);
58534   // Extend to 32-bits
58535   SDValue Result = DAG.getNode(ISD::ZERO_EXTEND, DL, OpInfo.ConstraintVT, CC);
58536 
58537   return Result;
58538 }
58539 
58540 /// Lower the specified operand into the Ops vector.
58541 /// If it is invalid, don't add anything to Ops.
LowerAsmOperandForConstraint(SDValue Op,StringRef Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const58542 void X86TargetLowering::LowerAsmOperandForConstraint(SDValue Op,
58543                                                      StringRef Constraint,
58544                                                      std::vector<SDValue> &Ops,
58545                                                      SelectionDAG &DAG) const {
58546   SDValue Result;
58547   char ConstraintLetter = Constraint[0];
58548   switch (ConstraintLetter) {
58549   default: break;
58550   case 'I':
58551     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58552       if (C->getZExtValue() <= 31) {
58553         Result = DAG.getTargetConstant(C->getZExtValue(), SDLoc(Op),
58554                                        Op.getValueType());
58555         break;
58556       }
58557     }
58558     return;
58559   case 'J':
58560     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58561       if (C->getZExtValue() <= 63) {
58562         Result = DAG.getTargetConstant(C->getZExtValue(), SDLoc(Op),
58563                                        Op.getValueType());
58564         break;
58565       }
58566     }
58567     return;
58568   case 'K':
58569     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58570       if (isInt<8>(C->getSExtValue())) {
58571         Result = DAG.getTargetConstant(C->getZExtValue(), SDLoc(Op),
58572                                        Op.getValueType());
58573         break;
58574       }
58575     }
58576     return;
58577   case 'L':
58578     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58579       if (C->getZExtValue() == 0xff || C->getZExtValue() == 0xffff ||
58580           (Subtarget.is64Bit() && C->getZExtValue() == 0xffffffff)) {
58581         Result = DAG.getTargetConstant(C->getSExtValue(), SDLoc(Op),
58582                                        Op.getValueType());
58583         break;
58584       }
58585     }
58586     return;
58587   case 'M':
58588     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58589       if (C->getZExtValue() <= 3) {
58590         Result = DAG.getTargetConstant(C->getZExtValue(), SDLoc(Op),
58591                                        Op.getValueType());
58592         break;
58593       }
58594     }
58595     return;
58596   case 'N':
58597     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58598       if (C->getZExtValue() <= 255) {
58599         Result = DAG.getTargetConstant(C->getZExtValue(), SDLoc(Op),
58600                                        Op.getValueType());
58601         break;
58602       }
58603     }
58604     return;
58605   case 'O':
58606     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58607       if (C->getZExtValue() <= 127) {
58608         Result = DAG.getTargetConstant(C->getZExtValue(), SDLoc(Op),
58609                                        Op.getValueType());
58610         break;
58611       }
58612     }
58613     return;
58614   case 'e': {
58615     // 32-bit signed value
58616     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58617       if (ConstantInt::isValueValidForType(Type::getInt32Ty(*DAG.getContext()),
58618                                            C->getSExtValue())) {
58619         // Widen to 64 bits here to get it sign extended.
58620         Result = DAG.getTargetConstant(C->getSExtValue(), SDLoc(Op), MVT::i64);
58621         break;
58622       }
58623     // FIXME gcc accepts some relocatable values here too, but only in certain
58624     // memory models; it's complicated.
58625     }
58626     return;
58627   }
58628   case 'W': {
58629     assert(Constraint[1] == 's');
58630     // Op is a BlockAddressSDNode or a GlobalAddressSDNode with an optional
58631     // offset.
58632     if (const auto *BA = dyn_cast<BlockAddressSDNode>(Op)) {
58633       Ops.push_back(DAG.getTargetBlockAddress(BA->getBlockAddress(),
58634                                               BA->getValueType(0)));
58635     } else {
58636       int64_t Offset = 0;
58637       if (Op->getOpcode() == ISD::ADD &&
58638           isa<ConstantSDNode>(Op->getOperand(1))) {
58639         Offset = cast<ConstantSDNode>(Op->getOperand(1))->getSExtValue();
58640         Op = Op->getOperand(0);
58641       }
58642       if (const auto *GA = dyn_cast<GlobalAddressSDNode>(Op))
58643         Ops.push_back(DAG.getTargetGlobalAddress(GA->getGlobal(), SDLoc(Op),
58644                                                  GA->getValueType(0), Offset));
58645     }
58646     return;
58647   }
58648   case 'Z': {
58649     // 32-bit unsigned value
58650     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
58651       if (ConstantInt::isValueValidForType(Type::getInt32Ty(*DAG.getContext()),
58652                                            C->getZExtValue())) {
58653         Result = DAG.getTargetConstant(C->getZExtValue(), SDLoc(Op),
58654                                        Op.getValueType());
58655         break;
58656       }
58657     }
58658     // FIXME gcc accepts some relocatable values here too, but only in certain
58659     // memory models; it's complicated.
58660     return;
58661   }
58662   case 'i': {
58663     // Literal immediates are always ok.
58664     if (auto *CST = dyn_cast<ConstantSDNode>(Op)) {
58665       bool IsBool = CST->getConstantIntValue()->getBitWidth() == 1;
58666       BooleanContent BCont = getBooleanContents(MVT::i64);
58667       ISD::NodeType ExtOpc = IsBool ? getExtendForContent(BCont)
58668                                     : ISD::SIGN_EXTEND;
58669       int64_t ExtVal = ExtOpc == ISD::ZERO_EXTEND ? CST->getZExtValue()
58670                                                   : CST->getSExtValue();
58671       Result = DAG.getTargetConstant(ExtVal, SDLoc(Op), MVT::i64);
58672       break;
58673     }
58674 
58675     // In any sort of PIC mode addresses need to be computed at runtime by
58676     // adding in a register or some sort of table lookup.  These can't
58677     // be used as immediates. BlockAddresses and BasicBlocks are fine though.
58678     if ((Subtarget.isPICStyleGOT() || Subtarget.isPICStyleStubPIC()) &&
58679         !(isa<BlockAddressSDNode>(Op) || isa<BasicBlockSDNode>(Op)))
58680       return;
58681 
58682     // If we are in non-pic codegen mode, we allow the address of a global (with
58683     // an optional displacement) to be used with 'i'.
58684     if (auto *GA = dyn_cast<GlobalAddressSDNode>(Op))
58685       // If we require an extra load to get this address, as in PIC mode, we
58686       // can't accept it.
58687       if (isGlobalStubReference(
58688               Subtarget.classifyGlobalReference(GA->getGlobal())))
58689         return;
58690     break;
58691   }
58692   }
58693 
58694   if (Result.getNode()) {
58695     Ops.push_back(Result);
58696     return;
58697   }
58698   return TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
58699 }
58700 
58701 /// Check if \p RC is a general purpose register class.
58702 /// I.e., GR* or one of their variant.
isGRClass(const TargetRegisterClass & RC)58703 static bool isGRClass(const TargetRegisterClass &RC) {
58704   return RC.hasSuperClassEq(&X86::GR8RegClass) ||
58705          RC.hasSuperClassEq(&X86::GR16RegClass) ||
58706          RC.hasSuperClassEq(&X86::GR32RegClass) ||
58707          RC.hasSuperClassEq(&X86::GR64RegClass) ||
58708          RC.hasSuperClassEq(&X86::LOW32_ADDR_ACCESS_RBPRegClass);
58709 }
58710 
58711 /// Check if \p RC is a vector register class.
58712 /// I.e., FR* / VR* or one of their variant.
isFRClass(const TargetRegisterClass & RC)58713 static bool isFRClass(const TargetRegisterClass &RC) {
58714   return RC.hasSuperClassEq(&X86::FR16XRegClass) ||
58715          RC.hasSuperClassEq(&X86::FR32XRegClass) ||
58716          RC.hasSuperClassEq(&X86::FR64XRegClass) ||
58717          RC.hasSuperClassEq(&X86::VR128XRegClass) ||
58718          RC.hasSuperClassEq(&X86::VR256XRegClass) ||
58719          RC.hasSuperClassEq(&X86::VR512RegClass);
58720 }
58721 
58722 /// Check if \p RC is a mask register class.
58723 /// I.e., VK* or one of their variant.
isVKClass(const TargetRegisterClass & RC)58724 static bool isVKClass(const TargetRegisterClass &RC) {
58725   return RC.hasSuperClassEq(&X86::VK1RegClass) ||
58726          RC.hasSuperClassEq(&X86::VK2RegClass) ||
58727          RC.hasSuperClassEq(&X86::VK4RegClass) ||
58728          RC.hasSuperClassEq(&X86::VK8RegClass) ||
58729          RC.hasSuperClassEq(&X86::VK16RegClass) ||
58730          RC.hasSuperClassEq(&X86::VK32RegClass) ||
58731          RC.hasSuperClassEq(&X86::VK64RegClass);
58732 }
58733 
useEGPRInlineAsm(const X86Subtarget & Subtarget)58734 static bool useEGPRInlineAsm(const X86Subtarget &Subtarget) {
58735   return Subtarget.hasEGPR() && Subtarget.useInlineAsmGPR32();
58736 }
58737 
58738 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const58739 X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
58740                                                 StringRef Constraint,
58741                                                 MVT VT) const {
58742   // First, see if this is a constraint that directly corresponds to an LLVM
58743   // register class.
58744   if (Constraint.size() == 1) {
58745     // GCC Constraint Letters
58746     switch (Constraint[0]) {
58747     default: break;
58748     // 'A' means [ER]AX + [ER]DX.
58749     case 'A':
58750       if (Subtarget.is64Bit())
58751         return std::make_pair(X86::RAX, &X86::GR64_ADRegClass);
58752       assert((Subtarget.is32Bit() || Subtarget.is16Bit()) &&
58753              "Expecting 64, 32 or 16 bit subtarget");
58754       return std::make_pair(X86::EAX, &X86::GR32_ADRegClass);
58755 
58756       // TODO: Slight differences here in allocation order and leaving
58757       // RIP in the class. Do they matter any more here than they do
58758       // in the normal allocation?
58759     case 'k':
58760       if (Subtarget.hasAVX512()) {
58761         if (VT == MVT::v1i1 || VT == MVT::i1)
58762           return std::make_pair(0U, &X86::VK1RegClass);
58763         if (VT == MVT::v8i1 || VT == MVT::i8)
58764           return std::make_pair(0U, &X86::VK8RegClass);
58765         if (VT == MVT::v16i1 || VT == MVT::i16)
58766           return std::make_pair(0U, &X86::VK16RegClass);
58767       }
58768       if (Subtarget.hasBWI()) {
58769         if (VT == MVT::v32i1 || VT == MVT::i32)
58770           return std::make_pair(0U, &X86::VK32RegClass);
58771         if (VT == MVT::v64i1 || VT == MVT::i64)
58772           return std::make_pair(0U, &X86::VK64RegClass);
58773       }
58774       break;
58775     case 'q':   // GENERAL_REGS in 64-bit mode, Q_REGS in 32-bit mode.
58776       if (Subtarget.is64Bit()) {
58777         if (VT == MVT::i8 || VT == MVT::i1)
58778           return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58779                                         ? &X86::GR8RegClass
58780                                         : &X86::GR8_NOREX2RegClass);
58781         if (VT == MVT::i16)
58782           return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58783                                         ? &X86::GR16RegClass
58784                                         : &X86::GR16_NOREX2RegClass);
58785         if (VT == MVT::i32 || VT == MVT::f32)
58786           return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58787                                         ? &X86::GR32RegClass
58788                                         : &X86::GR32_NOREX2RegClass);
58789         if (VT != MVT::f80 && !VT.isVector())
58790           return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58791                                         ? &X86::GR64RegClass
58792                                         : &X86::GR64_NOREX2RegClass);
58793         break;
58794       }
58795       [[fallthrough]];
58796       // 32-bit fallthrough
58797     case 'Q':   // Q_REGS
58798       if (VT == MVT::i8 || VT == MVT::i1)
58799         return std::make_pair(0U, &X86::GR8_ABCD_LRegClass);
58800       if (VT == MVT::i16)
58801         return std::make_pair(0U, &X86::GR16_ABCDRegClass);
58802       if (VT == MVT::i32 || VT == MVT::f32 ||
58803           (!VT.isVector() && !Subtarget.is64Bit()))
58804         return std::make_pair(0U, &X86::GR32_ABCDRegClass);
58805       if (VT != MVT::f80 && !VT.isVector())
58806         return std::make_pair(0U, &X86::GR64_ABCDRegClass);
58807       break;
58808     case 'r':   // GENERAL_REGS
58809     case 'l':   // INDEX_REGS
58810       if (VT == MVT::i8 || VT == MVT::i1)
58811         return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58812                                       ? &X86::GR8RegClass
58813                                       : &X86::GR8_NOREX2RegClass);
58814       if (VT == MVT::i16)
58815         return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58816                                       ? &X86::GR16RegClass
58817                                       : &X86::GR16_NOREX2RegClass);
58818       if (VT == MVT::i32 || VT == MVT::f32 ||
58819           (!VT.isVector() && !Subtarget.is64Bit()))
58820         return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58821                                       ? &X86::GR32RegClass
58822                                       : &X86::GR32_NOREX2RegClass);
58823       if (VT != MVT::f80 && !VT.isVector())
58824         return std::make_pair(0U, useEGPRInlineAsm(Subtarget)
58825                                       ? &X86::GR64RegClass
58826                                       : &X86::GR64_NOREX2RegClass);
58827       break;
58828     case 'R':   // LEGACY_REGS
58829       if (VT == MVT::i8 || VT == MVT::i1)
58830         return std::make_pair(0U, &X86::GR8_NOREXRegClass);
58831       if (VT == MVT::i16)
58832         return std::make_pair(0U, &X86::GR16_NOREXRegClass);
58833       if (VT == MVT::i32 || VT == MVT::f32 ||
58834           (!VT.isVector() && !Subtarget.is64Bit()))
58835         return std::make_pair(0U, &X86::GR32_NOREXRegClass);
58836       if (VT != MVT::f80 && !VT.isVector())
58837         return std::make_pair(0U, &X86::GR64_NOREXRegClass);
58838       break;
58839     case 'f':  // FP Stack registers.
58840       // If SSE is enabled for this VT, use f80 to ensure the isel moves the
58841       // value to the correct fpstack register class.
58842       if (VT == MVT::f32 && !isScalarFPTypeInSSEReg(VT))
58843         return std::make_pair(0U, &X86::RFP32RegClass);
58844       if (VT == MVT::f64 && !isScalarFPTypeInSSEReg(VT))
58845         return std::make_pair(0U, &X86::RFP64RegClass);
58846       if (VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f80)
58847         return std::make_pair(0U, &X86::RFP80RegClass);
58848       break;
58849     case 'y':   // MMX_REGS if MMX allowed.
58850       if (!Subtarget.hasMMX()) break;
58851       return std::make_pair(0U, &X86::VR64RegClass);
58852     case 'v':
58853     case 'x':   // SSE_REGS if SSE1 allowed or AVX_REGS if AVX allowed
58854       if (!Subtarget.hasSSE1()) break;
58855       bool VConstraint = (Constraint[0] == 'v');
58856 
58857       switch (VT.SimpleTy) {
58858       default: break;
58859       // Scalar SSE types.
58860       case MVT::f16:
58861         if (VConstraint && Subtarget.hasFP16())
58862           return std::make_pair(0U, &X86::FR16XRegClass);
58863         break;
58864       case MVT::f32:
58865       case MVT::i32:
58866         if (VConstraint && Subtarget.hasVLX())
58867           return std::make_pair(0U, &X86::FR32XRegClass);
58868         return std::make_pair(0U, &X86::FR32RegClass);
58869       case MVT::f64:
58870       case MVT::i64:
58871         if (VConstraint && Subtarget.hasVLX())
58872           return std::make_pair(0U, &X86::FR64XRegClass);
58873         return std::make_pair(0U, &X86::FR64RegClass);
58874       case MVT::i128:
58875         if (Subtarget.is64Bit()) {
58876           if (VConstraint && Subtarget.hasVLX())
58877             return std::make_pair(0U, &X86::VR128XRegClass);
58878           return std::make_pair(0U, &X86::VR128RegClass);
58879         }
58880         break;
58881       // Vector types and fp128.
58882       case MVT::v8f16:
58883         if (!Subtarget.hasFP16())
58884           break;
58885         if (VConstraint)
58886           return std::make_pair(0U, &X86::VR128XRegClass);
58887         return std::make_pair(0U, &X86::VR128RegClass);
58888       case MVT::v8bf16:
58889         if (!Subtarget.hasBF16() || !Subtarget.hasVLX())
58890           break;
58891         if (VConstraint)
58892           return std::make_pair(0U, &X86::VR128XRegClass);
58893         return std::make_pair(0U, &X86::VR128RegClass);
58894       case MVT::f128:
58895       case MVT::v16i8:
58896       case MVT::v8i16:
58897       case MVT::v4i32:
58898       case MVT::v2i64:
58899       case MVT::v4f32:
58900       case MVT::v2f64:
58901         if (VConstraint && Subtarget.hasVLX())
58902           return std::make_pair(0U, &X86::VR128XRegClass);
58903         return std::make_pair(0U, &X86::VR128RegClass);
58904       // AVX types.
58905       case MVT::v16f16:
58906         if (!Subtarget.hasFP16())
58907           break;
58908         if (VConstraint)
58909           return std::make_pair(0U, &X86::VR256XRegClass);
58910         return std::make_pair(0U, &X86::VR256RegClass);
58911       case MVT::v16bf16:
58912         if (!Subtarget.hasBF16() || !Subtarget.hasVLX())
58913           break;
58914         if (VConstraint)
58915           return std::make_pair(0U, &X86::VR256XRegClass);
58916         return std::make_pair(0U, &X86::VR256RegClass);
58917       case MVT::v32i8:
58918       case MVT::v16i16:
58919       case MVT::v8i32:
58920       case MVT::v4i64:
58921       case MVT::v8f32:
58922       case MVT::v4f64:
58923         if (VConstraint && Subtarget.hasVLX())
58924           return std::make_pair(0U, &X86::VR256XRegClass);
58925         if (Subtarget.hasAVX())
58926           return std::make_pair(0U, &X86::VR256RegClass);
58927         break;
58928       case MVT::v32f16:
58929         if (!Subtarget.hasFP16())
58930           break;
58931         if (VConstraint)
58932           return std::make_pair(0U, &X86::VR512RegClass);
58933         return std::make_pair(0U, &X86::VR512_0_15RegClass);
58934       case MVT::v32bf16:
58935         if (!Subtarget.hasBF16())
58936           break;
58937         if (VConstraint)
58938           return std::make_pair(0U, &X86::VR512RegClass);
58939         return std::make_pair(0U, &X86::VR512_0_15RegClass);
58940       case MVT::v64i8:
58941       case MVT::v32i16:
58942       case MVT::v8f64:
58943       case MVT::v16f32:
58944       case MVT::v16i32:
58945       case MVT::v8i64:
58946         if (!Subtarget.hasAVX512()) break;
58947         if (VConstraint)
58948           return std::make_pair(0U, &X86::VR512RegClass);
58949         return std::make_pair(0U, &X86::VR512_0_15RegClass);
58950       }
58951       break;
58952     }
58953   } else if (Constraint.size() == 2 && Constraint[0] == 'Y') {
58954     switch (Constraint[1]) {
58955     default:
58956       break;
58957     case 'i':
58958     case 't':
58959     case '2':
58960       return getRegForInlineAsmConstraint(TRI, "x", VT);
58961     case 'm':
58962       if (!Subtarget.hasMMX()) break;
58963       return std::make_pair(0U, &X86::VR64RegClass);
58964     case 'z':
58965       if (!Subtarget.hasSSE1()) break;
58966       switch (VT.SimpleTy) {
58967       default: break;
58968       // Scalar SSE types.
58969       case MVT::f16:
58970         if (!Subtarget.hasFP16())
58971           break;
58972         return std::make_pair(X86::XMM0, &X86::FR16XRegClass);
58973       case MVT::f32:
58974       case MVT::i32:
58975         return std::make_pair(X86::XMM0, &X86::FR32RegClass);
58976       case MVT::f64:
58977       case MVT::i64:
58978         return std::make_pair(X86::XMM0, &X86::FR64RegClass);
58979       case MVT::v8f16:
58980         if (!Subtarget.hasFP16())
58981           break;
58982         return std::make_pair(X86::XMM0, &X86::VR128RegClass);
58983       case MVT::v8bf16:
58984         if (!Subtarget.hasBF16() || !Subtarget.hasVLX())
58985           break;
58986         return std::make_pair(X86::XMM0, &X86::VR128RegClass);
58987       case MVT::f128:
58988       case MVT::v16i8:
58989       case MVT::v8i16:
58990       case MVT::v4i32:
58991       case MVT::v2i64:
58992       case MVT::v4f32:
58993       case MVT::v2f64:
58994         return std::make_pair(X86::XMM0, &X86::VR128RegClass);
58995       // AVX types.
58996       case MVT::v16f16:
58997         if (!Subtarget.hasFP16())
58998           break;
58999         return std::make_pair(X86::YMM0, &X86::VR256RegClass);
59000       case MVT::v16bf16:
59001         if (!Subtarget.hasBF16() || !Subtarget.hasVLX())
59002           break;
59003         return std::make_pair(X86::YMM0, &X86::VR256RegClass);
59004       case MVT::v32i8:
59005       case MVT::v16i16:
59006       case MVT::v8i32:
59007       case MVT::v4i64:
59008       case MVT::v8f32:
59009       case MVT::v4f64:
59010         if (Subtarget.hasAVX())
59011           return std::make_pair(X86::YMM0, &X86::VR256RegClass);
59012         break;
59013       case MVT::v32f16:
59014         if (!Subtarget.hasFP16())
59015           break;
59016         return std::make_pair(X86::ZMM0, &X86::VR512_0_15RegClass);
59017       case MVT::v32bf16:
59018         if (!Subtarget.hasBF16())
59019           break;
59020         return std::make_pair(X86::ZMM0, &X86::VR512_0_15RegClass);
59021       case MVT::v64i8:
59022       case MVT::v32i16:
59023       case MVT::v8f64:
59024       case MVT::v16f32:
59025       case MVT::v16i32:
59026       case MVT::v8i64:
59027         if (Subtarget.hasAVX512())
59028           return std::make_pair(X86::ZMM0, &X86::VR512_0_15RegClass);
59029         break;
59030       }
59031       break;
59032     case 'k':
59033       // This register class doesn't allocate k0 for masked vector operation.
59034       if (Subtarget.hasAVX512()) {
59035         if (VT == MVT::v1i1 || VT == MVT::i1)
59036           return std::make_pair(0U, &X86::VK1WMRegClass);
59037         if (VT == MVT::v8i1 || VT == MVT::i8)
59038           return std::make_pair(0U, &X86::VK8WMRegClass);
59039         if (VT == MVT::v16i1 || VT == MVT::i16)
59040           return std::make_pair(0U, &X86::VK16WMRegClass);
59041       }
59042       if (Subtarget.hasBWI()) {
59043         if (VT == MVT::v32i1 || VT == MVT::i32)
59044           return std::make_pair(0U, &X86::VK32WMRegClass);
59045         if (VT == MVT::v64i1 || VT == MVT::i64)
59046           return std::make_pair(0U, &X86::VK64WMRegClass);
59047       }
59048       break;
59049     }
59050   } else if (Constraint.size() == 2 && Constraint[0] == 'j') {
59051     switch (Constraint[1]) {
59052     default:
59053       break;
59054     case 'r':
59055       if (VT == MVT::i8 || VT == MVT::i1)
59056         return std::make_pair(0U, &X86::GR8_NOREX2RegClass);
59057       if (VT == MVT::i16)
59058         return std::make_pair(0U, &X86::GR16_NOREX2RegClass);
59059       if (VT == MVT::i32 || VT == MVT::f32)
59060         return std::make_pair(0U, &X86::GR32_NOREX2RegClass);
59061       if (VT != MVT::f80 && !VT.isVector())
59062         return std::make_pair(0U, &X86::GR64_NOREX2RegClass);
59063       break;
59064     case 'R':
59065       if (VT == MVT::i8 || VT == MVT::i1)
59066         return std::make_pair(0U, &X86::GR8RegClass);
59067       if (VT == MVT::i16)
59068         return std::make_pair(0U, &X86::GR16RegClass);
59069       if (VT == MVT::i32 || VT == MVT::f32)
59070         return std::make_pair(0U, &X86::GR32RegClass);
59071       if (VT != MVT::f80 && !VT.isVector())
59072         return std::make_pair(0U, &X86::GR64RegClass);
59073       break;
59074     }
59075   }
59076 
59077   if (parseConstraintCode(Constraint) != X86::COND_INVALID)
59078     return std::make_pair(0U, &X86::GR32RegClass);
59079 
59080   // Use the default implementation in TargetLowering to convert the register
59081   // constraint into a member of a register class.
59082   std::pair<Register, const TargetRegisterClass*> Res;
59083   Res = TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
59084 
59085   // Not found as a standard register?
59086   if (!Res.second) {
59087     // Only match x87 registers if the VT is one SelectionDAGBuilder can convert
59088     // to/from f80.
59089     if (VT == MVT::Other || VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f80) {
59090       // Map st(0) -> st(7) -> ST0
59091       if (Constraint.size() == 7 && Constraint[0] == '{' &&
59092           tolower(Constraint[1]) == 's' && tolower(Constraint[2]) == 't' &&
59093           Constraint[3] == '(' &&
59094           (Constraint[4] >= '0' && Constraint[4] <= '7') &&
59095           Constraint[5] == ')' && Constraint[6] == '}') {
59096         // st(7) is not allocatable and thus not a member of RFP80. Return
59097         // singleton class in cases where we have a reference to it.
59098         if (Constraint[4] == '7')
59099           return std::make_pair(X86::FP7, &X86::RFP80_7RegClass);
59100         return std::make_pair(X86::FP0 + Constraint[4] - '0',
59101                               &X86::RFP80RegClass);
59102       }
59103 
59104       // GCC allows "st(0)" to be called just plain "st".
59105       if (StringRef("{st}").equals_insensitive(Constraint))
59106         return std::make_pair(X86::FP0, &X86::RFP80RegClass);
59107     }
59108 
59109     // flags -> EFLAGS
59110     if (StringRef("{flags}").equals_insensitive(Constraint))
59111       return std::make_pair(X86::EFLAGS, &X86::CCRRegClass);
59112 
59113     // dirflag -> DF
59114     // Only allow for clobber.
59115     if (StringRef("{dirflag}").equals_insensitive(Constraint) &&
59116         VT == MVT::Other)
59117       return std::make_pair(X86::DF, &X86::DFCCRRegClass);
59118 
59119     // fpsr -> FPSW
59120     // Only allow for clobber.
59121     if (StringRef("{fpsr}").equals_insensitive(Constraint) && VT == MVT::Other)
59122       return std::make_pair(X86::FPSW, &X86::FPCCRRegClass);
59123 
59124     return Res;
59125   }
59126 
59127   // Make sure it isn't a register that requires 64-bit mode.
59128   if (!Subtarget.is64Bit() &&
59129       (isFRClass(*Res.second) || isGRClass(*Res.second)) &&
59130       TRI->getEncodingValue(Res.first) >= 8) {
59131     // Register requires REX prefix, but we're in 32-bit mode.
59132     return std::make_pair(0, nullptr);
59133   }
59134 
59135   // Make sure it isn't a register that requires AVX512.
59136   if (!Subtarget.hasAVX512() && isFRClass(*Res.second) &&
59137       TRI->getEncodingValue(Res.first) & 0x10) {
59138     // Register requires EVEX prefix.
59139     return std::make_pair(0, nullptr);
59140   }
59141 
59142   // Otherwise, check to see if this is a register class of the wrong value
59143   // type.  For example, we want to map "{ax},i32" -> {eax}, we don't want it to
59144   // turn into {ax},{dx}.
59145   // MVT::Other is used to specify clobber names.
59146   if (TRI->isTypeLegalForClass(*Res.second, VT) || VT == MVT::Other)
59147     return Res;   // Correct type already, nothing to do.
59148 
59149   // Get a matching integer of the correct size. i.e. "ax" with MVT::32 should
59150   // return "eax". This should even work for things like getting 64bit integer
59151   // registers when given an f64 type.
59152   const TargetRegisterClass *Class = Res.second;
59153   // The generic code will match the first register class that contains the
59154   // given register. Thus, based on the ordering of the tablegened file,
59155   // the "plain" GR classes might not come first.
59156   // Therefore, use a helper method.
59157   if (isGRClass(*Class)) {
59158     unsigned Size = VT.getSizeInBits();
59159     if (Size == 1) Size = 8;
59160     if (Size != 8 && Size != 16 && Size != 32 && Size != 64)
59161       return std::make_pair(0, nullptr);
59162     Register DestReg = getX86SubSuperRegister(Res.first, Size);
59163     if (DestReg.isValid()) {
59164       bool is64Bit = Subtarget.is64Bit();
59165       const TargetRegisterClass *RC =
59166           Size == 8 ? (is64Bit ? &X86::GR8RegClass : &X86::GR8_NOREXRegClass)
59167         : Size == 16 ? (is64Bit ? &X86::GR16RegClass : &X86::GR16_NOREXRegClass)
59168         : Size == 32 ? (is64Bit ? &X86::GR32RegClass : &X86::GR32_NOREXRegClass)
59169         : /*Size == 64*/ (is64Bit ? &X86::GR64RegClass : nullptr);
59170       if (Size == 64 && !is64Bit) {
59171         // Model GCC's behavior here and select a fixed pair of 32-bit
59172         // registers.
59173         switch (DestReg) {
59174         case X86::RAX:
59175           return std::make_pair(X86::EAX, &X86::GR32_ADRegClass);
59176         case X86::RDX:
59177           return std::make_pair(X86::EDX, &X86::GR32_DCRegClass);
59178         case X86::RCX:
59179           return std::make_pair(X86::ECX, &X86::GR32_CBRegClass);
59180         case X86::RBX:
59181           return std::make_pair(X86::EBX, &X86::GR32_BSIRegClass);
59182         case X86::RSI:
59183           return std::make_pair(X86::ESI, &X86::GR32_SIDIRegClass);
59184         case X86::RDI:
59185           return std::make_pair(X86::EDI, &X86::GR32_DIBPRegClass);
59186         case X86::RBP:
59187           return std::make_pair(X86::EBP, &X86::GR32_BPSPRegClass);
59188         default:
59189           return std::make_pair(0, nullptr);
59190         }
59191       }
59192       if (RC && RC->contains(DestReg))
59193         return std::make_pair(DestReg, RC);
59194       return Res;
59195     }
59196     // No register found/type mismatch.
59197     return std::make_pair(0, nullptr);
59198   } else if (isFRClass(*Class)) {
59199     // Handle references to XMM physical registers that got mapped into the
59200     // wrong class.  This can happen with constraints like {xmm0} where the
59201     // target independent register mapper will just pick the first match it can
59202     // find, ignoring the required type.
59203 
59204     // TODO: Handle f128 and i128 in FR128RegClass after it is tested well.
59205     if (VT == MVT::f16)
59206       Res.second = &X86::FR16XRegClass;
59207     else if (VT == MVT::f32 || VT == MVT::i32)
59208       Res.second = &X86::FR32XRegClass;
59209     else if (VT == MVT::f64 || VT == MVT::i64)
59210       Res.second = &X86::FR64XRegClass;
59211     else if (TRI->isTypeLegalForClass(X86::VR128XRegClass, VT))
59212       Res.second = &X86::VR128XRegClass;
59213     else if (TRI->isTypeLegalForClass(X86::VR256XRegClass, VT))
59214       Res.second = &X86::VR256XRegClass;
59215     else if (TRI->isTypeLegalForClass(X86::VR512RegClass, VT))
59216       Res.second = &X86::VR512RegClass;
59217     else {
59218       // Type mismatch and not a clobber: Return an error;
59219       Res.first = 0;
59220       Res.second = nullptr;
59221     }
59222   } else if (isVKClass(*Class)) {
59223     if (VT == MVT::v1i1 || VT == MVT::i1)
59224       Res.second = &X86::VK1RegClass;
59225     else if (VT == MVT::v8i1 || VT == MVT::i8)
59226       Res.second = &X86::VK8RegClass;
59227     else if (VT == MVT::v16i1 || VT == MVT::i16)
59228       Res.second = &X86::VK16RegClass;
59229     else if (VT == MVT::v32i1 || VT == MVT::i32)
59230       Res.second = &X86::VK32RegClass;
59231     else if (VT == MVT::v64i1 || VT == MVT::i64)
59232       Res.second = &X86::VK64RegClass;
59233     else {
59234       // Type mismatch and not a clobber: Return an error;
59235       Res.first = 0;
59236       Res.second = nullptr;
59237     }
59238   }
59239 
59240   return Res;
59241 }
59242 
isIntDivCheap(EVT VT,AttributeList Attr) const59243 bool X86TargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const {
59244   // Integer division on x86 is expensive. However, when aggressively optimizing
59245   // for code size, we prefer to use a div instruction, as it is usually smaller
59246   // than the alternative sequence.
59247   // The exception to this is vector division. Since x86 doesn't have vector
59248   // integer division, leaving the division as-is is a loss even in terms of
59249   // size, because it will have to be scalarized, while the alternative code
59250   // sequence can be performed in vector form.
59251   bool OptSize = Attr.hasFnAttr(Attribute::MinSize);
59252   return OptSize && !VT.isVector();
59253 }
59254 
initializeSplitCSR(MachineBasicBlock * Entry) const59255 void X86TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const {
59256   if (!Subtarget.is64Bit())
59257     return;
59258 
59259   // Update IsSplitCSR in X86MachineFunctionInfo.
59260   X86MachineFunctionInfo *AFI =
59261       Entry->getParent()->getInfo<X86MachineFunctionInfo>();
59262   AFI->setIsSplitCSR(true);
59263 }
59264 
insertCopiesSplitCSR(MachineBasicBlock * Entry,const SmallVectorImpl<MachineBasicBlock * > & Exits) const59265 void X86TargetLowering::insertCopiesSplitCSR(
59266     MachineBasicBlock *Entry,
59267     const SmallVectorImpl<MachineBasicBlock *> &Exits) const {
59268   const X86RegisterInfo *TRI = Subtarget.getRegisterInfo();
59269   const MCPhysReg *IStart = TRI->getCalleeSavedRegsViaCopy(Entry->getParent());
59270   if (!IStart)
59271     return;
59272 
59273   const TargetInstrInfo *TII = Subtarget.getInstrInfo();
59274   MachineRegisterInfo *MRI = &Entry->getParent()->getRegInfo();
59275   MachineBasicBlock::iterator MBBI = Entry->begin();
59276   for (const MCPhysReg *I = IStart; *I; ++I) {
59277     const TargetRegisterClass *RC = nullptr;
59278     if (X86::GR64RegClass.contains(*I))
59279       RC = &X86::GR64RegClass;
59280     else
59281       llvm_unreachable("Unexpected register class in CSRsViaCopy!");
59282 
59283     Register NewVR = MRI->createVirtualRegister(RC);
59284     // Create copy from CSR to a virtual register.
59285     // FIXME: this currently does not emit CFI pseudo-instructions, it works
59286     // fine for CXX_FAST_TLS since the C++-style TLS access functions should be
59287     // nounwind. If we want to generalize this later, we may need to emit
59288     // CFI pseudo-instructions.
59289     assert(
59290         Entry->getParent()->getFunction().hasFnAttribute(Attribute::NoUnwind) &&
59291         "Function should be nounwind in insertCopiesSplitCSR!");
59292     Entry->addLiveIn(*I);
59293     BuildMI(*Entry, MBBI, MIMetadata(), TII->get(TargetOpcode::COPY), NewVR)
59294         .addReg(*I);
59295 
59296     // Insert the copy-back instructions right before the terminator.
59297     for (auto *Exit : Exits)
59298       BuildMI(*Exit, Exit->getFirstTerminator(), MIMetadata(),
59299               TII->get(TargetOpcode::COPY), *I)
59300           .addReg(NewVR);
59301   }
59302 }
59303 
supportSwiftError() const59304 bool X86TargetLowering::supportSwiftError() const {
59305   return Subtarget.is64Bit();
59306 }
59307 
59308 MachineInstr *
EmitKCFICheck(MachineBasicBlock & MBB,MachineBasicBlock::instr_iterator & MBBI,const TargetInstrInfo * TII) const59309 X86TargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
59310                                  MachineBasicBlock::instr_iterator &MBBI,
59311                                  const TargetInstrInfo *TII) const {
59312   assert(MBBI->isCall() && MBBI->getCFIType() &&
59313          "Invalid call instruction for a KCFI check");
59314 
59315   MachineFunction &MF = *MBB.getParent();
59316   // If the call target is a memory operand, unfold it and use R11 for the
59317   // call, so KCFI_CHECK won't have to recompute the address.
59318   switch (MBBI->getOpcode()) {
59319   case X86::CALL64m:
59320   case X86::CALL64m_NT:
59321   case X86::TAILJMPm64:
59322   case X86::TAILJMPm64_REX: {
59323     MachineBasicBlock::instr_iterator OrigCall = MBBI;
59324     SmallVector<MachineInstr *, 2> NewMIs;
59325     if (!TII->unfoldMemoryOperand(MF, *OrigCall, X86::R11, /*UnfoldLoad=*/true,
59326                                   /*UnfoldStore=*/false, NewMIs))
59327       report_fatal_error("Failed to unfold memory operand for a KCFI check");
59328     for (auto *NewMI : NewMIs)
59329       MBBI = MBB.insert(OrigCall, NewMI);
59330     assert(MBBI->isCall() &&
59331            "Unexpected instruction after memory operand unfolding");
59332     if (OrigCall->shouldUpdateCallSiteInfo())
59333       MF.moveCallSiteInfo(&*OrigCall, &*MBBI);
59334     MBBI->setCFIType(MF, OrigCall->getCFIType());
59335     OrigCall->eraseFromParent();
59336     break;
59337   }
59338   default:
59339     break;
59340   }
59341 
59342   MachineOperand &Target = MBBI->getOperand(0);
59343   Register TargetReg;
59344   switch (MBBI->getOpcode()) {
59345   case X86::CALL64r:
59346   case X86::CALL64r_NT:
59347   case X86::TAILJMPr64:
59348   case X86::TAILJMPr64_REX:
59349     assert(Target.isReg() && "Unexpected target operand for an indirect call");
59350     Target.setIsRenamable(false);
59351     TargetReg = Target.getReg();
59352     break;
59353   case X86::CALL64pcrel32:
59354   case X86::TAILJMPd64:
59355     assert(Target.isSymbol() && "Unexpected target operand for a direct call");
59356     // X86TargetLowering::EmitLoweredIndirectThunk always uses r11 for
59357     // 64-bit indirect thunk calls.
59358     assert(StringRef(Target.getSymbolName()).ends_with("_r11") &&
59359            "Unexpected register for an indirect thunk call");
59360     TargetReg = X86::R11;
59361     break;
59362   default:
59363     llvm_unreachable("Unexpected CFI call opcode");
59364     break;
59365   }
59366 
59367   return BuildMI(MBB, MBBI, MIMetadata(*MBBI), TII->get(X86::KCFI_CHECK))
59368       .addReg(TargetReg)
59369       .addImm(MBBI->getCFIType())
59370       .getInstr();
59371 }
59372 
59373 /// Returns true if stack probing through a function call is requested.
hasStackProbeSymbol(const MachineFunction & MF) const59374 bool X86TargetLowering::hasStackProbeSymbol(const MachineFunction &MF) const {
59375   return !getStackProbeSymbolName(MF).empty();
59376 }
59377 
59378 /// Returns true if stack probing through inline assembly is requested.
hasInlineStackProbe(const MachineFunction & MF) const59379 bool X86TargetLowering::hasInlineStackProbe(const MachineFunction &MF) const {
59380 
59381   // No inline stack probe for Windows, they have their own mechanism.
59382   if (Subtarget.isOSWindows() ||
59383       MF.getFunction().hasFnAttribute("no-stack-arg-probe"))
59384     return false;
59385 
59386   // If the function specifically requests inline stack probes, emit them.
59387   if (MF.getFunction().hasFnAttribute("probe-stack"))
59388     return MF.getFunction().getFnAttribute("probe-stack").getValueAsString() ==
59389            "inline-asm";
59390 
59391   return false;
59392 }
59393 
59394 /// Returns the name of the symbol used to emit stack probes or the empty
59395 /// string if not applicable.
59396 StringRef
getStackProbeSymbolName(const MachineFunction & MF) const59397 X86TargetLowering::getStackProbeSymbolName(const MachineFunction &MF) const {
59398   // Inline Stack probes disable stack probe call
59399   if (hasInlineStackProbe(MF))
59400     return "";
59401 
59402   // If the function specifically requests stack probes, emit them.
59403   if (MF.getFunction().hasFnAttribute("probe-stack"))
59404     return MF.getFunction().getFnAttribute("probe-stack").getValueAsString();
59405 
59406   // Generally, if we aren't on Windows, the platform ABI does not include
59407   // support for stack probes, so don't emit them.
59408   if (!Subtarget.isOSWindows() || Subtarget.isTargetMachO() ||
59409       MF.getFunction().hasFnAttribute("no-stack-arg-probe"))
59410     return "";
59411 
59412   // We need a stack probe to conform to the Windows ABI. Choose the right
59413   // symbol.
59414   if (Subtarget.is64Bit())
59415     return Subtarget.isTargetCygMing() ? "___chkstk_ms" : "__chkstk";
59416   return Subtarget.isTargetCygMing() ? "_alloca" : "_chkstk";
59417 }
59418 
59419 unsigned
getStackProbeSize(const MachineFunction & MF) const59420 X86TargetLowering::getStackProbeSize(const MachineFunction &MF) const {
59421   // The default stack probe size is 4096 if the function has no stackprobesize
59422   // attribute.
59423   return MF.getFunction().getFnAttributeAsParsedInteger("stack-probe-size",
59424                                                         4096);
59425 }
59426 
getPrefLoopAlignment(MachineLoop * ML) const59427 Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
59428   if (ML && ML->isInnermost() &&
59429       ExperimentalPrefInnermostLoopAlignment.getNumOccurrences())
59430     return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
59431   return TargetLowering::getPrefLoopAlignment();
59432 }
59433