1 //===-- TargetLowering.cpp - Implement the TargetLowering class -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This implements the TargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "llvm/CodeGen/TargetLowering.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Analysis/ValueTracking.h"
16 #include "llvm/Analysis/VectorUtils.h"
17 #include "llvm/CodeGen/Analysis.h"
18 #include "llvm/CodeGen/CallingConvLower.h"
19 #include "llvm/CodeGen/CodeGenCommonISel.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineJumpTableInfo.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/SDPatternMatch.h"
25 #include "llvm/CodeGen/SelectionDAG.h"
26 #include "llvm/CodeGen/TargetRegisterInfo.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/GlobalVariable.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/MC/MCAsmInfo.h"
32 #include "llvm/MC/MCExpr.h"
33 #include "llvm/Support/DivisionByConstantInfo.h"
34 #include "llvm/Support/ErrorHandling.h"
35 #include "llvm/Support/KnownBits.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Target/TargetMachine.h"
38 #include <cctype>
39 #include <deque>
40 using namespace llvm;
41 using namespace llvm::SDPatternMatch;
42
43 /// NOTE: The TargetMachine owns TLOF.
TargetLowering(const TargetMachine & tm)44 TargetLowering::TargetLowering(const TargetMachine &tm)
45 : TargetLoweringBase(tm) {}
46
47 // Define the virtual destructor out-of-line for build efficiency.
48 TargetLowering::~TargetLowering() = default;
49
getTargetNodeName(unsigned Opcode) const50 const char *TargetLowering::getTargetNodeName(unsigned Opcode) const {
51 return nullptr;
52 }
53
isPositionIndependent() const54 bool TargetLowering::isPositionIndependent() const {
55 return getTargetMachine().isPositionIndependent();
56 }
57
58 /// Check whether a given call node is in tail position within its function. If
59 /// so, it sets Chain to the input chain of the tail call.
isInTailCallPosition(SelectionDAG & DAG,SDNode * Node,SDValue & Chain) const60 bool TargetLowering::isInTailCallPosition(SelectionDAG &DAG, SDNode *Node,
61 SDValue &Chain) const {
62 const Function &F = DAG.getMachineFunction().getFunction();
63
64 // First, check if tail calls have been disabled in this function.
65 if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
66 return false;
67
68 // Conservatively require the attributes of the call to match those of
69 // the return. Ignore following attributes because they don't affect the
70 // call sequence.
71 AttrBuilder CallerAttrs(F.getContext(), F.getAttributes().getRetAttrs());
72 for (const auto &Attr : {Attribute::Alignment, Attribute::Dereferenceable,
73 Attribute::DereferenceableOrNull, Attribute::NoAlias,
74 Attribute::NonNull, Attribute::NoUndef,
75 Attribute::Range, Attribute::NoFPClass})
76 CallerAttrs.removeAttribute(Attr);
77
78 if (CallerAttrs.hasAttributes())
79 return false;
80
81 // It's not safe to eliminate the sign / zero extension of the return value.
82 if (CallerAttrs.contains(Attribute::ZExt) ||
83 CallerAttrs.contains(Attribute::SExt))
84 return false;
85
86 // Check if the only use is a function return node.
87 return isUsedByReturnOnly(Node, Chain);
88 }
89
parametersInCSRMatch(const MachineRegisterInfo & MRI,const uint32_t * CallerPreservedMask,const SmallVectorImpl<CCValAssign> & ArgLocs,const SmallVectorImpl<SDValue> & OutVals) const90 bool TargetLowering::parametersInCSRMatch(const MachineRegisterInfo &MRI,
91 const uint32_t *CallerPreservedMask,
92 const SmallVectorImpl<CCValAssign> &ArgLocs,
93 const SmallVectorImpl<SDValue> &OutVals) const {
94 for (unsigned I = 0, E = ArgLocs.size(); I != E; ++I) {
95 const CCValAssign &ArgLoc = ArgLocs[I];
96 if (!ArgLoc.isRegLoc())
97 continue;
98 MCRegister Reg = ArgLoc.getLocReg();
99 // Only look at callee saved registers.
100 if (MachineOperand::clobbersPhysReg(CallerPreservedMask, Reg))
101 continue;
102 // Check that we pass the value used for the caller.
103 // (We look for a CopyFromReg reading a virtual register that is used
104 // for the function live-in value of register Reg)
105 SDValue Value = OutVals[I];
106 if (Value->getOpcode() == ISD::AssertZext)
107 Value = Value.getOperand(0);
108 if (Value->getOpcode() != ISD::CopyFromReg)
109 return false;
110 Register ArgReg = cast<RegisterSDNode>(Value->getOperand(1))->getReg();
111 if (MRI.getLiveInPhysReg(ArgReg) != Reg)
112 return false;
113 }
114 return true;
115 }
116
117 /// Set CallLoweringInfo attribute flags based on a call instruction
118 /// and called function attributes.
setAttributes(const CallBase * Call,unsigned ArgIdx)119 void TargetLoweringBase::ArgListEntry::setAttributes(const CallBase *Call,
120 unsigned ArgIdx) {
121 IsSExt = Call->paramHasAttr(ArgIdx, Attribute::SExt);
122 IsZExt = Call->paramHasAttr(ArgIdx, Attribute::ZExt);
123 IsNoExt = Call->paramHasAttr(ArgIdx, Attribute::NoExt);
124 IsInReg = Call->paramHasAttr(ArgIdx, Attribute::InReg);
125 IsSRet = Call->paramHasAttr(ArgIdx, Attribute::StructRet);
126 IsNest = Call->paramHasAttr(ArgIdx, Attribute::Nest);
127 IsByVal = Call->paramHasAttr(ArgIdx, Attribute::ByVal);
128 IsPreallocated = Call->paramHasAttr(ArgIdx, Attribute::Preallocated);
129 IsInAlloca = Call->paramHasAttr(ArgIdx, Attribute::InAlloca);
130 IsReturned = Call->paramHasAttr(ArgIdx, Attribute::Returned);
131 IsSwiftSelf = Call->paramHasAttr(ArgIdx, Attribute::SwiftSelf);
132 IsSwiftAsync = Call->paramHasAttr(ArgIdx, Attribute::SwiftAsync);
133 IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError);
134 Alignment = Call->getParamStackAlign(ArgIdx);
135 IndirectType = nullptr;
136 assert(IsByVal + IsPreallocated + IsInAlloca + IsSRet <= 1 &&
137 "multiple ABI attributes?");
138 if (IsByVal) {
139 IndirectType = Call->getParamByValType(ArgIdx);
140 if (!Alignment)
141 Alignment = Call->getParamAlign(ArgIdx);
142 }
143 if (IsPreallocated)
144 IndirectType = Call->getParamPreallocatedType(ArgIdx);
145 if (IsInAlloca)
146 IndirectType = Call->getParamInAllocaType(ArgIdx);
147 if (IsSRet)
148 IndirectType = Call->getParamStructRetType(ArgIdx);
149 }
150
151 /// Generate a libcall taking the given operands as arguments and returning a
152 /// result of type RetVT.
153 std::pair<SDValue, SDValue>
makeLibCall(SelectionDAG & DAG,RTLIB::Libcall LC,EVT RetVT,ArrayRef<SDValue> Ops,MakeLibCallOptions CallOptions,const SDLoc & dl,SDValue InChain) const154 TargetLowering::makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC, EVT RetVT,
155 ArrayRef<SDValue> Ops,
156 MakeLibCallOptions CallOptions,
157 const SDLoc &dl,
158 SDValue InChain) const {
159 if (!InChain)
160 InChain = DAG.getEntryNode();
161
162 TargetLowering::ArgListTy Args;
163 Args.reserve(Ops.size());
164
165 TargetLowering::ArgListEntry Entry;
166 ArrayRef<Type *> OpsTypeOverrides = CallOptions.OpsTypeOverrides;
167 for (unsigned i = 0; i < Ops.size(); ++i) {
168 SDValue NewOp = Ops[i];
169 Entry.Node = NewOp;
170 Entry.Ty = i < OpsTypeOverrides.size() && OpsTypeOverrides[i]
171 ? OpsTypeOverrides[i]
172 : Entry.Node.getValueType().getTypeForEVT(*DAG.getContext());
173 Entry.IsSExt =
174 shouldSignExtendTypeInLibCall(Entry.Ty, CallOptions.IsSigned);
175 Entry.IsZExt = !Entry.IsSExt;
176
177 if (CallOptions.IsSoften &&
178 !shouldExtendTypeInLibCall(CallOptions.OpsVTBeforeSoften[i])) {
179 Entry.IsSExt = Entry.IsZExt = false;
180 }
181 Args.push_back(Entry);
182 }
183
184 const char *LibcallName = getLibcallName(LC);
185 if (LC == RTLIB::UNKNOWN_LIBCALL || !LibcallName)
186 reportFatalInternalError("unsupported library call operation");
187
188 SDValue Callee =
189 DAG.getExternalSymbol(LibcallName, getPointerTy(DAG.getDataLayout()));
190
191 Type *RetTy = RetVT.getTypeForEVT(*DAG.getContext());
192 TargetLowering::CallLoweringInfo CLI(DAG);
193 bool signExtend = shouldSignExtendTypeInLibCall(RetTy, CallOptions.IsSigned);
194 bool zeroExtend = !signExtend;
195
196 if (CallOptions.IsSoften &&
197 !shouldExtendTypeInLibCall(CallOptions.RetVTBeforeSoften)) {
198 signExtend = zeroExtend = false;
199 }
200
201 CLI.setDebugLoc(dl)
202 .setChain(InChain)
203 .setLibCallee(getLibcallCallingConv(LC), RetTy, Callee, std::move(Args))
204 .setNoReturn(CallOptions.DoesNotReturn)
205 .setDiscardResult(!CallOptions.IsReturnValueUsed)
206 .setIsPostTypeLegalization(CallOptions.IsPostTypeLegalization)
207 .setSExtResult(signExtend)
208 .setZExtResult(zeroExtend);
209 return LowerCallTo(CLI);
210 }
211
findOptimalMemOpLowering(LLVMContext & Context,std::vector<EVT> & MemOps,unsigned Limit,const MemOp & Op,unsigned DstAS,unsigned SrcAS,const AttributeList & FuncAttributes) const212 bool TargetLowering::findOptimalMemOpLowering(
213 LLVMContext &Context, std::vector<EVT> &MemOps, unsigned Limit,
214 const MemOp &Op, unsigned DstAS, unsigned SrcAS,
215 const AttributeList &FuncAttributes) const {
216 if (Limit != ~unsigned(0) && Op.isMemcpyWithFixedDstAlign() &&
217 Op.getSrcAlign() < Op.getDstAlign())
218 return false;
219
220 EVT VT = getOptimalMemOpType(Context, Op, FuncAttributes);
221
222 if (VT == MVT::Other) {
223 // Use the largest integer type whose alignment constraints are satisfied.
224 // We only need to check DstAlign here as SrcAlign is always greater or
225 // equal to DstAlign (or zero).
226 VT = MVT::LAST_INTEGER_VALUETYPE;
227 if (Op.isFixedDstAlign())
228 while (Op.getDstAlign() < (VT.getSizeInBits() / 8) &&
229 !allowsMisalignedMemoryAccesses(VT, DstAS, Op.getDstAlign()))
230 VT = (MVT::SimpleValueType)(VT.getSimpleVT().SimpleTy - 1);
231 assert(VT.isInteger());
232
233 // Find the largest legal integer type.
234 MVT LVT = MVT::LAST_INTEGER_VALUETYPE;
235 while (!isTypeLegal(LVT))
236 LVT = (MVT::SimpleValueType)(LVT.SimpleTy - 1);
237 assert(LVT.isInteger());
238
239 // If the type we've chosen is larger than the largest legal integer type
240 // then use that instead.
241 if (VT.bitsGT(LVT))
242 VT = LVT;
243 }
244
245 unsigned NumMemOps = 0;
246 uint64_t Size = Op.size();
247 while (Size) {
248 unsigned VTSize = VT.getSizeInBits() / 8;
249 while (VTSize > Size) {
250 // For now, only use non-vector load / store's for the left-over pieces.
251 EVT NewVT = VT;
252 unsigned NewVTSize;
253
254 bool Found = false;
255 if (VT.isVector() || VT.isFloatingPoint()) {
256 NewVT = (VT.getSizeInBits() > 64) ? MVT::i64 : MVT::i32;
257 if (isOperationLegalOrCustom(ISD::STORE, NewVT) &&
258 isSafeMemOpType(NewVT.getSimpleVT()))
259 Found = true;
260 else if (NewVT == MVT::i64 &&
261 isOperationLegalOrCustom(ISD::STORE, MVT::f64) &&
262 isSafeMemOpType(MVT::f64)) {
263 // i64 is usually not legal on 32-bit targets, but f64 may be.
264 NewVT = MVT::f64;
265 Found = true;
266 }
267 }
268
269 if (!Found) {
270 do {
271 NewVT = (MVT::SimpleValueType)(NewVT.getSimpleVT().SimpleTy - 1);
272 if (NewVT == MVT::i8)
273 break;
274 } while (!isSafeMemOpType(NewVT.getSimpleVT()));
275 }
276 NewVTSize = NewVT.getSizeInBits() / 8;
277
278 // If the new VT cannot cover all of the remaining bits, then consider
279 // issuing a (or a pair of) unaligned and overlapping load / store.
280 unsigned Fast;
281 if (NumMemOps && Op.allowOverlap() && NewVTSize < Size &&
282 allowsMisalignedMemoryAccesses(
283 VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
284 MachineMemOperand::MONone, &Fast) &&
285 Fast)
286 VTSize = Size;
287 else {
288 VT = NewVT;
289 VTSize = NewVTSize;
290 }
291 }
292
293 if (++NumMemOps > Limit)
294 return false;
295
296 MemOps.push_back(VT);
297 Size -= VTSize;
298 }
299
300 return true;
301 }
302
303 /// Soften the operands of a comparison. This code is shared among BR_CC,
304 /// SELECT_CC, and SETCC handlers.
softenSetCCOperands(SelectionDAG & DAG,EVT VT,SDValue & NewLHS,SDValue & NewRHS,ISD::CondCode & CCCode,const SDLoc & dl,const SDValue OldLHS,const SDValue OldRHS) const305 void TargetLowering::softenSetCCOperands(SelectionDAG &DAG, EVT VT,
306 SDValue &NewLHS, SDValue &NewRHS,
307 ISD::CondCode &CCCode,
308 const SDLoc &dl, const SDValue OldLHS,
309 const SDValue OldRHS) const {
310 SDValue Chain;
311 return softenSetCCOperands(DAG, VT, NewLHS, NewRHS, CCCode, dl, OldLHS,
312 OldRHS, Chain);
313 }
314
softenSetCCOperands(SelectionDAG & DAG,EVT VT,SDValue & NewLHS,SDValue & NewRHS,ISD::CondCode & CCCode,const SDLoc & dl,const SDValue OldLHS,const SDValue OldRHS,SDValue & Chain,bool IsSignaling) const315 void TargetLowering::softenSetCCOperands(SelectionDAG &DAG, EVT VT,
316 SDValue &NewLHS, SDValue &NewRHS,
317 ISD::CondCode &CCCode,
318 const SDLoc &dl, const SDValue OldLHS,
319 const SDValue OldRHS,
320 SDValue &Chain,
321 bool IsSignaling) const {
322 // FIXME: Currently we cannot really respect all IEEE predicates due to libgcc
323 // not supporting it. We can update this code when libgcc provides such
324 // functions.
325
326 assert((VT == MVT::f32 || VT == MVT::f64 || VT == MVT::f128 || VT == MVT::ppcf128)
327 && "Unsupported setcc type!");
328
329 // Expand into one or more soft-fp libcall(s).
330 RTLIB::Libcall LC1 = RTLIB::UNKNOWN_LIBCALL, LC2 = RTLIB::UNKNOWN_LIBCALL;
331 bool ShouldInvertCC = false;
332 switch (CCCode) {
333 case ISD::SETEQ:
334 case ISD::SETOEQ:
335 LC1 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
336 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
337 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
338 break;
339 case ISD::SETNE:
340 case ISD::SETUNE:
341 LC1 = (VT == MVT::f32) ? RTLIB::UNE_F32 :
342 (VT == MVT::f64) ? RTLIB::UNE_F64 :
343 (VT == MVT::f128) ? RTLIB::UNE_F128 : RTLIB::UNE_PPCF128;
344 break;
345 case ISD::SETGE:
346 case ISD::SETOGE:
347 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
348 (VT == MVT::f64) ? RTLIB::OGE_F64 :
349 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
350 break;
351 case ISD::SETLT:
352 case ISD::SETOLT:
353 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
354 (VT == MVT::f64) ? RTLIB::OLT_F64 :
355 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
356 break;
357 case ISD::SETLE:
358 case ISD::SETOLE:
359 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
360 (VT == MVT::f64) ? RTLIB::OLE_F64 :
361 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
362 break;
363 case ISD::SETGT:
364 case ISD::SETOGT:
365 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
366 (VT == MVT::f64) ? RTLIB::OGT_F64 :
367 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
368 break;
369 case ISD::SETO:
370 ShouldInvertCC = true;
371 [[fallthrough]];
372 case ISD::SETUO:
373 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
374 (VT == MVT::f64) ? RTLIB::UO_F64 :
375 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
376 break;
377 case ISD::SETONE:
378 // SETONE = O && UNE
379 ShouldInvertCC = true;
380 [[fallthrough]];
381 case ISD::SETUEQ:
382 LC1 = (VT == MVT::f32) ? RTLIB::UO_F32 :
383 (VT == MVT::f64) ? RTLIB::UO_F64 :
384 (VT == MVT::f128) ? RTLIB::UO_F128 : RTLIB::UO_PPCF128;
385 LC2 = (VT == MVT::f32) ? RTLIB::OEQ_F32 :
386 (VT == MVT::f64) ? RTLIB::OEQ_F64 :
387 (VT == MVT::f128) ? RTLIB::OEQ_F128 : RTLIB::OEQ_PPCF128;
388 break;
389 default:
390 // Invert CC for unordered comparisons
391 ShouldInvertCC = true;
392 switch (CCCode) {
393 case ISD::SETULT:
394 LC1 = (VT == MVT::f32) ? RTLIB::OGE_F32 :
395 (VT == MVT::f64) ? RTLIB::OGE_F64 :
396 (VT == MVT::f128) ? RTLIB::OGE_F128 : RTLIB::OGE_PPCF128;
397 break;
398 case ISD::SETULE:
399 LC1 = (VT == MVT::f32) ? RTLIB::OGT_F32 :
400 (VT == MVT::f64) ? RTLIB::OGT_F64 :
401 (VT == MVT::f128) ? RTLIB::OGT_F128 : RTLIB::OGT_PPCF128;
402 break;
403 case ISD::SETUGT:
404 LC1 = (VT == MVT::f32) ? RTLIB::OLE_F32 :
405 (VT == MVT::f64) ? RTLIB::OLE_F64 :
406 (VT == MVT::f128) ? RTLIB::OLE_F128 : RTLIB::OLE_PPCF128;
407 break;
408 case ISD::SETUGE:
409 LC1 = (VT == MVT::f32) ? RTLIB::OLT_F32 :
410 (VT == MVT::f64) ? RTLIB::OLT_F64 :
411 (VT == MVT::f128) ? RTLIB::OLT_F128 : RTLIB::OLT_PPCF128;
412 break;
413 default: llvm_unreachable("Do not know how to soften this setcc!");
414 }
415 }
416
417 // Use the target specific return value for comparison lib calls.
418 EVT RetVT = getCmpLibcallReturnType();
419 SDValue Ops[2] = {NewLHS, NewRHS};
420 TargetLowering::MakeLibCallOptions CallOptions;
421 EVT OpsVT[2] = { OldLHS.getValueType(),
422 OldRHS.getValueType() };
423 CallOptions.setTypeListBeforeSoften(OpsVT, RetVT, true);
424 auto Call = makeLibCall(DAG, LC1, RetVT, Ops, CallOptions, dl, Chain);
425 NewLHS = Call.first;
426 NewRHS = DAG.getConstant(0, dl, RetVT);
427
428 RTLIB::LibcallImpl LC1Impl = getLibcallImpl(LC1);
429 if (LC1Impl == RTLIB::Unsupported) {
430 reportFatalUsageError(
431 "no libcall available to soften floating-point compare");
432 }
433
434 CCCode = getSoftFloatCmpLibcallPredicate(LC1Impl);
435 if (ShouldInvertCC) {
436 assert(RetVT.isInteger());
437 CCCode = getSetCCInverse(CCCode, RetVT);
438 }
439
440 if (LC2 == RTLIB::UNKNOWN_LIBCALL) {
441 // Update Chain.
442 Chain = Call.second;
443 } else {
444 RTLIB::LibcallImpl LC2Impl = getLibcallImpl(LC2);
445 if (LC2Impl == RTLIB::Unsupported) {
446 reportFatalUsageError(
447 "no libcall available to soften floating-point compare");
448 }
449
450 assert(CCCode == (ShouldInvertCC ? ISD::SETEQ : ISD::SETNE) &&
451 "unordered call should be simple boolean");
452
453 EVT SetCCVT =
454 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), RetVT);
455 if (getBooleanContents(RetVT) == ZeroOrOneBooleanContent) {
456 NewLHS = DAG.getNode(ISD::AssertZext, dl, RetVT, Call.first,
457 DAG.getValueType(MVT::i1));
458 }
459
460 SDValue Tmp = DAG.getSetCC(dl, SetCCVT, NewLHS, NewRHS, CCCode);
461 auto Call2 = makeLibCall(DAG, LC2, RetVT, Ops, CallOptions, dl, Chain);
462 CCCode = getSoftFloatCmpLibcallPredicate(LC2Impl);
463 if (ShouldInvertCC)
464 CCCode = getSetCCInverse(CCCode, RetVT);
465 NewLHS = DAG.getSetCC(dl, SetCCVT, Call2.first, NewRHS, CCCode);
466 if (Chain)
467 Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Call.second,
468 Call2.second);
469 NewLHS = DAG.getNode(ShouldInvertCC ? ISD::AND : ISD::OR, dl,
470 Tmp.getValueType(), Tmp, NewLHS);
471 NewRHS = SDValue();
472 }
473 }
474
475 /// Return the entry encoding for a jump table in the current function. The
476 /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
getJumpTableEncoding() const477 unsigned TargetLowering::getJumpTableEncoding() const {
478 // In non-pic modes, just use the address of a block.
479 if (!isPositionIndependent())
480 return MachineJumpTableInfo::EK_BlockAddress;
481
482 // Otherwise, use a label difference.
483 return MachineJumpTableInfo::EK_LabelDifference32;
484 }
485
getPICJumpTableRelocBase(SDValue Table,SelectionDAG & DAG) const486 SDValue TargetLowering::getPICJumpTableRelocBase(SDValue Table,
487 SelectionDAG &DAG) const {
488 return Table;
489 }
490
491 /// This returns the relocation base for the given PIC jumptable, the same as
492 /// getPICJumpTableRelocBase, but as an MCExpr.
493 const MCExpr *
getPICJumpTableRelocBaseExpr(const MachineFunction * MF,unsigned JTI,MCContext & Ctx) const494 TargetLowering::getPICJumpTableRelocBaseExpr(const MachineFunction *MF,
495 unsigned JTI,MCContext &Ctx) const{
496 // The normal PIC reloc base is the label at the start of the jump table.
497 return MCSymbolRefExpr::create(MF->getJTISymbol(JTI, Ctx), Ctx);
498 }
499
expandIndirectJTBranch(const SDLoc & dl,SDValue Value,SDValue Addr,int JTI,SelectionDAG & DAG) const500 SDValue TargetLowering::expandIndirectJTBranch(const SDLoc &dl, SDValue Value,
501 SDValue Addr, int JTI,
502 SelectionDAG &DAG) const {
503 SDValue Chain = Value;
504 // Jump table debug info is only needed if CodeView is enabled.
505 if (DAG.getTarget().getTargetTriple().isOSBinFormatCOFF()) {
506 Chain = DAG.getJumpTableDebugInfo(JTI, Chain, dl);
507 }
508 return DAG.getNode(ISD::BRIND, dl, MVT::Other, Chain, Addr);
509 }
510
511 bool
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const512 TargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const {
513 const TargetMachine &TM = getTargetMachine();
514 const GlobalValue *GV = GA->getGlobal();
515
516 // If the address is not even local to this DSO we will have to load it from
517 // a got and then add the offset.
518 if (!TM.shouldAssumeDSOLocal(GV))
519 return false;
520
521 // If the code is position independent we will have to add a base register.
522 if (isPositionIndependent())
523 return false;
524
525 // Otherwise we can do it.
526 return true;
527 }
528
529 //===----------------------------------------------------------------------===//
530 // Optimization Methods
531 //===----------------------------------------------------------------------===//
532
533 /// If the specified instruction has a constant integer operand and there are
534 /// bits set in that constant that are not demanded, then clear those bits and
535 /// return true.
ShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const536 bool TargetLowering::ShrinkDemandedConstant(SDValue Op,
537 const APInt &DemandedBits,
538 const APInt &DemandedElts,
539 TargetLoweringOpt &TLO) const {
540 SDLoc DL(Op);
541 unsigned Opcode = Op.getOpcode();
542
543 // Early-out if we've ended up calling an undemanded node, leave this to
544 // constant folding.
545 if (DemandedBits.isZero() || DemandedElts.isZero())
546 return false;
547
548 // Do target-specific constant optimization.
549 if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
550 return TLO.New.getNode();
551
552 // FIXME: ISD::SELECT, ISD::SELECT_CC
553 switch (Opcode) {
554 default:
555 break;
556 case ISD::XOR:
557 case ISD::AND:
558 case ISD::OR: {
559 auto *Op1C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
560 if (!Op1C || Op1C->isOpaque())
561 return false;
562
563 // If this is a 'not' op, don't touch it because that's a canonical form.
564 const APInt &C = Op1C->getAPIntValue();
565 if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C))
566 return false;
567
568 if (!C.isSubsetOf(DemandedBits)) {
569 EVT VT = Op.getValueType();
570 SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT);
571 SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC,
572 Op->getFlags());
573 return TLO.CombineTo(Op, NewOp);
574 }
575
576 break;
577 }
578 }
579
580 return false;
581 }
582
ShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,TargetLoweringOpt & TLO) const583 bool TargetLowering::ShrinkDemandedConstant(SDValue Op,
584 const APInt &DemandedBits,
585 TargetLoweringOpt &TLO) const {
586 EVT VT = Op.getValueType();
587 APInt DemandedElts = VT.isVector()
588 ? APInt::getAllOnes(VT.getVectorNumElements())
589 : APInt(1, 1);
590 return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO);
591 }
592
593 /// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free.
594 /// This uses isTruncateFree/isZExtFree and ANY_EXTEND for the widening cast,
595 /// but it could be generalized for targets with other types of implicit
596 /// widening casts.
ShrinkDemandedOp(SDValue Op,unsigned BitWidth,const APInt & DemandedBits,TargetLoweringOpt & TLO) const597 bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth,
598 const APInt &DemandedBits,
599 TargetLoweringOpt &TLO) const {
600 assert(Op.getNumOperands() == 2 &&
601 "ShrinkDemandedOp only supports binary operators!");
602 assert(Op.getNode()->getNumValues() == 1 &&
603 "ShrinkDemandedOp only supports nodes with one result!");
604
605 EVT VT = Op.getValueType();
606 SelectionDAG &DAG = TLO.DAG;
607 SDLoc dl(Op);
608
609 // Early return, as this function cannot handle vector types.
610 if (VT.isVector())
611 return false;
612
613 assert(Op.getOperand(0).getValueType().getScalarSizeInBits() == BitWidth &&
614 Op.getOperand(1).getValueType().getScalarSizeInBits() == BitWidth &&
615 "ShrinkDemandedOp only supports operands that have the same size!");
616
617 // Don't do this if the node has another user, which may require the
618 // full value.
619 if (!Op.getNode()->hasOneUse())
620 return false;
621
622 // Search for the smallest integer type with free casts to and from
623 // Op's type. For expedience, just check power-of-2 integer types.
624 unsigned DemandedSize = DemandedBits.getActiveBits();
625 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
626 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
627 EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
628 if (isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT)) {
629 // We found a type with free casts.
630
631 // If the operation has the 'disjoint' flag, then the
632 // operands on the new node are also disjoint.
633 SDNodeFlags Flags(Op->getFlags().hasDisjoint() ? SDNodeFlags::Disjoint
634 : SDNodeFlags::None);
635 SDValue X = DAG.getNode(
636 Op.getOpcode(), dl, SmallVT,
637 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
638 DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)), Flags);
639 assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
640 SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
641 return TLO.CombineTo(Op, Z);
642 }
643 }
644 return false;
645 }
646
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,DAGCombinerInfo & DCI) const647 bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
648 DAGCombinerInfo &DCI) const {
649 SelectionDAG &DAG = DCI.DAG;
650 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
651 !DCI.isBeforeLegalizeOps());
652 KnownBits Known;
653
654 bool Simplified = SimplifyDemandedBits(Op, DemandedBits, Known, TLO);
655 if (Simplified) {
656 DCI.AddToWorklist(Op.getNode());
657 DCI.CommitTargetLoweringOpt(TLO);
658 }
659 return Simplified;
660 }
661
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,DAGCombinerInfo & DCI) const662 bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
663 const APInt &DemandedElts,
664 DAGCombinerInfo &DCI) const {
665 SelectionDAG &DAG = DCI.DAG;
666 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
667 !DCI.isBeforeLegalizeOps());
668 KnownBits Known;
669
670 bool Simplified =
671 SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO);
672 if (Simplified) {
673 DCI.AddToWorklist(Op.getNode());
674 DCI.CommitTargetLoweringOpt(TLO);
675 }
676 return Simplified;
677 }
678
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth,bool AssumeSingleUse) const679 bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
680 KnownBits &Known,
681 TargetLoweringOpt &TLO,
682 unsigned Depth,
683 bool AssumeSingleUse) const {
684 EVT VT = Op.getValueType();
685
686 // Since the number of lanes in a scalable vector is unknown at compile time,
687 // we track one bit which is implicitly broadcast to all lanes. This means
688 // that all lanes in a scalable vector are considered demanded.
689 APInt DemandedElts = VT.isFixedLengthVector()
690 ? APInt::getAllOnes(VT.getVectorNumElements())
691 : APInt(1, 1);
692 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth,
693 AssumeSingleUse);
694 }
695
696 // TODO: Under what circumstances can we create nodes? Constant folding?
SimplifyMultipleUseDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,SelectionDAG & DAG,unsigned Depth) const697 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
698 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
699 SelectionDAG &DAG, unsigned Depth) const {
700 EVT VT = Op.getValueType();
701
702 // Limit search depth.
703 if (Depth >= SelectionDAG::MaxRecursionDepth)
704 return SDValue();
705
706 // Ignore UNDEFs.
707 if (Op.isUndef())
708 return SDValue();
709
710 // Not demanding any bits/elts from Op.
711 if (DemandedBits == 0 || DemandedElts == 0)
712 return DAG.getUNDEF(VT);
713
714 bool IsLE = DAG.getDataLayout().isLittleEndian();
715 unsigned NumElts = DemandedElts.getBitWidth();
716 unsigned BitWidth = DemandedBits.getBitWidth();
717 KnownBits LHSKnown, RHSKnown;
718 switch (Op.getOpcode()) {
719 case ISD::BITCAST: {
720 if (VT.isScalableVector())
721 return SDValue();
722
723 SDValue Src = peekThroughBitcasts(Op.getOperand(0));
724 EVT SrcVT = Src.getValueType();
725 EVT DstVT = Op.getValueType();
726 if (SrcVT == DstVT)
727 return Src;
728
729 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
730 unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
731 if (NumSrcEltBits == NumDstEltBits)
732 if (SDValue V = SimplifyMultipleUseDemandedBits(
733 Src, DemandedBits, DemandedElts, DAG, Depth + 1))
734 return DAG.getBitcast(DstVT, V);
735
736 if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
737 unsigned Scale = NumDstEltBits / NumSrcEltBits;
738 unsigned NumSrcElts = SrcVT.getVectorNumElements();
739 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
740 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
741 for (unsigned i = 0; i != Scale; ++i) {
742 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
743 unsigned BitOffset = EltOffset * NumSrcEltBits;
744 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
745 if (!Sub.isZero()) {
746 DemandedSrcBits |= Sub;
747 for (unsigned j = 0; j != NumElts; ++j)
748 if (DemandedElts[j])
749 DemandedSrcElts.setBit((j * Scale) + i);
750 }
751 }
752
753 if (SDValue V = SimplifyMultipleUseDemandedBits(
754 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
755 return DAG.getBitcast(DstVT, V);
756 }
757
758 // TODO - bigendian once we have test coverage.
759 if (IsLE && (NumSrcEltBits % NumDstEltBits) == 0) {
760 unsigned Scale = NumSrcEltBits / NumDstEltBits;
761 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
762 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
763 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
764 for (unsigned i = 0; i != NumElts; ++i)
765 if (DemandedElts[i]) {
766 unsigned Offset = (i % Scale) * NumDstEltBits;
767 DemandedSrcBits.insertBits(DemandedBits, Offset);
768 DemandedSrcElts.setBit(i / Scale);
769 }
770
771 if (SDValue V = SimplifyMultipleUseDemandedBits(
772 Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
773 return DAG.getBitcast(DstVT, V);
774 }
775
776 break;
777 }
778 case ISD::FREEZE: {
779 SDValue N0 = Op.getOperand(0);
780 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
781 /*PoisonOnly=*/false))
782 return N0;
783 break;
784 }
785 case ISD::AND: {
786 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
787 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
788
789 // If all of the demanded bits are known 1 on one side, return the other.
790 // These bits cannot contribute to the result of the 'and' in this
791 // context.
792 if (DemandedBits.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
793 return Op.getOperand(0);
794 if (DemandedBits.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
795 return Op.getOperand(1);
796 break;
797 }
798 case ISD::OR: {
799 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
800 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
801
802 // If all of the demanded bits are known zero on one side, return the
803 // other. These bits cannot contribute to the result of the 'or' in this
804 // context.
805 if (DemandedBits.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
806 return Op.getOperand(0);
807 if (DemandedBits.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
808 return Op.getOperand(1);
809 break;
810 }
811 case ISD::XOR: {
812 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
813 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
814
815 // If all of the demanded bits are known zero on one side, return the
816 // other.
817 if (DemandedBits.isSubsetOf(RHSKnown.Zero))
818 return Op.getOperand(0);
819 if (DemandedBits.isSubsetOf(LHSKnown.Zero))
820 return Op.getOperand(1);
821 break;
822 }
823 case ISD::ADD: {
824 RHSKnown = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
825 if (RHSKnown.isZero())
826 return Op.getOperand(0);
827
828 LHSKnown = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
829 if (LHSKnown.isZero())
830 return Op.getOperand(1);
831 break;
832 }
833 case ISD::SHL: {
834 // If we are only demanding sign bits then we can use the shift source
835 // directly.
836 if (std::optional<uint64_t> MaxSA =
837 DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
838 SDValue Op0 = Op.getOperand(0);
839 unsigned ShAmt = *MaxSA;
840 unsigned NumSignBits =
841 DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
842 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
843 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
844 return Op0;
845 }
846 break;
847 }
848 case ISD::SRL: {
849 // If we are only demanding sign bits then we can use the shift source
850 // directly.
851 if (std::optional<uint64_t> MaxSA =
852 DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
853 SDValue Op0 = Op.getOperand(0);
854 unsigned ShAmt = *MaxSA;
855 // Must already be signbits in DemandedBits bounds, and can't demand any
856 // shifted in zeroes.
857 if (DemandedBits.countl_zero() >= ShAmt) {
858 unsigned NumSignBits =
859 DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
860 if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
861 return Op0;
862 }
863 }
864 break;
865 }
866 case ISD::SETCC: {
867 SDValue Op0 = Op.getOperand(0);
868 SDValue Op1 = Op.getOperand(1);
869 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
870 // If (1) we only need the sign-bit, (2) the setcc operands are the same
871 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
872 // -1, we may be able to bypass the setcc.
873 if (DemandedBits.isSignMask() &&
874 Op0.getScalarValueSizeInBits() == BitWidth &&
875 getBooleanContents(Op0.getValueType()) ==
876 BooleanContent::ZeroOrNegativeOneBooleanContent) {
877 // If we're testing X < 0, then this compare isn't needed - just use X!
878 // FIXME: We're limiting to integer types here, but this should also work
879 // if we don't care about FP signed-zero. The use of SETLT with FP means
880 // that we don't care about NaNs.
881 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
882 (isNullConstant(Op1) || ISD::isBuildVectorAllZeros(Op1.getNode())))
883 return Op0;
884 }
885 break;
886 }
887 case ISD::SIGN_EXTEND_INREG: {
888 // If none of the extended bits are demanded, eliminate the sextinreg.
889 SDValue Op0 = Op.getOperand(0);
890 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
891 unsigned ExBits = ExVT.getScalarSizeInBits();
892 if (DemandedBits.getActiveBits() <= ExBits &&
893 shouldRemoveRedundantExtend(Op))
894 return Op0;
895 // If the input is already sign extended, just drop the extension.
896 unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
897 if (NumSignBits >= (BitWidth - ExBits + 1))
898 return Op0;
899 break;
900 }
901 case ISD::ANY_EXTEND_VECTOR_INREG:
902 case ISD::SIGN_EXTEND_VECTOR_INREG:
903 case ISD::ZERO_EXTEND_VECTOR_INREG: {
904 if (VT.isScalableVector())
905 return SDValue();
906
907 // If we only want the lowest element and none of extended bits, then we can
908 // return the bitcasted source vector.
909 SDValue Src = Op.getOperand(0);
910 EVT SrcVT = Src.getValueType();
911 EVT DstVT = Op.getValueType();
912 if (IsLE && DemandedElts == 1 &&
913 DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
914 DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
915 return DAG.getBitcast(DstVT, Src);
916 }
917 break;
918 }
919 case ISD::INSERT_VECTOR_ELT: {
920 if (VT.isScalableVector())
921 return SDValue();
922
923 // If we don't demand the inserted element, return the base vector.
924 SDValue Vec = Op.getOperand(0);
925 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
926 EVT VecVT = Vec.getValueType();
927 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
928 !DemandedElts[CIdx->getZExtValue()])
929 return Vec;
930 break;
931 }
932 case ISD::INSERT_SUBVECTOR: {
933 if (VT.isScalableVector())
934 return SDValue();
935
936 SDValue Vec = Op.getOperand(0);
937 SDValue Sub = Op.getOperand(1);
938 uint64_t Idx = Op.getConstantOperandVal(2);
939 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
940 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
941 // If we don't demand the inserted subvector, return the base vector.
942 if (DemandedSubElts == 0)
943 return Vec;
944 break;
945 }
946 case ISD::VECTOR_SHUFFLE: {
947 assert(!VT.isScalableVector());
948 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
949
950 // If all the demanded elts are from one operand and are inline,
951 // then we can use the operand directly.
952 bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
953 for (unsigned i = 0; i != NumElts; ++i) {
954 int M = ShuffleMask[i];
955 if (M < 0 || !DemandedElts[i])
956 continue;
957 AllUndef = false;
958 IdentityLHS &= (M == (int)i);
959 IdentityRHS &= ((M - NumElts) == i);
960 }
961
962 if (AllUndef)
963 return DAG.getUNDEF(Op.getValueType());
964 if (IdentityLHS)
965 return Op.getOperand(0);
966 if (IdentityRHS)
967 return Op.getOperand(1);
968 break;
969 }
970 default:
971 // TODO: Probably okay to remove after audit; here to reduce change size
972 // in initial enablement patch for scalable vectors
973 if (VT.isScalableVector())
974 return SDValue();
975
976 if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
977 if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
978 Op, DemandedBits, DemandedElts, DAG, Depth))
979 return V;
980 break;
981 }
982 return SDValue();
983 }
984
SimplifyMultipleUseDemandedBits(SDValue Op,const APInt & DemandedBits,SelectionDAG & DAG,unsigned Depth) const985 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
986 SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
987 unsigned Depth) const {
988 EVT VT = Op.getValueType();
989 // Since the number of lanes in a scalable vector is unknown at compile time,
990 // we track one bit which is implicitly broadcast to all lanes. This means
991 // that all lanes in a scalable vector are considered demanded.
992 APInt DemandedElts = VT.isFixedLengthVector()
993 ? APInt::getAllOnes(VT.getVectorNumElements())
994 : APInt(1, 1);
995 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
996 Depth);
997 }
998
SimplifyMultipleUseDemandedVectorElts(SDValue Op,const APInt & DemandedElts,SelectionDAG & DAG,unsigned Depth) const999 SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
1000 SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
1001 unsigned Depth) const {
1002 APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
1003 return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
1004 Depth);
1005 }
1006
1007 // Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
1008 // or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
combineShiftToAVG(SDValue Op,TargetLowering::TargetLoweringOpt & TLO,const TargetLowering & TLI,const APInt & DemandedBits,const APInt & DemandedElts,unsigned Depth)1009 static SDValue combineShiftToAVG(SDValue Op,
1010 TargetLowering::TargetLoweringOpt &TLO,
1011 const TargetLowering &TLI,
1012 const APInt &DemandedBits,
1013 const APInt &DemandedElts, unsigned Depth) {
1014 assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
1015 "SRL or SRA node is required here!");
1016 // Is the right shift using an immediate value of 1?
1017 ConstantSDNode *N1C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
1018 if (!N1C || !N1C->isOne())
1019 return SDValue();
1020
1021 // We are looking for an avgfloor
1022 // add(ext, ext)
1023 // or one of these as a avgceil
1024 // add(add(ext, ext), 1)
1025 // add(add(ext, 1), ext)
1026 // add(ext, add(ext, 1))
1027 SDValue Add = Op.getOperand(0);
1028 if (Add.getOpcode() != ISD::ADD)
1029 return SDValue();
1030
1031 SDValue ExtOpA = Add.getOperand(0);
1032 SDValue ExtOpB = Add.getOperand(1);
1033 SDValue Add2;
1034 auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3, SDValue A) {
1035 ConstantSDNode *ConstOp;
1036 if ((ConstOp = isConstOrConstSplat(Op2, DemandedElts)) &&
1037 ConstOp->isOne()) {
1038 ExtOpA = Op1;
1039 ExtOpB = Op3;
1040 Add2 = A;
1041 return true;
1042 }
1043 if ((ConstOp = isConstOrConstSplat(Op3, DemandedElts)) &&
1044 ConstOp->isOne()) {
1045 ExtOpA = Op1;
1046 ExtOpB = Op2;
1047 Add2 = A;
1048 return true;
1049 }
1050 return false;
1051 };
1052 bool IsCeil =
1053 (ExtOpA.getOpcode() == ISD::ADD &&
1054 MatchOperands(ExtOpA.getOperand(0), ExtOpA.getOperand(1), ExtOpB, ExtOpA)) ||
1055 (ExtOpB.getOpcode() == ISD::ADD &&
1056 MatchOperands(ExtOpB.getOperand(0), ExtOpB.getOperand(1), ExtOpA, ExtOpB));
1057
1058 // If the shift is signed (sra):
1059 // - Needs >= 2 sign bit for both operands.
1060 // - Needs >= 2 zero bits.
1061 // If the shift is unsigned (srl):
1062 // - Needs >= 1 zero bit for both operands.
1063 // - Needs 1 demanded bit zero and >= 2 sign bits.
1064 SelectionDAG &DAG = TLO.DAG;
1065 unsigned ShiftOpc = Op.getOpcode();
1066 bool IsSigned = false;
1067 unsigned KnownBits;
1068 unsigned NumSignedA = DAG.ComputeNumSignBits(ExtOpA, DemandedElts, Depth);
1069 unsigned NumSignedB = DAG.ComputeNumSignBits(ExtOpB, DemandedElts, Depth);
1070 unsigned NumSigned = std::min(NumSignedA, NumSignedB) - 1;
1071 unsigned NumZeroA =
1072 DAG.computeKnownBits(ExtOpA, DemandedElts, Depth).countMinLeadingZeros();
1073 unsigned NumZeroB =
1074 DAG.computeKnownBits(ExtOpB, DemandedElts, Depth).countMinLeadingZeros();
1075 unsigned NumZero = std::min(NumZeroA, NumZeroB);
1076
1077 switch (ShiftOpc) {
1078 default:
1079 llvm_unreachable("Unexpected ShiftOpc in combineShiftToAVG");
1080 case ISD::SRA: {
1081 if (NumZero >= 2 && NumSigned < NumZero) {
1082 IsSigned = false;
1083 KnownBits = NumZero;
1084 break;
1085 }
1086 if (NumSigned >= 1) {
1087 IsSigned = true;
1088 KnownBits = NumSigned;
1089 break;
1090 }
1091 return SDValue();
1092 }
1093 case ISD::SRL: {
1094 if (NumZero >= 1 && NumSigned < NumZero) {
1095 IsSigned = false;
1096 KnownBits = NumZero;
1097 break;
1098 }
1099 if (NumSigned >= 1 && DemandedBits.isSignBitClear()) {
1100 IsSigned = true;
1101 KnownBits = NumSigned;
1102 break;
1103 }
1104 return SDValue();
1105 }
1106 }
1107
1108 unsigned AVGOpc = IsCeil ? (IsSigned ? ISD::AVGCEILS : ISD::AVGCEILU)
1109 : (IsSigned ? ISD::AVGFLOORS : ISD::AVGFLOORU);
1110
1111 // Find the smallest power-2 type that is legal for this vector size and
1112 // operation, given the original type size and the number of known sign/zero
1113 // bits.
1114 EVT VT = Op.getValueType();
1115 unsigned MinWidth =
1116 std::max<unsigned>(VT.getScalarSizeInBits() - KnownBits, 8);
1117 EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
1118 if (NVT.getScalarSizeInBits() > VT.getScalarSizeInBits())
1119 return SDValue();
1120 if (VT.isVector())
1121 NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1122 if (TLO.LegalTypes() && !TLI.isOperationLegal(AVGOpc, NVT)) {
1123 // If we could not transform, and (both) adds are nuw/nsw, we can use the
1124 // larger type size to do the transform.
1125 if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
1126 return SDValue();
1127 if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
1128 Add.getOperand(1)) &&
1129 (!Add2 || DAG.willNotOverflowAdd(IsSigned, Add2.getOperand(0),
1130 Add2.getOperand(1))))
1131 NVT = VT;
1132 else
1133 return SDValue();
1134 }
1135
1136 // Don't create a AVGFLOOR node with a scalar constant unless its legal as
1137 // this is likely to stop other folds (reassociation, value tracking etc.)
1138 if (!IsCeil && !TLI.isOperationLegal(AVGOpc, NVT) &&
1139 (isa<ConstantSDNode>(ExtOpA) || isa<ConstantSDNode>(ExtOpB)))
1140 return SDValue();
1141
1142 SDLoc DL(Op);
1143 SDValue ResultAVG =
1144 DAG.getNode(AVGOpc, DL, NVT, DAG.getExtOrTrunc(IsSigned, ExtOpA, DL, NVT),
1145 DAG.getExtOrTrunc(IsSigned, ExtOpB, DL, NVT));
1146 return DAG.getExtOrTrunc(IsSigned, ResultAVG, DL, VT);
1147 }
1148
1149 /// Look at Op. At this point, we know that only the OriginalDemandedBits of the
1150 /// result of Op are ever used downstream. If we can use this information to
1151 /// simplify Op, create a new simplified DAG node and return true, returning the
1152 /// original and new nodes in Old and New. Otherwise, analyze the expression and
1153 /// return a mask of Known bits for the expression (used to simplify the
1154 /// caller). The Known bits may only be accurate for those bits in the
1155 /// OriginalDemandedBits and OriginalDemandedElts.
SimplifyDemandedBits(SDValue Op,const APInt & OriginalDemandedBits,const APInt & OriginalDemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth,bool AssumeSingleUse) const1156 bool TargetLowering::SimplifyDemandedBits(
1157 SDValue Op, const APInt &OriginalDemandedBits,
1158 const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
1159 unsigned Depth, bool AssumeSingleUse) const {
1160 unsigned BitWidth = OriginalDemandedBits.getBitWidth();
1161 assert(Op.getScalarValueSizeInBits() == BitWidth &&
1162 "Mask size mismatches value type size!");
1163
1164 // Don't know anything.
1165 Known = KnownBits(BitWidth);
1166
1167 EVT VT = Op.getValueType();
1168 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
1169 unsigned NumElts = OriginalDemandedElts.getBitWidth();
1170 assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) &&
1171 "Unexpected vector size");
1172
1173 APInt DemandedBits = OriginalDemandedBits;
1174 APInt DemandedElts = OriginalDemandedElts;
1175 SDLoc dl(Op);
1176
1177 // Undef operand.
1178 if (Op.isUndef())
1179 return false;
1180
1181 // We can't simplify target constants.
1182 if (Op.getOpcode() == ISD::TargetConstant)
1183 return false;
1184
1185 if (Op.getOpcode() == ISD::Constant) {
1186 // We know all of the bits for a constant!
1187 Known = KnownBits::makeConstant(Op->getAsAPIntVal());
1188 return false;
1189 }
1190
1191 if (Op.getOpcode() == ISD::ConstantFP) {
1192 // We know all of the bits for a floating point constant!
1193 Known = KnownBits::makeConstant(
1194 cast<ConstantFPSDNode>(Op)->getValueAPF().bitcastToAPInt());
1195 return false;
1196 }
1197
1198 // Other users may use these bits.
1199 bool HasMultiUse = false;
1200 if (!AssumeSingleUse && !Op.getNode()->hasOneUse()) {
1201 if (Depth >= SelectionDAG::MaxRecursionDepth) {
1202 // Limit search depth.
1203 return false;
1204 }
1205 // Allow multiple uses, just set the DemandedBits/Elts to all bits.
1206 DemandedBits = APInt::getAllOnes(BitWidth);
1207 DemandedElts = APInt::getAllOnes(NumElts);
1208 HasMultiUse = true;
1209 } else if (OriginalDemandedBits == 0 || OriginalDemandedElts == 0) {
1210 // Not demanding any bits/elts from Op.
1211 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1212 } else if (Depth >= SelectionDAG::MaxRecursionDepth) {
1213 // Limit search depth.
1214 return false;
1215 }
1216
1217 KnownBits Known2;
1218 switch (Op.getOpcode()) {
1219 case ISD::SCALAR_TO_VECTOR: {
1220 if (VT.isScalableVector())
1221 return false;
1222 if (!DemandedElts[0])
1223 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
1224
1225 KnownBits SrcKnown;
1226 SDValue Src = Op.getOperand(0);
1227 unsigned SrcBitWidth = Src.getScalarValueSizeInBits();
1228 APInt SrcDemandedBits = DemandedBits.zext(SrcBitWidth);
1229 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcKnown, TLO, Depth + 1))
1230 return true;
1231
1232 // Upper elements are undef, so only get the knownbits if we just demand
1233 // the bottom element.
1234 if (DemandedElts == 1)
1235 Known = SrcKnown.anyextOrTrunc(BitWidth);
1236 break;
1237 }
1238 case ISD::BUILD_VECTOR:
1239 // Collect the known bits that are shared by every demanded element.
1240 // TODO: Call SimplifyDemandedBits for non-constant demanded elements.
1241 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1242 return false; // Don't fall through, will infinitely loop.
1243 case ISD::SPLAT_VECTOR: {
1244 SDValue Scl = Op.getOperand(0);
1245 APInt DemandedSclBits = DemandedBits.zextOrTrunc(Scl.getValueSizeInBits());
1246 KnownBits KnownScl;
1247 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1248 return true;
1249
1250 // Implicitly truncate the bits to match the official semantics of
1251 // SPLAT_VECTOR.
1252 Known = KnownScl.trunc(BitWidth);
1253 break;
1254 }
1255 case ISD::LOAD: {
1256 auto *LD = cast<LoadSDNode>(Op);
1257 if (getTargetConstantFromLoad(LD)) {
1258 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
1259 return false; // Don't fall through, will infinitely loop.
1260 }
1261 if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) {
1262 // If this is a ZEXTLoad and we are looking at the loaded value.
1263 EVT MemVT = LD->getMemoryVT();
1264 unsigned MemBits = MemVT.getScalarSizeInBits();
1265 Known.Zero.setBitsFrom(MemBits);
1266 return false; // Don't fall through, will infinitely loop.
1267 }
1268 break;
1269 }
1270 case ISD::INSERT_VECTOR_ELT: {
1271 if (VT.isScalableVector())
1272 return false;
1273 SDValue Vec = Op.getOperand(0);
1274 SDValue Scl = Op.getOperand(1);
1275 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
1276 EVT VecVT = Vec.getValueType();
1277
1278 // If index isn't constant, assume we need all vector elements AND the
1279 // inserted element.
1280 APInt DemandedVecElts(DemandedElts);
1281 if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements())) {
1282 unsigned Idx = CIdx->getZExtValue();
1283 DemandedVecElts.clearBit(Idx);
1284
1285 // Inserted element is not required.
1286 if (!DemandedElts[Idx])
1287 return TLO.CombineTo(Op, Vec);
1288 }
1289
1290 KnownBits KnownScl;
1291 unsigned NumSclBits = Scl.getScalarValueSizeInBits();
1292 APInt DemandedSclBits = DemandedBits.zextOrTrunc(NumSclBits);
1293 if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1))
1294 return true;
1295
1296 Known = KnownScl.anyextOrTrunc(BitWidth);
1297
1298 KnownBits KnownVec;
1299 if (SimplifyDemandedBits(Vec, DemandedBits, DemandedVecElts, KnownVec, TLO,
1300 Depth + 1))
1301 return true;
1302
1303 if (!!DemandedVecElts)
1304 Known = Known.intersectWith(KnownVec);
1305
1306 return false;
1307 }
1308 case ISD::INSERT_SUBVECTOR: {
1309 if (VT.isScalableVector())
1310 return false;
1311 // Demand any elements from the subvector and the remainder from the src its
1312 // inserted into.
1313 SDValue Src = Op.getOperand(0);
1314 SDValue Sub = Op.getOperand(1);
1315 uint64_t Idx = Op.getConstantOperandVal(2);
1316 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
1317 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
1318 APInt DemandedSrcElts = DemandedElts;
1319 DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
1320
1321 KnownBits KnownSub, KnownSrc;
1322 if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO,
1323 Depth + 1))
1324 return true;
1325 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, KnownSrc, TLO,
1326 Depth + 1))
1327 return true;
1328
1329 Known.Zero.setAllBits();
1330 Known.One.setAllBits();
1331 if (!!DemandedSubElts)
1332 Known = Known.intersectWith(KnownSub);
1333 if (!!DemandedSrcElts)
1334 Known = Known.intersectWith(KnownSrc);
1335
1336 // Attempt to avoid multi-use src if we don't need anything from it.
1337 if (!DemandedBits.isAllOnes() || !DemandedSubElts.isAllOnes() ||
1338 !DemandedSrcElts.isAllOnes()) {
1339 SDValue NewSub = SimplifyMultipleUseDemandedBits(
1340 Sub, DemandedBits, DemandedSubElts, TLO.DAG, Depth + 1);
1341 SDValue NewSrc = SimplifyMultipleUseDemandedBits(
1342 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1343 if (NewSub || NewSrc) {
1344 NewSub = NewSub ? NewSub : Sub;
1345 NewSrc = NewSrc ? NewSrc : Src;
1346 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc, NewSub,
1347 Op.getOperand(2));
1348 return TLO.CombineTo(Op, NewOp);
1349 }
1350 }
1351 break;
1352 }
1353 case ISD::EXTRACT_SUBVECTOR: {
1354 if (VT.isScalableVector())
1355 return false;
1356 // Offset the demanded elts by the subvector index.
1357 SDValue Src = Op.getOperand(0);
1358 if (Src.getValueType().isScalableVector())
1359 break;
1360 uint64_t Idx = Op.getConstantOperandVal(1);
1361 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
1362 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
1363
1364 if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, Known, TLO,
1365 Depth + 1))
1366 return true;
1367
1368 // Attempt to avoid multi-use src if we don't need anything from it.
1369 if (!DemandedBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
1370 SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
1371 Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1);
1372 if (DemandedSrc) {
1373 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc,
1374 Op.getOperand(1));
1375 return TLO.CombineTo(Op, NewOp);
1376 }
1377 }
1378 break;
1379 }
1380 case ISD::CONCAT_VECTORS: {
1381 if (VT.isScalableVector())
1382 return false;
1383 Known.Zero.setAllBits();
1384 Known.One.setAllBits();
1385 EVT SubVT = Op.getOperand(0).getValueType();
1386 unsigned NumSubVecs = Op.getNumOperands();
1387 unsigned NumSubElts = SubVT.getVectorNumElements();
1388 for (unsigned i = 0; i != NumSubVecs; ++i) {
1389 APInt DemandedSubElts =
1390 DemandedElts.extractBits(NumSubElts, i * NumSubElts);
1391 if (SimplifyDemandedBits(Op.getOperand(i), DemandedBits, DemandedSubElts,
1392 Known2, TLO, Depth + 1))
1393 return true;
1394 // Known bits are shared by every demanded subvector element.
1395 if (!!DemandedSubElts)
1396 Known = Known.intersectWith(Known2);
1397 }
1398 break;
1399 }
1400 case ISD::VECTOR_SHUFFLE: {
1401 assert(!VT.isScalableVector());
1402 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
1403
1404 // Collect demanded elements from shuffle operands..
1405 APInt DemandedLHS, DemandedRHS;
1406 if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS,
1407 DemandedRHS))
1408 break;
1409
1410 if (!!DemandedLHS || !!DemandedRHS) {
1411 SDValue Op0 = Op.getOperand(0);
1412 SDValue Op1 = Op.getOperand(1);
1413
1414 Known.Zero.setAllBits();
1415 Known.One.setAllBits();
1416 if (!!DemandedLHS) {
1417 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedLHS, Known2, TLO,
1418 Depth + 1))
1419 return true;
1420 Known = Known.intersectWith(Known2);
1421 }
1422 if (!!DemandedRHS) {
1423 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedRHS, Known2, TLO,
1424 Depth + 1))
1425 return true;
1426 Known = Known.intersectWith(Known2);
1427 }
1428
1429 // Attempt to avoid multi-use ops if we don't need anything from them.
1430 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1431 Op0, DemandedBits, DemandedLHS, TLO.DAG, Depth + 1);
1432 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1433 Op1, DemandedBits, DemandedRHS, TLO.DAG, Depth + 1);
1434 if (DemandedOp0 || DemandedOp1) {
1435 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1436 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1437 SDValue NewOp = TLO.DAG.getVectorShuffle(VT, dl, Op0, Op1, ShuffleMask);
1438 return TLO.CombineTo(Op, NewOp);
1439 }
1440 }
1441 break;
1442 }
1443 case ISD::AND: {
1444 SDValue Op0 = Op.getOperand(0);
1445 SDValue Op1 = Op.getOperand(1);
1446
1447 // If the RHS is a constant, check to see if the LHS would be zero without
1448 // using the bits from the RHS. Below, we use knowledge about the RHS to
1449 // simplify the LHS, here we're using information from the LHS to simplify
1450 // the RHS.
1451 if (ConstantSDNode *RHSC = isConstOrConstSplat(Op1, DemandedElts)) {
1452 // Do not increment Depth here; that can cause an infinite loop.
1453 KnownBits LHSKnown = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth);
1454 // If the LHS already has zeros where RHSC does, this 'and' is dead.
1455 if ((LHSKnown.Zero & DemandedBits) ==
1456 (~RHSC->getAPIntValue() & DemandedBits))
1457 return TLO.CombineTo(Op, Op0);
1458
1459 // If any of the set bits in the RHS are known zero on the LHS, shrink
1460 // the constant.
1461 if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits,
1462 DemandedElts, TLO))
1463 return true;
1464
1465 // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its
1466 // constant, but if this 'and' is only clearing bits that were just set by
1467 // the xor, then this 'and' can be eliminated by shrinking the mask of
1468 // the xor. For example, for a 32-bit X:
1469 // and (xor (srl X, 31), -1), 1 --> xor (srl X, 31), 1
1470 if (isBitwiseNot(Op0) && Op0.hasOneUse() &&
1471 LHSKnown.One == ~RHSC->getAPIntValue()) {
1472 SDValue Xor = TLO.DAG.getNode(ISD::XOR, dl, VT, Op0.getOperand(0), Op1);
1473 return TLO.CombineTo(Op, Xor);
1474 }
1475 }
1476
1477 // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I)
1478 // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits).
1479 if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() &&
1480 (Op0.getOperand(0).isUndef() ||
1481 ISD::isBuildVectorOfConstantSDNodes(Op0.getOperand(0).getNode())) &&
1482 Op0->hasOneUse()) {
1483 unsigned NumSubElts =
1484 Op0.getOperand(1).getValueType().getVectorNumElements();
1485 unsigned SubIdx = Op0.getConstantOperandVal(2);
1486 APInt DemandedSub =
1487 APInt::getBitsSet(NumElts, SubIdx, SubIdx + NumSubElts);
1488 KnownBits KnownSubMask =
1489 TLO.DAG.computeKnownBits(Op1, DemandedSub & DemandedElts, Depth + 1);
1490 if (DemandedBits.isSubsetOf(KnownSubMask.One)) {
1491 SDValue NewAnd =
1492 TLO.DAG.getNode(ISD::AND, dl, VT, Op0.getOperand(0), Op1);
1493 SDValue NewInsert =
1494 TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VT, NewAnd,
1495 Op0.getOperand(1), Op0.getOperand(2));
1496 return TLO.CombineTo(Op, NewInsert);
1497 }
1498 }
1499
1500 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1501 Depth + 1))
1502 return true;
1503 if (SimplifyDemandedBits(Op0, ~Known.Zero & DemandedBits, DemandedElts,
1504 Known2, TLO, Depth + 1))
1505 return true;
1506
1507 // If all of the demanded bits are known one on one side, return the other.
1508 // These bits cannot contribute to the result of the 'and'.
1509 if (DemandedBits.isSubsetOf(Known2.Zero | Known.One))
1510 return TLO.CombineTo(Op, Op0);
1511 if (DemandedBits.isSubsetOf(Known.Zero | Known2.One))
1512 return TLO.CombineTo(Op, Op1);
1513 // If all of the demanded bits in the inputs are known zeros, return zero.
1514 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1515 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT));
1516 // If the RHS is a constant, see if we can simplify it.
1517 if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts,
1518 TLO))
1519 return true;
1520 // If the operation can be done in a smaller type, do so.
1521 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1522 return true;
1523
1524 // Attempt to avoid multi-use ops if we don't need anything from them.
1525 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1526 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1527 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1528 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1529 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1530 if (DemandedOp0 || DemandedOp1) {
1531 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1532 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1533 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1534 return TLO.CombineTo(Op, NewOp);
1535 }
1536 }
1537
1538 Known &= Known2;
1539 break;
1540 }
1541 case ISD::OR: {
1542 SDValue Op0 = Op.getOperand(0);
1543 SDValue Op1 = Op.getOperand(1);
1544 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1545 Depth + 1)) {
1546 Op->dropFlags(SDNodeFlags::Disjoint);
1547 return true;
1548 }
1549
1550 if (SimplifyDemandedBits(Op0, ~Known.One & DemandedBits, DemandedElts,
1551 Known2, TLO, Depth + 1)) {
1552 Op->dropFlags(SDNodeFlags::Disjoint);
1553 return true;
1554 }
1555
1556 // If all of the demanded bits are known zero on one side, return the other.
1557 // These bits cannot contribute to the result of the 'or'.
1558 if (DemandedBits.isSubsetOf(Known2.One | Known.Zero))
1559 return TLO.CombineTo(Op, Op0);
1560 if (DemandedBits.isSubsetOf(Known.One | Known2.Zero))
1561 return TLO.CombineTo(Op, Op1);
1562 // If the RHS is a constant, see if we can simplify it.
1563 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1564 return true;
1565 // If the operation can be done in a smaller type, do so.
1566 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1567 return true;
1568
1569 // Attempt to avoid multi-use ops if we don't need anything from them.
1570 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1571 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1572 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1573 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1574 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1575 if (DemandedOp0 || DemandedOp1) {
1576 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1577 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1578 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1579 return TLO.CombineTo(Op, NewOp);
1580 }
1581 }
1582
1583 // (or (and X, C1), (and (or X, Y), C2)) -> (or (and X, C1|C2), (and Y, C2))
1584 // TODO: Use SimplifyMultipleUseDemandedBits to peek through masks.
1585 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::AND &&
1586 Op0->hasOneUse() && Op1->hasOneUse()) {
1587 // Attempt to match all commutations - m_c_Or would've been useful!
1588 for (int I = 0; I != 2; ++I) {
1589 SDValue X = Op.getOperand(I).getOperand(0);
1590 SDValue C1 = Op.getOperand(I).getOperand(1);
1591 SDValue Alt = Op.getOperand(1 - I).getOperand(0);
1592 SDValue C2 = Op.getOperand(1 - I).getOperand(1);
1593 if (Alt.getOpcode() == ISD::OR) {
1594 for (int J = 0; J != 2; ++J) {
1595 if (X == Alt.getOperand(J)) {
1596 SDValue Y = Alt.getOperand(1 - J);
1597 if (SDValue C12 = TLO.DAG.FoldConstantArithmetic(ISD::OR, dl, VT,
1598 {C1, C2})) {
1599 SDValue MaskX = TLO.DAG.getNode(ISD::AND, dl, VT, X, C12);
1600 SDValue MaskY = TLO.DAG.getNode(ISD::AND, dl, VT, Y, C2);
1601 return TLO.CombineTo(
1602 Op, TLO.DAG.getNode(ISD::OR, dl, VT, MaskX, MaskY));
1603 }
1604 }
1605 }
1606 }
1607 }
1608 }
1609
1610 Known |= Known2;
1611 break;
1612 }
1613 case ISD::XOR: {
1614 SDValue Op0 = Op.getOperand(0);
1615 SDValue Op1 = Op.getOperand(1);
1616
1617 if (SimplifyDemandedBits(Op1, DemandedBits, DemandedElts, Known, TLO,
1618 Depth + 1))
1619 return true;
1620 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known2, TLO,
1621 Depth + 1))
1622 return true;
1623
1624 // If all of the demanded bits are known zero on one side, return the other.
1625 // These bits cannot contribute to the result of the 'xor'.
1626 if (DemandedBits.isSubsetOf(Known.Zero))
1627 return TLO.CombineTo(Op, Op0);
1628 if (DemandedBits.isSubsetOf(Known2.Zero))
1629 return TLO.CombineTo(Op, Op1);
1630 // If the operation can be done in a smaller type, do so.
1631 if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
1632 return true;
1633
1634 // If all of the unknown bits are known to be zero on one side or the other
1635 // turn this into an *inclusive* or.
1636 // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
1637 if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
1638 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
1639
1640 ConstantSDNode *C = isConstOrConstSplat(Op1, DemandedElts);
1641 if (C) {
1642 // If one side is a constant, and all of the set bits in the constant are
1643 // also known set on the other side, turn this into an AND, as we know
1644 // the bits will be cleared.
1645 // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
1646 // NB: it is okay if more bits are known than are requested
1647 if (C->getAPIntValue() == Known2.One) {
1648 SDValue ANDC =
1649 TLO.DAG.getConstant(~C->getAPIntValue() & DemandedBits, dl, VT);
1650 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::AND, dl, VT, Op0, ANDC));
1651 }
1652
1653 // If the RHS is a constant, see if we can change it. Don't alter a -1
1654 // constant because that's a 'not' op, and that is better for combining
1655 // and codegen.
1656 if (!C->isAllOnes() && DemandedBits.isSubsetOf(C->getAPIntValue())) {
1657 // We're flipping all demanded bits. Flip the undemanded bits too.
1658 SDValue New = TLO.DAG.getNOT(dl, Op0, VT);
1659 return TLO.CombineTo(Op, New);
1660 }
1661
1662 unsigned Op0Opcode = Op0.getOpcode();
1663 if ((Op0Opcode == ISD::SRL || Op0Opcode == ISD::SHL) && Op0.hasOneUse()) {
1664 if (ConstantSDNode *ShiftC =
1665 isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
1666 // Don't crash on an oversized shift. We can not guarantee that a
1667 // bogus shift has been simplified to undef.
1668 if (ShiftC->getAPIntValue().ult(BitWidth)) {
1669 uint64_t ShiftAmt = ShiftC->getZExtValue();
1670 APInt Ones = APInt::getAllOnes(BitWidth);
1671 Ones = Op0Opcode == ISD::SHL ? Ones.shl(ShiftAmt)
1672 : Ones.lshr(ShiftAmt);
1673 if ((DemandedBits & C->getAPIntValue()) == (DemandedBits & Ones) &&
1674 isDesirableToCommuteXorWithShift(Op.getNode())) {
1675 // If the xor constant is a demanded mask, do a 'not' before the
1676 // shift:
1677 // xor (X << ShiftC), XorC --> (not X) << ShiftC
1678 // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
1679 SDValue Not = TLO.DAG.getNOT(dl, Op0.getOperand(0), VT);
1680 return TLO.CombineTo(Op, TLO.DAG.getNode(Op0Opcode, dl, VT, Not,
1681 Op0.getOperand(1)));
1682 }
1683 }
1684 }
1685 }
1686 }
1687
1688 // If we can't turn this into a 'not', try to shrink the constant.
1689 if (!C || !C->isAllOnes())
1690 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1691 return true;
1692
1693 // Attempt to avoid multi-use ops if we don't need anything from them.
1694 if (!DemandedBits.isAllOnes() || !DemandedElts.isAllOnes()) {
1695 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1696 Op0, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1697 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
1698 Op1, DemandedBits, DemandedElts, TLO.DAG, Depth + 1);
1699 if (DemandedOp0 || DemandedOp1) {
1700 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
1701 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
1702 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1);
1703 return TLO.CombineTo(Op, NewOp);
1704 }
1705 }
1706
1707 Known ^= Known2;
1708 break;
1709 }
1710 case ISD::SELECT:
1711 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1712 Known, TLO, Depth + 1))
1713 return true;
1714 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1715 Known2, TLO, Depth + 1))
1716 return true;
1717
1718 // If the operands are constants, see if we can simplify them.
1719 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1720 return true;
1721
1722 // Only known if known in both the LHS and RHS.
1723 Known = Known.intersectWith(Known2);
1724 break;
1725 case ISD::VSELECT:
1726 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1727 Known, TLO, Depth + 1))
1728 return true;
1729 if (SimplifyDemandedBits(Op.getOperand(1), DemandedBits, DemandedElts,
1730 Known2, TLO, Depth + 1))
1731 return true;
1732
1733 // Only known if known in both the LHS and RHS.
1734 Known = Known.intersectWith(Known2);
1735 break;
1736 case ISD::SELECT_CC:
1737 if (SimplifyDemandedBits(Op.getOperand(3), DemandedBits, DemandedElts,
1738 Known, TLO, Depth + 1))
1739 return true;
1740 if (SimplifyDemandedBits(Op.getOperand(2), DemandedBits, DemandedElts,
1741 Known2, TLO, Depth + 1))
1742 return true;
1743
1744 // If the operands are constants, see if we can simplify them.
1745 if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
1746 return true;
1747
1748 // Only known if known in both the LHS and RHS.
1749 Known = Known.intersectWith(Known2);
1750 break;
1751 case ISD::SETCC: {
1752 SDValue Op0 = Op.getOperand(0);
1753 SDValue Op1 = Op.getOperand(1);
1754 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
1755 // If (1) we only need the sign-bit, (2) the setcc operands are the same
1756 // width as the setcc result, and (3) the result of a setcc conforms to 0 or
1757 // -1, we may be able to bypass the setcc.
1758 if (DemandedBits.isSignMask() &&
1759 Op0.getScalarValueSizeInBits() == BitWidth &&
1760 getBooleanContents(Op0.getValueType()) ==
1761 BooleanContent::ZeroOrNegativeOneBooleanContent) {
1762 // If we're testing X < 0, then this compare isn't needed - just use X!
1763 // FIXME: We're limiting to integer types here, but this should also work
1764 // if we don't care about FP signed-zero. The use of SETLT with FP means
1765 // that we don't care about NaNs.
1766 if (CC == ISD::SETLT && Op1.getValueType().isInteger() &&
1767 (isNullConstant(Op1) || ISD::isBuildVectorAllZeros(Op1.getNode())))
1768 return TLO.CombineTo(Op, Op0);
1769
1770 // TODO: Should we check for other forms of sign-bit comparisons?
1771 // Examples: X <= -1, X >= 0
1772 }
1773 if (getBooleanContents(Op0.getValueType()) ==
1774 TargetLowering::ZeroOrOneBooleanContent &&
1775 BitWidth > 1)
1776 Known.Zero.setBitsFrom(1);
1777 break;
1778 }
1779 case ISD::SHL: {
1780 SDValue Op0 = Op.getOperand(0);
1781 SDValue Op1 = Op.getOperand(1);
1782 EVT ShiftVT = Op1.getValueType();
1783
1784 if (std::optional<uint64_t> KnownSA =
1785 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1786 unsigned ShAmt = *KnownSA;
1787 if (ShAmt == 0)
1788 return TLO.CombineTo(Op, Op0);
1789
1790 // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
1791 // single shift. We can do this if the bottom bits (which are shifted
1792 // out) are never demanded.
1793 // TODO - support non-uniform vector amounts.
1794 if (Op0.getOpcode() == ISD::SRL) {
1795 if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
1796 if (std::optional<uint64_t> InnerSA =
1797 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1798 unsigned C1 = *InnerSA;
1799 unsigned Opc = ISD::SHL;
1800 int Diff = ShAmt - C1;
1801 if (Diff < 0) {
1802 Diff = -Diff;
1803 Opc = ISD::SRL;
1804 }
1805 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1806 return TLO.CombineTo(
1807 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1808 }
1809 }
1810 }
1811
1812 // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits
1813 // are not demanded. This will likely allow the anyext to be folded away.
1814 // TODO - support non-uniform vector amounts.
1815 if (Op0.getOpcode() == ISD::ANY_EXTEND) {
1816 SDValue InnerOp = Op0.getOperand(0);
1817 EVT InnerVT = InnerOp.getValueType();
1818 unsigned InnerBits = InnerVT.getScalarSizeInBits();
1819 if (ShAmt < InnerBits && DemandedBits.getActiveBits() <= InnerBits &&
1820 isTypeDesirableForOp(ISD::SHL, InnerVT)) {
1821 SDValue NarrowShl = TLO.DAG.getNode(
1822 ISD::SHL, dl, InnerVT, InnerOp,
1823 TLO.DAG.getShiftAmountConstant(ShAmt, InnerVT, dl));
1824 return TLO.CombineTo(
1825 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1826 }
1827
1828 // Repeat the SHL optimization above in cases where an extension
1829 // intervenes: (shl (anyext (shr x, c1)), c2) to
1830 // (shl (anyext x), c2-c1). This requires that the bottom c1 bits
1831 // aren't demanded (as above) and that the shifted upper c1 bits of
1832 // x aren't demanded.
1833 // TODO - support non-uniform vector amounts.
1834 if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
1835 InnerOp.hasOneUse()) {
1836 if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
1837 InnerOp, DemandedElts, Depth + 2)) {
1838 unsigned InnerShAmt = *SA2;
1839 if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
1840 DemandedBits.getActiveBits() <=
1841 (InnerBits - InnerShAmt + ShAmt) &&
1842 DemandedBits.countr_zero() >= ShAmt) {
1843 SDValue NewSA =
1844 TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT);
1845 SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT,
1846 InnerOp.getOperand(0));
1847 return TLO.CombineTo(
1848 Op, TLO.DAG.getNode(ISD::SHL, dl, VT, NewExt, NewSA));
1849 }
1850 }
1851 }
1852 }
1853
1854 APInt InDemandedMask = DemandedBits.lshr(ShAmt);
1855 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
1856 Depth + 1)) {
1857 // Disable the nsw and nuw flags. We can no longer guarantee that we
1858 // won't wrap after simplification.
1859 Op->dropFlags(SDNodeFlags::NoWrap);
1860 return true;
1861 }
1862 Known.Zero <<= ShAmt;
1863 Known.One <<= ShAmt;
1864 // low bits known zero.
1865 Known.Zero.setLowBits(ShAmt);
1866
1867 // Attempt to avoid multi-use ops if we don't need anything from them.
1868 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
1869 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
1870 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
1871 if (DemandedOp0) {
1872 SDValue NewOp = TLO.DAG.getNode(ISD::SHL, dl, VT, DemandedOp0, Op1);
1873 return TLO.CombineTo(Op, NewOp);
1874 }
1875 }
1876
1877 // TODO: Can we merge this fold with the one below?
1878 // Try shrinking the operation as long as the shift amount will still be
1879 // in range.
1880 if (ShAmt < DemandedBits.getActiveBits() && !VT.isVector() &&
1881 Op.getNode()->hasOneUse()) {
1882 // Search for the smallest integer type with free casts to and from
1883 // Op's type. For expedience, just check power-of-2 integer types.
1884 unsigned DemandedSize = DemandedBits.getActiveBits();
1885 for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
1886 SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
1887 EVT SmallVT = EVT::getIntegerVT(*TLO.DAG.getContext(), SmallVTBits);
1888 if (isNarrowingProfitable(Op.getNode(), VT, SmallVT) &&
1889 isTypeDesirableForOp(ISD::SHL, SmallVT) &&
1890 isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT) &&
1891 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, SmallVT))) {
1892 assert(DemandedSize <= SmallVTBits &&
1893 "Narrowed below demanded bits?");
1894 // We found a type with free casts.
1895 SDValue NarrowShl = TLO.DAG.getNode(
1896 ISD::SHL, dl, SmallVT,
1897 TLO.DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
1898 TLO.DAG.getShiftAmountConstant(ShAmt, SmallVT, dl));
1899 return TLO.CombineTo(
1900 Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
1901 }
1902 }
1903 }
1904
1905 // Narrow shift to lower half - similar to ShrinkDemandedOp.
1906 // (shl i64:x, K) -> (i64 zero_extend (shl (i32 (trunc i64:x)), K))
1907 // Only do this if we demand the upper half so the knownbits are correct.
1908 unsigned HalfWidth = BitWidth / 2;
1909 if ((BitWidth % 2) == 0 && !VT.isVector() && ShAmt < HalfWidth &&
1910 DemandedBits.countLeadingOnes() >= HalfWidth) {
1911 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), HalfWidth);
1912 if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
1913 isTypeDesirableForOp(ISD::SHL, HalfVT) &&
1914 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
1915 (!TLO.LegalOperations() || isOperationLegal(ISD::SHL, HalfVT))) {
1916 // If we're demanding the upper bits at all, we must ensure
1917 // that the upper bits of the shift result are known to be zero,
1918 // which is equivalent to the narrow shift being NUW.
1919 if (bool IsNUW = (Known.countMinLeadingZeros() >= HalfWidth)) {
1920 bool IsNSW = Known.countMinSignBits() > HalfWidth;
1921 SDNodeFlags Flags;
1922 Flags.setNoSignedWrap(IsNSW);
1923 Flags.setNoUnsignedWrap(IsNUW);
1924 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
1925 SDValue NewShiftAmt =
1926 TLO.DAG.getShiftAmountConstant(ShAmt, HalfVT, dl);
1927 SDValue NewShift = TLO.DAG.getNode(ISD::SHL, dl, HalfVT, NewOp,
1928 NewShiftAmt, Flags);
1929 SDValue NewExt =
1930 TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift);
1931 return TLO.CombineTo(Op, NewExt);
1932 }
1933 }
1934 }
1935 } else {
1936 // This is a variable shift, so we can't shift the demand mask by a known
1937 // amount. But if we are not demanding high bits, then we are not
1938 // demanding those bits from the pre-shifted operand either.
1939 if (unsigned CTLZ = DemandedBits.countl_zero()) {
1940 APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
1941 if (SimplifyDemandedBits(Op0, DemandedFromOp, DemandedElts, Known, TLO,
1942 Depth + 1)) {
1943 // Disable the nsw and nuw flags. We can no longer guarantee that we
1944 // won't wrap after simplification.
1945 Op->dropFlags(SDNodeFlags::NoWrap);
1946 return true;
1947 }
1948 Known.resetAll();
1949 }
1950 }
1951
1952 // If we are only demanding sign bits then we can use the shift source
1953 // directly.
1954 if (std::optional<uint64_t> MaxSA =
1955 TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
1956 unsigned ShAmt = *MaxSA;
1957 unsigned NumSignBits =
1958 TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
1959 unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
1960 if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
1961 return TLO.CombineTo(Op, Op0);
1962 }
1963 break;
1964 }
1965 case ISD::SRL: {
1966 SDValue Op0 = Op.getOperand(0);
1967 SDValue Op1 = Op.getOperand(1);
1968 EVT ShiftVT = Op1.getValueType();
1969
1970 if (std::optional<uint64_t> KnownSA =
1971 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
1972 unsigned ShAmt = *KnownSA;
1973 if (ShAmt == 0)
1974 return TLO.CombineTo(Op, Op0);
1975
1976 // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
1977 // single shift. We can do this if the top bits (which are shifted out)
1978 // are never demanded.
1979 // TODO - support non-uniform vector amounts.
1980 if (Op0.getOpcode() == ISD::SHL) {
1981 if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
1982 if (std::optional<uint64_t> InnerSA =
1983 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
1984 unsigned C1 = *InnerSA;
1985 unsigned Opc = ISD::SRL;
1986 int Diff = ShAmt - C1;
1987 if (Diff < 0) {
1988 Diff = -Diff;
1989 Opc = ISD::SHL;
1990 }
1991 SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
1992 return TLO.CombineTo(
1993 Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
1994 }
1995 }
1996 }
1997
1998 // If this is (srl (sra X, C1), ShAmt), see if we can combine this into a
1999 // single sra. We can do this if the top bits are never demanded.
2000 if (Op0.getOpcode() == ISD::SRA && Op0.hasOneUse()) {
2001 if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
2002 if (std::optional<uint64_t> InnerSA =
2003 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
2004 unsigned C1 = *InnerSA;
2005 // Clamp the combined shift amount if it exceeds the bit width.
2006 unsigned Combined = std::min(C1 + ShAmt, BitWidth - 1);
2007 SDValue NewSA = TLO.DAG.getConstant(Combined, dl, ShiftVT);
2008 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRA, dl, VT,
2009 Op0.getOperand(0), NewSA));
2010 }
2011 }
2012 }
2013
2014 APInt InDemandedMask = (DemandedBits << ShAmt);
2015
2016 // If the shift is exact, then it does demand the low bits (and knows that
2017 // they are zero).
2018 if (Op->getFlags().hasExact())
2019 InDemandedMask.setLowBits(ShAmt);
2020
2021 // Narrow shift to lower half - similar to ShrinkDemandedOp.
2022 // (srl i64:x, K) -> (i64 zero_extend (srl (i32 (trunc i64:x)), K))
2023 if ((BitWidth % 2) == 0 && !VT.isVector()) {
2024 APInt HiBits = APInt::getHighBitsSet(BitWidth, BitWidth / 2);
2025 EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), BitWidth / 2);
2026 if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
2027 isTypeDesirableForOp(ISD::SRL, HalfVT) &&
2028 isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
2029 (!TLO.LegalOperations() || isOperationLegal(ISD::SRL, HalfVT)) &&
2030 ((InDemandedMask.countLeadingZeros() >= (BitWidth / 2)) ||
2031 TLO.DAG.MaskedValueIsZero(Op0, HiBits))) {
2032 SDValue NewOp = TLO.DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Op0);
2033 SDValue NewShiftAmt =
2034 TLO.DAG.getShiftAmountConstant(ShAmt, HalfVT, dl);
2035 SDValue NewShift =
2036 TLO.DAG.getNode(ISD::SRL, dl, HalfVT, NewOp, NewShiftAmt);
2037 return TLO.CombineTo(
2038 Op, TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, NewShift));
2039 }
2040 }
2041
2042 // Compute the new bits that are at the top now.
2043 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
2044 Depth + 1))
2045 return true;
2046 Known.Zero.lshrInPlace(ShAmt);
2047 Known.One.lshrInPlace(ShAmt);
2048 // High bits known zero.
2049 Known.Zero.setHighBits(ShAmt);
2050
2051 // Attempt to avoid multi-use ops if we don't need anything from them.
2052 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2053 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2054 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2055 if (DemandedOp0) {
2056 SDValue NewOp = TLO.DAG.getNode(ISD::SRL, dl, VT, DemandedOp0, Op1);
2057 return TLO.CombineTo(Op, NewOp);
2058 }
2059 }
2060 } else {
2061 // Use generic knownbits computation as it has support for non-uniform
2062 // shift amounts.
2063 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2064 }
2065
2066 // If we are only demanding sign bits then we can use the shift source
2067 // directly.
2068 if (std::optional<uint64_t> MaxSA =
2069 TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
2070 unsigned ShAmt = *MaxSA;
2071 // Must already be signbits in DemandedBits bounds, and can't demand any
2072 // shifted in zeroes.
2073 if (DemandedBits.countl_zero() >= ShAmt) {
2074 unsigned NumSignBits =
2075 TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
2076 if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
2077 return TLO.CombineTo(Op, Op0);
2078 }
2079 }
2080
2081 // Try to match AVG patterns (after shift simplification).
2082 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2083 DemandedElts, Depth + 1))
2084 return TLO.CombineTo(Op, AVG);
2085
2086 break;
2087 }
2088 case ISD::SRA: {
2089 SDValue Op0 = Op.getOperand(0);
2090 SDValue Op1 = Op.getOperand(1);
2091 EVT ShiftVT = Op1.getValueType();
2092
2093 // If we only want bits that already match the signbit then we don't need
2094 // to shift.
2095 unsigned NumHiDemandedBits = BitWidth - DemandedBits.countr_zero();
2096 if (TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1) >=
2097 NumHiDemandedBits)
2098 return TLO.CombineTo(Op, Op0);
2099
2100 // If this is an arithmetic shift right and only the low-bit is set, we can
2101 // always convert this into a logical shr, even if the shift amount is
2102 // variable. The low bit of the shift cannot be an input sign bit unless
2103 // the shift amount is >= the size of the datatype, which is undefined.
2104 if (DemandedBits.isOne())
2105 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2106
2107 if (std::optional<uint64_t> KnownSA =
2108 TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
2109 unsigned ShAmt = *KnownSA;
2110 if (ShAmt == 0)
2111 return TLO.CombineTo(Op, Op0);
2112
2113 // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
2114 // supports sext_inreg.
2115 if (Op0.getOpcode() == ISD::SHL) {
2116 if (std::optional<uint64_t> InnerSA =
2117 TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
2118 unsigned LowBits = BitWidth - ShAmt;
2119 EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
2120 if (VT.isVector())
2121 ExtVT = EVT::getVectorVT(*TLO.DAG.getContext(), ExtVT,
2122 VT.getVectorElementCount());
2123
2124 if (*InnerSA == ShAmt) {
2125 if (!TLO.LegalOperations() ||
2126 getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) == Legal)
2127 return TLO.CombineTo(
2128 Op, TLO.DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, VT,
2129 Op0.getOperand(0),
2130 TLO.DAG.getValueType(ExtVT)));
2131
2132 // Even if we can't convert to sext_inreg, we might be able to
2133 // remove this shift pair if the input is already sign extended.
2134 unsigned NumSignBits =
2135 TLO.DAG.ComputeNumSignBits(Op0.getOperand(0), DemandedElts);
2136 if (NumSignBits > ShAmt)
2137 return TLO.CombineTo(Op, Op0.getOperand(0));
2138 }
2139 }
2140 }
2141
2142 APInt InDemandedMask = (DemandedBits << ShAmt);
2143
2144 // If the shift is exact, then it does demand the low bits (and knows that
2145 // they are zero).
2146 if (Op->getFlags().hasExact())
2147 InDemandedMask.setLowBits(ShAmt);
2148
2149 // If any of the demanded bits are produced by the sign extension, we also
2150 // demand the input sign bit.
2151 if (DemandedBits.countl_zero() < ShAmt)
2152 InDemandedMask.setSignBit();
2153
2154 if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
2155 Depth + 1))
2156 return true;
2157 Known.Zero.lshrInPlace(ShAmt);
2158 Known.One.lshrInPlace(ShAmt);
2159
2160 // If the input sign bit is known to be zero, or if none of the top bits
2161 // are demanded, turn this into an unsigned shift right.
2162 if (Known.Zero[BitWidth - ShAmt - 1] ||
2163 DemandedBits.countl_zero() >= ShAmt) {
2164 SDNodeFlags Flags;
2165 Flags.setExact(Op->getFlags().hasExact());
2166 return TLO.CombineTo(
2167 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1, Flags));
2168 }
2169
2170 int Log2 = DemandedBits.exactLogBase2();
2171 if (Log2 >= 0) {
2172 // The bit must come from the sign.
2173 SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT);
2174 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA));
2175 }
2176
2177 if (Known.One[BitWidth - ShAmt - 1])
2178 // New bits are known one.
2179 Known.One.setHighBits(ShAmt);
2180
2181 // Attempt to avoid multi-use ops if we don't need anything from them.
2182 if (!InDemandedMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2183 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2184 Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1);
2185 if (DemandedOp0) {
2186 SDValue NewOp = TLO.DAG.getNode(ISD::SRA, dl, VT, DemandedOp0, Op1);
2187 return TLO.CombineTo(Op, NewOp);
2188 }
2189 }
2190 }
2191
2192 // Try to match AVG patterns (after shift simplification).
2193 if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2194 DemandedElts, Depth + 1))
2195 return TLO.CombineTo(Op, AVG);
2196
2197 break;
2198 }
2199 case ISD::FSHL:
2200 case ISD::FSHR: {
2201 SDValue Op0 = Op.getOperand(0);
2202 SDValue Op1 = Op.getOperand(1);
2203 SDValue Op2 = Op.getOperand(2);
2204 bool IsFSHL = (Op.getOpcode() == ISD::FSHL);
2205
2206 if (ConstantSDNode *SA = isConstOrConstSplat(Op2, DemandedElts)) {
2207 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2208
2209 // For fshl, 0-shift returns the 1st arg.
2210 // For fshr, 0-shift returns the 2nd arg.
2211 if (Amt == 0) {
2212 if (SimplifyDemandedBits(IsFSHL ? Op0 : Op1, DemandedBits, DemandedElts,
2213 Known, TLO, Depth + 1))
2214 return true;
2215 break;
2216 }
2217
2218 // fshl: (Op0 << Amt) | (Op1 >> (BW - Amt))
2219 // fshr: (Op0 << (BW - Amt)) | (Op1 >> Amt)
2220 APInt Demanded0 = DemandedBits.lshr(IsFSHL ? Amt : (BitWidth - Amt));
2221 APInt Demanded1 = DemandedBits << (IsFSHL ? (BitWidth - Amt) : Amt);
2222 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2223 Depth + 1))
2224 return true;
2225 if (SimplifyDemandedBits(Op1, Demanded1, DemandedElts, Known, TLO,
2226 Depth + 1))
2227 return true;
2228
2229 Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt));
2230 Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt));
2231 Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2232 Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
2233 Known = Known.unionWith(Known2);
2234
2235 // Attempt to avoid multi-use ops if we don't need anything from them.
2236 if (!Demanded0.isAllOnes() || !Demanded1.isAllOnes() ||
2237 !DemandedElts.isAllOnes()) {
2238 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2239 Op0, Demanded0, DemandedElts, TLO.DAG, Depth + 1);
2240 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2241 Op1, Demanded1, DemandedElts, TLO.DAG, Depth + 1);
2242 if (DemandedOp0 || DemandedOp1) {
2243 DemandedOp0 = DemandedOp0 ? DemandedOp0 : Op0;
2244 DemandedOp1 = DemandedOp1 ? DemandedOp1 : Op1;
2245 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedOp0,
2246 DemandedOp1, Op2);
2247 return TLO.CombineTo(Op, NewOp);
2248 }
2249 }
2250 }
2251
2252 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2253 if (isPowerOf2_32(BitWidth)) {
2254 APInt DemandedAmtBits(Op2.getScalarValueSizeInBits(), BitWidth - 1);
2255 if (SimplifyDemandedBits(Op2, DemandedAmtBits, DemandedElts,
2256 Known2, TLO, Depth + 1))
2257 return true;
2258 }
2259 break;
2260 }
2261 case ISD::ROTL:
2262 case ISD::ROTR: {
2263 SDValue Op0 = Op.getOperand(0);
2264 SDValue Op1 = Op.getOperand(1);
2265 bool IsROTL = (Op.getOpcode() == ISD::ROTL);
2266
2267 // If we're rotating an 0/-1 value, then it stays an 0/-1 value.
2268 if (BitWidth == TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1))
2269 return TLO.CombineTo(Op, Op0);
2270
2271 if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
2272 unsigned Amt = SA->getAPIntValue().urem(BitWidth);
2273 unsigned RevAmt = BitWidth - Amt;
2274
2275 // rotl: (Op0 << Amt) | (Op0 >> (BW - Amt))
2276 // rotr: (Op0 << (BW - Amt)) | (Op0 >> Amt)
2277 APInt Demanded0 = DemandedBits.rotr(IsROTL ? Amt : RevAmt);
2278 if (SimplifyDemandedBits(Op0, Demanded0, DemandedElts, Known2, TLO,
2279 Depth + 1))
2280 return true;
2281
2282 // rot*(x, 0) --> x
2283 if (Amt == 0)
2284 return TLO.CombineTo(Op, Op0);
2285
2286 // See if we don't demand either half of the rotated bits.
2287 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SHL, VT)) &&
2288 DemandedBits.countr_zero() >= (IsROTL ? Amt : RevAmt)) {
2289 Op1 = TLO.DAG.getConstant(IsROTL ? Amt : RevAmt, dl, Op1.getValueType());
2290 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, Op1));
2291 }
2292 if ((!TLO.LegalOperations() || isOperationLegal(ISD::SRL, VT)) &&
2293 DemandedBits.countl_zero() >= (IsROTL ? RevAmt : Amt)) {
2294 Op1 = TLO.DAG.getConstant(IsROTL ? RevAmt : Amt, dl, Op1.getValueType());
2295 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
2296 }
2297 }
2298
2299 // For pow-2 bitwidths we only demand the bottom modulo amt bits.
2300 if (isPowerOf2_32(BitWidth)) {
2301 APInt DemandedAmtBits(Op1.getScalarValueSizeInBits(), BitWidth - 1);
2302 if (SimplifyDemandedBits(Op1, DemandedAmtBits, DemandedElts, Known2, TLO,
2303 Depth + 1))
2304 return true;
2305 }
2306 break;
2307 }
2308 case ISD::SMIN:
2309 case ISD::SMAX:
2310 case ISD::UMIN:
2311 case ISD::UMAX: {
2312 unsigned Opc = Op.getOpcode();
2313 SDValue Op0 = Op.getOperand(0);
2314 SDValue Op1 = Op.getOperand(1);
2315
2316 // If we're only demanding signbits, then we can simplify to OR/AND node.
2317 unsigned BitOp =
2318 (Opc == ISD::SMIN || Opc == ISD::UMAX) ? ISD::OR : ISD::AND;
2319 unsigned NumSignBits =
2320 std::min(TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1),
2321 TLO.DAG.ComputeNumSignBits(Op1, DemandedElts, Depth + 1));
2322 unsigned NumDemandedUpperBits = BitWidth - DemandedBits.countr_zero();
2323 if (NumSignBits >= NumDemandedUpperBits)
2324 return TLO.CombineTo(Op, TLO.DAG.getNode(BitOp, SDLoc(Op), VT, Op0, Op1));
2325
2326 // Check if one arg is always less/greater than (or equal) to the other arg.
2327 KnownBits Known0 = TLO.DAG.computeKnownBits(Op0, DemandedElts, Depth + 1);
2328 KnownBits Known1 = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
2329 switch (Opc) {
2330 case ISD::SMIN:
2331 if (std::optional<bool> IsSLE = KnownBits::sle(Known0, Known1))
2332 return TLO.CombineTo(Op, *IsSLE ? Op0 : Op1);
2333 if (std::optional<bool> IsSLT = KnownBits::slt(Known0, Known1))
2334 return TLO.CombineTo(Op, *IsSLT ? Op0 : Op1);
2335 Known = KnownBits::smin(Known0, Known1);
2336 break;
2337 case ISD::SMAX:
2338 if (std::optional<bool> IsSGE = KnownBits::sge(Known0, Known1))
2339 return TLO.CombineTo(Op, *IsSGE ? Op0 : Op1);
2340 if (std::optional<bool> IsSGT = KnownBits::sgt(Known0, Known1))
2341 return TLO.CombineTo(Op, *IsSGT ? Op0 : Op1);
2342 Known = KnownBits::smax(Known0, Known1);
2343 break;
2344 case ISD::UMIN:
2345 if (std::optional<bool> IsULE = KnownBits::ule(Known0, Known1))
2346 return TLO.CombineTo(Op, *IsULE ? Op0 : Op1);
2347 if (std::optional<bool> IsULT = KnownBits::ult(Known0, Known1))
2348 return TLO.CombineTo(Op, *IsULT ? Op0 : Op1);
2349 Known = KnownBits::umin(Known0, Known1);
2350 break;
2351 case ISD::UMAX:
2352 if (std::optional<bool> IsUGE = KnownBits::uge(Known0, Known1))
2353 return TLO.CombineTo(Op, *IsUGE ? Op0 : Op1);
2354 if (std::optional<bool> IsUGT = KnownBits::ugt(Known0, Known1))
2355 return TLO.CombineTo(Op, *IsUGT ? Op0 : Op1);
2356 Known = KnownBits::umax(Known0, Known1);
2357 break;
2358 }
2359 break;
2360 }
2361 case ISD::BITREVERSE: {
2362 SDValue Src = Op.getOperand(0);
2363 APInt DemandedSrcBits = DemandedBits.reverseBits();
2364 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2365 Depth + 1))
2366 return true;
2367 Known.One = Known2.One.reverseBits();
2368 Known.Zero = Known2.Zero.reverseBits();
2369 break;
2370 }
2371 case ISD::BSWAP: {
2372 SDValue Src = Op.getOperand(0);
2373
2374 // If the only bits demanded come from one byte of the bswap result,
2375 // just shift the input byte into position to eliminate the bswap.
2376 unsigned NLZ = DemandedBits.countl_zero();
2377 unsigned NTZ = DemandedBits.countr_zero();
2378
2379 // Round NTZ down to the next byte. If we have 11 trailing zeros, then
2380 // we need all the bits down to bit 8. Likewise, round NLZ. If we
2381 // have 14 leading zeros, round to 8.
2382 NLZ = alignDown(NLZ, 8);
2383 NTZ = alignDown(NTZ, 8);
2384 // If we need exactly one byte, we can do this transformation.
2385 if (BitWidth - NLZ - NTZ == 8) {
2386 // Replace this with either a left or right shift to get the byte into
2387 // the right place.
2388 unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
2389 if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
2390 unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
2391 SDValue ShAmt = TLO.DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
2392 SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
2393 return TLO.CombineTo(Op, NewOp);
2394 }
2395 }
2396
2397 APInt DemandedSrcBits = DemandedBits.byteSwap();
2398 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
2399 Depth + 1))
2400 return true;
2401 Known.One = Known2.One.byteSwap();
2402 Known.Zero = Known2.Zero.byteSwap();
2403 break;
2404 }
2405 case ISD::CTPOP: {
2406 // If only 1 bit is demanded, replace with PARITY as long as we're before
2407 // op legalization.
2408 // FIXME: Limit to scalars for now.
2409 if (DemandedBits.isOne() && !TLO.LegalOps && !VT.isVector())
2410 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::PARITY, dl, VT,
2411 Op.getOperand(0)));
2412
2413 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2414 break;
2415 }
2416 case ISD::SIGN_EXTEND_INREG: {
2417 SDValue Op0 = Op.getOperand(0);
2418 EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2419 unsigned ExVTBits = ExVT.getScalarSizeInBits();
2420
2421 // If we only care about the highest bit, don't bother shifting right.
2422 if (DemandedBits.isSignMask()) {
2423 unsigned MinSignedBits =
2424 TLO.DAG.ComputeMaxSignificantBits(Op0, DemandedElts, Depth + 1);
2425 bool AlreadySignExtended = ExVTBits >= MinSignedBits;
2426 // However if the input is already sign extended we expect the sign
2427 // extension to be dropped altogether later and do not simplify.
2428 if (!AlreadySignExtended) {
2429 // Compute the correct shift amount type, which must be getShiftAmountTy
2430 // for scalar types after legalization.
2431 SDValue ShiftAmt =
2432 TLO.DAG.getShiftAmountConstant(BitWidth - ExVTBits, VT, dl);
2433 return TLO.CombineTo(Op,
2434 TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
2435 }
2436 }
2437
2438 // If none of the extended bits are demanded, eliminate the sextinreg.
2439 if (DemandedBits.getActiveBits() <= ExVTBits)
2440 return TLO.CombineTo(Op, Op0);
2441
2442 APInt InputDemandedBits = DemandedBits.getLoBits(ExVTBits);
2443
2444 // Since the sign extended bits are demanded, we know that the sign
2445 // bit is demanded.
2446 InputDemandedBits.setBit(ExVTBits - 1);
2447
2448 if (SimplifyDemandedBits(Op0, InputDemandedBits, DemandedElts, Known, TLO,
2449 Depth + 1))
2450 return true;
2451
2452 // If the sign bit of the input is known set or clear, then we know the
2453 // top bits of the result.
2454
2455 // If the input sign bit is known zero, convert this into a zero extension.
2456 if (Known.Zero[ExVTBits - 1])
2457 return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT));
2458
2459 APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits);
2460 if (Known.One[ExVTBits - 1]) { // Input sign bit known set
2461 Known.One.setBitsFrom(ExVTBits);
2462 Known.Zero &= Mask;
2463 } else { // Input sign bit unknown
2464 Known.Zero &= Mask;
2465 Known.One &= Mask;
2466 }
2467 break;
2468 }
2469 case ISD::BUILD_PAIR: {
2470 EVT HalfVT = Op.getOperand(0).getValueType();
2471 unsigned HalfBitWidth = HalfVT.getScalarSizeInBits();
2472
2473 APInt MaskLo = DemandedBits.getLoBits(HalfBitWidth).trunc(HalfBitWidth);
2474 APInt MaskHi = DemandedBits.getHiBits(HalfBitWidth).trunc(HalfBitWidth);
2475
2476 KnownBits KnownLo, KnownHi;
2477
2478 if (SimplifyDemandedBits(Op.getOperand(0), MaskLo, KnownLo, TLO, Depth + 1))
2479 return true;
2480
2481 if (SimplifyDemandedBits(Op.getOperand(1), MaskHi, KnownHi, TLO, Depth + 1))
2482 return true;
2483
2484 Known = KnownHi.concat(KnownLo);
2485 break;
2486 }
2487 case ISD::ZERO_EXTEND_VECTOR_INREG:
2488 if (VT.isScalableVector())
2489 return false;
2490 [[fallthrough]];
2491 case ISD::ZERO_EXTEND: {
2492 SDValue Src = Op.getOperand(0);
2493 EVT SrcVT = Src.getValueType();
2494 unsigned InBits = SrcVT.getScalarSizeInBits();
2495 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2496 bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
2497
2498 // If none of the top bits are demanded, convert this into an any_extend.
2499 if (DemandedBits.getActiveBits() <= InBits) {
2500 // If we only need the non-extended bits of the bottom element
2501 // then we can just bitcast to the result.
2502 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2503 VT.getSizeInBits() == SrcVT.getSizeInBits())
2504 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2505
2506 unsigned Opc =
2507 IsVecInReg ? ISD::ANY_EXTEND_VECTOR_INREG : ISD::ANY_EXTEND;
2508 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2509 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2510 }
2511
2512 APInt InDemandedBits = DemandedBits.trunc(InBits);
2513 APInt InDemandedElts = DemandedElts.zext(InElts);
2514 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2515 Depth + 1)) {
2516 Op->dropFlags(SDNodeFlags::NonNeg);
2517 return true;
2518 }
2519 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2520 Known = Known.zext(BitWidth);
2521
2522 // Attempt to avoid multi-use ops if we don't need anything from them.
2523 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2524 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2525 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2526 break;
2527 }
2528 case ISD::SIGN_EXTEND_VECTOR_INREG:
2529 if (VT.isScalableVector())
2530 return false;
2531 [[fallthrough]];
2532 case ISD::SIGN_EXTEND: {
2533 SDValue Src = Op.getOperand(0);
2534 EVT SrcVT = Src.getValueType();
2535 unsigned InBits = SrcVT.getScalarSizeInBits();
2536 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2537 bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
2538
2539 APInt InDemandedElts = DemandedElts.zext(InElts);
2540 APInt InDemandedBits = DemandedBits.trunc(InBits);
2541
2542 // Since some of the sign extended bits are demanded, we know that the sign
2543 // bit is demanded.
2544 InDemandedBits.setBit(InBits - 1);
2545
2546 // If none of the top bits are demanded, convert this into an any_extend.
2547 if (DemandedBits.getActiveBits() <= InBits) {
2548 // If we only need the non-extended bits of the bottom element
2549 // then we can just bitcast to the result.
2550 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2551 VT.getSizeInBits() == SrcVT.getSizeInBits())
2552 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2553
2554 // Don't lose an all signbits 0/-1 splat on targets with 0/-1 booleans.
2555 if (getBooleanContents(VT) != ZeroOrNegativeOneBooleanContent ||
2556 TLO.DAG.ComputeNumSignBits(Src, InDemandedElts, Depth + 1) !=
2557 InBits) {
2558 unsigned Opc =
2559 IsVecInReg ? ISD::ANY_EXTEND_VECTOR_INREG : ISD::ANY_EXTEND;
2560 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT))
2561 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src));
2562 }
2563 }
2564
2565 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2566 Depth + 1))
2567 return true;
2568 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2569
2570 // If the sign bit is known one, the top bits match.
2571 Known = Known.sext(BitWidth);
2572
2573 // If the sign bit is known zero, convert this to a zero extend.
2574 if (Known.isNonNegative()) {
2575 unsigned Opc =
2576 IsVecInReg ? ISD::ZERO_EXTEND_VECTOR_INREG : ISD::ZERO_EXTEND;
2577 if (!TLO.LegalOperations() || isOperationLegal(Opc, VT)) {
2578 SDNodeFlags Flags;
2579 if (!IsVecInReg)
2580 Flags |= SDNodeFlags::NonNeg;
2581 return TLO.CombineTo(Op, TLO.DAG.getNode(Opc, dl, VT, Src, Flags));
2582 }
2583 }
2584
2585 // Attempt to avoid multi-use ops if we don't need anything from them.
2586 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2587 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2588 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2589 break;
2590 }
2591 case ISD::ANY_EXTEND_VECTOR_INREG:
2592 if (VT.isScalableVector())
2593 return false;
2594 [[fallthrough]];
2595 case ISD::ANY_EXTEND: {
2596 SDValue Src = Op.getOperand(0);
2597 EVT SrcVT = Src.getValueType();
2598 unsigned InBits = SrcVT.getScalarSizeInBits();
2599 unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1;
2600 bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG;
2601
2602 // If we only need the bottom element then we can just bitcast.
2603 // TODO: Handle ANY_EXTEND?
2604 if (IsLE && IsVecInReg && DemandedElts == 1 &&
2605 VT.getSizeInBits() == SrcVT.getSizeInBits())
2606 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
2607
2608 APInt InDemandedBits = DemandedBits.trunc(InBits);
2609 APInt InDemandedElts = DemandedElts.zext(InElts);
2610 if (SimplifyDemandedBits(Src, InDemandedBits, InDemandedElts, Known, TLO,
2611 Depth + 1))
2612 return true;
2613 assert(Known.getBitWidth() == InBits && "Src width has changed?");
2614 Known = Known.anyext(BitWidth);
2615
2616 // Attempt to avoid multi-use ops if we don't need anything from them.
2617 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2618 Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1))
2619 return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc));
2620 break;
2621 }
2622 case ISD::TRUNCATE: {
2623 SDValue Src = Op.getOperand(0);
2624
2625 // Simplify the input, using demanded bit information, and compute the known
2626 // zero/one bits live out.
2627 unsigned OperandBitWidth = Src.getScalarValueSizeInBits();
2628 APInt TruncMask = DemandedBits.zext(OperandBitWidth);
2629 if (SimplifyDemandedBits(Src, TruncMask, DemandedElts, Known, TLO,
2630 Depth + 1)) {
2631 // Disable the nsw and nuw flags. We can no longer guarantee that we
2632 // won't wrap after simplification.
2633 Op->dropFlags(SDNodeFlags::NoWrap);
2634 return true;
2635 }
2636 Known = Known.trunc(BitWidth);
2637
2638 // Attempt to avoid multi-use ops if we don't need anything from them.
2639 if (SDValue NewSrc = SimplifyMultipleUseDemandedBits(
2640 Src, TruncMask, DemandedElts, TLO.DAG, Depth + 1))
2641 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, NewSrc));
2642
2643 // If the input is only used by this truncate, see if we can shrink it based
2644 // on the known demanded bits.
2645 switch (Src.getOpcode()) {
2646 default:
2647 break;
2648 case ISD::SRL:
2649 // Shrink SRL by a constant if none of the high bits shifted in are
2650 // demanded.
2651 if (TLO.LegalTypes() && !isTypeDesirableForOp(ISD::SRL, VT))
2652 // Do not turn (vt1 truncate (vt2 srl)) into (vt1 srl) if vt1 is
2653 // undesirable.
2654 break;
2655
2656 if (Src.getNode()->hasOneUse()) {
2657 if (isTruncateFree(Src, VT) &&
2658 !isTruncateFree(Src.getValueType(), VT)) {
2659 // If truncate is only free at trunc(srl), do not turn it into
2660 // srl(trunc). The check is done by first check the truncate is free
2661 // at Src's opcode(srl), then check the truncate is not done by
2662 // referencing sub-register. In test, if both trunc(srl) and
2663 // srl(trunc)'s trunc are free, srl(trunc) performs better. If only
2664 // trunc(srl)'s trunc is free, trunc(srl) is better.
2665 break;
2666 }
2667
2668 std::optional<uint64_t> ShAmtC =
2669 TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
2670 if (!ShAmtC || *ShAmtC >= BitWidth)
2671 break;
2672 uint64_t ShVal = *ShAmtC;
2673
2674 APInt HighBits =
2675 APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
2676 HighBits.lshrInPlace(ShVal);
2677 HighBits = HighBits.trunc(BitWidth);
2678 if (!(HighBits & DemandedBits)) {
2679 // None of the shifted in bits are needed. Add a truncate of the
2680 // shift input, then shift it.
2681 SDValue NewShAmt = TLO.DAG.getShiftAmountConstant(ShVal, VT, dl);
2682 SDValue NewTrunc =
2683 TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
2684 return TLO.CombineTo(
2685 Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, NewShAmt));
2686 }
2687 }
2688 break;
2689 }
2690
2691 break;
2692 }
2693 case ISD::AssertZext: {
2694 // AssertZext demands all of the high bits, plus any of the low bits
2695 // demanded by its users.
2696 EVT ZVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
2697 APInt InMask = APInt::getLowBitsSet(BitWidth, ZVT.getSizeInBits());
2698 if (SimplifyDemandedBits(Op.getOperand(0), ~InMask | DemandedBits, Known,
2699 TLO, Depth + 1))
2700 return true;
2701
2702 Known.Zero |= ~InMask;
2703 Known.One &= (~Known.Zero);
2704 break;
2705 }
2706 case ISD::EXTRACT_VECTOR_ELT: {
2707 SDValue Src = Op.getOperand(0);
2708 SDValue Idx = Op.getOperand(1);
2709 ElementCount SrcEltCnt = Src.getValueType().getVectorElementCount();
2710 unsigned EltBitWidth = Src.getScalarValueSizeInBits();
2711
2712 if (SrcEltCnt.isScalable())
2713 return false;
2714
2715 // Demand the bits from every vector element without a constant index.
2716 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
2717 APInt DemandedSrcElts = APInt::getAllOnes(NumSrcElts);
2718 if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx))
2719 if (CIdx->getAPIntValue().ult(NumSrcElts))
2720 DemandedSrcElts = APInt::getOneBitSet(NumSrcElts, CIdx->getZExtValue());
2721
2722 // If BitWidth > EltBitWidth the value is anyext:ed. So we do not know
2723 // anything about the extended bits.
2724 APInt DemandedSrcBits = DemandedBits;
2725 if (BitWidth > EltBitWidth)
2726 DemandedSrcBits = DemandedSrcBits.trunc(EltBitWidth);
2727
2728 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts, Known2, TLO,
2729 Depth + 1))
2730 return true;
2731
2732 // Attempt to avoid multi-use ops if we don't need anything from them.
2733 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2734 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2735 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2736 SDValue NewOp =
2737 TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc, Idx);
2738 return TLO.CombineTo(Op, NewOp);
2739 }
2740 }
2741
2742 Known = Known2;
2743 if (BitWidth > EltBitWidth)
2744 Known = Known.anyext(BitWidth);
2745 break;
2746 }
2747 case ISD::BITCAST: {
2748 if (VT.isScalableVector())
2749 return false;
2750 SDValue Src = Op.getOperand(0);
2751 EVT SrcVT = Src.getValueType();
2752 unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
2753
2754 // If this is an FP->Int bitcast and if the sign bit is the only
2755 // thing demanded, turn this into a FGETSIGN.
2756 if (!TLO.LegalOperations() && !VT.isVector() && !SrcVT.isVector() &&
2757 DemandedBits == APInt::getSignMask(Op.getValueSizeInBits()) &&
2758 SrcVT.isFloatingPoint()) {
2759 bool OpVTLegal = isOperationLegalOrCustom(ISD::FGETSIGN, VT);
2760 bool i32Legal = isOperationLegalOrCustom(ISD::FGETSIGN, MVT::i32);
2761 if ((OpVTLegal || i32Legal) && VT.isSimple() && SrcVT != MVT::f16 &&
2762 SrcVT != MVT::f128) {
2763 // Cannot eliminate/lower SHL for f128 yet.
2764 EVT Ty = OpVTLegal ? VT : MVT::i32;
2765 // Make a FGETSIGN + SHL to move the sign bit into the appropriate
2766 // place. We expect the SHL to be eliminated by other optimizations.
2767 SDValue Sign = TLO.DAG.getNode(ISD::FGETSIGN, dl, Ty, Src);
2768 unsigned OpVTSizeInBits = Op.getValueSizeInBits();
2769 if (!OpVTLegal && OpVTSizeInBits > 32)
2770 Sign = TLO.DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Sign);
2771 unsigned ShVal = Op.getValueSizeInBits() - 1;
2772 SDValue ShAmt = TLO.DAG.getConstant(ShVal, dl, VT);
2773 return TLO.CombineTo(Op,
2774 TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
2775 }
2776 }
2777
2778 // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
2779 // Demand the elt/bit if any of the original elts/bits are demanded.
2780 if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
2781 unsigned Scale = BitWidth / NumSrcEltBits;
2782 unsigned NumSrcElts = SrcVT.getVectorNumElements();
2783 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2784 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2785 for (unsigned i = 0; i != Scale; ++i) {
2786 unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
2787 unsigned BitOffset = EltOffset * NumSrcEltBits;
2788 APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
2789 if (!Sub.isZero()) {
2790 DemandedSrcBits |= Sub;
2791 for (unsigned j = 0; j != NumElts; ++j)
2792 if (DemandedElts[j])
2793 DemandedSrcElts.setBit((j * Scale) + i);
2794 }
2795 }
2796
2797 APInt KnownSrcUndef, KnownSrcZero;
2798 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2799 KnownSrcZero, TLO, Depth + 1))
2800 return true;
2801
2802 KnownBits KnownSrcBits;
2803 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2804 KnownSrcBits, TLO, Depth + 1))
2805 return true;
2806 } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
2807 // TODO - bigendian once we have test coverage.
2808 unsigned Scale = NumSrcEltBits / BitWidth;
2809 unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
2810 APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
2811 APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
2812 for (unsigned i = 0; i != NumElts; ++i)
2813 if (DemandedElts[i]) {
2814 unsigned Offset = (i % Scale) * BitWidth;
2815 DemandedSrcBits.insertBits(DemandedBits, Offset);
2816 DemandedSrcElts.setBit(i / Scale);
2817 }
2818
2819 if (SrcVT.isVector()) {
2820 APInt KnownSrcUndef, KnownSrcZero;
2821 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
2822 KnownSrcZero, TLO, Depth + 1))
2823 return true;
2824 }
2825
2826 KnownBits KnownSrcBits;
2827 if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
2828 KnownSrcBits, TLO, Depth + 1))
2829 return true;
2830
2831 // Attempt to avoid multi-use ops if we don't need anything from them.
2832 if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
2833 if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
2834 Src, DemandedSrcBits, DemandedSrcElts, TLO.DAG, Depth + 1)) {
2835 SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
2836 return TLO.CombineTo(Op, NewOp);
2837 }
2838 }
2839 }
2840
2841 // If this is a bitcast, let computeKnownBits handle it. Only do this on a
2842 // recursive call where Known may be useful to the caller.
2843 if (Depth > 0) {
2844 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
2845 return false;
2846 }
2847 break;
2848 }
2849 case ISD::MUL:
2850 if (DemandedBits.isPowerOf2()) {
2851 // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
2852 // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
2853 // odd (has LSB set), then the left-shifted low bit of X is the answer.
2854 unsigned CTZ = DemandedBits.countr_zero();
2855 ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
2856 if (C && C->getAPIntValue().countr_zero() == CTZ) {
2857 SDValue AmtC = TLO.DAG.getShiftAmountConstant(CTZ, VT, dl);
2858 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
2859 return TLO.CombineTo(Op, Shl);
2860 }
2861 }
2862 // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
2863 // X * X is odd iff X is odd.
2864 // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
2865 if (Op.getOperand(0) == Op.getOperand(1) && DemandedBits.ult(4)) {
2866 SDValue One = TLO.DAG.getConstant(1, dl, VT);
2867 SDValue And1 = TLO.DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), One);
2868 return TLO.CombineTo(Op, And1);
2869 }
2870 [[fallthrough]];
2871 case ISD::ADD:
2872 case ISD::SUB: {
2873 // Add, Sub, and Mul don't demand any bits in positions beyond that
2874 // of the highest bit demanded of them.
2875 SDValue Op0 = Op.getOperand(0), Op1 = Op.getOperand(1);
2876 SDNodeFlags Flags = Op.getNode()->getFlags();
2877 unsigned DemandedBitsLZ = DemandedBits.countl_zero();
2878 APInt LoMask = APInt::getLowBitsSet(BitWidth, BitWidth - DemandedBitsLZ);
2879 KnownBits KnownOp0, KnownOp1;
2880 auto GetDemandedBitsLHSMask = [&](APInt Demanded,
2881 const KnownBits &KnownRHS) {
2882 if (Op.getOpcode() == ISD::MUL)
2883 Demanded.clearHighBits(KnownRHS.countMinTrailingZeros());
2884 return Demanded;
2885 };
2886 if (SimplifyDemandedBits(Op1, LoMask, DemandedElts, KnownOp1, TLO,
2887 Depth + 1) ||
2888 SimplifyDemandedBits(Op0, GetDemandedBitsLHSMask(LoMask, KnownOp1),
2889 DemandedElts, KnownOp0, TLO, Depth + 1) ||
2890 // See if the operation should be performed at a smaller bit width.
2891 ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) {
2892 // Disable the nsw and nuw flags. We can no longer guarantee that we
2893 // won't wrap after simplification.
2894 Op->dropFlags(SDNodeFlags::NoWrap);
2895 return true;
2896 }
2897
2898 // neg x with only low bit demanded is simply x.
2899 if (Op.getOpcode() == ISD::SUB && DemandedBits.isOne() &&
2900 isNullConstant(Op0))
2901 return TLO.CombineTo(Op, Op1);
2902
2903 // Attempt to avoid multi-use ops if we don't need anything from them.
2904 if (!LoMask.isAllOnes() || !DemandedElts.isAllOnes()) {
2905 SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits(
2906 Op0, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2907 SDValue DemandedOp1 = SimplifyMultipleUseDemandedBits(
2908 Op1, LoMask, DemandedElts, TLO.DAG, Depth + 1);
2909 if (DemandedOp0 || DemandedOp1) {
2910 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
2911 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
2912 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Op1,
2913 Flags & ~SDNodeFlags::NoWrap);
2914 return TLO.CombineTo(Op, NewOp);
2915 }
2916 }
2917
2918 // If we have a constant operand, we may be able to turn it into -1 if we
2919 // do not demand the high bits. This can make the constant smaller to
2920 // encode, allow more general folding, or match specialized instruction
2921 // patterns (eg, 'blsr' on x86). Don't bother changing 1 to -1 because that
2922 // is probably not useful (and could be detrimental).
2923 ConstantSDNode *C = isConstOrConstSplat(Op1);
2924 APInt HighMask = APInt::getHighBitsSet(BitWidth, DemandedBitsLZ);
2925 if (C && !C->isAllOnes() && !C->isOne() &&
2926 (C->getAPIntValue() | HighMask).isAllOnes()) {
2927 SDValue Neg1 = TLO.DAG.getAllOnesConstant(dl, VT);
2928 // Disable the nsw and nuw flags. We can no longer guarantee that we
2929 // won't wrap after simplification.
2930 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, Op0, Neg1,
2931 Flags & ~SDNodeFlags::NoWrap);
2932 return TLO.CombineTo(Op, NewOp);
2933 }
2934
2935 // Match a multiply with a disguised negated-power-of-2 and convert to a
2936 // an equivalent shift-left amount.
2937 // Example: (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2938 auto getShiftLeftAmt = [&HighMask](SDValue Mul) -> unsigned {
2939 if (Mul.getOpcode() != ISD::MUL || !Mul.hasOneUse())
2940 return 0;
2941
2942 // Don't touch opaque constants. Also, ignore zero and power-of-2
2943 // multiplies. Those will get folded later.
2944 ConstantSDNode *MulC = isConstOrConstSplat(Mul.getOperand(1));
2945 if (MulC && !MulC->isOpaque() && !MulC->isZero() &&
2946 !MulC->getAPIntValue().isPowerOf2()) {
2947 APInt UnmaskedC = MulC->getAPIntValue() | HighMask;
2948 if (UnmaskedC.isNegatedPowerOf2())
2949 return (-UnmaskedC).logBase2();
2950 }
2951 return 0;
2952 };
2953
2954 auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y,
2955 unsigned ShlAmt) {
2956 SDValue ShlAmtC = TLO.DAG.getShiftAmountConstant(ShlAmt, VT, dl);
2957 SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
2958 SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
2959 return TLO.CombineTo(Op, Res);
2960 };
2961
2962 if (isOperationLegalOrCustom(ISD::SHL, VT)) {
2963 if (Op.getOpcode() == ISD::ADD) {
2964 // (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
2965 if (unsigned ShAmt = getShiftLeftAmt(Op0))
2966 return foldMul(ISD::SUB, Op0.getOperand(0), Op1, ShAmt);
2967 // Op0 + (X * MulC) --> Op0 - (X << log2(-MulC))
2968 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2969 return foldMul(ISD::SUB, Op1.getOperand(0), Op0, ShAmt);
2970 }
2971 if (Op.getOpcode() == ISD::SUB) {
2972 // Op0 - (X * MulC) --> Op0 + (X << log2(-MulC))
2973 if (unsigned ShAmt = getShiftLeftAmt(Op1))
2974 return foldMul(ISD::ADD, Op1.getOperand(0), Op0, ShAmt);
2975 }
2976 }
2977
2978 if (Op.getOpcode() == ISD::MUL) {
2979 Known = KnownBits::mul(KnownOp0, KnownOp1);
2980 } else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
2981 Known = KnownBits::computeForAddSub(
2982 Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
2983 Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
2984 }
2985 break;
2986 }
2987 case ISD::FABS: {
2988 SDValue Op0 = Op.getOperand(0);
2989 APInt SignMask = APInt::getSignMask(BitWidth);
2990
2991 if (!DemandedBits.intersects(SignMask))
2992 return TLO.CombineTo(Op, Op0);
2993
2994 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known, TLO,
2995 Depth + 1))
2996 return true;
2997
2998 if (Known.isNonNegative())
2999 return TLO.CombineTo(Op, Op0);
3000 if (Known.isNegative())
3001 return TLO.CombineTo(
3002 Op, TLO.DAG.getNode(ISD::FNEG, dl, VT, Op0, Op->getFlags()));
3003
3004 Known.Zero |= SignMask;
3005 Known.One &= ~SignMask;
3006
3007 break;
3008 }
3009 case ISD::FCOPYSIGN: {
3010 SDValue Op0 = Op.getOperand(0);
3011 SDValue Op1 = Op.getOperand(1);
3012
3013 unsigned BitWidth0 = Op0.getScalarValueSizeInBits();
3014 unsigned BitWidth1 = Op1.getScalarValueSizeInBits();
3015 APInt SignMask0 = APInt::getSignMask(BitWidth0);
3016 APInt SignMask1 = APInt::getSignMask(BitWidth1);
3017
3018 if (!DemandedBits.intersects(SignMask0))
3019 return TLO.CombineTo(Op, Op0);
3020
3021 if (SimplifyDemandedBits(Op0, ~SignMask0 & DemandedBits, DemandedElts,
3022 Known, TLO, Depth + 1) ||
3023 SimplifyDemandedBits(Op1, SignMask1, DemandedElts, Known2, TLO,
3024 Depth + 1))
3025 return true;
3026
3027 if (Known2.isNonNegative())
3028 return TLO.CombineTo(
3029 Op, TLO.DAG.getNode(ISD::FABS, dl, VT, Op0, Op->getFlags()));
3030
3031 if (Known2.isNegative())
3032 return TLO.CombineTo(
3033 Op, TLO.DAG.getNode(ISD::FNEG, dl, VT,
3034 TLO.DAG.getNode(ISD::FABS, SDLoc(Op0), VT, Op0)));
3035
3036 Known.Zero &= ~SignMask0;
3037 Known.One &= ~SignMask0;
3038 break;
3039 }
3040 case ISD::FNEG: {
3041 SDValue Op0 = Op.getOperand(0);
3042 APInt SignMask = APInt::getSignMask(BitWidth);
3043
3044 if (!DemandedBits.intersects(SignMask))
3045 return TLO.CombineTo(Op, Op0);
3046
3047 if (SimplifyDemandedBits(Op0, DemandedBits, DemandedElts, Known, TLO,
3048 Depth + 1))
3049 return true;
3050
3051 if (!Known.isSignUnknown()) {
3052 Known.Zero ^= SignMask;
3053 Known.One ^= SignMask;
3054 }
3055
3056 break;
3057 }
3058 default:
3059 // We also ask the target about intrinsics (which could be specific to it).
3060 if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3061 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
3062 // TODO: Probably okay to remove after audit; here to reduce change size
3063 // in initial enablement patch for scalable vectors
3064 if (Op.getValueType().isScalableVector())
3065 break;
3066 if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts,
3067 Known, TLO, Depth))
3068 return true;
3069 break;
3070 }
3071
3072 // Just use computeKnownBits to compute output bits.
3073 Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
3074 break;
3075 }
3076
3077 // If we know the value of all of the demanded bits, return this as a
3078 // constant.
3079 if (!isTargetCanonicalConstantNode(Op) &&
3080 DemandedBits.isSubsetOf(Known.Zero | Known.One)) {
3081 // Avoid folding to a constant if any OpaqueConstant is involved.
3082 if (llvm::any_of(Op->ops(), [](SDValue V) {
3083 auto *C = dyn_cast<ConstantSDNode>(V);
3084 return C && C->isOpaque();
3085 }))
3086 return false;
3087 if (VT.isInteger())
3088 return TLO.CombineTo(Op, TLO.DAG.getConstant(Known.One, dl, VT));
3089 if (VT.isFloatingPoint())
3090 return TLO.CombineTo(
3091 Op, TLO.DAG.getConstantFP(APFloat(VT.getFltSemantics(), Known.One),
3092 dl, VT));
3093 }
3094
3095 // A multi use 'all demanded elts' simplify failed to find any knownbits.
3096 // Try again just for the original demanded elts.
3097 // Ensure we do this AFTER constant folding above.
3098 if (HasMultiUse && Known.isUnknown() && !OriginalDemandedElts.isAllOnes())
3099 Known = TLO.DAG.computeKnownBits(Op, OriginalDemandedElts, Depth);
3100
3101 return false;
3102 }
3103
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,DAGCombinerInfo & DCI) const3104 bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
3105 const APInt &DemandedElts,
3106 DAGCombinerInfo &DCI) const {
3107 SelectionDAG &DAG = DCI.DAG;
3108 TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
3109 !DCI.isBeforeLegalizeOps());
3110
3111 APInt KnownUndef, KnownZero;
3112 bool Simplified =
3113 SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
3114 if (Simplified) {
3115 DCI.AddToWorklist(Op.getNode());
3116 DCI.CommitTargetLoweringOpt(TLO);
3117 }
3118
3119 return Simplified;
3120 }
3121
3122 /// Given a vector binary operation and known undefined elements for each input
3123 /// operand, compute whether each element of the output is undefined.
getKnownUndefForVectorBinop(SDValue BO,SelectionDAG & DAG,const APInt & UndefOp0,const APInt & UndefOp1)3124 static APInt getKnownUndefForVectorBinop(SDValue BO, SelectionDAG &DAG,
3125 const APInt &UndefOp0,
3126 const APInt &UndefOp1) {
3127 EVT VT = BO.getValueType();
3128 assert(DAG.getTargetLoweringInfo().isBinOp(BO.getOpcode()) && VT.isVector() &&
3129 "Vector binop only");
3130
3131 EVT EltVT = VT.getVectorElementType();
3132 unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1;
3133 assert(UndefOp0.getBitWidth() == NumElts &&
3134 UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis");
3135
3136 auto getUndefOrConstantElt = [&](SDValue V, unsigned Index,
3137 const APInt &UndefVals) {
3138 if (UndefVals[Index])
3139 return DAG.getUNDEF(EltVT);
3140
3141 if (auto *BV = dyn_cast<BuildVectorSDNode>(V)) {
3142 // Try hard to make sure that the getNode() call is not creating temporary
3143 // nodes. Ignore opaque integers because they do not constant fold.
3144 SDValue Elt = BV->getOperand(Index);
3145 auto *C = dyn_cast<ConstantSDNode>(Elt);
3146 if (isa<ConstantFPSDNode>(Elt) || Elt.isUndef() || (C && !C->isOpaque()))
3147 return Elt;
3148 }
3149
3150 return SDValue();
3151 };
3152
3153 APInt KnownUndef = APInt::getZero(NumElts);
3154 for (unsigned i = 0; i != NumElts; ++i) {
3155 // If both inputs for this element are either constant or undef and match
3156 // the element type, compute the constant/undef result for this element of
3157 // the vector.
3158 // TODO: Ideally we would use FoldConstantArithmetic() here, but that does
3159 // not handle FP constants. The code within getNode() should be refactored
3160 // to avoid the danger of creating a bogus temporary node here.
3161 SDValue C0 = getUndefOrConstantElt(BO.getOperand(0), i, UndefOp0);
3162 SDValue C1 = getUndefOrConstantElt(BO.getOperand(1), i, UndefOp1);
3163 if (C0 && C1 && C0.getValueType() == EltVT && C1.getValueType() == EltVT)
3164 if (DAG.getNode(BO.getOpcode(), SDLoc(BO), EltVT, C0, C1).isUndef())
3165 KnownUndef.setBit(i);
3166 }
3167 return KnownUndef;
3168 }
3169
SimplifyDemandedVectorElts(SDValue Op,const APInt & OriginalDemandedElts,APInt & KnownUndef,APInt & KnownZero,TargetLoweringOpt & TLO,unsigned Depth,bool AssumeSingleUse) const3170 bool TargetLowering::SimplifyDemandedVectorElts(
3171 SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
3172 APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
3173 bool AssumeSingleUse) const {
3174 EVT VT = Op.getValueType();
3175 unsigned Opcode = Op.getOpcode();
3176 APInt DemandedElts = OriginalDemandedElts;
3177 unsigned NumElts = DemandedElts.getBitWidth();
3178 assert(VT.isVector() && "Expected vector op");
3179
3180 KnownUndef = KnownZero = APInt::getZero(NumElts);
3181
3182 if (!shouldSimplifyDemandedVectorElts(Op, TLO))
3183 return false;
3184
3185 // TODO: For now we assume we know nothing about scalable vectors.
3186 if (VT.isScalableVector())
3187 return false;
3188
3189 assert(VT.getVectorNumElements() == NumElts &&
3190 "Mask size mismatches value type element count!");
3191
3192 // Undef operand.
3193 if (Op.isUndef()) {
3194 KnownUndef.setAllBits();
3195 return false;
3196 }
3197
3198 // If Op has other users, assume that all elements are needed.
3199 if (!AssumeSingleUse && !Op.getNode()->hasOneUse())
3200 DemandedElts.setAllBits();
3201
3202 // Not demanding any elements from Op.
3203 if (DemandedElts == 0) {
3204 KnownUndef.setAllBits();
3205 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3206 }
3207
3208 // Limit search depth.
3209 if (Depth >= SelectionDAG::MaxRecursionDepth)
3210 return false;
3211
3212 SDLoc DL(Op);
3213 unsigned EltSizeInBits = VT.getScalarSizeInBits();
3214 bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
3215
3216 // Helper for demanding the specified elements and all the bits of both binary
3217 // operands.
3218 auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
3219 SDValue NewOp0 = SimplifyMultipleUseDemandedVectorElts(Op0, DemandedElts,
3220 TLO.DAG, Depth + 1);
3221 SDValue NewOp1 = SimplifyMultipleUseDemandedVectorElts(Op1, DemandedElts,
3222 TLO.DAG, Depth + 1);
3223 if (NewOp0 || NewOp1) {
3224 SDValue NewOp =
3225 TLO.DAG.getNode(Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0,
3226 NewOp1 ? NewOp1 : Op1, Op->getFlags());
3227 return TLO.CombineTo(Op, NewOp);
3228 }
3229 return false;
3230 };
3231
3232 switch (Opcode) {
3233 case ISD::SCALAR_TO_VECTOR: {
3234 if (!DemandedElts[0]) {
3235 KnownUndef.setAllBits();
3236 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3237 }
3238 SDValue ScalarSrc = Op.getOperand(0);
3239 if (ScalarSrc.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
3240 SDValue Src = ScalarSrc.getOperand(0);
3241 SDValue Idx = ScalarSrc.getOperand(1);
3242 EVT SrcVT = Src.getValueType();
3243
3244 ElementCount SrcEltCnt = SrcVT.getVectorElementCount();
3245
3246 if (SrcEltCnt.isScalable())
3247 return false;
3248
3249 unsigned NumSrcElts = SrcEltCnt.getFixedValue();
3250 if (isNullConstant(Idx)) {
3251 APInt SrcDemandedElts = APInt::getOneBitSet(NumSrcElts, 0);
3252 APInt SrcUndef = KnownUndef.zextOrTrunc(NumSrcElts);
3253 APInt SrcZero = KnownZero.zextOrTrunc(NumSrcElts);
3254 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3255 TLO, Depth + 1))
3256 return true;
3257 }
3258 }
3259 KnownUndef.setHighBits(NumElts - 1);
3260 break;
3261 }
3262 case ISD::BITCAST: {
3263 SDValue Src = Op.getOperand(0);
3264 EVT SrcVT = Src.getValueType();
3265
3266 if (!SrcVT.isVector()) {
3267 // TODO - bigendian once we have test coverage.
3268 if (IsLE) {
3269 APInt DemandedSrcBits = APInt::getZero(SrcVT.getSizeInBits());
3270 unsigned EltSize = VT.getScalarSizeInBits();
3271 for (unsigned I = 0; I != NumElts; ++I) {
3272 if (DemandedElts[I]) {
3273 unsigned Offset = I * EltSize;
3274 DemandedSrcBits.setBits(Offset, Offset + EltSize);
3275 }
3276 }
3277 KnownBits Known;
3278 if (SimplifyDemandedBits(Src, DemandedSrcBits, Known, TLO, Depth + 1))
3279 return true;
3280 }
3281 break;
3282 }
3283
3284 // Fast handling of 'identity' bitcasts.
3285 unsigned NumSrcElts = SrcVT.getVectorNumElements();
3286 if (NumSrcElts == NumElts)
3287 return SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef,
3288 KnownZero, TLO, Depth + 1);
3289
3290 APInt SrcDemandedElts, SrcZero, SrcUndef;
3291
3292 // Bitcast from 'large element' src vector to 'small element' vector, we
3293 // must demand a source element if any DemandedElt maps to it.
3294 if ((NumElts % NumSrcElts) == 0) {
3295 unsigned Scale = NumElts / NumSrcElts;
3296 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3297 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3298 TLO, Depth + 1))
3299 return true;
3300
3301 // Try calling SimplifyDemandedBits, converting demanded elts to the bits
3302 // of the large element.
3303 // TODO - bigendian once we have test coverage.
3304 if (IsLE) {
3305 unsigned SrcEltSizeInBits = SrcVT.getScalarSizeInBits();
3306 APInt SrcDemandedBits = APInt::getZero(SrcEltSizeInBits);
3307 for (unsigned i = 0; i != NumElts; ++i)
3308 if (DemandedElts[i]) {
3309 unsigned Ofs = (i % Scale) * EltSizeInBits;
3310 SrcDemandedBits.setBits(Ofs, Ofs + EltSizeInBits);
3311 }
3312
3313 KnownBits Known;
3314 if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcDemandedElts, Known,
3315 TLO, Depth + 1))
3316 return true;
3317
3318 // The bitcast has split each wide element into a number of
3319 // narrow subelements. We have just computed the Known bits
3320 // for wide elements. See if element splitting results in
3321 // some subelements being zero. Only for demanded elements!
3322 for (unsigned SubElt = 0; SubElt != Scale; ++SubElt) {
3323 if (!Known.Zero.extractBits(EltSizeInBits, SubElt * EltSizeInBits)
3324 .isAllOnes())
3325 continue;
3326 for (unsigned SrcElt = 0; SrcElt != NumSrcElts; ++SrcElt) {
3327 unsigned Elt = Scale * SrcElt + SubElt;
3328 if (DemandedElts[Elt])
3329 KnownZero.setBit(Elt);
3330 }
3331 }
3332 }
3333
3334 // If the src element is zero/undef then all the output elements will be -
3335 // only demanded elements are guaranteed to be correct.
3336 for (unsigned i = 0; i != NumSrcElts; ++i) {
3337 if (SrcDemandedElts[i]) {
3338 if (SrcZero[i])
3339 KnownZero.setBits(i * Scale, (i + 1) * Scale);
3340 if (SrcUndef[i])
3341 KnownUndef.setBits(i * Scale, (i + 1) * Scale);
3342 }
3343 }
3344 }
3345
3346 // Bitcast from 'small element' src vector to 'large element' vector, we
3347 // demand all smaller source elements covered by the larger demanded element
3348 // of this vector.
3349 if ((NumSrcElts % NumElts) == 0) {
3350 unsigned Scale = NumSrcElts / NumElts;
3351 SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3352 if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
3353 TLO, Depth + 1))
3354 return true;
3355
3356 // If all the src elements covering an output element are zero/undef, then
3357 // the output element will be as well, assuming it was demanded.
3358 for (unsigned i = 0; i != NumElts; ++i) {
3359 if (DemandedElts[i]) {
3360 if (SrcZero.extractBits(Scale, i * Scale).isAllOnes())
3361 KnownZero.setBit(i);
3362 if (SrcUndef.extractBits(Scale, i * Scale).isAllOnes())
3363 KnownUndef.setBit(i);
3364 }
3365 }
3366 }
3367 break;
3368 }
3369 case ISD::FREEZE: {
3370 SDValue N0 = Op.getOperand(0);
3371 if (TLO.DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
3372 /*PoisonOnly=*/false))
3373 return TLO.CombineTo(Op, N0);
3374
3375 // TODO: Replace this with the general fold from DAGCombiner::visitFREEZE
3376 // freeze(op(x, ...)) -> op(freeze(x), ...).
3377 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && DemandedElts == 1)
3378 return TLO.CombineTo(
3379 Op, TLO.DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT,
3380 TLO.DAG.getFreeze(N0.getOperand(0))));
3381 break;
3382 }
3383 case ISD::BUILD_VECTOR: {
3384 // Check all elements and simplify any unused elements with UNDEF.
3385 if (!DemandedElts.isAllOnes()) {
3386 // Don't simplify BROADCASTS.
3387 if (llvm::any_of(Op->op_values(),
3388 [&](SDValue Elt) { return Op.getOperand(0) != Elt; })) {
3389 SmallVector<SDValue, 32> Ops(Op->ops());
3390 bool Updated = false;
3391 for (unsigned i = 0; i != NumElts; ++i) {
3392 if (!DemandedElts[i] && !Ops[i].isUndef()) {
3393 Ops[i] = TLO.DAG.getUNDEF(Ops[0].getValueType());
3394 KnownUndef.setBit(i);
3395 Updated = true;
3396 }
3397 }
3398 if (Updated)
3399 return TLO.CombineTo(Op, TLO.DAG.getBuildVector(VT, DL, Ops));
3400 }
3401 }
3402 for (unsigned i = 0; i != NumElts; ++i) {
3403 SDValue SrcOp = Op.getOperand(i);
3404 if (SrcOp.isUndef()) {
3405 KnownUndef.setBit(i);
3406 } else if (EltSizeInBits == SrcOp.getScalarValueSizeInBits() &&
3407 (isNullConstant(SrcOp) || isNullFPConstant(SrcOp))) {
3408 KnownZero.setBit(i);
3409 }
3410 }
3411 break;
3412 }
3413 case ISD::CONCAT_VECTORS: {
3414 EVT SubVT = Op.getOperand(0).getValueType();
3415 unsigned NumSubVecs = Op.getNumOperands();
3416 unsigned NumSubElts = SubVT.getVectorNumElements();
3417 for (unsigned i = 0; i != NumSubVecs; ++i) {
3418 SDValue SubOp = Op.getOperand(i);
3419 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3420 APInt SubUndef, SubZero;
3421 if (SimplifyDemandedVectorElts(SubOp, SubElts, SubUndef, SubZero, TLO,
3422 Depth + 1))
3423 return true;
3424 KnownUndef.insertBits(SubUndef, i * NumSubElts);
3425 KnownZero.insertBits(SubZero, i * NumSubElts);
3426 }
3427
3428 // Attempt to avoid multi-use ops if we don't need anything from them.
3429 if (!DemandedElts.isAllOnes()) {
3430 bool FoundNewSub = false;
3431 SmallVector<SDValue, 2> DemandedSubOps;
3432 for (unsigned i = 0; i != NumSubVecs; ++i) {
3433 SDValue SubOp = Op.getOperand(i);
3434 APInt SubElts = DemandedElts.extractBits(NumSubElts, i * NumSubElts);
3435 SDValue NewSubOp = SimplifyMultipleUseDemandedVectorElts(
3436 SubOp, SubElts, TLO.DAG, Depth + 1);
3437 DemandedSubOps.push_back(NewSubOp ? NewSubOp : SubOp);
3438 FoundNewSub = NewSubOp ? true : FoundNewSub;
3439 }
3440 if (FoundNewSub) {
3441 SDValue NewOp =
3442 TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, DemandedSubOps);
3443 return TLO.CombineTo(Op, NewOp);
3444 }
3445 }
3446 break;
3447 }
3448 case ISD::INSERT_SUBVECTOR: {
3449 // Demand any elements from the subvector and the remainder from the src its
3450 // inserted into.
3451 SDValue Src = Op.getOperand(0);
3452 SDValue Sub = Op.getOperand(1);
3453 uint64_t Idx = Op.getConstantOperandVal(2);
3454 unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
3455 APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
3456 APInt DemandedSrcElts = DemandedElts;
3457 DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
3458
3459 APInt SubUndef, SubZero;
3460 if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
3461 Depth + 1))
3462 return true;
3463
3464 // If none of the src operand elements are demanded, replace it with undef.
3465 if (!DemandedSrcElts && !Src.isUndef())
3466 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
3467 TLO.DAG.getUNDEF(VT), Sub,
3468 Op.getOperand(2)));
3469
3470 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownUndef, KnownZero,
3471 TLO, Depth + 1))
3472 return true;
3473 KnownUndef.insertBits(SubUndef, Idx);
3474 KnownZero.insertBits(SubZero, Idx);
3475
3476 // Attempt to avoid multi-use ops if we don't need anything from them.
3477 if (!DemandedSrcElts.isAllOnes() || !DemandedSubElts.isAllOnes()) {
3478 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3479 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3480 SDValue NewSub = SimplifyMultipleUseDemandedVectorElts(
3481 Sub, DemandedSubElts, TLO.DAG, Depth + 1);
3482 if (NewSrc || NewSub) {
3483 NewSrc = NewSrc ? NewSrc : Src;
3484 NewSub = NewSub ? NewSub : Sub;
3485 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3486 NewSub, Op.getOperand(2));
3487 return TLO.CombineTo(Op, NewOp);
3488 }
3489 }
3490 break;
3491 }
3492 case ISD::EXTRACT_SUBVECTOR: {
3493 // Offset the demanded elts by the subvector index.
3494 SDValue Src = Op.getOperand(0);
3495 if (Src.getValueType().isScalableVector())
3496 break;
3497 uint64_t Idx = Op.getConstantOperandVal(1);
3498 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3499 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts).shl(Idx);
3500
3501 APInt SrcUndef, SrcZero;
3502 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3503 Depth + 1))
3504 return true;
3505 KnownUndef = SrcUndef.extractBits(NumElts, Idx);
3506 KnownZero = SrcZero.extractBits(NumElts, Idx);
3507
3508 // Attempt to avoid multi-use ops if we don't need anything from them.
3509 if (!DemandedElts.isAllOnes()) {
3510 SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts(
3511 Src, DemandedSrcElts, TLO.DAG, Depth + 1);
3512 if (NewSrc) {
3513 SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc,
3514 Op.getOperand(1));
3515 return TLO.CombineTo(Op, NewOp);
3516 }
3517 }
3518 break;
3519 }
3520 case ISD::INSERT_VECTOR_ELT: {
3521 SDValue Vec = Op.getOperand(0);
3522 SDValue Scl = Op.getOperand(1);
3523 auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
3524
3525 // For a legal, constant insertion index, if we don't need this insertion
3526 // then strip it, else remove it from the demanded elts.
3527 if (CIdx && CIdx->getAPIntValue().ult(NumElts)) {
3528 unsigned Idx = CIdx->getZExtValue();
3529 if (!DemandedElts[Idx])
3530 return TLO.CombineTo(Op, Vec);
3531
3532 APInt DemandedVecElts(DemandedElts);
3533 DemandedVecElts.clearBit(Idx);
3534 if (SimplifyDemandedVectorElts(Vec, DemandedVecElts, KnownUndef,
3535 KnownZero, TLO, Depth + 1))
3536 return true;
3537
3538 KnownUndef.setBitVal(Idx, Scl.isUndef());
3539
3540 KnownZero.setBitVal(Idx, isNullConstant(Scl) || isNullFPConstant(Scl));
3541 break;
3542 }
3543
3544 APInt VecUndef, VecZero;
3545 if (SimplifyDemandedVectorElts(Vec, DemandedElts, VecUndef, VecZero, TLO,
3546 Depth + 1))
3547 return true;
3548 // Without knowing the insertion index we can't set KnownUndef/KnownZero.
3549 break;
3550 }
3551 case ISD::VSELECT: {
3552 SDValue Sel = Op.getOperand(0);
3553 SDValue LHS = Op.getOperand(1);
3554 SDValue RHS = Op.getOperand(2);
3555
3556 // Try to transform the select condition based on the current demanded
3557 // elements.
3558 APInt UndefSel, ZeroSel;
3559 if (SimplifyDemandedVectorElts(Sel, DemandedElts, UndefSel, ZeroSel, TLO,
3560 Depth + 1))
3561 return true;
3562
3563 // See if we can simplify either vselect operand.
3564 APInt DemandedLHS(DemandedElts);
3565 APInt DemandedRHS(DemandedElts);
3566 APInt UndefLHS, ZeroLHS;
3567 APInt UndefRHS, ZeroRHS;
3568 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3569 Depth + 1))
3570 return true;
3571 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3572 Depth + 1))
3573 return true;
3574
3575 KnownUndef = UndefLHS & UndefRHS;
3576 KnownZero = ZeroLHS & ZeroRHS;
3577
3578 // If we know that the selected element is always zero, we don't need the
3579 // select value element.
3580 APInt DemandedSel = DemandedElts & ~KnownZero;
3581 if (DemandedSel != DemandedElts)
3582 if (SimplifyDemandedVectorElts(Sel, DemandedSel, UndefSel, ZeroSel, TLO,
3583 Depth + 1))
3584 return true;
3585
3586 break;
3587 }
3588 case ISD::VECTOR_SHUFFLE: {
3589 SDValue LHS = Op.getOperand(0);
3590 SDValue RHS = Op.getOperand(1);
3591 ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
3592
3593 // Collect demanded elements from shuffle operands..
3594 APInt DemandedLHS(NumElts, 0);
3595 APInt DemandedRHS(NumElts, 0);
3596 for (unsigned i = 0; i != NumElts; ++i) {
3597 int M = ShuffleMask[i];
3598 if (M < 0 || !DemandedElts[i])
3599 continue;
3600 assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range");
3601 if (M < (int)NumElts)
3602 DemandedLHS.setBit(M);
3603 else
3604 DemandedRHS.setBit(M - NumElts);
3605 }
3606
3607 // If either side isn't demanded, replace it by UNDEF. We handle this
3608 // explicitly here to also simplify in case of multiple uses (on the
3609 // contrary to the SimplifyDemandedVectorElts calls below).
3610 bool FoldLHS = !DemandedLHS && !LHS.isUndef();
3611 bool FoldRHS = !DemandedRHS && !RHS.isUndef();
3612 if (FoldLHS || FoldRHS) {
3613 LHS = FoldLHS ? TLO.DAG.getUNDEF(LHS.getValueType()) : LHS;
3614 RHS = FoldRHS ? TLO.DAG.getUNDEF(RHS.getValueType()) : RHS;
3615 SDValue NewOp =
3616 TLO.DAG.getVectorShuffle(VT, SDLoc(Op), LHS, RHS, ShuffleMask);
3617 return TLO.CombineTo(Op, NewOp);
3618 }
3619
3620 // See if we can simplify either shuffle operand.
3621 APInt UndefLHS, ZeroLHS;
3622 APInt UndefRHS, ZeroRHS;
3623 if (SimplifyDemandedVectorElts(LHS, DemandedLHS, UndefLHS, ZeroLHS, TLO,
3624 Depth + 1))
3625 return true;
3626 if (SimplifyDemandedVectorElts(RHS, DemandedRHS, UndefRHS, ZeroRHS, TLO,
3627 Depth + 1))
3628 return true;
3629
3630 // Simplify mask using undef elements from LHS/RHS.
3631 bool Updated = false;
3632 bool IdentityLHS = true, IdentityRHS = true;
3633 SmallVector<int, 32> NewMask(ShuffleMask);
3634 for (unsigned i = 0; i != NumElts; ++i) {
3635 int &M = NewMask[i];
3636 if (M < 0)
3637 continue;
3638 if (!DemandedElts[i] || (M < (int)NumElts && UndefLHS[M]) ||
3639 (M >= (int)NumElts && UndefRHS[M - NumElts])) {
3640 Updated = true;
3641 M = -1;
3642 }
3643 IdentityLHS &= (M < 0) || (M == (int)i);
3644 IdentityRHS &= (M < 0) || ((M - NumElts) == i);
3645 }
3646
3647 // Update legal shuffle masks based on demanded elements if it won't reduce
3648 // to Identity which can cause premature removal of the shuffle mask.
3649 if (Updated && !IdentityLHS && !IdentityRHS && !TLO.LegalOps) {
3650 SDValue LegalShuffle =
3651 buildLegalVectorShuffle(VT, DL, LHS, RHS, NewMask, TLO.DAG);
3652 if (LegalShuffle)
3653 return TLO.CombineTo(Op, LegalShuffle);
3654 }
3655
3656 // Propagate undef/zero elements from LHS/RHS.
3657 for (unsigned i = 0; i != NumElts; ++i) {
3658 int M = ShuffleMask[i];
3659 if (M < 0) {
3660 KnownUndef.setBit(i);
3661 } else if (M < (int)NumElts) {
3662 if (UndefLHS[M])
3663 KnownUndef.setBit(i);
3664 if (ZeroLHS[M])
3665 KnownZero.setBit(i);
3666 } else {
3667 if (UndefRHS[M - NumElts])
3668 KnownUndef.setBit(i);
3669 if (ZeroRHS[M - NumElts])
3670 KnownZero.setBit(i);
3671 }
3672 }
3673 break;
3674 }
3675 case ISD::ANY_EXTEND_VECTOR_INREG:
3676 case ISD::SIGN_EXTEND_VECTOR_INREG:
3677 case ISD::ZERO_EXTEND_VECTOR_INREG: {
3678 APInt SrcUndef, SrcZero;
3679 SDValue Src = Op.getOperand(0);
3680 unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
3681 APInt DemandedSrcElts = DemandedElts.zext(NumSrcElts);
3682 if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO,
3683 Depth + 1))
3684 return true;
3685 KnownZero = SrcZero.zextOrTrunc(NumElts);
3686 KnownUndef = SrcUndef.zextOrTrunc(NumElts);
3687
3688 if (IsLE && Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG &&
3689 Op.getValueSizeInBits() == Src.getValueSizeInBits() &&
3690 DemandedSrcElts == 1) {
3691 // aext - if we just need the bottom element then we can bitcast.
3692 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Src));
3693 }
3694
3695 if (Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
3696 // zext(undef) upper bits are guaranteed to be zero.
3697 if (DemandedElts.isSubsetOf(KnownUndef))
3698 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3699 KnownUndef.clearAllBits();
3700
3701 // zext - if we just need the bottom element then we can mask:
3702 // zext(and(x,c)) -> and(x,c') iff the zext is the only user of the and.
3703 if (IsLE && DemandedSrcElts == 1 && Src.getOpcode() == ISD::AND &&
3704 Op->isOnlyUserOf(Src.getNode()) &&
3705 Op.getValueSizeInBits() == Src.getValueSizeInBits()) {
3706 SDLoc DL(Op);
3707 EVT SrcVT = Src.getValueType();
3708 EVT SrcSVT = SrcVT.getScalarType();
3709 SmallVector<SDValue> MaskElts;
3710 MaskElts.push_back(TLO.DAG.getAllOnesConstant(DL, SrcSVT));
3711 MaskElts.append(NumSrcElts - 1, TLO.DAG.getConstant(0, DL, SrcSVT));
3712 SDValue Mask = TLO.DAG.getBuildVector(SrcVT, DL, MaskElts);
3713 if (SDValue Fold = TLO.DAG.FoldConstantArithmetic(
3714 ISD::AND, DL, SrcVT, {Src.getOperand(1), Mask})) {
3715 Fold = TLO.DAG.getNode(ISD::AND, DL, SrcVT, Src.getOperand(0), Fold);
3716 return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Fold));
3717 }
3718 }
3719 }
3720 break;
3721 }
3722
3723 // TODO: There are more binop opcodes that could be handled here - MIN,
3724 // MAX, saturated math, etc.
3725 case ISD::ADD: {
3726 SDValue Op0 = Op.getOperand(0);
3727 SDValue Op1 = Op.getOperand(1);
3728 if (Op0 == Op1 && Op->isOnlyUserOf(Op0.getNode())) {
3729 APInt UndefLHS, ZeroLHS;
3730 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3731 Depth + 1, /*AssumeSingleUse*/ true))
3732 return true;
3733 }
3734 [[fallthrough]];
3735 }
3736 case ISD::AVGCEILS:
3737 case ISD::AVGCEILU:
3738 case ISD::AVGFLOORS:
3739 case ISD::AVGFLOORU:
3740 case ISD::OR:
3741 case ISD::XOR:
3742 case ISD::SUB:
3743 case ISD::FADD:
3744 case ISD::FSUB:
3745 case ISD::FMUL:
3746 case ISD::FDIV:
3747 case ISD::FREM: {
3748 SDValue Op0 = Op.getOperand(0);
3749 SDValue Op1 = Op.getOperand(1);
3750
3751 APInt UndefRHS, ZeroRHS;
3752 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3753 Depth + 1))
3754 return true;
3755 APInt UndefLHS, ZeroLHS;
3756 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3757 Depth + 1))
3758 return true;
3759
3760 KnownZero = ZeroLHS & ZeroRHS;
3761 KnownUndef = getKnownUndefForVectorBinop(Op, TLO.DAG, UndefLHS, UndefRHS);
3762
3763 // Attempt to avoid multi-use ops if we don't need anything from them.
3764 // TODO - use KnownUndef to relax the demandedelts?
3765 if (!DemandedElts.isAllOnes())
3766 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3767 return true;
3768 break;
3769 }
3770 case ISD::SHL:
3771 case ISD::SRL:
3772 case ISD::SRA:
3773 case ISD::ROTL:
3774 case ISD::ROTR: {
3775 SDValue Op0 = Op.getOperand(0);
3776 SDValue Op1 = Op.getOperand(1);
3777
3778 APInt UndefRHS, ZeroRHS;
3779 if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO,
3780 Depth + 1))
3781 return true;
3782 APInt UndefLHS, ZeroLHS;
3783 if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO,
3784 Depth + 1))
3785 return true;
3786
3787 KnownZero = ZeroLHS;
3788 KnownUndef = UndefLHS & UndefRHS; // TODO: use getKnownUndefForVectorBinop?
3789
3790 // Attempt to avoid multi-use ops if we don't need anything from them.
3791 // TODO - use KnownUndef to relax the demandedelts?
3792 if (!DemandedElts.isAllOnes())
3793 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3794 return true;
3795 break;
3796 }
3797 case ISD::MUL:
3798 case ISD::MULHU:
3799 case ISD::MULHS:
3800 case ISD::AND: {
3801 SDValue Op0 = Op.getOperand(0);
3802 SDValue Op1 = Op.getOperand(1);
3803
3804 APInt SrcUndef, SrcZero;
3805 if (SimplifyDemandedVectorElts(Op1, DemandedElts, SrcUndef, SrcZero, TLO,
3806 Depth + 1))
3807 return true;
3808 // If we know that a demanded element was zero in Op1 we don't need to
3809 // demand it in Op0 - its guaranteed to be zero.
3810 APInt DemandedElts0 = DemandedElts & ~SrcZero;
3811 if (SimplifyDemandedVectorElts(Op0, DemandedElts0, KnownUndef, KnownZero,
3812 TLO, Depth + 1))
3813 return true;
3814
3815 KnownUndef &= DemandedElts0;
3816 KnownZero &= DemandedElts0;
3817
3818 // If every element pair has a zero/undef then just fold to zero.
3819 // fold (and x, undef) -> 0 / (and x, 0) -> 0
3820 // fold (mul x, undef) -> 0 / (mul x, 0) -> 0
3821 if (DemandedElts.isSubsetOf(SrcZero | KnownZero | SrcUndef | KnownUndef))
3822 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3823
3824 // If either side has a zero element, then the result element is zero, even
3825 // if the other is an UNDEF.
3826 // TODO: Extend getKnownUndefForVectorBinop to also deal with known zeros
3827 // and then handle 'and' nodes with the rest of the binop opcodes.
3828 KnownZero |= SrcZero;
3829 KnownUndef &= SrcUndef;
3830 KnownUndef &= ~KnownZero;
3831
3832 // Attempt to avoid multi-use ops if we don't need anything from them.
3833 if (!DemandedElts.isAllOnes())
3834 if (SimplifyDemandedVectorEltsBinOp(Op0, Op1))
3835 return true;
3836 break;
3837 }
3838 case ISD::TRUNCATE:
3839 case ISD::SIGN_EXTEND:
3840 case ISD::ZERO_EXTEND:
3841 if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef,
3842 KnownZero, TLO, Depth + 1))
3843 return true;
3844
3845 if (!DemandedElts.isAllOnes())
3846 if (SDValue NewOp = SimplifyMultipleUseDemandedVectorElts(
3847 Op.getOperand(0), DemandedElts, TLO.DAG, Depth + 1))
3848 return TLO.CombineTo(Op, TLO.DAG.getNode(Opcode, SDLoc(Op), VT, NewOp));
3849
3850 if (Op.getOpcode() == ISD::ZERO_EXTEND) {
3851 // zext(undef) upper bits are guaranteed to be zero.
3852 if (DemandedElts.isSubsetOf(KnownUndef))
3853 return TLO.CombineTo(Op, TLO.DAG.getConstant(0, SDLoc(Op), VT));
3854 KnownUndef.clearAllBits();
3855 }
3856 break;
3857 case ISD::SINT_TO_FP:
3858 case ISD::UINT_TO_FP:
3859 case ISD::FP_TO_SINT:
3860 case ISD::FP_TO_UINT:
3861 if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef,
3862 KnownZero, TLO, Depth + 1))
3863 return true;
3864 // Don't fall through to generic undef -> undef handling.
3865 return false;
3866 default: {
3867 if (Op.getOpcode() >= ISD::BUILTIN_OP_END) {
3868 if (SimplifyDemandedVectorEltsForTargetNode(Op, DemandedElts, KnownUndef,
3869 KnownZero, TLO, Depth))
3870 return true;
3871 } else {
3872 KnownBits Known;
3873 APInt DemandedBits = APInt::getAllOnes(EltSizeInBits);
3874 if (SimplifyDemandedBits(Op, DemandedBits, OriginalDemandedElts, Known,
3875 TLO, Depth, AssumeSingleUse))
3876 return true;
3877 }
3878 break;
3879 }
3880 }
3881 assert((KnownUndef & KnownZero) == 0 && "Elements flagged as undef AND zero");
3882
3883 // Constant fold all undef cases.
3884 // TODO: Handle zero cases as well.
3885 if (DemandedElts.isSubsetOf(KnownUndef))
3886 return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
3887
3888 return false;
3889 }
3890
3891 /// Determine which of the bits specified in Mask are known to be either zero or
3892 /// one and return them in the Known.
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const3893 void TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3894 KnownBits &Known,
3895 const APInt &DemandedElts,
3896 const SelectionDAG &DAG,
3897 unsigned Depth) const {
3898 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3899 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3900 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3901 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3902 "Should use MaskedValueIsZero if you don't know whether Op"
3903 " is a target node!");
3904 Known.resetAll();
3905 }
3906
computeKnownBitsForTargetInstr(GISelValueTracking & Analysis,Register R,KnownBits & Known,const APInt & DemandedElts,const MachineRegisterInfo & MRI,unsigned Depth) const3907 void TargetLowering::computeKnownBitsForTargetInstr(
3908 GISelValueTracking &Analysis, Register R, KnownBits &Known,
3909 const APInt &DemandedElts, const MachineRegisterInfo &MRI,
3910 unsigned Depth) const {
3911 Known.resetAll();
3912 }
3913
computeKnownFPClassForTargetInstr(GISelValueTracking & Analysis,Register R,KnownFPClass & Known,const APInt & DemandedElts,const MachineRegisterInfo & MRI,unsigned Depth) const3914 void TargetLowering::computeKnownFPClassForTargetInstr(
3915 GISelValueTracking &Analysis, Register R, KnownFPClass &Known,
3916 const APInt &DemandedElts, const MachineRegisterInfo &MRI,
3917 unsigned Depth) const {
3918 Known.resetAll();
3919 }
3920
computeKnownBitsForFrameIndex(const int FrameIdx,KnownBits & Known,const MachineFunction & MF) const3921 void TargetLowering::computeKnownBitsForFrameIndex(
3922 const int FrameIdx, KnownBits &Known, const MachineFunction &MF) const {
3923 // The low bits are known zero if the pointer is aligned.
3924 Known.Zero.setLowBits(Log2(MF.getFrameInfo().getObjectAlign(FrameIdx)));
3925 }
3926
computeKnownAlignForTargetInstr(GISelValueTracking & Analysis,Register R,const MachineRegisterInfo & MRI,unsigned Depth) const3927 Align TargetLowering::computeKnownAlignForTargetInstr(
3928 GISelValueTracking &Analysis, Register R, const MachineRegisterInfo &MRI,
3929 unsigned Depth) const {
3930 return Align(1);
3931 }
3932
3933 /// This method can be implemented by targets that want to expose additional
3934 /// information about sign bits to the DAG Combiner.
ComputeNumSignBitsForTargetNode(SDValue Op,const APInt &,const SelectionDAG &,unsigned Depth) const3935 unsigned TargetLowering::ComputeNumSignBitsForTargetNode(SDValue Op,
3936 const APInt &,
3937 const SelectionDAG &,
3938 unsigned Depth) const {
3939 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3940 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3941 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3942 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3943 "Should use ComputeNumSignBits if you don't know whether Op"
3944 " is a target node!");
3945 return 1;
3946 }
3947
computeNumSignBitsForTargetInstr(GISelValueTracking & Analysis,Register R,const APInt & DemandedElts,const MachineRegisterInfo & MRI,unsigned Depth) const3948 unsigned TargetLowering::computeNumSignBitsForTargetInstr(
3949 GISelValueTracking &Analysis, Register R, const APInt &DemandedElts,
3950 const MachineRegisterInfo &MRI, unsigned Depth) const {
3951 return 1;
3952 }
3953
SimplifyDemandedVectorEltsForTargetNode(SDValue Op,const APInt & DemandedElts,APInt & KnownUndef,APInt & KnownZero,TargetLoweringOpt & TLO,unsigned Depth) const3954 bool TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
3955 SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero,
3956 TargetLoweringOpt &TLO, unsigned Depth) const {
3957 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3958 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3959 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3960 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3961 "Should use SimplifyDemandedVectorElts if you don't know whether Op"
3962 " is a target node!");
3963 return false;
3964 }
3965
SimplifyDemandedBitsForTargetNode(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth) const3966 bool TargetLowering::SimplifyDemandedBitsForTargetNode(
3967 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3968 KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
3969 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3970 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3971 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3972 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3973 "Should use SimplifyDemandedBits if you don't know whether Op"
3974 " is a target node!");
3975 computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
3976 return false;
3977 }
3978
SimplifyMultipleUseDemandedBitsForTargetNode(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,SelectionDAG & DAG,unsigned Depth) const3979 SDValue TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
3980 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
3981 SelectionDAG &DAG, unsigned Depth) const {
3982 assert(
3983 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
3984 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
3985 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
3986 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
3987 "Should use SimplifyMultipleUseDemandedBits if you don't know whether Op"
3988 " is a target node!");
3989 return SDValue();
3990 }
3991
3992 SDValue
buildLegalVectorShuffle(EVT VT,const SDLoc & DL,SDValue N0,SDValue N1,MutableArrayRef<int> Mask,SelectionDAG & DAG) const3993 TargetLowering::buildLegalVectorShuffle(EVT VT, const SDLoc &DL, SDValue N0,
3994 SDValue N1, MutableArrayRef<int> Mask,
3995 SelectionDAG &DAG) const {
3996 bool LegalMask = isShuffleMaskLegal(Mask, VT);
3997 if (!LegalMask) {
3998 std::swap(N0, N1);
3999 ShuffleVectorSDNode::commuteMask(Mask);
4000 LegalMask = isShuffleMaskLegal(Mask, VT);
4001 }
4002
4003 if (!LegalMask)
4004 return SDValue();
4005
4006 return DAG.getVectorShuffle(VT, DL, N0, N1, Mask);
4007 }
4008
getTargetConstantFromLoad(LoadSDNode *) const4009 const Constant *TargetLowering::getTargetConstantFromLoad(LoadSDNode*) const {
4010 return nullptr;
4011 }
4012
isGuaranteedNotToBeUndefOrPoisonForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,bool PoisonOnly,unsigned Depth) const4013 bool TargetLowering::isGuaranteedNotToBeUndefOrPoisonForTargetNode(
4014 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
4015 bool PoisonOnly, unsigned Depth) const {
4016 assert(
4017 (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
4018 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
4019 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
4020 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
4021 "Should use isGuaranteedNotToBeUndefOrPoison if you don't know whether Op"
4022 " is a target node!");
4023
4024 // If Op can't create undef/poison and none of its operands are undef/poison
4025 // then Op is never undef/poison.
4026 return !canCreateUndefOrPoisonForTargetNode(Op, DemandedElts, DAG, PoisonOnly,
4027 /*ConsiderFlags*/ true, Depth) &&
4028 all_of(Op->ops(), [&](SDValue V) {
4029 return DAG.isGuaranteedNotToBeUndefOrPoison(V, PoisonOnly,
4030 Depth + 1);
4031 });
4032 }
4033
canCreateUndefOrPoisonForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,bool PoisonOnly,bool ConsiderFlags,unsigned Depth) const4034 bool TargetLowering::canCreateUndefOrPoisonForTargetNode(
4035 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
4036 bool PoisonOnly, bool ConsiderFlags, unsigned Depth) const {
4037 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
4038 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
4039 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
4040 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
4041 "Should use canCreateUndefOrPoison if you don't know whether Op"
4042 " is a target node!");
4043 // Be conservative and return true.
4044 return true;
4045 }
4046
isKnownNeverNaNForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,bool SNaN,unsigned Depth) const4047 bool TargetLowering::isKnownNeverNaNForTargetNode(SDValue Op,
4048 const APInt &DemandedElts,
4049 const SelectionDAG &DAG,
4050 bool SNaN,
4051 unsigned Depth) const {
4052 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
4053 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
4054 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
4055 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
4056 "Should use isKnownNeverNaN if you don't know whether Op"
4057 " is a target node!");
4058 return false;
4059 }
4060
isSplatValueForTargetNode(SDValue Op,const APInt & DemandedElts,APInt & UndefElts,const SelectionDAG & DAG,unsigned Depth) const4061 bool TargetLowering::isSplatValueForTargetNode(SDValue Op,
4062 const APInt &DemandedElts,
4063 APInt &UndefElts,
4064 const SelectionDAG &DAG,
4065 unsigned Depth) const {
4066 assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
4067 Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
4068 Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
4069 Op.getOpcode() == ISD::INTRINSIC_VOID) &&
4070 "Should use isSplatValue if you don't know whether Op"
4071 " is a target node!");
4072 return false;
4073 }
4074
4075 // FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must
4076 // work with truncating build vectors and vectors with elements of less than
4077 // 8 bits.
isConstTrueVal(SDValue N) const4078 bool TargetLowering::isConstTrueVal(SDValue N) const {
4079 if (!N)
4080 return false;
4081
4082 unsigned EltWidth;
4083 APInt CVal;
4084 if (ConstantSDNode *CN = isConstOrConstSplat(N, /*AllowUndefs=*/false,
4085 /*AllowTruncation=*/true)) {
4086 CVal = CN->getAPIntValue();
4087 EltWidth = N.getValueType().getScalarSizeInBits();
4088 } else
4089 return false;
4090
4091 // If this is a truncating splat, truncate the splat value.
4092 // Otherwise, we may fail to match the expected values below.
4093 if (EltWidth < CVal.getBitWidth())
4094 CVal = CVal.trunc(EltWidth);
4095
4096 switch (getBooleanContents(N.getValueType())) {
4097 case UndefinedBooleanContent:
4098 return CVal[0];
4099 case ZeroOrOneBooleanContent:
4100 return CVal.isOne();
4101 case ZeroOrNegativeOneBooleanContent:
4102 return CVal.isAllOnes();
4103 }
4104
4105 llvm_unreachable("Invalid boolean contents");
4106 }
4107
isConstFalseVal(SDValue N) const4108 bool TargetLowering::isConstFalseVal(SDValue N) const {
4109 if (!N)
4110 return false;
4111
4112 const ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N);
4113 if (!CN) {
4114 const BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N);
4115 if (!BV)
4116 return false;
4117
4118 // Only interested in constant splats, we don't care about undef
4119 // elements in identifying boolean constants and getConstantSplatNode
4120 // returns NULL if all ops are undef;
4121 CN = BV->getConstantSplatNode();
4122 if (!CN)
4123 return false;
4124 }
4125
4126 if (getBooleanContents(N->getValueType(0)) == UndefinedBooleanContent)
4127 return !CN->getAPIntValue()[0];
4128
4129 return CN->isZero();
4130 }
4131
isExtendedTrueVal(const ConstantSDNode * N,EVT VT,bool SExt) const4132 bool TargetLowering::isExtendedTrueVal(const ConstantSDNode *N, EVT VT,
4133 bool SExt) const {
4134 if (VT == MVT::i1)
4135 return N->isOne();
4136
4137 TargetLowering::BooleanContent Cnt = getBooleanContents(VT);
4138 switch (Cnt) {
4139 case TargetLowering::ZeroOrOneBooleanContent:
4140 // An extended value of 1 is always true, unless its original type is i1,
4141 // in which case it will be sign extended to -1.
4142 return (N->isOne() && !SExt) || (SExt && (N->getValueType(0) != MVT::i1));
4143 case TargetLowering::UndefinedBooleanContent:
4144 case TargetLowering::ZeroOrNegativeOneBooleanContent:
4145 return N->isAllOnes() && SExt;
4146 }
4147 llvm_unreachable("Unexpected enumeration.");
4148 }
4149
4150 /// This helper function of SimplifySetCC tries to optimize the comparison when
4151 /// either operand of the SetCC node is a bitwise-and instruction.
foldSetCCWithAnd(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,DAGCombinerInfo & DCI) const4152 SDValue TargetLowering::foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1,
4153 ISD::CondCode Cond, const SDLoc &DL,
4154 DAGCombinerInfo &DCI) const {
4155 if (N1.getOpcode() == ISD::AND && N0.getOpcode() != ISD::AND)
4156 std::swap(N0, N1);
4157
4158 SelectionDAG &DAG = DCI.DAG;
4159 EVT OpVT = N0.getValueType();
4160 if (N0.getOpcode() != ISD::AND || !OpVT.isInteger() ||
4161 (Cond != ISD::SETEQ && Cond != ISD::SETNE))
4162 return SDValue();
4163
4164 // (X & Y) != 0 --> zextOrTrunc(X & Y)
4165 // iff everything but LSB is known zero:
4166 if (Cond == ISD::SETNE && isNullConstant(N1) &&
4167 (getBooleanContents(OpVT) == TargetLowering::UndefinedBooleanContent ||
4168 getBooleanContents(OpVT) == TargetLowering::ZeroOrOneBooleanContent)) {
4169 unsigned NumEltBits = OpVT.getScalarSizeInBits();
4170 APInt UpperBits = APInt::getHighBitsSet(NumEltBits, NumEltBits - 1);
4171 if (DAG.MaskedValueIsZero(N0, UpperBits))
4172 return DAG.getBoolExtOrTrunc(N0, DL, VT, OpVT);
4173 }
4174
4175 // Try to eliminate a power-of-2 mask constant by converting to a signbit
4176 // test in a narrow type that we can truncate to with no cost. Examples:
4177 // (i32 X & 32768) == 0 --> (trunc X to i16) >= 0
4178 // (i32 X & 32768) != 0 --> (trunc X to i16) < 0
4179 // TODO: This conservatively checks for type legality on the source and
4180 // destination types. That may inhibit optimizations, but it also
4181 // allows setcc->shift transforms that may be more beneficial.
4182 auto *AndC = dyn_cast<ConstantSDNode>(N0.getOperand(1));
4183 if (AndC && isNullConstant(N1) && AndC->getAPIntValue().isPowerOf2() &&
4184 isTypeLegal(OpVT) && N0.hasOneUse()) {
4185 EVT NarrowVT = EVT::getIntegerVT(*DAG.getContext(),
4186 AndC->getAPIntValue().getActiveBits());
4187 if (isTruncateFree(OpVT, NarrowVT) && isTypeLegal(NarrowVT)) {
4188 SDValue Trunc = DAG.getZExtOrTrunc(N0.getOperand(0), DL, NarrowVT);
4189 SDValue Zero = DAG.getConstant(0, DL, NarrowVT);
4190 return DAG.getSetCC(DL, VT, Trunc, Zero,
4191 Cond == ISD::SETEQ ? ISD::SETGE : ISD::SETLT);
4192 }
4193 }
4194
4195 // Match these patterns in any of their permutations:
4196 // (X & Y) == Y
4197 // (X & Y) != Y
4198 SDValue X, Y;
4199 if (N0.getOperand(0) == N1) {
4200 X = N0.getOperand(1);
4201 Y = N0.getOperand(0);
4202 } else if (N0.getOperand(1) == N1) {
4203 X = N0.getOperand(0);
4204 Y = N0.getOperand(1);
4205 } else {
4206 return SDValue();
4207 }
4208
4209 // TODO: We should invert (X & Y) eq/ne 0 -> (X & Y) ne/eq Y if
4210 // `isXAndYEqZeroPreferableToXAndYEqY` is false. This is a bit difficult as
4211 // its liable to create and infinite loop.
4212 SDValue Zero = DAG.getConstant(0, DL, OpVT);
4213 if (isXAndYEqZeroPreferableToXAndYEqY(Cond, OpVT) &&
4214 DAG.isKnownToBeAPowerOfTwo(Y)) {
4215 // Simplify X & Y == Y to X & Y != 0 if Y has exactly one bit set.
4216 // Note that where Y is variable and is known to have at most one bit set
4217 // (for example, if it is Z & 1) we cannot do this; the expressions are not
4218 // equivalent when Y == 0.
4219 assert(OpVT.isInteger());
4220 Cond = ISD::getSetCCInverse(Cond, OpVT);
4221 if (DCI.isBeforeLegalizeOps() ||
4222 isCondCodeLegal(Cond, N0.getSimpleValueType()))
4223 return DAG.getSetCC(DL, VT, N0, Zero, Cond);
4224 } else if (N0.hasOneUse() && hasAndNotCompare(Y)) {
4225 // If the target supports an 'and-not' or 'and-complement' logic operation,
4226 // try to use that to make a comparison operation more efficient.
4227 // But don't do this transform if the mask is a single bit because there are
4228 // more efficient ways to deal with that case (for example, 'bt' on x86 or
4229 // 'rlwinm' on PPC).
4230
4231 // Bail out if the compare operand that we want to turn into a zero is
4232 // already a zero (otherwise, infinite loop).
4233 if (isNullConstant(Y))
4234 return SDValue();
4235
4236 // Transform this into: ~X & Y == 0.
4237 SDValue NotX = DAG.getNOT(SDLoc(X), X, OpVT);
4238 SDValue NewAnd = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, NotX, Y);
4239 return DAG.getSetCC(DL, VT, NewAnd, Zero, Cond);
4240 }
4241
4242 return SDValue();
4243 }
4244
4245 /// This helper function of SimplifySetCC tries to optimize the comparison when
4246 /// either operand of the SetCC node is a bitwise-or instruction.
4247 /// For now, this just transforms (X | Y) ==/!= Y into X & ~Y ==/!= 0.
foldSetCCWithOr(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,DAGCombinerInfo & DCI) const4248 SDValue TargetLowering::foldSetCCWithOr(EVT VT, SDValue N0, SDValue N1,
4249 ISD::CondCode Cond, const SDLoc &DL,
4250 DAGCombinerInfo &DCI) const {
4251 if (N1.getOpcode() == ISD::OR && N0.getOpcode() != ISD::OR)
4252 std::swap(N0, N1);
4253
4254 SelectionDAG &DAG = DCI.DAG;
4255 EVT OpVT = N0.getValueType();
4256 if (!N0.hasOneUse() || !OpVT.isInteger() ||
4257 (Cond != ISD::SETEQ && Cond != ISD::SETNE))
4258 return SDValue();
4259
4260 // (X | Y) == Y
4261 // (X | Y) != Y
4262 SDValue X;
4263 if (sd_match(N0, m_Or(m_Value(X), m_Specific(N1))) && hasAndNotCompare(X)) {
4264 // If the target supports an 'and-not' or 'and-complement' logic operation,
4265 // try to use that to make a comparison operation more efficient.
4266
4267 // Bail out if the compare operand that we want to turn into a zero is
4268 // already a zero (otherwise, infinite loop).
4269 if (isNullConstant(N1))
4270 return SDValue();
4271
4272 // Transform this into: X & ~Y ==/!= 0.
4273 SDValue NotY = DAG.getNOT(SDLoc(N1), N1, OpVT);
4274 SDValue NewAnd = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, X, NotY);
4275 return DAG.getSetCC(DL, VT, NewAnd, DAG.getConstant(0, DL, OpVT), Cond);
4276 }
4277
4278 return SDValue();
4279 }
4280
4281 /// There are multiple IR patterns that could be checking whether certain
4282 /// truncation of a signed number would be lossy or not. The pattern which is
4283 /// best at IR level, may not lower optimally. Thus, we want to unfold it.
4284 /// We are looking for the following pattern: (KeptBits is a constant)
4285 /// (add %x, (1 << (KeptBits-1))) srccond (1 << KeptBits)
4286 /// KeptBits won't be bitwidth(x), that will be constant-folded to true/false.
4287 /// KeptBits also can't be 1, that would have been folded to %x dstcond 0
4288 /// We will unfold it into the natural trunc+sext pattern:
4289 /// ((%x << C) a>> C) dstcond %x
4290 /// Where C = bitwidth(x) - KeptBits and C u< bitwidth(x)
optimizeSetCCOfSignedTruncationCheck(EVT SCCVT,SDValue N0,SDValue N1,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const4291 SDValue TargetLowering::optimizeSetCCOfSignedTruncationCheck(
4292 EVT SCCVT, SDValue N0, SDValue N1, ISD::CondCode Cond, DAGCombinerInfo &DCI,
4293 const SDLoc &DL) const {
4294 // We must be comparing with a constant.
4295 ConstantSDNode *C1;
4296 if (!(C1 = dyn_cast<ConstantSDNode>(N1)))
4297 return SDValue();
4298
4299 // N0 should be: add %x, (1 << (KeptBits-1))
4300 if (N0->getOpcode() != ISD::ADD)
4301 return SDValue();
4302
4303 // And we must be 'add'ing a constant.
4304 ConstantSDNode *C01;
4305 if (!(C01 = dyn_cast<ConstantSDNode>(N0->getOperand(1))))
4306 return SDValue();
4307
4308 SDValue X = N0->getOperand(0);
4309 EVT XVT = X.getValueType();
4310
4311 // Validate constants ...
4312
4313 APInt I1 = C1->getAPIntValue();
4314
4315 ISD::CondCode NewCond;
4316 if (Cond == ISD::CondCode::SETULT) {
4317 NewCond = ISD::CondCode::SETEQ;
4318 } else if (Cond == ISD::CondCode::SETULE) {
4319 NewCond = ISD::CondCode::SETEQ;
4320 // But need to 'canonicalize' the constant.
4321 I1 += 1;
4322 } else if (Cond == ISD::CondCode::SETUGT) {
4323 NewCond = ISD::CondCode::SETNE;
4324 // But need to 'canonicalize' the constant.
4325 I1 += 1;
4326 } else if (Cond == ISD::CondCode::SETUGE) {
4327 NewCond = ISD::CondCode::SETNE;
4328 } else
4329 return SDValue();
4330
4331 APInt I01 = C01->getAPIntValue();
4332
4333 auto checkConstants = [&I1, &I01]() -> bool {
4334 // Both of them must be power-of-two, and the constant from setcc is bigger.
4335 return I1.ugt(I01) && I1.isPowerOf2() && I01.isPowerOf2();
4336 };
4337
4338 if (checkConstants()) {
4339 // Great, e.g. got icmp ult i16 (add i16 %x, 128), 256
4340 } else {
4341 // What if we invert constants? (and the target predicate)
4342 I1.negate();
4343 I01.negate();
4344 assert(XVT.isInteger());
4345 NewCond = getSetCCInverse(NewCond, XVT);
4346 if (!checkConstants())
4347 return SDValue();
4348 // Great, e.g. got icmp uge i16 (add i16 %x, -128), -256
4349 }
4350
4351 // They are power-of-two, so which bit is set?
4352 const unsigned KeptBits = I1.logBase2();
4353 const unsigned KeptBitsMinusOne = I01.logBase2();
4354
4355 // Magic!
4356 if (KeptBits != (KeptBitsMinusOne + 1))
4357 return SDValue();
4358 assert(KeptBits > 0 && KeptBits < XVT.getSizeInBits() && "unreachable");
4359
4360 // We don't want to do this in every single case.
4361 SelectionDAG &DAG = DCI.DAG;
4362 if (!shouldTransformSignedTruncationCheck(XVT, KeptBits))
4363 return SDValue();
4364
4365 // Unfold into: sext_inreg(%x) cond %x
4366 // Where 'cond' will be either 'eq' or 'ne'.
4367 SDValue SExtInReg = DAG.getNode(
4368 ISD::SIGN_EXTEND_INREG, DL, XVT, X,
4369 DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), KeptBits)));
4370 return DAG.getSetCC(DL, SCCVT, SExtInReg, X, NewCond);
4371 }
4372
4373 // (X & (C l>>/<< Y)) ==/!= 0 --> ((X <</l>> Y) & C) ==/!= 0
optimizeSetCCByHoistingAndByConstFromLogicalShift(EVT SCCVT,SDValue N0,SDValue N1C,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const4374 SDValue TargetLowering::optimizeSetCCByHoistingAndByConstFromLogicalShift(
4375 EVT SCCVT, SDValue N0, SDValue N1C, ISD::CondCode Cond,
4376 DAGCombinerInfo &DCI, const SDLoc &DL) const {
4377 assert(isConstOrConstSplat(N1C) && isConstOrConstSplat(N1C)->isZero() &&
4378 "Should be a comparison with 0.");
4379 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4380 "Valid only for [in]equality comparisons.");
4381
4382 unsigned NewShiftOpcode;
4383 SDValue X, C, Y;
4384
4385 SelectionDAG &DAG = DCI.DAG;
4386
4387 // Look for '(C l>>/<< Y)'.
4388 auto Match = [&NewShiftOpcode, &X, &C, &Y, &DAG, this](SDValue V) {
4389 // The shift should be one-use.
4390 if (!V.hasOneUse())
4391 return false;
4392 unsigned OldShiftOpcode = V.getOpcode();
4393 switch (OldShiftOpcode) {
4394 case ISD::SHL:
4395 NewShiftOpcode = ISD::SRL;
4396 break;
4397 case ISD::SRL:
4398 NewShiftOpcode = ISD::SHL;
4399 break;
4400 default:
4401 return false; // must be a logical shift.
4402 }
4403 // We should be shifting a constant.
4404 // FIXME: best to use isConstantOrConstantVector().
4405 C = V.getOperand(0);
4406 ConstantSDNode *CC =
4407 isConstOrConstSplat(C, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4408 if (!CC)
4409 return false;
4410 Y = V.getOperand(1);
4411
4412 ConstantSDNode *XC =
4413 isConstOrConstSplat(X, /*AllowUndefs=*/true, /*AllowTruncation=*/true);
4414 return shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
4415 X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG);
4416 };
4417
4418 // LHS of comparison should be an one-use 'and'.
4419 if (N0.getOpcode() != ISD::AND || !N0.hasOneUse())
4420 return SDValue();
4421
4422 X = N0.getOperand(0);
4423 SDValue Mask = N0.getOperand(1);
4424
4425 // 'and' is commutative!
4426 if (!Match(Mask)) {
4427 std::swap(X, Mask);
4428 if (!Match(Mask))
4429 return SDValue();
4430 }
4431
4432 EVT VT = X.getValueType();
4433
4434 // Produce:
4435 // ((X 'OppositeShiftOpcode' Y) & C) Cond 0
4436 SDValue T0 = DAG.getNode(NewShiftOpcode, DL, VT, X, Y);
4437 SDValue T1 = DAG.getNode(ISD::AND, DL, VT, T0, C);
4438 SDValue T2 = DAG.getSetCC(DL, SCCVT, T1, N1C, Cond);
4439 return T2;
4440 }
4441
4442 /// Try to fold an equality comparison with a {add/sub/xor} binary operation as
4443 /// the 1st operand (N0). Callers are expected to swap the N0/N1 parameters to
4444 /// handle the commuted versions of these patterns.
foldSetCCWithBinOp(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,DAGCombinerInfo & DCI) const4445 SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
4446 ISD::CondCode Cond, const SDLoc &DL,
4447 DAGCombinerInfo &DCI) const {
4448 unsigned BOpcode = N0.getOpcode();
4449 assert((BOpcode == ISD::ADD || BOpcode == ISD::SUB || BOpcode == ISD::XOR) &&
4450 "Unexpected binop");
4451 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) && "Unexpected condcode");
4452
4453 // (X + Y) == X --> Y == 0
4454 // (X - Y) == X --> Y == 0
4455 // (X ^ Y) == X --> Y == 0
4456 SelectionDAG &DAG = DCI.DAG;
4457 EVT OpVT = N0.getValueType();
4458 SDValue X = N0.getOperand(0);
4459 SDValue Y = N0.getOperand(1);
4460 if (X == N1)
4461 return DAG.getSetCC(DL, VT, Y, DAG.getConstant(0, DL, OpVT), Cond);
4462
4463 if (Y != N1)
4464 return SDValue();
4465
4466 // (X + Y) == Y --> X == 0
4467 // (X ^ Y) == Y --> X == 0
4468 if (BOpcode == ISD::ADD || BOpcode == ISD::XOR)
4469 return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, OpVT), Cond);
4470
4471 // The shift would not be valid if the operands are boolean (i1).
4472 if (!N0.hasOneUse() || OpVT.getScalarSizeInBits() == 1)
4473 return SDValue();
4474
4475 // (X - Y) == Y --> X == Y << 1
4476 SDValue One = DAG.getShiftAmountConstant(1, OpVT, DL);
4477 SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
4478 if (!DCI.isCalledByLegalizer())
4479 DCI.AddToWorklist(YShl1.getNode());
4480 return DAG.getSetCC(DL, VT, X, YShl1, Cond);
4481 }
4482
simplifySetCCWithCTPOP(const TargetLowering & TLI,EVT VT,SDValue N0,const APInt & C1,ISD::CondCode Cond,const SDLoc & dl,SelectionDAG & DAG)4483 static SDValue simplifySetCCWithCTPOP(const TargetLowering &TLI, EVT VT,
4484 SDValue N0, const APInt &C1,
4485 ISD::CondCode Cond, const SDLoc &dl,
4486 SelectionDAG &DAG) {
4487 // Look through truncs that don't change the value of a ctpop.
4488 // FIXME: Add vector support? Need to be careful with setcc result type below.
4489 SDValue CTPOP = N0;
4490 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() && !VT.isVector() &&
4491 N0.getScalarValueSizeInBits() > Log2_32(N0.getOperand(0).getScalarValueSizeInBits()))
4492 CTPOP = N0.getOperand(0);
4493
4494 if (CTPOP.getOpcode() != ISD::CTPOP || !CTPOP.hasOneUse())
4495 return SDValue();
4496
4497 EVT CTVT = CTPOP.getValueType();
4498 SDValue CTOp = CTPOP.getOperand(0);
4499
4500 // Expand a power-of-2-or-zero comparison based on ctpop:
4501 // (ctpop x) u< 2 -> (x & x-1) == 0
4502 // (ctpop x) u> 1 -> (x & x-1) != 0
4503 if (Cond == ISD::SETULT || Cond == ISD::SETUGT) {
4504 // Keep the CTPOP if it is a cheap vector op.
4505 if (CTVT.isVector() && TLI.isCtpopFast(CTVT))
4506 return SDValue();
4507
4508 unsigned CostLimit = TLI.getCustomCtpopCost(CTVT, Cond);
4509 if (C1.ugt(CostLimit + (Cond == ISD::SETULT)))
4510 return SDValue();
4511 if (C1 == 0 && (Cond == ISD::SETULT))
4512 return SDValue(); // This is handled elsewhere.
4513
4514 unsigned Passes = C1.getLimitedValue() - (Cond == ISD::SETULT);
4515
4516 SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4517 SDValue Result = CTOp;
4518 for (unsigned i = 0; i < Passes; i++) {
4519 SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, Result, NegOne);
4520 Result = DAG.getNode(ISD::AND, dl, CTVT, Result, Add);
4521 }
4522 ISD::CondCode CC = Cond == ISD::SETULT ? ISD::SETEQ : ISD::SETNE;
4523 return DAG.getSetCC(dl, VT, Result, DAG.getConstant(0, dl, CTVT), CC);
4524 }
4525
4526 // Expand a power-of-2 comparison based on ctpop
4527 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && C1 == 1) {
4528 // Keep the CTPOP if it is cheap.
4529 if (TLI.isCtpopFast(CTVT))
4530 return SDValue();
4531
4532 SDValue Zero = DAG.getConstant(0, dl, CTVT);
4533 SDValue NegOne = DAG.getAllOnesConstant(dl, CTVT);
4534 assert(CTVT.isInteger());
4535 SDValue Add = DAG.getNode(ISD::ADD, dl, CTVT, CTOp, NegOne);
4536
4537 // Its not uncommon for known-never-zero X to exist in (ctpop X) eq/ne 1, so
4538 // check before emitting a potentially unnecessary op.
4539 if (DAG.isKnownNeverZero(CTOp)) {
4540 // (ctpop x) == 1 --> (x & x-1) == 0
4541 // (ctpop x) != 1 --> (x & x-1) != 0
4542 SDValue And = DAG.getNode(ISD::AND, dl, CTVT, CTOp, Add);
4543 SDValue RHS = DAG.getSetCC(dl, VT, And, Zero, Cond);
4544 return RHS;
4545 }
4546
4547 // (ctpop x) == 1 --> (x ^ x-1) > x-1
4548 // (ctpop x) != 1 --> (x ^ x-1) <= x-1
4549 SDValue Xor = DAG.getNode(ISD::XOR, dl, CTVT, CTOp, Add);
4550 ISD::CondCode CmpCond = Cond == ISD::SETEQ ? ISD::SETUGT : ISD::SETULE;
4551 return DAG.getSetCC(dl, VT, Xor, Add, CmpCond);
4552 }
4553
4554 return SDValue();
4555 }
4556
foldSetCCWithRotate(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & dl,SelectionDAG & DAG)4557 static SDValue foldSetCCWithRotate(EVT VT, SDValue N0, SDValue N1,
4558 ISD::CondCode Cond, const SDLoc &dl,
4559 SelectionDAG &DAG) {
4560 if (Cond != ISD::SETEQ && Cond != ISD::SETNE)
4561 return SDValue();
4562
4563 auto *C1 = isConstOrConstSplat(N1, /* AllowUndefs */ true);
4564 if (!C1 || !(C1->isZero() || C1->isAllOnes()))
4565 return SDValue();
4566
4567 auto getRotateSource = [](SDValue X) {
4568 if (X.getOpcode() == ISD::ROTL || X.getOpcode() == ISD::ROTR)
4569 return X.getOperand(0);
4570 return SDValue();
4571 };
4572
4573 // Peek through a rotated value compared against 0 or -1:
4574 // (rot X, Y) == 0/-1 --> X == 0/-1
4575 // (rot X, Y) != 0/-1 --> X != 0/-1
4576 if (SDValue R = getRotateSource(N0))
4577 return DAG.getSetCC(dl, VT, R, N1, Cond);
4578
4579 // Peek through an 'or' of a rotated value compared against 0:
4580 // or (rot X, Y), Z ==/!= 0 --> (or X, Z) ==/!= 0
4581 // or Z, (rot X, Y) ==/!= 0 --> (or X, Z) ==/!= 0
4582 //
4583 // TODO: Add the 'and' with -1 sibling.
4584 // TODO: Recurse through a series of 'or' ops to find the rotate.
4585 EVT OpVT = N0.getValueType();
4586 if (N0.hasOneUse() && N0.getOpcode() == ISD::OR && C1->isZero()) {
4587 if (SDValue R = getRotateSource(N0.getOperand(0))) {
4588 SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, R, N0.getOperand(1));
4589 return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4590 }
4591 if (SDValue R = getRotateSource(N0.getOperand(1))) {
4592 SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, R, N0.getOperand(0));
4593 return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4594 }
4595 }
4596
4597 return SDValue();
4598 }
4599
foldSetCCWithFunnelShift(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & dl,SelectionDAG & DAG)4600 static SDValue foldSetCCWithFunnelShift(EVT VT, SDValue N0, SDValue N1,
4601 ISD::CondCode Cond, const SDLoc &dl,
4602 SelectionDAG &DAG) {
4603 // If we are testing for all-bits-clear, we might be able to do that with
4604 // less shifting since bit-order does not matter.
4605 if (Cond != ISD::SETEQ && Cond != ISD::SETNE)
4606 return SDValue();
4607
4608 auto *C1 = isConstOrConstSplat(N1, /* AllowUndefs */ true);
4609 if (!C1 || !C1->isZero())
4610 return SDValue();
4611
4612 if (!N0.hasOneUse() ||
4613 (N0.getOpcode() != ISD::FSHL && N0.getOpcode() != ISD::FSHR))
4614 return SDValue();
4615
4616 unsigned BitWidth = N0.getScalarValueSizeInBits();
4617 auto *ShAmtC = isConstOrConstSplat(N0.getOperand(2));
4618 if (!ShAmtC)
4619 return SDValue();
4620
4621 uint64_t ShAmt = ShAmtC->getAPIntValue().urem(BitWidth);
4622 if (ShAmt == 0)
4623 return SDValue();
4624
4625 // Canonicalize fshr as fshl to reduce pattern-matching.
4626 if (N0.getOpcode() == ISD::FSHR)
4627 ShAmt = BitWidth - ShAmt;
4628
4629 // Match an 'or' with a specific operand 'Other' in either commuted variant.
4630 SDValue X, Y;
4631 auto matchOr = [&X, &Y](SDValue Or, SDValue Other) {
4632 if (Or.getOpcode() != ISD::OR || !Or.hasOneUse())
4633 return false;
4634 if (Or.getOperand(0) == Other) {
4635 X = Or.getOperand(0);
4636 Y = Or.getOperand(1);
4637 return true;
4638 }
4639 if (Or.getOperand(1) == Other) {
4640 X = Or.getOperand(1);
4641 Y = Or.getOperand(0);
4642 return true;
4643 }
4644 return false;
4645 };
4646
4647 EVT OpVT = N0.getValueType();
4648 EVT ShAmtVT = N0.getOperand(2).getValueType();
4649 SDValue F0 = N0.getOperand(0);
4650 SDValue F1 = N0.getOperand(1);
4651 if (matchOr(F0, F1)) {
4652 // fshl (or X, Y), X, C ==/!= 0 --> or (shl Y, C), X ==/!= 0
4653 SDValue NewShAmt = DAG.getConstant(ShAmt, dl, ShAmtVT);
4654 SDValue Shift = DAG.getNode(ISD::SHL, dl, OpVT, Y, NewShAmt);
4655 SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, Shift, X);
4656 return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4657 }
4658 if (matchOr(F1, F0)) {
4659 // fshl X, (or X, Y), C ==/!= 0 --> or (srl Y, BW-C), X ==/!= 0
4660 SDValue NewShAmt = DAG.getConstant(BitWidth - ShAmt, dl, ShAmtVT);
4661 SDValue Shift = DAG.getNode(ISD::SRL, dl, OpVT, Y, NewShAmt);
4662 SDValue NewOr = DAG.getNode(ISD::OR, dl, OpVT, Shift, X);
4663 return DAG.getSetCC(dl, VT, NewOr, N1, Cond);
4664 }
4665
4666 return SDValue();
4667 }
4668
4669 /// Try to simplify a setcc built with the specified operands and cc. If it is
4670 /// unable to simplify it, return a null SDValue.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,bool foldBooleans,DAGCombinerInfo & DCI,const SDLoc & dl) const4671 SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
4672 ISD::CondCode Cond, bool foldBooleans,
4673 DAGCombinerInfo &DCI,
4674 const SDLoc &dl) const {
4675 SelectionDAG &DAG = DCI.DAG;
4676 const DataLayout &Layout = DAG.getDataLayout();
4677 EVT OpVT = N0.getValueType();
4678 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4679
4680 // Constant fold or commute setcc.
4681 if (SDValue Fold = DAG.FoldSetCC(VT, N0, N1, Cond, dl))
4682 return Fold;
4683
4684 bool N0ConstOrSplat =
4685 isConstOrConstSplat(N0, /*AllowUndefs*/ false, /*AllowTruncate*/ true);
4686 bool N1ConstOrSplat =
4687 isConstOrConstSplat(N1, /*AllowUndefs*/ false, /*AllowTruncate*/ true);
4688
4689 // Canonicalize toward having the constant on the RHS.
4690 // TODO: Handle non-splat vector constants. All undef causes trouble.
4691 // FIXME: We can't yet fold constant scalable vector splats, so avoid an
4692 // infinite loop here when we encounter one.
4693 ISD::CondCode SwappedCC = ISD::getSetCCSwappedOperands(Cond);
4694 if (N0ConstOrSplat && !N1ConstOrSplat &&
4695 (DCI.isBeforeLegalizeOps() ||
4696 isCondCodeLegal(SwappedCC, N0.getSimpleValueType())))
4697 return DAG.getSetCC(dl, VT, N1, N0, SwappedCC);
4698
4699 // If we have a subtract with the same 2 non-constant operands as this setcc
4700 // -- but in reverse order -- then try to commute the operands of this setcc
4701 // to match. A matching pair of setcc (cmp) and sub may be combined into 1
4702 // instruction on some targets.
4703 if (!N0ConstOrSplat && !N1ConstOrSplat &&
4704 (DCI.isBeforeLegalizeOps() ||
4705 isCondCodeLegal(SwappedCC, N0.getSimpleValueType())) &&
4706 DAG.doesNodeExist(ISD::SUB, DAG.getVTList(OpVT), {N1, N0}) &&
4707 !DAG.doesNodeExist(ISD::SUB, DAG.getVTList(OpVT), {N0, N1}))
4708 return DAG.getSetCC(dl, VT, N1, N0, SwappedCC);
4709
4710 if (SDValue V = foldSetCCWithRotate(VT, N0, N1, Cond, dl, DAG))
4711 return V;
4712
4713 if (SDValue V = foldSetCCWithFunnelShift(VT, N0, N1, Cond, dl, DAG))
4714 return V;
4715
4716 if (auto *N1C = isConstOrConstSplat(N1)) {
4717 const APInt &C1 = N1C->getAPIntValue();
4718
4719 // Optimize some CTPOP cases.
4720 if (SDValue V = simplifySetCCWithCTPOP(*this, VT, N0, C1, Cond, dl, DAG))
4721 return V;
4722
4723 // For equality to 0 of a no-wrap multiply, decompose and test each op:
4724 // X * Y == 0 --> (X == 0) || (Y == 0)
4725 // X * Y != 0 --> (X != 0) && (Y != 0)
4726 // TODO: This bails out if minsize is set, but if the target doesn't have a
4727 // single instruction multiply for this type, it would likely be
4728 // smaller to decompose.
4729 if (C1.isZero() && (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4730 N0.getOpcode() == ISD::MUL && N0.hasOneUse() &&
4731 (N0->getFlags().hasNoUnsignedWrap() ||
4732 N0->getFlags().hasNoSignedWrap()) &&
4733 !Attr.hasFnAttr(Attribute::MinSize)) {
4734 SDValue IsXZero = DAG.getSetCC(dl, VT, N0.getOperand(0), N1, Cond);
4735 SDValue IsYZero = DAG.getSetCC(dl, VT, N0.getOperand(1), N1, Cond);
4736 unsigned LogicOp = Cond == ISD::SETEQ ? ISD::OR : ISD::AND;
4737 return DAG.getNode(LogicOp, dl, VT, IsXZero, IsYZero);
4738 }
4739
4740 // If the LHS is '(srl (ctlz x), 5)', the RHS is 0/1, and this is an
4741 // equality comparison, then we're just comparing whether X itself is
4742 // zero.
4743 if (N0.getOpcode() == ISD::SRL && (C1.isZero() || C1.isOne()) &&
4744 N0.getOperand(0).getOpcode() == ISD::CTLZ &&
4745 llvm::has_single_bit<uint32_t>(N0.getScalarValueSizeInBits())) {
4746 if (ConstantSDNode *ShAmt = isConstOrConstSplat(N0.getOperand(1))) {
4747 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4748 ShAmt->getAPIntValue() == Log2_32(N0.getScalarValueSizeInBits())) {
4749 if ((C1 == 0) == (Cond == ISD::SETEQ)) {
4750 // (srl (ctlz x), 5) == 0 -> X != 0
4751 // (srl (ctlz x), 5) != 1 -> X != 0
4752 Cond = ISD::SETNE;
4753 } else {
4754 // (srl (ctlz x), 5) != 0 -> X == 0
4755 // (srl (ctlz x), 5) == 1 -> X == 0
4756 Cond = ISD::SETEQ;
4757 }
4758 SDValue Zero = DAG.getConstant(0, dl, N0.getValueType());
4759 return DAG.getSetCC(dl, VT, N0.getOperand(0).getOperand(0), Zero,
4760 Cond);
4761 }
4762 }
4763 }
4764 }
4765
4766 // FIXME: Support vectors.
4767 if (auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode())) {
4768 const APInt &C1 = N1C->getAPIntValue();
4769
4770 // (zext x) == C --> x == (trunc C)
4771 // (sext x) == C --> x == (trunc C)
4772 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4773 DCI.isBeforeLegalize() && N0->hasOneUse()) {
4774 unsigned MinBits = N0.getValueSizeInBits();
4775 SDValue PreExt;
4776 bool Signed = false;
4777 if (N0->getOpcode() == ISD::ZERO_EXTEND) {
4778 // ZExt
4779 MinBits = N0->getOperand(0).getValueSizeInBits();
4780 PreExt = N0->getOperand(0);
4781 } else if (N0->getOpcode() == ISD::AND) {
4782 // DAGCombine turns costly ZExts into ANDs
4783 if (auto *C = dyn_cast<ConstantSDNode>(N0->getOperand(1)))
4784 if ((C->getAPIntValue()+1).isPowerOf2()) {
4785 MinBits = C->getAPIntValue().countr_one();
4786 PreExt = N0->getOperand(0);
4787 }
4788 } else if (N0->getOpcode() == ISD::SIGN_EXTEND) {
4789 // SExt
4790 MinBits = N0->getOperand(0).getValueSizeInBits();
4791 PreExt = N0->getOperand(0);
4792 Signed = true;
4793 } else if (auto *LN0 = dyn_cast<LoadSDNode>(N0)) {
4794 // ZEXTLOAD / SEXTLOAD
4795 if (LN0->getExtensionType() == ISD::ZEXTLOAD) {
4796 MinBits = LN0->getMemoryVT().getSizeInBits();
4797 PreExt = N0;
4798 } else if (LN0->getExtensionType() == ISD::SEXTLOAD) {
4799 Signed = true;
4800 MinBits = LN0->getMemoryVT().getSizeInBits();
4801 PreExt = N0;
4802 }
4803 }
4804
4805 // Figure out how many bits we need to preserve this constant.
4806 unsigned ReqdBits = Signed ? C1.getSignificantBits() : C1.getActiveBits();
4807
4808 // Make sure we're not losing bits from the constant.
4809 if (MinBits > 0 &&
4810 MinBits < C1.getBitWidth() &&
4811 MinBits >= ReqdBits) {
4812 EVT MinVT = EVT::getIntegerVT(*DAG.getContext(), MinBits);
4813 if (isTypeDesirableForOp(ISD::SETCC, MinVT)) {
4814 // Will get folded away.
4815 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, MinVT, PreExt);
4816 if (MinBits == 1 && C1 == 1)
4817 // Invert the condition.
4818 return DAG.getSetCC(dl, VT, Trunc, DAG.getConstant(0, dl, MVT::i1),
4819 Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
4820 SDValue C = DAG.getConstant(C1.trunc(MinBits), dl, MinVT);
4821 return DAG.getSetCC(dl, VT, Trunc, C, Cond);
4822 }
4823
4824 // If truncating the setcc operands is not desirable, we can still
4825 // simplify the expression in some cases:
4826 // setcc ([sz]ext (setcc x, y, cc)), 0, setne) -> setcc (x, y, cc)
4827 // setcc ([sz]ext (setcc x, y, cc)), 0, seteq) -> setcc (x, y, inv(cc))
4828 // setcc (zext (setcc x, y, cc)), 1, setne) -> setcc (x, y, inv(cc))
4829 // setcc (zext (setcc x, y, cc)), 1, seteq) -> setcc (x, y, cc)
4830 // setcc (sext (setcc x, y, cc)), -1, setne) -> setcc (x, y, inv(cc))
4831 // setcc (sext (setcc x, y, cc)), -1, seteq) -> setcc (x, y, cc)
4832 SDValue TopSetCC = N0->getOperand(0);
4833 unsigned N0Opc = N0->getOpcode();
4834 bool SExt = (N0Opc == ISD::SIGN_EXTEND);
4835 if (TopSetCC.getValueType() == MVT::i1 && VT == MVT::i1 &&
4836 TopSetCC.getOpcode() == ISD::SETCC &&
4837 (N0Opc == ISD::ZERO_EXTEND || N0Opc == ISD::SIGN_EXTEND) &&
4838 (isConstFalseVal(N1) ||
4839 isExtendedTrueVal(N1C, N0->getValueType(0), SExt))) {
4840
4841 bool Inverse = (N1C->isZero() && Cond == ISD::SETEQ) ||
4842 (!N1C->isZero() && Cond == ISD::SETNE);
4843
4844 if (!Inverse)
4845 return TopSetCC;
4846
4847 ISD::CondCode InvCond = ISD::getSetCCInverse(
4848 cast<CondCodeSDNode>(TopSetCC.getOperand(2))->get(),
4849 TopSetCC.getOperand(0).getValueType());
4850 return DAG.getSetCC(dl, VT, TopSetCC.getOperand(0),
4851 TopSetCC.getOperand(1),
4852 InvCond);
4853 }
4854 }
4855 }
4856
4857 // If the LHS is '(and load, const)', the RHS is 0, the test is for
4858 // equality or unsigned, and all 1 bits of the const are in the same
4859 // partial word, see if we can shorten the load.
4860 if (DCI.isBeforeLegalize() &&
4861 !ISD::isSignedIntSetCC(Cond) &&
4862 N0.getOpcode() == ISD::AND && C1 == 0 &&
4863 N0.getNode()->hasOneUse() &&
4864 isa<LoadSDNode>(N0.getOperand(0)) &&
4865 N0.getOperand(0).getNode()->hasOneUse() &&
4866 isa<ConstantSDNode>(N0.getOperand(1))) {
4867 auto *Lod = cast<LoadSDNode>(N0.getOperand(0));
4868 APInt bestMask;
4869 unsigned bestWidth = 0, bestOffset = 0;
4870 if (Lod->isSimple() && Lod->isUnindexed() &&
4871 (Lod->getMemoryVT().isByteSized() ||
4872 isPaddedAtMostSignificantBitsWhenStored(Lod->getMemoryVT()))) {
4873 unsigned memWidth = Lod->getMemoryVT().getStoreSizeInBits();
4874 unsigned origWidth = N0.getValueSizeInBits();
4875 unsigned maskWidth = origWidth;
4876 // We can narrow (e.g.) 16-bit extending loads on 32-bit target to
4877 // 8 bits, but have to be careful...
4878 if (Lod->getExtensionType() != ISD::NON_EXTLOAD)
4879 origWidth = Lod->getMemoryVT().getSizeInBits();
4880 const APInt &Mask = N0.getConstantOperandAPInt(1);
4881 // Only consider power-of-2 widths (and at least one byte) as candiates
4882 // for the narrowed load.
4883 for (unsigned width = 8; width < origWidth; width *= 2) {
4884 EVT newVT = EVT::getIntegerVT(*DAG.getContext(), width);
4885 APInt newMask = APInt::getLowBitsSet(maskWidth, width);
4886 // Avoid accessing any padding here for now (we could use memWidth
4887 // instead of origWidth here otherwise).
4888 unsigned maxOffset = origWidth - width;
4889 for (unsigned offset = 0; offset <= maxOffset; offset += 8) {
4890 if (Mask.isSubsetOf(newMask)) {
4891 unsigned ptrOffset =
4892 Layout.isLittleEndian() ? offset : memWidth - width - offset;
4893 unsigned IsFast = 0;
4894 assert((ptrOffset % 8) == 0 && "Non-Bytealigned pointer offset");
4895 Align NewAlign = commonAlignment(Lod->getAlign(), ptrOffset / 8);
4896 if (shouldReduceLoadWidth(Lod, ISD::NON_EXTLOAD, newVT,
4897 ptrOffset / 8) &&
4898 allowsMemoryAccess(
4899 *DAG.getContext(), Layout, newVT, Lod->getAddressSpace(),
4900 NewAlign, Lod->getMemOperand()->getFlags(), &IsFast) &&
4901 IsFast) {
4902 bestOffset = ptrOffset / 8;
4903 bestMask = Mask.lshr(offset);
4904 bestWidth = width;
4905 break;
4906 }
4907 }
4908 newMask <<= 8;
4909 }
4910 if (bestWidth)
4911 break;
4912 }
4913 }
4914 if (bestWidth) {
4915 EVT newVT = EVT::getIntegerVT(*DAG.getContext(), bestWidth);
4916 SDValue Ptr = Lod->getBasePtr();
4917 if (bestOffset != 0)
4918 Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(bestOffset));
4919 SDValue NewLoad =
4920 DAG.getLoad(newVT, dl, Lod->getChain(), Ptr,
4921 Lod->getPointerInfo().getWithOffset(bestOffset),
4922 Lod->getBaseAlign());
4923 SDValue And =
4924 DAG.getNode(ISD::AND, dl, newVT, NewLoad,
4925 DAG.getConstant(bestMask.trunc(bestWidth), dl, newVT));
4926 return DAG.getSetCC(dl, VT, And, DAG.getConstant(0LL, dl, newVT), Cond);
4927 }
4928 }
4929
4930 // If the LHS is a ZERO_EXTEND, perform the comparison on the input.
4931 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
4932 unsigned InSize = N0.getOperand(0).getValueSizeInBits();
4933
4934 // If the comparison constant has bits in the upper part, the
4935 // zero-extended value could never match.
4936 if (C1.intersects(APInt::getHighBitsSet(C1.getBitWidth(),
4937 C1.getBitWidth() - InSize))) {
4938 switch (Cond) {
4939 case ISD::SETUGT:
4940 case ISD::SETUGE:
4941 case ISD::SETEQ:
4942 return DAG.getConstant(0, dl, VT);
4943 case ISD::SETULT:
4944 case ISD::SETULE:
4945 case ISD::SETNE:
4946 return DAG.getConstant(1, dl, VT);
4947 case ISD::SETGT:
4948 case ISD::SETGE:
4949 // True if the sign bit of C1 is set.
4950 return DAG.getConstant(C1.isNegative(), dl, VT);
4951 case ISD::SETLT:
4952 case ISD::SETLE:
4953 // True if the sign bit of C1 isn't set.
4954 return DAG.getConstant(C1.isNonNegative(), dl, VT);
4955 default:
4956 break;
4957 }
4958 }
4959
4960 // Otherwise, we can perform the comparison with the low bits.
4961 switch (Cond) {
4962 case ISD::SETEQ:
4963 case ISD::SETNE:
4964 case ISD::SETUGT:
4965 case ISD::SETUGE:
4966 case ISD::SETULT:
4967 case ISD::SETULE: {
4968 EVT newVT = N0.getOperand(0).getValueType();
4969 // FIXME: Should use isNarrowingProfitable.
4970 if (DCI.isBeforeLegalizeOps() ||
4971 (isOperationLegal(ISD::SETCC, newVT) &&
4972 isCondCodeLegal(Cond, newVT.getSimpleVT()) &&
4973 isTypeDesirableForOp(ISD::SETCC, newVT))) {
4974 EVT NewSetCCVT = getSetCCResultType(Layout, *DAG.getContext(), newVT);
4975 SDValue NewConst = DAG.getConstant(C1.trunc(InSize), dl, newVT);
4976
4977 SDValue NewSetCC = DAG.getSetCC(dl, NewSetCCVT, N0.getOperand(0),
4978 NewConst, Cond);
4979 return DAG.getBoolExtOrTrunc(NewSetCC, dl, VT, N0.getValueType());
4980 }
4981 break;
4982 }
4983 default:
4984 break; // todo, be more careful with signed comparisons
4985 }
4986 } else if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
4987 (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
4988 !isSExtCheaperThanZExt(cast<VTSDNode>(N0.getOperand(1))->getVT(),
4989 OpVT)) {
4990 EVT ExtSrcTy = cast<VTSDNode>(N0.getOperand(1))->getVT();
4991 unsigned ExtSrcTyBits = ExtSrcTy.getSizeInBits();
4992 EVT ExtDstTy = N0.getValueType();
4993 unsigned ExtDstTyBits = ExtDstTy.getSizeInBits();
4994
4995 // If the constant doesn't fit into the number of bits for the source of
4996 // the sign extension, it is impossible for both sides to be equal.
4997 if (C1.getSignificantBits() > ExtSrcTyBits)
4998 return DAG.getBoolConstant(Cond == ISD::SETNE, dl, VT, OpVT);
4999
5000 assert(ExtDstTy == N0.getOperand(0).getValueType() &&
5001 ExtDstTy != ExtSrcTy && "Unexpected types!");
5002 APInt Imm = APInt::getLowBitsSet(ExtDstTyBits, ExtSrcTyBits);
5003 SDValue ZextOp = DAG.getNode(ISD::AND, dl, ExtDstTy, N0.getOperand(0),
5004 DAG.getConstant(Imm, dl, ExtDstTy));
5005 if (!DCI.isCalledByLegalizer())
5006 DCI.AddToWorklist(ZextOp.getNode());
5007 // Otherwise, make this a use of a zext.
5008 return DAG.getSetCC(dl, VT, ZextOp,
5009 DAG.getConstant(C1 & Imm, dl, ExtDstTy), Cond);
5010 } else if ((N1C->isZero() || N1C->isOne()) &&
5011 (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
5012 // SETCC (X), [0|1], [EQ|NE] -> X if X is known 0/1. i1 types are
5013 // excluded as they are handled below whilst checking for foldBooleans.
5014 if ((N0.getOpcode() == ISD::SETCC || VT.getScalarType() != MVT::i1) &&
5015 isTypeLegal(VT) && VT.bitsLE(N0.getValueType()) &&
5016 (N0.getValueType() == MVT::i1 ||
5017 getBooleanContents(N0.getValueType()) == ZeroOrOneBooleanContent) &&
5018 DAG.MaskedValueIsZero(
5019 N0, APInt::getBitsSetFrom(N0.getValueSizeInBits(), 1))) {
5020 bool TrueWhenTrue = (Cond == ISD::SETEQ) ^ (!N1C->isOne());
5021 if (TrueWhenTrue)
5022 return DAG.getNode(ISD::TRUNCATE, dl, VT, N0);
5023 // Invert the condition.
5024 if (N0.getOpcode() == ISD::SETCC) {
5025 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
5026 CC = ISD::getSetCCInverse(CC, N0.getOperand(0).getValueType());
5027 if (DCI.isBeforeLegalizeOps() ||
5028 isCondCodeLegal(CC, N0.getOperand(0).getSimpleValueType()))
5029 return DAG.getSetCC(dl, VT, N0.getOperand(0), N0.getOperand(1), CC);
5030 }
5031 }
5032
5033 if ((N0.getOpcode() == ISD::XOR ||
5034 (N0.getOpcode() == ISD::AND &&
5035 N0.getOperand(0).getOpcode() == ISD::XOR &&
5036 N0.getOperand(1) == N0.getOperand(0).getOperand(1))) &&
5037 isOneConstant(N0.getOperand(1))) {
5038 // If this is (X^1) == 0/1, swap the RHS and eliminate the xor. We
5039 // can only do this if the top bits are known zero.
5040 unsigned BitWidth = N0.getValueSizeInBits();
5041 if (DAG.MaskedValueIsZero(N0,
5042 APInt::getHighBitsSet(BitWidth,
5043 BitWidth-1))) {
5044 // Okay, get the un-inverted input value.
5045 SDValue Val;
5046 if (N0.getOpcode() == ISD::XOR) {
5047 Val = N0.getOperand(0);
5048 } else {
5049 assert(N0.getOpcode() == ISD::AND &&
5050 N0.getOperand(0).getOpcode() == ISD::XOR);
5051 // ((X^1)&1)^1 -> X & 1
5052 Val = DAG.getNode(ISD::AND, dl, N0.getValueType(),
5053 N0.getOperand(0).getOperand(0),
5054 N0.getOperand(1));
5055 }
5056
5057 return DAG.getSetCC(dl, VT, Val, N1,
5058 Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
5059 }
5060 } else if (N1C->isOne()) {
5061 SDValue Op0 = N0;
5062 if (Op0.getOpcode() == ISD::TRUNCATE)
5063 Op0 = Op0.getOperand(0);
5064
5065 if ((Op0.getOpcode() == ISD::XOR) &&
5066 Op0.getOperand(0).getOpcode() == ISD::SETCC &&
5067 Op0.getOperand(1).getOpcode() == ISD::SETCC) {
5068 SDValue XorLHS = Op0.getOperand(0);
5069 SDValue XorRHS = Op0.getOperand(1);
5070 // Ensure that the input setccs return an i1 type or 0/1 value.
5071 if (Op0.getValueType() == MVT::i1 ||
5072 (getBooleanContents(XorLHS.getOperand(0).getValueType()) ==
5073 ZeroOrOneBooleanContent &&
5074 getBooleanContents(XorRHS.getOperand(0).getValueType()) ==
5075 ZeroOrOneBooleanContent)) {
5076 // (xor (setcc), (setcc)) == / != 1 -> (setcc) != / == (setcc)
5077 Cond = (Cond == ISD::SETEQ) ? ISD::SETNE : ISD::SETEQ;
5078 return DAG.getSetCC(dl, VT, XorLHS, XorRHS, Cond);
5079 }
5080 }
5081 if (Op0.getOpcode() == ISD::AND && isOneConstant(Op0.getOperand(1))) {
5082 // If this is (X&1) == / != 1, normalize it to (X&1) != / == 0.
5083 if (Op0.getValueType().bitsGT(VT))
5084 Op0 = DAG.getNode(ISD::AND, dl, VT,
5085 DAG.getNode(ISD::TRUNCATE, dl, VT, Op0.getOperand(0)),
5086 DAG.getConstant(1, dl, VT));
5087 else if (Op0.getValueType().bitsLT(VT))
5088 Op0 = DAG.getNode(ISD::AND, dl, VT,
5089 DAG.getNode(ISD::ANY_EXTEND, dl, VT, Op0.getOperand(0)),
5090 DAG.getConstant(1, dl, VT));
5091
5092 return DAG.getSetCC(dl, VT, Op0,
5093 DAG.getConstant(0, dl, Op0.getValueType()),
5094 Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
5095 }
5096 if (Op0.getOpcode() == ISD::AssertZext &&
5097 cast<VTSDNode>(Op0.getOperand(1))->getVT() == MVT::i1)
5098 return DAG.getSetCC(dl, VT, Op0,
5099 DAG.getConstant(0, dl, Op0.getValueType()),
5100 Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
5101 }
5102 }
5103
5104 // Given:
5105 // icmp eq/ne (urem %x, %y), 0
5106 // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem':
5107 // icmp eq/ne %x, 0
5108 if (N0.getOpcode() == ISD::UREM && N1C->isZero() &&
5109 (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
5110 KnownBits XKnown = DAG.computeKnownBits(N0.getOperand(0));
5111 KnownBits YKnown = DAG.computeKnownBits(N0.getOperand(1));
5112 if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2)
5113 return DAG.getSetCC(dl, VT, N0.getOperand(0), N1, Cond);
5114 }
5115
5116 // Fold set_cc seteq (ashr X, BW-1), -1 -> set_cc setlt X, 0
5117 // and set_cc setne (ashr X, BW-1), -1 -> set_cc setge X, 0
5118 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
5119 N0.getOpcode() == ISD::SRA && isa<ConstantSDNode>(N0.getOperand(1)) &&
5120 N0.getConstantOperandAPInt(1) == OpVT.getScalarSizeInBits() - 1 &&
5121 N1C->isAllOnes()) {
5122 return DAG.getSetCC(dl, VT, N0.getOperand(0),
5123 DAG.getConstant(0, dl, OpVT),
5124 Cond == ISD::SETEQ ? ISD::SETLT : ISD::SETGE);
5125 }
5126
5127 if (SDValue V =
5128 optimizeSetCCOfSignedTruncationCheck(VT, N0, N1, Cond, DCI, dl))
5129 return V;
5130 }
5131
5132 // These simplifications apply to splat vectors as well.
5133 // TODO: Handle more splat vector cases.
5134 if (auto *N1C = isConstOrConstSplat(N1)) {
5135 const APInt &C1 = N1C->getAPIntValue();
5136
5137 APInt MinVal, MaxVal;
5138 unsigned OperandBitSize = N1C->getValueType(0).getScalarSizeInBits();
5139 if (ISD::isSignedIntSetCC(Cond)) {
5140 MinVal = APInt::getSignedMinValue(OperandBitSize);
5141 MaxVal = APInt::getSignedMaxValue(OperandBitSize);
5142 } else {
5143 MinVal = APInt::getMinValue(OperandBitSize);
5144 MaxVal = APInt::getMaxValue(OperandBitSize);
5145 }
5146
5147 // Canonicalize GE/LE comparisons to use GT/LT comparisons.
5148 if (Cond == ISD::SETGE || Cond == ISD::SETUGE) {
5149 // X >= MIN --> true
5150 if (C1 == MinVal)
5151 return DAG.getBoolConstant(true, dl, VT, OpVT);
5152
5153 if (!VT.isVector()) { // TODO: Support this for vectors.
5154 // X >= C0 --> X > (C0 - 1)
5155 APInt C = C1 - 1;
5156 ISD::CondCode NewCC = (Cond == ISD::SETGE) ? ISD::SETGT : ISD::SETUGT;
5157 if ((DCI.isBeforeLegalizeOps() ||
5158 isCondCodeLegal(NewCC, OpVT.getSimpleVT())) &&
5159 (!N1C->isOpaque() || (C.getBitWidth() <= 64 &&
5160 isLegalICmpImmediate(C.getSExtValue())))) {
5161 return DAG.getSetCC(dl, VT, N0,
5162 DAG.getConstant(C, dl, N1.getValueType()),
5163 NewCC);
5164 }
5165 }
5166 }
5167
5168 if (Cond == ISD::SETLE || Cond == ISD::SETULE) {
5169 // X <= MAX --> true
5170 if (C1 == MaxVal)
5171 return DAG.getBoolConstant(true, dl, VT, OpVT);
5172
5173 // X <= C0 --> X < (C0 + 1)
5174 if (!VT.isVector()) { // TODO: Support this for vectors.
5175 APInt C = C1 + 1;
5176 ISD::CondCode NewCC = (Cond == ISD::SETLE) ? ISD::SETLT : ISD::SETULT;
5177 if ((DCI.isBeforeLegalizeOps() ||
5178 isCondCodeLegal(NewCC, OpVT.getSimpleVT())) &&
5179 (!N1C->isOpaque() || (C.getBitWidth() <= 64 &&
5180 isLegalICmpImmediate(C.getSExtValue())))) {
5181 return DAG.getSetCC(dl, VT, N0,
5182 DAG.getConstant(C, dl, N1.getValueType()),
5183 NewCC);
5184 }
5185 }
5186 }
5187
5188 if (Cond == ISD::SETLT || Cond == ISD::SETULT) {
5189 if (C1 == MinVal)
5190 return DAG.getBoolConstant(false, dl, VT, OpVT); // X < MIN --> false
5191
5192 // TODO: Support this for vectors after legalize ops.
5193 if (!VT.isVector() || DCI.isBeforeLegalizeOps()) {
5194 // Canonicalize setlt X, Max --> setne X, Max
5195 if (C1 == MaxVal)
5196 return DAG.getSetCC(dl, VT, N0, N1, ISD::SETNE);
5197
5198 // If we have setult X, 1, turn it into seteq X, 0
5199 if (C1 == MinVal+1)
5200 return DAG.getSetCC(dl, VT, N0,
5201 DAG.getConstant(MinVal, dl, N0.getValueType()),
5202 ISD::SETEQ);
5203 }
5204 }
5205
5206 if (Cond == ISD::SETGT || Cond == ISD::SETUGT) {
5207 if (C1 == MaxVal)
5208 return DAG.getBoolConstant(false, dl, VT, OpVT); // X > MAX --> false
5209
5210 // TODO: Support this for vectors after legalize ops.
5211 if (!VT.isVector() || DCI.isBeforeLegalizeOps()) {
5212 // Canonicalize setgt X, Min --> setne X, Min
5213 if (C1 == MinVal)
5214 return DAG.getSetCC(dl, VT, N0, N1, ISD::SETNE);
5215
5216 // If we have setugt X, Max-1, turn it into seteq X, Max
5217 if (C1 == MaxVal-1)
5218 return DAG.getSetCC(dl, VT, N0,
5219 DAG.getConstant(MaxVal, dl, N0.getValueType()),
5220 ISD::SETEQ);
5221 }
5222 }
5223
5224 if (Cond == ISD::SETEQ || Cond == ISD::SETNE) {
5225 // (X & (C l>>/<< Y)) ==/!= 0 --> ((X <</l>> Y) & C) ==/!= 0
5226 if (C1.isZero())
5227 if (SDValue CC = optimizeSetCCByHoistingAndByConstFromLogicalShift(
5228 VT, N0, N1, Cond, DCI, dl))
5229 return CC;
5230
5231 // For all/any comparisons, replace or(x,shl(y,bw/2)) with and/or(x,y).
5232 // For example, when high 32-bits of i64 X are known clear:
5233 // all bits clear: (X | (Y<<32)) == 0 --> (X | Y) == 0
5234 // all bits set: (X | (Y<<32)) == -1 --> (X & Y) == -1
5235 bool CmpZero = N1C->isZero();
5236 bool CmpNegOne = N1C->isAllOnes();
5237 if ((CmpZero || CmpNegOne) && N0.hasOneUse()) {
5238 // Match or(lo,shl(hi,bw/2)) pattern.
5239 auto IsConcat = [&](SDValue V, SDValue &Lo, SDValue &Hi) {
5240 unsigned EltBits = V.getScalarValueSizeInBits();
5241 if (V.getOpcode() != ISD::OR || (EltBits % 2) != 0)
5242 return false;
5243 SDValue LHS = V.getOperand(0);
5244 SDValue RHS = V.getOperand(1);
5245 APInt HiBits = APInt::getHighBitsSet(EltBits, EltBits / 2);
5246 // Unshifted element must have zero upperbits.
5247 if (RHS.getOpcode() == ISD::SHL &&
5248 isa<ConstantSDNode>(RHS.getOperand(1)) &&
5249 RHS.getConstantOperandAPInt(1) == (EltBits / 2) &&
5250 DAG.MaskedValueIsZero(LHS, HiBits)) {
5251 Lo = LHS;
5252 Hi = RHS.getOperand(0);
5253 return true;
5254 }
5255 if (LHS.getOpcode() == ISD::SHL &&
5256 isa<ConstantSDNode>(LHS.getOperand(1)) &&
5257 LHS.getConstantOperandAPInt(1) == (EltBits / 2) &&
5258 DAG.MaskedValueIsZero(RHS, HiBits)) {
5259 Lo = RHS;
5260 Hi = LHS.getOperand(0);
5261 return true;
5262 }
5263 return false;
5264 };
5265
5266 auto MergeConcat = [&](SDValue Lo, SDValue Hi) {
5267 unsigned EltBits = N0.getScalarValueSizeInBits();
5268 unsigned HalfBits = EltBits / 2;
5269 APInt HiBits = APInt::getHighBitsSet(EltBits, HalfBits);
5270 SDValue LoBits = DAG.getConstant(~HiBits, dl, OpVT);
5271 SDValue HiMask = DAG.getNode(ISD::AND, dl, OpVT, Hi, LoBits);
5272 SDValue NewN0 =
5273 DAG.getNode(CmpZero ? ISD::OR : ISD::AND, dl, OpVT, Lo, HiMask);
5274 SDValue NewN1 = CmpZero ? DAG.getConstant(0, dl, OpVT) : LoBits;
5275 return DAG.getSetCC(dl, VT, NewN0, NewN1, Cond);
5276 };
5277
5278 SDValue Lo, Hi;
5279 if (IsConcat(N0, Lo, Hi))
5280 return MergeConcat(Lo, Hi);
5281
5282 if (N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR) {
5283 SDValue Lo0, Lo1, Hi0, Hi1;
5284 if (IsConcat(N0.getOperand(0), Lo0, Hi0) &&
5285 IsConcat(N0.getOperand(1), Lo1, Hi1)) {
5286 return MergeConcat(DAG.getNode(N0.getOpcode(), dl, OpVT, Lo0, Lo1),
5287 DAG.getNode(N0.getOpcode(), dl, OpVT, Hi0, Hi1));
5288 }
5289 }
5290 }
5291 }
5292
5293 // If we have "setcc X, C0", check to see if we can shrink the immediate
5294 // by changing cc.
5295 // TODO: Support this for vectors after legalize ops.
5296 if (!VT.isVector() || DCI.isBeforeLegalizeOps()) {
5297 // SETUGT X, SINTMAX -> SETLT X, 0
5298 // SETUGE X, SINTMIN -> SETLT X, 0
5299 if ((Cond == ISD::SETUGT && C1.isMaxSignedValue()) ||
5300 (Cond == ISD::SETUGE && C1.isMinSignedValue()))
5301 return DAG.getSetCC(dl, VT, N0,
5302 DAG.getConstant(0, dl, N1.getValueType()),
5303 ISD::SETLT);
5304
5305 // SETULT X, SINTMIN -> SETGT X, -1
5306 // SETULE X, SINTMAX -> SETGT X, -1
5307 if ((Cond == ISD::SETULT && C1.isMinSignedValue()) ||
5308 (Cond == ISD::SETULE && C1.isMaxSignedValue()))
5309 return DAG.getSetCC(dl, VT, N0,
5310 DAG.getAllOnesConstant(dl, N1.getValueType()),
5311 ISD::SETGT);
5312 }
5313 }
5314
5315 // Back to non-vector simplifications.
5316 // TODO: Can we do these for vector splats?
5317 if (auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode())) {
5318 const APInt &C1 = N1C->getAPIntValue();
5319 EVT ShValTy = N0.getValueType();
5320
5321 // Fold bit comparisons when we can. This will result in an
5322 // incorrect value when boolean false is negative one, unless
5323 // the bitsize is 1 in which case the false value is the same
5324 // in practice regardless of the representation.
5325 if ((VT.getSizeInBits() == 1 ||
5326 getBooleanContents(N0.getValueType()) == ZeroOrOneBooleanContent) &&
5327 (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
5328 (VT == ShValTy || (isTypeLegal(VT) && VT.bitsLE(ShValTy))) &&
5329 N0.getOpcode() == ISD::AND) {
5330 if (auto *AndRHS = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5331 if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0 --> (X & 8) >> 3
5332 // Perform the xform if the AND RHS is a single bit.
5333 unsigned ShCt = AndRHS->getAPIntValue().logBase2();
5334 if (AndRHS->getAPIntValue().isPowerOf2() &&
5335 !shouldAvoidTransformToShift(ShValTy, ShCt)) {
5336 return DAG.getNode(
5337 ISD::TRUNCATE, dl, VT,
5338 DAG.getNode(ISD::SRL, dl, ShValTy, N0,
5339 DAG.getShiftAmountConstant(ShCt, ShValTy, dl)));
5340 }
5341 } else if (Cond == ISD::SETEQ && C1 == AndRHS->getAPIntValue()) {
5342 // (X & 8) == 8 --> (X & 8) >> 3
5343 // Perform the xform if C1 is a single bit.
5344 unsigned ShCt = C1.logBase2();
5345 if (C1.isPowerOf2() && !shouldAvoidTransformToShift(ShValTy, ShCt)) {
5346 return DAG.getNode(
5347 ISD::TRUNCATE, dl, VT,
5348 DAG.getNode(ISD::SRL, dl, ShValTy, N0,
5349 DAG.getShiftAmountConstant(ShCt, ShValTy, dl)));
5350 }
5351 }
5352 }
5353 }
5354
5355 if (C1.getSignificantBits() <= 64 &&
5356 !isLegalICmpImmediate(C1.getSExtValue())) {
5357 // (X & -256) == 256 -> (X >> 8) == 1
5358 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
5359 N0.getOpcode() == ISD::AND && N0.hasOneUse()) {
5360 if (auto *AndRHS = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5361 const APInt &AndRHSC = AndRHS->getAPIntValue();
5362 if (AndRHSC.isNegatedPowerOf2() && C1.isSubsetOf(AndRHSC)) {
5363 unsigned ShiftBits = AndRHSC.countr_zero();
5364 if (!shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
5365 SDValue Shift = DAG.getNode(
5366 ISD::SRL, dl, ShValTy, N0.getOperand(0),
5367 DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl));
5368 SDValue CmpRHS = DAG.getConstant(C1.lshr(ShiftBits), dl, ShValTy);
5369 return DAG.getSetCC(dl, VT, Shift, CmpRHS, Cond);
5370 }
5371 }
5372 }
5373 } else if (Cond == ISD::SETULT || Cond == ISD::SETUGE ||
5374 Cond == ISD::SETULE || Cond == ISD::SETUGT) {
5375 bool AdjOne = (Cond == ISD::SETULE || Cond == ISD::SETUGT);
5376 // X < 0x100000000 -> (X >> 32) < 1
5377 // X >= 0x100000000 -> (X >> 32) >= 1
5378 // X <= 0x0ffffffff -> (X >> 32) < 1
5379 // X > 0x0ffffffff -> (X >> 32) >= 1
5380 unsigned ShiftBits;
5381 APInt NewC = C1;
5382 ISD::CondCode NewCond = Cond;
5383 if (AdjOne) {
5384 ShiftBits = C1.countr_one();
5385 NewC = NewC + 1;
5386 NewCond = (Cond == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE;
5387 } else {
5388 ShiftBits = C1.countr_zero();
5389 }
5390 NewC.lshrInPlace(ShiftBits);
5391 if (ShiftBits && NewC.getSignificantBits() <= 64 &&
5392 isLegalICmpImmediate(NewC.getSExtValue()) &&
5393 !shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
5394 SDValue Shift =
5395 DAG.getNode(ISD::SRL, dl, ShValTy, N0,
5396 DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl));
5397 SDValue CmpRHS = DAG.getConstant(NewC, dl, ShValTy);
5398 return DAG.getSetCC(dl, VT, Shift, CmpRHS, NewCond);
5399 }
5400 }
5401 }
5402 }
5403
5404 if (!isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1)) {
5405 auto *CFP = cast<ConstantFPSDNode>(N1);
5406 assert(!CFP->getValueAPF().isNaN() && "Unexpected NaN value");
5407
5408 // Otherwise, we know the RHS is not a NaN. Simplify the node to drop the
5409 // constant if knowing that the operand is non-nan is enough. We prefer to
5410 // have SETO(x,x) instead of SETO(x, 0.0) because this avoids having to
5411 // materialize 0.0.
5412 if (Cond == ISD::SETO || Cond == ISD::SETUO)
5413 return DAG.getSetCC(dl, VT, N0, N0, Cond);
5414
5415 // setcc (fneg x), C -> setcc swap(pred) x, -C
5416 if (N0.getOpcode() == ISD::FNEG) {
5417 ISD::CondCode SwapCond = ISD::getSetCCSwappedOperands(Cond);
5418 if (DCI.isBeforeLegalizeOps() ||
5419 isCondCodeLegal(SwapCond, N0.getSimpleValueType())) {
5420 SDValue NegN1 = DAG.getNode(ISD::FNEG, dl, N0.getValueType(), N1);
5421 return DAG.getSetCC(dl, VT, N0.getOperand(0), NegN1, SwapCond);
5422 }
5423 }
5424
5425 // setueq/setoeq X, (fabs Inf) -> is_fpclass X, fcInf
5426 if (isOperationLegalOrCustom(ISD::IS_FPCLASS, N0.getValueType()) &&
5427 !isFPImmLegal(CFP->getValueAPF(), CFP->getValueType(0))) {
5428 bool IsFabs = N0.getOpcode() == ISD::FABS;
5429 SDValue Op = IsFabs ? N0.getOperand(0) : N0;
5430 if ((Cond == ISD::SETOEQ || Cond == ISD::SETUEQ) && CFP->isInfinity()) {
5431 FPClassTest Flag = CFP->isNegative() ? (IsFabs ? fcNone : fcNegInf)
5432 : (IsFabs ? fcInf : fcPosInf);
5433 if (Cond == ISD::SETUEQ)
5434 Flag |= fcNan;
5435 return DAG.getNode(ISD::IS_FPCLASS, dl, VT, Op,
5436 DAG.getTargetConstant(Flag, dl, MVT::i32));
5437 }
5438 }
5439
5440 // If the condition is not legal, see if we can find an equivalent one
5441 // which is legal.
5442 if (!isCondCodeLegal(Cond, N0.getSimpleValueType())) {
5443 // If the comparison was an awkward floating-point == or != and one of
5444 // the comparison operands is infinity or negative infinity, convert the
5445 // condition to a less-awkward <= or >=.
5446 if (CFP->getValueAPF().isInfinity()) {
5447 bool IsNegInf = CFP->getValueAPF().isNegative();
5448 ISD::CondCode NewCond = ISD::SETCC_INVALID;
5449 switch (Cond) {
5450 case ISD::SETOEQ: NewCond = IsNegInf ? ISD::SETOLE : ISD::SETOGE; break;
5451 case ISD::SETUEQ: NewCond = IsNegInf ? ISD::SETULE : ISD::SETUGE; break;
5452 case ISD::SETUNE: NewCond = IsNegInf ? ISD::SETUGT : ISD::SETULT; break;
5453 case ISD::SETONE: NewCond = IsNegInf ? ISD::SETOGT : ISD::SETOLT; break;
5454 default: break;
5455 }
5456 if (NewCond != ISD::SETCC_INVALID &&
5457 isCondCodeLegal(NewCond, N0.getSimpleValueType()))
5458 return DAG.getSetCC(dl, VT, N0, N1, NewCond);
5459 }
5460 }
5461 }
5462
5463 if (N0 == N1) {
5464 // The sext(setcc()) => setcc() optimization relies on the appropriate
5465 // constant being emitted.
5466 assert(!N0.getValueType().isInteger() &&
5467 "Integer types should be handled by FoldSetCC");
5468
5469 bool EqTrue = ISD::isTrueWhenEqual(Cond);
5470 unsigned UOF = ISD::getUnorderedFlavor(Cond);
5471 if (UOF == 2) // FP operators that are undefined on NaNs.
5472 return DAG.getBoolConstant(EqTrue, dl, VT, OpVT);
5473 if (UOF == unsigned(EqTrue))
5474 return DAG.getBoolConstant(EqTrue, dl, VT, OpVT);
5475 // Otherwise, we can't fold it. However, we can simplify it to SETUO/SETO
5476 // if it is not already.
5477 ISD::CondCode NewCond = UOF == 0 ? ISD::SETO : ISD::SETUO;
5478 if (NewCond != Cond &&
5479 (DCI.isBeforeLegalizeOps() ||
5480 isCondCodeLegal(NewCond, N0.getSimpleValueType())))
5481 return DAG.getSetCC(dl, VT, N0, N1, NewCond);
5482 }
5483
5484 // ~X > ~Y --> Y > X
5485 // ~X < ~Y --> Y < X
5486 // ~X < C --> X > ~C
5487 // ~X > C --> X < ~C
5488 if ((isSignedIntSetCC(Cond) || isUnsignedIntSetCC(Cond)) &&
5489 N0.getValueType().isInteger()) {
5490 if (isBitwiseNot(N0)) {
5491 if (isBitwiseNot(N1))
5492 return DAG.getSetCC(dl, VT, N1.getOperand(0), N0.getOperand(0), Cond);
5493
5494 if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
5495 !DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(0))) {
5496 SDValue Not = DAG.getNOT(dl, N1, OpVT);
5497 return DAG.getSetCC(dl, VT, Not, N0.getOperand(0), Cond);
5498 }
5499 }
5500 }
5501
5502 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
5503 N0.getValueType().isInteger()) {
5504 if (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB ||
5505 N0.getOpcode() == ISD::XOR) {
5506 // Simplify (X+Y) == (X+Z) --> Y == Z
5507 if (N0.getOpcode() == N1.getOpcode()) {
5508 if (N0.getOperand(0) == N1.getOperand(0))
5509 return DAG.getSetCC(dl, VT, N0.getOperand(1), N1.getOperand(1), Cond);
5510 if (N0.getOperand(1) == N1.getOperand(1))
5511 return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(0), Cond);
5512 if (isCommutativeBinOp(N0.getOpcode())) {
5513 // If X op Y == Y op X, try other combinations.
5514 if (N0.getOperand(0) == N1.getOperand(1))
5515 return DAG.getSetCC(dl, VT, N0.getOperand(1), N1.getOperand(0),
5516 Cond);
5517 if (N0.getOperand(1) == N1.getOperand(0))
5518 return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(1),
5519 Cond);
5520 }
5521 }
5522
5523 // If RHS is a legal immediate value for a compare instruction, we need
5524 // to be careful about increasing register pressure needlessly.
5525 bool LegalRHSImm = false;
5526
5527 if (auto *RHSC = dyn_cast<ConstantSDNode>(N1)) {
5528 if (auto *LHSR = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5529 // Turn (X+C1) == C2 --> X == C2-C1
5530 if (N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse())
5531 return DAG.getSetCC(
5532 dl, VT, N0.getOperand(0),
5533 DAG.getConstant(RHSC->getAPIntValue() - LHSR->getAPIntValue(),
5534 dl, N0.getValueType()),
5535 Cond);
5536
5537 // Turn (X^C1) == C2 --> X == C1^C2
5538 if (N0.getOpcode() == ISD::XOR && N0.getNode()->hasOneUse())
5539 return DAG.getSetCC(
5540 dl, VT, N0.getOperand(0),
5541 DAG.getConstant(LHSR->getAPIntValue() ^ RHSC->getAPIntValue(),
5542 dl, N0.getValueType()),
5543 Cond);
5544 }
5545
5546 // Turn (C1-X) == C2 --> X == C1-C2
5547 if (auto *SUBC = dyn_cast<ConstantSDNode>(N0.getOperand(0)))
5548 if (N0.getOpcode() == ISD::SUB && N0.getNode()->hasOneUse())
5549 return DAG.getSetCC(
5550 dl, VT, N0.getOperand(1),
5551 DAG.getConstant(SUBC->getAPIntValue() - RHSC->getAPIntValue(),
5552 dl, N0.getValueType()),
5553 Cond);
5554
5555 // Could RHSC fold directly into a compare?
5556 if (RHSC->getValueType(0).getSizeInBits() <= 64)
5557 LegalRHSImm = isLegalICmpImmediate(RHSC->getSExtValue());
5558 }
5559
5560 // (X+Y) == X --> Y == 0 and similar folds.
5561 // Don't do this if X is an immediate that can fold into a cmp
5562 // instruction and X+Y has other uses. It could be an induction variable
5563 // chain, and the transform would increase register pressure.
5564 if (!LegalRHSImm || N0.hasOneUse())
5565 if (SDValue V = foldSetCCWithBinOp(VT, N0, N1, Cond, dl, DCI))
5566 return V;
5567 }
5568
5569 if (N1.getOpcode() == ISD::ADD || N1.getOpcode() == ISD::SUB ||
5570 N1.getOpcode() == ISD::XOR)
5571 if (SDValue V = foldSetCCWithBinOp(VT, N1, N0, Cond, dl, DCI))
5572 return V;
5573
5574 if (SDValue V = foldSetCCWithAnd(VT, N0, N1, Cond, dl, DCI))
5575 return V;
5576
5577 if (SDValue V = foldSetCCWithOr(VT, N0, N1, Cond, dl, DCI))
5578 return V;
5579 }
5580
5581 // Fold remainder of division by a constant.
5582 if ((N0.getOpcode() == ISD::UREM || N0.getOpcode() == ISD::SREM) &&
5583 N0.hasOneUse() && (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
5584 // When division is cheap or optimizing for minimum size,
5585 // fall through to DIVREM creation by skipping this fold.
5586 if (!isIntDivCheap(VT, Attr) && !Attr.hasFnAttr(Attribute::MinSize)) {
5587 if (N0.getOpcode() == ISD::UREM) {
5588 if (SDValue Folded = buildUREMEqFold(VT, N0, N1, Cond, DCI, dl))
5589 return Folded;
5590 } else if (N0.getOpcode() == ISD::SREM) {
5591 if (SDValue Folded = buildSREMEqFold(VT, N0, N1, Cond, DCI, dl))
5592 return Folded;
5593 }
5594 }
5595 }
5596
5597 // Fold away ALL boolean setcc's.
5598 if (N0.getValueType().getScalarType() == MVT::i1 && foldBooleans) {
5599 SDValue Temp;
5600 switch (Cond) {
5601 default: llvm_unreachable("Unknown integer setcc!");
5602 case ISD::SETEQ: // X == Y -> ~(X^Y)
5603 Temp = DAG.getNode(ISD::XOR, dl, OpVT, N0, N1);
5604 N0 = DAG.getNOT(dl, Temp, OpVT);
5605 if (!DCI.isCalledByLegalizer())
5606 DCI.AddToWorklist(Temp.getNode());
5607 break;
5608 case ISD::SETNE: // X != Y --> (X^Y)
5609 N0 = DAG.getNode(ISD::XOR, dl, OpVT, N0, N1);
5610 break;
5611 case ISD::SETGT: // X >s Y --> X == 0 & Y == 1 --> ~X & Y
5612 case ISD::SETULT: // X <u Y --> X == 0 & Y == 1 --> ~X & Y
5613 Temp = DAG.getNOT(dl, N0, OpVT);
5614 N0 = DAG.getNode(ISD::AND, dl, OpVT, N1, Temp);
5615 if (!DCI.isCalledByLegalizer())
5616 DCI.AddToWorklist(Temp.getNode());
5617 break;
5618 case ISD::SETLT: // X <s Y --> X == 1 & Y == 0 --> ~Y & X
5619 case ISD::SETUGT: // X >u Y --> X == 1 & Y == 0 --> ~Y & X
5620 Temp = DAG.getNOT(dl, N1, OpVT);
5621 N0 = DAG.getNode(ISD::AND, dl, OpVT, N0, Temp);
5622 if (!DCI.isCalledByLegalizer())
5623 DCI.AddToWorklist(Temp.getNode());
5624 break;
5625 case ISD::SETULE: // X <=u Y --> X == 0 | Y == 1 --> ~X | Y
5626 case ISD::SETGE: // X >=s Y --> X == 0 | Y == 1 --> ~X | Y
5627 Temp = DAG.getNOT(dl, N0, OpVT);
5628 N0 = DAG.getNode(ISD::OR, dl, OpVT, N1, Temp);
5629 if (!DCI.isCalledByLegalizer())
5630 DCI.AddToWorklist(Temp.getNode());
5631 break;
5632 case ISD::SETUGE: // X >=u Y --> X == 1 | Y == 0 --> ~Y | X
5633 case ISD::SETLE: // X <=s Y --> X == 1 | Y == 0 --> ~Y | X
5634 Temp = DAG.getNOT(dl, N1, OpVT);
5635 N0 = DAG.getNode(ISD::OR, dl, OpVT, N0, Temp);
5636 break;
5637 }
5638 if (VT.getScalarType() != MVT::i1) {
5639 if (!DCI.isCalledByLegalizer())
5640 DCI.AddToWorklist(N0.getNode());
5641 // FIXME: If running after legalize, we probably can't do this.
5642 ISD::NodeType ExtendCode = getExtendForContent(getBooleanContents(OpVT));
5643 N0 = DAG.getNode(ExtendCode, dl, VT, N0);
5644 }
5645 return N0;
5646 }
5647
5648 // Could not fold it.
5649 return SDValue();
5650 }
5651
5652 /// Returns true (and the GlobalValue and the offset) if the node is a
5653 /// GlobalAddress + offset.
isGAPlusOffset(SDNode * WN,const GlobalValue * & GA,int64_t & Offset) const5654 bool TargetLowering::isGAPlusOffset(SDNode *WN, const GlobalValue *&GA,
5655 int64_t &Offset) const {
5656
5657 SDNode *N = unwrapAddress(SDValue(WN, 0)).getNode();
5658
5659 if (auto *GASD = dyn_cast<GlobalAddressSDNode>(N)) {
5660 GA = GASD->getGlobal();
5661 Offset += GASD->getOffset();
5662 return true;
5663 }
5664
5665 if (N->getOpcode() == ISD::ADD) {
5666 SDValue N1 = N->getOperand(0);
5667 SDValue N2 = N->getOperand(1);
5668 if (isGAPlusOffset(N1.getNode(), GA, Offset)) {
5669 if (auto *V = dyn_cast<ConstantSDNode>(N2)) {
5670 Offset += V->getSExtValue();
5671 return true;
5672 }
5673 } else if (isGAPlusOffset(N2.getNode(), GA, Offset)) {
5674 if (auto *V = dyn_cast<ConstantSDNode>(N1)) {
5675 Offset += V->getSExtValue();
5676 return true;
5677 }
5678 }
5679 }
5680
5681 return false;
5682 }
5683
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const5684 SDValue TargetLowering::PerformDAGCombine(SDNode *N,
5685 DAGCombinerInfo &DCI) const {
5686 // Default implementation: no optimization.
5687 return SDValue();
5688 }
5689
5690 //===----------------------------------------------------------------------===//
5691 // Inline Assembler Implementation Methods
5692 //===----------------------------------------------------------------------===//
5693
5694 TargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const5695 TargetLowering::getConstraintType(StringRef Constraint) const {
5696 unsigned S = Constraint.size();
5697
5698 if (S == 1) {
5699 switch (Constraint[0]) {
5700 default: break;
5701 case 'r':
5702 return C_RegisterClass;
5703 case 'm': // memory
5704 case 'o': // offsetable
5705 case 'V': // not offsetable
5706 return C_Memory;
5707 case 'p': // Address.
5708 return C_Address;
5709 case 'n': // Simple Integer
5710 case 'E': // Floating Point Constant
5711 case 'F': // Floating Point Constant
5712 return C_Immediate;
5713 case 'i': // Simple Integer or Relocatable Constant
5714 case 's': // Relocatable Constant
5715 case 'X': // Allow ANY value.
5716 case 'I': // Target registers.
5717 case 'J':
5718 case 'K':
5719 case 'L':
5720 case 'M':
5721 case 'N':
5722 case 'O':
5723 case 'P':
5724 case '<':
5725 case '>':
5726 return C_Other;
5727 }
5728 }
5729
5730 if (S > 1 && Constraint[0] == '{' && Constraint[S - 1] == '}') {
5731 if (S == 8 && Constraint.substr(1, 6) == "memory") // "{memory}"
5732 return C_Memory;
5733 return C_Register;
5734 }
5735 return C_Unknown;
5736 }
5737
5738 /// Try to replace an X constraint, which matches anything, with another that
5739 /// has more specific requirements based on the type of the corresponding
5740 /// operand.
LowerXConstraint(EVT ConstraintVT) const5741 const char *TargetLowering::LowerXConstraint(EVT ConstraintVT) const {
5742 if (ConstraintVT.isInteger())
5743 return "r";
5744 if (ConstraintVT.isFloatingPoint())
5745 return "f"; // works for many targets
5746 return nullptr;
5747 }
5748
LowerAsmOutputForConstraint(SDValue & Chain,SDValue & Glue,const SDLoc & DL,const AsmOperandInfo & OpInfo,SelectionDAG & DAG) const5749 SDValue TargetLowering::LowerAsmOutputForConstraint(
5750 SDValue &Chain, SDValue &Glue, const SDLoc &DL,
5751 const AsmOperandInfo &OpInfo, SelectionDAG &DAG) const {
5752 return SDValue();
5753 }
5754
5755 /// Lower the specified operand into the Ops vector.
5756 /// If it is invalid, don't add anything to Ops.
LowerAsmOperandForConstraint(SDValue Op,StringRef Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const5757 void TargetLowering::LowerAsmOperandForConstraint(SDValue Op,
5758 StringRef Constraint,
5759 std::vector<SDValue> &Ops,
5760 SelectionDAG &DAG) const {
5761
5762 if (Constraint.size() > 1)
5763 return;
5764
5765 char ConstraintLetter = Constraint[0];
5766 switch (ConstraintLetter) {
5767 default: break;
5768 case 'X': // Allows any operand
5769 case 'i': // Simple Integer or Relocatable Constant
5770 case 'n': // Simple Integer
5771 case 's': { // Relocatable Constant
5772
5773 ConstantSDNode *C;
5774 uint64_t Offset = 0;
5775
5776 // Match (GA) or (C) or (GA+C) or (GA-C) or ((GA+C)+C) or (((GA+C)+C)+C),
5777 // etc., since getelementpointer is variadic. We can't use
5778 // SelectionDAG::FoldSymbolOffset because it expects the GA to be accessible
5779 // while in this case the GA may be furthest from the root node which is
5780 // likely an ISD::ADD.
5781 while (true) {
5782 if ((C = dyn_cast<ConstantSDNode>(Op)) && ConstraintLetter != 's') {
5783 // gcc prints these as sign extended. Sign extend value to 64 bits
5784 // now; without this it would get ZExt'd later in
5785 // ScheduleDAGSDNodes::EmitNode, which is very generic.
5786 bool IsBool = C->getConstantIntValue()->getBitWidth() == 1;
5787 BooleanContent BCont = getBooleanContents(MVT::i64);
5788 ISD::NodeType ExtOpc =
5789 IsBool ? getExtendForContent(BCont) : ISD::SIGN_EXTEND;
5790 int64_t ExtVal =
5791 ExtOpc == ISD::ZERO_EXTEND ? C->getZExtValue() : C->getSExtValue();
5792 Ops.push_back(
5793 DAG.getTargetConstant(Offset + ExtVal, SDLoc(C), MVT::i64));
5794 return;
5795 }
5796 if (ConstraintLetter != 'n') {
5797 if (const auto *GA = dyn_cast<GlobalAddressSDNode>(Op)) {
5798 Ops.push_back(DAG.getTargetGlobalAddress(GA->getGlobal(), SDLoc(Op),
5799 GA->getValueType(0),
5800 Offset + GA->getOffset()));
5801 return;
5802 }
5803 if (const auto *BA = dyn_cast<BlockAddressSDNode>(Op)) {
5804 Ops.push_back(DAG.getTargetBlockAddress(
5805 BA->getBlockAddress(), BA->getValueType(0),
5806 Offset + BA->getOffset(), BA->getTargetFlags()));
5807 return;
5808 }
5809 if (isa<BasicBlockSDNode>(Op)) {
5810 Ops.push_back(Op);
5811 return;
5812 }
5813 }
5814 const unsigned OpCode = Op.getOpcode();
5815 if (OpCode == ISD::ADD || OpCode == ISD::SUB) {
5816 if ((C = dyn_cast<ConstantSDNode>(Op.getOperand(0))))
5817 Op = Op.getOperand(1);
5818 // Subtraction is not commutative.
5819 else if (OpCode == ISD::ADD &&
5820 (C = dyn_cast<ConstantSDNode>(Op.getOperand(1))))
5821 Op = Op.getOperand(0);
5822 else
5823 return;
5824 Offset += (OpCode == ISD::ADD ? 1 : -1) * C->getSExtValue();
5825 continue;
5826 }
5827 return;
5828 }
5829 break;
5830 }
5831 }
5832 }
5833
CollectTargetIntrinsicOperands(const CallInst & I,SmallVectorImpl<SDValue> & Ops,SelectionDAG & DAG) const5834 void TargetLowering::CollectTargetIntrinsicOperands(
5835 const CallInst &I, SmallVectorImpl<SDValue> &Ops, SelectionDAG &DAG) const {
5836 }
5837
5838 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * RI,StringRef Constraint,MVT VT) const5839 TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *RI,
5840 StringRef Constraint,
5841 MVT VT) const {
5842 if (!Constraint.starts_with("{"))
5843 return std::make_pair(0u, static_cast<TargetRegisterClass *>(nullptr));
5844 assert(*(Constraint.end() - 1) == '}' && "Not a brace enclosed constraint?");
5845
5846 // Remove the braces from around the name.
5847 StringRef RegName(Constraint.data() + 1, Constraint.size() - 2);
5848
5849 std::pair<unsigned, const TargetRegisterClass *> R =
5850 std::make_pair(0u, static_cast<const TargetRegisterClass *>(nullptr));
5851
5852 // Figure out which register class contains this reg.
5853 for (const TargetRegisterClass *RC : RI->regclasses()) {
5854 // If none of the value types for this register class are valid, we
5855 // can't use it. For example, 64-bit reg classes on 32-bit targets.
5856 if (!isLegalRC(*RI, *RC))
5857 continue;
5858
5859 for (const MCPhysReg &PR : *RC) {
5860 if (RegName.equals_insensitive(RI->getRegAsmName(PR))) {
5861 std::pair<unsigned, const TargetRegisterClass *> S =
5862 std::make_pair(PR, RC);
5863
5864 // If this register class has the requested value type, return it,
5865 // otherwise keep searching and return the first class found
5866 // if no other is found which explicitly has the requested type.
5867 if (RI->isTypeLegalForClass(*RC, VT))
5868 return S;
5869 if (!R.second)
5870 R = S;
5871 }
5872 }
5873 }
5874
5875 return R;
5876 }
5877
5878 //===----------------------------------------------------------------------===//
5879 // Constraint Selection.
5880
5881 /// Return true of this is an input operand that is a matching constraint like
5882 /// "4".
isMatchingInputConstraint() const5883 bool TargetLowering::AsmOperandInfo::isMatchingInputConstraint() const {
5884 assert(!ConstraintCode.empty() && "No known constraint!");
5885 return isdigit(static_cast<unsigned char>(ConstraintCode[0]));
5886 }
5887
5888 /// If this is an input matching constraint, this method returns the output
5889 /// operand it matches.
getMatchedOperand() const5890 unsigned TargetLowering::AsmOperandInfo::getMatchedOperand() const {
5891 assert(!ConstraintCode.empty() && "No known constraint!");
5892 return atoi(ConstraintCode.c_str());
5893 }
5894
5895 /// Split up the constraint string from the inline assembly value into the
5896 /// specific constraints and their prefixes, and also tie in the associated
5897 /// operand values.
5898 /// If this returns an empty vector, and if the constraint string itself
5899 /// isn't empty, there was an error parsing.
5900 TargetLowering::AsmOperandInfoVector
ParseConstraints(const DataLayout & DL,const TargetRegisterInfo * TRI,const CallBase & Call) const5901 TargetLowering::ParseConstraints(const DataLayout &DL,
5902 const TargetRegisterInfo *TRI,
5903 const CallBase &Call) const {
5904 /// Information about all of the constraints.
5905 AsmOperandInfoVector ConstraintOperands;
5906 const InlineAsm *IA = cast<InlineAsm>(Call.getCalledOperand());
5907 unsigned maCount = 0; // Largest number of multiple alternative constraints.
5908
5909 // Do a prepass over the constraints, canonicalizing them, and building up the
5910 // ConstraintOperands list.
5911 unsigned ArgNo = 0; // ArgNo - The argument of the CallInst.
5912 unsigned ResNo = 0; // ResNo - The result number of the next output.
5913 unsigned LabelNo = 0; // LabelNo - CallBr indirect dest number.
5914
5915 for (InlineAsm::ConstraintInfo &CI : IA->ParseConstraints()) {
5916 ConstraintOperands.emplace_back(std::move(CI));
5917 AsmOperandInfo &OpInfo = ConstraintOperands.back();
5918
5919 // Update multiple alternative constraint count.
5920 if (OpInfo.multipleAlternatives.size() > maCount)
5921 maCount = OpInfo.multipleAlternatives.size();
5922
5923 OpInfo.ConstraintVT = MVT::Other;
5924
5925 // Compute the value type for each operand.
5926 switch (OpInfo.Type) {
5927 case InlineAsm::isOutput:
5928 // Indirect outputs just consume an argument.
5929 if (OpInfo.isIndirect) {
5930 OpInfo.CallOperandVal = Call.getArgOperand(ArgNo);
5931 break;
5932 }
5933
5934 // The return value of the call is this value. As such, there is no
5935 // corresponding argument.
5936 assert(!Call.getType()->isVoidTy() && "Bad inline asm!");
5937 if (auto *STy = dyn_cast<StructType>(Call.getType())) {
5938 OpInfo.ConstraintVT =
5939 getAsmOperandValueType(DL, STy->getElementType(ResNo))
5940 .getSimpleVT();
5941 } else {
5942 assert(ResNo == 0 && "Asm only has one result!");
5943 OpInfo.ConstraintVT =
5944 getAsmOperandValueType(DL, Call.getType()).getSimpleVT();
5945 }
5946 ++ResNo;
5947 break;
5948 case InlineAsm::isInput:
5949 OpInfo.CallOperandVal = Call.getArgOperand(ArgNo);
5950 break;
5951 case InlineAsm::isLabel:
5952 OpInfo.CallOperandVal = cast<CallBrInst>(&Call)->getIndirectDest(LabelNo);
5953 ++LabelNo;
5954 continue;
5955 case InlineAsm::isClobber:
5956 // Nothing to do.
5957 break;
5958 }
5959
5960 if (OpInfo.CallOperandVal) {
5961 llvm::Type *OpTy = OpInfo.CallOperandVal->getType();
5962 if (OpInfo.isIndirect) {
5963 OpTy = Call.getParamElementType(ArgNo);
5964 assert(OpTy && "Indirect operand must have elementtype attribute");
5965 }
5966
5967 // Look for vector wrapped in a struct. e.g. { <16 x i8> }.
5968 if (StructType *STy = dyn_cast<StructType>(OpTy))
5969 if (STy->getNumElements() == 1)
5970 OpTy = STy->getElementType(0);
5971
5972 // If OpTy is not a single value, it may be a struct/union that we
5973 // can tile with integers.
5974 if (!OpTy->isSingleValueType() && OpTy->isSized()) {
5975 unsigned BitSize = DL.getTypeSizeInBits(OpTy);
5976 switch (BitSize) {
5977 default: break;
5978 case 1:
5979 case 8:
5980 case 16:
5981 case 32:
5982 case 64:
5983 case 128:
5984 OpTy = IntegerType::get(OpTy->getContext(), BitSize);
5985 break;
5986 }
5987 }
5988
5989 EVT VT = getAsmOperandValueType(DL, OpTy, true);
5990 OpInfo.ConstraintVT = VT.isSimple() ? VT.getSimpleVT() : MVT::Other;
5991 ArgNo++;
5992 }
5993 }
5994
5995 // If we have multiple alternative constraints, select the best alternative.
5996 if (!ConstraintOperands.empty()) {
5997 if (maCount) {
5998 unsigned bestMAIndex = 0;
5999 int bestWeight = -1;
6000 // weight: -1 = invalid match, and 0 = so-so match to 5 = good match.
6001 int weight = -1;
6002 unsigned maIndex;
6003 // Compute the sums of the weights for each alternative, keeping track
6004 // of the best (highest weight) one so far.
6005 for (maIndex = 0; maIndex < maCount; ++maIndex) {
6006 int weightSum = 0;
6007 for (unsigned cIndex = 0, eIndex = ConstraintOperands.size();
6008 cIndex != eIndex; ++cIndex) {
6009 AsmOperandInfo &OpInfo = ConstraintOperands[cIndex];
6010 if (OpInfo.Type == InlineAsm::isClobber)
6011 continue;
6012
6013 // If this is an output operand with a matching input operand,
6014 // look up the matching input. If their types mismatch, e.g. one
6015 // is an integer, the other is floating point, or their sizes are
6016 // different, flag it as an maCantMatch.
6017 if (OpInfo.hasMatchingInput()) {
6018 AsmOperandInfo &Input = ConstraintOperands[OpInfo.MatchingInput];
6019 if (OpInfo.ConstraintVT != Input.ConstraintVT) {
6020 if ((OpInfo.ConstraintVT.isInteger() !=
6021 Input.ConstraintVT.isInteger()) ||
6022 (OpInfo.ConstraintVT.getSizeInBits() !=
6023 Input.ConstraintVT.getSizeInBits())) {
6024 weightSum = -1; // Can't match.
6025 break;
6026 }
6027 }
6028 }
6029 weight = getMultipleConstraintMatchWeight(OpInfo, maIndex);
6030 if (weight == -1) {
6031 weightSum = -1;
6032 break;
6033 }
6034 weightSum += weight;
6035 }
6036 // Update best.
6037 if (weightSum > bestWeight) {
6038 bestWeight = weightSum;
6039 bestMAIndex = maIndex;
6040 }
6041 }
6042
6043 // Now select chosen alternative in each constraint.
6044 for (AsmOperandInfo &cInfo : ConstraintOperands)
6045 if (cInfo.Type != InlineAsm::isClobber)
6046 cInfo.selectAlternative(bestMAIndex);
6047 }
6048 }
6049
6050 // Check and hook up tied operands, choose constraint code to use.
6051 for (unsigned cIndex = 0, eIndex = ConstraintOperands.size();
6052 cIndex != eIndex; ++cIndex) {
6053 AsmOperandInfo &OpInfo = ConstraintOperands[cIndex];
6054
6055 // If this is an output operand with a matching input operand, look up the
6056 // matching input. If their types mismatch, e.g. one is an integer, the
6057 // other is floating point, or their sizes are different, flag it as an
6058 // error.
6059 if (OpInfo.hasMatchingInput()) {
6060 AsmOperandInfo &Input = ConstraintOperands[OpInfo.MatchingInput];
6061
6062 if (OpInfo.ConstraintVT != Input.ConstraintVT) {
6063 std::pair<unsigned, const TargetRegisterClass *> MatchRC =
6064 getRegForInlineAsmConstraint(TRI, OpInfo.ConstraintCode,
6065 OpInfo.ConstraintVT);
6066 std::pair<unsigned, const TargetRegisterClass *> InputRC =
6067 getRegForInlineAsmConstraint(TRI, Input.ConstraintCode,
6068 Input.ConstraintVT);
6069 const bool OutOpIsIntOrFP = OpInfo.ConstraintVT.isInteger() ||
6070 OpInfo.ConstraintVT.isFloatingPoint();
6071 const bool InOpIsIntOrFP = Input.ConstraintVT.isInteger() ||
6072 Input.ConstraintVT.isFloatingPoint();
6073 if ((OutOpIsIntOrFP != InOpIsIntOrFP) ||
6074 (MatchRC.second != InputRC.second)) {
6075 report_fatal_error("Unsupported asm: input constraint"
6076 " with a matching output constraint of"
6077 " incompatible type!");
6078 }
6079 }
6080 }
6081 }
6082
6083 return ConstraintOperands;
6084 }
6085
6086 /// Return a number indicating our preference for chosing a type of constraint
6087 /// over another, for the purpose of sorting them. Immediates are almost always
6088 /// preferrable (when they can be emitted). A higher return value means a
6089 /// stronger preference for one constraint type relative to another.
6090 /// FIXME: We should prefer registers over memory but doing so may lead to
6091 /// unrecoverable register exhaustion later.
6092 /// https://github.com/llvm/llvm-project/issues/20571
getConstraintPiority(TargetLowering::ConstraintType CT)6093 static unsigned getConstraintPiority(TargetLowering::ConstraintType CT) {
6094 switch (CT) {
6095 case TargetLowering::C_Immediate:
6096 case TargetLowering::C_Other:
6097 return 4;
6098 case TargetLowering::C_Memory:
6099 case TargetLowering::C_Address:
6100 return 3;
6101 case TargetLowering::C_RegisterClass:
6102 return 2;
6103 case TargetLowering::C_Register:
6104 return 1;
6105 case TargetLowering::C_Unknown:
6106 return 0;
6107 }
6108 llvm_unreachable("Invalid constraint type");
6109 }
6110
6111 /// Examine constraint type and operand type and determine a weight value.
6112 /// This object must already have been set up with the operand type
6113 /// and the current alternative constraint selected.
6114 TargetLowering::ConstraintWeight
getMultipleConstraintMatchWeight(AsmOperandInfo & info,int maIndex) const6115 TargetLowering::getMultipleConstraintMatchWeight(
6116 AsmOperandInfo &info, int maIndex) const {
6117 InlineAsm::ConstraintCodeVector *rCodes;
6118 if (maIndex >= (int)info.multipleAlternatives.size())
6119 rCodes = &info.Codes;
6120 else
6121 rCodes = &info.multipleAlternatives[maIndex].Codes;
6122 ConstraintWeight BestWeight = CW_Invalid;
6123
6124 // Loop over the options, keeping track of the most general one.
6125 for (const std::string &rCode : *rCodes) {
6126 ConstraintWeight weight =
6127 getSingleConstraintMatchWeight(info, rCode.c_str());
6128 if (weight > BestWeight)
6129 BestWeight = weight;
6130 }
6131
6132 return BestWeight;
6133 }
6134
6135 /// Examine constraint type and operand type and determine a weight value.
6136 /// This object must already have been set up with the operand type
6137 /// and the current alternative constraint selected.
6138 TargetLowering::ConstraintWeight
getSingleConstraintMatchWeight(AsmOperandInfo & info,const char * constraint) const6139 TargetLowering::getSingleConstraintMatchWeight(
6140 AsmOperandInfo &info, const char *constraint) const {
6141 ConstraintWeight weight = CW_Invalid;
6142 Value *CallOperandVal = info.CallOperandVal;
6143 // If we don't have a value, we can't do a match,
6144 // but allow it at the lowest weight.
6145 if (!CallOperandVal)
6146 return CW_Default;
6147 // Look at the constraint type.
6148 switch (*constraint) {
6149 case 'i': // immediate integer.
6150 case 'n': // immediate integer with a known value.
6151 if (isa<ConstantInt>(CallOperandVal))
6152 weight = CW_Constant;
6153 break;
6154 case 's': // non-explicit intregal immediate.
6155 if (isa<GlobalValue>(CallOperandVal))
6156 weight = CW_Constant;
6157 break;
6158 case 'E': // immediate float if host format.
6159 case 'F': // immediate float.
6160 if (isa<ConstantFP>(CallOperandVal))
6161 weight = CW_Constant;
6162 break;
6163 case '<': // memory operand with autodecrement.
6164 case '>': // memory operand with autoincrement.
6165 case 'm': // memory operand.
6166 case 'o': // offsettable memory operand
6167 case 'V': // non-offsettable memory operand
6168 weight = CW_Memory;
6169 break;
6170 case 'r': // general register.
6171 case 'g': // general register, memory operand or immediate integer.
6172 // note: Clang converts "g" to "imr".
6173 if (CallOperandVal->getType()->isIntegerTy())
6174 weight = CW_Register;
6175 break;
6176 case 'X': // any operand.
6177 default:
6178 weight = CW_Default;
6179 break;
6180 }
6181 return weight;
6182 }
6183
6184 /// If there are multiple different constraints that we could pick for this
6185 /// operand (e.g. "imr") try to pick the 'best' one.
6186 /// This is somewhat tricky: constraints (TargetLowering::ConstraintType) fall
6187 /// into seven classes:
6188 /// Register -> one specific register
6189 /// RegisterClass -> a group of regs
6190 /// Memory -> memory
6191 /// Address -> a symbolic memory reference
6192 /// Immediate -> immediate values
6193 /// Other -> magic values (such as "Flag Output Operands")
6194 /// Unknown -> something we don't recognize yet and can't handle
6195 /// Ideally, we would pick the most specific constraint possible: if we have
6196 /// something that fits into a register, we would pick it. The problem here
6197 /// is that if we have something that could either be in a register or in
6198 /// memory that use of the register could cause selection of *other*
6199 /// operands to fail: they might only succeed if we pick memory. Because of
6200 /// this the heuristic we use is:
6201 ///
6202 /// 1) If there is an 'other' constraint, and if the operand is valid for
6203 /// that constraint, use it. This makes us take advantage of 'i'
6204 /// constraints when available.
6205 /// 2) Otherwise, pick the most general constraint present. This prefers
6206 /// 'm' over 'r', for example.
6207 ///
getConstraintPreferences(TargetLowering::AsmOperandInfo & OpInfo) const6208 TargetLowering::ConstraintGroup TargetLowering::getConstraintPreferences(
6209 TargetLowering::AsmOperandInfo &OpInfo) const {
6210 ConstraintGroup Ret;
6211
6212 Ret.reserve(OpInfo.Codes.size());
6213 for (StringRef Code : OpInfo.Codes) {
6214 TargetLowering::ConstraintType CType = getConstraintType(Code);
6215
6216 // Indirect 'other' or 'immediate' constraints are not allowed.
6217 if (OpInfo.isIndirect && !(CType == TargetLowering::C_Memory ||
6218 CType == TargetLowering::C_Register ||
6219 CType == TargetLowering::C_RegisterClass))
6220 continue;
6221
6222 // Things with matching constraints can only be registers, per gcc
6223 // documentation. This mainly affects "g" constraints.
6224 if (CType == TargetLowering::C_Memory && OpInfo.hasMatchingInput())
6225 continue;
6226
6227 Ret.emplace_back(Code, CType);
6228 }
6229
6230 llvm::stable_sort(Ret, [](ConstraintPair a, ConstraintPair b) {
6231 return getConstraintPiority(a.second) > getConstraintPiority(b.second);
6232 });
6233
6234 return Ret;
6235 }
6236
6237 /// If we have an immediate, see if we can lower it. Return true if we can,
6238 /// false otherwise.
lowerImmediateIfPossible(TargetLowering::ConstraintPair & P,SDValue Op,SelectionDAG * DAG,const TargetLowering & TLI)6239 static bool lowerImmediateIfPossible(TargetLowering::ConstraintPair &P,
6240 SDValue Op, SelectionDAG *DAG,
6241 const TargetLowering &TLI) {
6242
6243 assert((P.second == TargetLowering::C_Other ||
6244 P.second == TargetLowering::C_Immediate) &&
6245 "need immediate or other");
6246
6247 if (!Op.getNode())
6248 return false;
6249
6250 std::vector<SDValue> ResultOps;
6251 TLI.LowerAsmOperandForConstraint(Op, P.first, ResultOps, *DAG);
6252 return !ResultOps.empty();
6253 }
6254
6255 /// Determines the constraint code and constraint type to use for the specific
6256 /// AsmOperandInfo, setting OpInfo.ConstraintCode and OpInfo.ConstraintType.
ComputeConstraintToUse(AsmOperandInfo & OpInfo,SDValue Op,SelectionDAG * DAG) const6257 void TargetLowering::ComputeConstraintToUse(AsmOperandInfo &OpInfo,
6258 SDValue Op,
6259 SelectionDAG *DAG) const {
6260 assert(!OpInfo.Codes.empty() && "Must have at least one constraint");
6261
6262 // Single-letter constraints ('r') are very common.
6263 if (OpInfo.Codes.size() == 1) {
6264 OpInfo.ConstraintCode = OpInfo.Codes[0];
6265 OpInfo.ConstraintType = getConstraintType(OpInfo.ConstraintCode);
6266 } else {
6267 ConstraintGroup G = getConstraintPreferences(OpInfo);
6268 if (G.empty())
6269 return;
6270
6271 unsigned BestIdx = 0;
6272 for (const unsigned E = G.size();
6273 BestIdx < E && (G[BestIdx].second == TargetLowering::C_Other ||
6274 G[BestIdx].second == TargetLowering::C_Immediate);
6275 ++BestIdx) {
6276 if (lowerImmediateIfPossible(G[BestIdx], Op, DAG, *this))
6277 break;
6278 // If we're out of constraints, just pick the first one.
6279 if (BestIdx + 1 == E) {
6280 BestIdx = 0;
6281 break;
6282 }
6283 }
6284
6285 OpInfo.ConstraintCode = G[BestIdx].first;
6286 OpInfo.ConstraintType = G[BestIdx].second;
6287 }
6288
6289 // 'X' matches anything.
6290 if (OpInfo.ConstraintCode == "X" && OpInfo.CallOperandVal) {
6291 // Constants are handled elsewhere. For Functions, the type here is the
6292 // type of the result, which is not what we want to look at; leave them
6293 // alone.
6294 Value *v = OpInfo.CallOperandVal;
6295 if (isa<ConstantInt>(v) || isa<Function>(v)) {
6296 return;
6297 }
6298
6299 if (isa<BasicBlock>(v) || isa<BlockAddress>(v)) {
6300 OpInfo.ConstraintCode = "i";
6301 return;
6302 }
6303
6304 // Otherwise, try to resolve it to something we know about by looking at
6305 // the actual operand type.
6306 if (const char *Repl = LowerXConstraint(OpInfo.ConstraintVT)) {
6307 OpInfo.ConstraintCode = Repl;
6308 OpInfo.ConstraintType = getConstraintType(OpInfo.ConstraintCode);
6309 }
6310 }
6311 }
6312
6313 /// Given an exact SDIV by a constant, create a multiplication
6314 /// with the multiplicative inverse of the constant.
6315 /// Ref: "Hacker's Delight" by Henry Warren, 2nd Edition, p. 242
BuildExactSDIV(const TargetLowering & TLI,SDNode * N,const SDLoc & dl,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created)6316 static SDValue BuildExactSDIV(const TargetLowering &TLI, SDNode *N,
6317 const SDLoc &dl, SelectionDAG &DAG,
6318 SmallVectorImpl<SDNode *> &Created) {
6319 SDValue Op0 = N->getOperand(0);
6320 SDValue Op1 = N->getOperand(1);
6321 EVT VT = N->getValueType(0);
6322 EVT SVT = VT.getScalarType();
6323 EVT ShVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
6324 EVT ShSVT = ShVT.getScalarType();
6325
6326 bool UseSRA = false;
6327 SmallVector<SDValue, 16> Shifts, Factors;
6328
6329 auto BuildSDIVPattern = [&](ConstantSDNode *C) {
6330 if (C->isZero())
6331 return false;
6332 APInt Divisor = C->getAPIntValue();
6333 unsigned Shift = Divisor.countr_zero();
6334 if (Shift) {
6335 Divisor.ashrInPlace(Shift);
6336 UseSRA = true;
6337 }
6338 APInt Factor = Divisor.multiplicativeInverse();
6339 Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
6340 Factors.push_back(DAG.getConstant(Factor, dl, SVT));
6341 return true;
6342 };
6343
6344 // Collect all magic values from the build vector.
6345 if (!ISD::matchUnaryPredicate(Op1, BuildSDIVPattern))
6346 return SDValue();
6347
6348 SDValue Shift, Factor;
6349 if (Op1.getOpcode() == ISD::BUILD_VECTOR) {
6350 Shift = DAG.getBuildVector(ShVT, dl, Shifts);
6351 Factor = DAG.getBuildVector(VT, dl, Factors);
6352 } else if (Op1.getOpcode() == ISD::SPLAT_VECTOR) {
6353 assert(Shifts.size() == 1 && Factors.size() == 1 &&
6354 "Expected matchUnaryPredicate to return one element for scalable "
6355 "vectors");
6356 Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
6357 Factor = DAG.getSplatVector(VT, dl, Factors[0]);
6358 } else {
6359 assert(isa<ConstantSDNode>(Op1) && "Expected a constant");
6360 Shift = Shifts[0];
6361 Factor = Factors[0];
6362 }
6363
6364 SDValue Res = Op0;
6365 if (UseSRA) {
6366 Res = DAG.getNode(ISD::SRA, dl, VT, Res, Shift, SDNodeFlags::Exact);
6367 Created.push_back(Res.getNode());
6368 }
6369
6370 return DAG.getNode(ISD::MUL, dl, VT, Res, Factor);
6371 }
6372
6373 /// Given an exact UDIV by a constant, create a multiplication
6374 /// with the multiplicative inverse of the constant.
6375 /// Ref: "Hacker's Delight" by Henry Warren, 2nd Edition, p. 242
BuildExactUDIV(const TargetLowering & TLI,SDNode * N,const SDLoc & dl,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created)6376 static SDValue BuildExactUDIV(const TargetLowering &TLI, SDNode *N,
6377 const SDLoc &dl, SelectionDAG &DAG,
6378 SmallVectorImpl<SDNode *> &Created) {
6379 EVT VT = N->getValueType(0);
6380 EVT SVT = VT.getScalarType();
6381 EVT ShVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
6382 EVT ShSVT = ShVT.getScalarType();
6383
6384 bool UseSRL = false;
6385 SmallVector<SDValue, 16> Shifts, Factors;
6386
6387 auto BuildUDIVPattern = [&](ConstantSDNode *C) {
6388 if (C->isZero())
6389 return false;
6390 APInt Divisor = C->getAPIntValue();
6391 unsigned Shift = Divisor.countr_zero();
6392 if (Shift) {
6393 Divisor.lshrInPlace(Shift);
6394 UseSRL = true;
6395 }
6396 // Calculate the multiplicative inverse modulo BW.
6397 APInt Factor = Divisor.multiplicativeInverse();
6398 Shifts.push_back(DAG.getConstant(Shift, dl, ShSVT));
6399 Factors.push_back(DAG.getConstant(Factor, dl, SVT));
6400 return true;
6401 };
6402
6403 SDValue Op1 = N->getOperand(1);
6404
6405 // Collect all magic values from the build vector.
6406 if (!ISD::matchUnaryPredicate(Op1, BuildUDIVPattern))
6407 return SDValue();
6408
6409 SDValue Shift, Factor;
6410 if (Op1.getOpcode() == ISD::BUILD_VECTOR) {
6411 Shift = DAG.getBuildVector(ShVT, dl, Shifts);
6412 Factor = DAG.getBuildVector(VT, dl, Factors);
6413 } else if (Op1.getOpcode() == ISD::SPLAT_VECTOR) {
6414 assert(Shifts.size() == 1 && Factors.size() == 1 &&
6415 "Expected matchUnaryPredicate to return one element for scalable "
6416 "vectors");
6417 Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
6418 Factor = DAG.getSplatVector(VT, dl, Factors[0]);
6419 } else {
6420 assert(isa<ConstantSDNode>(Op1) && "Expected a constant");
6421 Shift = Shifts[0];
6422 Factor = Factors[0];
6423 }
6424
6425 SDValue Res = N->getOperand(0);
6426 if (UseSRL) {
6427 Res = DAG.getNode(ISD::SRL, dl, VT, Res, Shift, SDNodeFlags::Exact);
6428 Created.push_back(Res.getNode());
6429 }
6430
6431 return DAG.getNode(ISD::MUL, dl, VT, Res, Factor);
6432 }
6433
BuildSDIVPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const6434 SDValue TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
6435 SelectionDAG &DAG,
6436 SmallVectorImpl<SDNode *> &Created) const {
6437 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
6438 if (isIntDivCheap(N->getValueType(0), Attr))
6439 return SDValue(N, 0); // Lower SDIV as SDIV
6440 return SDValue();
6441 }
6442
6443 SDValue
BuildSREMPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const6444 TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor,
6445 SelectionDAG &DAG,
6446 SmallVectorImpl<SDNode *> &Created) const {
6447 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
6448 if (isIntDivCheap(N->getValueType(0), Attr))
6449 return SDValue(N, 0); // Lower SREM as SREM
6450 return SDValue();
6451 }
6452
6453 /// Build sdiv by power-of-2 with conditional move instructions
6454 /// Ref: "Hacker's Delight" by Henry Warren 10-1
6455 /// If conditional move/branch is preferred, we lower sdiv x, +/-2**k into:
6456 /// bgez x, label
6457 /// add x, x, 2**k-1
6458 /// label:
6459 /// sra res, x, k
6460 /// neg res, res (when the divisor is negative)
buildSDIVPow2WithCMov(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const6461 SDValue TargetLowering::buildSDIVPow2WithCMov(
6462 SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
6463 SmallVectorImpl<SDNode *> &Created) const {
6464 unsigned Lg2 = Divisor.countr_zero();
6465 EVT VT = N->getValueType(0);
6466
6467 SDLoc DL(N);
6468 SDValue N0 = N->getOperand(0);
6469 SDValue Zero = DAG.getConstant(0, DL, VT);
6470 APInt Lg2Mask = APInt::getLowBitsSet(VT.getSizeInBits(), Lg2);
6471 SDValue Pow2MinusOne = DAG.getConstant(Lg2Mask, DL, VT);
6472
6473 // If N0 is negative, we need to add (Pow2 - 1) to it before shifting right.
6474 EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6475 SDValue Cmp = DAG.getSetCC(DL, CCVT, N0, Zero, ISD::SETLT);
6476 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Pow2MinusOne);
6477 SDValue CMov = DAG.getNode(ISD::SELECT, DL, VT, Cmp, Add, N0);
6478
6479 Created.push_back(Cmp.getNode());
6480 Created.push_back(Add.getNode());
6481 Created.push_back(CMov.getNode());
6482
6483 // Divide by pow2.
6484 SDValue SRA =
6485 DAG.getNode(ISD::SRA, DL, VT, CMov, DAG.getConstant(Lg2, DL, VT));
6486
6487 // If we're dividing by a positive value, we're done. Otherwise, we must
6488 // negate the result.
6489 if (Divisor.isNonNegative())
6490 return SRA;
6491
6492 Created.push_back(SRA.getNode());
6493 return DAG.getNode(ISD::SUB, DL, VT, Zero, SRA);
6494 }
6495
6496 /// Given an ISD::SDIV node expressing a divide by constant,
6497 /// return a DAG expression to select that will generate the same value by
6498 /// multiplying by a magic number.
6499 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N,SelectionDAG & DAG,bool IsAfterLegalization,bool IsAfterLegalTypes,SmallVectorImpl<SDNode * > & Created) const6500 SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
6501 bool IsAfterLegalization,
6502 bool IsAfterLegalTypes,
6503 SmallVectorImpl<SDNode *> &Created) const {
6504 SDLoc dl(N);
6505 EVT VT = N->getValueType(0);
6506 EVT SVT = VT.getScalarType();
6507 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
6508 EVT ShSVT = ShVT.getScalarType();
6509 unsigned EltBits = VT.getScalarSizeInBits();
6510 EVT MulVT;
6511
6512 // Check to see if we can do this.
6513 // FIXME: We should be more aggressive here.
6514 if (!isTypeLegal(VT)) {
6515 // Limit this to simple scalars for now.
6516 if (VT.isVector() || !VT.isSimple())
6517 return SDValue();
6518
6519 // If this type will be promoted to a large enough type with a legal
6520 // multiply operation, we can go ahead and do this transform.
6521 if (getTypeAction(VT.getSimpleVT()) != TypePromoteInteger)
6522 return SDValue();
6523
6524 MulVT = getTypeToTransformTo(*DAG.getContext(), VT);
6525 if (MulVT.getSizeInBits() < (2 * EltBits) ||
6526 !isOperationLegal(ISD::MUL, MulVT))
6527 return SDValue();
6528 }
6529
6530 // If the sdiv has an 'exact' bit we can use a simpler lowering.
6531 if (N->getFlags().hasExact())
6532 return BuildExactSDIV(*this, N, dl, DAG, Created);
6533
6534 SmallVector<SDValue, 16> MagicFactors, Factors, Shifts, ShiftMasks;
6535
6536 auto BuildSDIVPattern = [&](ConstantSDNode *C) {
6537 if (C->isZero())
6538 return false;
6539
6540 const APInt &Divisor = C->getAPIntValue();
6541 SignedDivisionByConstantInfo magics = SignedDivisionByConstantInfo::get(Divisor);
6542 int NumeratorFactor = 0;
6543 int ShiftMask = -1;
6544
6545 if (Divisor.isOne() || Divisor.isAllOnes()) {
6546 // If d is +1/-1, we just multiply the numerator by +1/-1.
6547 NumeratorFactor = Divisor.getSExtValue();
6548 magics.Magic = 0;
6549 magics.ShiftAmount = 0;
6550 ShiftMask = 0;
6551 } else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
6552 // If d > 0 and m < 0, add the numerator.
6553 NumeratorFactor = 1;
6554 } else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
6555 // If d < 0 and m > 0, subtract the numerator.
6556 NumeratorFactor = -1;
6557 }
6558
6559 MagicFactors.push_back(DAG.getConstant(magics.Magic, dl, SVT));
6560 Factors.push_back(DAG.getSignedConstant(NumeratorFactor, dl, SVT));
6561 Shifts.push_back(DAG.getConstant(magics.ShiftAmount, dl, ShSVT));
6562 ShiftMasks.push_back(DAG.getSignedConstant(ShiftMask, dl, SVT));
6563 return true;
6564 };
6565
6566 SDValue N0 = N->getOperand(0);
6567 SDValue N1 = N->getOperand(1);
6568
6569 // Collect the shifts / magic values from each element.
6570 if (!ISD::matchUnaryPredicate(N1, BuildSDIVPattern))
6571 return SDValue();
6572
6573 SDValue MagicFactor, Factor, Shift, ShiftMask;
6574 if (N1.getOpcode() == ISD::BUILD_VECTOR) {
6575 MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors);
6576 Factor = DAG.getBuildVector(VT, dl, Factors);
6577 Shift = DAG.getBuildVector(ShVT, dl, Shifts);
6578 ShiftMask = DAG.getBuildVector(VT, dl, ShiftMasks);
6579 } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
6580 assert(MagicFactors.size() == 1 && Factors.size() == 1 &&
6581 Shifts.size() == 1 && ShiftMasks.size() == 1 &&
6582 "Expected matchUnaryPredicate to return one element for scalable "
6583 "vectors");
6584 MagicFactor = DAG.getSplatVector(VT, dl, MagicFactors[0]);
6585 Factor = DAG.getSplatVector(VT, dl, Factors[0]);
6586 Shift = DAG.getSplatVector(ShVT, dl, Shifts[0]);
6587 ShiftMask = DAG.getSplatVector(VT, dl, ShiftMasks[0]);
6588 } else {
6589 assert(isa<ConstantSDNode>(N1) && "Expected a constant");
6590 MagicFactor = MagicFactors[0];
6591 Factor = Factors[0];
6592 Shift = Shifts[0];
6593 ShiftMask = ShiftMasks[0];
6594 }
6595
6596 // Multiply the numerator (operand 0) by the magic value.
6597 // FIXME: We should support doing a MUL in a wider type.
6598 auto GetMULHS = [&](SDValue X, SDValue Y) {
6599 // If the type isn't legal, use a wider mul of the type calculated
6600 // earlier.
6601 if (!isTypeLegal(VT)) {
6602 X = DAG.getNode(ISD::SIGN_EXTEND, dl, MulVT, X);
6603 Y = DAG.getNode(ISD::SIGN_EXTEND, dl, MulVT, Y);
6604 Y = DAG.getNode(ISD::MUL, dl, MulVT, X, Y);
6605 Y = DAG.getNode(ISD::SRL, dl, MulVT, Y,
6606 DAG.getShiftAmountConstant(EltBits, MulVT, dl));
6607 return DAG.getNode(ISD::TRUNCATE, dl, VT, Y);
6608 }
6609
6610 if (isOperationLegalOrCustom(ISD::MULHS, VT, IsAfterLegalization))
6611 return DAG.getNode(ISD::MULHS, dl, VT, X, Y);
6612 if (isOperationLegalOrCustom(ISD::SMUL_LOHI, VT, IsAfterLegalization)) {
6613 SDValue LoHi =
6614 DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(VT, VT), X, Y);
6615 return SDValue(LoHi.getNode(), 1);
6616 }
6617 // If type twice as wide legal, widen and use a mul plus a shift.
6618 unsigned Size = VT.getScalarSizeInBits();
6619 EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), Size * 2);
6620 if (VT.isVector())
6621 WideVT = EVT::getVectorVT(*DAG.getContext(), WideVT,
6622 VT.getVectorElementCount());
6623 // Some targets like AMDGPU try to go from SDIV to SDIVREM which is then
6624 // custom lowered. This is very expensive so avoid it at all costs for
6625 // constant divisors.
6626 if ((!IsAfterLegalTypes && isOperationExpand(ISD::SDIV, VT) &&
6627 isOperationCustom(ISD::SDIVREM, VT.getScalarType())) ||
6628 isOperationLegalOrCustom(ISD::MUL, WideVT)) {
6629 X = DAG.getNode(ISD::SIGN_EXTEND, dl, WideVT, X);
6630 Y = DAG.getNode(ISD::SIGN_EXTEND, dl, WideVT, Y);
6631 Y = DAG.getNode(ISD::MUL, dl, WideVT, X, Y);
6632 Y = DAG.getNode(ISD::SRL, dl, WideVT, Y,
6633 DAG.getShiftAmountConstant(EltBits, WideVT, dl));
6634 return DAG.getNode(ISD::TRUNCATE, dl, VT, Y);
6635 }
6636 return SDValue();
6637 };
6638
6639 SDValue Q = GetMULHS(N0, MagicFactor);
6640 if (!Q)
6641 return SDValue();
6642
6643 Created.push_back(Q.getNode());
6644
6645 // (Optionally) Add/subtract the numerator using Factor.
6646 Factor = DAG.getNode(ISD::MUL, dl, VT, N0, Factor);
6647 Created.push_back(Factor.getNode());
6648 Q = DAG.getNode(ISD::ADD, dl, VT, Q, Factor);
6649 Created.push_back(Q.getNode());
6650
6651 // Shift right algebraic by shift value.
6652 Q = DAG.getNode(ISD::SRA, dl, VT, Q, Shift);
6653 Created.push_back(Q.getNode());
6654
6655 // Extract the sign bit, mask it and add it to the quotient.
6656 SDValue SignShift = DAG.getConstant(EltBits - 1, dl, ShVT);
6657 SDValue T = DAG.getNode(ISD::SRL, dl, VT, Q, SignShift);
6658 Created.push_back(T.getNode());
6659 T = DAG.getNode(ISD::AND, dl, VT, T, ShiftMask);
6660 Created.push_back(T.getNode());
6661 return DAG.getNode(ISD::ADD, dl, VT, Q, T);
6662 }
6663
6664 /// Given an ISD::UDIV node expressing a divide by constant,
6665 /// return a DAG expression to select that will generate the same value by
6666 /// multiplying by a magic number.
6667 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N,SelectionDAG & DAG,bool IsAfterLegalization,bool IsAfterLegalTypes,SmallVectorImpl<SDNode * > & Created) const6668 SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
6669 bool IsAfterLegalization,
6670 bool IsAfterLegalTypes,
6671 SmallVectorImpl<SDNode *> &Created) const {
6672 SDLoc dl(N);
6673 EVT VT = N->getValueType(0);
6674 EVT SVT = VT.getScalarType();
6675 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
6676 EVT ShSVT = ShVT.getScalarType();
6677 unsigned EltBits = VT.getScalarSizeInBits();
6678 EVT MulVT;
6679
6680 // Check to see if we can do this.
6681 // FIXME: We should be more aggressive here.
6682 if (!isTypeLegal(VT)) {
6683 // Limit this to simple scalars for now.
6684 if (VT.isVector() || !VT.isSimple())
6685 return SDValue();
6686
6687 // If this type will be promoted to a large enough type with a legal
6688 // multiply operation, we can go ahead and do this transform.
6689 if (getTypeAction(VT.getSimpleVT()) != TypePromoteInteger)
6690 return SDValue();
6691
6692 MulVT = getTypeToTransformTo(*DAG.getContext(), VT);
6693 if (MulVT.getSizeInBits() < (2 * EltBits) ||
6694 !isOperationLegal(ISD::MUL, MulVT))
6695 return SDValue();
6696 }
6697
6698 // If the udiv has an 'exact' bit we can use a simpler lowering.
6699 if (N->getFlags().hasExact())
6700 return BuildExactUDIV(*this, N, dl, DAG, Created);
6701
6702 SDValue N0 = N->getOperand(0);
6703 SDValue N1 = N->getOperand(1);
6704
6705 // Try to use leading zeros of the dividend to reduce the multiplier and
6706 // avoid expensive fixups.
6707 unsigned KnownLeadingZeros = DAG.computeKnownBits(N0).countMinLeadingZeros();
6708
6709 bool UseNPQ = false, UsePreShift = false, UsePostShift = false;
6710 SmallVector<SDValue, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
6711
6712 auto BuildUDIVPattern = [&](ConstantSDNode *C) {
6713 if (C->isZero())
6714 return false;
6715 const APInt& Divisor = C->getAPIntValue();
6716
6717 SDValue PreShift, MagicFactor, NPQFactor, PostShift;
6718
6719 // Magic algorithm doesn't work for division by 1. We need to emit a select
6720 // at the end.
6721 if (Divisor.isOne()) {
6722 PreShift = PostShift = DAG.getUNDEF(ShSVT);
6723 MagicFactor = NPQFactor = DAG.getUNDEF(SVT);
6724 } else {
6725 UnsignedDivisionByConstantInfo magics =
6726 UnsignedDivisionByConstantInfo::get(
6727 Divisor, std::min(KnownLeadingZeros, Divisor.countl_zero()));
6728
6729 MagicFactor = DAG.getConstant(magics.Magic, dl, SVT);
6730
6731 assert(magics.PreShift < Divisor.getBitWidth() &&
6732 "We shouldn't generate an undefined shift!");
6733 assert(magics.PostShift < Divisor.getBitWidth() &&
6734 "We shouldn't generate an undefined shift!");
6735 assert((!magics.IsAdd || magics.PreShift == 0) &&
6736 "Unexpected pre-shift");
6737 PreShift = DAG.getConstant(magics.PreShift, dl, ShSVT);
6738 PostShift = DAG.getConstant(magics.PostShift, dl, ShSVT);
6739 NPQFactor = DAG.getConstant(
6740 magics.IsAdd ? APInt::getOneBitSet(EltBits, EltBits - 1)
6741 : APInt::getZero(EltBits),
6742 dl, SVT);
6743 UseNPQ |= magics.IsAdd;
6744 UsePreShift |= magics.PreShift != 0;
6745 UsePostShift |= magics.PostShift != 0;
6746 }
6747
6748 PreShifts.push_back(PreShift);
6749 MagicFactors.push_back(MagicFactor);
6750 NPQFactors.push_back(NPQFactor);
6751 PostShifts.push_back(PostShift);
6752 return true;
6753 };
6754
6755 // Collect the shifts/magic values from each element.
6756 if (!ISD::matchUnaryPredicate(N1, BuildUDIVPattern))
6757 return SDValue();
6758
6759 SDValue PreShift, PostShift, MagicFactor, NPQFactor;
6760 if (N1.getOpcode() == ISD::BUILD_VECTOR) {
6761 PreShift = DAG.getBuildVector(ShVT, dl, PreShifts);
6762 MagicFactor = DAG.getBuildVector(VT, dl, MagicFactors);
6763 NPQFactor = DAG.getBuildVector(VT, dl, NPQFactors);
6764 PostShift = DAG.getBuildVector(ShVT, dl, PostShifts);
6765 } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
6766 assert(PreShifts.size() == 1 && MagicFactors.size() == 1 &&
6767 NPQFactors.size() == 1 && PostShifts.size() == 1 &&
6768 "Expected matchUnaryPredicate to return one for scalable vectors");
6769 PreShift = DAG.getSplatVector(ShVT, dl, PreShifts[0]);
6770 MagicFactor = DAG.getSplatVector(VT, dl, MagicFactors[0]);
6771 NPQFactor = DAG.getSplatVector(VT, dl, NPQFactors[0]);
6772 PostShift = DAG.getSplatVector(ShVT, dl, PostShifts[0]);
6773 } else {
6774 assert(isa<ConstantSDNode>(N1) && "Expected a constant");
6775 PreShift = PreShifts[0];
6776 MagicFactor = MagicFactors[0];
6777 PostShift = PostShifts[0];
6778 }
6779
6780 SDValue Q = N0;
6781 if (UsePreShift) {
6782 Q = DAG.getNode(ISD::SRL, dl, VT, Q, PreShift);
6783 Created.push_back(Q.getNode());
6784 }
6785
6786 // FIXME: We should support doing a MUL in a wider type.
6787 auto GetMULHU = [&](SDValue X, SDValue Y) {
6788 // If the type isn't legal, use a wider mul of the type calculated
6789 // earlier.
6790 if (!isTypeLegal(VT)) {
6791 X = DAG.getNode(ISD::ZERO_EXTEND, dl, MulVT, X);
6792 Y = DAG.getNode(ISD::ZERO_EXTEND, dl, MulVT, Y);
6793 Y = DAG.getNode(ISD::MUL, dl, MulVT, X, Y);
6794 Y = DAG.getNode(ISD::SRL, dl, MulVT, Y,
6795 DAG.getShiftAmountConstant(EltBits, MulVT, dl));
6796 return DAG.getNode(ISD::TRUNCATE, dl, VT, Y);
6797 }
6798
6799 if (isOperationLegalOrCustom(ISD::MULHU, VT, IsAfterLegalization))
6800 return DAG.getNode(ISD::MULHU, dl, VT, X, Y);
6801 if (isOperationLegalOrCustom(ISD::UMUL_LOHI, VT, IsAfterLegalization)) {
6802 SDValue LoHi =
6803 DAG.getNode(ISD::UMUL_LOHI, dl, DAG.getVTList(VT, VT), X, Y);
6804 return SDValue(LoHi.getNode(), 1);
6805 }
6806 // If type twice as wide legal, widen and use a mul plus a shift.
6807 unsigned Size = VT.getScalarSizeInBits();
6808 EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), Size * 2);
6809 if (VT.isVector())
6810 WideVT = EVT::getVectorVT(*DAG.getContext(), WideVT,
6811 VT.getVectorElementCount());
6812 // Some targets like AMDGPU try to go from UDIV to UDIVREM which is then
6813 // custom lowered. This is very expensive so avoid it at all costs for
6814 // constant divisors.
6815 if ((!IsAfterLegalTypes && isOperationExpand(ISD::UDIV, VT) &&
6816 isOperationCustom(ISD::UDIVREM, VT.getScalarType())) ||
6817 isOperationLegalOrCustom(ISD::MUL, WideVT)) {
6818 X = DAG.getNode(ISD::ZERO_EXTEND, dl, WideVT, X);
6819 Y = DAG.getNode(ISD::ZERO_EXTEND, dl, WideVT, Y);
6820 Y = DAG.getNode(ISD::MUL, dl, WideVT, X, Y);
6821 Y = DAG.getNode(ISD::SRL, dl, WideVT, Y,
6822 DAG.getShiftAmountConstant(EltBits, WideVT, dl));
6823 return DAG.getNode(ISD::TRUNCATE, dl, VT, Y);
6824 }
6825 return SDValue(); // No mulhu or equivalent
6826 };
6827
6828 // Multiply the numerator (operand 0) by the magic value.
6829 Q = GetMULHU(Q, MagicFactor);
6830 if (!Q)
6831 return SDValue();
6832
6833 Created.push_back(Q.getNode());
6834
6835 if (UseNPQ) {
6836 SDValue NPQ = DAG.getNode(ISD::SUB, dl, VT, N0, Q);
6837 Created.push_back(NPQ.getNode());
6838
6839 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
6840 // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
6841 if (VT.isVector())
6842 NPQ = GetMULHU(NPQ, NPQFactor);
6843 else
6844 NPQ = DAG.getNode(ISD::SRL, dl, VT, NPQ, DAG.getConstant(1, dl, ShVT));
6845
6846 Created.push_back(NPQ.getNode());
6847
6848 Q = DAG.getNode(ISD::ADD, dl, VT, NPQ, Q);
6849 Created.push_back(Q.getNode());
6850 }
6851
6852 if (UsePostShift) {
6853 Q = DAG.getNode(ISD::SRL, dl, VT, Q, PostShift);
6854 Created.push_back(Q.getNode());
6855 }
6856
6857 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6858
6859 SDValue One = DAG.getConstant(1, dl, VT);
6860 SDValue IsOne = DAG.getSetCC(dl, SetCCVT, N1, One, ISD::SETEQ);
6861 return DAG.getSelect(dl, VT, IsOne, N0, Q);
6862 }
6863
6864 /// If all values in Values that *don't* match the predicate are same 'splat'
6865 /// value, then replace all values with that splat value.
6866 /// Else, if AlternativeReplacement was provided, then replace all values that
6867 /// do match predicate with AlternativeReplacement value.
6868 static void
turnVectorIntoSplatVector(MutableArrayRef<SDValue> Values,std::function<bool (SDValue)> Predicate,SDValue AlternativeReplacement=SDValue ())6869 turnVectorIntoSplatVector(MutableArrayRef<SDValue> Values,
6870 std::function<bool(SDValue)> Predicate,
6871 SDValue AlternativeReplacement = SDValue()) {
6872 SDValue Replacement;
6873 // Is there a value for which the Predicate does *NOT* match? What is it?
6874 auto SplatValue = llvm::find_if_not(Values, Predicate);
6875 if (SplatValue != Values.end()) {
6876 // Does Values consist only of SplatValue's and values matching Predicate?
6877 if (llvm::all_of(Values, [Predicate, SplatValue](SDValue Value) {
6878 return Value == *SplatValue || Predicate(Value);
6879 })) // Then we shall replace values matching predicate with SplatValue.
6880 Replacement = *SplatValue;
6881 }
6882 if (!Replacement) {
6883 // Oops, we did not find the "baseline" splat value.
6884 if (!AlternativeReplacement)
6885 return; // Nothing to do.
6886 // Let's replace with provided value then.
6887 Replacement = AlternativeReplacement;
6888 }
6889 std::replace_if(Values.begin(), Values.end(), Predicate, Replacement);
6890 }
6891
6892 /// Given an ISD::UREM used only by an ISD::SETEQ or ISD::SETNE
6893 /// where the divisor is constant and the comparison target is zero,
6894 /// return a DAG expression that will generate the same comparison result
6895 /// using only multiplications, additions and shifts/rotations.
6896 /// Ref: "Hacker's Delight" 10-17.
buildUREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const6897 SDValue TargetLowering::buildUREMEqFold(EVT SETCCVT, SDValue REMNode,
6898 SDValue CompTargetNode,
6899 ISD::CondCode Cond,
6900 DAGCombinerInfo &DCI,
6901 const SDLoc &DL) const {
6902 SmallVector<SDNode *, 5> Built;
6903 if (SDValue Folded = prepareUREMEqFold(SETCCVT, REMNode, CompTargetNode, Cond,
6904 DCI, DL, Built)) {
6905 for (SDNode *N : Built)
6906 DCI.AddToWorklist(N);
6907 return Folded;
6908 }
6909
6910 return SDValue();
6911 }
6912
6913 SDValue
prepareUREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL,SmallVectorImpl<SDNode * > & Created) const6914 TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
6915 SDValue CompTargetNode, ISD::CondCode Cond,
6916 DAGCombinerInfo &DCI, const SDLoc &DL,
6917 SmallVectorImpl<SDNode *> &Created) const {
6918 // fold (seteq/ne (urem N, D), 0) -> (setule/ugt (rotr (mul N, P), K), Q)
6919 // - D must be constant, with D = D0 * 2^K where D0 is odd
6920 // - P is the multiplicative inverse of D0 modulo 2^W
6921 // - Q = floor(((2^W) - 1) / D)
6922 // where W is the width of the common type of N and D.
6923 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
6924 "Only applicable for (in)equality comparisons.");
6925
6926 SelectionDAG &DAG = DCI.DAG;
6927
6928 EVT VT = REMNode.getValueType();
6929 EVT SVT = VT.getScalarType();
6930 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
6931 EVT ShSVT = ShVT.getScalarType();
6932
6933 // If MUL is unavailable, we cannot proceed in any case.
6934 if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::MUL, VT))
6935 return SDValue();
6936
6937 bool ComparingWithAllZeros = true;
6938 bool AllComparisonsWithNonZerosAreTautological = true;
6939 bool HadTautologicalLanes = false;
6940 bool AllLanesAreTautological = true;
6941 bool HadEvenDivisor = false;
6942 bool AllDivisorsArePowerOfTwo = true;
6943 bool HadTautologicalInvertedLanes = false;
6944 SmallVector<SDValue, 16> PAmts, KAmts, QAmts;
6945
6946 auto BuildUREMPattern = [&](ConstantSDNode *CDiv, ConstantSDNode *CCmp) {
6947 // Division by 0 is UB. Leave it to be constant-folded elsewhere.
6948 if (CDiv->isZero())
6949 return false;
6950
6951 const APInt &D = CDiv->getAPIntValue();
6952 const APInt &Cmp = CCmp->getAPIntValue();
6953
6954 ComparingWithAllZeros &= Cmp.isZero();
6955
6956 // x u% C1` is *always* less than C1. So given `x u% C1 == C2`,
6957 // if C2 is not less than C1, the comparison is always false.
6958 // But we will only be able to produce the comparison that will give the
6959 // opposive tautological answer. So this lane would need to be fixed up.
6960 bool TautologicalInvertedLane = D.ule(Cmp);
6961 HadTautologicalInvertedLanes |= TautologicalInvertedLane;
6962
6963 // If all lanes are tautological (either all divisors are ones, or divisor
6964 // is not greater than the constant we are comparing with),
6965 // we will prefer to avoid the fold.
6966 bool TautologicalLane = D.isOne() || TautologicalInvertedLane;
6967 HadTautologicalLanes |= TautologicalLane;
6968 AllLanesAreTautological &= TautologicalLane;
6969
6970 // If we are comparing with non-zero, we need'll need to subtract said
6971 // comparison value from the LHS. But there is no point in doing that if
6972 // every lane where we are comparing with non-zero is tautological..
6973 if (!Cmp.isZero())
6974 AllComparisonsWithNonZerosAreTautological &= TautologicalLane;
6975
6976 // Decompose D into D0 * 2^K
6977 unsigned K = D.countr_zero();
6978 assert((!D.isOne() || (K == 0)) && "For divisor '1' we won't rotate.");
6979 APInt D0 = D.lshr(K);
6980
6981 // D is even if it has trailing zeros.
6982 HadEvenDivisor |= (K != 0);
6983 // D is a power-of-two if D0 is one.
6984 // If all divisors are power-of-two, we will prefer to avoid the fold.
6985 AllDivisorsArePowerOfTwo &= D0.isOne();
6986
6987 // P = inv(D0, 2^W)
6988 // 2^W requires W + 1 bits, so we have to extend and then truncate.
6989 unsigned W = D.getBitWidth();
6990 APInt P = D0.multiplicativeInverse();
6991 assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
6992
6993 // Q = floor((2^W - 1) u/ D)
6994 // R = ((2^W - 1) u% D)
6995 APInt Q, R;
6996 APInt::udivrem(APInt::getAllOnes(W), D, Q, R);
6997
6998 // If we are comparing with zero, then that comparison constant is okay,
6999 // else it may need to be one less than that.
7000 if (Cmp.ugt(R))
7001 Q -= 1;
7002
7003 assert(APInt::getAllOnes(ShSVT.getSizeInBits()).ugt(K) &&
7004 "We are expecting that K is always less than all-ones for ShSVT");
7005
7006 // If the lane is tautological the result can be constant-folded.
7007 if (TautologicalLane) {
7008 // Set P and K amount to a bogus values so we can try to splat them.
7009 P = 0;
7010 K = -1;
7011 // And ensure that comparison constant is tautological,
7012 // it will always compare true/false.
7013 Q = -1;
7014 }
7015
7016 PAmts.push_back(DAG.getConstant(P, DL, SVT));
7017 KAmts.push_back(
7018 DAG.getConstant(APInt(ShSVT.getSizeInBits(), K, /*isSigned=*/false,
7019 /*implicitTrunc=*/true),
7020 DL, ShSVT));
7021 QAmts.push_back(DAG.getConstant(Q, DL, SVT));
7022 return true;
7023 };
7024
7025 SDValue N = REMNode.getOperand(0);
7026 SDValue D = REMNode.getOperand(1);
7027
7028 // Collect the values from each element.
7029 if (!ISD::matchBinaryPredicate(D, CompTargetNode, BuildUREMPattern))
7030 return SDValue();
7031
7032 // If all lanes are tautological, the result can be constant-folded.
7033 if (AllLanesAreTautological)
7034 return SDValue();
7035
7036 // If this is a urem by a powers-of-two, avoid the fold since it can be
7037 // best implemented as a bit test.
7038 if (AllDivisorsArePowerOfTwo)
7039 return SDValue();
7040
7041 SDValue PVal, KVal, QVal;
7042 if (D.getOpcode() == ISD::BUILD_VECTOR) {
7043 if (HadTautologicalLanes) {
7044 // Try to turn PAmts into a splat, since we don't care about the values
7045 // that are currently '0'. If we can't, just keep '0'`s.
7046 turnVectorIntoSplatVector(PAmts, isNullConstant);
7047 // Try to turn KAmts into a splat, since we don't care about the values
7048 // that are currently '-1'. If we can't, change them to '0'`s.
7049 turnVectorIntoSplatVector(KAmts, isAllOnesConstant,
7050 DAG.getConstant(0, DL, ShSVT));
7051 }
7052
7053 PVal = DAG.getBuildVector(VT, DL, PAmts);
7054 KVal = DAG.getBuildVector(ShVT, DL, KAmts);
7055 QVal = DAG.getBuildVector(VT, DL, QAmts);
7056 } else if (D.getOpcode() == ISD::SPLAT_VECTOR) {
7057 assert(PAmts.size() == 1 && KAmts.size() == 1 && QAmts.size() == 1 &&
7058 "Expected matchBinaryPredicate to return one element for "
7059 "SPLAT_VECTORs");
7060 PVal = DAG.getSplatVector(VT, DL, PAmts[0]);
7061 KVal = DAG.getSplatVector(ShVT, DL, KAmts[0]);
7062 QVal = DAG.getSplatVector(VT, DL, QAmts[0]);
7063 } else {
7064 PVal = PAmts[0];
7065 KVal = KAmts[0];
7066 QVal = QAmts[0];
7067 }
7068
7069 if (!ComparingWithAllZeros && !AllComparisonsWithNonZerosAreTautological) {
7070 if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::SUB, VT))
7071 return SDValue(); // FIXME: Could/should use `ISD::ADD`?
7072 assert(CompTargetNode.getValueType() == N.getValueType() &&
7073 "Expecting that the types on LHS and RHS of comparisons match.");
7074 N = DAG.getNode(ISD::SUB, DL, VT, N, CompTargetNode);
7075 }
7076
7077 // (mul N, P)
7078 SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal);
7079 Created.push_back(Op0.getNode());
7080
7081 // Rotate right only if any divisor was even. We avoid rotates for all-odd
7082 // divisors as a performance improvement, since rotating by 0 is a no-op.
7083 if (HadEvenDivisor) {
7084 // We need ROTR to do this.
7085 if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::ROTR, VT))
7086 return SDValue();
7087 // UREM: (rotr (mul N, P), K)
7088 Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal);
7089 Created.push_back(Op0.getNode());
7090 }
7091
7092 // UREM: (setule/setugt (rotr (mul N, P), K), Q)
7093 SDValue NewCC =
7094 DAG.getSetCC(DL, SETCCVT, Op0, QVal,
7095 ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
7096 if (!HadTautologicalInvertedLanes)
7097 return NewCC;
7098
7099 // If any lanes previously compared always-false, the NewCC will give
7100 // always-true result for them, so we need to fixup those lanes.
7101 // Or the other way around for inequality predicate.
7102 assert(VT.isVector() && "Can/should only get here for vectors.");
7103 Created.push_back(NewCC.getNode());
7104
7105 // x u% C1` is *always* less than C1. So given `x u% C1 == C2`,
7106 // if C2 is not less than C1, the comparison is always false.
7107 // But we have produced the comparison that will give the
7108 // opposive tautological answer. So these lanes would need to be fixed up.
7109 SDValue TautologicalInvertedChannels =
7110 DAG.getSetCC(DL, SETCCVT, D, CompTargetNode, ISD::SETULE);
7111 Created.push_back(TautologicalInvertedChannels.getNode());
7112
7113 // NOTE: we avoid letting illegal types through even if we're before legalize
7114 // ops – legalization has a hard time producing good code for this.
7115 if (isOperationLegalOrCustom(ISD::VSELECT, SETCCVT)) {
7116 // If we have a vector select, let's replace the comparison results in the
7117 // affected lanes with the correct tautological result.
7118 SDValue Replacement = DAG.getBoolConstant(Cond == ISD::SETEQ ? false : true,
7119 DL, SETCCVT, SETCCVT);
7120 return DAG.getNode(ISD::VSELECT, DL, SETCCVT, TautologicalInvertedChannels,
7121 Replacement, NewCC);
7122 }
7123
7124 // Else, we can just invert the comparison result in the appropriate lanes.
7125 //
7126 // NOTE: see the note above VSELECT above.
7127 if (isOperationLegalOrCustom(ISD::XOR, SETCCVT))
7128 return DAG.getNode(ISD::XOR, DL, SETCCVT, NewCC,
7129 TautologicalInvertedChannels);
7130
7131 return SDValue(); // Don't know how to lower.
7132 }
7133
7134 /// Given an ISD::SREM used only by an ISD::SETEQ or ISD::SETNE
7135 /// where the divisor is constant and the comparison target is zero,
7136 /// return a DAG expression that will generate the same comparison result
7137 /// using only multiplications, additions and shifts/rotations.
7138 /// Ref: "Hacker's Delight" 10-17.
buildSREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL) const7139 SDValue TargetLowering::buildSREMEqFold(EVT SETCCVT, SDValue REMNode,
7140 SDValue CompTargetNode,
7141 ISD::CondCode Cond,
7142 DAGCombinerInfo &DCI,
7143 const SDLoc &DL) const {
7144 SmallVector<SDNode *, 7> Built;
7145 if (SDValue Folded = prepareSREMEqFold(SETCCVT, REMNode, CompTargetNode, Cond,
7146 DCI, DL, Built)) {
7147 assert(Built.size() <= 7 && "Max size prediction failed.");
7148 for (SDNode *N : Built)
7149 DCI.AddToWorklist(N);
7150 return Folded;
7151 }
7152
7153 return SDValue();
7154 }
7155
7156 SDValue
prepareSREMEqFold(EVT SETCCVT,SDValue REMNode,SDValue CompTargetNode,ISD::CondCode Cond,DAGCombinerInfo & DCI,const SDLoc & DL,SmallVectorImpl<SDNode * > & Created) const7157 TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
7158 SDValue CompTargetNode, ISD::CondCode Cond,
7159 DAGCombinerInfo &DCI, const SDLoc &DL,
7160 SmallVectorImpl<SDNode *> &Created) const {
7161 // Derived from Hacker's Delight, 2nd Edition, by Hank Warren. Section 10-17.
7162 // Fold:
7163 // (seteq/ne (srem N, D), 0)
7164 // To:
7165 // (setule/ugt (rotr (add (mul N, P), A), K), Q)
7166 //
7167 // - D must be constant, with D = D0 * 2^K where D0 is odd
7168 // - P is the multiplicative inverse of D0 modulo 2^W
7169 // - A = bitwiseand(floor((2^(W - 1) - 1) / D0), (-(2^k)))
7170 // - Q = floor((2 * A) / (2^K))
7171 // where W is the width of the common type of N and D.
7172 //
7173 // When D is a power of two (and thus D0 is 1), the normal
7174 // formula for A and Q don't apply, because the derivation
7175 // depends on D not dividing 2^(W-1), and thus theorem ZRS
7176 // does not apply. This specifically fails when N = INT_MIN.
7177 //
7178 // Instead, for power-of-two D, we use:
7179 // - A = 2^(W-1)
7180 // |-> Order-preserving map from [-2^(W-1), 2^(W-1) - 1] to [0,2^W - 1])
7181 // - Q = 2^(W-K) - 1
7182 // |-> Test that the top K bits are zero after rotation
7183 assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
7184 "Only applicable for (in)equality comparisons.");
7185
7186 SelectionDAG &DAG = DCI.DAG;
7187
7188 EVT VT = REMNode.getValueType();
7189 EVT SVT = VT.getScalarType();
7190 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
7191 EVT ShSVT = ShVT.getScalarType();
7192
7193 // If we are after ops legalization, and MUL is unavailable, we can not
7194 // proceed.
7195 if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::MUL, VT))
7196 return SDValue();
7197
7198 // TODO: Could support comparing with non-zero too.
7199 ConstantSDNode *CompTarget = isConstOrConstSplat(CompTargetNode);
7200 if (!CompTarget || !CompTarget->isZero())
7201 return SDValue();
7202
7203 bool HadIntMinDivisor = false;
7204 bool HadOneDivisor = false;
7205 bool AllDivisorsAreOnes = true;
7206 bool HadEvenDivisor = false;
7207 bool NeedToApplyOffset = false;
7208 bool AllDivisorsArePowerOfTwo = true;
7209 SmallVector<SDValue, 16> PAmts, AAmts, KAmts, QAmts;
7210
7211 auto BuildSREMPattern = [&](ConstantSDNode *C) {
7212 // Division by 0 is UB. Leave it to be constant-folded elsewhere.
7213 if (C->isZero())
7214 return false;
7215
7216 // FIXME: we don't fold `rem %X, -C` to `rem %X, C` in DAGCombine.
7217
7218 // WARNING: this fold is only valid for positive divisors!
7219 APInt D = C->getAPIntValue();
7220 if (D.isNegative())
7221 D.negate(); // `rem %X, -C` is equivalent to `rem %X, C`
7222
7223 HadIntMinDivisor |= D.isMinSignedValue();
7224
7225 // If all divisors are ones, we will prefer to avoid the fold.
7226 HadOneDivisor |= D.isOne();
7227 AllDivisorsAreOnes &= D.isOne();
7228
7229 // Decompose D into D0 * 2^K
7230 unsigned K = D.countr_zero();
7231 assert((!D.isOne() || (K == 0)) && "For divisor '1' we won't rotate.");
7232 APInt D0 = D.lshr(K);
7233
7234 if (!D.isMinSignedValue()) {
7235 // D is even if it has trailing zeros; unless it's INT_MIN, in which case
7236 // we don't care about this lane in this fold, we'll special-handle it.
7237 HadEvenDivisor |= (K != 0);
7238 }
7239
7240 // D is a power-of-two if D0 is one. This includes INT_MIN.
7241 // If all divisors are power-of-two, we will prefer to avoid the fold.
7242 AllDivisorsArePowerOfTwo &= D0.isOne();
7243
7244 // P = inv(D0, 2^W)
7245 // 2^W requires W + 1 bits, so we have to extend and then truncate.
7246 unsigned W = D.getBitWidth();
7247 APInt P = D0.multiplicativeInverse();
7248 assert((D0 * P).isOne() && "Multiplicative inverse basic check failed.");
7249
7250 // A = floor((2^(W - 1) - 1) / D0) & -2^K
7251 APInt A = APInt::getSignedMaxValue(W).udiv(D0);
7252 A.clearLowBits(K);
7253
7254 if (!D.isMinSignedValue()) {
7255 // If divisor INT_MIN, then we don't care about this lane in this fold,
7256 // we'll special-handle it.
7257 NeedToApplyOffset |= A != 0;
7258 }
7259
7260 // Q = floor((2 * A) / (2^K))
7261 APInt Q = (2 * A).udiv(APInt::getOneBitSet(W, K));
7262
7263 assert(APInt::getAllOnes(SVT.getSizeInBits()).ugt(A) &&
7264 "We are expecting that A is always less than all-ones for SVT");
7265 assert(APInt::getAllOnes(ShSVT.getSizeInBits()).ugt(K) &&
7266 "We are expecting that K is always less than all-ones for ShSVT");
7267
7268 // If D was a power of two, apply the alternate constant derivation.
7269 if (D0.isOne()) {
7270 // A = 2^(W-1)
7271 A = APInt::getSignedMinValue(W);
7272 // - Q = 2^(W-K) - 1
7273 Q = APInt::getAllOnes(W - K).zext(W);
7274 }
7275
7276 // If the divisor is 1 the result can be constant-folded. Likewise, we
7277 // don't care about INT_MIN lanes, those can be set to undef if appropriate.
7278 if (D.isOne()) {
7279 // Set P, A and K to a bogus values so we can try to splat them.
7280 P = 0;
7281 A = -1;
7282 K = -1;
7283
7284 // x ?% 1 == 0 <--> true <--> x u<= -1
7285 Q = -1;
7286 }
7287
7288 PAmts.push_back(DAG.getConstant(P, DL, SVT));
7289 AAmts.push_back(DAG.getConstant(A, DL, SVT));
7290 KAmts.push_back(
7291 DAG.getConstant(APInt(ShSVT.getSizeInBits(), K, /*isSigned=*/false,
7292 /*implicitTrunc=*/true),
7293 DL, ShSVT));
7294 QAmts.push_back(DAG.getConstant(Q, DL, SVT));
7295 return true;
7296 };
7297
7298 SDValue N = REMNode.getOperand(0);
7299 SDValue D = REMNode.getOperand(1);
7300
7301 // Collect the values from each element.
7302 if (!ISD::matchUnaryPredicate(D, BuildSREMPattern))
7303 return SDValue();
7304
7305 // If this is a srem by a one, avoid the fold since it can be constant-folded.
7306 if (AllDivisorsAreOnes)
7307 return SDValue();
7308
7309 // If this is a srem by a powers-of-two (including INT_MIN), avoid the fold
7310 // since it can be best implemented as a bit test.
7311 if (AllDivisorsArePowerOfTwo)
7312 return SDValue();
7313
7314 SDValue PVal, AVal, KVal, QVal;
7315 if (D.getOpcode() == ISD::BUILD_VECTOR) {
7316 if (HadOneDivisor) {
7317 // Try to turn PAmts into a splat, since we don't care about the values
7318 // that are currently '0'. If we can't, just keep '0'`s.
7319 turnVectorIntoSplatVector(PAmts, isNullConstant);
7320 // Try to turn AAmts into a splat, since we don't care about the
7321 // values that are currently '-1'. If we can't, change them to '0'`s.
7322 turnVectorIntoSplatVector(AAmts, isAllOnesConstant,
7323 DAG.getConstant(0, DL, SVT));
7324 // Try to turn KAmts into a splat, since we don't care about the values
7325 // that are currently '-1'. If we can't, change them to '0'`s.
7326 turnVectorIntoSplatVector(KAmts, isAllOnesConstant,
7327 DAG.getConstant(0, DL, ShSVT));
7328 }
7329
7330 PVal = DAG.getBuildVector(VT, DL, PAmts);
7331 AVal = DAG.getBuildVector(VT, DL, AAmts);
7332 KVal = DAG.getBuildVector(ShVT, DL, KAmts);
7333 QVal = DAG.getBuildVector(VT, DL, QAmts);
7334 } else if (D.getOpcode() == ISD::SPLAT_VECTOR) {
7335 assert(PAmts.size() == 1 && AAmts.size() == 1 && KAmts.size() == 1 &&
7336 QAmts.size() == 1 &&
7337 "Expected matchUnaryPredicate to return one element for scalable "
7338 "vectors");
7339 PVal = DAG.getSplatVector(VT, DL, PAmts[0]);
7340 AVal = DAG.getSplatVector(VT, DL, AAmts[0]);
7341 KVal = DAG.getSplatVector(ShVT, DL, KAmts[0]);
7342 QVal = DAG.getSplatVector(VT, DL, QAmts[0]);
7343 } else {
7344 assert(isa<ConstantSDNode>(D) && "Expected a constant");
7345 PVal = PAmts[0];
7346 AVal = AAmts[0];
7347 KVal = KAmts[0];
7348 QVal = QAmts[0];
7349 }
7350
7351 // (mul N, P)
7352 SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal);
7353 Created.push_back(Op0.getNode());
7354
7355 if (NeedToApplyOffset) {
7356 // We need ADD to do this.
7357 if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::ADD, VT))
7358 return SDValue();
7359
7360 // (add (mul N, P), A)
7361 Op0 = DAG.getNode(ISD::ADD, DL, VT, Op0, AVal);
7362 Created.push_back(Op0.getNode());
7363 }
7364
7365 // Rotate right only if any divisor was even. We avoid rotates for all-odd
7366 // divisors as a performance improvement, since rotating by 0 is a no-op.
7367 if (HadEvenDivisor) {
7368 // We need ROTR to do this.
7369 if (!DCI.isBeforeLegalizeOps() && !isOperationLegalOrCustom(ISD::ROTR, VT))
7370 return SDValue();
7371 // SREM: (rotr (add (mul N, P), A), K)
7372 Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal);
7373 Created.push_back(Op0.getNode());
7374 }
7375
7376 // SREM: (setule/setugt (rotr (add (mul N, P), A), K), Q)
7377 SDValue Fold =
7378 DAG.getSetCC(DL, SETCCVT, Op0, QVal,
7379 ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
7380
7381 // If we didn't have lanes with INT_MIN divisor, then we're done.
7382 if (!HadIntMinDivisor)
7383 return Fold;
7384
7385 // That fold is only valid for positive divisors. Which effectively means,
7386 // it is invalid for INT_MIN divisors. So if we have such a lane,
7387 // we must fix-up results for said lanes.
7388 assert(VT.isVector() && "Can/should only get here for vectors.");
7389
7390 // NOTE: we avoid letting illegal types through even if we're before legalize
7391 // ops – legalization has a hard time producing good code for the code that
7392 // follows.
7393 if (!isOperationLegalOrCustom(ISD::SETCC, SETCCVT) ||
7394 !isOperationLegalOrCustom(ISD::AND, VT) ||
7395 !isCondCodeLegalOrCustom(Cond, VT.getSimpleVT()) ||
7396 !isOperationLegalOrCustom(ISD::VSELECT, SETCCVT))
7397 return SDValue();
7398
7399 Created.push_back(Fold.getNode());
7400
7401 SDValue IntMin = DAG.getConstant(
7402 APInt::getSignedMinValue(SVT.getScalarSizeInBits()), DL, VT);
7403 SDValue IntMax = DAG.getConstant(
7404 APInt::getSignedMaxValue(SVT.getScalarSizeInBits()), DL, VT);
7405 SDValue Zero =
7406 DAG.getConstant(APInt::getZero(SVT.getScalarSizeInBits()), DL, VT);
7407
7408 // Which lanes had INT_MIN divisors? Divisor is constant, so const-folded.
7409 SDValue DivisorIsIntMin = DAG.getSetCC(DL, SETCCVT, D, IntMin, ISD::SETEQ);
7410 Created.push_back(DivisorIsIntMin.getNode());
7411
7412 // (N s% INT_MIN) ==/!= 0 <--> (N & INT_MAX) ==/!= 0
7413 SDValue Masked = DAG.getNode(ISD::AND, DL, VT, N, IntMax);
7414 Created.push_back(Masked.getNode());
7415 SDValue MaskedIsZero = DAG.getSetCC(DL, SETCCVT, Masked, Zero, Cond);
7416 Created.push_back(MaskedIsZero.getNode());
7417
7418 // To produce final result we need to blend 2 vectors: 'SetCC' and
7419 // 'MaskedIsZero'. If the divisor for channel was *NOT* INT_MIN, we pick
7420 // from 'Fold', else pick from 'MaskedIsZero'. Since 'DivisorIsIntMin' is
7421 // constant-folded, select can get lowered to a shuffle with constant mask.
7422 SDValue Blended = DAG.getNode(ISD::VSELECT, DL, SETCCVT, DivisorIsIntMin,
7423 MaskedIsZero, Fold);
7424
7425 return Blended;
7426 }
7427
getSqrtInputTest(SDValue Op,SelectionDAG & DAG,const DenormalMode & Mode) const7428 SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
7429 const DenormalMode &Mode) const {
7430 SDLoc DL(Op);
7431 EVT VT = Op.getValueType();
7432 EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
7433 SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
7434
7435 // This is specifically a check for the handling of denormal inputs, not the
7436 // result.
7437 if (Mode.Input == DenormalMode::PreserveSign ||
7438 Mode.Input == DenormalMode::PositiveZero) {
7439 // Test = X == 0.0
7440 return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
7441 }
7442
7443 // Testing it with denormal inputs to avoid wrong estimate.
7444 //
7445 // Test = fabs(X) < SmallestNormal
7446 const fltSemantics &FltSem = VT.getFltSemantics();
7447 APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
7448 SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
7449 SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
7450 return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
7451 }
7452
getNegatedExpression(SDValue Op,SelectionDAG & DAG,bool LegalOps,bool OptForSize,NegatibleCost & Cost,unsigned Depth) const7453 SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
7454 bool LegalOps, bool OptForSize,
7455 NegatibleCost &Cost,
7456 unsigned Depth) const {
7457 // fneg is removable even if it has multiple uses.
7458 if (Op.getOpcode() == ISD::FNEG || Op.getOpcode() == ISD::VP_FNEG) {
7459 Cost = NegatibleCost::Cheaper;
7460 return Op.getOperand(0);
7461 }
7462
7463 // Don't recurse exponentially.
7464 if (Depth > SelectionDAG::MaxRecursionDepth)
7465 return SDValue();
7466
7467 // Pre-increment recursion depth for use in recursive calls.
7468 ++Depth;
7469 const SDNodeFlags Flags = Op->getFlags();
7470 const TargetOptions &Options = DAG.getTarget().Options;
7471 EVT VT = Op.getValueType();
7472 unsigned Opcode = Op.getOpcode();
7473
7474 // Don't allow anything with multiple uses unless we know it is free.
7475 if (!Op.hasOneUse() && Opcode != ISD::ConstantFP) {
7476 bool IsFreeExtend = Opcode == ISD::FP_EXTEND &&
7477 isFPExtFree(VT, Op.getOperand(0).getValueType());
7478 if (!IsFreeExtend)
7479 return SDValue();
7480 }
7481
7482 auto RemoveDeadNode = [&](SDValue N) {
7483 if (N && N.getNode()->use_empty())
7484 DAG.RemoveDeadNode(N.getNode());
7485 };
7486
7487 SDLoc DL(Op);
7488
7489 // Because getNegatedExpression can delete nodes we need a handle to keep
7490 // temporary nodes alive in case the recursion manages to create an identical
7491 // node.
7492 std::list<HandleSDNode> Handles;
7493
7494 switch (Opcode) {
7495 case ISD::ConstantFP: {
7496 // Don't invert constant FP values after legalization unless the target says
7497 // the negated constant is legal.
7498 bool IsOpLegal =
7499 isOperationLegal(ISD::ConstantFP, VT) ||
7500 isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT,
7501 OptForSize);
7502
7503 if (LegalOps && !IsOpLegal)
7504 break;
7505
7506 APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF();
7507 V.changeSign();
7508 SDValue CFP = DAG.getConstantFP(V, DL, VT);
7509
7510 // If we already have the use of the negated floating constant, it is free
7511 // to negate it even it has multiple uses.
7512 if (!Op.hasOneUse() && CFP.use_empty())
7513 break;
7514 Cost = NegatibleCost::Neutral;
7515 return CFP;
7516 }
7517 case ISD::BUILD_VECTOR: {
7518 // Only permit BUILD_VECTOR of constants.
7519 if (llvm::any_of(Op->op_values(), [&](SDValue N) {
7520 return !N.isUndef() && !isa<ConstantFPSDNode>(N);
7521 }))
7522 break;
7523
7524 bool IsOpLegal =
7525 (isOperationLegal(ISD::ConstantFP, VT) &&
7526 isOperationLegal(ISD::BUILD_VECTOR, VT)) ||
7527 llvm::all_of(Op->op_values(), [&](SDValue N) {
7528 return N.isUndef() ||
7529 isFPImmLegal(neg(cast<ConstantFPSDNode>(N)->getValueAPF()), VT,
7530 OptForSize);
7531 });
7532
7533 if (LegalOps && !IsOpLegal)
7534 break;
7535
7536 SmallVector<SDValue, 4> Ops;
7537 for (SDValue C : Op->op_values()) {
7538 if (C.isUndef()) {
7539 Ops.push_back(C);
7540 continue;
7541 }
7542 APFloat V = cast<ConstantFPSDNode>(C)->getValueAPF();
7543 V.changeSign();
7544 Ops.push_back(DAG.getConstantFP(V, DL, C.getValueType()));
7545 }
7546 Cost = NegatibleCost::Neutral;
7547 return DAG.getBuildVector(VT, DL, Ops);
7548 }
7549 case ISD::FADD: {
7550 if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
7551 break;
7552
7553 // After operation legalization, it might not be legal to create new FSUBs.
7554 if (LegalOps && !isOperationLegalOrCustom(ISD::FSUB, VT))
7555 break;
7556 SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
7557
7558 // fold (fneg (fadd X, Y)) -> (fsub (fneg X), Y)
7559 NegatibleCost CostX = NegatibleCost::Expensive;
7560 SDValue NegX =
7561 getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
7562 // Prevent this node from being deleted by the next call.
7563 if (NegX)
7564 Handles.emplace_back(NegX);
7565
7566 // fold (fneg (fadd X, Y)) -> (fsub (fneg Y), X)
7567 NegatibleCost CostY = NegatibleCost::Expensive;
7568 SDValue NegY =
7569 getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
7570
7571 // We're done with the handles.
7572 Handles.clear();
7573
7574 // Negate the X if its cost is less or equal than Y.
7575 if (NegX && (CostX <= CostY)) {
7576 Cost = CostX;
7577 SDValue N = DAG.getNode(ISD::FSUB, DL, VT, NegX, Y, Flags);
7578 if (NegY != N)
7579 RemoveDeadNode(NegY);
7580 return N;
7581 }
7582
7583 // Negate the Y if it is not expensive.
7584 if (NegY) {
7585 Cost = CostY;
7586 SDValue N = DAG.getNode(ISD::FSUB, DL, VT, NegY, X, Flags);
7587 if (NegX != N)
7588 RemoveDeadNode(NegX);
7589 return N;
7590 }
7591 break;
7592 }
7593 case ISD::FSUB: {
7594 // We can't turn -(A-B) into B-A when we honor signed zeros.
7595 if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
7596 break;
7597
7598 SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
7599 // fold (fneg (fsub 0, Y)) -> Y
7600 if (ConstantFPSDNode *C = isConstOrConstSplatFP(X, /*AllowUndefs*/ true))
7601 if (C->isZero()) {
7602 Cost = NegatibleCost::Cheaper;
7603 return Y;
7604 }
7605
7606 // fold (fneg (fsub X, Y)) -> (fsub Y, X)
7607 Cost = NegatibleCost::Neutral;
7608 return DAG.getNode(ISD::FSUB, DL, VT, Y, X, Flags);
7609 }
7610 case ISD::FMUL:
7611 case ISD::FDIV: {
7612 SDValue X = Op.getOperand(0), Y = Op.getOperand(1);
7613
7614 // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
7615 NegatibleCost CostX = NegatibleCost::Expensive;
7616 SDValue NegX =
7617 getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
7618 // Prevent this node from being deleted by the next call.
7619 if (NegX)
7620 Handles.emplace_back(NegX);
7621
7622 // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y))
7623 NegatibleCost CostY = NegatibleCost::Expensive;
7624 SDValue NegY =
7625 getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
7626
7627 // We're done with the handles.
7628 Handles.clear();
7629
7630 // Negate the X if its cost is less or equal than Y.
7631 if (NegX && (CostX <= CostY)) {
7632 Cost = CostX;
7633 SDValue N = DAG.getNode(Opcode, DL, VT, NegX, Y, Flags);
7634 if (NegY != N)
7635 RemoveDeadNode(NegY);
7636 return N;
7637 }
7638
7639 // Ignore X * 2.0 because that is expected to be canonicalized to X + X.
7640 if (auto *C = isConstOrConstSplatFP(Op.getOperand(1)))
7641 if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL)
7642 break;
7643
7644 // Negate the Y if it is not expensive.
7645 if (NegY) {
7646 Cost = CostY;
7647 SDValue N = DAG.getNode(Opcode, DL, VT, X, NegY, Flags);
7648 if (NegX != N)
7649 RemoveDeadNode(NegX);
7650 return N;
7651 }
7652 break;
7653 }
7654 case ISD::FMA:
7655 case ISD::FMAD: {
7656 if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
7657 break;
7658
7659 SDValue X = Op.getOperand(0), Y = Op.getOperand(1), Z = Op.getOperand(2);
7660 NegatibleCost CostZ = NegatibleCost::Expensive;
7661 SDValue NegZ =
7662 getNegatedExpression(Z, DAG, LegalOps, OptForSize, CostZ, Depth);
7663 // Give up if fail to negate the Z.
7664 if (!NegZ)
7665 break;
7666
7667 // Prevent this node from being deleted by the next two calls.
7668 Handles.emplace_back(NegZ);
7669
7670 // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
7671 NegatibleCost CostX = NegatibleCost::Expensive;
7672 SDValue NegX =
7673 getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth);
7674 // Prevent this node from being deleted by the next call.
7675 if (NegX)
7676 Handles.emplace_back(NegX);
7677
7678 // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
7679 NegatibleCost CostY = NegatibleCost::Expensive;
7680 SDValue NegY =
7681 getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth);
7682
7683 // We're done with the handles.
7684 Handles.clear();
7685
7686 // Negate the X if its cost is less or equal than Y.
7687 if (NegX && (CostX <= CostY)) {
7688 Cost = std::min(CostX, CostZ);
7689 SDValue N = DAG.getNode(Opcode, DL, VT, NegX, Y, NegZ, Flags);
7690 if (NegY != N)
7691 RemoveDeadNode(NegY);
7692 return N;
7693 }
7694
7695 // Negate the Y if it is not expensive.
7696 if (NegY) {
7697 Cost = std::min(CostY, CostZ);
7698 SDValue N = DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags);
7699 if (NegX != N)
7700 RemoveDeadNode(NegX);
7701 return N;
7702 }
7703 break;
7704 }
7705
7706 case ISD::FP_EXTEND:
7707 case ISD::FSIN:
7708 if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
7709 OptForSize, Cost, Depth))
7710 return DAG.getNode(Opcode, DL, VT, NegV);
7711 break;
7712 case ISD::FP_ROUND:
7713 if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps,
7714 OptForSize, Cost, Depth))
7715 return DAG.getNode(ISD::FP_ROUND, DL, VT, NegV, Op.getOperand(1));
7716 break;
7717 case ISD::SELECT:
7718 case ISD::VSELECT: {
7719 // fold (fneg (select C, LHS, RHS)) -> (select C, (fneg LHS), (fneg RHS))
7720 // iff at least one cost is cheaper and the other is neutral/cheaper
7721 SDValue LHS = Op.getOperand(1);
7722 NegatibleCost CostLHS = NegatibleCost::Expensive;
7723 SDValue NegLHS =
7724 getNegatedExpression(LHS, DAG, LegalOps, OptForSize, CostLHS, Depth);
7725 if (!NegLHS || CostLHS > NegatibleCost::Neutral) {
7726 RemoveDeadNode(NegLHS);
7727 break;
7728 }
7729
7730 // Prevent this node from being deleted by the next call.
7731 Handles.emplace_back(NegLHS);
7732
7733 SDValue RHS = Op.getOperand(2);
7734 NegatibleCost CostRHS = NegatibleCost::Expensive;
7735 SDValue NegRHS =
7736 getNegatedExpression(RHS, DAG, LegalOps, OptForSize, CostRHS, Depth);
7737
7738 // We're done with the handles.
7739 Handles.clear();
7740
7741 if (!NegRHS || CostRHS > NegatibleCost::Neutral ||
7742 (CostLHS != NegatibleCost::Cheaper &&
7743 CostRHS != NegatibleCost::Cheaper)) {
7744 RemoveDeadNode(NegLHS);
7745 RemoveDeadNode(NegRHS);
7746 break;
7747 }
7748
7749 Cost = std::min(CostLHS, CostRHS);
7750 return DAG.getSelect(DL, VT, Op.getOperand(0), NegLHS, NegRHS);
7751 }
7752 }
7753
7754 return SDValue();
7755 }
7756
7757 //===----------------------------------------------------------------------===//
7758 // Legalization Utilities
7759 //===----------------------------------------------------------------------===//
7760
expandMUL_LOHI(unsigned Opcode,EVT VT,const SDLoc & dl,SDValue LHS,SDValue RHS,SmallVectorImpl<SDValue> & Result,EVT HiLoVT,SelectionDAG & DAG,MulExpansionKind Kind,SDValue LL,SDValue LH,SDValue RL,SDValue RH) const7761 bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, const SDLoc &dl,
7762 SDValue LHS, SDValue RHS,
7763 SmallVectorImpl<SDValue> &Result,
7764 EVT HiLoVT, SelectionDAG &DAG,
7765 MulExpansionKind Kind, SDValue LL,
7766 SDValue LH, SDValue RL, SDValue RH) const {
7767 assert(Opcode == ISD::MUL || Opcode == ISD::UMUL_LOHI ||
7768 Opcode == ISD::SMUL_LOHI);
7769
7770 bool HasMULHS = (Kind == MulExpansionKind::Always) ||
7771 isOperationLegalOrCustom(ISD::MULHS, HiLoVT);
7772 bool HasMULHU = (Kind == MulExpansionKind::Always) ||
7773 isOperationLegalOrCustom(ISD::MULHU, HiLoVT);
7774 bool HasSMUL_LOHI = (Kind == MulExpansionKind::Always) ||
7775 isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT);
7776 bool HasUMUL_LOHI = (Kind == MulExpansionKind::Always) ||
7777 isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT);
7778
7779 if (!HasMULHU && !HasMULHS && !HasUMUL_LOHI && !HasSMUL_LOHI)
7780 return false;
7781
7782 unsigned OuterBitSize = VT.getScalarSizeInBits();
7783 unsigned InnerBitSize = HiLoVT.getScalarSizeInBits();
7784
7785 // LL, LH, RL, and RH must be either all NULL or all set to a value.
7786 assert((LL.getNode() && LH.getNode() && RL.getNode() && RH.getNode()) ||
7787 (!LL.getNode() && !LH.getNode() && !RL.getNode() && !RH.getNode()));
7788
7789 SDVTList VTs = DAG.getVTList(HiLoVT, HiLoVT);
7790 auto MakeMUL_LOHI = [&](SDValue L, SDValue R, SDValue &Lo, SDValue &Hi,
7791 bool Signed) -> bool {
7792 if ((Signed && HasSMUL_LOHI) || (!Signed && HasUMUL_LOHI)) {
7793 Lo = DAG.getNode(Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI, dl, VTs, L, R);
7794 Hi = SDValue(Lo.getNode(), 1);
7795 return true;
7796 }
7797 if ((Signed && HasMULHS) || (!Signed && HasMULHU)) {
7798 Lo = DAG.getNode(ISD::MUL, dl, HiLoVT, L, R);
7799 Hi = DAG.getNode(Signed ? ISD::MULHS : ISD::MULHU, dl, HiLoVT, L, R);
7800 return true;
7801 }
7802 return false;
7803 };
7804
7805 SDValue Lo, Hi;
7806
7807 if (!LL.getNode() && !RL.getNode() &&
7808 isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) {
7809 LL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LHS);
7810 RL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, RHS);
7811 }
7812
7813 if (!LL.getNode())
7814 return false;
7815
7816 APInt HighMask = APInt::getHighBitsSet(OuterBitSize, InnerBitSize);
7817 if (DAG.MaskedValueIsZero(LHS, HighMask) &&
7818 DAG.MaskedValueIsZero(RHS, HighMask)) {
7819 // The inputs are both zero-extended.
7820 if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) {
7821 Result.push_back(Lo);
7822 Result.push_back(Hi);
7823 if (Opcode != ISD::MUL) {
7824 SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
7825 Result.push_back(Zero);
7826 Result.push_back(Zero);
7827 }
7828 return true;
7829 }
7830 }
7831
7832 if (!VT.isVector() && Opcode == ISD::MUL &&
7833 DAG.ComputeMaxSignificantBits(LHS) <= InnerBitSize &&
7834 DAG.ComputeMaxSignificantBits(RHS) <= InnerBitSize) {
7835 // The input values are both sign-extended.
7836 // TODO non-MUL case?
7837 if (MakeMUL_LOHI(LL, RL, Lo, Hi, true)) {
7838 Result.push_back(Lo);
7839 Result.push_back(Hi);
7840 return true;
7841 }
7842 }
7843
7844 unsigned ShiftAmount = OuterBitSize - InnerBitSize;
7845 SDValue Shift = DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
7846
7847 if (!LH.getNode() && !RH.getNode() &&
7848 isOperationLegalOrCustom(ISD::SRL, VT) &&
7849 isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) {
7850 LH = DAG.getNode(ISD::SRL, dl, VT, LHS, Shift);
7851 LH = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LH);
7852 RH = DAG.getNode(ISD::SRL, dl, VT, RHS, Shift);
7853 RH = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, RH);
7854 }
7855
7856 if (!LH.getNode())
7857 return false;
7858
7859 if (!MakeMUL_LOHI(LL, RL, Lo, Hi, false))
7860 return false;
7861
7862 Result.push_back(Lo);
7863
7864 if (Opcode == ISD::MUL) {
7865 RH = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RH);
7866 LH = DAG.getNode(ISD::MUL, dl, HiLoVT, LH, RL);
7867 Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, RH);
7868 Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, LH);
7869 Result.push_back(Hi);
7870 return true;
7871 }
7872
7873 // Compute the full width result.
7874 auto Merge = [&](SDValue Lo, SDValue Hi) -> SDValue {
7875 Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo);
7876 Hi = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
7877 Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift);
7878 return DAG.getNode(ISD::OR, dl, VT, Lo, Hi);
7879 };
7880
7881 SDValue Next = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
7882 if (!MakeMUL_LOHI(LL, RH, Lo, Hi, false))
7883 return false;
7884
7885 // This is effectively the add part of a multiply-add of half-sized operands,
7886 // so it cannot overflow.
7887 Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
7888
7889 if (!MakeMUL_LOHI(LH, RL, Lo, Hi, false))
7890 return false;
7891
7892 SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
7893 EVT BoolType = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
7894
7895 bool UseGlue = (isOperationLegalOrCustom(ISD::ADDC, VT) &&
7896 isOperationLegalOrCustom(ISD::ADDE, VT));
7897 if (UseGlue)
7898 Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next,
7899 Merge(Lo, Hi));
7900 else
7901 Next = DAG.getNode(ISD::UADDO_CARRY, dl, DAG.getVTList(VT, BoolType), Next,
7902 Merge(Lo, Hi), DAG.getConstant(0, dl, BoolType));
7903
7904 SDValue Carry = Next.getValue(1);
7905 Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
7906 Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
7907
7908 if (!MakeMUL_LOHI(LH, RH, Lo, Hi, Opcode == ISD::SMUL_LOHI))
7909 return false;
7910
7911 if (UseGlue)
7912 Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HiLoVT, MVT::Glue), Hi, Zero,
7913 Carry);
7914 else
7915 Hi = DAG.getNode(ISD::UADDO_CARRY, dl, DAG.getVTList(HiLoVT, BoolType), Hi,
7916 Zero, Carry);
7917
7918 Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
7919
7920 if (Opcode == ISD::SMUL_LOHI) {
7921 SDValue NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
7922 DAG.getNode(ISD::ZERO_EXTEND, dl, VT, RL));
7923 Next = DAG.getSelectCC(dl, LH, Zero, NextSub, Next, ISD::SETLT);
7924
7925 NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
7926 DAG.getNode(ISD::ZERO_EXTEND, dl, VT, LL));
7927 Next = DAG.getSelectCC(dl, RH, Zero, NextSub, Next, ISD::SETLT);
7928 }
7929
7930 Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
7931 Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
7932 Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
7933 return true;
7934 }
7935
expandMUL(SDNode * N,SDValue & Lo,SDValue & Hi,EVT HiLoVT,SelectionDAG & DAG,MulExpansionKind Kind,SDValue LL,SDValue LH,SDValue RL,SDValue RH) const7936 bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
7937 SelectionDAG &DAG, MulExpansionKind Kind,
7938 SDValue LL, SDValue LH, SDValue RL,
7939 SDValue RH) const {
7940 SmallVector<SDValue, 2> Result;
7941 bool Ok = expandMUL_LOHI(N->getOpcode(), N->getValueType(0), SDLoc(N),
7942 N->getOperand(0), N->getOperand(1), Result, HiLoVT,
7943 DAG, Kind, LL, LH, RL, RH);
7944 if (Ok) {
7945 assert(Result.size() == 2);
7946 Lo = Result[0];
7947 Hi = Result[1];
7948 }
7949 return Ok;
7950 }
7951
7952 // Optimize unsigned division or remainder by constants for types twice as large
7953 // as a legal VT.
7954 //
7955 // If (1 << (BitWidth / 2)) % Constant == 1, then the remainder
7956 // can be computed
7957 // as:
7958 // Sum += __builtin_uadd_overflow(Lo, High, &Sum);
7959 // Remainder = Sum % Constant
7960 // This is based on "Remainder by Summing Digits" from Hacker's Delight.
7961 //
7962 // For division, we can compute the remainder using the algorithm described
7963 // above, subtract it from the dividend to get an exact multiple of Constant.
7964 // Then multiply that exact multiply by the multiplicative inverse modulo
7965 // (1 << (BitWidth / 2)) to get the quotient.
7966
7967 // If Constant is even, we can shift right the dividend and the divisor by the
7968 // number of trailing zeros in Constant before applying the remainder algorithm.
7969 // If we're after the quotient, we can subtract this value from the shifted
7970 // dividend and multiply by the multiplicative inverse of the shifted divisor.
7971 // If we want the remainder, we shift the value left by the number of trailing
7972 // zeros and add the bits that were shifted out of the dividend.
expandDIVREMByConstant(SDNode * N,SmallVectorImpl<SDValue> & Result,EVT HiLoVT,SelectionDAG & DAG,SDValue LL,SDValue LH) const7973 bool TargetLowering::expandDIVREMByConstant(SDNode *N,
7974 SmallVectorImpl<SDValue> &Result,
7975 EVT HiLoVT, SelectionDAG &DAG,
7976 SDValue LL, SDValue LH) const {
7977 unsigned Opcode = N->getOpcode();
7978 EVT VT = N->getValueType(0);
7979
7980 // TODO: Support signed division/remainder.
7981 if (Opcode == ISD::SREM || Opcode == ISD::SDIV || Opcode == ISD::SDIVREM)
7982 return false;
7983 assert(
7984 (Opcode == ISD::UREM || Opcode == ISD::UDIV || Opcode == ISD::UDIVREM) &&
7985 "Unexpected opcode");
7986
7987 auto *CN = dyn_cast<ConstantSDNode>(N->getOperand(1));
7988 if (!CN)
7989 return false;
7990
7991 APInt Divisor = CN->getAPIntValue();
7992 unsigned BitWidth = Divisor.getBitWidth();
7993 unsigned HBitWidth = BitWidth / 2;
7994 assert(VT.getScalarSizeInBits() == BitWidth &&
7995 HiLoVT.getScalarSizeInBits() == HBitWidth && "Unexpected VTs");
7996
7997 // Divisor needs to less than (1 << HBitWidth).
7998 APInt HalfMaxPlus1 = APInt::getOneBitSet(BitWidth, HBitWidth);
7999 if (Divisor.uge(HalfMaxPlus1))
8000 return false;
8001
8002 // We depend on the UREM by constant optimization in DAGCombiner that requires
8003 // high multiply.
8004 if (!isOperationLegalOrCustom(ISD::MULHU, HiLoVT) &&
8005 !isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT))
8006 return false;
8007
8008 // Don't expand if optimizing for size.
8009 if (DAG.shouldOptForSize())
8010 return false;
8011
8012 // Early out for 0 or 1 divisors.
8013 if (Divisor.ule(1))
8014 return false;
8015
8016 // If the divisor is even, shift it until it becomes odd.
8017 unsigned TrailingZeros = 0;
8018 if (!Divisor[0]) {
8019 TrailingZeros = Divisor.countr_zero();
8020 Divisor.lshrInPlace(TrailingZeros);
8021 }
8022
8023 SDLoc dl(N);
8024 SDValue Sum;
8025 SDValue PartialRem;
8026
8027 // If (1 << HBitWidth) % divisor == 1, we can add the two halves together and
8028 // then add in the carry.
8029 // TODO: If we can't split it in half, we might be able to split into 3 or
8030 // more pieces using a smaller bit width.
8031 if (HalfMaxPlus1.urem(Divisor).isOne()) {
8032 assert(!LL == !LH && "Expected both input halves or no input halves!");
8033 if (!LL)
8034 std::tie(LL, LH) = DAG.SplitScalar(N->getOperand(0), dl, HiLoVT, HiLoVT);
8035
8036 // Shift the input by the number of TrailingZeros in the divisor. The
8037 // shifted out bits will be added to the remainder later.
8038 if (TrailingZeros) {
8039 // Save the shifted off bits if we need the remainder.
8040 if (Opcode != ISD::UDIV) {
8041 APInt Mask = APInt::getLowBitsSet(HBitWidth, TrailingZeros);
8042 PartialRem = DAG.getNode(ISD::AND, dl, HiLoVT, LL,
8043 DAG.getConstant(Mask, dl, HiLoVT));
8044 }
8045
8046 LL = DAG.getNode(
8047 ISD::OR, dl, HiLoVT,
8048 DAG.getNode(ISD::SRL, dl, HiLoVT, LL,
8049 DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl)),
8050 DAG.getNode(ISD::SHL, dl, HiLoVT, LH,
8051 DAG.getShiftAmountConstant(HBitWidth - TrailingZeros,
8052 HiLoVT, dl)));
8053 LH = DAG.getNode(ISD::SRL, dl, HiLoVT, LH,
8054 DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
8055 }
8056
8057 // Use uaddo_carry if we can, otherwise use a compare to detect overflow.
8058 EVT SetCCType =
8059 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), HiLoVT);
8060 if (isOperationLegalOrCustom(ISD::UADDO_CARRY, HiLoVT)) {
8061 SDVTList VTList = DAG.getVTList(HiLoVT, SetCCType);
8062 Sum = DAG.getNode(ISD::UADDO, dl, VTList, LL, LH);
8063 Sum = DAG.getNode(ISD::UADDO_CARRY, dl, VTList, Sum,
8064 DAG.getConstant(0, dl, HiLoVT), Sum.getValue(1));
8065 } else {
8066 Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, LL, LH);
8067 SDValue Carry = DAG.getSetCC(dl, SetCCType, Sum, LL, ISD::SETULT);
8068 // If the boolean for the target is 0 or 1, we can add the setcc result
8069 // directly.
8070 if (getBooleanContents(HiLoVT) ==
8071 TargetLoweringBase::ZeroOrOneBooleanContent)
8072 Carry = DAG.getZExtOrTrunc(Carry, dl, HiLoVT);
8073 else
8074 Carry = DAG.getSelect(dl, HiLoVT, Carry, DAG.getConstant(1, dl, HiLoVT),
8075 DAG.getConstant(0, dl, HiLoVT));
8076 Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, Sum, Carry);
8077 }
8078 }
8079
8080 // If we didn't find a sum, we can't do the expansion.
8081 if (!Sum)
8082 return false;
8083
8084 // Perform a HiLoVT urem on the Sum using truncated divisor.
8085 SDValue RemL =
8086 DAG.getNode(ISD::UREM, dl, HiLoVT, Sum,
8087 DAG.getConstant(Divisor.trunc(HBitWidth), dl, HiLoVT));
8088 SDValue RemH = DAG.getConstant(0, dl, HiLoVT);
8089
8090 if (Opcode != ISD::UREM) {
8091 // Subtract the remainder from the shifted dividend.
8092 SDValue Dividend = DAG.getNode(ISD::BUILD_PAIR, dl, VT, LL, LH);
8093 SDValue Rem = DAG.getNode(ISD::BUILD_PAIR, dl, VT, RemL, RemH);
8094
8095 Dividend = DAG.getNode(ISD::SUB, dl, VT, Dividend, Rem);
8096
8097 // Multiply by the multiplicative inverse of the divisor modulo
8098 // (1 << BitWidth).
8099 APInt MulFactor = Divisor.multiplicativeInverse();
8100
8101 SDValue Quotient = DAG.getNode(ISD::MUL, dl, VT, Dividend,
8102 DAG.getConstant(MulFactor, dl, VT));
8103
8104 // Split the quotient into low and high parts.
8105 SDValue QuotL, QuotH;
8106 std::tie(QuotL, QuotH) = DAG.SplitScalar(Quotient, dl, HiLoVT, HiLoVT);
8107 Result.push_back(QuotL);
8108 Result.push_back(QuotH);
8109 }
8110
8111 if (Opcode != ISD::UDIV) {
8112 // If we shifted the input, shift the remainder left and add the bits we
8113 // shifted off the input.
8114 if (TrailingZeros) {
8115 RemL = DAG.getNode(ISD::SHL, dl, HiLoVT, RemL,
8116 DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
8117 RemL = DAG.getNode(ISD::ADD, dl, HiLoVT, RemL, PartialRem);
8118 }
8119 Result.push_back(RemL);
8120 Result.push_back(DAG.getConstant(0, dl, HiLoVT));
8121 }
8122
8123 return true;
8124 }
8125
8126 // Check that (every element of) Z is undef or not an exact multiple of BW.
isNonZeroModBitWidthOrUndef(SDValue Z,unsigned BW)8127 static bool isNonZeroModBitWidthOrUndef(SDValue Z, unsigned BW) {
8128 return ISD::matchUnaryPredicate(
8129 Z,
8130 [=](ConstantSDNode *C) { return !C || C->getAPIntValue().urem(BW) != 0; },
8131 /*AllowUndef=*/true, /*AllowTruncation=*/true);
8132 }
8133
expandVPFunnelShift(SDNode * Node,SelectionDAG & DAG)8134 static SDValue expandVPFunnelShift(SDNode *Node, SelectionDAG &DAG) {
8135 EVT VT = Node->getValueType(0);
8136 SDValue ShX, ShY;
8137 SDValue ShAmt, InvShAmt;
8138 SDValue X = Node->getOperand(0);
8139 SDValue Y = Node->getOperand(1);
8140 SDValue Z = Node->getOperand(2);
8141 SDValue Mask = Node->getOperand(3);
8142 SDValue VL = Node->getOperand(4);
8143
8144 unsigned BW = VT.getScalarSizeInBits();
8145 bool IsFSHL = Node->getOpcode() == ISD::VP_FSHL;
8146 SDLoc DL(SDValue(Node, 0));
8147
8148 EVT ShVT = Z.getValueType();
8149 if (isNonZeroModBitWidthOrUndef(Z, BW)) {
8150 // fshl: X << C | Y >> (BW - C)
8151 // fshr: X << (BW - C) | Y >> C
8152 // where C = Z % BW is not zero
8153 SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
8154 ShAmt = DAG.getNode(ISD::VP_UREM, DL, ShVT, Z, BitWidthC, Mask, VL);
8155 InvShAmt = DAG.getNode(ISD::VP_SUB, DL, ShVT, BitWidthC, ShAmt, Mask, VL);
8156 ShX = DAG.getNode(ISD::VP_SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt, Mask,
8157 VL);
8158 ShY = DAG.getNode(ISD::VP_SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt, Mask,
8159 VL);
8160 } else {
8161 // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
8162 // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
8163 SDValue BitMask = DAG.getConstant(BW - 1, DL, ShVT);
8164 if (isPowerOf2_32(BW)) {
8165 // Z % BW -> Z & (BW - 1)
8166 ShAmt = DAG.getNode(ISD::VP_AND, DL, ShVT, Z, BitMask, Mask, VL);
8167 // (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
8168 SDValue NotZ = DAG.getNode(ISD::VP_XOR, DL, ShVT, Z,
8169 DAG.getAllOnesConstant(DL, ShVT), Mask, VL);
8170 InvShAmt = DAG.getNode(ISD::VP_AND, DL, ShVT, NotZ, BitMask, Mask, VL);
8171 } else {
8172 SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
8173 ShAmt = DAG.getNode(ISD::VP_UREM, DL, ShVT, Z, BitWidthC, Mask, VL);
8174 InvShAmt = DAG.getNode(ISD::VP_SUB, DL, ShVT, BitMask, ShAmt, Mask, VL);
8175 }
8176
8177 SDValue One = DAG.getConstant(1, DL, ShVT);
8178 if (IsFSHL) {
8179 ShX = DAG.getNode(ISD::VP_SHL, DL, VT, X, ShAmt, Mask, VL);
8180 SDValue ShY1 = DAG.getNode(ISD::VP_SRL, DL, VT, Y, One, Mask, VL);
8181 ShY = DAG.getNode(ISD::VP_SRL, DL, VT, ShY1, InvShAmt, Mask, VL);
8182 } else {
8183 SDValue ShX1 = DAG.getNode(ISD::VP_SHL, DL, VT, X, One, Mask, VL);
8184 ShX = DAG.getNode(ISD::VP_SHL, DL, VT, ShX1, InvShAmt, Mask, VL);
8185 ShY = DAG.getNode(ISD::VP_SRL, DL, VT, Y, ShAmt, Mask, VL);
8186 }
8187 }
8188 return DAG.getNode(ISD::VP_OR, DL, VT, ShX, ShY, Mask, VL);
8189 }
8190
expandFunnelShift(SDNode * Node,SelectionDAG & DAG) const8191 SDValue TargetLowering::expandFunnelShift(SDNode *Node,
8192 SelectionDAG &DAG) const {
8193 if (Node->isVPOpcode())
8194 return expandVPFunnelShift(Node, DAG);
8195
8196 EVT VT = Node->getValueType(0);
8197
8198 if (VT.isVector() && (!isOperationLegalOrCustom(ISD::SHL, VT) ||
8199 !isOperationLegalOrCustom(ISD::SRL, VT) ||
8200 !isOperationLegalOrCustom(ISD::SUB, VT) ||
8201 !isOperationLegalOrCustomOrPromote(ISD::OR, VT)))
8202 return SDValue();
8203
8204 SDValue X = Node->getOperand(0);
8205 SDValue Y = Node->getOperand(1);
8206 SDValue Z = Node->getOperand(2);
8207
8208 unsigned BW = VT.getScalarSizeInBits();
8209 bool IsFSHL = Node->getOpcode() == ISD::FSHL;
8210 SDLoc DL(SDValue(Node, 0));
8211
8212 EVT ShVT = Z.getValueType();
8213
8214 // If a funnel shift in the other direction is more supported, use it.
8215 unsigned RevOpcode = IsFSHL ? ISD::FSHR : ISD::FSHL;
8216 if (!isOperationLegalOrCustom(Node->getOpcode(), VT) &&
8217 isOperationLegalOrCustom(RevOpcode, VT) && isPowerOf2_32(BW)) {
8218 if (isNonZeroModBitWidthOrUndef(Z, BW)) {
8219 // fshl X, Y, Z -> fshr X, Y, -Z
8220 // fshr X, Y, Z -> fshl X, Y, -Z
8221 SDValue Zero = DAG.getConstant(0, DL, ShVT);
8222 Z = DAG.getNode(ISD::SUB, DL, VT, Zero, Z);
8223 } else {
8224 // fshl X, Y, Z -> fshr (srl X, 1), (fshr X, Y, 1), ~Z
8225 // fshr X, Y, Z -> fshl (fshl X, Y, 1), (shl Y, 1), ~Z
8226 SDValue One = DAG.getConstant(1, DL, ShVT);
8227 if (IsFSHL) {
8228 Y = DAG.getNode(RevOpcode, DL, VT, X, Y, One);
8229 X = DAG.getNode(ISD::SRL, DL, VT, X, One);
8230 } else {
8231 X = DAG.getNode(RevOpcode, DL, VT, X, Y, One);
8232 Y = DAG.getNode(ISD::SHL, DL, VT, Y, One);
8233 }
8234 Z = DAG.getNOT(DL, Z, ShVT);
8235 }
8236 return DAG.getNode(RevOpcode, DL, VT, X, Y, Z);
8237 }
8238
8239 SDValue ShX, ShY;
8240 SDValue ShAmt, InvShAmt;
8241 if (isNonZeroModBitWidthOrUndef(Z, BW)) {
8242 // fshl: X << C | Y >> (BW - C)
8243 // fshr: X << (BW - C) | Y >> C
8244 // where C = Z % BW is not zero
8245 SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
8246 ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC);
8247 InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt);
8248 ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt);
8249 ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt);
8250 } else {
8251 // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
8252 // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
8253 SDValue Mask = DAG.getConstant(BW - 1, DL, ShVT);
8254 if (isPowerOf2_32(BW)) {
8255 // Z % BW -> Z & (BW - 1)
8256 ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask);
8257 // (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
8258 InvShAmt = DAG.getNode(ISD::AND, DL, ShVT, DAG.getNOT(DL, Z, ShVT), Mask);
8259 } else {
8260 SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT);
8261 ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC);
8262 InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, Mask, ShAmt);
8263 }
8264
8265 SDValue One = DAG.getConstant(1, DL, ShVT);
8266 if (IsFSHL) {
8267 ShX = DAG.getNode(ISD::SHL, DL, VT, X, ShAmt);
8268 SDValue ShY1 = DAG.getNode(ISD::SRL, DL, VT, Y, One);
8269 ShY = DAG.getNode(ISD::SRL, DL, VT, ShY1, InvShAmt);
8270 } else {
8271 SDValue ShX1 = DAG.getNode(ISD::SHL, DL, VT, X, One);
8272 ShX = DAG.getNode(ISD::SHL, DL, VT, ShX1, InvShAmt);
8273 ShY = DAG.getNode(ISD::SRL, DL, VT, Y, ShAmt);
8274 }
8275 }
8276 return DAG.getNode(ISD::OR, DL, VT, ShX, ShY);
8277 }
8278
8279 // TODO: Merge with expandFunnelShift.
expandROT(SDNode * Node,bool AllowVectorOps,SelectionDAG & DAG) const8280 SDValue TargetLowering::expandROT(SDNode *Node, bool AllowVectorOps,
8281 SelectionDAG &DAG) const {
8282 EVT VT = Node->getValueType(0);
8283 unsigned EltSizeInBits = VT.getScalarSizeInBits();
8284 bool IsLeft = Node->getOpcode() == ISD::ROTL;
8285 SDValue Op0 = Node->getOperand(0);
8286 SDValue Op1 = Node->getOperand(1);
8287 SDLoc DL(SDValue(Node, 0));
8288
8289 EVT ShVT = Op1.getValueType();
8290 SDValue Zero = DAG.getConstant(0, DL, ShVT);
8291
8292 // If a rotate in the other direction is more supported, use it.
8293 unsigned RevRot = IsLeft ? ISD::ROTR : ISD::ROTL;
8294 if (!isOperationLegalOrCustom(Node->getOpcode(), VT) &&
8295 isOperationLegalOrCustom(RevRot, VT) && isPowerOf2_32(EltSizeInBits)) {
8296 SDValue Sub = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Op1);
8297 return DAG.getNode(RevRot, DL, VT, Op0, Sub);
8298 }
8299
8300 if (!AllowVectorOps && VT.isVector() &&
8301 (!isOperationLegalOrCustom(ISD::SHL, VT) ||
8302 !isOperationLegalOrCustom(ISD::SRL, VT) ||
8303 !isOperationLegalOrCustom(ISD::SUB, VT) ||
8304 !isOperationLegalOrCustomOrPromote(ISD::OR, VT) ||
8305 !isOperationLegalOrCustomOrPromote(ISD::AND, VT)))
8306 return SDValue();
8307
8308 unsigned ShOpc = IsLeft ? ISD::SHL : ISD::SRL;
8309 unsigned HsOpc = IsLeft ? ISD::SRL : ISD::SHL;
8310 SDValue BitWidthMinusOneC = DAG.getConstant(EltSizeInBits - 1, DL, ShVT);
8311 SDValue ShVal;
8312 SDValue HsVal;
8313 if (isPowerOf2_32(EltSizeInBits)) {
8314 // (rotl x, c) -> x << (c & (w - 1)) | x >> (-c & (w - 1))
8315 // (rotr x, c) -> x >> (c & (w - 1)) | x << (-c & (w - 1))
8316 SDValue NegOp1 = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Op1);
8317 SDValue ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Op1, BitWidthMinusOneC);
8318 ShVal = DAG.getNode(ShOpc, DL, VT, Op0, ShAmt);
8319 SDValue HsAmt = DAG.getNode(ISD::AND, DL, ShVT, NegOp1, BitWidthMinusOneC);
8320 HsVal = DAG.getNode(HsOpc, DL, VT, Op0, HsAmt);
8321 } else {
8322 // (rotl x, c) -> x << (c % w) | x >> 1 >> (w - 1 - (c % w))
8323 // (rotr x, c) -> x >> (c % w) | x << 1 << (w - 1 - (c % w))
8324 SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT);
8325 SDValue ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Op1, BitWidthC);
8326 ShVal = DAG.getNode(ShOpc, DL, VT, Op0, ShAmt);
8327 SDValue HsAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthMinusOneC, ShAmt);
8328 SDValue One = DAG.getConstant(1, DL, ShVT);
8329 HsVal =
8330 DAG.getNode(HsOpc, DL, VT, DAG.getNode(HsOpc, DL, VT, Op0, One), HsAmt);
8331 }
8332 return DAG.getNode(ISD::OR, DL, VT, ShVal, HsVal);
8333 }
8334
expandShiftParts(SDNode * Node,SDValue & Lo,SDValue & Hi,SelectionDAG & DAG) const8335 void TargetLowering::expandShiftParts(SDNode *Node, SDValue &Lo, SDValue &Hi,
8336 SelectionDAG &DAG) const {
8337 assert(Node->getNumOperands() == 3 && "Not a double-shift!");
8338 EVT VT = Node->getValueType(0);
8339 unsigned VTBits = VT.getScalarSizeInBits();
8340 assert(isPowerOf2_32(VTBits) && "Power-of-two integer type expected");
8341
8342 bool IsSHL = Node->getOpcode() == ISD::SHL_PARTS;
8343 bool IsSRA = Node->getOpcode() == ISD::SRA_PARTS;
8344 SDValue ShOpLo = Node->getOperand(0);
8345 SDValue ShOpHi = Node->getOperand(1);
8346 SDValue ShAmt = Node->getOperand(2);
8347 EVT ShAmtVT = ShAmt.getValueType();
8348 EVT ShAmtCCVT =
8349 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ShAmtVT);
8350 SDLoc dl(Node);
8351
8352 // ISD::FSHL and ISD::FSHR have defined overflow behavior but ISD::SHL and
8353 // ISD::SRA/L nodes haven't. Insert an AND to be safe, it's usually optimized
8354 // away during isel.
8355 SDValue SafeShAmt = DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
8356 DAG.getConstant(VTBits - 1, dl, ShAmtVT));
8357 SDValue Tmp1 = IsSRA ? DAG.getNode(ISD::SRA, dl, VT, ShOpHi,
8358 DAG.getConstant(VTBits - 1, dl, ShAmtVT))
8359 : DAG.getConstant(0, dl, VT);
8360
8361 SDValue Tmp2, Tmp3;
8362 if (IsSHL) {
8363 Tmp2 = DAG.getNode(ISD::FSHL, dl, VT, ShOpHi, ShOpLo, ShAmt);
8364 Tmp3 = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, SafeShAmt);
8365 } else {
8366 Tmp2 = DAG.getNode(ISD::FSHR, dl, VT, ShOpHi, ShOpLo, ShAmt);
8367 Tmp3 = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, dl, VT, ShOpHi, SafeShAmt);
8368 }
8369
8370 // If the shift amount is larger or equal than the width of a part we don't
8371 // use the result from the FSHL/FSHR. Insert a test and select the appropriate
8372 // values for large shift amounts.
8373 SDValue AndNode = DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
8374 DAG.getConstant(VTBits, dl, ShAmtVT));
8375 SDValue Cond = DAG.getSetCC(dl, ShAmtCCVT, AndNode,
8376 DAG.getConstant(0, dl, ShAmtVT), ISD::SETNE);
8377
8378 if (IsSHL) {
8379 Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2);
8380 Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3);
8381 } else {
8382 Lo = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp3, Tmp2);
8383 Hi = DAG.getNode(ISD::SELECT, dl, VT, Cond, Tmp1, Tmp3);
8384 }
8385 }
8386
expandFP_TO_SINT(SDNode * Node,SDValue & Result,SelectionDAG & DAG) const8387 bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result,
8388 SelectionDAG &DAG) const {
8389 unsigned OpNo = Node->isStrictFPOpcode() ? 1 : 0;
8390 SDValue Src = Node->getOperand(OpNo);
8391 EVT SrcVT = Src.getValueType();
8392 EVT DstVT = Node->getValueType(0);
8393 SDLoc dl(SDValue(Node, 0));
8394
8395 // FIXME: Only f32 to i64 conversions are supported.
8396 if (SrcVT != MVT::f32 || DstVT != MVT::i64)
8397 return false;
8398
8399 if (Node->isStrictFPOpcode())
8400 // When a NaN is converted to an integer a trap is allowed. We can't
8401 // use this expansion here because it would eliminate that trap. Other
8402 // traps are also allowed and cannot be eliminated. See
8403 // IEEE 754-2008 sec 5.8.
8404 return false;
8405
8406 // Expand f32 -> i64 conversion
8407 // This algorithm comes from compiler-rt's implementation of fixsfdi:
8408 // https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/fixsfdi.c
8409 unsigned SrcEltBits = SrcVT.getScalarSizeInBits();
8410 EVT IntVT = SrcVT.changeTypeToInteger();
8411 EVT IntShVT = getShiftAmountTy(IntVT, DAG.getDataLayout());
8412
8413 SDValue ExponentMask = DAG.getConstant(0x7F800000, dl, IntVT);
8414 SDValue ExponentLoBit = DAG.getConstant(23, dl, IntVT);
8415 SDValue Bias = DAG.getConstant(127, dl, IntVT);
8416 SDValue SignMask = DAG.getConstant(APInt::getSignMask(SrcEltBits), dl, IntVT);
8417 SDValue SignLowBit = DAG.getConstant(SrcEltBits - 1, dl, IntVT);
8418 SDValue MantissaMask = DAG.getConstant(0x007FFFFF, dl, IntVT);
8419
8420 SDValue Bits = DAG.getNode(ISD::BITCAST, dl, IntVT, Src);
8421
8422 SDValue ExponentBits = DAG.getNode(
8423 ISD::SRL, dl, IntVT, DAG.getNode(ISD::AND, dl, IntVT, Bits, ExponentMask),
8424 DAG.getZExtOrTrunc(ExponentLoBit, dl, IntShVT));
8425 SDValue Exponent = DAG.getNode(ISD::SUB, dl, IntVT, ExponentBits, Bias);
8426
8427 SDValue Sign = DAG.getNode(ISD::SRA, dl, IntVT,
8428 DAG.getNode(ISD::AND, dl, IntVT, Bits, SignMask),
8429 DAG.getZExtOrTrunc(SignLowBit, dl, IntShVT));
8430 Sign = DAG.getSExtOrTrunc(Sign, dl, DstVT);
8431
8432 SDValue R = DAG.getNode(ISD::OR, dl, IntVT,
8433 DAG.getNode(ISD::AND, dl, IntVT, Bits, MantissaMask),
8434 DAG.getConstant(0x00800000, dl, IntVT));
8435
8436 R = DAG.getZExtOrTrunc(R, dl, DstVT);
8437
8438 R = DAG.getSelectCC(
8439 dl, Exponent, ExponentLoBit,
8440 DAG.getNode(ISD::SHL, dl, DstVT, R,
8441 DAG.getZExtOrTrunc(
8442 DAG.getNode(ISD::SUB, dl, IntVT, Exponent, ExponentLoBit),
8443 dl, IntShVT)),
8444 DAG.getNode(ISD::SRL, dl, DstVT, R,
8445 DAG.getZExtOrTrunc(
8446 DAG.getNode(ISD::SUB, dl, IntVT, ExponentLoBit, Exponent),
8447 dl, IntShVT)),
8448 ISD::SETGT);
8449
8450 SDValue Ret = DAG.getNode(ISD::SUB, dl, DstVT,
8451 DAG.getNode(ISD::XOR, dl, DstVT, R, Sign), Sign);
8452
8453 Result = DAG.getSelectCC(dl, Exponent, DAG.getConstant(0, dl, IntVT),
8454 DAG.getConstant(0, dl, DstVT), Ret, ISD::SETLT);
8455 return true;
8456 }
8457
expandFP_TO_UINT(SDNode * Node,SDValue & Result,SDValue & Chain,SelectionDAG & DAG) const8458 bool TargetLowering::expandFP_TO_UINT(SDNode *Node, SDValue &Result,
8459 SDValue &Chain,
8460 SelectionDAG &DAG) const {
8461 SDLoc dl(SDValue(Node, 0));
8462 unsigned OpNo = Node->isStrictFPOpcode() ? 1 : 0;
8463 SDValue Src = Node->getOperand(OpNo);
8464
8465 EVT SrcVT = Src.getValueType();
8466 EVT DstVT = Node->getValueType(0);
8467 EVT SetCCVT =
8468 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
8469 EVT DstSetCCVT =
8470 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), DstVT);
8471
8472 // Only expand vector types if we have the appropriate vector bit operations.
8473 unsigned SIntOpcode = Node->isStrictFPOpcode() ? ISD::STRICT_FP_TO_SINT :
8474 ISD::FP_TO_SINT;
8475 if (DstVT.isVector() && (!isOperationLegalOrCustom(SIntOpcode, DstVT) ||
8476 !isOperationLegalOrCustomOrPromote(ISD::XOR, SrcVT)))
8477 return false;
8478
8479 // If the maximum float value is smaller then the signed integer range,
8480 // the destination signmask can't be represented by the float, so we can
8481 // just use FP_TO_SINT directly.
8482 const fltSemantics &APFSem = SrcVT.getFltSemantics();
8483 APFloat APF(APFSem, APInt::getZero(SrcVT.getScalarSizeInBits()));
8484 APInt SignMask = APInt::getSignMask(DstVT.getScalarSizeInBits());
8485 if (APFloat::opOverflow &
8486 APF.convertFromAPInt(SignMask, false, APFloat::rmNearestTiesToEven)) {
8487 if (Node->isStrictFPOpcode()) {
8488 Result = DAG.getNode(ISD::STRICT_FP_TO_SINT, dl, { DstVT, MVT::Other },
8489 { Node->getOperand(0), Src });
8490 Chain = Result.getValue(1);
8491 } else
8492 Result = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Src);
8493 return true;
8494 }
8495
8496 // Don't expand it if there isn't cheap fsub instruction.
8497 if (!isOperationLegalOrCustom(
8498 Node->isStrictFPOpcode() ? ISD::STRICT_FSUB : ISD::FSUB, SrcVT))
8499 return false;
8500
8501 SDValue Cst = DAG.getConstantFP(APF, dl, SrcVT);
8502 SDValue Sel;
8503
8504 if (Node->isStrictFPOpcode()) {
8505 Sel = DAG.getSetCC(dl, SetCCVT, Src, Cst, ISD::SETLT,
8506 Node->getOperand(0), /*IsSignaling*/ true);
8507 Chain = Sel.getValue(1);
8508 } else {
8509 Sel = DAG.getSetCC(dl, SetCCVT, Src, Cst, ISD::SETLT);
8510 }
8511
8512 bool Strict = Node->isStrictFPOpcode() ||
8513 shouldUseStrictFP_TO_INT(SrcVT, DstVT, /*IsSigned*/ false);
8514
8515 if (Strict) {
8516 // Expand based on maximum range of FP_TO_SINT, if the value exceeds the
8517 // signmask then offset (the result of which should be fully representable).
8518 // Sel = Src < 0x8000000000000000
8519 // FltOfs = select Sel, 0, 0x8000000000000000
8520 // IntOfs = select Sel, 0, 0x8000000000000000
8521 // Result = fp_to_sint(Src - FltOfs) ^ IntOfs
8522
8523 // TODO: Should any fast-math-flags be set for the FSUB?
8524 SDValue FltOfs = DAG.getSelect(dl, SrcVT, Sel,
8525 DAG.getConstantFP(0.0, dl, SrcVT), Cst);
8526 Sel = DAG.getBoolExtOrTrunc(Sel, dl, DstSetCCVT, DstVT);
8527 SDValue IntOfs = DAG.getSelect(dl, DstVT, Sel,
8528 DAG.getConstant(0, dl, DstVT),
8529 DAG.getConstant(SignMask, dl, DstVT));
8530 SDValue SInt;
8531 if (Node->isStrictFPOpcode()) {
8532 SDValue Val = DAG.getNode(ISD::STRICT_FSUB, dl, { SrcVT, MVT::Other },
8533 { Chain, Src, FltOfs });
8534 SInt = DAG.getNode(ISD::STRICT_FP_TO_SINT, dl, { DstVT, MVT::Other },
8535 { Val.getValue(1), Val });
8536 Chain = SInt.getValue(1);
8537 } else {
8538 SDValue Val = DAG.getNode(ISD::FSUB, dl, SrcVT, Src, FltOfs);
8539 SInt = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Val);
8540 }
8541 Result = DAG.getNode(ISD::XOR, dl, DstVT, SInt, IntOfs);
8542 } else {
8543 // Expand based on maximum range of FP_TO_SINT:
8544 // True = fp_to_sint(Src)
8545 // False = 0x8000000000000000 + fp_to_sint(Src - 0x8000000000000000)
8546 // Result = select (Src < 0x8000000000000000), True, False
8547
8548 SDValue True = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT, Src);
8549 // TODO: Should any fast-math-flags be set for the FSUB?
8550 SDValue False = DAG.getNode(ISD::FP_TO_SINT, dl, DstVT,
8551 DAG.getNode(ISD::FSUB, dl, SrcVT, Src, Cst));
8552 False = DAG.getNode(ISD::XOR, dl, DstVT, False,
8553 DAG.getConstant(SignMask, dl, DstVT));
8554 Sel = DAG.getBoolExtOrTrunc(Sel, dl, DstSetCCVT, DstVT);
8555 Result = DAG.getSelect(dl, DstVT, Sel, True, False);
8556 }
8557 return true;
8558 }
8559
expandUINT_TO_FP(SDNode * Node,SDValue & Result,SDValue & Chain,SelectionDAG & DAG) const8560 bool TargetLowering::expandUINT_TO_FP(SDNode *Node, SDValue &Result,
8561 SDValue &Chain, SelectionDAG &DAG) const {
8562 // This transform is not correct for converting 0 when rounding mode is set
8563 // to round toward negative infinity which will produce -0.0. So disable
8564 // under strictfp.
8565 if (Node->isStrictFPOpcode())
8566 return false;
8567
8568 SDValue Src = Node->getOperand(0);
8569 EVT SrcVT = Src.getValueType();
8570 EVT DstVT = Node->getValueType(0);
8571
8572 // If the input is known to be non-negative and SINT_TO_FP is legal then use
8573 // it.
8574 if (Node->getFlags().hasNonNeg() &&
8575 isOperationLegalOrCustom(ISD::SINT_TO_FP, SrcVT)) {
8576 Result =
8577 DAG.getNode(ISD::SINT_TO_FP, SDLoc(Node), DstVT, Node->getOperand(0));
8578 return true;
8579 }
8580
8581 if (SrcVT.getScalarType() != MVT::i64 || DstVT.getScalarType() != MVT::f64)
8582 return false;
8583
8584 // Only expand vector types if we have the appropriate vector bit
8585 // operations.
8586 if (SrcVT.isVector() && (!isOperationLegalOrCustom(ISD::SRL, SrcVT) ||
8587 !isOperationLegalOrCustom(ISD::FADD, DstVT) ||
8588 !isOperationLegalOrCustom(ISD::FSUB, DstVT) ||
8589 !isOperationLegalOrCustomOrPromote(ISD::OR, SrcVT) ||
8590 !isOperationLegalOrCustomOrPromote(ISD::AND, SrcVT)))
8591 return false;
8592
8593 SDLoc dl(SDValue(Node, 0));
8594
8595 // Implementation of unsigned i64 to f64 following the algorithm in
8596 // __floatundidf in compiler_rt. This implementation performs rounding
8597 // correctly in all rounding modes with the exception of converting 0
8598 // when rounding toward negative infinity. In that case the fsub will
8599 // produce -0.0. This will be added to +0.0 and produce -0.0 which is
8600 // incorrect.
8601 SDValue TwoP52 = DAG.getConstant(UINT64_C(0x4330000000000000), dl, SrcVT);
8602 SDValue TwoP84PlusTwoP52 = DAG.getConstantFP(
8603 llvm::bit_cast<double>(UINT64_C(0x4530000000100000)), dl, DstVT);
8604 SDValue TwoP84 = DAG.getConstant(UINT64_C(0x4530000000000000), dl, SrcVT);
8605 SDValue LoMask = DAG.getConstant(UINT64_C(0x00000000FFFFFFFF), dl, SrcVT);
8606 SDValue HiShift = DAG.getShiftAmountConstant(32, SrcVT, dl);
8607
8608 SDValue Lo = DAG.getNode(ISD::AND, dl, SrcVT, Src, LoMask);
8609 SDValue Hi = DAG.getNode(ISD::SRL, dl, SrcVT, Src, HiShift);
8610 SDValue LoOr = DAG.getNode(ISD::OR, dl, SrcVT, Lo, TwoP52);
8611 SDValue HiOr = DAG.getNode(ISD::OR, dl, SrcVT, Hi, TwoP84);
8612 SDValue LoFlt = DAG.getBitcast(DstVT, LoOr);
8613 SDValue HiFlt = DAG.getBitcast(DstVT, HiOr);
8614 SDValue HiSub = DAG.getNode(ISD::FSUB, dl, DstVT, HiFlt, TwoP84PlusTwoP52);
8615 Result = DAG.getNode(ISD::FADD, dl, DstVT, LoFlt, HiSub);
8616 return true;
8617 }
8618
8619 SDValue
createSelectForFMINNUM_FMAXNUM(SDNode * Node,SelectionDAG & DAG) const8620 TargetLowering::createSelectForFMINNUM_FMAXNUM(SDNode *Node,
8621 SelectionDAG &DAG) const {
8622 unsigned Opcode = Node->getOpcode();
8623 assert((Opcode == ISD::FMINNUM || Opcode == ISD::FMAXNUM ||
8624 Opcode == ISD::STRICT_FMINNUM || Opcode == ISD::STRICT_FMAXNUM) &&
8625 "Wrong opcode");
8626
8627 if (Node->getFlags().hasNoNaNs()) {
8628 ISD::CondCode Pred = Opcode == ISD::FMINNUM ? ISD::SETLT : ISD::SETGT;
8629 EVT VT = Node->getValueType(0);
8630 if ((!isCondCodeLegal(Pred, VT.getSimpleVT()) ||
8631 !isOperationLegalOrCustom(ISD::VSELECT, VT)) &&
8632 VT.isVector())
8633 return SDValue();
8634 SDValue Op1 = Node->getOperand(0);
8635 SDValue Op2 = Node->getOperand(1);
8636 SDValue SelCC = DAG.getSelectCC(SDLoc(Node), Op1, Op2, Op1, Op2, Pred);
8637 SelCC->setFlags(Node->getFlags());
8638 return SelCC;
8639 }
8640
8641 return SDValue();
8642 }
8643
expandFMINNUM_FMAXNUM(SDNode * Node,SelectionDAG & DAG) const8644 SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,
8645 SelectionDAG &DAG) const {
8646 if (SDValue Expanded = expandVectorNaryOpBySplitting(Node, DAG))
8647 return Expanded;
8648
8649 EVT VT = Node->getValueType(0);
8650 if (VT.isScalableVector())
8651 report_fatal_error(
8652 "Expanding fminnum/fmaxnum for scalable vectors is undefined.");
8653
8654 SDLoc dl(Node);
8655 unsigned NewOp =
8656 Node->getOpcode() == ISD::FMINNUM ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
8657
8658 if (isOperationLegalOrCustom(NewOp, VT)) {
8659 SDValue Quiet0 = Node->getOperand(0);
8660 SDValue Quiet1 = Node->getOperand(1);
8661
8662 if (!Node->getFlags().hasNoNaNs()) {
8663 // Insert canonicalizes if it's possible we need to quiet to get correct
8664 // sNaN behavior.
8665 if (!DAG.isKnownNeverSNaN(Quiet0)) {
8666 Quiet0 = DAG.getNode(ISD::FCANONICALIZE, dl, VT, Quiet0,
8667 Node->getFlags());
8668 }
8669 if (!DAG.isKnownNeverSNaN(Quiet1)) {
8670 Quiet1 = DAG.getNode(ISD::FCANONICALIZE, dl, VT, Quiet1,
8671 Node->getFlags());
8672 }
8673 }
8674
8675 return DAG.getNode(NewOp, dl, VT, Quiet0, Quiet1, Node->getFlags());
8676 }
8677
8678 // If the target has FMINIMUM/FMAXIMUM but not FMINNUM/FMAXNUM use that
8679 // instead if there are no NaNs and there can't be an incompatible zero
8680 // compare: at least one operand isn't +/-0, or there are no signed-zeros.
8681 if ((Node->getFlags().hasNoNaNs() ||
8682 (DAG.isKnownNeverNaN(Node->getOperand(0)) &&
8683 DAG.isKnownNeverNaN(Node->getOperand(1)))) &&
8684 (Node->getFlags().hasNoSignedZeros() ||
8685 DAG.isKnownNeverZeroFloat(Node->getOperand(0)) ||
8686 DAG.isKnownNeverZeroFloat(Node->getOperand(1)))) {
8687 unsigned IEEE2018Op =
8688 Node->getOpcode() == ISD::FMINNUM ? ISD::FMINIMUM : ISD::FMAXIMUM;
8689 if (isOperationLegalOrCustom(IEEE2018Op, VT))
8690 return DAG.getNode(IEEE2018Op, dl, VT, Node->getOperand(0),
8691 Node->getOperand(1), Node->getFlags());
8692 }
8693
8694 if (SDValue SelCC = createSelectForFMINNUM_FMAXNUM(Node, DAG))
8695 return SelCC;
8696
8697 return SDValue();
8698 }
8699
expandFMINIMUM_FMAXIMUM(SDNode * N,SelectionDAG & DAG) const8700 SDValue TargetLowering::expandFMINIMUM_FMAXIMUM(SDNode *N,
8701 SelectionDAG &DAG) const {
8702 if (SDValue Expanded = expandVectorNaryOpBySplitting(N, DAG))
8703 return Expanded;
8704
8705 SDLoc DL(N);
8706 SDValue LHS = N->getOperand(0);
8707 SDValue RHS = N->getOperand(1);
8708 unsigned Opc = N->getOpcode();
8709 EVT VT = N->getValueType(0);
8710 EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
8711 bool IsMax = Opc == ISD::FMAXIMUM;
8712 SDNodeFlags Flags = N->getFlags();
8713
8714 // First, implement comparison not propagating NaN. If no native fmin or fmax
8715 // available, use plain select with setcc instead.
8716 SDValue MinMax;
8717 unsigned CompOpcIeee = IsMax ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
8718 unsigned CompOpc = IsMax ? ISD::FMAXNUM : ISD::FMINNUM;
8719
8720 // FIXME: We should probably define fminnum/fmaxnum variants with correct
8721 // signed zero behavior.
8722 bool MinMaxMustRespectOrderedZero = false;
8723
8724 if (isOperationLegalOrCustom(CompOpcIeee, VT)) {
8725 MinMax = DAG.getNode(CompOpcIeee, DL, VT, LHS, RHS, Flags);
8726 MinMaxMustRespectOrderedZero = true;
8727 } else if (isOperationLegalOrCustom(CompOpc, VT)) {
8728 MinMax = DAG.getNode(CompOpc, DL, VT, LHS, RHS, Flags);
8729 } else {
8730 if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
8731 return DAG.UnrollVectorOp(N);
8732
8733 // NaN (if exists) will be propagated later, so orderness doesn't matter.
8734 SDValue Compare =
8735 DAG.getSetCC(DL, CCVT, LHS, RHS, IsMax ? ISD::SETOGT : ISD::SETOLT);
8736 MinMax = DAG.getSelect(DL, VT, Compare, LHS, RHS, Flags);
8737 }
8738
8739 // Propagate any NaN of both operands
8740 if (!N->getFlags().hasNoNaNs() &&
8741 (!DAG.isKnownNeverNaN(RHS) || !DAG.isKnownNeverNaN(LHS))) {
8742 ConstantFP *FPNaN = ConstantFP::get(*DAG.getContext(),
8743 APFloat::getNaN(VT.getFltSemantics()));
8744 MinMax = DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, LHS, RHS, ISD::SETUO),
8745 DAG.getConstantFP(*FPNaN, DL, VT), MinMax, Flags);
8746 }
8747
8748 // fminimum/fmaximum requires -0.0 less than +0.0
8749 if (!MinMaxMustRespectOrderedZero && !N->getFlags().hasNoSignedZeros() &&
8750 !DAG.isKnownNeverZeroFloat(RHS) && !DAG.isKnownNeverZeroFloat(LHS)) {
8751 SDValue IsZero = DAG.getSetCC(DL, CCVT, MinMax,
8752 DAG.getConstantFP(0.0, DL, VT), ISD::SETOEQ);
8753 SDValue TestZero =
8754 DAG.getTargetConstant(IsMax ? fcPosZero : fcNegZero, DL, MVT::i32);
8755 SDValue LCmp = DAG.getSelect(
8756 DL, VT, DAG.getNode(ISD::IS_FPCLASS, DL, CCVT, LHS, TestZero), LHS,
8757 MinMax, Flags);
8758 SDValue RCmp = DAG.getSelect(
8759 DL, VT, DAG.getNode(ISD::IS_FPCLASS, DL, CCVT, RHS, TestZero), RHS,
8760 LCmp, Flags);
8761 MinMax = DAG.getSelect(DL, VT, IsZero, RCmp, MinMax, Flags);
8762 }
8763
8764 return MinMax;
8765 }
8766
expandFMINIMUMNUM_FMAXIMUMNUM(SDNode * Node,SelectionDAG & DAG) const8767 SDValue TargetLowering::expandFMINIMUMNUM_FMAXIMUMNUM(SDNode *Node,
8768 SelectionDAG &DAG) const {
8769 SDLoc DL(Node);
8770 SDValue LHS = Node->getOperand(0);
8771 SDValue RHS = Node->getOperand(1);
8772 unsigned Opc = Node->getOpcode();
8773 EVT VT = Node->getValueType(0);
8774 EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
8775 bool IsMax = Opc == ISD::FMAXIMUMNUM;
8776 const TargetOptions &Options = DAG.getTarget().Options;
8777 SDNodeFlags Flags = Node->getFlags();
8778
8779 unsigned NewOp =
8780 Opc == ISD::FMINIMUMNUM ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
8781
8782 if (isOperationLegalOrCustom(NewOp, VT)) {
8783 if (!Flags.hasNoNaNs()) {
8784 // Insert canonicalizes if it's possible we need to quiet to get correct
8785 // sNaN behavior.
8786 if (!DAG.isKnownNeverSNaN(LHS)) {
8787 LHS = DAG.getNode(ISD::FCANONICALIZE, DL, VT, LHS, Flags);
8788 }
8789 if (!DAG.isKnownNeverSNaN(RHS)) {
8790 RHS = DAG.getNode(ISD::FCANONICALIZE, DL, VT, RHS, Flags);
8791 }
8792 }
8793
8794 return DAG.getNode(NewOp, DL, VT, LHS, RHS, Flags);
8795 }
8796
8797 // We can use FMINIMUM/FMAXIMUM if there is no NaN, since it has
8798 // same behaviors for all of other cases: +0.0 vs -0.0 included.
8799 if (Flags.hasNoNaNs() ||
8800 (DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS))) {
8801 unsigned IEEE2019Op =
8802 Opc == ISD::FMINIMUMNUM ? ISD::FMINIMUM : ISD::FMAXIMUM;
8803 if (isOperationLegalOrCustom(IEEE2019Op, VT))
8804 return DAG.getNode(IEEE2019Op, DL, VT, LHS, RHS, Flags);
8805 }
8806
8807 // FMINNUM/FMAXMUM returns qNaN if either operand is sNaN, and it may return
8808 // either one for +0.0 vs -0.0.
8809 if ((Flags.hasNoNaNs() ||
8810 (DAG.isKnownNeverSNaN(LHS) && DAG.isKnownNeverSNaN(RHS))) &&
8811 (Flags.hasNoSignedZeros() || DAG.isKnownNeverZeroFloat(LHS) ||
8812 DAG.isKnownNeverZeroFloat(RHS))) {
8813 unsigned IEEE2008Op = Opc == ISD::FMINIMUMNUM ? ISD::FMINNUM : ISD::FMAXNUM;
8814 if (isOperationLegalOrCustom(IEEE2008Op, VT))
8815 return DAG.getNode(IEEE2008Op, DL, VT, LHS, RHS, Flags);
8816 }
8817
8818 if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
8819 return DAG.UnrollVectorOp(Node);
8820
8821 // If only one operand is NaN, override it with another operand.
8822 if (!Flags.hasNoNaNs() && !DAG.isKnownNeverNaN(LHS)) {
8823 LHS = DAG.getSelectCC(DL, LHS, LHS, RHS, LHS, ISD::SETUO);
8824 }
8825 if (!Flags.hasNoNaNs() && !DAG.isKnownNeverNaN(RHS)) {
8826 RHS = DAG.getSelectCC(DL, RHS, RHS, LHS, RHS, ISD::SETUO);
8827 }
8828
8829 SDValue MinMax =
8830 DAG.getSelectCC(DL, LHS, RHS, LHS, RHS, IsMax ? ISD::SETGT : ISD::SETLT);
8831
8832 // TODO: We need quiet sNaN if strictfp.
8833
8834 // Fixup signed zero behavior.
8835 if (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros() ||
8836 DAG.isKnownNeverZeroFloat(LHS) || DAG.isKnownNeverZeroFloat(RHS)) {
8837 return MinMax;
8838 }
8839 SDValue TestZero =
8840 DAG.getTargetConstant(IsMax ? fcPosZero : fcNegZero, DL, MVT::i32);
8841 SDValue IsZero = DAG.getSetCC(DL, CCVT, MinMax,
8842 DAG.getConstantFP(0.0, DL, VT), ISD::SETEQ);
8843 SDValue LCmp = DAG.getSelect(
8844 DL, VT, DAG.getNode(ISD::IS_FPCLASS, DL, CCVT, LHS, TestZero), LHS,
8845 MinMax, Flags);
8846 SDValue RCmp = DAG.getSelect(
8847 DL, VT, DAG.getNode(ISD::IS_FPCLASS, DL, CCVT, RHS, TestZero), RHS, LCmp,
8848 Flags);
8849 return DAG.getSelect(DL, VT, IsZero, RCmp, MinMax, Flags);
8850 }
8851
8852 /// Returns a true value if if this FPClassTest can be performed with an ordered
8853 /// fcmp to 0, and a false value if it's an unordered fcmp to 0. Returns
8854 /// std::nullopt if it cannot be performed as a compare with 0.
isFCmpEqualZero(FPClassTest Test,const fltSemantics & Semantics,const MachineFunction & MF)8855 static std::optional<bool> isFCmpEqualZero(FPClassTest Test,
8856 const fltSemantics &Semantics,
8857 const MachineFunction &MF) {
8858 FPClassTest OrderedMask = Test & ~fcNan;
8859 FPClassTest NanTest = Test & fcNan;
8860 bool IsOrdered = NanTest == fcNone;
8861 bool IsUnordered = NanTest == fcNan;
8862
8863 // Skip cases that are testing for only a qnan or snan.
8864 if (!IsOrdered && !IsUnordered)
8865 return std::nullopt;
8866
8867 if (OrderedMask == fcZero &&
8868 MF.getDenormalMode(Semantics).Input == DenormalMode::IEEE)
8869 return IsOrdered;
8870 if (OrderedMask == (fcZero | fcSubnormal) &&
8871 MF.getDenormalMode(Semantics).inputsAreZero())
8872 return IsOrdered;
8873 return std::nullopt;
8874 }
8875
expandIS_FPCLASS(EVT ResultVT,SDValue Op,const FPClassTest OrigTestMask,SDNodeFlags Flags,const SDLoc & DL,SelectionDAG & DAG) const8876 SDValue TargetLowering::expandIS_FPCLASS(EVT ResultVT, SDValue Op,
8877 const FPClassTest OrigTestMask,
8878 SDNodeFlags Flags, const SDLoc &DL,
8879 SelectionDAG &DAG) const {
8880 EVT OperandVT = Op.getValueType();
8881 assert(OperandVT.isFloatingPoint());
8882 FPClassTest Test = OrigTestMask;
8883
8884 // Degenerated cases.
8885 if (Test == fcNone)
8886 return DAG.getBoolConstant(false, DL, ResultVT, OperandVT);
8887 if (Test == fcAllFlags)
8888 return DAG.getBoolConstant(true, DL, ResultVT, OperandVT);
8889
8890 // PPC double double is a pair of doubles, of which the higher part determines
8891 // the value class.
8892 if (OperandVT == MVT::ppcf128) {
8893 Op = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::f64, Op,
8894 DAG.getConstant(1, DL, MVT::i32));
8895 OperandVT = MVT::f64;
8896 }
8897
8898 // Floating-point type properties.
8899 EVT ScalarFloatVT = OperandVT.getScalarType();
8900 const Type *FloatTy = ScalarFloatVT.getTypeForEVT(*DAG.getContext());
8901 const llvm::fltSemantics &Semantics = FloatTy->getFltSemantics();
8902 bool IsF80 = (ScalarFloatVT == MVT::f80);
8903
8904 // Some checks can be implemented using float comparisons, if floating point
8905 // exceptions are ignored.
8906 if (Flags.hasNoFPExcept() &&
8907 isOperationLegalOrCustom(ISD::SETCC, OperandVT.getScalarType())) {
8908 FPClassTest FPTestMask = Test;
8909 bool IsInvertedFP = false;
8910
8911 if (FPClassTest InvertedFPCheck =
8912 invertFPClassTestIfSimpler(FPTestMask, true)) {
8913 FPTestMask = InvertedFPCheck;
8914 IsInvertedFP = true;
8915 }
8916
8917 ISD::CondCode OrderedCmpOpcode = IsInvertedFP ? ISD::SETUNE : ISD::SETOEQ;
8918 ISD::CondCode UnorderedCmpOpcode = IsInvertedFP ? ISD::SETONE : ISD::SETUEQ;
8919
8920 // See if we can fold an | fcNan into an unordered compare.
8921 FPClassTest OrderedFPTestMask = FPTestMask & ~fcNan;
8922
8923 // Can't fold the ordered check if we're only testing for snan or qnan
8924 // individually.
8925 if ((FPTestMask & fcNan) != fcNan)
8926 OrderedFPTestMask = FPTestMask;
8927
8928 const bool IsOrdered = FPTestMask == OrderedFPTestMask;
8929
8930 if (std::optional<bool> IsCmp0 =
8931 isFCmpEqualZero(FPTestMask, Semantics, DAG.getMachineFunction());
8932 IsCmp0 && (isCondCodeLegalOrCustom(
8933 *IsCmp0 ? OrderedCmpOpcode : UnorderedCmpOpcode,
8934 OperandVT.getScalarType().getSimpleVT()))) {
8935
8936 // If denormals could be implicitly treated as 0, this is not equivalent
8937 // to a compare with 0 since it will also be true for denormals.
8938 return DAG.getSetCC(DL, ResultVT, Op,
8939 DAG.getConstantFP(0.0, DL, OperandVT),
8940 *IsCmp0 ? OrderedCmpOpcode : UnorderedCmpOpcode);
8941 }
8942
8943 if (FPTestMask == fcNan &&
8944 isCondCodeLegalOrCustom(IsInvertedFP ? ISD::SETO : ISD::SETUO,
8945 OperandVT.getScalarType().getSimpleVT()))
8946 return DAG.getSetCC(DL, ResultVT, Op, Op,
8947 IsInvertedFP ? ISD::SETO : ISD::SETUO);
8948
8949 bool IsOrderedInf = FPTestMask == fcInf;
8950 if ((FPTestMask == fcInf || FPTestMask == (fcInf | fcNan)) &&
8951 isCondCodeLegalOrCustom(IsOrderedInf ? OrderedCmpOpcode
8952 : UnorderedCmpOpcode,
8953 OperandVT.getScalarType().getSimpleVT()) &&
8954 isOperationLegalOrCustom(ISD::FABS, OperandVT.getScalarType()) &&
8955 (isOperationLegal(ISD::ConstantFP, OperandVT.getScalarType()) ||
8956 (OperandVT.isVector() &&
8957 isOperationLegalOrCustom(ISD::BUILD_VECTOR, OperandVT)))) {
8958 // isinf(x) --> fabs(x) == inf
8959 SDValue Abs = DAG.getNode(ISD::FABS, DL, OperandVT, Op);
8960 SDValue Inf =
8961 DAG.getConstantFP(APFloat::getInf(Semantics), DL, OperandVT);
8962 return DAG.getSetCC(DL, ResultVT, Abs, Inf,
8963 IsOrderedInf ? OrderedCmpOpcode : UnorderedCmpOpcode);
8964 }
8965
8966 if ((OrderedFPTestMask == fcPosInf || OrderedFPTestMask == fcNegInf) &&
8967 isCondCodeLegalOrCustom(IsOrdered ? OrderedCmpOpcode
8968 : UnorderedCmpOpcode,
8969 OperandVT.getSimpleVT())) {
8970 // isposinf(x) --> x == inf
8971 // isneginf(x) --> x == -inf
8972 // isposinf(x) || nan --> x u== inf
8973 // isneginf(x) || nan --> x u== -inf
8974
8975 SDValue Inf = DAG.getConstantFP(
8976 APFloat::getInf(Semantics, OrderedFPTestMask == fcNegInf), DL,
8977 OperandVT);
8978 return DAG.getSetCC(DL, ResultVT, Op, Inf,
8979 IsOrdered ? OrderedCmpOpcode : UnorderedCmpOpcode);
8980 }
8981
8982 if (OrderedFPTestMask == (fcSubnormal | fcZero) && !IsOrdered) {
8983 // TODO: Could handle ordered case, but it produces worse code for
8984 // x86. Maybe handle ordered if fabs is free?
8985
8986 ISD::CondCode OrderedOp = IsInvertedFP ? ISD::SETUGE : ISD::SETOLT;
8987 ISD::CondCode UnorderedOp = IsInvertedFP ? ISD::SETOGE : ISD::SETULT;
8988
8989 if (isCondCodeLegalOrCustom(IsOrdered ? OrderedOp : UnorderedOp,
8990 OperandVT.getScalarType().getSimpleVT())) {
8991 // (issubnormal(x) || iszero(x)) --> fabs(x) < smallest_normal
8992
8993 // TODO: Maybe only makes sense if fabs is free. Integer test of
8994 // exponent bits seems better for x86.
8995 SDValue Abs = DAG.getNode(ISD::FABS, DL, OperandVT, Op);
8996 SDValue SmallestNormal = DAG.getConstantFP(
8997 APFloat::getSmallestNormalized(Semantics), DL, OperandVT);
8998 return DAG.getSetCC(DL, ResultVT, Abs, SmallestNormal,
8999 IsOrdered ? OrderedOp : UnorderedOp);
9000 }
9001 }
9002
9003 if (FPTestMask == fcNormal) {
9004 // TODO: Handle unordered
9005 ISD::CondCode IsFiniteOp = IsInvertedFP ? ISD::SETUGE : ISD::SETOLT;
9006 ISD::CondCode IsNormalOp = IsInvertedFP ? ISD::SETOLT : ISD::SETUGE;
9007
9008 if (isCondCodeLegalOrCustom(IsFiniteOp,
9009 OperandVT.getScalarType().getSimpleVT()) &&
9010 isCondCodeLegalOrCustom(IsNormalOp,
9011 OperandVT.getScalarType().getSimpleVT()) &&
9012 isFAbsFree(OperandVT)) {
9013 // isnormal(x) --> fabs(x) < infinity && !(fabs(x) < smallest_normal)
9014 SDValue Inf =
9015 DAG.getConstantFP(APFloat::getInf(Semantics), DL, OperandVT);
9016 SDValue SmallestNormal = DAG.getConstantFP(
9017 APFloat::getSmallestNormalized(Semantics), DL, OperandVT);
9018
9019 SDValue Abs = DAG.getNode(ISD::FABS, DL, OperandVT, Op);
9020 SDValue IsFinite = DAG.getSetCC(DL, ResultVT, Abs, Inf, IsFiniteOp);
9021 SDValue IsNormal =
9022 DAG.getSetCC(DL, ResultVT, Abs, SmallestNormal, IsNormalOp);
9023 unsigned LogicOp = IsInvertedFP ? ISD::OR : ISD::AND;
9024 return DAG.getNode(LogicOp, DL, ResultVT, IsFinite, IsNormal);
9025 }
9026 }
9027 }
9028
9029 // Some checks may be represented as inversion of simpler check, for example
9030 // "inf|normal|subnormal|zero" => !"nan".
9031 bool IsInverted = false;
9032
9033 if (FPClassTest InvertedCheck = invertFPClassTestIfSimpler(Test, false)) {
9034 Test = InvertedCheck;
9035 IsInverted = true;
9036 }
9037
9038 // In the general case use integer operations.
9039 unsigned BitSize = OperandVT.getScalarSizeInBits();
9040 EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), BitSize);
9041 if (OperandVT.isVector())
9042 IntVT = EVT::getVectorVT(*DAG.getContext(), IntVT,
9043 OperandVT.getVectorElementCount());
9044 SDValue OpAsInt = DAG.getBitcast(IntVT, Op);
9045
9046 // Various masks.
9047 APInt SignBit = APInt::getSignMask(BitSize);
9048 APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign.
9049 APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
9050 const unsigned ExplicitIntBitInF80 = 63;
9051 APInt ExpMask = Inf;
9052 if (IsF80)
9053 ExpMask.clearBit(ExplicitIntBitInF80);
9054 APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
9055 APInt QNaNBitMask =
9056 APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
9057 APInt InversionMask = APInt::getAllOnes(ResultVT.getScalarSizeInBits());
9058
9059 SDValue ValueMaskV = DAG.getConstant(ValueMask, DL, IntVT);
9060 SDValue SignBitV = DAG.getConstant(SignBit, DL, IntVT);
9061 SDValue ExpMaskV = DAG.getConstant(ExpMask, DL, IntVT);
9062 SDValue ZeroV = DAG.getConstant(0, DL, IntVT);
9063 SDValue InfV = DAG.getConstant(Inf, DL, IntVT);
9064 SDValue ResultInversionMask = DAG.getConstant(InversionMask, DL, ResultVT);
9065
9066 SDValue Res;
9067 const auto appendResult = [&](SDValue PartialRes) {
9068 if (PartialRes) {
9069 if (Res)
9070 Res = DAG.getNode(ISD::OR, DL, ResultVT, Res, PartialRes);
9071 else
9072 Res = PartialRes;
9073 }
9074 };
9075
9076 SDValue IntBitIsSetV; // Explicit integer bit in f80 mantissa is set.
9077 const auto getIntBitIsSet = [&]() -> SDValue {
9078 if (!IntBitIsSetV) {
9079 APInt IntBitMask(BitSize, 0);
9080 IntBitMask.setBit(ExplicitIntBitInF80);
9081 SDValue IntBitMaskV = DAG.getConstant(IntBitMask, DL, IntVT);
9082 SDValue IntBitV = DAG.getNode(ISD::AND, DL, IntVT, OpAsInt, IntBitMaskV);
9083 IntBitIsSetV = DAG.getSetCC(DL, ResultVT, IntBitV, ZeroV, ISD::SETNE);
9084 }
9085 return IntBitIsSetV;
9086 };
9087
9088 // Split the value into sign bit and absolute value.
9089 SDValue AbsV = DAG.getNode(ISD::AND, DL, IntVT, OpAsInt, ValueMaskV);
9090 SDValue SignV = DAG.getSetCC(DL, ResultVT, OpAsInt,
9091 DAG.getConstant(0, DL, IntVT), ISD::SETLT);
9092
9093 // Tests that involve more than one class should be processed first.
9094 SDValue PartialRes;
9095
9096 if (IsF80)
9097 ; // Detect finite numbers of f80 by checking individual classes because
9098 // they have different settings of the explicit integer bit.
9099 else if ((Test & fcFinite) == fcFinite) {
9100 // finite(V) ==> abs(V) < exp_mask
9101 PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, ExpMaskV, ISD::SETLT);
9102 Test &= ~fcFinite;
9103 } else if ((Test & fcFinite) == fcPosFinite) {
9104 // finite(V) && V > 0 ==> V < exp_mask
9105 PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, ExpMaskV, ISD::SETULT);
9106 Test &= ~fcPosFinite;
9107 } else if ((Test & fcFinite) == fcNegFinite) {
9108 // finite(V) && V < 0 ==> abs(V) < exp_mask && signbit == 1
9109 PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, ExpMaskV, ISD::SETLT);
9110 PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, SignV);
9111 Test &= ~fcNegFinite;
9112 }
9113 appendResult(PartialRes);
9114
9115 if (FPClassTest PartialCheck = Test & (fcZero | fcSubnormal)) {
9116 // fcZero | fcSubnormal => test all exponent bits are 0
9117 // TODO: Handle sign bit specific cases
9118 if (PartialCheck == (fcZero | fcSubnormal)) {
9119 SDValue ExpBits = DAG.getNode(ISD::AND, DL, IntVT, OpAsInt, ExpMaskV);
9120 SDValue ExpIsZero =
9121 DAG.getSetCC(DL, ResultVT, ExpBits, ZeroV, ISD::SETEQ);
9122 appendResult(ExpIsZero);
9123 Test &= ~PartialCheck & fcAllFlags;
9124 }
9125 }
9126
9127 // Check for individual classes.
9128
9129 if (unsigned PartialCheck = Test & fcZero) {
9130 if (PartialCheck == fcPosZero)
9131 PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, ZeroV, ISD::SETEQ);
9132 else if (PartialCheck == fcZero)
9133 PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, ZeroV, ISD::SETEQ);
9134 else // ISD::fcNegZero
9135 PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, SignBitV, ISD::SETEQ);
9136 appendResult(PartialRes);
9137 }
9138
9139 if (unsigned PartialCheck = Test & fcSubnormal) {
9140 // issubnormal(V) ==> unsigned(abs(V) - 1) < (all mantissa bits set)
9141 // issubnormal(V) && V>0 ==> unsigned(V - 1) < (all mantissa bits set)
9142 SDValue V = (PartialCheck == fcPosSubnormal) ? OpAsInt : AbsV;
9143 SDValue MantissaV = DAG.getConstant(AllOneMantissa, DL, IntVT);
9144 SDValue VMinusOneV =
9145 DAG.getNode(ISD::SUB, DL, IntVT, V, DAG.getConstant(1, DL, IntVT));
9146 PartialRes = DAG.getSetCC(DL, ResultVT, VMinusOneV, MantissaV, ISD::SETULT);
9147 if (PartialCheck == fcNegSubnormal)
9148 PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, SignV);
9149 appendResult(PartialRes);
9150 }
9151
9152 if (unsigned PartialCheck = Test & fcInf) {
9153 if (PartialCheck == fcPosInf)
9154 PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, InfV, ISD::SETEQ);
9155 else if (PartialCheck == fcInf)
9156 PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, InfV, ISD::SETEQ);
9157 else { // ISD::fcNegInf
9158 APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
9159 SDValue NegInfV = DAG.getConstant(NegInf, DL, IntVT);
9160 PartialRes = DAG.getSetCC(DL, ResultVT, OpAsInt, NegInfV, ISD::SETEQ);
9161 }
9162 appendResult(PartialRes);
9163 }
9164
9165 if (unsigned PartialCheck = Test & fcNan) {
9166 APInt InfWithQnanBit = Inf | QNaNBitMask;
9167 SDValue InfWithQnanBitV = DAG.getConstant(InfWithQnanBit, DL, IntVT);
9168 if (PartialCheck == fcNan) {
9169 // isnan(V) ==> abs(V) > int(inf)
9170 PartialRes = DAG.getSetCC(DL, ResultVT, AbsV, InfV, ISD::SETGT);
9171 if (IsF80) {
9172 // Recognize unsupported values as NaNs for compatibility with glibc.
9173 // In them (exp(V)==0) == int_bit.
9174 SDValue ExpBits = DAG.getNode(ISD::AND, DL, IntVT, AbsV, ExpMaskV);
9175 SDValue ExpIsZero =
9176 DAG.getSetCC(DL, ResultVT, ExpBits, ZeroV, ISD::SETEQ);
9177 SDValue IsPseudo =
9178 DAG.getSetCC(DL, ResultVT, getIntBitIsSet(), ExpIsZero, ISD::SETEQ);
9179 PartialRes = DAG.getNode(ISD::OR, DL, ResultVT, PartialRes, IsPseudo);
9180 }
9181 } else if (PartialCheck == fcQNan) {
9182 // isquiet(V) ==> abs(V) >= (unsigned(Inf) | quiet_bit)
9183 PartialRes =
9184 DAG.getSetCC(DL, ResultVT, AbsV, InfWithQnanBitV, ISD::SETGE);
9185 } else { // ISD::fcSNan
9186 // issignaling(V) ==> abs(V) > unsigned(Inf) &&
9187 // abs(V) < (unsigned(Inf) | quiet_bit)
9188 SDValue IsNan = DAG.getSetCC(DL, ResultVT, AbsV, InfV, ISD::SETGT);
9189 SDValue IsNotQnan =
9190 DAG.getSetCC(DL, ResultVT, AbsV, InfWithQnanBitV, ISD::SETLT);
9191 PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, IsNan, IsNotQnan);
9192 }
9193 appendResult(PartialRes);
9194 }
9195
9196 if (unsigned PartialCheck = Test & fcNormal) {
9197 // isnormal(V) ==> (0 < exp < max_exp) ==> (unsigned(exp-1) < (max_exp-1))
9198 APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
9199 SDValue ExpLSBV = DAG.getConstant(ExpLSB, DL, IntVT);
9200 SDValue ExpMinus1 = DAG.getNode(ISD::SUB, DL, IntVT, AbsV, ExpLSBV);
9201 APInt ExpLimit = ExpMask - ExpLSB;
9202 SDValue ExpLimitV = DAG.getConstant(ExpLimit, DL, IntVT);
9203 PartialRes = DAG.getSetCC(DL, ResultVT, ExpMinus1, ExpLimitV, ISD::SETULT);
9204 if (PartialCheck == fcNegNormal)
9205 PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, SignV);
9206 else if (PartialCheck == fcPosNormal) {
9207 SDValue PosSignV =
9208 DAG.getNode(ISD::XOR, DL, ResultVT, SignV, ResultInversionMask);
9209 PartialRes = DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, PosSignV);
9210 }
9211 if (IsF80)
9212 PartialRes =
9213 DAG.getNode(ISD::AND, DL, ResultVT, PartialRes, getIntBitIsSet());
9214 appendResult(PartialRes);
9215 }
9216
9217 if (!Res)
9218 return DAG.getConstant(IsInverted, DL, ResultVT);
9219 if (IsInverted)
9220 Res = DAG.getNode(ISD::XOR, DL, ResultVT, Res, ResultInversionMask);
9221 return Res;
9222 }
9223
9224 // Only expand vector types if we have the appropriate vector bit operations.
canExpandVectorCTPOP(const TargetLowering & TLI,EVT VT)9225 static bool canExpandVectorCTPOP(const TargetLowering &TLI, EVT VT) {
9226 assert(VT.isVector() && "Expected vector type");
9227 unsigned Len = VT.getScalarSizeInBits();
9228 return TLI.isOperationLegalOrCustom(ISD::ADD, VT) &&
9229 TLI.isOperationLegalOrCustom(ISD::SUB, VT) &&
9230 TLI.isOperationLegalOrCustom(ISD::SRL, VT) &&
9231 (Len == 8 || TLI.isOperationLegalOrCustom(ISD::MUL, VT)) &&
9232 TLI.isOperationLegalOrCustomOrPromote(ISD::AND, VT);
9233 }
9234
expandCTPOP(SDNode * Node,SelectionDAG & DAG) const9235 SDValue TargetLowering::expandCTPOP(SDNode *Node, SelectionDAG &DAG) const {
9236 SDLoc dl(Node);
9237 EVT VT = Node->getValueType(0);
9238 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
9239 SDValue Op = Node->getOperand(0);
9240 unsigned Len = VT.getScalarSizeInBits();
9241 assert(VT.isInteger() && "CTPOP not implemented for this type.");
9242
9243 // TODO: Add support for irregular type lengths.
9244 if (!(Len <= 128 && Len % 8 == 0))
9245 return SDValue();
9246
9247 // Only expand vector types if we have the appropriate vector bit operations.
9248 if (VT.isVector() && !canExpandVectorCTPOP(*this, VT))
9249 return SDValue();
9250
9251 // This is the "best" algorithm from
9252 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
9253 SDValue Mask55 =
9254 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl, VT);
9255 SDValue Mask33 =
9256 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl, VT);
9257 SDValue Mask0F =
9258 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl, VT);
9259
9260 // v = v - ((v >> 1) & 0x55555555...)
9261 Op = DAG.getNode(ISD::SUB, dl, VT, Op,
9262 DAG.getNode(ISD::AND, dl, VT,
9263 DAG.getNode(ISD::SRL, dl, VT, Op,
9264 DAG.getConstant(1, dl, ShVT)),
9265 Mask55));
9266 // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
9267 Op = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::AND, dl, VT, Op, Mask33),
9268 DAG.getNode(ISD::AND, dl, VT,
9269 DAG.getNode(ISD::SRL, dl, VT, Op,
9270 DAG.getConstant(2, dl, ShVT)),
9271 Mask33));
9272 // v = (v + (v >> 4)) & 0x0F0F0F0F...
9273 Op = DAG.getNode(ISD::AND, dl, VT,
9274 DAG.getNode(ISD::ADD, dl, VT, Op,
9275 DAG.getNode(ISD::SRL, dl, VT, Op,
9276 DAG.getConstant(4, dl, ShVT))),
9277 Mask0F);
9278
9279 if (Len <= 8)
9280 return Op;
9281
9282 // Avoid the multiply if we only have 2 bytes to add.
9283 // TODO: Only doing this for scalars because vectors weren't as obviously
9284 // improved.
9285 if (Len == 16 && !VT.isVector()) {
9286 // v = (v + (v >> 8)) & 0x00FF;
9287 return DAG.getNode(ISD::AND, dl, VT,
9288 DAG.getNode(ISD::ADD, dl, VT, Op,
9289 DAG.getNode(ISD::SRL, dl, VT, Op,
9290 DAG.getConstant(8, dl, ShVT))),
9291 DAG.getConstant(0xFF, dl, VT));
9292 }
9293
9294 // v = (v * 0x01010101...) >> (Len - 8)
9295 SDValue V;
9296 if (isOperationLegalOrCustomOrPromote(
9297 ISD::MUL, getTypeToTransformTo(*DAG.getContext(), VT))) {
9298 SDValue Mask01 =
9299 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x01)), dl, VT);
9300 V = DAG.getNode(ISD::MUL, dl, VT, Op, Mask01);
9301 } else {
9302 V = Op;
9303 for (unsigned Shift = 8; Shift < Len; Shift *= 2) {
9304 SDValue ShiftC = DAG.getShiftAmountConstant(Shift, VT, dl);
9305 V = DAG.getNode(ISD::ADD, dl, VT, V,
9306 DAG.getNode(ISD::SHL, dl, VT, V, ShiftC));
9307 }
9308 }
9309 return DAG.getNode(ISD::SRL, dl, VT, V, DAG.getConstant(Len - 8, dl, ShVT));
9310 }
9311
expandVPCTPOP(SDNode * Node,SelectionDAG & DAG) const9312 SDValue TargetLowering::expandVPCTPOP(SDNode *Node, SelectionDAG &DAG) const {
9313 SDLoc dl(Node);
9314 EVT VT = Node->getValueType(0);
9315 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
9316 SDValue Op = Node->getOperand(0);
9317 SDValue Mask = Node->getOperand(1);
9318 SDValue VL = Node->getOperand(2);
9319 unsigned Len = VT.getScalarSizeInBits();
9320 assert(VT.isInteger() && "VP_CTPOP not implemented for this type.");
9321
9322 // TODO: Add support for irregular type lengths.
9323 if (!(Len <= 128 && Len % 8 == 0))
9324 return SDValue();
9325
9326 // This is same algorithm of expandCTPOP from
9327 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
9328 SDValue Mask55 =
9329 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), dl, VT);
9330 SDValue Mask33 =
9331 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), dl, VT);
9332 SDValue Mask0F =
9333 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), dl, VT);
9334
9335 SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5;
9336
9337 // v = v - ((v >> 1) & 0x55555555...)
9338 Tmp1 = DAG.getNode(ISD::VP_AND, dl, VT,
9339 DAG.getNode(ISD::VP_SRL, dl, VT, Op,
9340 DAG.getConstant(1, dl, ShVT), Mask, VL),
9341 Mask55, Mask, VL);
9342 Op = DAG.getNode(ISD::VP_SUB, dl, VT, Op, Tmp1, Mask, VL);
9343
9344 // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
9345 Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Op, Mask33, Mask, VL);
9346 Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT,
9347 DAG.getNode(ISD::VP_SRL, dl, VT, Op,
9348 DAG.getConstant(2, dl, ShVT), Mask, VL),
9349 Mask33, Mask, VL);
9350 Op = DAG.getNode(ISD::VP_ADD, dl, VT, Tmp2, Tmp3, Mask, VL);
9351
9352 // v = (v + (v >> 4)) & 0x0F0F0F0F...
9353 Tmp4 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(4, dl, ShVT),
9354 Mask, VL),
9355 Tmp5 = DAG.getNode(ISD::VP_ADD, dl, VT, Op, Tmp4, Mask, VL);
9356 Op = DAG.getNode(ISD::VP_AND, dl, VT, Tmp5, Mask0F, Mask, VL);
9357
9358 if (Len <= 8)
9359 return Op;
9360
9361 // v = (v * 0x01010101...) >> (Len - 8)
9362 SDValue V;
9363 if (isOperationLegalOrCustomOrPromote(
9364 ISD::VP_MUL, getTypeToTransformTo(*DAG.getContext(), VT))) {
9365 SDValue Mask01 =
9366 DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x01)), dl, VT);
9367 V = DAG.getNode(ISD::VP_MUL, dl, VT, Op, Mask01, Mask, VL);
9368 } else {
9369 V = Op;
9370 for (unsigned Shift = 8; Shift < Len; Shift *= 2) {
9371 SDValue ShiftC = DAG.getShiftAmountConstant(Shift, VT, dl);
9372 V = DAG.getNode(ISD::VP_ADD, dl, VT, V,
9373 DAG.getNode(ISD::VP_SHL, dl, VT, V, ShiftC, Mask, VL),
9374 Mask, VL);
9375 }
9376 }
9377 return DAG.getNode(ISD::VP_SRL, dl, VT, V, DAG.getConstant(Len - 8, dl, ShVT),
9378 Mask, VL);
9379 }
9380
expandCTLZ(SDNode * Node,SelectionDAG & DAG) const9381 SDValue TargetLowering::expandCTLZ(SDNode *Node, SelectionDAG &DAG) const {
9382 SDLoc dl(Node);
9383 EVT VT = Node->getValueType(0);
9384 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
9385 SDValue Op = Node->getOperand(0);
9386 unsigned NumBitsPerElt = VT.getScalarSizeInBits();
9387
9388 // If the non-ZERO_UNDEF version is supported we can use that instead.
9389 if (Node->getOpcode() == ISD::CTLZ_ZERO_UNDEF &&
9390 isOperationLegalOrCustom(ISD::CTLZ, VT))
9391 return DAG.getNode(ISD::CTLZ, dl, VT, Op);
9392
9393 // If the ZERO_UNDEF version is supported use that and handle the zero case.
9394 if (isOperationLegalOrCustom(ISD::CTLZ_ZERO_UNDEF, VT)) {
9395 EVT SetCCVT =
9396 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9397 SDValue CTLZ = DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, VT, Op);
9398 SDValue Zero = DAG.getConstant(0, dl, VT);
9399 SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ);
9400 return DAG.getSelect(dl, VT, SrcIsZero,
9401 DAG.getConstant(NumBitsPerElt, dl, VT), CTLZ);
9402 }
9403
9404 // Only expand vector types if we have the appropriate vector bit operations.
9405 // This includes the operations needed to expand CTPOP if it isn't supported.
9406 if (VT.isVector() && (!isPowerOf2_32(NumBitsPerElt) ||
9407 (!isOperationLegalOrCustom(ISD::CTPOP, VT) &&
9408 !canExpandVectorCTPOP(*this, VT)) ||
9409 !isOperationLegalOrCustom(ISD::SRL, VT) ||
9410 !isOperationLegalOrCustomOrPromote(ISD::OR, VT)))
9411 return SDValue();
9412
9413 // for now, we do this:
9414 // x = x | (x >> 1);
9415 // x = x | (x >> 2);
9416 // ...
9417 // x = x | (x >>16);
9418 // x = x | (x >>32); // for 64-bit input
9419 // return popcount(~x);
9420 //
9421 // Ref: "Hacker's Delight" by Henry Warren
9422 for (unsigned i = 0; (1U << i) < NumBitsPerElt; ++i) {
9423 SDValue Tmp = DAG.getConstant(1ULL << i, dl, ShVT);
9424 Op = DAG.getNode(ISD::OR, dl, VT, Op,
9425 DAG.getNode(ISD::SRL, dl, VT, Op, Tmp));
9426 }
9427 Op = DAG.getNOT(dl, Op, VT);
9428 return DAG.getNode(ISD::CTPOP, dl, VT, Op);
9429 }
9430
expandVPCTLZ(SDNode * Node,SelectionDAG & DAG) const9431 SDValue TargetLowering::expandVPCTLZ(SDNode *Node, SelectionDAG &DAG) const {
9432 SDLoc dl(Node);
9433 EVT VT = Node->getValueType(0);
9434 EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
9435 SDValue Op = Node->getOperand(0);
9436 SDValue Mask = Node->getOperand(1);
9437 SDValue VL = Node->getOperand(2);
9438 unsigned NumBitsPerElt = VT.getScalarSizeInBits();
9439
9440 // do this:
9441 // x = x | (x >> 1);
9442 // x = x | (x >> 2);
9443 // ...
9444 // x = x | (x >>16);
9445 // x = x | (x >>32); // for 64-bit input
9446 // return popcount(~x);
9447 for (unsigned i = 0; (1U << i) < NumBitsPerElt; ++i) {
9448 SDValue Tmp = DAG.getConstant(1ULL << i, dl, ShVT);
9449 Op = DAG.getNode(ISD::VP_OR, dl, VT, Op,
9450 DAG.getNode(ISD::VP_SRL, dl, VT, Op, Tmp, Mask, VL), Mask,
9451 VL);
9452 }
9453 Op = DAG.getNode(ISD::VP_XOR, dl, VT, Op, DAG.getAllOnesConstant(dl, VT),
9454 Mask, VL);
9455 return DAG.getNode(ISD::VP_CTPOP, dl, VT, Op, Mask, VL);
9456 }
9457
CTTZTableLookup(SDNode * Node,SelectionDAG & DAG,const SDLoc & DL,EVT VT,SDValue Op,unsigned BitWidth) const9458 SDValue TargetLowering::CTTZTableLookup(SDNode *Node, SelectionDAG &DAG,
9459 const SDLoc &DL, EVT VT, SDValue Op,
9460 unsigned BitWidth) const {
9461 if (BitWidth != 32 && BitWidth != 64)
9462 return SDValue();
9463 APInt DeBruijn = BitWidth == 32 ? APInt(32, 0x077CB531U)
9464 : APInt(64, 0x0218A392CD3D5DBFULL);
9465 const DataLayout &TD = DAG.getDataLayout();
9466 MachinePointerInfo PtrInfo =
9467 MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
9468 unsigned ShiftAmt = BitWidth - Log2_32(BitWidth);
9469 SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Op);
9470 SDValue Lookup = DAG.getNode(
9471 ISD::SRL, DL, VT,
9472 DAG.getNode(ISD::MUL, DL, VT, DAG.getNode(ISD::AND, DL, VT, Op, Neg),
9473 DAG.getConstant(DeBruijn, DL, VT)),
9474 DAG.getConstant(ShiftAmt, DL, VT));
9475 Lookup = DAG.getSExtOrTrunc(Lookup, DL, getPointerTy(TD));
9476
9477 SmallVector<uint8_t> Table(BitWidth, 0);
9478 for (unsigned i = 0; i < BitWidth; i++) {
9479 APInt Shl = DeBruijn.shl(i);
9480 APInt Lshr = Shl.lshr(ShiftAmt);
9481 Table[Lshr.getZExtValue()] = i;
9482 }
9483
9484 // Create a ConstantArray in Constant Pool
9485 auto *CA = ConstantDataArray::get(*DAG.getContext(), Table);
9486 SDValue CPIdx = DAG.getConstantPool(CA, getPointerTy(TD),
9487 TD.getPrefTypeAlign(CA->getType()));
9488 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, DL, VT, DAG.getEntryNode(),
9489 DAG.getMemBasePlusOffset(CPIdx, Lookup, DL),
9490 PtrInfo, MVT::i8);
9491 if (Node->getOpcode() == ISD::CTTZ_ZERO_UNDEF)
9492 return ExtLoad;
9493
9494 EVT SetCCVT =
9495 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9496 SDValue Zero = DAG.getConstant(0, DL, VT);
9497 SDValue SrcIsZero = DAG.getSetCC(DL, SetCCVT, Op, Zero, ISD::SETEQ);
9498 return DAG.getSelect(DL, VT, SrcIsZero,
9499 DAG.getConstant(BitWidth, DL, VT), ExtLoad);
9500 }
9501
expandCTTZ(SDNode * Node,SelectionDAG & DAG) const9502 SDValue TargetLowering::expandCTTZ(SDNode *Node, SelectionDAG &DAG) const {
9503 SDLoc dl(Node);
9504 EVT VT = Node->getValueType(0);
9505 SDValue Op = Node->getOperand(0);
9506 unsigned NumBitsPerElt = VT.getScalarSizeInBits();
9507
9508 // If the non-ZERO_UNDEF version is supported we can use that instead.
9509 if (Node->getOpcode() == ISD::CTTZ_ZERO_UNDEF &&
9510 isOperationLegalOrCustom(ISD::CTTZ, VT))
9511 return DAG.getNode(ISD::CTTZ, dl, VT, Op);
9512
9513 // If the ZERO_UNDEF version is supported use that and handle the zero case.
9514 if (isOperationLegalOrCustom(ISD::CTTZ_ZERO_UNDEF, VT)) {
9515 EVT SetCCVT =
9516 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9517 SDValue CTTZ = DAG.getNode(ISD::CTTZ_ZERO_UNDEF, dl, VT, Op);
9518 SDValue Zero = DAG.getConstant(0, dl, VT);
9519 SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ);
9520 return DAG.getSelect(dl, VT, SrcIsZero,
9521 DAG.getConstant(NumBitsPerElt, dl, VT), CTTZ);
9522 }
9523
9524 // Only expand vector types if we have the appropriate vector bit operations.
9525 // This includes the operations needed to expand CTPOP if it isn't supported.
9526 if (VT.isVector() && (!isPowerOf2_32(NumBitsPerElt) ||
9527 (!isOperationLegalOrCustom(ISD::CTPOP, VT) &&
9528 !isOperationLegalOrCustom(ISD::CTLZ, VT) &&
9529 !canExpandVectorCTPOP(*this, VT)) ||
9530 !isOperationLegalOrCustom(ISD::SUB, VT) ||
9531 !isOperationLegalOrCustomOrPromote(ISD::AND, VT) ||
9532 !isOperationLegalOrCustomOrPromote(ISD::XOR, VT)))
9533 return SDValue();
9534
9535 // Emit Table Lookup if ISD::CTPOP used in the fallback path below is going
9536 // to be expanded or converted to a libcall.
9537 if (!VT.isVector() && !isOperationLegalOrCustomOrPromote(ISD::CTPOP, VT) &&
9538 !isOperationLegal(ISD::CTLZ, VT))
9539 if (SDValue V = CTTZTableLookup(Node, DAG, dl, VT, Op, NumBitsPerElt))
9540 return V;
9541
9542 // for now, we use: { return popcount(~x & (x - 1)); }
9543 // unless the target has ctlz but not ctpop, in which case we use:
9544 // { return 32 - nlz(~x & (x-1)); }
9545 // Ref: "Hacker's Delight" by Henry Warren
9546 SDValue Tmp = DAG.getNode(
9547 ISD::AND, dl, VT, DAG.getNOT(dl, Op, VT),
9548 DAG.getNode(ISD::SUB, dl, VT, Op, DAG.getConstant(1, dl, VT)));
9549
9550 // If ISD::CTLZ is legal and CTPOP isn't, then do that instead.
9551 if (isOperationLegal(ISD::CTLZ, VT) && !isOperationLegal(ISD::CTPOP, VT)) {
9552 return DAG.getNode(ISD::SUB, dl, VT, DAG.getConstant(NumBitsPerElt, dl, VT),
9553 DAG.getNode(ISD::CTLZ, dl, VT, Tmp));
9554 }
9555
9556 return DAG.getNode(ISD::CTPOP, dl, VT, Tmp);
9557 }
9558
expandVPCTTZ(SDNode * Node,SelectionDAG & DAG) const9559 SDValue TargetLowering::expandVPCTTZ(SDNode *Node, SelectionDAG &DAG) const {
9560 SDValue Op = Node->getOperand(0);
9561 SDValue Mask = Node->getOperand(1);
9562 SDValue VL = Node->getOperand(2);
9563 SDLoc dl(Node);
9564 EVT VT = Node->getValueType(0);
9565
9566 // Same as the vector part of expandCTTZ, use: popcount(~x & (x - 1))
9567 SDValue Not = DAG.getNode(ISD::VP_XOR, dl, VT, Op,
9568 DAG.getAllOnesConstant(dl, VT), Mask, VL);
9569 SDValue MinusOne = DAG.getNode(ISD::VP_SUB, dl, VT, Op,
9570 DAG.getConstant(1, dl, VT), Mask, VL);
9571 SDValue Tmp = DAG.getNode(ISD::VP_AND, dl, VT, Not, MinusOne, Mask, VL);
9572 return DAG.getNode(ISD::VP_CTPOP, dl, VT, Tmp, Mask, VL);
9573 }
9574
expandVPCTTZElements(SDNode * N,SelectionDAG & DAG) const9575 SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
9576 SelectionDAG &DAG) const {
9577 // %cond = to_bool_vec %source
9578 // %splat = splat /*val=*/VL
9579 // %tz = step_vector
9580 // %v = vp.select %cond, /*true=*/tz, /*false=*/%splat
9581 // %r = vp.reduce.umin %v
9582 SDLoc DL(N);
9583 SDValue Source = N->getOperand(0);
9584 SDValue Mask = N->getOperand(1);
9585 SDValue EVL = N->getOperand(2);
9586 EVT SrcVT = Source.getValueType();
9587 EVT ResVT = N->getValueType(0);
9588 EVT ResVecVT =
9589 EVT::getVectorVT(*DAG.getContext(), ResVT, SrcVT.getVectorElementCount());
9590
9591 // Convert to boolean vector.
9592 if (SrcVT.getScalarType() != MVT::i1) {
9593 SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
9594 SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
9595 SrcVT.getVectorElementCount());
9596 Source = DAG.getNode(ISD::VP_SETCC, DL, SrcVT, Source, AllZero,
9597 DAG.getCondCode(ISD::SETNE), Mask, EVL);
9598 }
9599
9600 SDValue ExtEVL = DAG.getZExtOrTrunc(EVL, DL, ResVT);
9601 SDValue Splat = DAG.getSplat(ResVecVT, DL, ExtEVL);
9602 SDValue StepVec = DAG.getStepVector(DL, ResVecVT);
9603 SDValue Select =
9604 DAG.getNode(ISD::VP_SELECT, DL, ResVecVT, Source, StepVec, Splat, EVL);
9605 return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
9606 }
9607
expandVectorFindLastActive(SDNode * N,SelectionDAG & DAG) const9608 SDValue TargetLowering::expandVectorFindLastActive(SDNode *N,
9609 SelectionDAG &DAG) const {
9610 SDLoc DL(N);
9611 SDValue Mask = N->getOperand(0);
9612 EVT MaskVT = Mask.getValueType();
9613 EVT BoolVT = MaskVT.getScalarType();
9614
9615 // Find a suitable type for a stepvector.
9616 ConstantRange VScaleRange(1, /*isFullSet=*/true); // Fixed length default.
9617 if (MaskVT.isScalableVector())
9618 VScaleRange = getVScaleRange(&DAG.getMachineFunction().getFunction(), 64);
9619 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9620 unsigned EltWidth = TLI.getBitWidthForCttzElements(
9621 BoolVT.getTypeForEVT(*DAG.getContext()), MaskVT.getVectorElementCount(),
9622 /*ZeroIsPoison=*/true, &VScaleRange);
9623 EVT StepVT = MVT::getIntegerVT(EltWidth);
9624 EVT StepVecVT = MaskVT.changeVectorElementType(StepVT);
9625
9626 // If promotion is required to make the type legal, do it here; promotion
9627 // of integers within LegalizeVectorOps is looking for types of the same
9628 // size but with a smaller number of larger elements, not the usual larger
9629 // size with the same number of larger elements.
9630 if (TLI.getTypeAction(StepVecVT.getSimpleVT()) ==
9631 TargetLowering::TypePromoteInteger) {
9632 StepVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), StepVecVT);
9633 StepVT = StepVecVT.getVectorElementType();
9634 }
9635
9636 // Zero out lanes with inactive elements, then find the highest remaining
9637 // value from the stepvector.
9638 SDValue Zeroes = DAG.getConstant(0, DL, StepVecVT);
9639 SDValue StepVec = DAG.getStepVector(DL, StepVecVT);
9640 SDValue ActiveElts = DAG.getSelect(DL, StepVecVT, Mask, StepVec, Zeroes);
9641 SDValue HighestIdx = DAG.getNode(ISD::VECREDUCE_UMAX, DL, StepVT, ActiveElts);
9642 return DAG.getZExtOrTrunc(HighestIdx, DL, N->getValueType(0));
9643 }
9644
expandABS(SDNode * N,SelectionDAG & DAG,bool IsNegative) const9645 SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
9646 bool IsNegative) const {
9647 SDLoc dl(N);
9648 EVT VT = N->getValueType(0);
9649 SDValue Op = N->getOperand(0);
9650
9651 // abs(x) -> smax(x,sub(0,x))
9652 if (!IsNegative && isOperationLegal(ISD::SUB, VT) &&
9653 isOperationLegal(ISD::SMAX, VT)) {
9654 SDValue Zero = DAG.getConstant(0, dl, VT);
9655 Op = DAG.getFreeze(Op);
9656 return DAG.getNode(ISD::SMAX, dl, VT, Op,
9657 DAG.getNode(ISD::SUB, dl, VT, Zero, Op));
9658 }
9659
9660 // abs(x) -> umin(x,sub(0,x))
9661 if (!IsNegative && isOperationLegal(ISD::SUB, VT) &&
9662 isOperationLegal(ISD::UMIN, VT)) {
9663 SDValue Zero = DAG.getConstant(0, dl, VT);
9664 Op = DAG.getFreeze(Op);
9665 return DAG.getNode(ISD::UMIN, dl, VT, Op,
9666 DAG.getNode(ISD::SUB, dl, VT, Zero, Op));
9667 }
9668
9669 // 0 - abs(x) -> smin(x, sub(0,x))
9670 if (IsNegative && isOperationLegal(ISD::SUB, VT) &&
9671 isOperationLegal(ISD::SMIN, VT)) {
9672 SDValue Zero = DAG.getConstant(0, dl, VT);
9673 Op = DAG.getFreeze(Op);
9674 return DAG.getNode(ISD::SMIN, dl, VT, Op,
9675 DAG.getNode(ISD::SUB, dl, VT, Zero, Op));
9676 }
9677
9678 // Only expand vector types if we have the appropriate vector operations.
9679 if (VT.isVector() &&
9680 (!isOperationLegalOrCustom(ISD::SRA, VT) ||
9681 (!IsNegative && !isOperationLegalOrCustom(ISD::ADD, VT)) ||
9682 (IsNegative && !isOperationLegalOrCustom(ISD::SUB, VT)) ||
9683 !isOperationLegalOrCustomOrPromote(ISD::XOR, VT)))
9684 return SDValue();
9685
9686 Op = DAG.getFreeze(Op);
9687 SDValue Shift = DAG.getNode(
9688 ISD::SRA, dl, VT, Op,
9689 DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, dl));
9690 SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, Op, Shift);
9691
9692 // abs(x) -> Y = sra (X, size(X)-1); sub (xor (X, Y), Y)
9693 if (!IsNegative)
9694 return DAG.getNode(ISD::SUB, dl, VT, Xor, Shift);
9695
9696 // 0 - abs(x) -> Y = sra (X, size(X)-1); sub (Y, xor (X, Y))
9697 return DAG.getNode(ISD::SUB, dl, VT, Shift, Xor);
9698 }
9699
expandABD(SDNode * N,SelectionDAG & DAG) const9700 SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
9701 SDLoc dl(N);
9702 EVT VT = N->getValueType(0);
9703 SDValue LHS = DAG.getFreeze(N->getOperand(0));
9704 SDValue RHS = DAG.getFreeze(N->getOperand(1));
9705 bool IsSigned = N->getOpcode() == ISD::ABDS;
9706
9707 // abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
9708 // abdu(lhs, rhs) -> sub(umax(lhs,rhs), umin(lhs,rhs))
9709 unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
9710 unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
9711 if (isOperationLegal(MaxOpc, VT) && isOperationLegal(MinOpc, VT)) {
9712 SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
9713 SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
9714 return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
9715 }
9716
9717 // abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
9718 if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT))
9719 return DAG.getNode(ISD::OR, dl, VT,
9720 DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
9721 DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
9722
9723 // If the subtract doesn't overflow then just use abs(sub())
9724 // NOTE: don't use frozen operands for value tracking.
9725 bool IsNonNegative = DAG.SignBitIsZero(N->getOperand(1)) &&
9726 DAG.SignBitIsZero(N->getOperand(0));
9727
9728 if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(0),
9729 N->getOperand(1)))
9730 return DAG.getNode(ISD::ABS, dl, VT,
9731 DAG.getNode(ISD::SUB, dl, VT, LHS, RHS));
9732
9733 if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(1),
9734 N->getOperand(0)))
9735 return DAG.getNode(ISD::ABS, dl, VT,
9736 DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
9737
9738 EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
9739 ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
9740 SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
9741
9742 // Branchless expansion iff cmp result is allbits:
9743 // abds(lhs, rhs) -> sub(sgt(lhs, rhs), xor(sgt(lhs, rhs), sub(lhs, rhs)))
9744 // abdu(lhs, rhs) -> sub(ugt(lhs, rhs), xor(ugt(lhs, rhs), sub(lhs, rhs)))
9745 if (CCVT == VT && getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) {
9746 SDValue Diff = DAG.getNode(ISD::SUB, dl, VT, LHS, RHS);
9747 SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, Diff, Cmp);
9748 return DAG.getNode(ISD::SUB, dl, VT, Cmp, Xor);
9749 }
9750
9751 // Similar to the branchless expansion, use the (sign-extended) usubo overflow
9752 // flag if the (scalar) type is illegal as this is more likely to legalize
9753 // cleanly:
9754 // abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), uof(lhs, rhs)), uof(lhs, rhs))
9755 if (!IsSigned && VT.isScalarInteger() && !isTypeLegal(VT)) {
9756 SDValue USubO =
9757 DAG.getNode(ISD::USUBO, dl, DAG.getVTList(VT, MVT::i1), {LHS, RHS});
9758 SDValue Cmp = DAG.getNode(ISD::SIGN_EXTEND, dl, VT, USubO.getValue(1));
9759 SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, USubO.getValue(0), Cmp);
9760 return DAG.getNode(ISD::SUB, dl, VT, Xor, Cmp);
9761 }
9762
9763 // FIXME: Should really try to split the vector in case it's legal on a
9764 // subvector.
9765 if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
9766 return DAG.UnrollVectorOp(N);
9767
9768 // abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
9769 // abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
9770 return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
9771 DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
9772 }
9773
expandAVG(SDNode * N,SelectionDAG & DAG) const9774 SDValue TargetLowering::expandAVG(SDNode *N, SelectionDAG &DAG) const {
9775 SDLoc dl(N);
9776 EVT VT = N->getValueType(0);
9777 SDValue LHS = N->getOperand(0);
9778 SDValue RHS = N->getOperand(1);
9779
9780 unsigned Opc = N->getOpcode();
9781 bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9782 bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9783 unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9784 unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9785 unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9786 unsigned ExtOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
9787 assert((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9788 Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9789 "Unknown AVG node");
9790
9791 // If the operands are already extended, we can add+shift.
9792 bool IsExt =
9793 (IsSigned && DAG.ComputeNumSignBits(LHS) >= 2 &&
9794 DAG.ComputeNumSignBits(RHS) >= 2) ||
9795 (!IsSigned && DAG.computeKnownBits(LHS).countMinLeadingZeros() >= 1 &&
9796 DAG.computeKnownBits(RHS).countMinLeadingZeros() >= 1);
9797 if (IsExt) {
9798 SDValue Sum = DAG.getNode(ISD::ADD, dl, VT, LHS, RHS);
9799 if (!IsFloor)
9800 Sum = DAG.getNode(ISD::ADD, dl, VT, Sum, DAG.getConstant(1, dl, VT));
9801 return DAG.getNode(ShiftOpc, dl, VT, Sum,
9802 DAG.getShiftAmountConstant(1, VT, dl));
9803 }
9804
9805 // For scalars, see if we can efficiently extend/truncate to use add+shift.
9806 if (VT.isScalarInteger()) {
9807 unsigned BW = VT.getScalarSizeInBits();
9808 EVT ExtVT = VT.getIntegerVT(*DAG.getContext(), 2 * BW);
9809 if (isTypeLegal(ExtVT) && isTruncateFree(ExtVT, VT)) {
9810 LHS = DAG.getNode(ExtOpc, dl, ExtVT, LHS);
9811 RHS = DAG.getNode(ExtOpc, dl, ExtVT, RHS);
9812 SDValue Avg = DAG.getNode(ISD::ADD, dl, ExtVT, LHS, RHS);
9813 if (!IsFloor)
9814 Avg = DAG.getNode(ISD::ADD, dl, ExtVT, Avg,
9815 DAG.getConstant(1, dl, ExtVT));
9816 // Just use SRL as we will be truncating away the extended sign bits.
9817 Avg = DAG.getNode(ISD::SRL, dl, ExtVT, Avg,
9818 DAG.getShiftAmountConstant(1, ExtVT, dl));
9819 return DAG.getNode(ISD::TRUNCATE, dl, VT, Avg);
9820 }
9821 }
9822
9823 // avgflooru(lhs, rhs) -> or(lshr(add(lhs, rhs),1),shl(overflow, typesize-1))
9824 if (Opc == ISD::AVGFLOORU && VT.isScalarInteger() && !isTypeLegal(VT)) {
9825 SDValue UAddWithOverflow =
9826 DAG.getNode(ISD::UADDO, dl, DAG.getVTList(VT, MVT::i1), {RHS, LHS});
9827
9828 SDValue Sum = UAddWithOverflow.getValue(0);
9829 SDValue Overflow = UAddWithOverflow.getValue(1);
9830
9831 // Right shift the sum by 1
9832 SDValue LShrVal = DAG.getNode(ISD::SRL, dl, VT, Sum,
9833 DAG.getShiftAmountConstant(1, VT, dl));
9834
9835 SDValue ZeroExtOverflow = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Overflow);
9836 SDValue OverflowShl = DAG.getNode(
9837 ISD::SHL, dl, VT, ZeroExtOverflow,
9838 DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, dl));
9839
9840 return DAG.getNode(ISD::OR, dl, VT, LShrVal, OverflowShl);
9841 }
9842
9843 // avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9844 // avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9845 // avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9846 // avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9847 LHS = DAG.getFreeze(LHS);
9848 RHS = DAG.getFreeze(RHS);
9849 SDValue Sign = DAG.getNode(SignOpc, dl, VT, LHS, RHS);
9850 SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
9851 SDValue Shift =
9852 DAG.getNode(ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant(1, VT, dl));
9853 return DAG.getNode(SumOpc, dl, VT, Sign, Shift);
9854 }
9855
expandBSWAP(SDNode * N,SelectionDAG & DAG) const9856 SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
9857 SDLoc dl(N);
9858 EVT VT = N->getValueType(0);
9859 SDValue Op = N->getOperand(0);
9860
9861 if (!VT.isSimple())
9862 return SDValue();
9863
9864 EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
9865 SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5, Tmp6, Tmp7, Tmp8;
9866 switch (VT.getSimpleVT().getScalarType().SimpleTy) {
9867 default:
9868 return SDValue();
9869 case MVT::i16:
9870 // Use a rotate by 8. This can be further expanded if necessary.
9871 return DAG.getNode(ISD::ROTL, dl, VT, Op, DAG.getConstant(8, dl, SHVT));
9872 case MVT::i32:
9873 Tmp4 = DAG.getNode(ISD::SHL, dl, VT, Op, DAG.getConstant(24, dl, SHVT));
9874 Tmp3 = DAG.getNode(ISD::AND, dl, VT, Op,
9875 DAG.getConstant(0xFF00, dl, VT));
9876 Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(8, dl, SHVT));
9877 Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(8, dl, SHVT));
9878 Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(0xFF00, dl, VT));
9879 Tmp1 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(24, dl, SHVT));
9880 Tmp4 = DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp3);
9881 Tmp2 = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp1);
9882 return DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp2);
9883 case MVT::i64:
9884 Tmp8 = DAG.getNode(ISD::SHL, dl, VT, Op, DAG.getConstant(56, dl, SHVT));
9885 Tmp7 = DAG.getNode(ISD::AND, dl, VT, Op,
9886 DAG.getConstant(255ULL<<8, dl, VT));
9887 Tmp7 = DAG.getNode(ISD::SHL, dl, VT, Tmp7, DAG.getConstant(40, dl, SHVT));
9888 Tmp6 = DAG.getNode(ISD::AND, dl, VT, Op,
9889 DAG.getConstant(255ULL<<16, dl, VT));
9890 Tmp6 = DAG.getNode(ISD::SHL, dl, VT, Tmp6, DAG.getConstant(24, dl, SHVT));
9891 Tmp5 = DAG.getNode(ISD::AND, dl, VT, Op,
9892 DAG.getConstant(255ULL<<24, dl, VT));
9893 Tmp5 = DAG.getNode(ISD::SHL, dl, VT, Tmp5, DAG.getConstant(8, dl, SHVT));
9894 Tmp4 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(8, dl, SHVT));
9895 Tmp4 = DAG.getNode(ISD::AND, dl, VT, Tmp4,
9896 DAG.getConstant(255ULL<<24, dl, VT));
9897 Tmp3 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(24, dl, SHVT));
9898 Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp3,
9899 DAG.getConstant(255ULL<<16, dl, VT));
9900 Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(40, dl, SHVT));
9901 Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2,
9902 DAG.getConstant(255ULL<<8, dl, VT));
9903 Tmp1 = DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(56, dl, SHVT));
9904 Tmp8 = DAG.getNode(ISD::OR, dl, VT, Tmp8, Tmp7);
9905 Tmp6 = DAG.getNode(ISD::OR, dl, VT, Tmp6, Tmp5);
9906 Tmp4 = DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp3);
9907 Tmp2 = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp1);
9908 Tmp8 = DAG.getNode(ISD::OR, dl, VT, Tmp8, Tmp6);
9909 Tmp4 = DAG.getNode(ISD::OR, dl, VT, Tmp4, Tmp2);
9910 return DAG.getNode(ISD::OR, dl, VT, Tmp8, Tmp4);
9911 }
9912 }
9913
expandVPBSWAP(SDNode * N,SelectionDAG & DAG) const9914 SDValue TargetLowering::expandVPBSWAP(SDNode *N, SelectionDAG &DAG) const {
9915 SDLoc dl(N);
9916 EVT VT = N->getValueType(0);
9917 SDValue Op = N->getOperand(0);
9918 SDValue Mask = N->getOperand(1);
9919 SDValue EVL = N->getOperand(2);
9920
9921 if (!VT.isSimple())
9922 return SDValue();
9923
9924 EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
9925 SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5, Tmp6, Tmp7, Tmp8;
9926 switch (VT.getSimpleVT().getScalarType().SimpleTy) {
9927 default:
9928 return SDValue();
9929 case MVT::i16:
9930 Tmp1 = DAG.getNode(ISD::VP_SHL, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
9931 Mask, EVL);
9932 Tmp2 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
9933 Mask, EVL);
9934 return DAG.getNode(ISD::VP_OR, dl, VT, Tmp1, Tmp2, Mask, EVL);
9935 case MVT::i32:
9936 Tmp4 = DAG.getNode(ISD::VP_SHL, dl, VT, Op, DAG.getConstant(24, dl, SHVT),
9937 Mask, EVL);
9938 Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Op, DAG.getConstant(0xFF00, dl, VT),
9939 Mask, EVL);
9940 Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(8, dl, SHVT),
9941 Mask, EVL);
9942 Tmp2 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
9943 Mask, EVL);
9944 Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
9945 DAG.getConstant(0xFF00, dl, VT), Mask, EVL);
9946 Tmp1 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(24, dl, SHVT),
9947 Mask, EVL);
9948 Tmp4 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp3, Mask, EVL);
9949 Tmp2 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp1, Mask, EVL);
9950 return DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp2, Mask, EVL);
9951 case MVT::i64:
9952 Tmp8 = DAG.getNode(ISD::VP_SHL, dl, VT, Op, DAG.getConstant(56, dl, SHVT),
9953 Mask, EVL);
9954 Tmp7 = DAG.getNode(ISD::VP_AND, dl, VT, Op,
9955 DAG.getConstant(255ULL << 8, dl, VT), Mask, EVL);
9956 Tmp7 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp7, DAG.getConstant(40, dl, SHVT),
9957 Mask, EVL);
9958 Tmp6 = DAG.getNode(ISD::VP_AND, dl, VT, Op,
9959 DAG.getConstant(255ULL << 16, dl, VT), Mask, EVL);
9960 Tmp6 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp6, DAG.getConstant(24, dl, SHVT),
9961 Mask, EVL);
9962 Tmp5 = DAG.getNode(ISD::VP_AND, dl, VT, Op,
9963 DAG.getConstant(255ULL << 24, dl, VT), Mask, EVL);
9964 Tmp5 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp5, DAG.getConstant(8, dl, SHVT),
9965 Mask, EVL);
9966 Tmp4 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(8, dl, SHVT),
9967 Mask, EVL);
9968 Tmp4 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp4,
9969 DAG.getConstant(255ULL << 24, dl, VT), Mask, EVL);
9970 Tmp3 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(24, dl, SHVT),
9971 Mask, EVL);
9972 Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp3,
9973 DAG.getConstant(255ULL << 16, dl, VT), Mask, EVL);
9974 Tmp2 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(40, dl, SHVT),
9975 Mask, EVL);
9976 Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
9977 DAG.getConstant(255ULL << 8, dl, VT), Mask, EVL);
9978 Tmp1 = DAG.getNode(ISD::VP_SRL, dl, VT, Op, DAG.getConstant(56, dl, SHVT),
9979 Mask, EVL);
9980 Tmp8 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp8, Tmp7, Mask, EVL);
9981 Tmp6 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp6, Tmp5, Mask, EVL);
9982 Tmp4 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp3, Mask, EVL);
9983 Tmp2 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp1, Mask, EVL);
9984 Tmp8 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp8, Tmp6, Mask, EVL);
9985 Tmp4 = DAG.getNode(ISD::VP_OR, dl, VT, Tmp4, Tmp2, Mask, EVL);
9986 return DAG.getNode(ISD::VP_OR, dl, VT, Tmp8, Tmp4, Mask, EVL);
9987 }
9988 }
9989
expandBITREVERSE(SDNode * N,SelectionDAG & DAG) const9990 SDValue TargetLowering::expandBITREVERSE(SDNode *N, SelectionDAG &DAG) const {
9991 SDLoc dl(N);
9992 EVT VT = N->getValueType(0);
9993 SDValue Op = N->getOperand(0);
9994 EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
9995 unsigned Sz = VT.getScalarSizeInBits();
9996
9997 SDValue Tmp, Tmp2, Tmp3;
9998
9999 // If we can, perform BSWAP first and then the mask+swap the i4, then i2
10000 // and finally the i1 pairs.
10001 // TODO: We can easily support i4/i2 legal types if any target ever does.
10002 if (Sz >= 8 && isPowerOf2_32(Sz)) {
10003 // Create the masks - repeating the pattern every byte.
10004 APInt Mask4 = APInt::getSplat(Sz, APInt(8, 0x0F));
10005 APInt Mask2 = APInt::getSplat(Sz, APInt(8, 0x33));
10006 APInt Mask1 = APInt::getSplat(Sz, APInt(8, 0x55));
10007
10008 // BSWAP if the type is wider than a single byte.
10009 Tmp = (Sz > 8 ? DAG.getNode(ISD::BSWAP, dl, VT, Op) : Op);
10010
10011 // swap i4: ((V >> 4) & 0x0F) | ((V & 0x0F) << 4)
10012 Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp, DAG.getConstant(4, dl, SHVT));
10013 Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Mask4, dl, VT));
10014 Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(Mask4, dl, VT));
10015 Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(4, dl, SHVT));
10016 Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
10017
10018 // swap i2: ((V >> 2) & 0x33) | ((V & 0x33) << 2)
10019 Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp, DAG.getConstant(2, dl, SHVT));
10020 Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Mask2, dl, VT));
10021 Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(Mask2, dl, VT));
10022 Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(2, dl, SHVT));
10023 Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
10024
10025 // swap i1: ((V >> 1) & 0x55) | ((V & 0x55) << 1)
10026 Tmp2 = DAG.getNode(ISD::SRL, dl, VT, Tmp, DAG.getConstant(1, dl, SHVT));
10027 Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Mask1, dl, VT));
10028 Tmp3 = DAG.getNode(ISD::AND, dl, VT, Tmp, DAG.getConstant(Mask1, dl, VT));
10029 Tmp3 = DAG.getNode(ISD::SHL, dl, VT, Tmp3, DAG.getConstant(1, dl, SHVT));
10030 Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp2, Tmp3);
10031 return Tmp;
10032 }
10033
10034 Tmp = DAG.getConstant(0, dl, VT);
10035 for (unsigned I = 0, J = Sz-1; I < Sz; ++I, --J) {
10036 if (I < J)
10037 Tmp2 =
10038 DAG.getNode(ISD::SHL, dl, VT, Op, DAG.getConstant(J - I, dl, SHVT));
10039 else
10040 Tmp2 =
10041 DAG.getNode(ISD::SRL, dl, VT, Op, DAG.getConstant(I - J, dl, SHVT));
10042
10043 APInt Shift = APInt::getOneBitSet(Sz, J);
10044 Tmp2 = DAG.getNode(ISD::AND, dl, VT, Tmp2, DAG.getConstant(Shift, dl, VT));
10045 Tmp = DAG.getNode(ISD::OR, dl, VT, Tmp, Tmp2);
10046 }
10047
10048 return Tmp;
10049 }
10050
expandVPBITREVERSE(SDNode * N,SelectionDAG & DAG) const10051 SDValue TargetLowering::expandVPBITREVERSE(SDNode *N, SelectionDAG &DAG) const {
10052 assert(N->getOpcode() == ISD::VP_BITREVERSE);
10053
10054 SDLoc dl(N);
10055 EVT VT = N->getValueType(0);
10056 SDValue Op = N->getOperand(0);
10057 SDValue Mask = N->getOperand(1);
10058 SDValue EVL = N->getOperand(2);
10059 EVT SHVT = getShiftAmountTy(VT, DAG.getDataLayout());
10060 unsigned Sz = VT.getScalarSizeInBits();
10061
10062 SDValue Tmp, Tmp2, Tmp3;
10063
10064 // If we can, perform BSWAP first and then the mask+swap the i4, then i2
10065 // and finally the i1 pairs.
10066 // TODO: We can easily support i4/i2 legal types if any target ever does.
10067 if (Sz >= 8 && isPowerOf2_32(Sz)) {
10068 // Create the masks - repeating the pattern every byte.
10069 APInt Mask4 = APInt::getSplat(Sz, APInt(8, 0x0F));
10070 APInt Mask2 = APInt::getSplat(Sz, APInt(8, 0x33));
10071 APInt Mask1 = APInt::getSplat(Sz, APInt(8, 0x55));
10072
10073 // BSWAP if the type is wider than a single byte.
10074 Tmp = (Sz > 8 ? DAG.getNode(ISD::VP_BSWAP, dl, VT, Op, Mask, EVL) : Op);
10075
10076 // swap i4: ((V >> 4) & 0x0F) | ((V & 0x0F) << 4)
10077 Tmp2 = DAG.getNode(ISD::VP_SRL, dl, VT, Tmp, DAG.getConstant(4, dl, SHVT),
10078 Mask, EVL);
10079 Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
10080 DAG.getConstant(Mask4, dl, VT), Mask, EVL);
10081 Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp, DAG.getConstant(Mask4, dl, VT),
10082 Mask, EVL);
10083 Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(4, dl, SHVT),
10084 Mask, EVL);
10085 Tmp = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp3, Mask, EVL);
10086
10087 // swap i2: ((V >> 2) & 0x33) | ((V & 0x33) << 2)
10088 Tmp2 = DAG.getNode(ISD::VP_SRL, dl, VT, Tmp, DAG.getConstant(2, dl, SHVT),
10089 Mask, EVL);
10090 Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
10091 DAG.getConstant(Mask2, dl, VT), Mask, EVL);
10092 Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp, DAG.getConstant(Mask2, dl, VT),
10093 Mask, EVL);
10094 Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(2, dl, SHVT),
10095 Mask, EVL);
10096 Tmp = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp3, Mask, EVL);
10097
10098 // swap i1: ((V >> 1) & 0x55) | ((V & 0x55) << 1)
10099 Tmp2 = DAG.getNode(ISD::VP_SRL, dl, VT, Tmp, DAG.getConstant(1, dl, SHVT),
10100 Mask, EVL);
10101 Tmp2 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp2,
10102 DAG.getConstant(Mask1, dl, VT), Mask, EVL);
10103 Tmp3 = DAG.getNode(ISD::VP_AND, dl, VT, Tmp, DAG.getConstant(Mask1, dl, VT),
10104 Mask, EVL);
10105 Tmp3 = DAG.getNode(ISD::VP_SHL, dl, VT, Tmp3, DAG.getConstant(1, dl, SHVT),
10106 Mask, EVL);
10107 Tmp = DAG.getNode(ISD::VP_OR, dl, VT, Tmp2, Tmp3, Mask, EVL);
10108 return Tmp;
10109 }
10110 return SDValue();
10111 }
10112
10113 std::pair<SDValue, SDValue>
scalarizeVectorLoad(LoadSDNode * LD,SelectionDAG & DAG) const10114 TargetLowering::scalarizeVectorLoad(LoadSDNode *LD,
10115 SelectionDAG &DAG) const {
10116 SDLoc SL(LD);
10117 SDValue Chain = LD->getChain();
10118 SDValue BasePTR = LD->getBasePtr();
10119 EVT SrcVT = LD->getMemoryVT();
10120 EVT DstVT = LD->getValueType(0);
10121 ISD::LoadExtType ExtType = LD->getExtensionType();
10122
10123 if (SrcVT.isScalableVector())
10124 report_fatal_error("Cannot scalarize scalable vector loads");
10125
10126 unsigned NumElem = SrcVT.getVectorNumElements();
10127
10128 EVT SrcEltVT = SrcVT.getScalarType();
10129 EVT DstEltVT = DstVT.getScalarType();
10130
10131 // A vector must always be stored in memory as-is, i.e. without any padding
10132 // between the elements, since various code depend on it, e.g. in the
10133 // handling of a bitcast of a vector type to int, which may be done with a
10134 // vector store followed by an integer load. A vector that does not have
10135 // elements that are byte-sized must therefore be stored as an integer
10136 // built out of the extracted vector elements.
10137 if (!SrcEltVT.isByteSized()) {
10138 unsigned NumLoadBits = SrcVT.getStoreSizeInBits();
10139 EVT LoadVT = EVT::getIntegerVT(*DAG.getContext(), NumLoadBits);
10140
10141 unsigned NumSrcBits = SrcVT.getSizeInBits();
10142 EVT SrcIntVT = EVT::getIntegerVT(*DAG.getContext(), NumSrcBits);
10143
10144 unsigned SrcEltBits = SrcEltVT.getSizeInBits();
10145 SDValue SrcEltBitMask = DAG.getConstant(
10146 APInt::getLowBitsSet(NumLoadBits, SrcEltBits), SL, LoadVT);
10147
10148 // Load the whole vector and avoid masking off the top bits as it makes
10149 // the codegen worse.
10150 SDValue Load =
10151 DAG.getExtLoad(ISD::EXTLOAD, SL, LoadVT, Chain, BasePTR,
10152 LD->getPointerInfo(), SrcIntVT, LD->getBaseAlign(),
10153 LD->getMemOperand()->getFlags(), LD->getAAInfo());
10154
10155 SmallVector<SDValue, 8> Vals;
10156 for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
10157 unsigned ShiftIntoIdx =
10158 (DAG.getDataLayout().isBigEndian() ? (NumElem - 1) - Idx : Idx);
10159 SDValue ShiftAmount = DAG.getShiftAmountConstant(
10160 ShiftIntoIdx * SrcEltVT.getSizeInBits(), LoadVT, SL);
10161 SDValue ShiftedElt = DAG.getNode(ISD::SRL, SL, LoadVT, Load, ShiftAmount);
10162 SDValue Elt =
10163 DAG.getNode(ISD::AND, SL, LoadVT, ShiftedElt, SrcEltBitMask);
10164 SDValue Scalar = DAG.getNode(ISD::TRUNCATE, SL, SrcEltVT, Elt);
10165
10166 if (ExtType != ISD::NON_EXTLOAD) {
10167 unsigned ExtendOp = ISD::getExtForLoadExtType(false, ExtType);
10168 Scalar = DAG.getNode(ExtendOp, SL, DstEltVT, Scalar);
10169 }
10170
10171 Vals.push_back(Scalar);
10172 }
10173
10174 SDValue Value = DAG.getBuildVector(DstVT, SL, Vals);
10175 return std::make_pair(Value, Load.getValue(1));
10176 }
10177
10178 unsigned Stride = SrcEltVT.getSizeInBits() / 8;
10179 assert(SrcEltVT.isByteSized());
10180
10181 SmallVector<SDValue, 8> Vals;
10182 SmallVector<SDValue, 8> LoadChains;
10183
10184 for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
10185 SDValue ScalarLoad = DAG.getExtLoad(
10186 ExtType, SL, DstEltVT, Chain, BasePTR,
10187 LD->getPointerInfo().getWithOffset(Idx * Stride), SrcEltVT,
10188 LD->getBaseAlign(), LD->getMemOperand()->getFlags(), LD->getAAInfo());
10189
10190 BasePTR = DAG.getObjectPtrOffset(SL, BasePTR, TypeSize::getFixed(Stride));
10191
10192 Vals.push_back(ScalarLoad.getValue(0));
10193 LoadChains.push_back(ScalarLoad.getValue(1));
10194 }
10195
10196 SDValue NewChain = DAG.getNode(ISD::TokenFactor, SL, MVT::Other, LoadChains);
10197 SDValue Value = DAG.getBuildVector(DstVT, SL, Vals);
10198
10199 return std::make_pair(Value, NewChain);
10200 }
10201
scalarizeVectorStore(StoreSDNode * ST,SelectionDAG & DAG) const10202 SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST,
10203 SelectionDAG &DAG) const {
10204 SDLoc SL(ST);
10205
10206 SDValue Chain = ST->getChain();
10207 SDValue BasePtr = ST->getBasePtr();
10208 SDValue Value = ST->getValue();
10209 EVT StVT = ST->getMemoryVT();
10210
10211 if (StVT.isScalableVector())
10212 report_fatal_error("Cannot scalarize scalable vector stores");
10213
10214 // The type of the data we want to save
10215 EVT RegVT = Value.getValueType();
10216 EVT RegSclVT = RegVT.getScalarType();
10217
10218 // The type of data as saved in memory.
10219 EVT MemSclVT = StVT.getScalarType();
10220
10221 unsigned NumElem = StVT.getVectorNumElements();
10222
10223 // A vector must always be stored in memory as-is, i.e. without any padding
10224 // between the elements, since various code depend on it, e.g. in the
10225 // handling of a bitcast of a vector type to int, which may be done with a
10226 // vector store followed by an integer load. A vector that does not have
10227 // elements that are byte-sized must therefore be stored as an integer
10228 // built out of the extracted vector elements.
10229 if (!MemSclVT.isByteSized()) {
10230 unsigned NumBits = StVT.getSizeInBits();
10231 EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), NumBits);
10232
10233 SDValue CurrVal = DAG.getConstant(0, SL, IntVT);
10234
10235 for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
10236 SDValue Elt = DAG.getExtractVectorElt(SL, RegSclVT, Value, Idx);
10237 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MemSclVT, Elt);
10238 SDValue ExtElt = DAG.getNode(ISD::ZERO_EXTEND, SL, IntVT, Trunc);
10239 unsigned ShiftIntoIdx =
10240 (DAG.getDataLayout().isBigEndian() ? (NumElem - 1) - Idx : Idx);
10241 SDValue ShiftAmount =
10242 DAG.getConstant(ShiftIntoIdx * MemSclVT.getSizeInBits(), SL, IntVT);
10243 SDValue ShiftedElt =
10244 DAG.getNode(ISD::SHL, SL, IntVT, ExtElt, ShiftAmount);
10245 CurrVal = DAG.getNode(ISD::OR, SL, IntVT, CurrVal, ShiftedElt);
10246 }
10247
10248 return DAG.getStore(Chain, SL, CurrVal, BasePtr, ST->getPointerInfo(),
10249 ST->getBaseAlign(), ST->getMemOperand()->getFlags(),
10250 ST->getAAInfo());
10251 }
10252
10253 // Store Stride in bytes
10254 unsigned Stride = MemSclVT.getSizeInBits() / 8;
10255 assert(Stride && "Zero stride!");
10256 // Extract each of the elements from the original vector and save them into
10257 // memory individually.
10258 SmallVector<SDValue, 8> Stores;
10259 for (unsigned Idx = 0; Idx < NumElem; ++Idx) {
10260 SDValue Elt = DAG.getExtractVectorElt(SL, RegSclVT, Value, Idx);
10261
10262 SDValue Ptr =
10263 DAG.getObjectPtrOffset(SL, BasePtr, TypeSize::getFixed(Idx * Stride));
10264
10265 // This scalar TruncStore may be illegal, but we legalize it later.
10266 SDValue Store = DAG.getTruncStore(
10267 Chain, SL, Elt, Ptr, ST->getPointerInfo().getWithOffset(Idx * Stride),
10268 MemSclVT, ST->getBaseAlign(), ST->getMemOperand()->getFlags(),
10269 ST->getAAInfo());
10270
10271 Stores.push_back(Store);
10272 }
10273
10274 return DAG.getNode(ISD::TokenFactor, SL, MVT::Other, Stores);
10275 }
10276
10277 std::pair<SDValue, SDValue>
expandUnalignedLoad(LoadSDNode * LD,SelectionDAG & DAG) const10278 TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const {
10279 assert(LD->getAddressingMode() == ISD::UNINDEXED &&
10280 "unaligned indexed loads not implemented!");
10281 SDValue Chain = LD->getChain();
10282 SDValue Ptr = LD->getBasePtr();
10283 EVT VT = LD->getValueType(0);
10284 EVT LoadedVT = LD->getMemoryVT();
10285 SDLoc dl(LD);
10286 auto &MF = DAG.getMachineFunction();
10287
10288 if (VT.isFloatingPoint() || VT.isVector()) {
10289 EVT intVT = EVT::getIntegerVT(*DAG.getContext(), LoadedVT.getSizeInBits());
10290 if (isTypeLegal(intVT) && isTypeLegal(LoadedVT)) {
10291 if (!isOperationLegalOrCustom(ISD::LOAD, intVT) &&
10292 LoadedVT.isVector()) {
10293 // Scalarize the load and let the individual components be handled.
10294 return scalarizeVectorLoad(LD, DAG);
10295 }
10296
10297 // Expand to a (misaligned) integer load of the same size,
10298 // then bitconvert to floating point or vector.
10299 SDValue newLoad = DAG.getLoad(intVT, dl, Chain, Ptr,
10300 LD->getMemOperand());
10301 SDValue Result = DAG.getNode(ISD::BITCAST, dl, LoadedVT, newLoad);
10302 if (LoadedVT != VT)
10303 Result = DAG.getNode(VT.isFloatingPoint() ? ISD::FP_EXTEND :
10304 ISD::ANY_EXTEND, dl, VT, Result);
10305
10306 return std::make_pair(Result, newLoad.getValue(1));
10307 }
10308
10309 // Copy the value to a (aligned) stack slot using (unaligned) integer
10310 // loads and stores, then do a (aligned) load from the stack slot.
10311 MVT RegVT = getRegisterType(*DAG.getContext(), intVT);
10312 unsigned LoadedBytes = LoadedVT.getStoreSize();
10313 unsigned RegBytes = RegVT.getSizeInBits() / 8;
10314 unsigned NumRegs = (LoadedBytes + RegBytes - 1) / RegBytes;
10315
10316 // Make sure the stack slot is also aligned for the register type.
10317 SDValue StackBase = DAG.CreateStackTemporary(LoadedVT, RegVT);
10318 auto FrameIndex = cast<FrameIndexSDNode>(StackBase.getNode())->getIndex();
10319 SmallVector<SDValue, 8> Stores;
10320 SDValue StackPtr = StackBase;
10321 unsigned Offset = 0;
10322
10323 EVT PtrVT = Ptr.getValueType();
10324 EVT StackPtrVT = StackPtr.getValueType();
10325
10326 SDValue PtrIncrement = DAG.getConstant(RegBytes, dl, PtrVT);
10327 SDValue StackPtrIncrement = DAG.getConstant(RegBytes, dl, StackPtrVT);
10328
10329 // Do all but one copies using the full register width.
10330 for (unsigned i = 1; i < NumRegs; i++) {
10331 // Load one integer register's worth from the original location.
10332 SDValue Load = DAG.getLoad(
10333 RegVT, dl, Chain, Ptr, LD->getPointerInfo().getWithOffset(Offset),
10334 LD->getBaseAlign(), LD->getMemOperand()->getFlags(), LD->getAAInfo());
10335 // Follow the load with a store to the stack slot. Remember the store.
10336 Stores.push_back(DAG.getStore(
10337 Load.getValue(1), dl, Load, StackPtr,
10338 MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset)));
10339 // Increment the pointers.
10340 Offset += RegBytes;
10341
10342 Ptr = DAG.getObjectPtrOffset(dl, Ptr, PtrIncrement);
10343 StackPtr = DAG.getObjectPtrOffset(dl, StackPtr, StackPtrIncrement);
10344 }
10345
10346 // The last copy may be partial. Do an extending load.
10347 EVT MemVT = EVT::getIntegerVT(*DAG.getContext(),
10348 8 * (LoadedBytes - Offset));
10349 SDValue Load = DAG.getExtLoad(
10350 ISD::EXTLOAD, dl, RegVT, Chain, Ptr,
10351 LD->getPointerInfo().getWithOffset(Offset), MemVT, LD->getBaseAlign(),
10352 LD->getMemOperand()->getFlags(), LD->getAAInfo());
10353 // Follow the load with a store to the stack slot. Remember the store.
10354 // On big-endian machines this requires a truncating store to ensure
10355 // that the bits end up in the right place.
10356 Stores.push_back(DAG.getTruncStore(
10357 Load.getValue(1), dl, Load, StackPtr,
10358 MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset), MemVT));
10359
10360 // The order of the stores doesn't matter - say it with a TokenFactor.
10361 SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Stores);
10362
10363 // Finally, perform the original load only redirected to the stack slot.
10364 Load = DAG.getExtLoad(LD->getExtensionType(), dl, VT, TF, StackBase,
10365 MachinePointerInfo::getFixedStack(MF, FrameIndex, 0),
10366 LoadedVT);
10367
10368 // Callers expect a MERGE_VALUES node.
10369 return std::make_pair(Load, TF);
10370 }
10371
10372 assert(LoadedVT.isInteger() && !LoadedVT.isVector() &&
10373 "Unaligned load of unsupported type.");
10374
10375 // Compute the new VT that is half the size of the old one. This is an
10376 // integer MVT.
10377 unsigned NumBits = LoadedVT.getSizeInBits();
10378 EVT NewLoadedVT;
10379 NewLoadedVT = EVT::getIntegerVT(*DAG.getContext(), NumBits/2);
10380 NumBits >>= 1;
10381
10382 Align Alignment = LD->getBaseAlign();
10383 unsigned IncrementSize = NumBits / 8;
10384 ISD::LoadExtType HiExtType = LD->getExtensionType();
10385
10386 // If the original load is NON_EXTLOAD, the hi part load must be ZEXTLOAD.
10387 if (HiExtType == ISD::NON_EXTLOAD)
10388 HiExtType = ISD::ZEXTLOAD;
10389
10390 // Load the value in two parts
10391 SDValue Lo, Hi;
10392 if (DAG.getDataLayout().isLittleEndian()) {
10393 Lo = DAG.getExtLoad(ISD::ZEXTLOAD, dl, VT, Chain, Ptr, LD->getPointerInfo(),
10394 NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
10395 LD->getAAInfo());
10396
10397 Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(IncrementSize));
10398 Hi = DAG.getExtLoad(HiExtType, dl, VT, Chain, Ptr,
10399 LD->getPointerInfo().getWithOffset(IncrementSize),
10400 NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
10401 LD->getAAInfo());
10402 } else {
10403 Hi = DAG.getExtLoad(HiExtType, dl, VT, Chain, Ptr, LD->getPointerInfo(),
10404 NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
10405 LD->getAAInfo());
10406
10407 Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(IncrementSize));
10408 Lo = DAG.getExtLoad(ISD::ZEXTLOAD, dl, VT, Chain, Ptr,
10409 LD->getPointerInfo().getWithOffset(IncrementSize),
10410 NewLoadedVT, Alignment, LD->getMemOperand()->getFlags(),
10411 LD->getAAInfo());
10412 }
10413
10414 // aggregate the two parts
10415 SDValue ShiftAmount = DAG.getShiftAmountConstant(NumBits, VT, dl);
10416 SDValue Result = DAG.getNode(ISD::SHL, dl, VT, Hi, ShiftAmount);
10417 Result = DAG.getNode(ISD::OR, dl, VT, Result, Lo);
10418
10419 SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo.getValue(1),
10420 Hi.getValue(1));
10421
10422 return std::make_pair(Result, TF);
10423 }
10424
expandUnalignedStore(StoreSDNode * ST,SelectionDAG & DAG) const10425 SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST,
10426 SelectionDAG &DAG) const {
10427 assert(ST->getAddressingMode() == ISD::UNINDEXED &&
10428 "unaligned indexed stores not implemented!");
10429 SDValue Chain = ST->getChain();
10430 SDValue Ptr = ST->getBasePtr();
10431 SDValue Val = ST->getValue();
10432 EVT VT = Val.getValueType();
10433 Align Alignment = ST->getBaseAlign();
10434 auto &MF = DAG.getMachineFunction();
10435 EVT StoreMemVT = ST->getMemoryVT();
10436
10437 SDLoc dl(ST);
10438 if (StoreMemVT.isFloatingPoint() || StoreMemVT.isVector()) {
10439 EVT intVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
10440 if (isTypeLegal(intVT)) {
10441 if (!isOperationLegalOrCustom(ISD::STORE, intVT) &&
10442 StoreMemVT.isVector()) {
10443 // Scalarize the store and let the individual components be handled.
10444 SDValue Result = scalarizeVectorStore(ST, DAG);
10445 return Result;
10446 }
10447 // Expand to a bitconvert of the value to the integer type of the
10448 // same size, then a (misaligned) int store.
10449 // FIXME: Does not handle truncating floating point stores!
10450 SDValue Result = DAG.getNode(ISD::BITCAST, dl, intVT, Val);
10451 Result = DAG.getStore(Chain, dl, Result, Ptr, ST->getPointerInfo(),
10452 Alignment, ST->getMemOperand()->getFlags());
10453 return Result;
10454 }
10455 // Do a (aligned) store to a stack slot, then copy from the stack slot
10456 // to the final destination using (unaligned) integer loads and stores.
10457 MVT RegVT = getRegisterType(
10458 *DAG.getContext(),
10459 EVT::getIntegerVT(*DAG.getContext(), StoreMemVT.getSizeInBits()));
10460 EVT PtrVT = Ptr.getValueType();
10461 unsigned StoredBytes = StoreMemVT.getStoreSize();
10462 unsigned RegBytes = RegVT.getSizeInBits() / 8;
10463 unsigned NumRegs = (StoredBytes + RegBytes - 1) / RegBytes;
10464
10465 // Make sure the stack slot is also aligned for the register type.
10466 SDValue StackPtr = DAG.CreateStackTemporary(StoreMemVT, RegVT);
10467 auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
10468
10469 // Perform the original store, only redirected to the stack slot.
10470 SDValue Store = DAG.getTruncStore(
10471 Chain, dl, Val, StackPtr,
10472 MachinePointerInfo::getFixedStack(MF, FrameIndex, 0), StoreMemVT);
10473
10474 EVT StackPtrVT = StackPtr.getValueType();
10475
10476 SDValue PtrIncrement = DAG.getConstant(RegBytes, dl, PtrVT);
10477 SDValue StackPtrIncrement = DAG.getConstant(RegBytes, dl, StackPtrVT);
10478 SmallVector<SDValue, 8> Stores;
10479 unsigned Offset = 0;
10480
10481 // Do all but one copies using the full register width.
10482 for (unsigned i = 1; i < NumRegs; i++) {
10483 // Load one integer register's worth from the stack slot.
10484 SDValue Load = DAG.getLoad(
10485 RegVT, dl, Store, StackPtr,
10486 MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset));
10487 // Store it to the final location. Remember the store.
10488 Stores.push_back(DAG.getStore(Load.getValue(1), dl, Load, Ptr,
10489 ST->getPointerInfo().getWithOffset(Offset),
10490 ST->getBaseAlign(),
10491 ST->getMemOperand()->getFlags()));
10492 // Increment the pointers.
10493 Offset += RegBytes;
10494 StackPtr = DAG.getObjectPtrOffset(dl, StackPtr, StackPtrIncrement);
10495 Ptr = DAG.getObjectPtrOffset(dl, Ptr, PtrIncrement);
10496 }
10497
10498 // The last store may be partial. Do a truncating store. On big-endian
10499 // machines this requires an extending load from the stack slot to ensure
10500 // that the bits are in the right place.
10501 EVT LoadMemVT =
10502 EVT::getIntegerVT(*DAG.getContext(), 8 * (StoredBytes - Offset));
10503
10504 // Load from the stack slot.
10505 SDValue Load = DAG.getExtLoad(
10506 ISD::EXTLOAD, dl, RegVT, Store, StackPtr,
10507 MachinePointerInfo::getFixedStack(MF, FrameIndex, Offset), LoadMemVT);
10508
10509 Stores.push_back(DAG.getTruncStore(
10510 Load.getValue(1), dl, Load, Ptr,
10511 ST->getPointerInfo().getWithOffset(Offset), LoadMemVT,
10512 ST->getBaseAlign(), ST->getMemOperand()->getFlags(), ST->getAAInfo()));
10513 // The order of the stores doesn't matter - say it with a TokenFactor.
10514 SDValue Result = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Stores);
10515 return Result;
10516 }
10517
10518 assert(StoreMemVT.isInteger() && !StoreMemVT.isVector() &&
10519 "Unaligned store of unknown type.");
10520 // Get the half-size VT
10521 EVT NewStoredVT = StoreMemVT.getHalfSizedIntegerVT(*DAG.getContext());
10522 unsigned NumBits = NewStoredVT.getFixedSizeInBits();
10523 unsigned IncrementSize = NumBits / 8;
10524
10525 // Divide the stored value in two parts.
10526 SDValue ShiftAmount =
10527 DAG.getShiftAmountConstant(NumBits, Val.getValueType(), dl);
10528 SDValue Lo = Val;
10529 // If Val is a constant, replace the upper bits with 0. The SRL will constant
10530 // fold and not use the upper bits. A smaller constant may be easier to
10531 // materialize.
10532 if (auto *C = dyn_cast<ConstantSDNode>(Lo); C && !C->isOpaque())
10533 Lo = DAG.getNode(
10534 ISD::AND, dl, VT, Lo,
10535 DAG.getConstant(APInt::getLowBitsSet(VT.getSizeInBits(), NumBits), dl,
10536 VT));
10537 SDValue Hi = DAG.getNode(ISD::SRL, dl, VT, Val, ShiftAmount);
10538
10539 // Store the two parts
10540 SDValue Store1, Store2;
10541 Store1 = DAG.getTruncStore(Chain, dl,
10542 DAG.getDataLayout().isLittleEndian() ? Lo : Hi,
10543 Ptr, ST->getPointerInfo(), NewStoredVT, Alignment,
10544 ST->getMemOperand()->getFlags());
10545
10546 Ptr = DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(IncrementSize));
10547 Store2 = DAG.getTruncStore(
10548 Chain, dl, DAG.getDataLayout().isLittleEndian() ? Hi : Lo, Ptr,
10549 ST->getPointerInfo().getWithOffset(IncrementSize), NewStoredVT, Alignment,
10550 ST->getMemOperand()->getFlags(), ST->getAAInfo());
10551
10552 SDValue Result =
10553 DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store1, Store2);
10554 return Result;
10555 }
10556
10557 SDValue
IncrementMemoryAddress(SDValue Addr,SDValue Mask,const SDLoc & DL,EVT DataVT,SelectionDAG & DAG,bool IsCompressedMemory) const10558 TargetLowering::IncrementMemoryAddress(SDValue Addr, SDValue Mask,
10559 const SDLoc &DL, EVT DataVT,
10560 SelectionDAG &DAG,
10561 bool IsCompressedMemory) const {
10562 SDValue Increment;
10563 EVT AddrVT = Addr.getValueType();
10564 EVT MaskVT = Mask.getValueType();
10565 assert(DataVT.getVectorElementCount() == MaskVT.getVectorElementCount() &&
10566 "Incompatible types of Data and Mask");
10567 if (IsCompressedMemory) {
10568 if (DataVT.isScalableVector())
10569 report_fatal_error(
10570 "Cannot currently handle compressed memory with scalable vectors");
10571 // Incrementing the pointer according to number of '1's in the mask.
10572 EVT MaskIntVT = EVT::getIntegerVT(*DAG.getContext(), MaskVT.getSizeInBits());
10573 SDValue MaskInIntReg = DAG.getBitcast(MaskIntVT, Mask);
10574 if (MaskIntVT.getSizeInBits() < 32) {
10575 MaskInIntReg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskInIntReg);
10576 MaskIntVT = MVT::i32;
10577 }
10578
10579 // Count '1's with POPCNT.
10580 Increment = DAG.getNode(ISD::CTPOP, DL, MaskIntVT, MaskInIntReg);
10581 Increment = DAG.getZExtOrTrunc(Increment, DL, AddrVT);
10582 // Scale is an element size in bytes.
10583 SDValue Scale = DAG.getConstant(DataVT.getScalarSizeInBits() / 8, DL,
10584 AddrVT);
10585 Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale);
10586 } else if (DataVT.isScalableVector()) {
10587 Increment = DAG.getVScale(DL, AddrVT,
10588 APInt(AddrVT.getFixedSizeInBits(),
10589 DataVT.getStoreSize().getKnownMinValue()));
10590 } else
10591 Increment = DAG.getConstant(DataVT.getStoreSize(), DL, AddrVT);
10592
10593 return DAG.getNode(ISD::ADD, DL, AddrVT, Addr, Increment);
10594 }
10595
clampDynamicVectorIndex(SelectionDAG & DAG,SDValue Idx,EVT VecVT,const SDLoc & dl,ElementCount SubEC)10596 static SDValue clampDynamicVectorIndex(SelectionDAG &DAG, SDValue Idx,
10597 EVT VecVT, const SDLoc &dl,
10598 ElementCount SubEC) {
10599 assert(!(SubEC.isScalable() && VecVT.isFixedLengthVector()) &&
10600 "Cannot index a scalable vector within a fixed-width vector");
10601
10602 unsigned NElts = VecVT.getVectorMinNumElements();
10603 unsigned NumSubElts = SubEC.getKnownMinValue();
10604 EVT IdxVT = Idx.getValueType();
10605
10606 if (VecVT.isScalableVector() && !SubEC.isScalable()) {
10607 // If this is a constant index and we know the value plus the number of the
10608 // elements in the subvector minus one is less than the minimum number of
10609 // elements then it's safe to return Idx.
10610 if (auto *IdxCst = dyn_cast<ConstantSDNode>(Idx))
10611 if (IdxCst->getZExtValue() + (NumSubElts - 1) < NElts)
10612 return Idx;
10613 SDValue VS =
10614 DAG.getVScale(dl, IdxVT, APInt(IdxVT.getFixedSizeInBits(), NElts));
10615 unsigned SubOpcode = NumSubElts <= NElts ? ISD::SUB : ISD::USUBSAT;
10616 SDValue Sub = DAG.getNode(SubOpcode, dl, IdxVT, VS,
10617 DAG.getConstant(NumSubElts, dl, IdxVT));
10618 return DAG.getNode(ISD::UMIN, dl, IdxVT, Idx, Sub);
10619 }
10620 if (isPowerOf2_32(NElts) && NumSubElts == 1) {
10621 APInt Imm = APInt::getLowBitsSet(IdxVT.getSizeInBits(), Log2_32(NElts));
10622 return DAG.getNode(ISD::AND, dl, IdxVT, Idx,
10623 DAG.getConstant(Imm, dl, IdxVT));
10624 }
10625 unsigned MaxIndex = NumSubElts < NElts ? NElts - NumSubElts : 0;
10626 return DAG.getNode(ISD::UMIN, dl, IdxVT, Idx,
10627 DAG.getConstant(MaxIndex, dl, IdxVT));
10628 }
10629
getVectorElementPointer(SelectionDAG & DAG,SDValue VecPtr,EVT VecVT,SDValue Index) const10630 SDValue TargetLowering::getVectorElementPointer(SelectionDAG &DAG,
10631 SDValue VecPtr, EVT VecVT,
10632 SDValue Index) const {
10633 return getVectorSubVecPointer(
10634 DAG, VecPtr, VecVT,
10635 EVT::getVectorVT(*DAG.getContext(), VecVT.getVectorElementType(), 1),
10636 Index);
10637 }
10638
getVectorSubVecPointer(SelectionDAG & DAG,SDValue VecPtr,EVT VecVT,EVT SubVecVT,SDValue Index) const10639 SDValue TargetLowering::getVectorSubVecPointer(SelectionDAG &DAG,
10640 SDValue VecPtr, EVT VecVT,
10641 EVT SubVecVT,
10642 SDValue Index) const {
10643 SDLoc dl(Index);
10644 // Make sure the index type is big enough to compute in.
10645 Index = DAG.getZExtOrTrunc(Index, dl, VecPtr.getValueType());
10646
10647 EVT EltVT = VecVT.getVectorElementType();
10648
10649 // Calculate the element offset and add it to the pointer.
10650 unsigned EltSize = EltVT.getFixedSizeInBits() / 8; // FIXME: should be ABI size.
10651 assert(EltSize * 8 == EltVT.getFixedSizeInBits() &&
10652 "Converting bits to bytes lost precision");
10653 assert(SubVecVT.getVectorElementType() == EltVT &&
10654 "Sub-vector must be a vector with matching element type");
10655 Index = clampDynamicVectorIndex(DAG, Index, VecVT, dl,
10656 SubVecVT.getVectorElementCount());
10657
10658 EVT IdxVT = Index.getValueType();
10659 if (SubVecVT.isScalableVector())
10660 Index =
10661 DAG.getNode(ISD::MUL, dl, IdxVT, Index,
10662 DAG.getVScale(dl, IdxVT, APInt(IdxVT.getSizeInBits(), 1)));
10663
10664 Index = DAG.getNode(ISD::MUL, dl, IdxVT, Index,
10665 DAG.getConstant(EltSize, dl, IdxVT));
10666 return DAG.getMemBasePlusOffset(VecPtr, Index, dl);
10667 }
10668
10669 //===----------------------------------------------------------------------===//
10670 // Implementation of Emulated TLS Model
10671 //===----------------------------------------------------------------------===//
10672
LowerToTLSEmulatedModel(const GlobalAddressSDNode * GA,SelectionDAG & DAG) const10673 SDValue TargetLowering::LowerToTLSEmulatedModel(const GlobalAddressSDNode *GA,
10674 SelectionDAG &DAG) const {
10675 // Access to address of TLS varialbe xyz is lowered to a function call:
10676 // __emutls_get_address( address of global variable named "__emutls_v.xyz" )
10677 EVT PtrVT = getPointerTy(DAG.getDataLayout());
10678 PointerType *VoidPtrType = PointerType::get(*DAG.getContext(), 0);
10679 SDLoc dl(GA);
10680
10681 ArgListTy Args;
10682 ArgListEntry Entry;
10683 const GlobalValue *GV =
10684 cast<GlobalValue>(GA->getGlobal()->stripPointerCastsAndAliases());
10685 SmallString<32> NameString("__emutls_v.");
10686 NameString += GV->getName();
10687 StringRef EmuTlsVarName(NameString);
10688 const GlobalVariable *EmuTlsVar =
10689 GV->getParent()->getNamedGlobal(EmuTlsVarName);
10690 assert(EmuTlsVar && "Cannot find EmuTlsVar ");
10691 Entry.Node = DAG.getGlobalAddress(EmuTlsVar, dl, PtrVT);
10692 Entry.Ty = VoidPtrType;
10693 Args.push_back(Entry);
10694
10695 SDValue EmuTlsGetAddr = DAG.getExternalSymbol("__emutls_get_address", PtrVT);
10696
10697 TargetLowering::CallLoweringInfo CLI(DAG);
10698 CLI.setDebugLoc(dl).setChain(DAG.getEntryNode());
10699 CLI.setLibCallee(CallingConv::C, VoidPtrType, EmuTlsGetAddr, std::move(Args));
10700 std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
10701
10702 // TLSADDR will be codegen'ed as call. Inform MFI that function has calls.
10703 // At last for X86 targets, maybe good for other targets too?
10704 MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
10705 MFI.setAdjustsStack(true); // Is this only for X86 target?
10706 MFI.setHasCalls(true);
10707
10708 assert((GA->getOffset() == 0) &&
10709 "Emulated TLS must have zero offset in GlobalAddressSDNode");
10710 return CallResult.first;
10711 }
10712
lowerCmpEqZeroToCtlzSrl(SDValue Op,SelectionDAG & DAG) const10713 SDValue TargetLowering::lowerCmpEqZeroToCtlzSrl(SDValue Op,
10714 SelectionDAG &DAG) const {
10715 assert((Op->getOpcode() == ISD::SETCC) && "Input has to be a SETCC node.");
10716 if (!isCtlzFast())
10717 return SDValue();
10718 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
10719 SDLoc dl(Op);
10720 if (isNullConstant(Op.getOperand(1)) && CC == ISD::SETEQ) {
10721 EVT VT = Op.getOperand(0).getValueType();
10722 SDValue Zext = Op.getOperand(0);
10723 if (VT.bitsLT(MVT::i32)) {
10724 VT = MVT::i32;
10725 Zext = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Op.getOperand(0));
10726 }
10727 unsigned Log2b = Log2_32(VT.getSizeInBits());
10728 SDValue Clz = DAG.getNode(ISD::CTLZ, dl, VT, Zext);
10729 SDValue Scc = DAG.getNode(ISD::SRL, dl, VT, Clz,
10730 DAG.getConstant(Log2b, dl, MVT::i32));
10731 return DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Scc);
10732 }
10733 return SDValue();
10734 }
10735
expandIntMINMAX(SDNode * Node,SelectionDAG & DAG) const10736 SDValue TargetLowering::expandIntMINMAX(SDNode *Node, SelectionDAG &DAG) const {
10737 SDValue Op0 = Node->getOperand(0);
10738 SDValue Op1 = Node->getOperand(1);
10739 EVT VT = Op0.getValueType();
10740 EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
10741 unsigned Opcode = Node->getOpcode();
10742 SDLoc DL(Node);
10743
10744 // umax(x,1) --> sub(x,cmpeq(x,0)) iff cmp result is allbits
10745 if (Opcode == ISD::UMAX && llvm::isOneOrOneSplat(Op1, true) && BoolVT == VT &&
10746 getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) {
10747 Op0 = DAG.getFreeze(Op0);
10748 SDValue Zero = DAG.getConstant(0, DL, VT);
10749 return DAG.getNode(ISD::SUB, DL, VT, Op0,
10750 DAG.getSetCC(DL, VT, Op0, Zero, ISD::SETEQ));
10751 }
10752
10753 // umin(x,y) -> sub(x,usubsat(x,y))
10754 // TODO: Missing freeze(Op0)?
10755 if (Opcode == ISD::UMIN && isOperationLegal(ISD::SUB, VT) &&
10756 isOperationLegal(ISD::USUBSAT, VT)) {
10757 return DAG.getNode(ISD::SUB, DL, VT, Op0,
10758 DAG.getNode(ISD::USUBSAT, DL, VT, Op0, Op1));
10759 }
10760
10761 // umax(x,y) -> add(x,usubsat(y,x))
10762 // TODO: Missing freeze(Op0)?
10763 if (Opcode == ISD::UMAX && isOperationLegal(ISD::ADD, VT) &&
10764 isOperationLegal(ISD::USUBSAT, VT)) {
10765 return DAG.getNode(ISD::ADD, DL, VT, Op0,
10766 DAG.getNode(ISD::USUBSAT, DL, VT, Op1, Op0));
10767 }
10768
10769 // FIXME: Should really try to split the vector in case it's legal on a
10770 // subvector.
10771 if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
10772 return DAG.UnrollVectorOp(Node);
10773
10774 // Attempt to find an existing SETCC node that we can reuse.
10775 // TODO: Do we need a generic doesSETCCNodeExist?
10776 // TODO: Missing freeze(Op0)/freeze(Op1)?
10777 auto buildMinMax = [&](ISD::CondCode PrefCC, ISD::CondCode AltCC,
10778 ISD::CondCode PrefCommuteCC,
10779 ISD::CondCode AltCommuteCC) {
10780 SDVTList BoolVTList = DAG.getVTList(BoolVT);
10781 for (ISD::CondCode CC : {PrefCC, AltCC}) {
10782 if (DAG.doesNodeExist(ISD::SETCC, BoolVTList,
10783 {Op0, Op1, DAG.getCondCode(CC)})) {
10784 SDValue Cond = DAG.getSetCC(DL, BoolVT, Op0, Op1, CC);
10785 return DAG.getSelect(DL, VT, Cond, Op0, Op1);
10786 }
10787 }
10788 for (ISD::CondCode CC : {PrefCommuteCC, AltCommuteCC}) {
10789 if (DAG.doesNodeExist(ISD::SETCC, BoolVTList,
10790 {Op0, Op1, DAG.getCondCode(CC)})) {
10791 SDValue Cond = DAG.getSetCC(DL, BoolVT, Op0, Op1, CC);
10792 return DAG.getSelect(DL, VT, Cond, Op1, Op0);
10793 }
10794 }
10795 SDValue Cond = DAG.getSetCC(DL, BoolVT, Op0, Op1, PrefCC);
10796 return DAG.getSelect(DL, VT, Cond, Op0, Op1);
10797 };
10798
10799 // Expand Y = MAX(A, B) -> Y = (A > B) ? A : B
10800 // -> Y = (A < B) ? B : A
10801 // -> Y = (A >= B) ? A : B
10802 // -> Y = (A <= B) ? B : A
10803 switch (Opcode) {
10804 case ISD::SMAX:
10805 return buildMinMax(ISD::SETGT, ISD::SETGE, ISD::SETLT, ISD::SETLE);
10806 case ISD::SMIN:
10807 return buildMinMax(ISD::SETLT, ISD::SETLE, ISD::SETGT, ISD::SETGE);
10808 case ISD::UMAX:
10809 return buildMinMax(ISD::SETUGT, ISD::SETUGE, ISD::SETULT, ISD::SETULE);
10810 case ISD::UMIN:
10811 return buildMinMax(ISD::SETULT, ISD::SETULE, ISD::SETUGT, ISD::SETUGE);
10812 }
10813
10814 llvm_unreachable("How did we get here?");
10815 }
10816
expandAddSubSat(SDNode * Node,SelectionDAG & DAG) const10817 SDValue TargetLowering::expandAddSubSat(SDNode *Node, SelectionDAG &DAG) const {
10818 unsigned Opcode = Node->getOpcode();
10819 SDValue LHS = Node->getOperand(0);
10820 SDValue RHS = Node->getOperand(1);
10821 EVT VT = LHS.getValueType();
10822 SDLoc dl(Node);
10823
10824 assert(VT == RHS.getValueType() && "Expected operands to be the same type");
10825 assert(VT.isInteger() && "Expected operands to be integers");
10826
10827 // usub.sat(a, b) -> umax(a, b) - b
10828 if (Opcode == ISD::USUBSAT && isOperationLegal(ISD::UMAX, VT)) {
10829 SDValue Max = DAG.getNode(ISD::UMAX, dl, VT, LHS, RHS);
10830 return DAG.getNode(ISD::SUB, dl, VT, Max, RHS);
10831 }
10832
10833 // uadd.sat(a, b) -> umin(a, ~b) + b
10834 if (Opcode == ISD::UADDSAT && isOperationLegal(ISD::UMIN, VT)) {
10835 SDValue InvRHS = DAG.getNOT(dl, RHS, VT);
10836 SDValue Min = DAG.getNode(ISD::UMIN, dl, VT, LHS, InvRHS);
10837 return DAG.getNode(ISD::ADD, dl, VT, Min, RHS);
10838 }
10839
10840 unsigned OverflowOp;
10841 switch (Opcode) {
10842 case ISD::SADDSAT:
10843 OverflowOp = ISD::SADDO;
10844 break;
10845 case ISD::UADDSAT:
10846 OverflowOp = ISD::UADDO;
10847 break;
10848 case ISD::SSUBSAT:
10849 OverflowOp = ISD::SSUBO;
10850 break;
10851 case ISD::USUBSAT:
10852 OverflowOp = ISD::USUBO;
10853 break;
10854 default:
10855 llvm_unreachable("Expected method to receive signed or unsigned saturation "
10856 "addition or subtraction node.");
10857 }
10858
10859 // FIXME: Should really try to split the vector in case it's legal on a
10860 // subvector.
10861 if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
10862 return DAG.UnrollVectorOp(Node);
10863
10864 unsigned BitWidth = LHS.getScalarValueSizeInBits();
10865 EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
10866 SDValue Result = DAG.getNode(OverflowOp, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
10867 SDValue SumDiff = Result.getValue(0);
10868 SDValue Overflow = Result.getValue(1);
10869 SDValue Zero = DAG.getConstant(0, dl, VT);
10870 SDValue AllOnes = DAG.getAllOnesConstant(dl, VT);
10871
10872 if (Opcode == ISD::UADDSAT) {
10873 if (getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) {
10874 // (LHS + RHS) | OverflowMask
10875 SDValue OverflowMask = DAG.getSExtOrTrunc(Overflow, dl, VT);
10876 return DAG.getNode(ISD::OR, dl, VT, SumDiff, OverflowMask);
10877 }
10878 // Overflow ? 0xffff.... : (LHS + RHS)
10879 return DAG.getSelect(dl, VT, Overflow, AllOnes, SumDiff);
10880 }
10881
10882 if (Opcode == ISD::USUBSAT) {
10883 if (getBooleanContents(VT) == ZeroOrNegativeOneBooleanContent) {
10884 // (LHS - RHS) & ~OverflowMask
10885 SDValue OverflowMask = DAG.getSExtOrTrunc(Overflow, dl, VT);
10886 SDValue Not = DAG.getNOT(dl, OverflowMask, VT);
10887 return DAG.getNode(ISD::AND, dl, VT, SumDiff, Not);
10888 }
10889 // Overflow ? 0 : (LHS - RHS)
10890 return DAG.getSelect(dl, VT, Overflow, Zero, SumDiff);
10891 }
10892
10893 if (Opcode == ISD::SADDSAT || Opcode == ISD::SSUBSAT) {
10894 APInt MinVal = APInt::getSignedMinValue(BitWidth);
10895 APInt MaxVal = APInt::getSignedMaxValue(BitWidth);
10896
10897 KnownBits KnownLHS = DAG.computeKnownBits(LHS);
10898 KnownBits KnownRHS = DAG.computeKnownBits(RHS);
10899
10900 // If either of the operand signs are known, then they are guaranteed to
10901 // only saturate in one direction. If non-negative they will saturate
10902 // towards SIGNED_MAX, if negative they will saturate towards SIGNED_MIN.
10903 //
10904 // In the case of ISD::SSUBSAT, 'x - y' is equivalent to 'x + (-y)', so the
10905 // sign of 'y' has to be flipped.
10906
10907 bool LHSIsNonNegative = KnownLHS.isNonNegative();
10908 bool RHSIsNonNegative = Opcode == ISD::SADDSAT ? KnownRHS.isNonNegative()
10909 : KnownRHS.isNegative();
10910 if (LHSIsNonNegative || RHSIsNonNegative) {
10911 SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
10912 return DAG.getSelect(dl, VT, Overflow, SatMax, SumDiff);
10913 }
10914
10915 bool LHSIsNegative = KnownLHS.isNegative();
10916 bool RHSIsNegative = Opcode == ISD::SADDSAT ? KnownRHS.isNegative()
10917 : KnownRHS.isNonNegative();
10918 if (LHSIsNegative || RHSIsNegative) {
10919 SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
10920 return DAG.getSelect(dl, VT, Overflow, SatMin, SumDiff);
10921 }
10922 }
10923
10924 // Overflow ? (SumDiff >> BW) ^ MinVal : SumDiff
10925 APInt MinVal = APInt::getSignedMinValue(BitWidth);
10926 SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
10927 SDValue Shift = DAG.getNode(ISD::SRA, dl, VT, SumDiff,
10928 DAG.getConstant(BitWidth - 1, dl, VT));
10929 Result = DAG.getNode(ISD::XOR, dl, VT, Shift, SatMin);
10930 return DAG.getSelect(dl, VT, Overflow, Result, SumDiff);
10931 }
10932
expandCMP(SDNode * Node,SelectionDAG & DAG) const10933 SDValue TargetLowering::expandCMP(SDNode *Node, SelectionDAG &DAG) const {
10934 unsigned Opcode = Node->getOpcode();
10935 SDValue LHS = Node->getOperand(0);
10936 SDValue RHS = Node->getOperand(1);
10937 EVT VT = LHS.getValueType();
10938 EVT ResVT = Node->getValueType(0);
10939 EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
10940 SDLoc dl(Node);
10941
10942 auto LTPredicate = (Opcode == ISD::UCMP ? ISD::SETULT : ISD::SETLT);
10943 auto GTPredicate = (Opcode == ISD::UCMP ? ISD::SETUGT : ISD::SETGT);
10944 SDValue IsLT = DAG.getSetCC(dl, BoolVT, LHS, RHS, LTPredicate);
10945 SDValue IsGT = DAG.getSetCC(dl, BoolVT, LHS, RHS, GTPredicate);
10946
10947 // We can't perform arithmetic on i1 values. Extending them would
10948 // probably result in worse codegen, so let's just use two selects instead.
10949 // Some targets are also just better off using selects rather than subtraction
10950 // because one of the conditions can be merged with one of the selects.
10951 // And finally, if we don't know the contents of high bits of a boolean value
10952 // we can't perform any arithmetic either.
10953 if (shouldExpandCmpUsingSelects(VT) || BoolVT.getScalarSizeInBits() == 1 ||
10954 getBooleanContents(BoolVT) == UndefinedBooleanContent) {
10955 SDValue SelectZeroOrOne =
10956 DAG.getSelect(dl, ResVT, IsGT, DAG.getConstant(1, dl, ResVT),
10957 DAG.getConstant(0, dl, ResVT));
10958 return DAG.getSelect(dl, ResVT, IsLT, DAG.getAllOnesConstant(dl, ResVT),
10959 SelectZeroOrOne);
10960 }
10961
10962 if (getBooleanContents(BoolVT) == ZeroOrNegativeOneBooleanContent)
10963 std::swap(IsGT, IsLT);
10964 return DAG.getSExtOrTrunc(DAG.getNode(ISD::SUB, dl, BoolVT, IsGT, IsLT), dl,
10965 ResVT);
10966 }
10967
expandShlSat(SDNode * Node,SelectionDAG & DAG) const10968 SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
10969 unsigned Opcode = Node->getOpcode();
10970 bool IsSigned = Opcode == ISD::SSHLSAT;
10971 SDValue LHS = Node->getOperand(0);
10972 SDValue RHS = Node->getOperand(1);
10973 EVT VT = LHS.getValueType();
10974 SDLoc dl(Node);
10975
10976 assert((Node->getOpcode() == ISD::SSHLSAT ||
10977 Node->getOpcode() == ISD::USHLSAT) &&
10978 "Expected a SHLSAT opcode");
10979 assert(VT == RHS.getValueType() && "Expected operands to be the same type");
10980 assert(VT.isInteger() && "Expected operands to be integers");
10981
10982 if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
10983 return DAG.UnrollVectorOp(Node);
10984
10985 // If LHS != (LHS << RHS) >> RHS, we have overflow and must saturate.
10986
10987 unsigned BW = VT.getScalarSizeInBits();
10988 EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
10989 SDValue Result = DAG.getNode(ISD::SHL, dl, VT, LHS, RHS);
10990 SDValue Orig =
10991 DAG.getNode(IsSigned ? ISD::SRA : ISD::SRL, dl, VT, Result, RHS);
10992
10993 SDValue SatVal;
10994 if (IsSigned) {
10995 SDValue SatMin = DAG.getConstant(APInt::getSignedMinValue(BW), dl, VT);
10996 SDValue SatMax = DAG.getConstant(APInt::getSignedMaxValue(BW), dl, VT);
10997 SDValue Cond =
10998 DAG.getSetCC(dl, BoolVT, LHS, DAG.getConstant(0, dl, VT), ISD::SETLT);
10999 SatVal = DAG.getSelect(dl, VT, Cond, SatMin, SatMax);
11000 } else {
11001 SatVal = DAG.getConstant(APInt::getMaxValue(BW), dl, VT);
11002 }
11003 SDValue Cond = DAG.getSetCC(dl, BoolVT, LHS, Orig, ISD::SETNE);
11004 return DAG.getSelect(dl, VT, Cond, SatVal, Result);
11005 }
11006
forceExpandMultiply(SelectionDAG & DAG,const SDLoc & dl,bool Signed,SDValue & Lo,SDValue & Hi,SDValue LHS,SDValue RHS,SDValue HiLHS,SDValue HiRHS) const11007 void TargetLowering::forceExpandMultiply(SelectionDAG &DAG, const SDLoc &dl,
11008 bool Signed, SDValue &Lo, SDValue &Hi,
11009 SDValue LHS, SDValue RHS,
11010 SDValue HiLHS, SDValue HiRHS) const {
11011 EVT VT = LHS.getValueType();
11012 assert(RHS.getValueType() == VT && "Mismatching operand types");
11013
11014 assert((HiLHS && HiRHS) || (!HiLHS && !HiRHS));
11015 assert((!Signed || !HiLHS) &&
11016 "Signed flag should only be set when HiLHS and RiRHS are null");
11017
11018 // We'll expand the multiplication by brute force because we have no other
11019 // options. This is a trivially-generalized version of the code from
11020 // Hacker's Delight (itself derived from Knuth's Algorithm M from section
11021 // 4.3.1). If Signed is set, we can use arithmetic right shifts to propagate
11022 // sign bits while calculating the Hi half.
11023 unsigned Bits = VT.getSizeInBits();
11024 unsigned HalfBits = Bits / 2;
11025 SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
11026 SDValue LL = DAG.getNode(ISD::AND, dl, VT, LHS, Mask);
11027 SDValue RL = DAG.getNode(ISD::AND, dl, VT, RHS, Mask);
11028
11029 SDValue T = DAG.getNode(ISD::MUL, dl, VT, LL, RL);
11030 SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);
11031
11032 SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
11033 // This is always an unsigned shift.
11034 SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
11035
11036 unsigned ShiftOpc = Signed ? ISD::SRA : ISD::SRL;
11037 SDValue LH = DAG.getNode(ShiftOpc, dl, VT, LHS, Shift);
11038 SDValue RH = DAG.getNode(ShiftOpc, dl, VT, RHS, Shift);
11039
11040 SDValue U =
11041 DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RL), TH);
11042 SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
11043 SDValue UH = DAG.getNode(ShiftOpc, dl, VT, U, Shift);
11044
11045 SDValue V =
11046 DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LL, RH), UL);
11047 SDValue VH = DAG.getNode(ShiftOpc, dl, VT, V, Shift);
11048
11049 Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
11050 DAG.getNode(ISD::SHL, dl, VT, V, Shift));
11051
11052 Hi = DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LH, RH),
11053 DAG.getNode(ISD::ADD, dl, VT, UH, VH));
11054
11055 // If HiLHS and HiRHS are set, multiply them by the opposite low part and add
11056 // the products to Hi.
11057 if (HiLHS) {
11058 Hi = DAG.getNode(ISD::ADD, dl, VT, Hi,
11059 DAG.getNode(ISD::ADD, dl, VT,
11060 DAG.getNode(ISD::MUL, dl, VT, HiRHS, LHS),
11061 DAG.getNode(ISD::MUL, dl, VT, RHS, HiLHS)));
11062 }
11063 }
11064
forceExpandWideMUL(SelectionDAG & DAG,const SDLoc & dl,bool Signed,const SDValue LHS,const SDValue RHS,SDValue & Lo,SDValue & Hi) const11065 void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
11066 bool Signed, const SDValue LHS,
11067 const SDValue RHS, SDValue &Lo,
11068 SDValue &Hi) const {
11069 EVT VT = LHS.getValueType();
11070 assert(RHS.getValueType() == VT && "Mismatching operand types");
11071 EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
11072 // We can fall back to a libcall with an illegal type for the MUL if we
11073 // have a libcall big enough.
11074 RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
11075 if (WideVT == MVT::i16)
11076 LC = RTLIB::MUL_I16;
11077 else if (WideVT == MVT::i32)
11078 LC = RTLIB::MUL_I32;
11079 else if (WideVT == MVT::i64)
11080 LC = RTLIB::MUL_I64;
11081 else if (WideVT == MVT::i128)
11082 LC = RTLIB::MUL_I128;
11083
11084 if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
11085 forceExpandMultiply(DAG, dl, Signed, Lo, Hi, LHS, RHS);
11086 return;
11087 }
11088
11089 SDValue HiLHS, HiRHS;
11090 if (Signed) {
11091 // The high part is obtained by SRA'ing all but one of the bits of low
11092 // part.
11093 unsigned LoSize = VT.getFixedSizeInBits();
11094 SDValue Shift = DAG.getShiftAmountConstant(LoSize - 1, VT, dl);
11095 HiLHS = DAG.getNode(ISD::SRA, dl, VT, LHS, Shift);
11096 HiRHS = DAG.getNode(ISD::SRA, dl, VT, RHS, Shift);
11097 } else {
11098 HiLHS = DAG.getConstant(0, dl, VT);
11099 HiRHS = DAG.getConstant(0, dl, VT);
11100 }
11101
11102 // Attempt a libcall.
11103 SDValue Ret;
11104 TargetLowering::MakeLibCallOptions CallOptions;
11105 CallOptions.setIsSigned(Signed);
11106 CallOptions.setIsPostTypeLegalization(true);
11107 if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) {
11108 // Halves of WideVT are packed into registers in different order
11109 // depending on platform endianness. This is usually handled by
11110 // the C calling convention, but we can't defer to it in
11111 // the legalizer.
11112 SDValue Args[] = {LHS, HiLHS, RHS, HiRHS};
11113 Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
11114 } else {
11115 SDValue Args[] = {HiLHS, LHS, HiRHS, RHS};
11116 Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
11117 }
11118 assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
11119 "Ret value is a collection of constituent nodes holding result.");
11120 if (DAG.getDataLayout().isLittleEndian()) {
11121 // Same as above.
11122 Lo = Ret.getOperand(0);
11123 Hi = Ret.getOperand(1);
11124 } else {
11125 Lo = Ret.getOperand(1);
11126 Hi = Ret.getOperand(0);
11127 }
11128 }
11129
11130 SDValue
expandFixedPointMul(SDNode * Node,SelectionDAG & DAG) const11131 TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
11132 assert((Node->getOpcode() == ISD::SMULFIX ||
11133 Node->getOpcode() == ISD::UMULFIX ||
11134 Node->getOpcode() == ISD::SMULFIXSAT ||
11135 Node->getOpcode() == ISD::UMULFIXSAT) &&
11136 "Expected a fixed point multiplication opcode");
11137
11138 SDLoc dl(Node);
11139 SDValue LHS = Node->getOperand(0);
11140 SDValue RHS = Node->getOperand(1);
11141 EVT VT = LHS.getValueType();
11142 unsigned Scale = Node->getConstantOperandVal(2);
11143 bool Saturating = (Node->getOpcode() == ISD::SMULFIXSAT ||
11144 Node->getOpcode() == ISD::UMULFIXSAT);
11145 bool Signed = (Node->getOpcode() == ISD::SMULFIX ||
11146 Node->getOpcode() == ISD::SMULFIXSAT);
11147 EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
11148 unsigned VTSize = VT.getScalarSizeInBits();
11149
11150 if (!Scale) {
11151 // [us]mul.fix(a, b, 0) -> mul(a, b)
11152 if (!Saturating) {
11153 if (isOperationLegalOrCustom(ISD::MUL, VT))
11154 return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
11155 } else if (Signed && isOperationLegalOrCustom(ISD::SMULO, VT)) {
11156 SDValue Result =
11157 DAG.getNode(ISD::SMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
11158 SDValue Product = Result.getValue(0);
11159 SDValue Overflow = Result.getValue(1);
11160 SDValue Zero = DAG.getConstant(0, dl, VT);
11161
11162 APInt MinVal = APInt::getSignedMinValue(VTSize);
11163 APInt MaxVal = APInt::getSignedMaxValue(VTSize);
11164 SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
11165 SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
11166 // Xor the inputs, if resulting sign bit is 0 the product will be
11167 // positive, else negative.
11168 SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
11169 SDValue ProdNeg = DAG.getSetCC(dl, BoolVT, Xor, Zero, ISD::SETLT);
11170 Result = DAG.getSelect(dl, VT, ProdNeg, SatMin, SatMax);
11171 return DAG.getSelect(dl, VT, Overflow, Result, Product);
11172 } else if (!Signed && isOperationLegalOrCustom(ISD::UMULO, VT)) {
11173 SDValue Result =
11174 DAG.getNode(ISD::UMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
11175 SDValue Product = Result.getValue(0);
11176 SDValue Overflow = Result.getValue(1);
11177
11178 APInt MaxVal = APInt::getMaxValue(VTSize);
11179 SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
11180 return DAG.getSelect(dl, VT, Overflow, SatMax, Product);
11181 }
11182 }
11183
11184 assert(((Signed && Scale < VTSize) || (!Signed && Scale <= VTSize)) &&
11185 "Expected scale to be less than the number of bits if signed or at "
11186 "most the number of bits if unsigned.");
11187 assert(LHS.getValueType() == RHS.getValueType() &&
11188 "Expected both operands to be the same type");
11189
11190 // Get the upper and lower bits of the result.
11191 SDValue Lo, Hi;
11192 unsigned LoHiOp = Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
11193 unsigned HiOp = Signed ? ISD::MULHS : ISD::MULHU;
11194 EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VTSize * 2);
11195 if (VT.isVector())
11196 WideVT =
11197 EVT::getVectorVT(*DAG.getContext(), WideVT, VT.getVectorElementCount());
11198 if (isOperationLegalOrCustom(LoHiOp, VT)) {
11199 SDValue Result = DAG.getNode(LoHiOp, dl, DAG.getVTList(VT, VT), LHS, RHS);
11200 Lo = Result.getValue(0);
11201 Hi = Result.getValue(1);
11202 } else if (isOperationLegalOrCustom(HiOp, VT)) {
11203 Lo = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
11204 Hi = DAG.getNode(HiOp, dl, VT, LHS, RHS);
11205 } else if (isOperationLegalOrCustom(ISD::MUL, WideVT)) {
11206 // Try for a multiplication using a wider type.
11207 unsigned Ext = Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
11208 SDValue LHSExt = DAG.getNode(Ext, dl, WideVT, LHS);
11209 SDValue RHSExt = DAG.getNode(Ext, dl, WideVT, RHS);
11210 SDValue Res = DAG.getNode(ISD::MUL, dl, WideVT, LHSExt, RHSExt);
11211 Lo = DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
11212 SDValue Shifted =
11213 DAG.getNode(ISD::SRA, dl, WideVT, Res,
11214 DAG.getShiftAmountConstant(VTSize, WideVT, dl));
11215 Hi = DAG.getNode(ISD::TRUNCATE, dl, VT, Shifted);
11216 } else if (VT.isVector()) {
11217 return SDValue();
11218 } else {
11219 forceExpandWideMUL(DAG, dl, Signed, LHS, RHS, Lo, Hi);
11220 }
11221
11222 if (Scale == VTSize)
11223 // Result is just the top half since we'd be shifting by the width of the
11224 // operand. Overflow impossible so this works for both UMULFIX and
11225 // UMULFIXSAT.
11226 return Hi;
11227
11228 // The result will need to be shifted right by the scale since both operands
11229 // are scaled. The result is given to us in 2 halves, so we only want part of
11230 // both in the result.
11231 SDValue Result = DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo,
11232 DAG.getShiftAmountConstant(Scale, VT, dl));
11233 if (!Saturating)
11234 return Result;
11235
11236 if (!Signed) {
11237 // Unsigned overflow happened if the upper (VTSize - Scale) bits (of the
11238 // widened multiplication) aren't all zeroes.
11239
11240 // Saturate to max if ((Hi >> Scale) != 0),
11241 // which is the same as if (Hi > ((1 << Scale) - 1))
11242 APInt MaxVal = APInt::getMaxValue(VTSize);
11243 SDValue LowMask = DAG.getConstant(APInt::getLowBitsSet(VTSize, Scale),
11244 dl, VT);
11245 Result = DAG.getSelectCC(dl, Hi, LowMask,
11246 DAG.getConstant(MaxVal, dl, VT), Result,
11247 ISD::SETUGT);
11248
11249 return Result;
11250 }
11251
11252 // Signed overflow happened if the upper (VTSize - Scale + 1) bits (of the
11253 // widened multiplication) aren't all ones or all zeroes.
11254
11255 SDValue SatMin = DAG.getConstant(APInt::getSignedMinValue(VTSize), dl, VT);
11256 SDValue SatMax = DAG.getConstant(APInt::getSignedMaxValue(VTSize), dl, VT);
11257
11258 if (Scale == 0) {
11259 SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, Lo,
11260 DAG.getShiftAmountConstant(VTSize - 1, VT, dl));
11261 SDValue Overflow = DAG.getSetCC(dl, BoolVT, Hi, Sign, ISD::SETNE);
11262 // Saturated to SatMin if wide product is negative, and SatMax if wide
11263 // product is positive ...
11264 SDValue Zero = DAG.getConstant(0, dl, VT);
11265 SDValue ResultIfOverflow = DAG.getSelectCC(dl, Hi, Zero, SatMin, SatMax,
11266 ISD::SETLT);
11267 // ... but only if we overflowed.
11268 return DAG.getSelect(dl, VT, Overflow, ResultIfOverflow, Result);
11269 }
11270
11271 // We handled Scale==0 above so all the bits to examine is in Hi.
11272
11273 // Saturate to max if ((Hi >> (Scale - 1)) > 0),
11274 // which is the same as if (Hi > (1 << (Scale - 1)) - 1)
11275 SDValue LowMask = DAG.getConstant(APInt::getLowBitsSet(VTSize, Scale - 1),
11276 dl, VT);
11277 Result = DAG.getSelectCC(dl, Hi, LowMask, SatMax, Result, ISD::SETGT);
11278 // Saturate to min if (Hi >> (Scale - 1)) < -1),
11279 // which is the same as if (HI < (-1 << (Scale - 1))
11280 SDValue HighMask =
11281 DAG.getConstant(APInt::getHighBitsSet(VTSize, VTSize - Scale + 1),
11282 dl, VT);
11283 Result = DAG.getSelectCC(dl, Hi, HighMask, SatMin, Result, ISD::SETLT);
11284 return Result;
11285 }
11286
11287 SDValue
expandFixedPointDiv(unsigned Opcode,const SDLoc & dl,SDValue LHS,SDValue RHS,unsigned Scale,SelectionDAG & DAG) const11288 TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl,
11289 SDValue LHS, SDValue RHS,
11290 unsigned Scale, SelectionDAG &DAG) const {
11291 assert((Opcode == ISD::SDIVFIX || Opcode == ISD::SDIVFIXSAT ||
11292 Opcode == ISD::UDIVFIX || Opcode == ISD::UDIVFIXSAT) &&
11293 "Expected a fixed point division opcode");
11294
11295 EVT VT = LHS.getValueType();
11296 bool Signed = Opcode == ISD::SDIVFIX || Opcode == ISD::SDIVFIXSAT;
11297 bool Saturating = Opcode == ISD::SDIVFIXSAT || Opcode == ISD::UDIVFIXSAT;
11298 EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
11299
11300 // If there is enough room in the type to upscale the LHS or downscale the
11301 // RHS before the division, we can perform it in this type without having to
11302 // resize. For signed operations, the LHS headroom is the number of
11303 // redundant sign bits, and for unsigned ones it is the number of zeroes.
11304 // The headroom for the RHS is the number of trailing zeroes.
11305 unsigned LHSLead = Signed ? DAG.ComputeNumSignBits(LHS) - 1
11306 : DAG.computeKnownBits(LHS).countMinLeadingZeros();
11307 unsigned RHSTrail = DAG.computeKnownBits(RHS).countMinTrailingZeros();
11308
11309 // For signed saturating operations, we need to be able to detect true integer
11310 // division overflow; that is, when you have MIN / -EPS. However, this
11311 // is undefined behavior and if we emit divisions that could take such
11312 // values it may cause undesired behavior (arithmetic exceptions on x86, for
11313 // example).
11314 // Avoid this by requiring an extra bit so that we never get this case.
11315 // FIXME: This is a bit unfortunate as it means that for an 8-bit 7-scale
11316 // signed saturating division, we need to emit a whopping 32-bit division.
11317 if (LHSLead + RHSTrail < Scale + (unsigned)(Saturating && Signed))
11318 return SDValue();
11319
11320 unsigned LHSShift = std::min(LHSLead, Scale);
11321 unsigned RHSShift = Scale - LHSShift;
11322
11323 // At this point, we know that if we shift the LHS up by LHSShift and the
11324 // RHS down by RHSShift, we can emit a regular division with a final scaling
11325 // factor of Scale.
11326
11327 if (LHSShift)
11328 LHS = DAG.getNode(ISD::SHL, dl, VT, LHS,
11329 DAG.getShiftAmountConstant(LHSShift, VT, dl));
11330 if (RHSShift)
11331 RHS = DAG.getNode(Signed ? ISD::SRA : ISD::SRL, dl, VT, RHS,
11332 DAG.getShiftAmountConstant(RHSShift, VT, dl));
11333
11334 SDValue Quot;
11335 if (Signed) {
11336 // For signed operations, if the resulting quotient is negative and the
11337 // remainder is nonzero, subtract 1 from the quotient to round towards
11338 // negative infinity.
11339 SDValue Rem;
11340 // FIXME: Ideally we would always produce an SDIVREM here, but if the
11341 // type isn't legal, SDIVREM cannot be expanded. There is no reason why
11342 // we couldn't just form a libcall, but the type legalizer doesn't do it.
11343 if (isTypeLegal(VT) &&
11344 isOperationLegalOrCustom(ISD::SDIVREM, VT)) {
11345 Quot = DAG.getNode(ISD::SDIVREM, dl,
11346 DAG.getVTList(VT, VT),
11347 LHS, RHS);
11348 Rem = Quot.getValue(1);
11349 Quot = Quot.getValue(0);
11350 } else {
11351 Quot = DAG.getNode(ISD::SDIV, dl, VT,
11352 LHS, RHS);
11353 Rem = DAG.getNode(ISD::SREM, dl, VT,
11354 LHS, RHS);
11355 }
11356 SDValue Zero = DAG.getConstant(0, dl, VT);
11357 SDValue RemNonZero = DAG.getSetCC(dl, BoolVT, Rem, Zero, ISD::SETNE);
11358 SDValue LHSNeg = DAG.getSetCC(dl, BoolVT, LHS, Zero, ISD::SETLT);
11359 SDValue RHSNeg = DAG.getSetCC(dl, BoolVT, RHS, Zero, ISD::SETLT);
11360 SDValue QuotNeg = DAG.getNode(ISD::XOR, dl, BoolVT, LHSNeg, RHSNeg);
11361 SDValue Sub1 = DAG.getNode(ISD::SUB, dl, VT, Quot,
11362 DAG.getConstant(1, dl, VT));
11363 Quot = DAG.getSelect(dl, VT,
11364 DAG.getNode(ISD::AND, dl, BoolVT, RemNonZero, QuotNeg),
11365 Sub1, Quot);
11366 } else
11367 Quot = DAG.getNode(ISD::UDIV, dl, VT,
11368 LHS, RHS);
11369
11370 return Quot;
11371 }
11372
expandUADDSUBO(SDNode * Node,SDValue & Result,SDValue & Overflow,SelectionDAG & DAG) const11373 void TargetLowering::expandUADDSUBO(
11374 SDNode *Node, SDValue &Result, SDValue &Overflow, SelectionDAG &DAG) const {
11375 SDLoc dl(Node);
11376 SDValue LHS = Node->getOperand(0);
11377 SDValue RHS = Node->getOperand(1);
11378 bool IsAdd = Node->getOpcode() == ISD::UADDO;
11379
11380 // If UADDO_CARRY/SUBO_CARRY is legal, use that instead.
11381 unsigned OpcCarry = IsAdd ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
11382 if (isOperationLegalOrCustom(OpcCarry, Node->getValueType(0))) {
11383 SDValue CarryIn = DAG.getConstant(0, dl, Node->getValueType(1));
11384 SDValue NodeCarry = DAG.getNode(OpcCarry, dl, Node->getVTList(),
11385 { LHS, RHS, CarryIn });
11386 Result = SDValue(NodeCarry.getNode(), 0);
11387 Overflow = SDValue(NodeCarry.getNode(), 1);
11388 return;
11389 }
11390
11391 Result = DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, dl,
11392 LHS.getValueType(), LHS, RHS);
11393
11394 EVT ResultType = Node->getValueType(1);
11395 EVT SetCCType = getSetCCResultType(
11396 DAG.getDataLayout(), *DAG.getContext(), Node->getValueType(0));
11397 SDValue SetCC;
11398 if (IsAdd && isOneConstant(RHS)) {
11399 // Special case: uaddo X, 1 overflowed if X+1 is 0. This potential reduces
11400 // the live range of X. We assume comparing with 0 is cheap.
11401 // The general case (X + C) < C is not necessarily beneficial. Although we
11402 // reduce the live range of X, we may introduce the materialization of
11403 // constant C.
11404 SetCC =
11405 DAG.getSetCC(dl, SetCCType, Result,
11406 DAG.getConstant(0, dl, Node->getValueType(0)), ISD::SETEQ);
11407 } else if (IsAdd && isAllOnesConstant(RHS)) {
11408 // Special case: uaddo X, -1 overflows if X != 0.
11409 SetCC =
11410 DAG.getSetCC(dl, SetCCType, LHS,
11411 DAG.getConstant(0, dl, Node->getValueType(0)), ISD::SETNE);
11412 } else {
11413 ISD::CondCode CC = IsAdd ? ISD::SETULT : ISD::SETUGT;
11414 SetCC = DAG.getSetCC(dl, SetCCType, Result, LHS, CC);
11415 }
11416 Overflow = DAG.getBoolExtOrTrunc(SetCC, dl, ResultType, ResultType);
11417 }
11418
expandSADDSUBO(SDNode * Node,SDValue & Result,SDValue & Overflow,SelectionDAG & DAG) const11419 void TargetLowering::expandSADDSUBO(
11420 SDNode *Node, SDValue &Result, SDValue &Overflow, SelectionDAG &DAG) const {
11421 SDLoc dl(Node);
11422 SDValue LHS = Node->getOperand(0);
11423 SDValue RHS = Node->getOperand(1);
11424 bool IsAdd = Node->getOpcode() == ISD::SADDO;
11425
11426 Result = DAG.getNode(IsAdd ? ISD::ADD : ISD::SUB, dl,
11427 LHS.getValueType(), LHS, RHS);
11428
11429 EVT ResultType = Node->getValueType(1);
11430 EVT OType = getSetCCResultType(
11431 DAG.getDataLayout(), *DAG.getContext(), Node->getValueType(0));
11432
11433 // If SADDSAT/SSUBSAT is legal, compare results to detect overflow.
11434 unsigned OpcSat = IsAdd ? ISD::SADDSAT : ISD::SSUBSAT;
11435 if (isOperationLegal(OpcSat, LHS.getValueType())) {
11436 SDValue Sat = DAG.getNode(OpcSat, dl, LHS.getValueType(), LHS, RHS);
11437 SDValue SetCC = DAG.getSetCC(dl, OType, Result, Sat, ISD::SETNE);
11438 Overflow = DAG.getBoolExtOrTrunc(SetCC, dl, ResultType, ResultType);
11439 return;
11440 }
11441
11442 SDValue Zero = DAG.getConstant(0, dl, LHS.getValueType());
11443
11444 // For an addition, the result should be less than one of the operands (LHS)
11445 // if and only if the other operand (RHS) is negative, otherwise there will
11446 // be overflow.
11447 // For a subtraction, the result should be less than one of the operands
11448 // (LHS) if and only if the other operand (RHS) is (non-zero) positive,
11449 // otherwise there will be overflow.
11450 SDValue ResultLowerThanLHS = DAG.getSetCC(dl, OType, Result, LHS, ISD::SETLT);
11451 SDValue ConditionRHS =
11452 DAG.getSetCC(dl, OType, RHS, Zero, IsAdd ? ISD::SETLT : ISD::SETGT);
11453
11454 Overflow = DAG.getBoolExtOrTrunc(
11455 DAG.getNode(ISD::XOR, dl, OType, ConditionRHS, ResultLowerThanLHS), dl,
11456 ResultType, ResultType);
11457 }
11458
expandMULO(SDNode * Node,SDValue & Result,SDValue & Overflow,SelectionDAG & DAG) const11459 bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
11460 SDValue &Overflow, SelectionDAG &DAG) const {
11461 SDLoc dl(Node);
11462 EVT VT = Node->getValueType(0);
11463 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
11464 SDValue LHS = Node->getOperand(0);
11465 SDValue RHS = Node->getOperand(1);
11466 bool isSigned = Node->getOpcode() == ISD::SMULO;
11467
11468 // For power-of-two multiplications we can use a simpler shift expansion.
11469 if (ConstantSDNode *RHSC = isConstOrConstSplat(RHS)) {
11470 const APInt &C = RHSC->getAPIntValue();
11471 // mulo(X, 1 << S) -> { X << S, (X << S) >> S != X }
11472 if (C.isPowerOf2()) {
11473 // smulo(x, signed_min) is same as umulo(x, signed_min).
11474 bool UseArithShift = isSigned && !C.isMinSignedValue();
11475 SDValue ShiftAmt = DAG.getShiftAmountConstant(C.logBase2(), VT, dl);
11476 Result = DAG.getNode(ISD::SHL, dl, VT, LHS, ShiftAmt);
11477 Overflow = DAG.getSetCC(dl, SetCCVT,
11478 DAG.getNode(UseArithShift ? ISD::SRA : ISD::SRL,
11479 dl, VT, Result, ShiftAmt),
11480 LHS, ISD::SETNE);
11481 return true;
11482 }
11483 }
11484
11485 EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits() * 2);
11486 if (VT.isVector())
11487 WideVT =
11488 EVT::getVectorVT(*DAG.getContext(), WideVT, VT.getVectorElementCount());
11489
11490 SDValue BottomHalf;
11491 SDValue TopHalf;
11492 static const unsigned Ops[2][3] =
11493 { { ISD::MULHU, ISD::UMUL_LOHI, ISD::ZERO_EXTEND },
11494 { ISD::MULHS, ISD::SMUL_LOHI, ISD::SIGN_EXTEND }};
11495 if (isOperationLegalOrCustom(Ops[isSigned][0], VT)) {
11496 BottomHalf = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
11497 TopHalf = DAG.getNode(Ops[isSigned][0], dl, VT, LHS, RHS);
11498 } else if (isOperationLegalOrCustom(Ops[isSigned][1], VT)) {
11499 BottomHalf = DAG.getNode(Ops[isSigned][1], dl, DAG.getVTList(VT, VT), LHS,
11500 RHS);
11501 TopHalf = BottomHalf.getValue(1);
11502 } else if (isTypeLegal(WideVT)) {
11503 LHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, LHS);
11504 RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS);
11505 SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS);
11506 BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
11507 SDValue ShiftAmt =
11508 DAG.getShiftAmountConstant(VT.getScalarSizeInBits(), WideVT, dl);
11509 TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT,
11510 DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt));
11511 } else {
11512 if (VT.isVector())
11513 return false;
11514
11515 forceExpandWideMUL(DAG, dl, isSigned, LHS, RHS, BottomHalf, TopHalf);
11516 }
11517
11518 Result = BottomHalf;
11519 if (isSigned) {
11520 SDValue ShiftAmt = DAG.getShiftAmountConstant(
11521 VT.getScalarSizeInBits() - 1, BottomHalf.getValueType(), dl);
11522 SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt);
11523 Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE);
11524 } else {
11525 Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf,
11526 DAG.getConstant(0, dl, VT), ISD::SETNE);
11527 }
11528
11529 // Truncate the result if SetCC returns a larger type than needed.
11530 EVT RType = Node->getValueType(1);
11531 if (RType.bitsLT(Overflow.getValueType()))
11532 Overflow = DAG.getNode(ISD::TRUNCATE, dl, RType, Overflow);
11533
11534 assert(RType.getSizeInBits() == Overflow.getValueSizeInBits() &&
11535 "Unexpected result type for S/UMULO legalization");
11536 return true;
11537 }
11538
expandVecReduce(SDNode * Node,SelectionDAG & DAG) const11539 SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
11540 SDLoc dl(Node);
11541 unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Node->getOpcode());
11542 SDValue Op = Node->getOperand(0);
11543 EVT VT = Op.getValueType();
11544
11545 // Try to use a shuffle reduction for power of two vectors.
11546 if (VT.isPow2VectorType()) {
11547 while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
11548 EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
11549 if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
11550 break;
11551
11552 SDValue Lo, Hi;
11553 std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
11554 Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
11555 VT = HalfVT;
11556
11557 // Stop if splitting is enough to make the reduction legal.
11558 if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
11559 return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
11560 Node->getFlags());
11561 }
11562 }
11563
11564 if (VT.isScalableVector())
11565 reportFatalInternalError(
11566 "Expanding reductions for scalable vectors is undefined.");
11567
11568 EVT EltVT = VT.getVectorElementType();
11569 unsigned NumElts = VT.getVectorNumElements();
11570
11571 SmallVector<SDValue, 8> Ops;
11572 DAG.ExtractVectorElements(Op, Ops, 0, NumElts);
11573
11574 SDValue Res = Ops[0];
11575 for (unsigned i = 1; i < NumElts; i++)
11576 Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Node->getFlags());
11577
11578 // Result type may be wider than element type.
11579 if (EltVT != Node->getValueType(0))
11580 Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res);
11581 return Res;
11582 }
11583
expandVecReduceSeq(SDNode * Node,SelectionDAG & DAG) const11584 SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const {
11585 SDLoc dl(Node);
11586 SDValue AccOp = Node->getOperand(0);
11587 SDValue VecOp = Node->getOperand(1);
11588 SDNodeFlags Flags = Node->getFlags();
11589
11590 EVT VT = VecOp.getValueType();
11591 EVT EltVT = VT.getVectorElementType();
11592
11593 if (VT.isScalableVector())
11594 report_fatal_error(
11595 "Expanding reductions for scalable vectors is undefined.");
11596
11597 unsigned NumElts = VT.getVectorNumElements();
11598
11599 SmallVector<SDValue, 8> Ops;
11600 DAG.ExtractVectorElements(VecOp, Ops, 0, NumElts);
11601
11602 unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Node->getOpcode());
11603
11604 SDValue Res = AccOp;
11605 for (unsigned i = 0; i < NumElts; i++)
11606 Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags);
11607
11608 return Res;
11609 }
11610
expandREM(SDNode * Node,SDValue & Result,SelectionDAG & DAG) const11611 bool TargetLowering::expandREM(SDNode *Node, SDValue &Result,
11612 SelectionDAG &DAG) const {
11613 EVT VT = Node->getValueType(0);
11614 SDLoc dl(Node);
11615 bool isSigned = Node->getOpcode() == ISD::SREM;
11616 unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV;
11617 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
11618 SDValue Dividend = Node->getOperand(0);
11619 SDValue Divisor = Node->getOperand(1);
11620 if (isOperationLegalOrCustom(DivRemOpc, VT)) {
11621 SDVTList VTs = DAG.getVTList(VT, VT);
11622 Result = DAG.getNode(DivRemOpc, dl, VTs, Dividend, Divisor).getValue(1);
11623 return true;
11624 }
11625 if (isOperationLegalOrCustom(DivOpc, VT)) {
11626 // X % Y -> X-X/Y*Y
11627 SDValue Divide = DAG.getNode(DivOpc, dl, VT, Dividend, Divisor);
11628 SDValue Mul = DAG.getNode(ISD::MUL, dl, VT, Divide, Divisor);
11629 Result = DAG.getNode(ISD::SUB, dl, VT, Dividend, Mul);
11630 return true;
11631 }
11632 return false;
11633 }
11634
expandFP_TO_INT_SAT(SDNode * Node,SelectionDAG & DAG) const11635 SDValue TargetLowering::expandFP_TO_INT_SAT(SDNode *Node,
11636 SelectionDAG &DAG) const {
11637 bool IsSigned = Node->getOpcode() == ISD::FP_TO_SINT_SAT;
11638 SDLoc dl(SDValue(Node, 0));
11639 SDValue Src = Node->getOperand(0);
11640
11641 // DstVT is the result type, while SatVT is the size to which we saturate
11642 EVT SrcVT = Src.getValueType();
11643 EVT DstVT = Node->getValueType(0);
11644
11645 EVT SatVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
11646 unsigned SatWidth = SatVT.getScalarSizeInBits();
11647 unsigned DstWidth = DstVT.getScalarSizeInBits();
11648 assert(SatWidth <= DstWidth &&
11649 "Expected saturation width smaller than result width");
11650
11651 // Determine minimum and maximum integer values and their corresponding
11652 // floating-point values.
11653 APInt MinInt, MaxInt;
11654 if (IsSigned) {
11655 MinInt = APInt::getSignedMinValue(SatWidth).sext(DstWidth);
11656 MaxInt = APInt::getSignedMaxValue(SatWidth).sext(DstWidth);
11657 } else {
11658 MinInt = APInt::getMinValue(SatWidth).zext(DstWidth);
11659 MaxInt = APInt::getMaxValue(SatWidth).zext(DstWidth);
11660 }
11661
11662 // We cannot risk emitting FP_TO_XINT nodes with a source VT of [b]f16, as
11663 // libcall emission cannot handle this. Large result types will fail.
11664 if (SrcVT == MVT::f16 || SrcVT == MVT::bf16) {
11665 Src = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, Src);
11666 SrcVT = Src.getValueType();
11667 }
11668
11669 const fltSemantics &Sem = SrcVT.getFltSemantics();
11670 APFloat MinFloat(Sem);
11671 APFloat MaxFloat(Sem);
11672
11673 APFloat::opStatus MinStatus =
11674 MinFloat.convertFromAPInt(MinInt, IsSigned, APFloat::rmTowardZero);
11675 APFloat::opStatus MaxStatus =
11676 MaxFloat.convertFromAPInt(MaxInt, IsSigned, APFloat::rmTowardZero);
11677 bool AreExactFloatBounds = !(MinStatus & APFloat::opStatus::opInexact) &&
11678 !(MaxStatus & APFloat::opStatus::opInexact);
11679
11680 SDValue MinFloatNode = DAG.getConstantFP(MinFloat, dl, SrcVT);
11681 SDValue MaxFloatNode = DAG.getConstantFP(MaxFloat, dl, SrcVT);
11682
11683 // If the integer bounds are exactly representable as floats and min/max are
11684 // legal, emit a min+max+fptoi sequence. Otherwise we have to use a sequence
11685 // of comparisons and selects.
11686 bool MinMaxLegal = isOperationLegal(ISD::FMINNUM, SrcVT) &&
11687 isOperationLegal(ISD::FMAXNUM, SrcVT);
11688 if (AreExactFloatBounds && MinMaxLegal) {
11689 SDValue Clamped = Src;
11690
11691 // Clamp Src by MinFloat from below. If Src is NaN the result is MinFloat.
11692 Clamped = DAG.getNode(ISD::FMAXNUM, dl, SrcVT, Clamped, MinFloatNode);
11693 // Clamp by MaxFloat from above. NaN cannot occur.
11694 Clamped = DAG.getNode(ISD::FMINNUM, dl, SrcVT, Clamped, MaxFloatNode);
11695 // Convert clamped value to integer.
11696 SDValue FpToInt = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT,
11697 dl, DstVT, Clamped);
11698
11699 // In the unsigned case we're done, because we mapped NaN to MinFloat,
11700 // which will cast to zero.
11701 if (!IsSigned)
11702 return FpToInt;
11703
11704 // Otherwise, select 0 if Src is NaN.
11705 SDValue ZeroInt = DAG.getConstant(0, dl, DstVT);
11706 EVT SetCCVT =
11707 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
11708 SDValue IsNan = DAG.getSetCC(dl, SetCCVT, Src, Src, ISD::CondCode::SETUO);
11709 return DAG.getSelect(dl, DstVT, IsNan, ZeroInt, FpToInt);
11710 }
11711
11712 SDValue MinIntNode = DAG.getConstant(MinInt, dl, DstVT);
11713 SDValue MaxIntNode = DAG.getConstant(MaxInt, dl, DstVT);
11714
11715 // Result of direct conversion. The assumption here is that the operation is
11716 // non-trapping and it's fine to apply it to an out-of-range value if we
11717 // select it away later.
11718 SDValue FpToInt =
11719 DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl, DstVT, Src);
11720
11721 SDValue Select = FpToInt;
11722
11723 EVT SetCCVT =
11724 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
11725
11726 // If Src ULT MinFloat, select MinInt. In particular, this also selects
11727 // MinInt if Src is NaN.
11728 SDValue ULT = DAG.getSetCC(dl, SetCCVT, Src, MinFloatNode, ISD::SETULT);
11729 Select = DAG.getSelect(dl, DstVT, ULT, MinIntNode, Select);
11730 // If Src OGT MaxFloat, select MaxInt.
11731 SDValue OGT = DAG.getSetCC(dl, SetCCVT, Src, MaxFloatNode, ISD::SETOGT);
11732 Select = DAG.getSelect(dl, DstVT, OGT, MaxIntNode, Select);
11733
11734 // In the unsigned case we are done, because we mapped NaN to MinInt, which
11735 // is already zero.
11736 if (!IsSigned)
11737 return Select;
11738
11739 // Otherwise, select 0 if Src is NaN.
11740 SDValue ZeroInt = DAG.getConstant(0, dl, DstVT);
11741 SDValue IsNan = DAG.getSetCC(dl, SetCCVT, Src, Src, ISD::CondCode::SETUO);
11742 return DAG.getSelect(dl, DstVT, IsNan, ZeroInt, Select);
11743 }
11744
expandRoundInexactToOdd(EVT ResultVT,SDValue Op,const SDLoc & dl,SelectionDAG & DAG) const11745 SDValue TargetLowering::expandRoundInexactToOdd(EVT ResultVT, SDValue Op,
11746 const SDLoc &dl,
11747 SelectionDAG &DAG) const {
11748 EVT OperandVT = Op.getValueType();
11749 if (OperandVT.getScalarType() == ResultVT.getScalarType())
11750 return Op;
11751 EVT ResultIntVT = ResultVT.changeTypeToInteger();
11752 // We are rounding binary64/binary128 -> binary32 -> bfloat16. This
11753 // can induce double-rounding which may alter the results. We can
11754 // correct for this using a trick explained in: Boldo, Sylvie, and
11755 // Guillaume Melquiond. "When double rounding is odd." 17th IMACS
11756 // World Congress. 2005.
11757 SDValue Narrow = DAG.getFPExtendOrRound(Op, dl, ResultVT);
11758 SDValue NarrowAsWide = DAG.getFPExtendOrRound(Narrow, dl, OperandVT);
11759
11760 // We can keep the narrow value as-is if narrowing was exact (no
11761 // rounding error), the wide value was NaN (the narrow value is also
11762 // NaN and should be preserved) or if we rounded to the odd value.
11763 SDValue NarrowBits = DAG.getNode(ISD::BITCAST, dl, ResultIntVT, Narrow);
11764 SDValue One = DAG.getConstant(1, dl, ResultIntVT);
11765 SDValue NegativeOne = DAG.getAllOnesConstant(dl, ResultIntVT);
11766 SDValue And = DAG.getNode(ISD::AND, dl, ResultIntVT, NarrowBits, One);
11767 EVT ResultIntVTCCVT = getSetCCResultType(
11768 DAG.getDataLayout(), *DAG.getContext(), And.getValueType());
11769 SDValue Zero = DAG.getConstant(0, dl, ResultIntVT);
11770 // The result is already odd so we don't need to do anything.
11771 SDValue AlreadyOdd = DAG.getSetCC(dl, ResultIntVTCCVT, And, Zero, ISD::SETNE);
11772
11773 EVT WideSetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
11774 Op.getValueType());
11775 // We keep results which are exact, odd or NaN.
11776 SDValue KeepNarrow =
11777 DAG.getSetCC(dl, WideSetCCVT, Op, NarrowAsWide, ISD::SETUEQ);
11778 KeepNarrow = DAG.getNode(ISD::OR, dl, WideSetCCVT, KeepNarrow, AlreadyOdd);
11779 // We morally performed a round-down if AbsNarrow is smaller than
11780 // AbsWide.
11781 SDValue AbsWide = DAG.getNode(ISD::FABS, dl, OperandVT, Op);
11782 SDValue AbsNarrowAsWide = DAG.getNode(ISD::FABS, dl, OperandVT, NarrowAsWide);
11783 SDValue NarrowIsRd =
11784 DAG.getSetCC(dl, WideSetCCVT, AbsWide, AbsNarrowAsWide, ISD::SETOGT);
11785 // If the narrow value is odd or exact, pick it.
11786 // Otherwise, narrow is even and corresponds to either the rounded-up
11787 // or rounded-down value. If narrow is the rounded-down value, we want
11788 // the rounded-up value as it will be odd.
11789 SDValue Adjust = DAG.getSelect(dl, ResultIntVT, NarrowIsRd, One, NegativeOne);
11790 SDValue Adjusted = DAG.getNode(ISD::ADD, dl, ResultIntVT, NarrowBits, Adjust);
11791 Op = DAG.getSelect(dl, ResultIntVT, KeepNarrow, NarrowBits, Adjusted);
11792 return DAG.getNode(ISD::BITCAST, dl, ResultVT, Op);
11793 }
11794
expandFP_ROUND(SDNode * Node,SelectionDAG & DAG) const11795 SDValue TargetLowering::expandFP_ROUND(SDNode *Node, SelectionDAG &DAG) const {
11796 assert(Node->getOpcode() == ISD::FP_ROUND && "Unexpected opcode!");
11797 SDValue Op = Node->getOperand(0);
11798 EVT VT = Node->getValueType(0);
11799 SDLoc dl(Node);
11800 if (VT.getScalarType() == MVT::bf16) {
11801 if (Node->getConstantOperandVal(1) == 1) {
11802 return DAG.getNode(ISD::FP_TO_BF16, dl, VT, Node->getOperand(0));
11803 }
11804 EVT OperandVT = Op.getValueType();
11805 SDValue IsNaN = DAG.getSetCC(
11806 dl,
11807 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), OperandVT),
11808 Op, Op, ISD::SETUO);
11809
11810 // We are rounding binary64/binary128 -> binary32 -> bfloat16. This
11811 // can induce double-rounding which may alter the results. We can
11812 // correct for this using a trick explained in: Boldo, Sylvie, and
11813 // Guillaume Melquiond. "When double rounding is odd." 17th IMACS
11814 // World Congress. 2005.
11815 EVT F32 = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
11816 EVT I32 = F32.changeTypeToInteger();
11817 Op = expandRoundInexactToOdd(F32, Op, dl, DAG);
11818 Op = DAG.getNode(ISD::BITCAST, dl, I32, Op);
11819
11820 // Conversions should set NaN's quiet bit. This also prevents NaNs from
11821 // turning into infinities.
11822 SDValue NaN =
11823 DAG.getNode(ISD::OR, dl, I32, Op, DAG.getConstant(0x400000, dl, I32));
11824
11825 // Factor in the contribution of the low 16 bits.
11826 SDValue One = DAG.getConstant(1, dl, I32);
11827 SDValue Lsb = DAG.getNode(ISD::SRL, dl, I32, Op,
11828 DAG.getShiftAmountConstant(16, I32, dl));
11829 Lsb = DAG.getNode(ISD::AND, dl, I32, Lsb, One);
11830 SDValue RoundingBias =
11831 DAG.getNode(ISD::ADD, dl, I32, DAG.getConstant(0x7fff, dl, I32), Lsb);
11832 SDValue Add = DAG.getNode(ISD::ADD, dl, I32, Op, RoundingBias);
11833
11834 // Don't round if we had a NaN, we don't want to turn 0x7fffffff into
11835 // 0x80000000.
11836 Op = DAG.getSelect(dl, I32, IsNaN, NaN, Add);
11837
11838 // Now that we have rounded, shift the bits into position.
11839 Op = DAG.getNode(ISD::SRL, dl, I32, Op,
11840 DAG.getShiftAmountConstant(16, I32, dl));
11841 Op = DAG.getNode(ISD::BITCAST, dl, I32, Op);
11842 EVT I16 = I32.isVector() ? I32.changeVectorElementType(MVT::i16) : MVT::i16;
11843 Op = DAG.getNode(ISD::TRUNCATE, dl, I16, Op);
11844 return DAG.getNode(ISD::BITCAST, dl, VT, Op);
11845 }
11846 return SDValue();
11847 }
11848
expandVectorSplice(SDNode * Node,SelectionDAG & DAG) const11849 SDValue TargetLowering::expandVectorSplice(SDNode *Node,
11850 SelectionDAG &DAG) const {
11851 assert(Node->getOpcode() == ISD::VECTOR_SPLICE && "Unexpected opcode!");
11852 assert(Node->getValueType(0).isScalableVector() &&
11853 "Fixed length vector types expected to use SHUFFLE_VECTOR!");
11854
11855 EVT VT = Node->getValueType(0);
11856 SDValue V1 = Node->getOperand(0);
11857 SDValue V2 = Node->getOperand(1);
11858 int64_t Imm = cast<ConstantSDNode>(Node->getOperand(2))->getSExtValue();
11859 SDLoc DL(Node);
11860
11861 // Expand through memory thusly:
11862 // Alloca CONCAT_VECTORS_TYPES(V1, V2) Ptr
11863 // Store V1, Ptr
11864 // Store V2, Ptr + sizeof(V1)
11865 // If (Imm < 0)
11866 // TrailingElts = -Imm
11867 // Ptr = Ptr + sizeof(V1) - (TrailingElts * sizeof(VT.Elt))
11868 // else
11869 // Ptr = Ptr + (Imm * sizeof(VT.Elt))
11870 // Res = Load Ptr
11871
11872 Align Alignment = DAG.getReducedAlign(VT, /*UseABI=*/false);
11873
11874 EVT MemVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
11875 VT.getVectorElementCount() * 2);
11876 SDValue StackPtr = DAG.CreateStackTemporary(MemVT.getStoreSize(), Alignment);
11877 EVT PtrVT = StackPtr.getValueType();
11878 auto &MF = DAG.getMachineFunction();
11879 auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
11880 auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
11881
11882 // Store the lo part of CONCAT_VECTORS(V1, V2)
11883 SDValue StoreV1 = DAG.getStore(DAG.getEntryNode(), DL, V1, StackPtr, PtrInfo);
11884 // Store the hi part of CONCAT_VECTORS(V1, V2)
11885 SDValue OffsetToV2 = DAG.getVScale(
11886 DL, PtrVT,
11887 APInt(PtrVT.getFixedSizeInBits(), VT.getStoreSize().getKnownMinValue()));
11888 SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, OffsetToV2);
11889 SDValue StoreV2 = DAG.getStore(StoreV1, DL, V2, StackPtr2, PtrInfo);
11890
11891 if (Imm >= 0) {
11892 // Load back the required element. getVectorElementPointer takes care of
11893 // clamping the index if it's out-of-bounds.
11894 StackPtr = getVectorElementPointer(DAG, StackPtr, VT, Node->getOperand(2));
11895 // Load the spliced result
11896 return DAG.getLoad(VT, DL, StoreV2, StackPtr,
11897 MachinePointerInfo::getUnknownStack(MF));
11898 }
11899
11900 uint64_t TrailingElts = -Imm;
11901
11902 // NOTE: TrailingElts must be clamped so as not to read outside of V1:V2.
11903 TypeSize EltByteSize = VT.getVectorElementType().getStoreSize();
11904 SDValue TrailingBytes =
11905 DAG.getConstant(TrailingElts * EltByteSize, DL, PtrVT);
11906
11907 if (TrailingElts > VT.getVectorMinNumElements()) {
11908 SDValue VLBytes =
11909 DAG.getVScale(DL, PtrVT,
11910 APInt(PtrVT.getFixedSizeInBits(),
11911 VT.getStoreSize().getKnownMinValue()));
11912 TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VLBytes);
11913 }
11914
11915 // Calculate the start address of the spliced result.
11916 StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);
11917
11918 // Load the spliced result
11919 return DAG.getLoad(VT, DL, StoreV2, StackPtr2,
11920 MachinePointerInfo::getUnknownStack(MF));
11921 }
11922
expandVECTOR_COMPRESS(SDNode * Node,SelectionDAG & DAG) const11923 SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
11924 SelectionDAG &DAG) const {
11925 SDLoc DL(Node);
11926 SDValue Vec = Node->getOperand(0);
11927 SDValue Mask = Node->getOperand(1);
11928 SDValue Passthru = Node->getOperand(2);
11929
11930 EVT VecVT = Vec.getValueType();
11931 EVT ScalarVT = VecVT.getScalarType();
11932 EVT MaskVT = Mask.getValueType();
11933 EVT MaskScalarVT = MaskVT.getScalarType();
11934
11935 // Needs to be handled by targets that have scalable vector types.
11936 if (VecVT.isScalableVector())
11937 report_fatal_error("Cannot expand masked_compress for scalable vectors.");
11938
11939 SDValue StackPtr = DAG.CreateStackTemporary(
11940 VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
11941 int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
11942 MachinePointerInfo PtrInfo =
11943 MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
11944
11945 MVT PositionVT = getVectorIdxTy(DAG.getDataLayout());
11946 SDValue Chain = DAG.getEntryNode();
11947 SDValue OutPos = DAG.getConstant(0, DL, PositionVT);
11948
11949 bool HasPassthru = !Passthru.isUndef();
11950
11951 // If we have a passthru vector, store it on the stack, overwrite the matching
11952 // positions and then re-write the last element that was potentially
11953 // overwritten even though mask[i] = false.
11954 if (HasPassthru)
11955 Chain = DAG.getStore(Chain, DL, Passthru, StackPtr, PtrInfo);
11956
11957 SDValue LastWriteVal;
11958 APInt PassthruSplatVal;
11959 bool IsSplatPassthru =
11960 ISD::isConstantSplatVector(Passthru.getNode(), PassthruSplatVal);
11961
11962 if (IsSplatPassthru) {
11963 // As we do not know which position we wrote to last, we cannot simply
11964 // access that index from the passthru vector. So we first check if passthru
11965 // is a splat vector, to use any element ...
11966 LastWriteVal = DAG.getConstant(PassthruSplatVal, DL, ScalarVT);
11967 } else if (HasPassthru) {
11968 // ... if it is not a splat vector, we need to get the passthru value at
11969 // position = popcount(mask) and re-load it from the stack before it is
11970 // overwritten in the loop below.
11971 EVT PopcountVT = ScalarVT.changeTypeToInteger();
11972 SDValue Popcount = DAG.getNode(
11973 ISD::TRUNCATE, DL, MaskVT.changeVectorElementType(MVT::i1), Mask);
11974 Popcount =
11975 DAG.getNode(ISD::ZERO_EXTEND, DL,
11976 MaskVT.changeVectorElementType(PopcountVT), Popcount);
11977 Popcount = DAG.getNode(ISD::VECREDUCE_ADD, DL, PopcountVT, Popcount);
11978 SDValue LastElmtPtr =
11979 getVectorElementPointer(DAG, StackPtr, VecVT, Popcount);
11980 LastWriteVal = DAG.getLoad(
11981 ScalarVT, DL, Chain, LastElmtPtr,
11982 MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
11983 Chain = LastWriteVal.getValue(1);
11984 }
11985
11986 unsigned NumElms = VecVT.getVectorNumElements();
11987 for (unsigned I = 0; I < NumElms; I++) {
11988 SDValue ValI = DAG.getExtractVectorElt(DL, ScalarVT, Vec, I);
11989 SDValue OutPtr = getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
11990 Chain = DAG.getStore(
11991 Chain, DL, ValI, OutPtr,
11992 MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
11993
11994 // Get the mask value and add it to the current output position. This
11995 // either increments by 1 if MaskI is true or adds 0 otherwise.
11996 // Freeze in case we have poison/undef mask entries.
11997 SDValue MaskI =
11998 DAG.getFreeze(DAG.getExtractVectorElt(DL, MaskScalarVT, Mask, I));
11999 MaskI = DAG.getFreeze(MaskI);
12000 MaskI = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, MaskI);
12001 MaskI = DAG.getNode(ISD::ZERO_EXTEND, DL, PositionVT, MaskI);
12002 OutPos = DAG.getNode(ISD::ADD, DL, PositionVT, OutPos, MaskI);
12003
12004 if (HasPassthru && I == NumElms - 1) {
12005 SDValue EndOfVector =
12006 DAG.getConstant(VecVT.getVectorNumElements() - 1, DL, PositionVT);
12007 SDValue AllLanesSelected =
12008 DAG.getSetCC(DL, MVT::i1, OutPos, EndOfVector, ISD::CondCode::SETUGT);
12009 OutPos = DAG.getNode(ISD::UMIN, DL, PositionVT, OutPos, EndOfVector);
12010 OutPtr = getVectorElementPointer(DAG, StackPtr, VecVT, OutPos);
12011
12012 // Re-write the last ValI if all lanes were selected. Otherwise,
12013 // overwrite the last write it with the passthru value.
12014 LastWriteVal = DAG.getSelect(DL, ScalarVT, AllLanesSelected, ValI,
12015 LastWriteVal, SDNodeFlags::Unpredictable);
12016 Chain = DAG.getStore(
12017 Chain, DL, LastWriteVal, OutPtr,
12018 MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()));
12019 }
12020 }
12021
12022 return DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
12023 }
12024
expandPartialReduceMLA(SDNode * N,SelectionDAG & DAG) const12025 SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
12026 SelectionDAG &DAG) const {
12027 SDLoc DL(N);
12028 SDValue Acc = N->getOperand(0);
12029 SDValue MulLHS = N->getOperand(1);
12030 SDValue MulRHS = N->getOperand(2);
12031 EVT AccVT = Acc.getValueType();
12032 EVT MulOpVT = MulLHS.getValueType();
12033
12034 EVT ExtMulOpVT =
12035 EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
12036 MulOpVT.getVectorElementCount());
12037
12038 unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
12039 ? ISD::ZERO_EXTEND
12040 : ISD::SIGN_EXTEND;
12041 unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
12042 ? ISD::SIGN_EXTEND
12043 : ISD::ZERO_EXTEND;
12044
12045 if (ExtMulOpVT != MulOpVT) {
12046 MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
12047 MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
12048 }
12049 SDValue Input = MulLHS;
12050 APInt ConstantOne;
12051 if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
12052 !ConstantOne.isOne())
12053 Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12054
12055 unsigned Stride = AccVT.getVectorMinNumElements();
12056 unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;
12057
12058 // Collect all of the subvectors
12059 std::deque<SDValue> Subvectors = {Acc};
12060 for (unsigned I = 0; I < ScaleFactor; I++)
12061 Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
12062
12063 // Flatten the subvector tree
12064 while (Subvectors.size() > 1) {
12065 Subvectors.push_back(
12066 DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
12067 Subvectors.pop_front();
12068 Subvectors.pop_front();
12069 }
12070
12071 assert(Subvectors.size() == 1 &&
12072 "There should only be one subvector after tree flattening");
12073
12074 return Subvectors[0];
12075 }
12076
LegalizeSetCCCondCode(SelectionDAG & DAG,EVT VT,SDValue & LHS,SDValue & RHS,SDValue & CC,SDValue Mask,SDValue EVL,bool & NeedInvert,const SDLoc & dl,SDValue & Chain,bool IsSignaling) const12077 bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
12078 SDValue &LHS, SDValue &RHS,
12079 SDValue &CC, SDValue Mask,
12080 SDValue EVL, bool &NeedInvert,
12081 const SDLoc &dl, SDValue &Chain,
12082 bool IsSignaling) const {
12083 MVT OpVT = LHS.getSimpleValueType();
12084 ISD::CondCode CCCode = cast<CondCodeSDNode>(CC)->get();
12085 NeedInvert = false;
12086 assert(!EVL == !Mask && "VP Mask and EVL must either both be set or unset");
12087 bool IsNonVP = !EVL;
12088 switch (getCondCodeAction(CCCode, OpVT)) {
12089 default:
12090 llvm_unreachable("Unknown condition code action!");
12091 case TargetLowering::Legal:
12092 // Nothing to do.
12093 break;
12094 case TargetLowering::Expand: {
12095 ISD::CondCode InvCC = ISD::getSetCCSwappedOperands(CCCode);
12096 if (isCondCodeLegalOrCustom(InvCC, OpVT)) {
12097 std::swap(LHS, RHS);
12098 CC = DAG.getCondCode(InvCC);
12099 return true;
12100 }
12101 // Swapping operands didn't work. Try inverting the condition.
12102 bool NeedSwap = false;
12103 InvCC = getSetCCInverse(CCCode, OpVT);
12104 if (!isCondCodeLegalOrCustom(InvCC, OpVT)) {
12105 // If inverting the condition is not enough, try swapping operands
12106 // on top of it.
12107 InvCC = ISD::getSetCCSwappedOperands(InvCC);
12108 NeedSwap = true;
12109 }
12110 if (isCondCodeLegalOrCustom(InvCC, OpVT)) {
12111 CC = DAG.getCondCode(InvCC);
12112 NeedInvert = true;
12113 if (NeedSwap)
12114 std::swap(LHS, RHS);
12115 return true;
12116 }
12117
12118 // Special case: expand i1 comparisons using logical operations.
12119 if (OpVT == MVT::i1) {
12120 SDValue Ret;
12121 switch (CCCode) {
12122 default:
12123 llvm_unreachable("Unknown integer setcc!");
12124 case ISD::SETEQ: // X == Y --> ~(X ^ Y)
12125 Ret = DAG.getNOT(dl, DAG.getNode(ISD::XOR, dl, MVT::i1, LHS, RHS),
12126 MVT::i1);
12127 break;
12128 case ISD::SETNE: // X != Y --> (X ^ Y)
12129 Ret = DAG.getNode(ISD::XOR, dl, MVT::i1, LHS, RHS);
12130 break;
12131 case ISD::SETGT: // X >s Y --> X == 0 & Y == 1 --> ~X & Y
12132 case ISD::SETULT: // X <u Y --> X == 0 & Y == 1 --> ~X & Y
12133 Ret = DAG.getNode(ISD::AND, dl, MVT::i1, RHS,
12134 DAG.getNOT(dl, LHS, MVT::i1));
12135 break;
12136 case ISD::SETLT: // X <s Y --> X == 1 & Y == 0 --> ~Y & X
12137 case ISD::SETUGT: // X >u Y --> X == 1 & Y == 0 --> ~Y & X
12138 Ret = DAG.getNode(ISD::AND, dl, MVT::i1, LHS,
12139 DAG.getNOT(dl, RHS, MVT::i1));
12140 break;
12141 case ISD::SETULE: // X <=u Y --> X == 0 | Y == 1 --> ~X | Y
12142 case ISD::SETGE: // X >=s Y --> X == 0 | Y == 1 --> ~X | Y
12143 Ret = DAG.getNode(ISD::OR, dl, MVT::i1, RHS,
12144 DAG.getNOT(dl, LHS, MVT::i1));
12145 break;
12146 case ISD::SETUGE: // X >=u Y --> X == 1 | Y == 0 --> ~Y | X
12147 case ISD::SETLE: // X <=s Y --> X == 1 | Y == 0 --> ~Y | X
12148 Ret = DAG.getNode(ISD::OR, dl, MVT::i1, LHS,
12149 DAG.getNOT(dl, RHS, MVT::i1));
12150 break;
12151 }
12152
12153 LHS = DAG.getZExtOrTrunc(Ret, dl, VT);
12154 RHS = SDValue();
12155 CC = SDValue();
12156 return true;
12157 }
12158
12159 ISD::CondCode CC1 = ISD::SETCC_INVALID, CC2 = ISD::SETCC_INVALID;
12160 unsigned Opc = 0;
12161 switch (CCCode) {
12162 default:
12163 llvm_unreachable("Don't know how to expand this condition!");
12164 case ISD::SETUO:
12165 if (isCondCodeLegal(ISD::SETUNE, OpVT)) {
12166 CC1 = ISD::SETUNE;
12167 CC2 = ISD::SETUNE;
12168 Opc = ISD::OR;
12169 break;
12170 }
12171 assert(isCondCodeLegal(ISD::SETOEQ, OpVT) &&
12172 "If SETUE is expanded, SETOEQ or SETUNE must be legal!");
12173 NeedInvert = true;
12174 [[fallthrough]];
12175 case ISD::SETO:
12176 assert(isCondCodeLegal(ISD::SETOEQ, OpVT) &&
12177 "If SETO is expanded, SETOEQ must be legal!");
12178 CC1 = ISD::SETOEQ;
12179 CC2 = ISD::SETOEQ;
12180 Opc = ISD::AND;
12181 break;
12182 case ISD::SETONE:
12183 case ISD::SETUEQ:
12184 // If the SETUO or SETO CC isn't legal, we might be able to use
12185 // SETOGT || SETOLT, inverting the result for SETUEQ. We only need one
12186 // of SETOGT/SETOLT to be legal, the other can be emulated by swapping
12187 // the operands.
12188 CC2 = ((unsigned)CCCode & 0x8U) ? ISD::SETUO : ISD::SETO;
12189 if (!isCondCodeLegal(CC2, OpVT) && (isCondCodeLegal(ISD::SETOGT, OpVT) ||
12190 isCondCodeLegal(ISD::SETOLT, OpVT))) {
12191 CC1 = ISD::SETOGT;
12192 CC2 = ISD::SETOLT;
12193 Opc = ISD::OR;
12194 NeedInvert = ((unsigned)CCCode & 0x8U);
12195 break;
12196 }
12197 [[fallthrough]];
12198 case ISD::SETOEQ:
12199 case ISD::SETOGT:
12200 case ISD::SETOGE:
12201 case ISD::SETOLT:
12202 case ISD::SETOLE:
12203 case ISD::SETUNE:
12204 case ISD::SETUGT:
12205 case ISD::SETUGE:
12206 case ISD::SETULT:
12207 case ISD::SETULE:
12208 // If we are floating point, assign and break, otherwise fall through.
12209 if (!OpVT.isInteger()) {
12210 // We can use the 4th bit to tell if we are the unordered
12211 // or ordered version of the opcode.
12212 CC2 = ((unsigned)CCCode & 0x8U) ? ISD::SETUO : ISD::SETO;
12213 Opc = ((unsigned)CCCode & 0x8U) ? ISD::OR : ISD::AND;
12214 CC1 = (ISD::CondCode)(((int)CCCode & 0x7) | 0x10);
12215 break;
12216 }
12217 // Fallthrough if we are unsigned integer.
12218 [[fallthrough]];
12219 case ISD::SETLE:
12220 case ISD::SETGT:
12221 case ISD::SETGE:
12222 case ISD::SETLT:
12223 case ISD::SETNE:
12224 case ISD::SETEQ:
12225 // If all combinations of inverting the condition and swapping operands
12226 // didn't work then we have no means to expand the condition.
12227 llvm_unreachable("Don't know how to expand this condition!");
12228 }
12229
12230 SDValue SetCC1, SetCC2;
12231 if (CCCode != ISD::SETO && CCCode != ISD::SETUO) {
12232 // If we aren't the ordered or unorder operation,
12233 // then the pattern is (LHS CC1 RHS) Opc (LHS CC2 RHS).
12234 if (IsNonVP) {
12235 SetCC1 = DAG.getSetCC(dl, VT, LHS, RHS, CC1, Chain, IsSignaling);
12236 SetCC2 = DAG.getSetCC(dl, VT, LHS, RHS, CC2, Chain, IsSignaling);
12237 } else {
12238 SetCC1 = DAG.getSetCCVP(dl, VT, LHS, RHS, CC1, Mask, EVL);
12239 SetCC2 = DAG.getSetCCVP(dl, VT, LHS, RHS, CC2, Mask, EVL);
12240 }
12241 } else {
12242 // Otherwise, the pattern is (LHS CC1 LHS) Opc (RHS CC2 RHS)
12243 if (IsNonVP) {
12244 SetCC1 = DAG.getSetCC(dl, VT, LHS, LHS, CC1, Chain, IsSignaling);
12245 SetCC2 = DAG.getSetCC(dl, VT, RHS, RHS, CC2, Chain, IsSignaling);
12246 } else {
12247 SetCC1 = DAG.getSetCCVP(dl, VT, LHS, LHS, CC1, Mask, EVL);
12248 SetCC2 = DAG.getSetCCVP(dl, VT, RHS, RHS, CC2, Mask, EVL);
12249 }
12250 }
12251 if (Chain)
12252 Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, SetCC1.getValue(1),
12253 SetCC2.getValue(1));
12254 if (IsNonVP)
12255 LHS = DAG.getNode(Opc, dl, VT, SetCC1, SetCC2);
12256 else {
12257 // Transform the binary opcode to the VP equivalent.
12258 assert((Opc == ISD::OR || Opc == ISD::AND) && "Unexpected opcode");
12259 Opc = Opc == ISD::OR ? ISD::VP_OR : ISD::VP_AND;
12260 LHS = DAG.getNode(Opc, dl, VT, SetCC1, SetCC2, Mask, EVL);
12261 }
12262 RHS = SDValue();
12263 CC = SDValue();
12264 return true;
12265 }
12266 }
12267 return false;
12268 }
12269
expandVectorNaryOpBySplitting(SDNode * Node,SelectionDAG & DAG) const12270 SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
12271 SelectionDAG &DAG) const {
12272 EVT VT = Node->getValueType(0);
12273 // Despite its documentation, GetSplitDestVTs will assert if VT cannot be
12274 // split into two equal parts.
12275 if (!VT.isVector() || !VT.getVectorElementCount().isKnownMultipleOf(2))
12276 return SDValue();
12277
12278 // Restrict expansion to cases where both parts can be concatenated.
12279 auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VT);
12280 if (LoVT != HiVT || !isTypeLegal(LoVT))
12281 return SDValue();
12282
12283 SDLoc DL(Node);
12284 unsigned Opcode = Node->getOpcode();
12285
12286 // Don't expand if the result is likely to be unrolled anyway.
12287 if (!isOperationLegalOrCustomOrPromote(Opcode, LoVT))
12288 return SDValue();
12289
12290 SmallVector<SDValue, 4> LoOps, HiOps;
12291 for (const SDValue &V : Node->op_values()) {
12292 auto [Lo, Hi] = DAG.SplitVector(V, DL, LoVT, HiVT);
12293 LoOps.push_back(Lo);
12294 HiOps.push_back(Hi);
12295 }
12296
12297 SDValue SplitOpLo = DAG.getNode(Opcode, DL, LoVT, LoOps);
12298 SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
12299 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
12300 }
12301
scalarizeExtractedVectorLoad(EVT ResultVT,const SDLoc & DL,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad,SelectionDAG & DAG) const12302 SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
12303 const SDLoc &DL,
12304 EVT InVecVT, SDValue EltNo,
12305 LoadSDNode *OriginalLoad,
12306 SelectionDAG &DAG) const {
12307 assert(OriginalLoad->isSimple());
12308
12309 EVT VecEltVT = InVecVT.getVectorElementType();
12310
12311 // If the vector element type is not a multiple of a byte then we are unable
12312 // to correctly compute an address to load only the extracted element as a
12313 // scalar.
12314 if (!VecEltVT.isByteSized())
12315 return SDValue();
12316
12317 ISD::LoadExtType ExtTy =
12318 ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
12319 if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT))
12320 return SDValue();
12321
12322 std::optional<unsigned> ByteOffset;
12323 Align Alignment = OriginalLoad->getAlign();
12324 MachinePointerInfo MPI;
12325 if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
12326 int Elt = ConstEltNo->getZExtValue();
12327 ByteOffset = VecEltVT.getSizeInBits() * Elt / 8;
12328 MPI = OriginalLoad->getPointerInfo().getWithOffset(*ByteOffset);
12329 Alignment = commonAlignment(Alignment, *ByteOffset);
12330 } else {
12331 // Discard the pointer info except the address space because the memory
12332 // operand can't represent this new access since the offset is variable.
12333 MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
12334 Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
12335 }
12336
12337 if (!shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT, ByteOffset))
12338 return SDValue();
12339
12340 unsigned IsFast = 0;
12341 if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
12342 OriginalLoad->getAddressSpace(), Alignment,
12343 OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
12344 !IsFast)
12345 return SDValue();
12346
12347 SDValue NewPtr =
12348 getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);
12349
12350 // We are replacing a vector load with a scalar load. The new load must have
12351 // identical memory op ordering to the original.
12352 SDValue Load;
12353 if (ResultVT.bitsGT(VecEltVT)) {
12354 // If the result type of vextract is wider than the load, then issue an
12355 // extending load instead.
12356 ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
12357 ? ISD::ZEXTLOAD
12358 : ISD::EXTLOAD;
12359 Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
12360 NewPtr, MPI, VecEltVT, Alignment,
12361 OriginalLoad->getMemOperand()->getFlags(),
12362 OriginalLoad->getAAInfo());
12363 DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12364 } else {
12365 // The result type is narrower or the same width as the vector element
12366 Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
12367 Alignment, OriginalLoad->getMemOperand()->getFlags(),
12368 OriginalLoad->getAAInfo());
12369 DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12370 if (ResultVT.bitsLT(VecEltVT))
12371 Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
12372 else
12373 Load = DAG.getBitcast(ResultVT, Load);
12374 }
12375
12376 return Load;
12377 }
12378