xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (revision 6e516c87b6d779911edde7481d8aef165b837a03)
1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/MemoryLocation.h"
32 #include "llvm/Analysis/TargetLibraryInfo.h"
33 #include "llvm/Analysis/ValueTracking.h"
34 #include "llvm/Analysis/VectorUtils.h"
35 #include "llvm/CodeGen/ByteProvider.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFunction.h"
39 #include "llvm/CodeGen/MachineMemOperand.h"
40 #include "llvm/CodeGen/RuntimeLibcallUtil.h"
41 #include "llvm/CodeGen/SDPatternMatch.h"
42 #include "llvm/CodeGen/SelectionDAG.h"
43 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
44 #include "llvm/CodeGen/SelectionDAGNodes.h"
45 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
46 #include "llvm/CodeGen/TargetLowering.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/TargetSubtargetInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/CodeGenTypes/MachineValueType.h"
51 #include "llvm/IR/Attributes.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/DataLayout.h"
54 #include "llvm/IR/DerivedTypes.h"
55 #include "llvm/IR/Function.h"
56 #include "llvm/IR/Metadata.h"
57 #include "llvm/Support/Casting.h"
58 #include "llvm/Support/CodeGen.h"
59 #include "llvm/Support/CommandLine.h"
60 #include "llvm/Support/Compiler.h"
61 #include "llvm/Support/Debug.h"
62 #include "llvm/Support/DebugCounter.h"
63 #include "llvm/Support/ErrorHandling.h"
64 #include "llvm/Support/KnownBits.h"
65 #include "llvm/Support/MathExtras.h"
66 #include "llvm/Support/raw_ostream.h"
67 #include "llvm/Target/TargetMachine.h"
68 #include "llvm/Target/TargetOptions.h"
69 #include <algorithm>
70 #include <cassert>
71 #include <cstdint>
72 #include <functional>
73 #include <iterator>
74 #include <optional>
75 #include <string>
76 #include <tuple>
77 #include <utility>
78 #include <variant>
79 
80 #include "MatchContext.h"
81 
82 using namespace llvm;
83 using namespace llvm::SDPatternMatch;
84 
85 #define DEBUG_TYPE "dagcombine"
86 
87 STATISTIC(NodesCombined   , "Number of dag nodes combined");
88 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
89 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
90 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
91 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
92 STATISTIC(SlicedLoads, "Number of load sliced");
93 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
94 
95 DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
96               "Controls whether a DAG combine is performed for a node");
97 
98 static cl::opt<bool>
99 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
100                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
101 
102 static cl::opt<bool>
103 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
104         cl::desc("Enable DAG combiner's use of TBAA"));
105 
106 #ifndef NDEBUG
107 static cl::opt<std::string>
108 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
109                    cl::desc("Only use DAG-combiner alias analysis in this"
110                             " function"));
111 #endif
112 
113 /// Hidden option to stress test load slicing, i.e., when this option
114 /// is enabled, load slicing bypasses most of its profitability guards.
115 static cl::opt<bool>
116 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
117                   cl::desc("Bypass the profitability model of load slicing"),
118                   cl::init(false));
119 
120 static cl::opt<bool>
121   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
122                     cl::desc("DAG combiner may split indexing from loads"));
123 
124 static cl::opt<bool>
125     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
126                        cl::desc("DAG combiner enable merging multiple stores "
127                                 "into a wider store"));
128 
129 static cl::opt<unsigned> TokenFactorInlineLimit(
130     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
131     cl::desc("Limit the number of operands to inline for Token Factors"));
132 
133 static cl::opt<unsigned> StoreMergeDependenceLimit(
134     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
135     cl::desc("Limit the number of times for the same StoreNode and RootNode "
136              "to bail out in store merging dependence check"));
137 
138 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
139     "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
140     cl::desc("DAG combiner enable reducing the width of load/op/store "
141              "sequence"));
142 
143 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
144     "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
145     cl::desc("DAG combiner enable load/<replace bytes>/store with "
146              "a narrower store"));
147 
148 static cl::opt<bool> EnableVectorFCopySignExtendRound(
149     "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
150     cl::desc(
151         "Enable merging extends and rounds into FCOPYSIGN on vector types"));
152 
153 namespace {
154 
155   class DAGCombiner {
156     SelectionDAG &DAG;
157     const TargetLowering &TLI;
158     const SelectionDAGTargetInfo *STI;
159     CombineLevel Level = BeforeLegalizeTypes;
160     CodeGenOptLevel OptLevel;
161     bool LegalDAG = false;
162     bool LegalOperations = false;
163     bool LegalTypes = false;
164     bool ForCodeSize;
165     bool DisableGenericCombines;
166 
167     /// Worklist of all of the nodes that need to be simplified.
168     ///
169     /// This must behave as a stack -- new nodes to process are pushed onto the
170     /// back and when processing we pop off of the back.
171     ///
172     /// The worklist will not contain duplicates but may contain null entries
173     /// due to nodes being deleted from the underlying DAG. For fast lookup and
174     /// deduplication, the index of the node in this vector is stored in the
175     /// node in SDNode::CombinerWorklistIndex.
176     SmallVector<SDNode *, 64> Worklist;
177 
178     /// This records all nodes attempted to be added to the worklist since we
179     /// considered a new worklist entry. As we keep do not add duplicate nodes
180     /// in the worklist, this is different from the tail of the worklist.
181     SmallSetVector<SDNode *, 32> PruningList;
182 
183     /// Map from candidate StoreNode to the pair of RootNode and count.
184     /// The count is used to track how many times we have seen the StoreNode
185     /// with the same RootNode bail out in dependence check. If we have seen
186     /// the bail out for the same pair many times over a limit, we won't
187     /// consider the StoreNode with the same RootNode as store merging
188     /// candidate again.
189     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
190 
191     // AA - Used for DAG load/store alias analysis.
192     AliasAnalysis *AA;
193 
194     /// This caches all chains that have already been processed in
195     /// DAGCombiner::getStoreMergeCandidates() and found to have no mergeable
196     /// stores candidates.
197     SmallPtrSet<SDNode *, 4> ChainsWithoutMergeableStores;
198 
199     /// When an instruction is simplified, add all users of the instruction to
200     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)201     void AddUsersToWorklist(SDNode *N) {
202       for (SDNode *Node : N->uses())
203         AddToWorklist(Node);
204     }
205 
206     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)207     void AddToWorklistWithUsers(SDNode *N) {
208       AddUsersToWorklist(N);
209       AddToWorklist(N);
210     }
211 
212     // Prune potentially dangling nodes. This is called after
213     // any visit to a node, but should also be called during a visit after any
214     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()215     void clearAddedDanglingWorklistEntries() {
216       // Check any nodes added to the worklist to see if they are prunable.
217       while (!PruningList.empty()) {
218         auto *N = PruningList.pop_back_val();
219         if (N->use_empty())
220           recursivelyDeleteUnusedNodes(N);
221       }
222     }
223 
getNextWorklistEntry()224     SDNode *getNextWorklistEntry() {
225       // Before we do any work, remove nodes that are not in use.
226       clearAddedDanglingWorklistEntries();
227       SDNode *N = nullptr;
228       // The Worklist holds the SDNodes in order, but it may contain null
229       // entries.
230       while (!N && !Worklist.empty()) {
231         N = Worklist.pop_back_val();
232       }
233 
234       if (N) {
235         assert(N->getCombinerWorklistIndex() >= 0 &&
236                "Found a worklist entry without a corresponding map entry!");
237         // Set to -2 to indicate that we combined the node.
238         N->setCombinerWorklistIndex(-2);
239       }
240       return N;
241     }
242 
243     /// Call the node-specific routine that folds each particular type of node.
244     SDValue visit(SDNode *N);
245 
246   public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOptLevel OL)247     DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOptLevel OL)
248         : DAG(D), TLI(D.getTargetLoweringInfo()),
249           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
250       ForCodeSize = DAG.shouldOptForSize();
251       DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
252 
253       MaximumLegalStoreInBits = 0;
254       // We use the minimum store size here, since that's all we can guarantee
255       // for the scalable vector types.
256       for (MVT VT : MVT::all_valuetypes())
257         if (EVT(VT).isSimple() && VT != MVT::Other &&
258             TLI.isTypeLegal(EVT(VT)) &&
259             VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
260           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
261     }
262 
ConsiderForPruning(SDNode * N)263     void ConsiderForPruning(SDNode *N) {
264       // Mark this for potential pruning.
265       PruningList.insert(N);
266     }
267 
268     /// Add to the worklist making sure its instance is at the back (next to be
269     /// processed.)
AddToWorklist(SDNode * N,bool IsCandidateForPruning=true,bool SkipIfCombinedBefore=false)270     void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true,
271                        bool SkipIfCombinedBefore = false) {
272       assert(N->getOpcode() != ISD::DELETED_NODE &&
273              "Deleted Node added to Worklist");
274 
275       // Skip handle nodes as they can't usefully be combined and confuse the
276       // zero-use deletion strategy.
277       if (N->getOpcode() == ISD::HANDLENODE)
278         return;
279 
280       if (SkipIfCombinedBefore && N->getCombinerWorklistIndex() == -2)
281         return;
282 
283       if (IsCandidateForPruning)
284         ConsiderForPruning(N);
285 
286       if (N->getCombinerWorklistIndex() < 0) {
287         N->setCombinerWorklistIndex(Worklist.size());
288         Worklist.push_back(N);
289       }
290     }
291 
292     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)293     void removeFromWorklist(SDNode *N) {
294       PruningList.remove(N);
295       StoreRootCountMap.erase(N);
296 
297       int WorklistIndex = N->getCombinerWorklistIndex();
298       // If not in the worklist, the index might be -1 or -2 (was combined
299       // before). As the node gets deleted anyway, there's no need to update
300       // the index.
301       if (WorklistIndex < 0)
302         return; // Not in the worklist.
303 
304       // Null out the entry rather than erasing it to avoid a linear operation.
305       Worklist[WorklistIndex] = nullptr;
306       N->setCombinerWorklistIndex(-1);
307     }
308 
309     void deleteAndRecombine(SDNode *N);
310     bool recursivelyDeleteUnusedNodes(SDNode *N);
311 
312     /// Replaces all uses of the results of one DAG node with new values.
313     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
314                       bool AddTo = true);
315 
316     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)317     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
318       return CombineTo(N, &Res, 1, AddTo);
319     }
320 
321     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)322     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
323                       bool AddTo = true) {
324       SDValue To[] = { Res0, Res1 };
325       return CombineTo(N, To, 2, AddTo);
326     }
327 
328     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
329 
330   private:
331     unsigned MaximumLegalStoreInBits;
332 
333     /// Check the specified integer node value to see if it can be simplified or
334     /// if things it uses can be simplified by bit propagation.
335     /// If so, return true.
SimplifyDemandedBits(SDValue Op)336     bool SimplifyDemandedBits(SDValue Op) {
337       unsigned BitWidth = Op.getScalarValueSizeInBits();
338       APInt DemandedBits = APInt::getAllOnes(BitWidth);
339       return SimplifyDemandedBits(Op, DemandedBits);
340     }
341 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)342     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
343       EVT VT = Op.getValueType();
344       APInt DemandedElts = VT.isFixedLengthVector()
345                                ? APInt::getAllOnes(VT.getVectorNumElements())
346                                : APInt(1, 1);
347       return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, false);
348     }
349 
350     /// Check the specified vector node value to see if it can be simplified or
351     /// if things it uses can be simplified as it only uses some of the
352     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)353     bool SimplifyDemandedVectorElts(SDValue Op) {
354       // TODO: For now just pretend it cannot be simplified.
355       if (Op.getValueType().isScalableVector())
356         return false;
357 
358       unsigned NumElts = Op.getValueType().getVectorNumElements();
359       APInt DemandedElts = APInt::getAllOnes(NumElts);
360       return SimplifyDemandedVectorElts(Op, DemandedElts);
361     }
362 
363     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
364                               const APInt &DemandedElts,
365                               bool AssumeSingleUse = false);
366     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
367                                     bool AssumeSingleUse = false);
368 
369     bool CombineToPreIndexedLoadStore(SDNode *N);
370     bool CombineToPostIndexedLoadStore(SDNode *N);
371     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
372     bool SliceUpLoad(SDNode *N);
373 
374     // Looks up the chain to find a unique (unaliased) store feeding the passed
375     // load. If no such store is found, returns a nullptr.
376     // Note: This will look past a CALLSEQ_START if the load is chained to it so
377     //       so that it can find stack stores for byval params.
378     StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
379     // Scalars have size 0 to distinguish from singleton vectors.
380     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
381     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
382     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
383 
384     /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
385     ///   load.
386     ///
387     /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
388     /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
389     /// \param EltNo index of the vector element to load.
390     /// \param OriginalLoad load that EVE came from to be replaced.
391     /// \returns EVE on success SDValue() on failure.
392     SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
393                                          SDValue EltNo,
394                                          LoadSDNode *OriginalLoad);
395     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
396     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
397     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
398     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
399     SDValue PromoteIntBinOp(SDValue Op);
400     SDValue PromoteIntShiftOp(SDValue Op);
401     SDValue PromoteExtend(SDValue Op);
402     bool PromoteLoad(SDValue Op);
403 
404     SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
405                                 SDValue RHS, SDValue True, SDValue False,
406                                 ISD::CondCode CC);
407 
408     /// Call the node-specific routine that knows how to fold each
409     /// particular type of node. If that doesn't do anything, try the
410     /// target-specific DAG combines.
411     SDValue combine(SDNode *N);
412 
413     // Visitation implementation - Implement dag node combining for different
414     // node types.  The semantics are as follows:
415     // Return Value:
416     //   SDValue.getNode() == 0 - No change was made
417     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
418     //   otherwise              - N should be replaced by the returned Operand.
419     //
420     SDValue visitTokenFactor(SDNode *N);
421     SDValue visitMERGE_VALUES(SDNode *N);
422     SDValue visitADD(SDNode *N);
423     SDValue visitADDLike(SDNode *N);
424     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
425     SDValue visitSUB(SDNode *N);
426     SDValue visitADDSAT(SDNode *N);
427     SDValue visitSUBSAT(SDNode *N);
428     SDValue visitADDC(SDNode *N);
429     SDValue visitADDO(SDNode *N);
430     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
431     SDValue visitSUBC(SDNode *N);
432     SDValue visitSUBO(SDNode *N);
433     SDValue visitADDE(SDNode *N);
434     SDValue visitUADDO_CARRY(SDNode *N);
435     SDValue visitSADDO_CARRY(SDNode *N);
436     SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
437                                  SDNode *N);
438     SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
439                                  SDNode *N);
440     SDValue visitSUBE(SDNode *N);
441     SDValue visitUSUBO_CARRY(SDNode *N);
442     SDValue visitSSUBO_CARRY(SDNode *N);
443     template <class MatchContextClass> SDValue visitMUL(SDNode *N);
444     SDValue visitMULFIX(SDNode *N);
445     SDValue useDivRem(SDNode *N);
446     SDValue visitSDIV(SDNode *N);
447     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
448     SDValue visitUDIV(SDNode *N);
449     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
450     SDValue visitREM(SDNode *N);
451     SDValue visitMULHU(SDNode *N);
452     SDValue visitMULHS(SDNode *N);
453     SDValue visitAVG(SDNode *N);
454     SDValue visitABD(SDNode *N);
455     SDValue visitSMUL_LOHI(SDNode *N);
456     SDValue visitUMUL_LOHI(SDNode *N);
457     SDValue visitMULO(SDNode *N);
458     SDValue visitIMINMAX(SDNode *N);
459     SDValue visitAND(SDNode *N);
460     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
461     SDValue visitOR(SDNode *N);
462     SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
463     SDValue visitXOR(SDNode *N);
464     SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
465     SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
466     SDValue visitSHL(SDNode *N);
467     SDValue visitSRA(SDNode *N);
468     SDValue visitSRL(SDNode *N);
469     SDValue visitFunnelShift(SDNode *N);
470     SDValue visitSHLSAT(SDNode *N);
471     SDValue visitRotate(SDNode *N);
472     SDValue visitABS(SDNode *N);
473     SDValue visitBSWAP(SDNode *N);
474     SDValue visitBITREVERSE(SDNode *N);
475     SDValue visitCTLZ(SDNode *N);
476     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
477     SDValue visitCTTZ(SDNode *N);
478     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
479     SDValue visitCTPOP(SDNode *N);
480     SDValue visitSELECT(SDNode *N);
481     SDValue visitVSELECT(SDNode *N);
482     SDValue visitVP_SELECT(SDNode *N);
483     SDValue visitSELECT_CC(SDNode *N);
484     SDValue visitSETCC(SDNode *N);
485     SDValue visitSETCCCARRY(SDNode *N);
486     SDValue visitSIGN_EXTEND(SDNode *N);
487     SDValue visitZERO_EXTEND(SDNode *N);
488     SDValue visitANY_EXTEND(SDNode *N);
489     SDValue visitAssertExt(SDNode *N);
490     SDValue visitAssertAlign(SDNode *N);
491     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
492     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
493     SDValue visitTRUNCATE(SDNode *N);
494     SDValue visitBITCAST(SDNode *N);
495     SDValue visitFREEZE(SDNode *N);
496     SDValue visitBUILD_PAIR(SDNode *N);
497     SDValue visitFADD(SDNode *N);
498     SDValue visitVP_FADD(SDNode *N);
499     SDValue visitVP_FSUB(SDNode *N);
500     SDValue visitSTRICT_FADD(SDNode *N);
501     SDValue visitFSUB(SDNode *N);
502     SDValue visitFMUL(SDNode *N);
503     template <class MatchContextClass> SDValue visitFMA(SDNode *N);
504     SDValue visitFMAD(SDNode *N);
505     SDValue visitFDIV(SDNode *N);
506     SDValue visitFREM(SDNode *N);
507     SDValue visitFSQRT(SDNode *N);
508     SDValue visitFCOPYSIGN(SDNode *N);
509     SDValue visitFPOW(SDNode *N);
510     SDValue visitSINT_TO_FP(SDNode *N);
511     SDValue visitUINT_TO_FP(SDNode *N);
512     SDValue visitFP_TO_SINT(SDNode *N);
513     SDValue visitFP_TO_UINT(SDNode *N);
514     SDValue visitXRINT(SDNode *N);
515     SDValue visitFP_ROUND(SDNode *N);
516     SDValue visitFP_EXTEND(SDNode *N);
517     SDValue visitFNEG(SDNode *N);
518     SDValue visitFABS(SDNode *N);
519     SDValue visitFCEIL(SDNode *N);
520     SDValue visitFTRUNC(SDNode *N);
521     SDValue visitFFREXP(SDNode *N);
522     SDValue visitFFLOOR(SDNode *N);
523     SDValue visitFMinMax(SDNode *N);
524     SDValue visitBRCOND(SDNode *N);
525     SDValue visitBR_CC(SDNode *N);
526     SDValue visitLOAD(SDNode *N);
527 
528     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
529     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
530     SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
531 
532     bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
533 
534     SDValue visitSTORE(SDNode *N);
535     SDValue visitATOMIC_STORE(SDNode *N);
536     SDValue visitLIFETIME_END(SDNode *N);
537     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
538     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
539     SDValue visitBUILD_VECTOR(SDNode *N);
540     SDValue visitCONCAT_VECTORS(SDNode *N);
541     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
542     SDValue visitVECTOR_SHUFFLE(SDNode *N);
543     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
544     SDValue visitINSERT_SUBVECTOR(SDNode *N);
545     SDValue visitVECTOR_COMPRESS(SDNode *N);
546     SDValue visitMLOAD(SDNode *N);
547     SDValue visitMSTORE(SDNode *N);
548     SDValue visitMGATHER(SDNode *N);
549     SDValue visitMSCATTER(SDNode *N);
550     SDValue visitVPGATHER(SDNode *N);
551     SDValue visitVPSCATTER(SDNode *N);
552     SDValue visitVP_STRIDED_LOAD(SDNode *N);
553     SDValue visitVP_STRIDED_STORE(SDNode *N);
554     SDValue visitFP_TO_FP16(SDNode *N);
555     SDValue visitFP16_TO_FP(SDNode *N);
556     SDValue visitFP_TO_BF16(SDNode *N);
557     SDValue visitBF16_TO_FP(SDNode *N);
558     SDValue visitVECREDUCE(SDNode *N);
559     SDValue visitVPOp(SDNode *N);
560     SDValue visitGET_FPENV_MEM(SDNode *N);
561     SDValue visitSET_FPENV_MEM(SDNode *N);
562 
563     template <class MatchContextClass>
564     SDValue visitFADDForFMACombine(SDNode *N);
565     template <class MatchContextClass>
566     SDValue visitFSUBForFMACombine(SDNode *N);
567     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
568 
569     SDValue XformToShuffleWithZero(SDNode *N);
570     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
571                                                     const SDLoc &DL,
572                                                     SDNode *N,
573                                                     SDValue N0,
574                                                     SDValue N1);
575     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
576                                       SDValue N1, SDNodeFlags Flags);
577     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
578                            SDValue N1, SDNodeFlags Flags);
579     SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
580                                  EVT VT, SDValue N0, SDValue N1,
581                                  SDNodeFlags Flags = SDNodeFlags());
582 
583     SDValue visitShiftByConstant(SDNode *N);
584 
585     SDValue foldSelectOfConstants(SDNode *N);
586     SDValue foldVSelectOfConstants(SDNode *N);
587     SDValue foldBinOpIntoSelect(SDNode *BO);
588     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
589     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
590     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
591     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
592                              SDValue N2, SDValue N3, ISD::CondCode CC,
593                              bool NotExtCompare = false);
594     SDValue convertSelectOfFPConstantsToLoadOffset(
595         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
596         ISD::CondCode CC);
597     SDValue foldSignChangeInBitcast(SDNode *N);
598     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
599                                    SDValue N2, SDValue N3, ISD::CondCode CC);
600     SDValue foldSelectOfBinops(SDNode *N);
601     SDValue foldSextSetcc(SDNode *N);
602     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
603                               const SDLoc &DL);
604     SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
605     SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
606     SDValue unfoldMaskedMerge(SDNode *N);
607     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
608     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
609                           const SDLoc &DL, bool foldBooleans);
610     SDValue rebuildSetCC(SDValue N);
611 
612     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
613                            SDValue &CC, bool MatchStrict = false) const;
614     bool isOneUseSetCC(SDValue N) const;
615 
616     SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
617     SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
618 
619     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
620                                          unsigned HiOp);
621     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
622     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
623                                  const TargetLowering &TLI);
624 
625     SDValue CombineExtLoad(SDNode *N);
626     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
627     SDValue combineRepeatedFPDivisors(SDNode *N);
628     SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
629     SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
630     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
631     SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
632     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
633     SDValue BuildSDIV(SDNode *N);
634     SDValue BuildSDIVPow2(SDNode *N);
635     SDValue BuildUDIV(SDNode *N);
636     SDValue BuildSREMPow2(SDNode *N);
637     SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
638     SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
639                           bool KnownNeverZero = false,
640                           bool InexpensiveOnly = false,
641                           std::optional<EVT> OutVT = std::nullopt);
642     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
643     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
644     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
645     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
646     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
647                                 SDNodeFlags Flags, bool Reciprocal);
648     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
649                                 SDNodeFlags Flags, bool Reciprocal);
650     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
651                                bool DemandHighBits = true);
652     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
653     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
654                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
655                               unsigned PosOpcode, unsigned NegOpcode,
656                               const SDLoc &DL);
657     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
658                               SDValue InnerPos, SDValue InnerNeg, bool HasPos,
659                               unsigned PosOpcode, unsigned NegOpcode,
660                               const SDLoc &DL);
661     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
662     SDValue MatchLoadCombine(SDNode *N);
663     SDValue mergeTruncStores(StoreSDNode *N);
664     SDValue reduceLoadWidth(SDNode *N);
665     SDValue ReduceLoadOpStoreWidth(SDNode *N);
666     SDValue splitMergedValStore(StoreSDNode *ST);
667     SDValue TransformFPLoadStorePair(SDNode *N);
668     SDValue convertBuildVecZextToZext(SDNode *N);
669     SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
670     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
671     SDValue reduceBuildVecTruncToBitCast(SDNode *N);
672     SDValue reduceBuildVecToShuffle(SDNode *N);
673     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
674                                   ArrayRef<int> VectorMask, SDValue VecIn1,
675                                   SDValue VecIn2, unsigned LeftIdx,
676                                   bool DidSplitVec);
677     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
678 
679     /// Walk up chain skipping non-aliasing memory nodes,
680     /// looking for aliasing nodes and adding them to the Aliases vector.
681     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
682                           SmallVectorImpl<SDValue> &Aliases);
683 
684     /// Return true if there is any possibility that the two addresses overlap.
685     bool mayAlias(SDNode *Op0, SDNode *Op1) const;
686 
687     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
688     /// chain (aliasing node.)
689     SDValue FindBetterChain(SDNode *N, SDValue Chain);
690 
691     /// Try to replace a store and any possibly adjacent stores on
692     /// consecutive chains with better chains. Return true only if St is
693     /// replaced.
694     ///
695     /// Notice that other chains may still be replaced even if the function
696     /// returns false.
697     bool findBetterNeighborChains(StoreSDNode *St);
698 
699     // Helper for findBetterNeighborChains. Walk up store chain add additional
700     // chained stores that do not overlap and can be parallelized.
701     bool parallelizeChainedStores(StoreSDNode *St);
702 
703     /// Holds a pointer to an LSBaseSDNode as well as information on where it
704     /// is located in a sequence of memory operations connected by a chain.
705     struct MemOpLink {
706       // Ptr to the mem node.
707       LSBaseSDNode *MemNode;
708 
709       // Offset from the base ptr.
710       int64_t OffsetFromBase;
711 
MemOpLink__anon666e37100111::DAGCombiner::MemOpLink712       MemOpLink(LSBaseSDNode *N, int64_t Offset)
713           : MemNode(N), OffsetFromBase(Offset) {}
714     };
715 
716     // Classify the origin of a stored value.
717     enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)718     StoreSource getStoreSource(SDValue StoreVal) {
719       switch (StoreVal.getOpcode()) {
720       case ISD::Constant:
721       case ISD::ConstantFP:
722         return StoreSource::Constant;
723       case ISD::BUILD_VECTOR:
724         if (ISD::isBuildVectorOfConstantSDNodes(StoreVal.getNode()) ||
725             ISD::isBuildVectorOfConstantFPSDNodes(StoreVal.getNode()))
726           return StoreSource::Constant;
727         return StoreSource::Unknown;
728       case ISD::EXTRACT_VECTOR_ELT:
729       case ISD::EXTRACT_SUBVECTOR:
730         return StoreSource::Extract;
731       case ISD::LOAD:
732         return StoreSource::Load;
733       default:
734         return StoreSource::Unknown;
735       }
736     }
737 
738     /// This is a helper function for visitMUL to check the profitability
739     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
740     /// MulNode is the original multiply, AddNode is (add x, c1),
741     /// and ConstNode is c2.
742     bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
743                                      SDValue ConstNode);
744 
745     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
746     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
747     /// the type of the loaded value to be extended.
748     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
749                           EVT LoadResultTy, EVT &ExtVT);
750 
751     /// Helper function to calculate whether the given Load/Store can have its
752     /// width reduced to ExtVT.
753     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
754                            EVT &MemVT, unsigned ShAmt = 0);
755 
756     /// Used by BackwardsPropagateMask to find suitable loads.
757     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
758                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
759                            ConstantSDNode *Mask, SDNode *&NodeToMask);
760     /// Attempt to propagate a given AND node back to load leaves so that they
761     /// can be combined into narrow loads.
762     bool BackwardsPropagateMask(SDNode *N);
763 
764     /// Helper function for mergeConsecutiveStores which merges the component
765     /// store chains.
766     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
767                                 unsigned NumStores);
768 
769     /// Helper function for mergeConsecutiveStores which checks if all the store
770     /// nodes have the same underlying object. We can still reuse the first
771     /// store's pointer info if all the stores are from the same object.
772     bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
773 
774     /// This is a helper function for mergeConsecutiveStores. When the source
775     /// elements of the consecutive stores are all constants or all extracted
776     /// vector elements, try to merge them into one larger store introducing
777     /// bitcasts if necessary.  \return True if a merged store was created.
778     bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
779                                          EVT MemVT, unsigned NumStores,
780                                          bool IsConstantSrc, bool UseVector,
781                                          bool UseTrunc);
782 
783     /// This is a helper function for mergeConsecutiveStores. Stores that
784     /// potentially may be merged with St are placed in StoreNodes. On success,
785     /// returns a chain predecessor to all store candidates.
786     SDNode *getStoreMergeCandidates(StoreSDNode *St,
787                                     SmallVectorImpl<MemOpLink> &StoreNodes);
788 
789     /// Helper function for mergeConsecutiveStores. Checks if candidate stores
790     /// have indirect dependency through their operands. RootNode is the
791     /// predecessor to all stores calculated by getStoreMergeCandidates and is
792     /// used to prune the dependency check. \return True if safe to merge.
793     bool checkMergeStoreCandidatesForDependencies(
794         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
795         SDNode *RootNode);
796 
797     /// This is a helper function for mergeConsecutiveStores. Given a list of
798     /// store candidates, find the first N that are consecutive in memory.
799     /// Returns 0 if there are not at least 2 consecutive stores to try merging.
800     unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
801                                   int64_t ElementSizeBytes) const;
802 
803     /// This is a helper function for mergeConsecutiveStores. It is used for
804     /// store chains that are composed entirely of constant values.
805     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
806                                   unsigned NumConsecutiveStores,
807                                   EVT MemVT, SDNode *Root, bool AllowVectors);
808 
809     /// This is a helper function for mergeConsecutiveStores. It is used for
810     /// store chains that are composed entirely of extracted vector elements.
811     /// When extracting multiple vector elements, try to store them in one
812     /// vector store rather than a sequence of scalar stores.
813     bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
814                                  unsigned NumConsecutiveStores, EVT MemVT,
815                                  SDNode *Root);
816 
817     /// This is a helper function for mergeConsecutiveStores. It is used for
818     /// store chains that are composed entirely of loaded values.
819     bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
820                               unsigned NumConsecutiveStores, EVT MemVT,
821                               SDNode *Root, bool AllowVectors,
822                               bool IsNonTemporalStore, bool IsNonTemporalLoad);
823 
824     /// Merge consecutive store operations into a wide store.
825     /// This optimization uses wide integers or vectors when possible.
826     /// \return true if stores were merged.
827     bool mergeConsecutiveStores(StoreSDNode *St);
828 
829     /// Try to transform a truncation where C is a constant:
830     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
831     ///
832     /// \p N needs to be a truncation and its first operand an AND. Other
833     /// requirements are checked by the function (e.g. that trunc is
834     /// single-use) and if missed an empty SDValue is returned.
835     SDValue distributeTruncateThroughAnd(SDNode *N);
836 
837     /// Helper function to determine whether the target supports operation
838     /// given by \p Opcode for type \p VT, that is, whether the operation
839     /// is legal or custom before legalizing operations, and whether is
840     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)841     bool hasOperation(unsigned Opcode, EVT VT) {
842       return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
843     }
844 
845   public:
846     /// Runs the dag combiner on all nodes in the work list
847     void Run(CombineLevel AtLevel);
848 
getDAG() const849     SelectionDAG &getDAG() const { return DAG; }
850 
851     /// Convenience wrapper around TargetLowering::getShiftAmountTy.
getShiftAmountTy(EVT LHSTy)852     EVT getShiftAmountTy(EVT LHSTy) {
853       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout());
854     }
855 
856     /// This method returns true if we are running before type legalization or
857     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)858     bool isTypeLegal(const EVT &VT) {
859       if (!LegalTypes) return true;
860       return TLI.isTypeLegal(VT);
861     }
862 
863     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const864     EVT getSetCCResultType(EVT VT) const {
865       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
866     }
867 
868     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
869                          SDValue OrigLoad, SDValue ExtLoad,
870                          ISD::NodeType ExtType);
871   };
872 
873 /// This class is a DAGUpdateListener that removes any deleted
874 /// nodes from the worklist.
875 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
876   DAGCombiner &DC;
877 
878 public:
WorklistRemover(DAGCombiner & dc)879   explicit WorklistRemover(DAGCombiner &dc)
880     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
881 
NodeDeleted(SDNode * N,SDNode * E)882   void NodeDeleted(SDNode *N, SDNode *E) override {
883     DC.removeFromWorklist(N);
884   }
885 };
886 
887 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
888   DAGCombiner &DC;
889 
890 public:
WorklistInserter(DAGCombiner & dc)891   explicit WorklistInserter(DAGCombiner &dc)
892       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
893 
894   // FIXME: Ideally we could add N to the worklist, but this causes exponential
895   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)896   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
897 };
898 
899 } // end anonymous namespace
900 
901 //===----------------------------------------------------------------------===//
902 //  TargetLowering::DAGCombinerInfo implementation
903 //===----------------------------------------------------------------------===//
904 
AddToWorklist(SDNode * N)905 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
906   ((DAGCombiner*)DC)->AddToWorklist(N);
907 }
908 
909 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)910 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
911   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
912 }
913 
914 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)915 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
916   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
917 }
918 
919 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)920 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
921   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
922 }
923 
924 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)925 recursivelyDeleteUnusedNodes(SDNode *N) {
926   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
927 }
928 
929 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)930 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
931   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // Helper Functions
936 //===----------------------------------------------------------------------===//
937 
deleteAndRecombine(SDNode * N)938 void DAGCombiner::deleteAndRecombine(SDNode *N) {
939   removeFromWorklist(N);
940 
941   // If the operands of this node are only used by the node, they will now be
942   // dead. Make sure to re-visit them and recursively delete dead nodes.
943   for (const SDValue &Op : N->ops())
944     // For an operand generating multiple values, one of the values may
945     // become dead allowing further simplification (e.g. split index
946     // arithmetic from an indexed load).
947     if (Op->hasOneUse() || Op->getNumValues() > 1)
948       AddToWorklist(Op.getNode());
949 
950   DAG.DeleteNode(N);
951 }
952 
953 // APInts must be the same size for most operations, this helper
954 // function zero extends the shorter of the pair so that they match.
955 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)956 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
957   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
958   LHS = LHS.zext(Bits);
959   RHS = RHS.zext(Bits);
960 }
961 
962 // Return true if this node is a setcc, or is a select_cc
963 // that selects between the target values used for true and false, making it
964 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
965 // the appropriate nodes based on the type of node we are checking. This
966 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const967 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
968                                     SDValue &CC, bool MatchStrict) const {
969   if (N.getOpcode() == ISD::SETCC) {
970     LHS = N.getOperand(0);
971     RHS = N.getOperand(1);
972     CC  = N.getOperand(2);
973     return true;
974   }
975 
976   if (MatchStrict &&
977       (N.getOpcode() == ISD::STRICT_FSETCC ||
978        N.getOpcode() == ISD::STRICT_FSETCCS)) {
979     LHS = N.getOperand(1);
980     RHS = N.getOperand(2);
981     CC  = N.getOperand(3);
982     return true;
983   }
984 
985   if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
986       !TLI.isConstFalseVal(N.getOperand(3)))
987     return false;
988 
989   if (TLI.getBooleanContents(N.getValueType()) ==
990       TargetLowering::UndefinedBooleanContent)
991     return false;
992 
993   LHS = N.getOperand(0);
994   RHS = N.getOperand(1);
995   CC  = N.getOperand(4);
996   return true;
997 }
998 
999 /// Return true if this is a SetCC-equivalent operation with only one use.
1000 /// If this is true, it allows the users to invert the operation for free when
1001 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const1002 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1003   SDValue N0, N1, N2;
1004   if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
1005     return true;
1006   return false;
1007 }
1008 
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)1009 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1010   if (!ScalarTy.isSimple())
1011     return false;
1012 
1013   uint64_t MaskForTy = 0ULL;
1014   switch (ScalarTy.getSimpleVT().SimpleTy) {
1015   case MVT::i8:
1016     MaskForTy = 0xFFULL;
1017     break;
1018   case MVT::i16:
1019     MaskForTy = 0xFFFFULL;
1020     break;
1021   case MVT::i32:
1022     MaskForTy = 0xFFFFFFFFULL;
1023     break;
1024   default:
1025     return false;
1026     break;
1027   }
1028 
1029   APInt Val;
1030   if (ISD::isConstantSplatVector(N, Val))
1031     return Val.getLimitedValue() == MaskForTy;
1032 
1033   return false;
1034 }
1035 
1036 // Determines if it is a constant integer or a splat/build vector of constant
1037 // integers (and undefs).
1038 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)1039 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1040   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
1041     return !(Const->isOpaque() && NoOpaques);
1042   if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1043     return false;
1044   unsigned BitWidth = N.getScalarValueSizeInBits();
1045   for (const SDValue &Op : N->op_values()) {
1046     if (Op.isUndef())
1047       continue;
1048     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
1049     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1050         (Const->isOpaque() && NoOpaques))
1051       return false;
1052   }
1053   return true;
1054 }
1055 
1056 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1057 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)1058 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1059   if (V.getOpcode() != ISD::BUILD_VECTOR)
1060     return false;
1061   return isConstantOrConstantVector(V, NoOpaques) ||
1062          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
1063 }
1064 
1065 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)1066 static bool canSplitIdx(LoadSDNode *LD) {
1067   return MaySplitLoadIndex &&
1068          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
1069           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
1070 }
1071 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1072 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1073                                                              const SDLoc &DL,
1074                                                              SDNode *N,
1075                                                              SDValue N0,
1076                                                              SDValue N1) {
1077   // Currently this only tries to ensure we don't undo the GEP splits done by
1078   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1079   // we check if the following transformation would be problematic:
1080   // (load/store (add, (add, x, offset1), offset2)) ->
1081   // (load/store (add, x, offset1+offset2)).
1082 
1083   // (load/store (add, (add, x, y), offset2)) ->
1084   // (load/store (add, (add, x, offset2), y)).
1085 
1086   if (N0.getOpcode() != ISD::ADD)
1087     return false;
1088 
1089   // Check for vscale addressing modes.
1090   // (load/store (add/sub (add x, y), vscale))
1091   // (load/store (add/sub (add x, y), (lsl vscale, C)))
1092   // (load/store (add/sub (add x, y), (mul vscale, C)))
1093   if ((N1.getOpcode() == ISD::VSCALE ||
1094        ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
1095         N1.getOperand(0).getOpcode() == ISD::VSCALE &&
1096         isa<ConstantSDNode>(N1.getOperand(1)))) &&
1097       N1.getValueType().getFixedSizeInBits() <= 64) {
1098     int64_t ScalableOffset = N1.getOpcode() == ISD::VSCALE
1099                                  ? N1.getConstantOperandVal(0)
1100                                  : (N1.getOperand(0).getConstantOperandVal(0) *
1101                                     (N1.getOpcode() == ISD::SHL
1102                                          ? (1LL << N1.getConstantOperandVal(1))
1103                                          : N1.getConstantOperandVal(1)));
1104     if (Opc == ISD::SUB)
1105       ScalableOffset = -ScalableOffset;
1106     if (all_of(N->uses(), [&](SDNode *Node) {
1107           if (auto *LoadStore = dyn_cast<MemSDNode>(Node);
1108               LoadStore && LoadStore->getBasePtr().getNode() == N) {
1109             TargetLoweringBase::AddrMode AM;
1110             AM.HasBaseReg = true;
1111             AM.ScalableOffset = ScalableOffset;
1112             EVT VT = LoadStore->getMemoryVT();
1113             unsigned AS = LoadStore->getAddressSpace();
1114             Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1115             return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy,
1116                                              AS);
1117           }
1118           return false;
1119         }))
1120       return true;
1121   }
1122 
1123   if (Opc != ISD::ADD)
1124     return false;
1125 
1126   auto *C2 = dyn_cast<ConstantSDNode>(N1);
1127   if (!C2)
1128     return false;
1129 
1130   const APInt &C2APIntVal = C2->getAPIntValue();
1131   if (C2APIntVal.getSignificantBits() > 64)
1132     return false;
1133 
1134   if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1135     if (N0.hasOneUse())
1136       return false;
1137 
1138     const APInt &C1APIntVal = C1->getAPIntValue();
1139     const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1140     if (CombinedValueIntVal.getSignificantBits() > 64)
1141       return false;
1142     const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1143 
1144     for (SDNode *Node : N->uses()) {
1145       if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1146         // Is x[offset2] already not a legal addressing mode? If so then
1147         // reassociating the constants breaks nothing (we test offset2 because
1148         // that's the one we hope to fold into the load or store).
1149         TargetLoweringBase::AddrMode AM;
1150         AM.HasBaseReg = true;
1151         AM.BaseOffs = C2APIntVal.getSExtValue();
1152         EVT VT = LoadStore->getMemoryVT();
1153         unsigned AS = LoadStore->getAddressSpace();
1154         Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1155         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1156           continue;
1157 
1158         // Would x[offset1+offset2] still be a legal addressing mode?
1159         AM.BaseOffs = CombinedValue;
1160         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1161           return true;
1162       }
1163     }
1164   } else {
1165     if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1166       if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1167         return false;
1168 
1169     for (SDNode *Node : N->uses()) {
1170       auto *LoadStore = dyn_cast<MemSDNode>(Node);
1171       if (!LoadStore)
1172         return false;
1173 
1174       // Is x[offset2] a legal addressing mode? If so then
1175       // reassociating the constants breaks address pattern
1176       TargetLoweringBase::AddrMode AM;
1177       AM.HasBaseReg = true;
1178       AM.BaseOffs = C2APIntVal.getSExtValue();
1179       EVT VT = LoadStore->getMemoryVT();
1180       unsigned AS = LoadStore->getAddressSpace();
1181       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1182       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1183         return false;
1184     }
1185     return true;
1186   }
1187 
1188   return false;
1189 }
1190 
1191 /// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1192 /// \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1193 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1194                                                SDValue N0, SDValue N1,
1195                                                SDNodeFlags Flags) {
1196   EVT VT = N0.getValueType();
1197 
1198   if (N0.getOpcode() != Opc)
1199     return SDValue();
1200 
1201   SDValue N00 = N0.getOperand(0);
1202   SDValue N01 = N0.getOperand(1);
1203 
1204   if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N01))) {
1205     SDNodeFlags NewFlags;
1206     if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1207         Flags.hasNoUnsignedWrap())
1208       NewFlags.setNoUnsignedWrap(true);
1209 
1210     if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N1))) {
1211       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1212       if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1}))
1213         return DAG.getNode(Opc, DL, VT, N00, OpNode, NewFlags);
1214       return SDValue();
1215     }
1216     if (TLI.isReassocProfitable(DAG, N0, N1)) {
1217       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1218       //              iff (op x, c1) has one use
1219       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, NewFlags);
1220       return DAG.getNode(Opc, DL, VT, OpNode, N01, NewFlags);
1221     }
1222   }
1223 
1224   // Check for repeated operand logic simplifications.
1225   if (Opc == ISD::AND || Opc == ISD::OR) {
1226     // (N00 & N01) & N00 --> N00 & N01
1227     // (N00 & N01) & N01 --> N00 & N01
1228     // (N00 | N01) | N00 --> N00 | N01
1229     // (N00 | N01) | N01 --> N00 | N01
1230     if (N1 == N00 || N1 == N01)
1231       return N0;
1232   }
1233   if (Opc == ISD::XOR) {
1234     // (N00 ^ N01) ^ N00 --> N01
1235     if (N1 == N00)
1236       return N01;
1237     // (N00 ^ N01) ^ N01 --> N00
1238     if (N1 == N01)
1239       return N00;
1240   }
1241 
1242   if (TLI.isReassocProfitable(DAG, N0, N1)) {
1243     if (N1 != N01) {
1244       // Reassociate if (op N00, N1) already exist
1245       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1246         // if Op (Op N00, N1), N01 already exist
1247         // we need to stop reassciate to avoid dead loop
1248         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1249           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1250       }
1251     }
1252 
1253     if (N1 != N00) {
1254       // Reassociate if (op N01, N1) already exist
1255       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1256         // if Op (Op N01, N1), N00 already exist
1257         // we need to stop reassciate to avoid dead loop
1258         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1259           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1260       }
1261     }
1262 
1263     // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1264     // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1265     // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1266     // comparisons with the same predicate. This enables optimizations as the
1267     // following one:
1268     // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1269     // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1270     if (Opc == ISD::AND || Opc == ISD::OR) {
1271       if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1272           N01->getOpcode() == ISD::SETCC) {
1273         ISD::CondCode CC1 = cast<CondCodeSDNode>(N1.getOperand(2))->get();
1274         ISD::CondCode CC00 = cast<CondCodeSDNode>(N00.getOperand(2))->get();
1275         ISD::CondCode CC01 = cast<CondCodeSDNode>(N01.getOperand(2))->get();
1276         if (CC1 == CC00 && CC1 != CC01) {
1277           SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, Flags);
1278           return DAG.getNode(Opc, DL, VT, OpNode, N01, Flags);
1279         }
1280         if (CC1 == CC01 && CC1 != CC00) {
1281           SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N01, N1, Flags);
1282           return DAG.getNode(Opc, DL, VT, OpNode, N00, Flags);
1283         }
1284       }
1285     }
1286   }
1287 
1288   return SDValue();
1289 }
1290 
1291 /// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1292 /// same kind of operation as \p Opc.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1293 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1294                                     SDValue N1, SDNodeFlags Flags) {
1295   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1296 
1297   // Floating-point reassociation is not allowed without loose FP math.
1298   if (N0.getValueType().isFloatingPoint() ||
1299       N1.getValueType().isFloatingPoint())
1300     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1301       return SDValue();
1302 
1303   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1304     return Combined;
1305   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0, Flags))
1306     return Combined;
1307   return SDValue();
1308 }
1309 
1310 // Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1311 // Note that we only expect Flags to be passed from FP operations. For integer
1312 // operations they need to be dropped.
reassociateReduction(unsigned RedOpc,unsigned Opc,const SDLoc & DL,EVT VT,SDValue N0,SDValue N1,SDNodeFlags Flags)1313 SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1314                                           const SDLoc &DL, EVT VT, SDValue N0,
1315                                           SDValue N1, SDNodeFlags Flags) {
1316   if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1317       N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
1318       N0->hasOneUse() && N1->hasOneUse() &&
1319       TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) &&
1320       TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
1321     SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1322     return DAG.getNode(RedOpc, DL, VT,
1323                        DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
1324                                    N0.getOperand(0), N1.getOperand(0)));
1325   }
1326   return SDValue();
1327 }
1328 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1329 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1330                                bool AddTo) {
1331   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1332   ++NodesCombined;
1333   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1334              To[0].dump(&DAG);
1335              dbgs() << " and " << NumTo - 1 << " other values\n");
1336   for (unsigned i = 0, e = NumTo; i != e; ++i)
1337     assert((!To[i].getNode() ||
1338             N->getValueType(i) == To[i].getValueType()) &&
1339            "Cannot combine value to value of different type!");
1340 
1341   WorklistRemover DeadNodes(*this);
1342   DAG.ReplaceAllUsesWith(N, To);
1343   if (AddTo) {
1344     // Push the new nodes and any users onto the worklist
1345     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1346       if (To[i].getNode())
1347         AddToWorklistWithUsers(To[i].getNode());
1348     }
1349   }
1350 
1351   // Finally, if the node is now dead, remove it from the graph.  The node
1352   // may not be dead if the replacement process recursively simplified to
1353   // something else needing this node.
1354   if (N->use_empty())
1355     deleteAndRecombine(N);
1356   return SDValue(N, 0);
1357 }
1358 
1359 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1360 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1361   // Replace the old value with the new one.
1362   ++NodesCombined;
1363   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1364              dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1365 
1366   // Replace all uses.
1367   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1368 
1369   // Push the new node and any (possibly new) users onto the worklist.
1370   AddToWorklistWithUsers(TLO.New.getNode());
1371 
1372   // Finally, if the node is now dead, remove it from the graph.
1373   recursivelyDeleteUnusedNodes(TLO.Old.getNode());
1374 }
1375 
1376 /// Check the specified integer node value to see if it can be simplified or if
1377 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1378 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1379                                        const APInt &DemandedElts,
1380                                        bool AssumeSingleUse) {
1381   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1382   KnownBits Known;
1383   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1384                                 AssumeSingleUse))
1385     return false;
1386 
1387   // Revisit the node.
1388   AddToWorklist(Op.getNode());
1389 
1390   CommitTargetLoweringOpt(TLO);
1391   return true;
1392 }
1393 
1394 /// Check the specified vector node value to see if it can be simplified or
1395 /// if things it uses can be simplified as it only uses some of the elements.
1396 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1397 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1398                                              const APInt &DemandedElts,
1399                                              bool AssumeSingleUse) {
1400   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1401   APInt KnownUndef, KnownZero;
1402   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1403                                       TLO, 0, AssumeSingleUse))
1404     return false;
1405 
1406   // Revisit the node.
1407   AddToWorklist(Op.getNode());
1408 
1409   CommitTargetLoweringOpt(TLO);
1410   return true;
1411 }
1412 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1413 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1414   SDLoc DL(Load);
1415   EVT VT = Load->getValueType(0);
1416   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1417 
1418   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1419              Trunc.dump(&DAG); dbgs() << '\n');
1420 
1421   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1422   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1423 
1424   AddToWorklist(Trunc.getNode());
1425   recursivelyDeleteUnusedNodes(Load);
1426 }
1427 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1428 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1429   Replace = false;
1430   SDLoc DL(Op);
1431   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1432     LoadSDNode *LD = cast<LoadSDNode>(Op);
1433     EVT MemVT = LD->getMemoryVT();
1434     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1435                                                       : LD->getExtensionType();
1436     Replace = true;
1437     return DAG.getExtLoad(ExtType, DL, PVT,
1438                           LD->getChain(), LD->getBasePtr(),
1439                           MemVT, LD->getMemOperand());
1440   }
1441 
1442   unsigned Opc = Op.getOpcode();
1443   switch (Opc) {
1444   default: break;
1445   case ISD::AssertSext:
1446     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1447       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1448     break;
1449   case ISD::AssertZext:
1450     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1451       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1452     break;
1453   case ISD::Constant: {
1454     unsigned ExtOpc =
1455       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1456     return DAG.getNode(ExtOpc, DL, PVT, Op);
1457   }
1458   }
1459 
1460   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1461     return SDValue();
1462   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1463 }
1464 
SExtPromoteOperand(SDValue Op,EVT PVT)1465 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1466   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1467     return SDValue();
1468   EVT OldVT = Op.getValueType();
1469   SDLoc DL(Op);
1470   bool Replace = false;
1471   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1472   if (!NewOp.getNode())
1473     return SDValue();
1474   AddToWorklist(NewOp.getNode());
1475 
1476   if (Replace)
1477     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1478   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1479                      DAG.getValueType(OldVT));
1480 }
1481 
ZExtPromoteOperand(SDValue Op,EVT PVT)1482 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1483   EVT OldVT = Op.getValueType();
1484   SDLoc DL(Op);
1485   bool Replace = false;
1486   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1487   if (!NewOp.getNode())
1488     return SDValue();
1489   AddToWorklist(NewOp.getNode());
1490 
1491   if (Replace)
1492     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1493   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1494 }
1495 
1496 /// Promote the specified integer binary operation if the target indicates it is
1497 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1498 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1499 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1500   if (!LegalOperations)
1501     return SDValue();
1502 
1503   EVT VT = Op.getValueType();
1504   if (VT.isVector() || !VT.isInteger())
1505     return SDValue();
1506 
1507   // If operation type is 'undesirable', e.g. i16 on x86, consider
1508   // promoting it.
1509   unsigned Opc = Op.getOpcode();
1510   if (TLI.isTypeDesirableForOp(Opc, VT))
1511     return SDValue();
1512 
1513   EVT PVT = VT;
1514   // Consult target whether it is a good idea to promote this operation and
1515   // what's the right type to promote it to.
1516   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1517     assert(PVT != VT && "Don't know what type to promote to!");
1518 
1519     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1520 
1521     bool Replace0 = false;
1522     SDValue N0 = Op.getOperand(0);
1523     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1524 
1525     bool Replace1 = false;
1526     SDValue N1 = Op.getOperand(1);
1527     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1528     SDLoc DL(Op);
1529 
1530     SDValue RV =
1531         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1532 
1533     // We are always replacing N0/N1's use in N and only need additional
1534     // replacements if there are additional uses.
1535     // Note: We are checking uses of the *nodes* (SDNode) rather than values
1536     //       (SDValue) here because the node may reference multiple values
1537     //       (for example, the chain value of a load node).
1538     Replace0 &= !N0->hasOneUse();
1539     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1540 
1541     // Combine Op here so it is preserved past replacements.
1542     CombineTo(Op.getNode(), RV);
1543 
1544     // If operands have a use ordering, make sure we deal with
1545     // predecessor first.
1546     if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1547       std::swap(N0, N1);
1548       std::swap(NN0, NN1);
1549     }
1550 
1551     if (Replace0) {
1552       AddToWorklist(NN0.getNode());
1553       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1554     }
1555     if (Replace1) {
1556       AddToWorklist(NN1.getNode());
1557       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1558     }
1559     return Op;
1560   }
1561   return SDValue();
1562 }
1563 
1564 /// Promote the specified integer shift operation if the target indicates it is
1565 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1566 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1567 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1568   if (!LegalOperations)
1569     return SDValue();
1570 
1571   EVT VT = Op.getValueType();
1572   if (VT.isVector() || !VT.isInteger())
1573     return SDValue();
1574 
1575   // If operation type is 'undesirable', e.g. i16 on x86, consider
1576   // promoting it.
1577   unsigned Opc = Op.getOpcode();
1578   if (TLI.isTypeDesirableForOp(Opc, VT))
1579     return SDValue();
1580 
1581   EVT PVT = VT;
1582   // Consult target whether it is a good idea to promote this operation and
1583   // what's the right type to promote it to.
1584   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1585     assert(PVT != VT && "Don't know what type to promote to!");
1586 
1587     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1588 
1589     bool Replace = false;
1590     SDValue N0 = Op.getOperand(0);
1591     if (Opc == ISD::SRA)
1592       N0 = SExtPromoteOperand(N0, PVT);
1593     else if (Opc == ISD::SRL)
1594       N0 = ZExtPromoteOperand(N0, PVT);
1595     else
1596       N0 = PromoteOperand(N0, PVT, Replace);
1597 
1598     if (!N0.getNode())
1599       return SDValue();
1600 
1601     SDLoc DL(Op);
1602     SDValue N1 = Op.getOperand(1);
1603     SDValue RV =
1604         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1605 
1606     if (Replace)
1607       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1608 
1609     // Deal with Op being deleted.
1610     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1611       return RV;
1612   }
1613   return SDValue();
1614 }
1615 
PromoteExtend(SDValue Op)1616 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1617   if (!LegalOperations)
1618     return SDValue();
1619 
1620   EVT VT = Op.getValueType();
1621   if (VT.isVector() || !VT.isInteger())
1622     return SDValue();
1623 
1624   // If operation type is 'undesirable', e.g. i16 on x86, consider
1625   // promoting it.
1626   unsigned Opc = Op.getOpcode();
1627   if (TLI.isTypeDesirableForOp(Opc, VT))
1628     return SDValue();
1629 
1630   EVT PVT = VT;
1631   // Consult target whether it is a good idea to promote this operation and
1632   // what's the right type to promote it to.
1633   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1634     assert(PVT != VT && "Don't know what type to promote to!");
1635     // fold (aext (aext x)) -> (aext x)
1636     // fold (aext (zext x)) -> (zext x)
1637     // fold (aext (sext x)) -> (sext x)
1638     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1639     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1640   }
1641   return SDValue();
1642 }
1643 
PromoteLoad(SDValue Op)1644 bool DAGCombiner::PromoteLoad(SDValue Op) {
1645   if (!LegalOperations)
1646     return false;
1647 
1648   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1649     return false;
1650 
1651   EVT VT = Op.getValueType();
1652   if (VT.isVector() || !VT.isInteger())
1653     return false;
1654 
1655   // If operation type is 'undesirable', e.g. i16 on x86, consider
1656   // promoting it.
1657   unsigned Opc = Op.getOpcode();
1658   if (TLI.isTypeDesirableForOp(Opc, VT))
1659     return false;
1660 
1661   EVT PVT = VT;
1662   // Consult target whether it is a good idea to promote this operation and
1663   // what's the right type to promote it to.
1664   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1665     assert(PVT != VT && "Don't know what type to promote to!");
1666 
1667     SDLoc DL(Op);
1668     SDNode *N = Op.getNode();
1669     LoadSDNode *LD = cast<LoadSDNode>(N);
1670     EVT MemVT = LD->getMemoryVT();
1671     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1672                                                       : LD->getExtensionType();
1673     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1674                                    LD->getChain(), LD->getBasePtr(),
1675                                    MemVT, LD->getMemOperand());
1676     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1677 
1678     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1679                Result.dump(&DAG); dbgs() << '\n');
1680 
1681     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1682     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1683 
1684     AddToWorklist(Result.getNode());
1685     recursivelyDeleteUnusedNodes(N);
1686     return true;
1687   }
1688 
1689   return false;
1690 }
1691 
1692 /// Recursively delete a node which has no uses and any operands for
1693 /// which it is the only use.
1694 ///
1695 /// Note that this both deletes the nodes and removes them from the worklist.
1696 /// It also adds any nodes who have had a user deleted to the worklist as they
1697 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1698 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1699   if (!N->use_empty())
1700     return false;
1701 
1702   SmallSetVector<SDNode *, 16> Nodes;
1703   Nodes.insert(N);
1704   do {
1705     N = Nodes.pop_back_val();
1706     if (!N)
1707       continue;
1708 
1709     if (N->use_empty()) {
1710       for (const SDValue &ChildN : N->op_values())
1711         Nodes.insert(ChildN.getNode());
1712 
1713       removeFromWorklist(N);
1714       DAG.DeleteNode(N);
1715     } else {
1716       AddToWorklist(N);
1717     }
1718   } while (!Nodes.empty());
1719   return true;
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 //  Main DAG Combiner implementation
1724 //===----------------------------------------------------------------------===//
1725 
Run(CombineLevel AtLevel)1726 void DAGCombiner::Run(CombineLevel AtLevel) {
1727   // set the instance variables, so that the various visit routines may use it.
1728   Level = AtLevel;
1729   LegalDAG = Level >= AfterLegalizeDAG;
1730   LegalOperations = Level >= AfterLegalizeVectorOps;
1731   LegalTypes = Level >= AfterLegalizeTypes;
1732 
1733   WorklistInserter AddNodes(*this);
1734 
1735   // Add all the dag nodes to the worklist.
1736   //
1737   // Note: All nodes are not added to PruningList here, this is because the only
1738   // nodes which can be deleted are those which have no uses and all other nodes
1739   // which would otherwise be added to the worklist by the first call to
1740   // getNextWorklistEntry are already present in it.
1741   for (SDNode &Node : DAG.allnodes())
1742     AddToWorklist(&Node, /* IsCandidateForPruning */ Node.use_empty());
1743 
1744   // Create a dummy node (which is not added to allnodes), that adds a reference
1745   // to the root node, preventing it from being deleted, and tracking any
1746   // changes of the root.
1747   HandleSDNode Dummy(DAG.getRoot());
1748 
1749   // While we have a valid worklist entry node, try to combine it.
1750   while (SDNode *N = getNextWorklistEntry()) {
1751     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1752     // N is deleted from the DAG, since they too may now be dead or may have a
1753     // reduced number of uses, allowing other xforms.
1754     if (recursivelyDeleteUnusedNodes(N))
1755       continue;
1756 
1757     WorklistRemover DeadNodes(*this);
1758 
1759     // If this combine is running after legalizing the DAG, re-legalize any
1760     // nodes pulled off the worklist.
1761     if (LegalDAG) {
1762       SmallSetVector<SDNode *, 16> UpdatedNodes;
1763       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1764 
1765       for (SDNode *LN : UpdatedNodes)
1766         AddToWorklistWithUsers(LN);
1767 
1768       if (!NIsValid)
1769         continue;
1770     }
1771 
1772     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1773 
1774     // Add any operands of the new node which have not yet been combined to the
1775     // worklist as well. getNextWorklistEntry flags nodes that have been
1776     // combined before. Because the worklist uniques things already, this won't
1777     // repeatedly process the same operand.
1778     for (const SDValue &ChildN : N->op_values())
1779       AddToWorklist(ChildN.getNode(), /*IsCandidateForPruning=*/true,
1780                     /*SkipIfCombinedBefore=*/true);
1781 
1782     SDValue RV = combine(N);
1783 
1784     if (!RV.getNode())
1785       continue;
1786 
1787     ++NodesCombined;
1788 
1789     // Invalidate cached info.
1790     ChainsWithoutMergeableStores.clear();
1791 
1792     // If we get back the same node we passed in, rather than a new node or
1793     // zero, we know that the node must have defined multiple values and
1794     // CombineTo was used.  Since CombineTo takes care of the worklist
1795     // mechanics for us, we have no work to do in this case.
1796     if (RV.getNode() == N)
1797       continue;
1798 
1799     assert(N->getOpcode() != ISD::DELETED_NODE &&
1800            RV.getOpcode() != ISD::DELETED_NODE &&
1801            "Node was deleted but visit returned new node!");
1802 
1803     LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1804 
1805     if (N->getNumValues() == RV->getNumValues())
1806       DAG.ReplaceAllUsesWith(N, RV.getNode());
1807     else {
1808       assert(N->getValueType(0) == RV.getValueType() &&
1809              N->getNumValues() == 1 && "Type mismatch");
1810       DAG.ReplaceAllUsesWith(N, &RV);
1811     }
1812 
1813     // Push the new node and any users onto the worklist.  Omit this if the
1814     // new node is the EntryToken (e.g. if a store managed to get optimized
1815     // out), because re-visiting the EntryToken and its users will not uncover
1816     // any additional opportunities, but there may be a large number of such
1817     // users, potentially causing compile time explosion.
1818     if (RV.getOpcode() != ISD::EntryToken)
1819       AddToWorklistWithUsers(RV.getNode());
1820 
1821     // Finally, if the node is now dead, remove it from the graph.  The node
1822     // may not be dead if the replacement process recursively simplified to
1823     // something else needing this node. This will also take care of adding any
1824     // operands which have lost a user to the worklist.
1825     recursivelyDeleteUnusedNodes(N);
1826   }
1827 
1828   // If the root changed (e.g. it was a dead load, update the root).
1829   DAG.setRoot(Dummy.getValue());
1830   DAG.RemoveDeadNodes();
1831 }
1832 
visit(SDNode * N)1833 SDValue DAGCombiner::visit(SDNode *N) {
1834   // clang-format off
1835   switch (N->getOpcode()) {
1836   default: break;
1837   case ISD::TokenFactor:        return visitTokenFactor(N);
1838   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1839   case ISD::ADD:                return visitADD(N);
1840   case ISD::SUB:                return visitSUB(N);
1841   case ISD::SADDSAT:
1842   case ISD::UADDSAT:            return visitADDSAT(N);
1843   case ISD::SSUBSAT:
1844   case ISD::USUBSAT:            return visitSUBSAT(N);
1845   case ISD::ADDC:               return visitADDC(N);
1846   case ISD::SADDO:
1847   case ISD::UADDO:              return visitADDO(N);
1848   case ISD::SUBC:               return visitSUBC(N);
1849   case ISD::SSUBO:
1850   case ISD::USUBO:              return visitSUBO(N);
1851   case ISD::ADDE:               return visitADDE(N);
1852   case ISD::UADDO_CARRY:        return visitUADDO_CARRY(N);
1853   case ISD::SADDO_CARRY:        return visitSADDO_CARRY(N);
1854   case ISD::SUBE:               return visitSUBE(N);
1855   case ISD::USUBO_CARRY:        return visitUSUBO_CARRY(N);
1856   case ISD::SSUBO_CARRY:        return visitSSUBO_CARRY(N);
1857   case ISD::SMULFIX:
1858   case ISD::SMULFIXSAT:
1859   case ISD::UMULFIX:
1860   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1861   case ISD::MUL:                return visitMUL<EmptyMatchContext>(N);
1862   case ISD::SDIV:               return visitSDIV(N);
1863   case ISD::UDIV:               return visitUDIV(N);
1864   case ISD::SREM:
1865   case ISD::UREM:               return visitREM(N);
1866   case ISD::MULHU:              return visitMULHU(N);
1867   case ISD::MULHS:              return visitMULHS(N);
1868   case ISD::AVGFLOORS:
1869   case ISD::AVGFLOORU:
1870   case ISD::AVGCEILS:
1871   case ISD::AVGCEILU:           return visitAVG(N);
1872   case ISD::ABDS:
1873   case ISD::ABDU:               return visitABD(N);
1874   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1875   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1876   case ISD::SMULO:
1877   case ISD::UMULO:              return visitMULO(N);
1878   case ISD::SMIN:
1879   case ISD::SMAX:
1880   case ISD::UMIN:
1881   case ISD::UMAX:               return visitIMINMAX(N);
1882   case ISD::AND:                return visitAND(N);
1883   case ISD::OR:                 return visitOR(N);
1884   case ISD::XOR:                return visitXOR(N);
1885   case ISD::SHL:                return visitSHL(N);
1886   case ISD::SRA:                return visitSRA(N);
1887   case ISD::SRL:                return visitSRL(N);
1888   case ISD::ROTR:
1889   case ISD::ROTL:               return visitRotate(N);
1890   case ISD::FSHL:
1891   case ISD::FSHR:               return visitFunnelShift(N);
1892   case ISD::SSHLSAT:
1893   case ISD::USHLSAT:            return visitSHLSAT(N);
1894   case ISD::ABS:                return visitABS(N);
1895   case ISD::BSWAP:              return visitBSWAP(N);
1896   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1897   case ISD::CTLZ:               return visitCTLZ(N);
1898   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1899   case ISD::CTTZ:               return visitCTTZ(N);
1900   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1901   case ISD::CTPOP:              return visitCTPOP(N);
1902   case ISD::SELECT:             return visitSELECT(N);
1903   case ISD::VSELECT:            return visitVSELECT(N);
1904   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1905   case ISD::SETCC:              return visitSETCC(N);
1906   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1907   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1908   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1909   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1910   case ISD::AssertSext:
1911   case ISD::AssertZext:         return visitAssertExt(N);
1912   case ISD::AssertAlign:        return visitAssertAlign(N);
1913   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1914   case ISD::SIGN_EXTEND_VECTOR_INREG:
1915   case ISD::ZERO_EXTEND_VECTOR_INREG:
1916   case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1917   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1918   case ISD::BITCAST:            return visitBITCAST(N);
1919   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1920   case ISD::FADD:               return visitFADD(N);
1921   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
1922   case ISD::FSUB:               return visitFSUB(N);
1923   case ISD::FMUL:               return visitFMUL(N);
1924   case ISD::FMA:                return visitFMA<EmptyMatchContext>(N);
1925   case ISD::FMAD:               return visitFMAD(N);
1926   case ISD::FDIV:               return visitFDIV(N);
1927   case ISD::FREM:               return visitFREM(N);
1928   case ISD::FSQRT:              return visitFSQRT(N);
1929   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1930   case ISD::FPOW:               return visitFPOW(N);
1931   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1932   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1933   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1934   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1935   case ISD::LRINT:
1936   case ISD::LLRINT:             return visitXRINT(N);
1937   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1938   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
1939   case ISD::FNEG:               return visitFNEG(N);
1940   case ISD::FABS:               return visitFABS(N);
1941   case ISD::FFLOOR:             return visitFFLOOR(N);
1942   case ISD::FMINNUM:
1943   case ISD::FMAXNUM:
1944   case ISD::FMINIMUM:
1945   case ISD::FMAXIMUM:           return visitFMinMax(N);
1946   case ISD::FCEIL:              return visitFCEIL(N);
1947   case ISD::FTRUNC:             return visitFTRUNC(N);
1948   case ISD::FFREXP:             return visitFFREXP(N);
1949   case ISD::BRCOND:             return visitBRCOND(N);
1950   case ISD::BR_CC:              return visitBR_CC(N);
1951   case ISD::LOAD:               return visitLOAD(N);
1952   case ISD::STORE:              return visitSTORE(N);
1953   case ISD::ATOMIC_STORE:       return visitATOMIC_STORE(N);
1954   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
1955   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1956   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
1957   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
1958   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
1959   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
1960   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
1961   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
1962   case ISD::MGATHER:            return visitMGATHER(N);
1963   case ISD::MLOAD:              return visitMLOAD(N);
1964   case ISD::MSCATTER:           return visitMSCATTER(N);
1965   case ISD::MSTORE:             return visitMSTORE(N);
1966   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
1967   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
1968   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
1969   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
1970   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
1971   case ISD::BF16_TO_FP:         return visitBF16_TO_FP(N);
1972   case ISD::FREEZE:             return visitFREEZE(N);
1973   case ISD::GET_FPENV_MEM:      return visitGET_FPENV_MEM(N);
1974   case ISD::SET_FPENV_MEM:      return visitSET_FPENV_MEM(N);
1975   case ISD::VECREDUCE_FADD:
1976   case ISD::VECREDUCE_FMUL:
1977   case ISD::VECREDUCE_ADD:
1978   case ISD::VECREDUCE_MUL:
1979   case ISD::VECREDUCE_AND:
1980   case ISD::VECREDUCE_OR:
1981   case ISD::VECREDUCE_XOR:
1982   case ISD::VECREDUCE_SMAX:
1983   case ISD::VECREDUCE_SMIN:
1984   case ISD::VECREDUCE_UMAX:
1985   case ISD::VECREDUCE_UMIN:
1986   case ISD::VECREDUCE_FMAX:
1987   case ISD::VECREDUCE_FMIN:
1988   case ISD::VECREDUCE_FMAXIMUM:
1989   case ISD::VECREDUCE_FMINIMUM:     return visitVECREDUCE(N);
1990 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
1991 #include "llvm/IR/VPIntrinsics.def"
1992     return visitVPOp(N);
1993   }
1994   // clang-format on
1995   return SDValue();
1996 }
1997 
combine(SDNode * N)1998 SDValue DAGCombiner::combine(SDNode *N) {
1999   if (!DebugCounter::shouldExecute(DAGCombineCounter))
2000     return SDValue();
2001 
2002   SDValue RV;
2003   if (!DisableGenericCombines)
2004     RV = visit(N);
2005 
2006   // If nothing happened, try a target-specific DAG combine.
2007   if (!RV.getNode()) {
2008     assert(N->getOpcode() != ISD::DELETED_NODE &&
2009            "Node was deleted but visit returned NULL!");
2010 
2011     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2012         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
2013 
2014       // Expose the DAG combiner to the target combiner impls.
2015       TargetLowering::DAGCombinerInfo
2016         DagCombineInfo(DAG, Level, false, this);
2017 
2018       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
2019     }
2020   }
2021 
2022   // If nothing happened still, try promoting the operation.
2023   if (!RV.getNode()) {
2024     switch (N->getOpcode()) {
2025     default: break;
2026     case ISD::ADD:
2027     case ISD::SUB:
2028     case ISD::MUL:
2029     case ISD::AND:
2030     case ISD::OR:
2031     case ISD::XOR:
2032       RV = PromoteIntBinOp(SDValue(N, 0));
2033       break;
2034     case ISD::SHL:
2035     case ISD::SRA:
2036     case ISD::SRL:
2037       RV = PromoteIntShiftOp(SDValue(N, 0));
2038       break;
2039     case ISD::SIGN_EXTEND:
2040     case ISD::ZERO_EXTEND:
2041     case ISD::ANY_EXTEND:
2042       RV = PromoteExtend(SDValue(N, 0));
2043       break;
2044     case ISD::LOAD:
2045       if (PromoteLoad(SDValue(N, 0)))
2046         RV = SDValue(N, 0);
2047       break;
2048     }
2049   }
2050 
2051   // If N is a commutative binary node, try to eliminate it if the commuted
2052   // version is already present in the DAG.
2053   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
2054     SDValue N0 = N->getOperand(0);
2055     SDValue N1 = N->getOperand(1);
2056 
2057     // Constant operands are canonicalized to RHS.
2058     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
2059       SDValue Ops[] = {N1, N0};
2060       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
2061                                             N->getFlags());
2062       if (CSENode)
2063         return SDValue(CSENode, 0);
2064     }
2065   }
2066 
2067   return RV;
2068 }
2069 
2070 /// Given a node, return its input chain if it has one, otherwise return a null
2071 /// sd operand.
getInputChainForNode(SDNode * N)2072 static SDValue getInputChainForNode(SDNode *N) {
2073   if (unsigned NumOps = N->getNumOperands()) {
2074     if (N->getOperand(0).getValueType() == MVT::Other)
2075       return N->getOperand(0);
2076     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
2077       return N->getOperand(NumOps-1);
2078     for (unsigned i = 1; i < NumOps-1; ++i)
2079       if (N->getOperand(i).getValueType() == MVT::Other)
2080         return N->getOperand(i);
2081   }
2082   return SDValue();
2083 }
2084 
visitTokenFactor(SDNode * N)2085 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2086   // If N has two operands, where one has an input chain equal to the other,
2087   // the 'other' chain is redundant.
2088   if (N->getNumOperands() == 2) {
2089     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
2090       return N->getOperand(0);
2091     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
2092       return N->getOperand(1);
2093   }
2094 
2095   // Don't simplify token factors if optnone.
2096   if (OptLevel == CodeGenOptLevel::None)
2097     return SDValue();
2098 
2099   // Don't simplify the token factor if the node itself has too many operands.
2100   if (N->getNumOperands() > TokenFactorInlineLimit)
2101     return SDValue();
2102 
2103   // If the sole user is a token factor, we should make sure we have a
2104   // chance to merge them together. This prevents TF chains from inhibiting
2105   // optimizations.
2106   if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
2107     AddToWorklist(*(N->use_begin()));
2108 
2109   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
2110   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
2111   SmallPtrSet<SDNode*, 16> SeenOps;
2112   bool Changed = false;             // If we should replace this token factor.
2113 
2114   // Start out with this token factor.
2115   TFs.push_back(N);
2116 
2117   // Iterate through token factors.  The TFs grows when new token factors are
2118   // encountered.
2119   for (unsigned i = 0; i < TFs.size(); ++i) {
2120     // Limit number of nodes to inline, to avoid quadratic compile times.
2121     // We have to add the outstanding Token Factors to Ops, otherwise we might
2122     // drop Ops from the resulting Token Factors.
2123     if (Ops.size() > TokenFactorInlineLimit) {
2124       for (unsigned j = i; j < TFs.size(); j++)
2125         Ops.emplace_back(TFs[j], 0);
2126       // Drop unprocessed Token Factors from TFs, so we do not add them to the
2127       // combiner worklist later.
2128       TFs.resize(i);
2129       break;
2130     }
2131 
2132     SDNode *TF = TFs[i];
2133     // Check each of the operands.
2134     for (const SDValue &Op : TF->op_values()) {
2135       switch (Op.getOpcode()) {
2136       case ISD::EntryToken:
2137         // Entry tokens don't need to be added to the list. They are
2138         // redundant.
2139         Changed = true;
2140         break;
2141 
2142       case ISD::TokenFactor:
2143         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
2144           // Queue up for processing.
2145           TFs.push_back(Op.getNode());
2146           Changed = true;
2147           break;
2148         }
2149         [[fallthrough]];
2150 
2151       default:
2152         // Only add if it isn't already in the list.
2153         if (SeenOps.insert(Op.getNode()).second)
2154           Ops.push_back(Op);
2155         else
2156           Changed = true;
2157         break;
2158       }
2159     }
2160   }
2161 
2162   // Re-visit inlined Token Factors, to clean them up in case they have been
2163   // removed. Skip the first Token Factor, as this is the current node.
2164   for (unsigned i = 1, e = TFs.size(); i < e; i++)
2165     AddToWorklist(TFs[i]);
2166 
2167   // Remove Nodes that are chained to another node in the list. Do so
2168   // by walking up chains breath-first stopping when we've seen
2169   // another operand. In general we must climb to the EntryNode, but we can exit
2170   // early if we find all remaining work is associated with just one operand as
2171   // no further pruning is possible.
2172 
2173   // List of nodes to search through and original Ops from which they originate.
2174   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2175   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2176   SmallPtrSet<SDNode *, 16> SeenChains;
2177   bool DidPruneOps = false;
2178 
2179   unsigned NumLeftToConsider = 0;
2180   for (const SDValue &Op : Ops) {
2181     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2182     OpWorkCount.push_back(1);
2183   }
2184 
2185   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2186     // If this is an Op, we can remove the op from the list. Remark any
2187     // search associated with it as from the current OpNumber.
2188     if (SeenOps.contains(Op)) {
2189       Changed = true;
2190       DidPruneOps = true;
2191       unsigned OrigOpNumber = 0;
2192       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2193         OrigOpNumber++;
2194       assert((OrigOpNumber != Ops.size()) &&
2195              "expected to find TokenFactor Operand");
2196       // Re-mark worklist from OrigOpNumber to OpNumber
2197       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2198         if (Worklist[i].second == OrigOpNumber) {
2199           Worklist[i].second = OpNumber;
2200         }
2201       }
2202       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2203       OpWorkCount[OrigOpNumber] = 0;
2204       NumLeftToConsider--;
2205     }
2206     // Add if it's a new chain
2207     if (SeenChains.insert(Op).second) {
2208       OpWorkCount[OpNumber]++;
2209       Worklist.push_back(std::make_pair(Op, OpNumber));
2210     }
2211   };
2212 
2213   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2214     // We need at least be consider at least 2 Ops to prune.
2215     if (NumLeftToConsider <= 1)
2216       break;
2217     auto CurNode = Worklist[i].first;
2218     auto CurOpNumber = Worklist[i].second;
2219     assert((OpWorkCount[CurOpNumber] > 0) &&
2220            "Node should not appear in worklist");
2221     switch (CurNode->getOpcode()) {
2222     case ISD::EntryToken:
2223       // Hitting EntryToken is the only way for the search to terminate without
2224       // hitting
2225       // another operand's search. Prevent us from marking this operand
2226       // considered.
2227       NumLeftToConsider++;
2228       break;
2229     case ISD::TokenFactor:
2230       for (const SDValue &Op : CurNode->op_values())
2231         AddToWorklist(i, Op.getNode(), CurOpNumber);
2232       break;
2233     case ISD::LIFETIME_START:
2234     case ISD::LIFETIME_END:
2235     case ISD::CopyFromReg:
2236     case ISD::CopyToReg:
2237       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2238       break;
2239     default:
2240       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2241         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2242       break;
2243     }
2244     OpWorkCount[CurOpNumber]--;
2245     if (OpWorkCount[CurOpNumber] == 0)
2246       NumLeftToConsider--;
2247   }
2248 
2249   // If we've changed things around then replace token factor.
2250   if (Changed) {
2251     SDValue Result;
2252     if (Ops.empty()) {
2253       // The entry token is the only possible outcome.
2254       Result = DAG.getEntryNode();
2255     } else {
2256       if (DidPruneOps) {
2257         SmallVector<SDValue, 8> PrunedOps;
2258         //
2259         for (const SDValue &Op : Ops) {
2260           if (SeenChains.count(Op.getNode()) == 0)
2261             PrunedOps.push_back(Op);
2262         }
2263         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2264       } else {
2265         Result = DAG.getTokenFactor(SDLoc(N), Ops);
2266       }
2267     }
2268     return Result;
2269   }
2270   return SDValue();
2271 }
2272 
2273 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2274 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2275   WorklistRemover DeadNodes(*this);
2276   // Replacing results may cause a different MERGE_VALUES to suddenly
2277   // be CSE'd with N, and carry its uses with it. Iterate until no
2278   // uses remain, to ensure that the node can be safely deleted.
2279   // First add the users of this node to the work list so that they
2280   // can be tried again once they have new operands.
2281   AddUsersToWorklist(N);
2282   do {
2283     // Do as a single replacement to avoid rewalking use lists.
2284     SmallVector<SDValue, 8> Ops;
2285     for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2286       Ops.push_back(N->getOperand(i));
2287     DAG.ReplaceAllUsesWith(N, Ops.data());
2288   } while (!N->use_empty());
2289   deleteAndRecombine(N);
2290   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
2291 }
2292 
2293 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2294 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2295 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2296   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2297   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2298 }
2299 
2300 // isTruncateOf - If N is a truncate of some other value, return true, record
2301 // the value being truncated in Op and which of Op's bits are zero/one in Known.
2302 // This function computes KnownBits to avoid a duplicated call to
2303 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)2304 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2305                          KnownBits &Known) {
2306   if (N->getOpcode() == ISD::TRUNCATE) {
2307     Op = N->getOperand(0);
2308     Known = DAG.computeKnownBits(Op);
2309     return true;
2310   }
2311 
2312   if (N.getValueType().getScalarType() != MVT::i1 ||
2313       !sd_match(
2314           N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
2315     return false;
2316 
2317   Known = DAG.computeKnownBits(Op);
2318   return (Known.Zero | 1).isAllOnes();
2319 }
2320 
2321 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2322 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2323 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2324                                     const TargetLowering &TLI) {
2325   EVT VT;
2326   unsigned AS;
2327 
2328   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2329     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2330       return false;
2331     VT = LD->getMemoryVT();
2332     AS = LD->getAddressSpace();
2333   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2334     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2335       return false;
2336     VT = ST->getMemoryVT();
2337     AS = ST->getAddressSpace();
2338   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2339     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2340       return false;
2341     VT = LD->getMemoryVT();
2342     AS = LD->getAddressSpace();
2343   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2344     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2345       return false;
2346     VT = ST->getMemoryVT();
2347     AS = ST->getAddressSpace();
2348   } else {
2349     return false;
2350   }
2351 
2352   TargetLowering::AddrMode AM;
2353   if (N->getOpcode() == ISD::ADD) {
2354     AM.HasBaseReg = true;
2355     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2356     if (Offset)
2357       // [reg +/- imm]
2358       AM.BaseOffs = Offset->getSExtValue();
2359     else
2360       // [reg +/- reg]
2361       AM.Scale = 1;
2362   } else if (N->getOpcode() == ISD::SUB) {
2363     AM.HasBaseReg = true;
2364     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2365     if (Offset)
2366       // [reg +/- imm]
2367       AM.BaseOffs = -Offset->getSExtValue();
2368     else
2369       // [reg +/- reg]
2370       AM.Scale = 1;
2371   } else {
2372     return false;
2373   }
2374 
2375   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2376                                    VT.getTypeForEVT(*DAG.getContext()), AS);
2377 }
2378 
2379 /// This inverts a canonicalization in IR that replaces a variable select arm
2380 /// with an identity constant. Codegen improves if we re-use the variable
2381 /// operand rather than load a constant. This can also be converted into a
2382 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2383 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2384                                               bool ShouldCommuteOperands) {
2385   // Match a select as operand 1. The identity constant that we are looking for
2386   // is only valid as operand 1 of a non-commutative binop.
2387   SDValue N0 = N->getOperand(0);
2388   SDValue N1 = N->getOperand(1);
2389   if (ShouldCommuteOperands)
2390     std::swap(N0, N1);
2391 
2392   // TODO: Should this apply to scalar select too?
2393   if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2394     return SDValue();
2395 
2396   // We can't hoist all instructions because of immediate UB (not speculatable).
2397   // For example div/rem by zero.
2398   if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2399     return SDValue();
2400 
2401   unsigned Opcode = N->getOpcode();
2402   EVT VT = N->getValueType(0);
2403   SDValue Cond = N1.getOperand(0);
2404   SDValue TVal = N1.getOperand(1);
2405   SDValue FVal = N1.getOperand(2);
2406 
2407   // This transform increases uses of N0, so freeze it to be safe.
2408   // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2409   unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2410   if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
2411     SDValue F0 = DAG.getFreeze(N0);
2412     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2413     return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2414   }
2415   // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2416   if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
2417     SDValue F0 = DAG.getFreeze(N0);
2418     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2419     return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2420   }
2421 
2422   return SDValue();
2423 }
2424 
foldBinOpIntoSelect(SDNode * BO)2425 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2426   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2427          "Unexpected binary operator");
2428 
2429   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2430   auto BinOpcode = BO->getOpcode();
2431   EVT VT = BO->getValueType(0);
2432   if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2433     if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2434       return Sel;
2435 
2436     if (TLI.isCommutativeBinOp(BO->getOpcode()))
2437       if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2438         return Sel;
2439   }
2440 
2441   // Don't do this unless the old select is going away. We want to eliminate the
2442   // binary operator, not replace a binop with a select.
2443   // TODO: Handle ISD::SELECT_CC.
2444   unsigned SelOpNo = 0;
2445   SDValue Sel = BO->getOperand(0);
2446   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2447     SelOpNo = 1;
2448     Sel = BO->getOperand(1);
2449 
2450     // Peek through trunc to shift amount type.
2451     if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2452          BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2453       // This is valid when the truncated bits of x are already zero.
2454       SDValue Op;
2455       KnownBits Known;
2456       if (isTruncateOf(DAG, Sel, Op, Known) &&
2457           Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2458         Sel = Op;
2459     }
2460   }
2461 
2462   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2463     return SDValue();
2464 
2465   SDValue CT = Sel.getOperand(1);
2466   if (!isConstantOrConstantVector(CT, true) &&
2467       !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2468     return SDValue();
2469 
2470   SDValue CF = Sel.getOperand(2);
2471   if (!isConstantOrConstantVector(CF, true) &&
2472       !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2473     return SDValue();
2474 
2475   // Bail out if any constants are opaque because we can't constant fold those.
2476   // The exception is "and" and "or" with either 0 or -1 in which case we can
2477   // propagate non constant operands into select. I.e.:
2478   // and (select Cond, 0, -1), X --> select Cond, 0, X
2479   // or X, (select Cond, -1, 0) --> select Cond, -1, X
2480   bool CanFoldNonConst =
2481       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2482       ((isNullOrNullSplat(CT) && isAllOnesOrAllOnesSplat(CF)) ||
2483        (isNullOrNullSplat(CF) && isAllOnesOrAllOnesSplat(CT)));
2484 
2485   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2486   if (!CanFoldNonConst &&
2487       !isConstantOrConstantVector(CBO, true) &&
2488       !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2489     return SDValue();
2490 
2491   SDLoc DL(Sel);
2492   SDValue NewCT, NewCF;
2493 
2494   if (CanFoldNonConst) {
2495     // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2496     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CT)) ||
2497         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CT)))
2498       NewCT = CT;
2499     else
2500       NewCT = CBO;
2501 
2502     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CF)) ||
2503         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CF)))
2504       NewCF = CF;
2505     else
2506       NewCF = CBO;
2507   } else {
2508     // We have a select-of-constants followed by a binary operator with a
2509     // constant. Eliminate the binop by pulling the constant math into the
2510     // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2511     // CBO, CF + CBO
2512     NewCT = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CT})
2513                     : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CT, CBO});
2514     if (!NewCT)
2515       return SDValue();
2516 
2517     NewCF = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CF})
2518                     : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CF, CBO});
2519     if (!NewCF)
2520       return SDValue();
2521   }
2522 
2523   SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2524   SelectOp->setFlags(BO->getFlags());
2525   return SelectOp;
2526 }
2527 
foldAddSubBoolOfMaskedVal(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)2528 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
2529                                          SelectionDAG &DAG) {
2530   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2531          "Expecting add or sub");
2532 
2533   // Match a constant operand and a zext operand for the math instruction:
2534   // add Z, C
2535   // sub C, Z
2536   bool IsAdd = N->getOpcode() == ISD::ADD;
2537   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2538   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2539   auto *CN = dyn_cast<ConstantSDNode>(C);
2540   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2541     return SDValue();
2542 
2543   // Match the zext operand as a setcc of a boolean.
2544   if (Z.getOperand(0).getValueType() != MVT::i1)
2545     return SDValue();
2546 
2547   // Match the compare as: setcc (X & 1), 0, eq.
2548   if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
2549                                          m_SpecificCondCode(ISD::SETEQ))))
2550     return SDValue();
2551 
2552   // We are adding/subtracting a constant and an inverted low bit. Turn that
2553   // into a subtract/add of the low bit with incremented/decremented constant:
2554   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2555   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2556   EVT VT = C.getValueType();
2557   SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
2558   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
2559                      : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2560   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2561 }
2562 
2563 // Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
foldSubToAvg(SDNode * N,const SDLoc & DL)2564 SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2565   SDValue N0 = N->getOperand(0);
2566   EVT VT = N0.getValueType();
2567   SDValue A, B;
2568 
2569   if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
2570       sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
2571                         m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
2572                               m_SpecificInt(1))))) {
2573     return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
2574   }
2575   if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
2576       sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
2577                         m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
2578                               m_SpecificInt(1))))) {
2579     return DAG.getNode(ISD::AVGCEILS, DL, VT, A, B);
2580   }
2581   return SDValue();
2582 }
2583 
2584 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2585 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)2586 static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
2587                                    SelectionDAG &DAG) {
2588   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2589          "Expecting add or sub");
2590 
2591   // We need a constant operand for the add/sub, and the other operand is a
2592   // logical shift right: add (srl), C or sub C, (srl).
2593   bool IsAdd = N->getOpcode() == ISD::ADD;
2594   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2595   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2596   if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2597       ShiftOp.getOpcode() != ISD::SRL)
2598     return SDValue();
2599 
2600   // The shift must be of a 'not' value.
2601   SDValue Not = ShiftOp.getOperand(0);
2602   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2603     return SDValue();
2604 
2605   // The shift must be moving the sign bit to the least-significant-bit.
2606   EVT VT = ShiftOp.getValueType();
2607   SDValue ShAmt = ShiftOp.getOperand(1);
2608   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2609   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2610     return SDValue();
2611 
2612   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2613   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2614   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2615   if (SDValue NewC = DAG.FoldConstantArithmetic(
2616           IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2617           {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2618     SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2619                                    Not.getOperand(0), ShAmt);
2620     return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2621   }
2622 
2623   return SDValue();
2624 }
2625 
2626 static bool
areBitwiseNotOfEachother(SDValue Op0,SDValue Op1)2627 areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2628   return (isBitwiseNot(Op0) && Op0.getOperand(0) == Op1) ||
2629          (isBitwiseNot(Op1) && Op1.getOperand(0) == Op0);
2630 }
2631 
2632 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2633 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2634 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2635 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2636   SDValue N0 = N->getOperand(0);
2637   SDValue N1 = N->getOperand(1);
2638   EVT VT = N0.getValueType();
2639   SDLoc DL(N);
2640 
2641   // fold (add x, undef) -> undef
2642   if (N0.isUndef())
2643     return N0;
2644   if (N1.isUndef())
2645     return N1;
2646 
2647   // fold (add c1, c2) -> c1+c2
2648   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2649     return C;
2650 
2651   // canonicalize constant to RHS
2652   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2653       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2654     return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2655 
2656   if (areBitwiseNotOfEachother(N0, N1))
2657     return DAG.getConstant(APInt::getAllOnes(VT.getScalarSizeInBits()), DL, VT);
2658 
2659   // fold vector ops
2660   if (VT.isVector()) {
2661     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2662       return FoldedVOp;
2663 
2664     // fold (add x, 0) -> x, vector edition
2665     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2666       return N0;
2667   }
2668 
2669   // fold (add x, 0) -> x
2670   if (isNullConstant(N1))
2671     return N0;
2672 
2673   if (N0.getOpcode() == ISD::SUB) {
2674     SDValue N00 = N0.getOperand(0);
2675     SDValue N01 = N0.getOperand(1);
2676 
2677     // fold ((A-c1)+c2) -> (A+(c2-c1))
2678     if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2679       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2680 
2681     // fold ((c1-A)+c2) -> (c1+c2)-A
2682     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2683       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2684   }
2685 
2686   // add (sext i1 X), 1 -> zext (not i1 X)
2687   // We don't transform this pattern:
2688   //   add (zext i1 X), -1 -> sext (not i1 X)
2689   // because most (?) targets generate better code for the zext form.
2690   if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2691       isOneOrOneSplat(N1)) {
2692     SDValue X = N0.getOperand(0);
2693     if ((!LegalOperations ||
2694          (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2695           TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2696         X.getScalarValueSizeInBits() == 1) {
2697       SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2698       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2699     }
2700   }
2701 
2702   // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2703   // iff (or x, c0) is equivalent to (add x, c0).
2704   // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2705   // iff (xor x, c0) is equivalent to (add x, c0).
2706   if (DAG.isADDLike(N0)) {
2707     SDValue N01 = N0.getOperand(1);
2708     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2709       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2710   }
2711 
2712   if (SDValue NewSel = foldBinOpIntoSelect(N))
2713     return NewSel;
2714 
2715   // reassociate add
2716   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2717     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2718       return RADD;
2719 
2720     // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2721     // equivalent to (add x, c).
2722     // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2723     // equivalent to (add x, c).
2724     // Do this optimization only when adding c does not introduce instructions
2725     // for adding carries.
2726     auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2727       if (DAG.isADDLike(N0) && N0.hasOneUse() &&
2728           isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2729         // If N0's type does not split or is a sign mask, it does not introduce
2730         // add carry.
2731         auto TyActn = TLI.getTypeAction(*DAG.getContext(), N0.getValueType());
2732         bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2733                           TyActn == TargetLoweringBase::TypePromoteInteger ||
2734                           isMinSignedConstant(N0.getOperand(1));
2735         if (NoAddCarry)
2736           return DAG.getNode(
2737               ISD::ADD, DL, VT,
2738               DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2739               N0.getOperand(1));
2740       }
2741       return SDValue();
2742     };
2743     if (SDValue Add = ReassociateAddOr(N0, N1))
2744       return Add;
2745     if (SDValue Add = ReassociateAddOr(N1, N0))
2746       return Add;
2747 
2748     // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2749     if (SDValue SD =
2750             reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
2751       return SD;
2752   }
2753 
2754   SDValue A, B, C, D;
2755 
2756   // fold ((0-A) + B) -> B-A
2757   if (sd_match(N0, m_Neg(m_Value(A))))
2758     return DAG.getNode(ISD::SUB, DL, VT, N1, A);
2759 
2760   // fold (A + (0-B)) -> A-B
2761   if (sd_match(N1, m_Neg(m_Value(B))))
2762     return DAG.getNode(ISD::SUB, DL, VT, N0, B);
2763 
2764   // fold (A+(B-A)) -> B
2765   if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0))))
2766     return B;
2767 
2768   // fold ((B-A)+A) -> B
2769   if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1))))
2770     return B;
2771 
2772   // fold ((A-B)+(C-A)) -> (C-B)
2773   if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
2774       sd_match(N1, m_Sub(m_Value(C), m_Specific(A))))
2775     return DAG.getNode(ISD::SUB, DL, VT, C, B);
2776 
2777   // fold ((A-B)+(B-C)) -> (A-C)
2778   if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
2779       sd_match(N1, m_Sub(m_Specific(B), m_Value(C))))
2780     return DAG.getNode(ISD::SUB, DL, VT, A, C);
2781 
2782   // fold (A+(B-(A+C))) to (B-C)
2783   // fold (A+(B-(C+A))) to (B-C)
2784   if (sd_match(N1, m_Sub(m_Value(B), m_Add(m_Specific(N0), m_Value(C)))))
2785     return DAG.getNode(ISD::SUB, DL, VT, B, C);
2786 
2787   // fold (A+((B-A)+or-C)) to (B+or-C)
2788   if (sd_match(N1,
2789                m_AnyOf(m_Add(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)),
2790                        m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)))))
2791     return DAG.getNode(N1.getOpcode(), DL, VT, B, C);
2792 
2793   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2794   if (sd_match(N0, m_OneUse(m_Sub(m_Value(A), m_Value(B)))) &&
2795       sd_match(N1, m_OneUse(m_Sub(m_Value(C), m_Value(D)))) &&
2796       (isConstantOrConstantVector(A) || isConstantOrConstantVector(C)))
2797     return DAG.getNode(ISD::SUB, DL, VT,
2798                        DAG.getNode(ISD::ADD, SDLoc(N0), VT, A, C),
2799                        DAG.getNode(ISD::ADD, SDLoc(N1), VT, B, D));
2800 
2801   // fold (add (umax X, C), -C) --> (usubsat X, C)
2802   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2803     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2804       return (!Max && !Op) ||
2805              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2806     };
2807     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2808                                   /*AllowUndefs*/ true))
2809       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2810                          N0.getOperand(1));
2811   }
2812 
2813   if (SimplifyDemandedBits(SDValue(N, 0)))
2814     return SDValue(N, 0);
2815 
2816   if (isOneOrOneSplat(N1)) {
2817     // fold (add (xor a, -1), 1) -> (sub 0, a)
2818     if (isBitwiseNot(N0))
2819       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2820                          N0.getOperand(0));
2821 
2822     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2823     if (N0.getOpcode() == ISD::ADD) {
2824       SDValue A, Xor;
2825 
2826       if (isBitwiseNot(N0.getOperand(0))) {
2827         A = N0.getOperand(1);
2828         Xor = N0.getOperand(0);
2829       } else if (isBitwiseNot(N0.getOperand(1))) {
2830         A = N0.getOperand(0);
2831         Xor = N0.getOperand(1);
2832       }
2833 
2834       if (Xor)
2835         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2836     }
2837 
2838     // Look for:
2839     //   add (add x, y), 1
2840     // And if the target does not like this form then turn into:
2841     //   sub y, (xor x, -1)
2842     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2843         N0.hasOneUse() &&
2844         // Limit this to after legalization if the add has wrap flags
2845         (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
2846                                        !N->getFlags().hasNoSignedWrap()))) {
2847       SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
2848       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2849     }
2850   }
2851 
2852   // (x - y) + -1  ->  add (xor y, -1), x
2853   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2854       isAllOnesOrAllOnesSplat(N1, /*AllowUndefs=*/true)) {
2855     SDValue Not = DAG.getNOT(DL, N0.getOperand(1), VT);
2856     return DAG.getNode(ISD::ADD, DL, VT, Not, N0.getOperand(0));
2857   }
2858 
2859   // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
2860   // This can help if the inner add has multiple uses.
2861   APInt CM, CA;
2862   if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(N1)) {
2863     if (VT.getScalarSizeInBits() <= 64) {
2864       if (sd_match(N0, m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
2865                                       m_ConstInt(CM)))) &&
2866           TLI.isLegalAddImmediate(
2867               (CA * CM + CB->getAPIntValue()).getSExtValue())) {
2868         SDNodeFlags Flags;
2869         // If all the inputs are nuw, the outputs can be nuw. If all the input
2870         // are _also_ nsw the outputs can be too.
2871         if (N->getFlags().hasNoUnsignedWrap() &&
2872             N0->getFlags().hasNoUnsignedWrap() &&
2873             N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
2874           Flags.setNoUnsignedWrap(true);
2875           if (N->getFlags().hasNoSignedWrap() &&
2876               N0->getFlags().hasNoSignedWrap() &&
2877               N0.getOperand(0)->getFlags().hasNoSignedWrap())
2878             Flags.setNoSignedWrap(true);
2879         }
2880         SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
2881                                   DAG.getConstant(CM, DL, VT), Flags);
2882         return DAG.getNode(
2883             ISD::ADD, DL, VT, Mul,
2884             DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
2885       }
2886       // Also look in case there is an intermediate add.
2887       if (sd_match(N0, m_OneUse(m_Add(
2888                            m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
2889                                           m_ConstInt(CM))),
2890                            m_Value(B)))) &&
2891           TLI.isLegalAddImmediate(
2892               (CA * CM + CB->getAPIntValue()).getSExtValue())) {
2893         SDNodeFlags Flags;
2894         // If all the inputs are nuw, the outputs can be nuw. If all the input
2895         // are _also_ nsw the outputs can be too.
2896         SDValue OMul =
2897             N0.getOperand(0) == B ? N0.getOperand(1) : N0.getOperand(0);
2898         if (N->getFlags().hasNoUnsignedWrap() &&
2899             N0->getFlags().hasNoUnsignedWrap() &&
2900             OMul->getFlags().hasNoUnsignedWrap() &&
2901             OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
2902           Flags.setNoUnsignedWrap(true);
2903           if (N->getFlags().hasNoSignedWrap() &&
2904               N0->getFlags().hasNoSignedWrap() &&
2905               OMul->getFlags().hasNoSignedWrap() &&
2906               OMul.getOperand(0)->getFlags().hasNoSignedWrap())
2907             Flags.setNoSignedWrap(true);
2908         }
2909         SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
2910                                   DAG.getConstant(CM, DL, VT), Flags);
2911         SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N1), VT, Mul, B, Flags);
2912         return DAG.getNode(
2913             ISD::ADD, DL, VT, Add,
2914             DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
2915       }
2916     }
2917   }
2918 
2919   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2920     return Combined;
2921 
2922   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2923     return Combined;
2924 
2925   return SDValue();
2926 }
2927 
2928 // Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
foldAddToAvg(SDNode * N,const SDLoc & DL)2929 SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
2930   SDValue N0 = N->getOperand(0);
2931   EVT VT = N0.getValueType();
2932   SDValue A, B;
2933 
2934   if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
2935       sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
2936                         m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
2937                               m_SpecificInt(1))))) {
2938     return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
2939   }
2940   if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
2941       sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
2942                         m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
2943                               m_SpecificInt(1))))) {
2944     return DAG.getNode(ISD::AVGFLOORS, DL, VT, A, B);
2945   }
2946 
2947   return SDValue();
2948 }
2949 
visitADD(SDNode * N)2950 SDValue DAGCombiner::visitADD(SDNode *N) {
2951   SDValue N0 = N->getOperand(0);
2952   SDValue N1 = N->getOperand(1);
2953   EVT VT = N0.getValueType();
2954   SDLoc DL(N);
2955 
2956   if (SDValue Combined = visitADDLike(N))
2957     return Combined;
2958 
2959   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
2960     return V;
2961 
2962   if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
2963     return V;
2964 
2965   // Try to match AVGFLOOR fixedwidth pattern
2966   if (SDValue V = foldAddToAvg(N, DL))
2967     return V;
2968 
2969   // fold (a+b) -> (a|b) iff a and b share no bits.
2970   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2971       DAG.haveNoCommonBitsSet(N0, N1)) {
2972     SDNodeFlags Flags;
2973     Flags.setDisjoint(true);
2974     return DAG.getNode(ISD::OR, DL, VT, N0, N1, Flags);
2975   }
2976 
2977   // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2978   if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2979     const APInt &C0 = N0->getConstantOperandAPInt(0);
2980     const APInt &C1 = N1->getConstantOperandAPInt(0);
2981     return DAG.getVScale(DL, VT, C0 + C1);
2982   }
2983 
2984   // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2985   if (N0.getOpcode() == ISD::ADD &&
2986       N0.getOperand(1).getOpcode() == ISD::VSCALE &&
2987       N1.getOpcode() == ISD::VSCALE) {
2988     const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2989     const APInt &VS1 = N1->getConstantOperandAPInt(0);
2990     SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2991     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2992   }
2993 
2994   // Fold (add step_vector(c1), step_vector(c2)  to step_vector(c1+c2))
2995   if (N0.getOpcode() == ISD::STEP_VECTOR &&
2996       N1.getOpcode() == ISD::STEP_VECTOR) {
2997     const APInt &C0 = N0->getConstantOperandAPInt(0);
2998     const APInt &C1 = N1->getConstantOperandAPInt(0);
2999     APInt NewStep = C0 + C1;
3000     return DAG.getStepVector(DL, VT, NewStep);
3001   }
3002 
3003   // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3004   if (N0.getOpcode() == ISD::ADD &&
3005       N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR &&
3006       N1.getOpcode() == ISD::STEP_VECTOR) {
3007     const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3008     const APInt &SV1 = N1->getConstantOperandAPInt(0);
3009     APInt NewStep = SV0 + SV1;
3010     SDValue SV = DAG.getStepVector(DL, VT, NewStep);
3011     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
3012   }
3013 
3014   return SDValue();
3015 }
3016 
visitADDSAT(SDNode * N)3017 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3018   unsigned Opcode = N->getOpcode();
3019   SDValue N0 = N->getOperand(0);
3020   SDValue N1 = N->getOperand(1);
3021   EVT VT = N0.getValueType();
3022   bool IsSigned = Opcode == ISD::SADDSAT;
3023   SDLoc DL(N);
3024 
3025   // fold (add_sat x, undef) -> -1
3026   if (N0.isUndef() || N1.isUndef())
3027     return DAG.getAllOnesConstant(DL, VT);
3028 
3029   // fold (add_sat c1, c2) -> c3
3030   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
3031     return C;
3032 
3033   // canonicalize constant to RHS
3034   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3035       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3036     return DAG.getNode(Opcode, DL, VT, N1, N0);
3037 
3038   // fold vector ops
3039   if (VT.isVector()) {
3040     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3041       return FoldedVOp;
3042 
3043     // fold (add_sat x, 0) -> x, vector edition
3044     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3045       return N0;
3046   }
3047 
3048   // fold (add_sat x, 0) -> x
3049   if (isNullConstant(N1))
3050     return N0;
3051 
3052   // If it cannot overflow, transform into an add.
3053   if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3054     return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
3055 
3056   return SDValue();
3057 }
3058 
getAsCarry(const TargetLowering & TLI,SDValue V,bool ForceCarryReconstruction=false)3059 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3060                           bool ForceCarryReconstruction = false) {
3061   bool Masked = false;
3062 
3063   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3064   while (true) {
3065     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3066       V = V.getOperand(0);
3067       continue;
3068     }
3069 
3070     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
3071       if (ForceCarryReconstruction)
3072         return V;
3073 
3074       Masked = true;
3075       V = V.getOperand(0);
3076       continue;
3077     }
3078 
3079     if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3080       return V;
3081 
3082     break;
3083   }
3084 
3085   // If this is not a carry, return.
3086   if (V.getResNo() != 1)
3087     return SDValue();
3088 
3089   if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3090       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3091     return SDValue();
3092 
3093   EVT VT = V->getValueType(0);
3094   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
3095     return SDValue();
3096 
3097   // If the result is masked, then no matter what kind of bool it is we can
3098   // return. If it isn't, then we need to make sure the bool type is either 0 or
3099   // 1 and not other values.
3100   if (Masked ||
3101       TLI.getBooleanContents(V.getValueType()) ==
3102           TargetLoweringBase::ZeroOrOneBooleanContent)
3103     return V;
3104 
3105   return SDValue();
3106 }
3107 
3108 /// Given the operands of an add/sub operation, see if the 2nd operand is a
3109 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3110 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)3111 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3112                                  SelectionDAG &DAG, const SDLoc &DL) {
3113   if (N1.getOpcode() == ISD::ZERO_EXTEND)
3114     N1 = N1.getOperand(0);
3115 
3116   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
3117     return SDValue();
3118 
3119   EVT VT = N0.getValueType();
3120   SDValue N10 = N1.getOperand(0);
3121   if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3122     N10 = N10.getOperand(0);
3123 
3124   if (N10.getValueType() != VT)
3125     return SDValue();
3126 
3127   if (DAG.ComputeNumSignBits(N10) != VT.getScalarSizeInBits())
3128     return SDValue();
3129 
3130   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3131   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3132   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N10);
3133 }
3134 
3135 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)3136 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3137                                              SDNode *LocReference) {
3138   EVT VT = N0.getValueType();
3139   SDLoc DL(LocReference);
3140 
3141   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3142   SDValue Y, N;
3143   if (sd_match(N1, m_Shl(m_Neg(m_Value(Y)), m_Value(N))))
3144     return DAG.getNode(ISD::SUB, DL, VT, N0,
3145                        DAG.getNode(ISD::SHL, DL, VT, Y, N));
3146 
3147   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
3148     return V;
3149 
3150   // Look for:
3151   //   add (add x, 1), y
3152   // And if the target does not like this form then turn into:
3153   //   sub y, (xor x, -1)
3154   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3155       N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1)) &&
3156       // Limit this to after legalization if the add has wrap flags
3157       (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3158                                      !N0->getFlags().hasNoSignedWrap()))) {
3159     SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
3160     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
3161   }
3162 
3163   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3164     // Hoist one-use subtraction by non-opaque constant:
3165     //   (x - C) + y  ->  (x + y) - C
3166     // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3167     if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3168       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
3169       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
3170     }
3171     // Hoist one-use subtraction from non-opaque constant:
3172     //   (C - x) + y  ->  (y - x) + C
3173     if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3174       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
3175       return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
3176     }
3177   }
3178 
3179   // add (mul x, C), x -> mul x, C+1
3180   if (N0.getOpcode() == ISD::MUL && N0.getOperand(0) == N1 &&
3181       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true) &&
3182       N0.hasOneUse()) {
3183     SDValue NewC = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
3184                                DAG.getConstant(1, DL, VT));
3185     return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), NewC);
3186   }
3187 
3188   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3189   // rather than 'add 0/-1' (the zext should get folded).
3190   // add (sext i1 Y), X --> sub X, (zext i1 Y)
3191   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3192       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
3193       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
3194     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
3195     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
3196   }
3197 
3198   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3199   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3200     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3201     if (TN->getVT() == MVT::i1) {
3202       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3203                                  DAG.getConstant(1, DL, VT));
3204       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
3205     }
3206   }
3207 
3208   // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3209   if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1)) &&
3210       N1.getResNo() == 0)
3211     return DAG.getNode(ISD::UADDO_CARRY, DL, N1->getVTList(),
3212                        N0, N1.getOperand(0), N1.getOperand(2));
3213 
3214   // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3215   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3216     if (SDValue Carry = getAsCarry(TLI, N1))
3217       return DAG.getNode(ISD::UADDO_CARRY, DL,
3218                          DAG.getVTList(VT, Carry.getValueType()), N0,
3219                          DAG.getConstant(0, DL, VT), Carry);
3220 
3221   return SDValue();
3222 }
3223 
visitADDC(SDNode * N)3224 SDValue DAGCombiner::visitADDC(SDNode *N) {
3225   SDValue N0 = N->getOperand(0);
3226   SDValue N1 = N->getOperand(1);
3227   EVT VT = N0.getValueType();
3228   SDLoc DL(N);
3229 
3230   // If the flag result is dead, turn this into an ADD.
3231   if (!N->hasAnyUseOfValue(1))
3232     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3233                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3234 
3235   // canonicalize constant to RHS.
3236   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3237   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3238   if (N0C && !N1C)
3239     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
3240 
3241   // fold (addc x, 0) -> x + no carry out
3242   if (isNullConstant(N1))
3243     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
3244                                         DL, MVT::Glue));
3245 
3246   // If it cannot overflow, transform into an add.
3247   if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3248     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3249                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3250 
3251   return SDValue();
3252 }
3253 
3254 /**
3255  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3256  * then the flip also occurs if computing the inverse is the same cost.
3257  * This function returns an empty SDValue in case it cannot flip the boolean
3258  * without increasing the cost of the computation. If you want to flip a boolean
3259  * no matter what, use DAG.getLogicalNOT.
3260  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)3261 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3262                                   const TargetLowering &TLI,
3263                                   bool Force) {
3264   if (Force && isa<ConstantSDNode>(V))
3265     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3266 
3267   if (V.getOpcode() != ISD::XOR)
3268     return SDValue();
3269 
3270   ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
3271   if (!Const)
3272     return SDValue();
3273 
3274   EVT VT = V.getValueType();
3275 
3276   bool IsFlip = false;
3277   switch(TLI.getBooleanContents(VT)) {
3278     case TargetLowering::ZeroOrOneBooleanContent:
3279       IsFlip = Const->isOne();
3280       break;
3281     case TargetLowering::ZeroOrNegativeOneBooleanContent:
3282       IsFlip = Const->isAllOnes();
3283       break;
3284     case TargetLowering::UndefinedBooleanContent:
3285       IsFlip = (Const->getAPIntValue() & 0x01) == 1;
3286       break;
3287   }
3288 
3289   if (IsFlip)
3290     return V.getOperand(0);
3291   if (Force)
3292     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3293   return SDValue();
3294 }
3295 
visitADDO(SDNode * N)3296 SDValue DAGCombiner::visitADDO(SDNode *N) {
3297   SDValue N0 = N->getOperand(0);
3298   SDValue N1 = N->getOperand(1);
3299   EVT VT = N0.getValueType();
3300   bool IsSigned = (ISD::SADDO == N->getOpcode());
3301 
3302   EVT CarryVT = N->getValueType(1);
3303   SDLoc DL(N);
3304 
3305   // If the flag result is dead, turn this into an ADD.
3306   if (!N->hasAnyUseOfValue(1))
3307     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3308                      DAG.getUNDEF(CarryVT));
3309 
3310   // canonicalize constant to RHS.
3311   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3312       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3313     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
3314 
3315   // fold (addo x, 0) -> x + no carry out
3316   if (isNullOrNullSplat(N1))
3317     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3318 
3319   // If it cannot overflow, transform into an add.
3320   if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3321     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3322                      DAG.getConstant(0, DL, CarryVT));
3323 
3324   if (IsSigned) {
3325     // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3326     if (isBitwiseNot(N0) && isOneOrOneSplat(N1))
3327       return DAG.getNode(ISD::SSUBO, DL, N->getVTList(),
3328                          DAG.getConstant(0, DL, VT), N0.getOperand(0));
3329   } else {
3330     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3331     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
3332       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
3333                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3334       return CombineTo(
3335           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3336     }
3337 
3338     if (SDValue Combined = visitUADDOLike(N0, N1, N))
3339       return Combined;
3340 
3341     if (SDValue Combined = visitUADDOLike(N1, N0, N))
3342       return Combined;
3343   }
3344 
3345   return SDValue();
3346 }
3347 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3348 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3349   EVT VT = N0.getValueType();
3350   if (VT.isVector())
3351     return SDValue();
3352 
3353   // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3354   // If Y + 1 cannot overflow.
3355   if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1))) {
3356     SDValue Y = N1.getOperand(0);
3357     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3358     if (DAG.computeOverflowForUnsignedAdd(Y, One) == SelectionDAG::OFK_Never)
3359       return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0, Y,
3360                          N1.getOperand(2));
3361   }
3362 
3363   // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3364   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3365     if (SDValue Carry = getAsCarry(TLI, N1))
3366       return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0,
3367                          DAG.getConstant(0, SDLoc(N), VT), Carry);
3368 
3369   return SDValue();
3370 }
3371 
visitADDE(SDNode * N)3372 SDValue DAGCombiner::visitADDE(SDNode *N) {
3373   SDValue N0 = N->getOperand(0);
3374   SDValue N1 = N->getOperand(1);
3375   SDValue CarryIn = N->getOperand(2);
3376 
3377   // canonicalize constant to RHS
3378   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3379   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3380   if (N0C && !N1C)
3381     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3382                        N1, N0, CarryIn);
3383 
3384   // fold (adde x, y, false) -> (addc x, y)
3385   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3386     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3387 
3388   return SDValue();
3389 }
3390 
visitUADDO_CARRY(SDNode * N)3391 SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3392   SDValue N0 = N->getOperand(0);
3393   SDValue N1 = N->getOperand(1);
3394   SDValue CarryIn = N->getOperand(2);
3395   SDLoc DL(N);
3396 
3397   // canonicalize constant to RHS
3398   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3399   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3400   if (N0C && !N1C)
3401     return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3402 
3403   // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3404   if (isNullConstant(CarryIn)) {
3405     if (!LegalOperations ||
3406         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3407       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3408   }
3409 
3410   // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3411   if (isNullConstant(N0) && isNullConstant(N1)) {
3412     EVT VT = N0.getValueType();
3413     EVT CarryVT = CarryIn.getValueType();
3414     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3415     AddToWorklist(CarryExt.getNode());
3416     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3417                                     DAG.getConstant(1, DL, VT)),
3418                      DAG.getConstant(0, DL, CarryVT));
3419   }
3420 
3421   if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3422     return Combined;
3423 
3424   if (SDValue Combined = visitUADDO_CARRYLike(N1, N0, CarryIn, N))
3425     return Combined;
3426 
3427   // We want to avoid useless duplication.
3428   // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3429   // not a binary operation, this is not really possible to leverage this
3430   // existing mechanism for it. However, if more operations require the same
3431   // deduplication logic, then it may be worth generalize.
3432   SDValue Ops[] = {N1, N0, CarryIn};
3433   SDNode *CSENode =
3434       DAG.getNodeIfExists(ISD::UADDO_CARRY, N->getVTList(), Ops, N->getFlags());
3435   if (CSENode)
3436     return SDValue(CSENode, 0);
3437 
3438   return SDValue();
3439 }
3440 
3441 /**
3442  * If we are facing some sort of diamond carry propagation pattern try to
3443  * break it up to generate something like:
3444  *   (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3445  *
3446  * The end result is usually an increase in operation required, but because the
3447  * carry is now linearized, other transforms can kick in and optimize the DAG.
3448  *
3449  * Patterns typically look something like
3450  *                (uaddo A, B)
3451  *                /          \
3452  *             Carry         Sum
3453  *               |             \
3454  *               | (uaddo_carry *, 0, Z)
3455  *               |       /
3456  *                \   Carry
3457  *                 |   /
3458  * (uaddo_carry X, *, *)
3459  *
3460  * But numerous variation exist. Our goal is to identify A, B, X and Z and
3461  * produce a combine with a single path for carry propagation.
3462  */
combineUADDO_CARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3463 static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3464                                          SelectionDAG &DAG, SDValue X,
3465                                          SDValue Carry0, SDValue Carry1,
3466                                          SDNode *N) {
3467   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3468     return SDValue();
3469   if (Carry1.getOpcode() != ISD::UADDO)
3470     return SDValue();
3471 
3472   SDValue Z;
3473 
3474   /**
3475    * First look for a suitable Z. It will present itself in the form of
3476    * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3477    */
3478   if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3479       isNullConstant(Carry0.getOperand(1))) {
3480     Z = Carry0.getOperand(2);
3481   } else if (Carry0.getOpcode() == ISD::UADDO &&
3482              isOneConstant(Carry0.getOperand(1))) {
3483     EVT VT = Carry0->getValueType(1);
3484     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3485   } else {
3486     // We couldn't find a suitable Z.
3487     return SDValue();
3488   }
3489 
3490 
3491   auto cancelDiamond = [&](SDValue A,SDValue B) {
3492     SDLoc DL(N);
3493     SDValue NewY =
3494         DAG.getNode(ISD::UADDO_CARRY, DL, Carry0->getVTList(), A, B, Z);
3495     Combiner.AddToWorklist(NewY.getNode());
3496     return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), X,
3497                        DAG.getConstant(0, DL, X.getValueType()),
3498                        NewY.getValue(1));
3499   };
3500 
3501   /**
3502    *         (uaddo A, B)
3503    *              |
3504    *             Sum
3505    *              |
3506    * (uaddo_carry *, 0, Z)
3507    */
3508   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3509     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3510   }
3511 
3512   /**
3513    * (uaddo_carry A, 0, Z)
3514    *         |
3515    *        Sum
3516    *         |
3517    *  (uaddo *, B)
3518    */
3519   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3520     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3521   }
3522 
3523   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3524     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3525   }
3526 
3527   return SDValue();
3528 }
3529 
3530 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3531 // match patterns like:
3532 //
3533 //          (uaddo A, B)            CarryIn
3534 //            |  \                     |
3535 //            |   \                    |
3536 //    PartialSum   PartialCarryOutX   /
3537 //            |        |             /
3538 //            |    ____|____________/
3539 //            |   /    |
3540 //     (uaddo *, *)    \________
3541 //       |  \                   \
3542 //       |   \                   |
3543 //       |    PartialCarryOutY   |
3544 //       |        \              |
3545 //       |         \            /
3546 //   AddCarrySum    |    ______/
3547 //                  |   /
3548 //   CarryOut = (or *, *)
3549 //
3550 // And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3551 //
3552 //    {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3553 //
3554 // Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3555 // with a single path for carry/borrow out propagation.
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3556 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3557                                    SDValue N0, SDValue N1, SDNode *N) {
3558   SDValue Carry0 = getAsCarry(TLI, N0);
3559   if (!Carry0)
3560     return SDValue();
3561   SDValue Carry1 = getAsCarry(TLI, N1);
3562   if (!Carry1)
3563     return SDValue();
3564 
3565   unsigned Opcode = Carry0.getOpcode();
3566   if (Opcode != Carry1.getOpcode())
3567     return SDValue();
3568   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3569     return SDValue();
3570   // Guarantee identical type of CarryOut
3571   EVT CarryOutType = N->getValueType(0);
3572   if (CarryOutType != Carry0.getValue(1).getValueType() ||
3573       CarryOutType != Carry1.getValue(1).getValueType())
3574     return SDValue();
3575 
3576   // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3577   // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3578   if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3579     std::swap(Carry0, Carry1);
3580 
3581   // Check if nodes are connected in expected way.
3582   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3583       Carry1.getOperand(1) != Carry0.getValue(0))
3584     return SDValue();
3585 
3586   // The carry in value must be on the righthand side for subtraction.
3587   unsigned CarryInOperandNum =
3588       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3589   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3590     return SDValue();
3591   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3592 
3593   unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3594   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3595     return SDValue();
3596 
3597   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3598   CarryIn = getAsCarry(TLI, CarryIn, true);
3599   if (!CarryIn)
3600     return SDValue();
3601 
3602   SDLoc DL(N);
3603   CarryIn = DAG.getBoolExtOrTrunc(CarryIn, DL, Carry1->getValueType(1),
3604                                   Carry1->getValueType(0));
3605   SDValue Merged =
3606       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3607                   Carry0.getOperand(1), CarryIn);
3608 
3609   // Please note that because we have proven that the result of the UADDO/USUBO
3610   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3611   // therefore prove that if the first UADDO/USUBO overflows, the second
3612   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3613   // maximum value.
3614   //
3615   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3616   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3617   //
3618   // This is important because it means that OR and XOR can be used to merge
3619   // carry flags; and that AND can return a constant zero.
3620   //
3621   // TODO: match other operations that can merge flags (ADD, etc)
3622   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3623   if (N->getOpcode() == ISD::AND)
3624     return DAG.getConstant(0, DL, CarryOutType);
3625   return Merged.getValue(1);
3626 }
3627 
visitUADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3628 SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3629                                           SDValue CarryIn, SDNode *N) {
3630   // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3631   // carry.
3632   if (isBitwiseNot(N0))
3633     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3634       SDLoc DL(N);
3635       SDValue Sub = DAG.getNode(ISD::USUBO_CARRY, DL, N->getVTList(), N1,
3636                                 N0.getOperand(0), NotC);
3637       return CombineTo(
3638           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3639     }
3640 
3641   // Iff the flag result is dead:
3642   // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3643   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3644   // or the dependency between the instructions.
3645   if ((N0.getOpcode() == ISD::ADD ||
3646        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3647         N0.getValue(1) != CarryIn)) &&
3648       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3649     return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(),
3650                        N0.getOperand(0), N0.getOperand(1), CarryIn);
3651 
3652   /**
3653    * When one of the uaddo_carry argument is itself a carry, we may be facing
3654    * a diamond carry propagation. In which case we try to transform the DAG
3655    * to ensure linear carry propagation if that is possible.
3656    */
3657   if (auto Y = getAsCarry(TLI, N1)) {
3658     // Because both are carries, Y and Z can be swapped.
3659     if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3660       return R;
3661     if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3662       return R;
3663   }
3664 
3665   return SDValue();
3666 }
3667 
visitSADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3668 SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3669                                           SDValue CarryIn, SDNode *N) {
3670   // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3671   if (isBitwiseNot(N0)) {
3672     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true))
3673       return DAG.getNode(ISD::SSUBO_CARRY, SDLoc(N), N->getVTList(), N1,
3674                          N0.getOperand(0), NotC);
3675   }
3676 
3677   return SDValue();
3678 }
3679 
visitSADDO_CARRY(SDNode * N)3680 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3681   SDValue N0 = N->getOperand(0);
3682   SDValue N1 = N->getOperand(1);
3683   SDValue CarryIn = N->getOperand(2);
3684   SDLoc DL(N);
3685 
3686   // canonicalize constant to RHS
3687   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3688   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3689   if (N0C && !N1C)
3690     return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3691 
3692   // fold (saddo_carry x, y, false) -> (saddo x, y)
3693   if (isNullConstant(CarryIn)) {
3694     if (!LegalOperations ||
3695         TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3696       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3697   }
3698 
3699   if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3700     return Combined;
3701 
3702   if (SDValue Combined = visitSADDO_CARRYLike(N1, N0, CarryIn, N))
3703     return Combined;
3704 
3705   return SDValue();
3706 }
3707 
3708 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3709 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3710 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3711                                    SDValue RHS, SelectionDAG &DAG,
3712                                    const SDLoc &DL) {
3713   assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3714          "Illegal truncation");
3715 
3716   if (DstVT == SrcVT)
3717     return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3718 
3719   // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3720   // clamping RHS.
3721   APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3722                                           DstVT.getScalarSizeInBits());
3723   if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3724     return SDValue();
3725 
3726   SDValue SatLimit =
3727       DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3728                                            DstVT.getScalarSizeInBits()),
3729                       DL, SrcVT);
3730   RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3731   RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3732   LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3733   return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3734 }
3735 
3736 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3737 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N,const SDLoc & DL)3738 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3739   if (N->getOpcode() != ISD::SUB ||
3740       !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3741     return SDValue();
3742 
3743   EVT SubVT = N->getValueType(0);
3744   SDValue Op0 = N->getOperand(0);
3745   SDValue Op1 = N->getOperand(1);
3746 
3747   // Try to find umax(a,b) - b or a - umin(a,b) patterns
3748   // they may be converted to usubsat(a,b).
3749   if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3750     SDValue MaxLHS = Op0.getOperand(0);
3751     SDValue MaxRHS = Op0.getOperand(1);
3752     if (MaxLHS == Op1)
3753       return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, DL);
3754     if (MaxRHS == Op1)
3755       return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, DL);
3756   }
3757 
3758   if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3759     SDValue MinLHS = Op1.getOperand(0);
3760     SDValue MinRHS = Op1.getOperand(1);
3761     if (MinLHS == Op0)
3762       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, DL);
3763     if (MinRHS == Op0)
3764       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, DL);
3765   }
3766 
3767   // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3768   if (Op1.getOpcode() == ISD::TRUNCATE &&
3769       Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3770       Op1.getOperand(0).hasOneUse()) {
3771     SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3772     SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3773     if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3774       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3775                                  DAG, DL);
3776     if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3777       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3778                                  DAG, DL);
3779   }
3780 
3781   return SDValue();
3782 }
3783 
3784 // Since it may not be valid to emit a fold to zero for vector initializers
3785 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3786 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3787                              SelectionDAG &DAG, bool LegalOperations) {
3788   if (!VT.isVector())
3789     return DAG.getConstant(0, DL, VT);
3790   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3791     return DAG.getConstant(0, DL, VT);
3792   return SDValue();
3793 }
3794 
visitSUB(SDNode * N)3795 SDValue DAGCombiner::visitSUB(SDNode *N) {
3796   SDValue N0 = N->getOperand(0);
3797   SDValue N1 = N->getOperand(1);
3798   EVT VT = N0.getValueType();
3799   unsigned BitWidth = VT.getScalarSizeInBits();
3800   SDLoc DL(N);
3801 
3802   auto PeekThroughFreeze = [](SDValue N) {
3803     if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3804       return N->getOperand(0);
3805     return N;
3806   };
3807 
3808   // fold (sub x, x) -> 0
3809   // FIXME: Refactor this and xor and other similar operations together.
3810   if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3811     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3812 
3813   // fold (sub c1, c2) -> c3
3814   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3815     return C;
3816 
3817   // fold vector ops
3818   if (VT.isVector()) {
3819     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3820       return FoldedVOp;
3821 
3822     // fold (sub x, 0) -> x, vector edition
3823     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3824       return N0;
3825   }
3826 
3827   if (SDValue NewSel = foldBinOpIntoSelect(N))
3828     return NewSel;
3829 
3830   // fold (sub x, c) -> (add x, -c)
3831   if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
3832     return DAG.getNode(ISD::ADD, DL, VT, N0,
3833                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3834 
3835   if (isNullOrNullSplat(N0)) {
3836     // Right-shifting everything out but the sign bit followed by negation is
3837     // the same as flipping arithmetic/logical shift type without the negation:
3838     // -(X >>u 31) -> (X >>s 31)
3839     // -(X >>s 31) -> (X >>u 31)
3840     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3841       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3842       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3843         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3844         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3845           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3846       }
3847     }
3848 
3849     // 0 - X --> 0 if the sub is NUW.
3850     if (N->getFlags().hasNoUnsignedWrap())
3851       return N0;
3852 
3853     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3854       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3855       // N1 must be 0 because negating the minimum signed value is undefined.
3856       if (N->getFlags().hasNoSignedWrap())
3857         return N0;
3858 
3859       // 0 - X --> X if X is 0 or the minimum signed value.
3860       return N1;
3861     }
3862 
3863     // Convert 0 - abs(x).
3864     if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3865         !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
3866       if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
3867         return Result;
3868 
3869     // Fold neg(splat(neg(x)) -> splat(x)
3870     if (VT.isVector()) {
3871       SDValue N1S = DAG.getSplatValue(N1, true);
3872       if (N1S && N1S.getOpcode() == ISD::SUB &&
3873           isNullConstant(N1S.getOperand(0)))
3874         return DAG.getSplat(VT, DL, N1S.getOperand(1));
3875     }
3876   }
3877 
3878   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3879   if (isAllOnesOrAllOnesSplat(N0))
3880     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3881 
3882   // fold (A - (0-B)) -> A+B
3883   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3884     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3885 
3886   // fold A-(A-B) -> B
3887   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3888     return N1.getOperand(1);
3889 
3890   // fold (A+B)-A -> B
3891   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3892     return N0.getOperand(1);
3893 
3894   // fold (A+B)-B -> A
3895   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3896     return N0.getOperand(0);
3897 
3898   // fold (A+C1)-C2 -> A+(C1-C2)
3899   if (N0.getOpcode() == ISD::ADD) {
3900     SDValue N01 = N0.getOperand(1);
3901     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
3902       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3903   }
3904 
3905   // fold C2-(A+C1) -> (C2-C1)-A
3906   if (N1.getOpcode() == ISD::ADD) {
3907     SDValue N11 = N1.getOperand(1);
3908     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
3909       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3910   }
3911 
3912   // fold (A-C1)-C2 -> A-(C1+C2)
3913   if (N0.getOpcode() == ISD::SUB) {
3914     SDValue N01 = N0.getOperand(1);
3915     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
3916       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3917   }
3918 
3919   // fold (c1-A)-c2 -> (c1-c2)-A
3920   if (N0.getOpcode() == ISD::SUB) {
3921     SDValue N00 = N0.getOperand(0);
3922     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
3923       return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3924   }
3925 
3926   SDValue A, B, C;
3927 
3928   // fold ((A+(B+C))-B) -> A+C
3929   if (sd_match(N0, m_Add(m_Value(A), m_Add(m_Specific(N1), m_Value(C)))))
3930     return DAG.getNode(ISD::ADD, DL, VT, A, C);
3931 
3932   // fold ((A+(B-C))-B) -> A-C
3933   if (sd_match(N0, m_Add(m_Value(A), m_Sub(m_Specific(N1), m_Value(C)))))
3934     return DAG.getNode(ISD::SUB, DL, VT, A, C);
3935 
3936   // fold ((A-(B-C))-C) -> A-B
3937   if (sd_match(N0, m_Sub(m_Value(A), m_Sub(m_Value(B), m_Specific(N1)))))
3938     return DAG.getNode(ISD::SUB, DL, VT, A, B);
3939 
3940   // fold (A-(B-C)) -> A+(C-B)
3941   if (sd_match(N1, m_OneUse(m_Sub(m_Value(B), m_Value(C)))))
3942     return DAG.getNode(ISD::ADD, DL, VT, N0,
3943                        DAG.getNode(ISD::SUB, DL, VT, C, B));
3944 
3945   // A - (A & B)  ->  A & (~B)
3946   if (sd_match(N1, m_And(m_Specific(N0), m_Value(B))) &&
3947       (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true)))
3948     return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getNOT(DL, B, VT));
3949 
3950   // fold (A - (-B * C)) -> (A + (B * C))
3951   if (sd_match(N1, m_OneUse(m_Mul(m_Neg(m_Value(B)), m_Value(C)))))
3952     return DAG.getNode(ISD::ADD, DL, VT, N0,
3953                        DAG.getNode(ISD::MUL, DL, VT, B, C));
3954 
3955   // If either operand of a sub is undef, the result is undef
3956   if (N0.isUndef())
3957     return N0;
3958   if (N1.isUndef())
3959     return N1;
3960 
3961   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3962     return V;
3963 
3964   if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3965     return V;
3966 
3967   // Try to match AVGCEIL fixedwidth pattern
3968   if (SDValue V = foldSubToAvg(N, DL))
3969     return V;
3970 
3971   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, DL))
3972     return V;
3973 
3974   if (SDValue V = foldSubToUSubSat(VT, N, DL))
3975     return V;
3976 
3977   // (A - B) - 1  ->  add (xor B, -1), A
3978   if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
3979     return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
3980 
3981   // Look for:
3982   //   sub y, (xor x, -1)
3983   // And if the target does not like this form then turn into:
3984   //   add (add x, y), 1
3985   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3986     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3987     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3988   }
3989 
3990   // Hoist one-use addition by non-opaque constant:
3991   //   (x + C) - y  ->  (x - y) + C
3992   if (!reassociationCanBreakAddressingModePattern(ISD::SUB, DL, N, N0, N1) &&
3993       N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
3994       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3995     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3996     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3997   }
3998   // y - (x + C)  ->  (y - x) - C
3999   if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4000       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
4001     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
4002     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
4003   }
4004   // (x - C) - y  ->  (x - y) - C
4005   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4006   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4007       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4008     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4009     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
4010   }
4011   // (C - x) - y  ->  C - (x + y)
4012   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4013       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
4014     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
4015     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
4016   }
4017 
4018   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4019   // rather than 'sub 0/1' (the sext should get folded).
4020   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4021   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4022       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
4023       TLI.getBooleanContents(VT) ==
4024           TargetLowering::ZeroOrNegativeOneBooleanContent) {
4025     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
4026     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
4027   }
4028 
4029   // fold B = sra (A, size(A)-1); sub (xor (A, B), B) -> (abs A)
4030   if ((!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4031       sd_match(N1, m_Sra(m_Value(A), m_SpecificInt(BitWidth - 1))) &&
4032       sd_match(N0, m_Xor(m_Specific(A), m_Specific(N1))))
4033     return DAG.getNode(ISD::ABS, DL, VT, A);
4034 
4035   // If the relocation model supports it, consider symbol offsets.
4036   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
4037     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4038       // fold (sub Sym+c1, Sym+c2) -> c1-c2
4039       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
4040         if (GA->getGlobal() == GB->getGlobal())
4041           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
4042                                  DL, VT);
4043     }
4044 
4045   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4046   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4047     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
4048     if (TN->getVT() == MVT::i1) {
4049       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
4050                                  DAG.getConstant(1, DL, VT));
4051       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
4052     }
4053   }
4054 
4055   // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4056   if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4057     const APInt &IntVal = N1.getConstantOperandAPInt(0);
4058     return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
4059   }
4060 
4061   // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4062   if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4063     APInt NewStep = -N1.getConstantOperandAPInt(0);
4064     return DAG.getNode(ISD::ADD, DL, VT, N0,
4065                        DAG.getStepVector(DL, VT, NewStep));
4066   }
4067 
4068   // Prefer an add for more folding potential and possibly better codegen:
4069   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4070   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4071     SDValue ShAmt = N1.getOperand(1);
4072     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
4073     if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
4074       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
4075       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
4076     }
4077   }
4078 
4079   // As with the previous fold, prefer add for more folding potential.
4080   // Subtracting SMIN/0 is the same as adding SMIN/0:
4081   // N0 - (X << BW-1) --> N0 + (X << BW-1)
4082   if (N1.getOpcode() == ISD::SHL) {
4083     ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
4084     if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4085       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
4086   }
4087 
4088   // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4089   if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(N0.getOperand(1)) &&
4090       N0.getResNo() == 0 && N0.hasOneUse())
4091     return DAG.getNode(ISD::USUBO_CARRY, DL, N0->getVTList(),
4092                        N0.getOperand(0), N1, N0.getOperand(2));
4093 
4094   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT)) {
4095     // (sub Carry, X)  ->  (uaddo_carry (sub 0, X), 0, Carry)
4096     if (SDValue Carry = getAsCarry(TLI, N0)) {
4097       SDValue X = N1;
4098       SDValue Zero = DAG.getConstant(0, DL, VT);
4099       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
4100       return DAG.getNode(ISD::UADDO_CARRY, DL,
4101                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
4102                          Carry);
4103     }
4104   }
4105 
4106   // If there's no chance of borrowing from adjacent bits, then sub is xor:
4107   // sub C0, X --> xor X, C0
4108   if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
4109     if (!C0->isOpaque()) {
4110       const APInt &C0Val = C0->getAPIntValue();
4111       const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
4112       if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4113         return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4114     }
4115   }
4116 
4117   // smax(a,b) - smin(a,b) --> abds(a,b)
4118   if (hasOperation(ISD::ABDS, VT) &&
4119       sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
4120       sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
4121     return DAG.getNode(ISD::ABDS, DL, VT, A, B);
4122 
4123   // umax(a,b) - umin(a,b) --> abdu(a,b)
4124   if (hasOperation(ISD::ABDU, VT) &&
4125       sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
4126       sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
4127     return DAG.getNode(ISD::ABDU, DL, VT, A, B);
4128 
4129   return SDValue();
4130 }
4131 
visitSUBSAT(SDNode * N)4132 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4133   unsigned Opcode = N->getOpcode();
4134   SDValue N0 = N->getOperand(0);
4135   SDValue N1 = N->getOperand(1);
4136   EVT VT = N0.getValueType();
4137   bool IsSigned = Opcode == ISD::SSUBSAT;
4138   SDLoc DL(N);
4139 
4140   // fold (sub_sat x, undef) -> 0
4141   if (N0.isUndef() || N1.isUndef())
4142     return DAG.getConstant(0, DL, VT);
4143 
4144   // fold (sub_sat x, x) -> 0
4145   if (N0 == N1)
4146     return DAG.getConstant(0, DL, VT);
4147 
4148   // fold (sub_sat c1, c2) -> c3
4149   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4150     return C;
4151 
4152   // fold vector ops
4153   if (VT.isVector()) {
4154     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4155       return FoldedVOp;
4156 
4157     // fold (sub_sat x, 0) -> x, vector edition
4158     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4159       return N0;
4160   }
4161 
4162   // fold (sub_sat x, 0) -> x
4163   if (isNullConstant(N1))
4164     return N0;
4165 
4166   // If it cannot overflow, transform into an sub.
4167   if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4168     return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
4169 
4170   return SDValue();
4171 }
4172 
visitSUBC(SDNode * N)4173 SDValue DAGCombiner::visitSUBC(SDNode *N) {
4174   SDValue N0 = N->getOperand(0);
4175   SDValue N1 = N->getOperand(1);
4176   EVT VT = N0.getValueType();
4177   SDLoc DL(N);
4178 
4179   // If the flag result is dead, turn this into an SUB.
4180   if (!N->hasAnyUseOfValue(1))
4181     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4182                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4183 
4184   // fold (subc x, x) -> 0 + no borrow
4185   if (N0 == N1)
4186     return CombineTo(N, DAG.getConstant(0, DL, VT),
4187                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4188 
4189   // fold (subc x, 0) -> x + no borrow
4190   if (isNullConstant(N1))
4191     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4192 
4193   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4194   if (isAllOnesConstant(N0))
4195     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4196                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4197 
4198   return SDValue();
4199 }
4200 
visitSUBO(SDNode * N)4201 SDValue DAGCombiner::visitSUBO(SDNode *N) {
4202   SDValue N0 = N->getOperand(0);
4203   SDValue N1 = N->getOperand(1);
4204   EVT VT = N0.getValueType();
4205   bool IsSigned = (ISD::SSUBO == N->getOpcode());
4206 
4207   EVT CarryVT = N->getValueType(1);
4208   SDLoc DL(N);
4209 
4210   // If the flag result is dead, turn this into an SUB.
4211   if (!N->hasAnyUseOfValue(1))
4212     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4213                      DAG.getUNDEF(CarryVT));
4214 
4215   // fold (subo x, x) -> 0 + no borrow
4216   if (N0 == N1)
4217     return CombineTo(N, DAG.getConstant(0, DL, VT),
4218                      DAG.getConstant(0, DL, CarryVT));
4219 
4220   // fold (subox, c) -> (addo x, -c)
4221   if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
4222     if (IsSigned && !N1C->isMinSignedValue())
4223       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
4224                          DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4225 
4226   // fold (subo x, 0) -> x + no borrow
4227   if (isNullOrNullSplat(N1))
4228     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
4229 
4230   // If it cannot overflow, transform into an sub.
4231   if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4232     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4233                      DAG.getConstant(0, DL, CarryVT));
4234 
4235   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4236   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
4237     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4238                      DAG.getConstant(0, DL, CarryVT));
4239 
4240   return SDValue();
4241 }
4242 
visitSUBE(SDNode * N)4243 SDValue DAGCombiner::visitSUBE(SDNode *N) {
4244   SDValue N0 = N->getOperand(0);
4245   SDValue N1 = N->getOperand(1);
4246   SDValue CarryIn = N->getOperand(2);
4247 
4248   // fold (sube x, y, false) -> (subc x, y)
4249   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4250     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
4251 
4252   return SDValue();
4253 }
4254 
visitUSUBO_CARRY(SDNode * N)4255 SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4256   SDValue N0 = N->getOperand(0);
4257   SDValue N1 = N->getOperand(1);
4258   SDValue CarryIn = N->getOperand(2);
4259 
4260   // fold (usubo_carry x, y, false) -> (usubo x, y)
4261   if (isNullConstant(CarryIn)) {
4262     if (!LegalOperations ||
4263         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
4264       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
4265   }
4266 
4267   return SDValue();
4268 }
4269 
visitSSUBO_CARRY(SDNode * N)4270 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4271   SDValue N0 = N->getOperand(0);
4272   SDValue N1 = N->getOperand(1);
4273   SDValue CarryIn = N->getOperand(2);
4274 
4275   // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4276   if (isNullConstant(CarryIn)) {
4277     if (!LegalOperations ||
4278         TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
4279       return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
4280   }
4281 
4282   return SDValue();
4283 }
4284 
4285 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4286 // UMULFIXSAT here.
visitMULFIX(SDNode * N)4287 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4288   SDValue N0 = N->getOperand(0);
4289   SDValue N1 = N->getOperand(1);
4290   SDValue Scale = N->getOperand(2);
4291   EVT VT = N0.getValueType();
4292 
4293   // fold (mulfix x, undef, scale) -> 0
4294   if (N0.isUndef() || N1.isUndef())
4295     return DAG.getConstant(0, SDLoc(N), VT);
4296 
4297   // Canonicalize constant to RHS (vector doesn't have to splat)
4298   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4299      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4300     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
4301 
4302   // fold (mulfix x, 0, scale) -> 0
4303   if (isNullConstant(N1))
4304     return DAG.getConstant(0, SDLoc(N), VT);
4305 
4306   return SDValue();
4307 }
4308 
visitMUL(SDNode * N)4309 template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4310   SDValue N0 = N->getOperand(0);
4311   SDValue N1 = N->getOperand(1);
4312   EVT VT = N0.getValueType();
4313   unsigned BitWidth = VT.getScalarSizeInBits();
4314   SDLoc DL(N);
4315   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4316   MatchContextClass Matcher(DAG, TLI, N);
4317 
4318   // fold (mul x, undef) -> 0
4319   if (N0.isUndef() || N1.isUndef())
4320     return DAG.getConstant(0, DL, VT);
4321 
4322   // fold (mul c1, c2) -> c1*c2
4323   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
4324     return C;
4325 
4326   // canonicalize constant to RHS (vector doesn't have to splat)
4327   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4328       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4329     return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
4330 
4331   bool N1IsConst = false;
4332   bool N1IsOpaqueConst = false;
4333   APInt ConstValue1;
4334 
4335   // fold vector ops
4336   if (VT.isVector()) {
4337     // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4338     if (!UseVP)
4339       if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4340         return FoldedVOp;
4341 
4342     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4343     assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
4344            "Splat APInt should be element width");
4345   } else {
4346     N1IsConst = isa<ConstantSDNode>(N1);
4347     if (N1IsConst) {
4348       ConstValue1 = N1->getAsAPIntVal();
4349       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
4350     }
4351   }
4352 
4353   // fold (mul x, 0) -> 0
4354   if (N1IsConst && ConstValue1.isZero())
4355     return N1;
4356 
4357   // fold (mul x, 1) -> x
4358   if (N1IsConst && ConstValue1.isOne())
4359     return N0;
4360 
4361   if (!UseVP)
4362     if (SDValue NewSel = foldBinOpIntoSelect(N))
4363       return NewSel;
4364 
4365   // fold (mul x, -1) -> 0-x
4366   if (N1IsConst && ConstValue1.isAllOnes())
4367     return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4368 
4369   // fold (mul x, (1 << c)) -> x << c
4370   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4371       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4372     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4373       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4374       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4375       return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
4376     }
4377   }
4378 
4379   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4380   if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4381     unsigned Log2Val = (-ConstValue1).logBase2();
4382 
4383     // FIXME: If the input is something that is easily negated (e.g. a
4384     // single-use add), we should put the negate there.
4385     return Matcher.getNode(
4386         ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4387         Matcher.getNode(ISD::SHL, DL, VT, N0,
4388                         DAG.getShiftAmountConstant(Log2Val, VT, DL)));
4389   }
4390 
4391   // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4392   // hi result is in use in case we hit this mid-legalization.
4393   if (!UseVP) {
4394     for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4395       if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4396         SDVTList LoHiVT = DAG.getVTList(VT, VT);
4397         // TODO: Can we match commutable operands with getNodeIfExists?
4398         if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4399           if (LoHi->hasAnyUseOfValue(1))
4400             return SDValue(LoHi, 0);
4401         if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4402           if (LoHi->hasAnyUseOfValue(1))
4403             return SDValue(LoHi, 0);
4404       }
4405     }
4406   }
4407 
4408   // Try to transform:
4409   // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4410   // mul x, (2^N + 1) --> add (shl x, N), x
4411   // mul x, (2^N - 1) --> sub (shl x, N), x
4412   // Examples: x * 33 --> (x << 5) + x
4413   //           x * 15 --> (x << 4) - x
4414   //           x * -33 --> -((x << 5) + x)
4415   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4416   // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4417   // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4418   // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4419   // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4420   //           x * 0xf800 --> (x << 16) - (x << 11)
4421   //           x * -0x8800 --> -((x << 15) + (x << 11))
4422   //           x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4423   if (!UseVP && N1IsConst &&
4424       TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4425     // TODO: We could handle more general decomposition of any constant by
4426     //       having the target set a limit on number of ops and making a
4427     //       callback to determine that sequence (similar to sqrt expansion).
4428     unsigned MathOp = ISD::DELETED_NODE;
4429     APInt MulC = ConstValue1.abs();
4430     // The constant `2` should be treated as (2^0 + 1).
4431     unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4432     MulC.lshrInPlace(TZeros);
4433     if ((MulC - 1).isPowerOf2())
4434       MathOp = ISD::ADD;
4435     else if ((MulC + 1).isPowerOf2())
4436       MathOp = ISD::SUB;
4437 
4438     if (MathOp != ISD::DELETED_NODE) {
4439       unsigned ShAmt =
4440           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4441       ShAmt += TZeros;
4442       assert(ShAmt < BitWidth &&
4443              "multiply-by-constant generated out of bounds shift");
4444       SDValue Shl =
4445           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4446       SDValue R =
4447           TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4448                                DAG.getNode(ISD::SHL, DL, VT, N0,
4449                                            DAG.getConstant(TZeros, DL, VT)))
4450                  : DAG.getNode(MathOp, DL, VT, Shl, N0);
4451       if (ConstValue1.isNegative())
4452         R = DAG.getNegative(R, DL, VT);
4453       return R;
4454     }
4455   }
4456 
4457   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4458   if (sd_context_match(N0, Matcher, m_Opc(ISD::SHL))) {
4459     SDValue N01 = N0.getOperand(1);
4460     if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4461       return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4462   }
4463 
4464   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4465   // use.
4466   {
4467     SDValue Sh, Y;
4468 
4469     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
4470     if (sd_context_match(N0, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4471         isConstantOrConstantVector(N0.getOperand(1))) {
4472       Sh = N0; Y = N1;
4473     } else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4474                isConstantOrConstantVector(N1.getOperand(1))) {
4475       Sh = N1; Y = N0;
4476     }
4477 
4478     if (Sh.getNode()) {
4479       SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4480       return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4481     }
4482   }
4483 
4484   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4485   if (sd_context_match(N0, Matcher, m_Opc(ISD::ADD)) &&
4486       DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4487       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4488       isMulAddWithConstProfitable(N, N0, N1))
4489     return Matcher.getNode(
4490         ISD::ADD, DL, VT,
4491         Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4492         Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4493 
4494   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4495   ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4496   if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4497     const APInt &C0 = N0.getConstantOperandAPInt(0);
4498     const APInt &C1 = NC1->getAPIntValue();
4499     return DAG.getVScale(DL, VT, C0 * C1);
4500   }
4501 
4502   // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4503   APInt MulVal;
4504   if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4505       ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4506     const APInt &C0 = N0.getConstantOperandAPInt(0);
4507     APInt NewStep = C0 * MulVal;
4508     return DAG.getStepVector(DL, VT, NewStep);
4509   }
4510 
4511   // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
4512   SDValue X;
4513   if (!UseVP && (!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4514       sd_context_match(
4515           N, Matcher,
4516           m_Mul(m_Or(m_Sra(m_Value(X), m_SpecificInt(BitWidth - 1)), m_One()),
4517                 m_Deferred(X)))) {
4518     return Matcher.getNode(ISD::ABS, DL, VT, X);
4519   }
4520 
4521   // Fold ((mul x, 0/undef) -> 0,
4522   //       (mul x, 1) -> x) -> x)
4523   // -> and(x, mask)
4524   // We can replace vectors with '0' and '1' factors with a clearing mask.
4525   if (VT.isFixedLengthVector()) {
4526     unsigned NumElts = VT.getVectorNumElements();
4527     SmallBitVector ClearMask;
4528     ClearMask.reserve(NumElts);
4529     auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4530       if (!V || V->isZero()) {
4531         ClearMask.push_back(true);
4532         return true;
4533       }
4534       ClearMask.push_back(false);
4535       return V->isOne();
4536     };
4537     if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4538         ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4539       assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4540       EVT LegalSVT = N1.getOperand(0).getValueType();
4541       SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4542       SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4543       SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4544       for (unsigned I = 0; I != NumElts; ++I)
4545         if (ClearMask[I])
4546           Mask[I] = Zero;
4547       return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4548     }
4549   }
4550 
4551   // reassociate mul
4552   // TODO: Change reassociateOps to support vp ops.
4553   if (!UseVP)
4554     if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4555       return RMUL;
4556 
4557   // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4558   // TODO: Change reassociateReduction to support vp ops.
4559   if (!UseVP)
4560     if (SDValue SD =
4561             reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4562       return SD;
4563 
4564   // Simplify the operands using demanded-bits information.
4565   if (SimplifyDemandedBits(SDValue(N, 0)))
4566     return SDValue(N, 0);
4567 
4568   return SDValue();
4569 }
4570 
4571 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4572 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4573                                      const TargetLowering &TLI) {
4574   RTLIB::Libcall LC;
4575   EVT NodeType = Node->getValueType(0);
4576   if (!NodeType.isSimple())
4577     return false;
4578   switch (NodeType.getSimpleVT().SimpleTy) {
4579   default: return false; // No libcall for vector types.
4580   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
4581   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4582   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4583   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4584   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4585   }
4586 
4587   return TLI.getLibcallName(LC) != nullptr;
4588 }
4589 
4590 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4591 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4592   if (Node->use_empty())
4593     return SDValue(); // This is a dead node, leave it alone.
4594 
4595   unsigned Opcode = Node->getOpcode();
4596   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4597   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4598 
4599   // DivMod lib calls can still work on non-legal types if using lib-calls.
4600   EVT VT = Node->getValueType(0);
4601   if (VT.isVector() || !VT.isInteger())
4602     return SDValue();
4603 
4604   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4605     return SDValue();
4606 
4607   // If DIVREM is going to get expanded into a libcall,
4608   // but there is no libcall available, then don't combine.
4609   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4610       !isDivRemLibcallAvailable(Node, isSigned, TLI))
4611     return SDValue();
4612 
4613   // If div is legal, it's better to do the normal expansion
4614   unsigned OtherOpcode = 0;
4615   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4616     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4617     if (TLI.isOperationLegalOrCustom(Opcode, VT))
4618       return SDValue();
4619   } else {
4620     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4621     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4622       return SDValue();
4623   }
4624 
4625   SDValue Op0 = Node->getOperand(0);
4626   SDValue Op1 = Node->getOperand(1);
4627   SDValue combined;
4628   for (SDNode *User : Op0->uses()) {
4629     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4630         User->use_empty())
4631       continue;
4632     // Convert the other matching node(s), too;
4633     // otherwise, the DIVREM may get target-legalized into something
4634     // target-specific that we won't be able to recognize.
4635     unsigned UserOpc = User->getOpcode();
4636     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4637         User->getOperand(0) == Op0 &&
4638         User->getOperand(1) == Op1) {
4639       if (!combined) {
4640         if (UserOpc == OtherOpcode) {
4641           SDVTList VTs = DAG.getVTList(VT, VT);
4642           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4643         } else if (UserOpc == DivRemOpc) {
4644           combined = SDValue(User, 0);
4645         } else {
4646           assert(UserOpc == Opcode);
4647           continue;
4648         }
4649       }
4650       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4651         CombineTo(User, combined);
4652       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4653         CombineTo(User, combined.getValue(1));
4654     }
4655   }
4656   return combined;
4657 }
4658 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4659 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4660   SDValue N0 = N->getOperand(0);
4661   SDValue N1 = N->getOperand(1);
4662   EVT VT = N->getValueType(0);
4663   SDLoc DL(N);
4664 
4665   unsigned Opc = N->getOpcode();
4666   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4667   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4668 
4669   // X / undef -> undef
4670   // X % undef -> undef
4671   // X / 0 -> undef
4672   // X % 0 -> undef
4673   // NOTE: This includes vectors where any divisor element is zero/undef.
4674   if (DAG.isUndef(Opc, {N0, N1}))
4675     return DAG.getUNDEF(VT);
4676 
4677   // undef / X -> 0
4678   // undef % X -> 0
4679   if (N0.isUndef())
4680     return DAG.getConstant(0, DL, VT);
4681 
4682   // 0 / X -> 0
4683   // 0 % X -> 0
4684   ConstantSDNode *N0C = isConstOrConstSplat(N0);
4685   if (N0C && N0C->isZero())
4686     return N0;
4687 
4688   // X / X -> 1
4689   // X % X -> 0
4690   if (N0 == N1)
4691     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
4692 
4693   // X / 1 -> X
4694   // X % 1 -> 0
4695   // If this is a boolean op (single-bit element type), we can't have
4696   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4697   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4698   // it's a 1.
4699   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4700     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
4701 
4702   return SDValue();
4703 }
4704 
visitSDIV(SDNode * N)4705 SDValue DAGCombiner::visitSDIV(SDNode *N) {
4706   SDValue N0 = N->getOperand(0);
4707   SDValue N1 = N->getOperand(1);
4708   EVT VT = N->getValueType(0);
4709   EVT CCVT = getSetCCResultType(VT);
4710   SDLoc DL(N);
4711 
4712   // fold (sdiv c1, c2) -> c1/c2
4713   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4714     return C;
4715 
4716   // fold vector ops
4717   if (VT.isVector())
4718     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4719       return FoldedVOp;
4720 
4721   // fold (sdiv X, -1) -> 0-X
4722   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4723   if (N1C && N1C->isAllOnes())
4724     return DAG.getNegative(N0, DL, VT);
4725 
4726   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4727   if (N1C && N1C->isMinSignedValue())
4728     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4729                          DAG.getConstant(1, DL, VT),
4730                          DAG.getConstant(0, DL, VT));
4731 
4732   if (SDValue V = simplifyDivRem(N, DAG))
4733     return V;
4734 
4735   if (SDValue NewSel = foldBinOpIntoSelect(N))
4736     return NewSel;
4737 
4738   // If we know the sign bits of both operands are zero, strength reduce to a
4739   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
4740   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4741     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4742 
4743   if (SDValue V = visitSDIVLike(N0, N1, N)) {
4744     // If the corresponding remainder node exists, update its users with
4745     // (Dividend - (Quotient * Divisor).
4746     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4747                                               { N0, N1 })) {
4748       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4749       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4750       AddToWorklist(Mul.getNode());
4751       AddToWorklist(Sub.getNode());
4752       CombineTo(RemNode, Sub);
4753     }
4754     return V;
4755   }
4756 
4757   // sdiv, srem -> sdivrem
4758   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4759   // true.  Otherwise, we break the simplification logic in visitREM().
4760   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4761   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4762     if (SDValue DivRem = useDivRem(N))
4763         return DivRem;
4764 
4765   return SDValue();
4766 }
4767 
isDivisorPowerOfTwo(SDValue Divisor)4768 static bool isDivisorPowerOfTwo(SDValue Divisor) {
4769   // Helper for determining whether a value is a power-2 constant scalar or a
4770   // vector of such elements.
4771   auto IsPowerOfTwo = [](ConstantSDNode *C) {
4772     if (C->isZero() || C->isOpaque())
4773       return false;
4774     if (C->getAPIntValue().isPowerOf2())
4775       return true;
4776     if (C->getAPIntValue().isNegatedPowerOf2())
4777       return true;
4778     return false;
4779   };
4780 
4781   return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
4782 }
4783 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4784 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4785   SDLoc DL(N);
4786   EVT VT = N->getValueType(0);
4787   EVT CCVT = getSetCCResultType(VT);
4788   unsigned BitWidth = VT.getScalarSizeInBits();
4789 
4790   // fold (sdiv X, pow2) -> simple ops after legalize
4791   // FIXME: We check for the exact bit here because the generic lowering gives
4792   // better results in that case. The target-specific lowering should learn how
4793   // to handle exact sdivs efficiently.
4794   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
4795     // Target-specific implementation of sdiv x, pow2.
4796     if (SDValue Res = BuildSDIVPow2(N))
4797       return Res;
4798 
4799     // Create constants that are functions of the shift amount value.
4800     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4801     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4802     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4803     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4804     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4805     if (!isConstantOrConstantVector(Inexact))
4806       return SDValue();
4807 
4808     // Splat the sign bit into the register
4809     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4810                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4811     AddToWorklist(Sign.getNode());
4812 
4813     // Add (N0 < 0) ? abs2 - 1 : 0;
4814     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4815     AddToWorklist(Srl.getNode());
4816     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4817     AddToWorklist(Add.getNode());
4818     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4819     AddToWorklist(Sra.getNode());
4820 
4821     // Special case: (sdiv X, 1) -> X
4822     // Special Case: (sdiv X, -1) -> 0-X
4823     SDValue One = DAG.getConstant(1, DL, VT);
4824     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4825     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4826     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4827     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4828     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4829 
4830     // If dividing by a positive value, we're done. Otherwise, the result must
4831     // be negated.
4832     SDValue Zero = DAG.getConstant(0, DL, VT);
4833     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4834 
4835     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4836     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4837     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4838     return Res;
4839   }
4840 
4841   // If integer divide is expensive and we satisfy the requirements, emit an
4842   // alternate sequence.  Targets may check function attributes for size/speed
4843   // trade-offs.
4844   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4845   if (isConstantOrConstantVector(N1) &&
4846       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4847     if (SDValue Op = BuildSDIV(N))
4848       return Op;
4849 
4850   return SDValue();
4851 }
4852 
visitUDIV(SDNode * N)4853 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4854   SDValue N0 = N->getOperand(0);
4855   SDValue N1 = N->getOperand(1);
4856   EVT VT = N->getValueType(0);
4857   EVT CCVT = getSetCCResultType(VT);
4858   SDLoc DL(N);
4859 
4860   // fold (udiv c1, c2) -> c1/c2
4861   if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4862     return C;
4863 
4864   // fold vector ops
4865   if (VT.isVector())
4866     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4867       return FoldedVOp;
4868 
4869   // fold (udiv X, -1) -> select(X == -1, 1, 0)
4870   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4871   if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
4872     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4873                          DAG.getConstant(1, DL, VT),
4874                          DAG.getConstant(0, DL, VT));
4875   }
4876 
4877   if (SDValue V = simplifyDivRem(N, DAG))
4878     return V;
4879 
4880   if (SDValue NewSel = foldBinOpIntoSelect(N))
4881     return NewSel;
4882 
4883   if (SDValue V = visitUDIVLike(N0, N1, N)) {
4884     // If the corresponding remainder node exists, update its users with
4885     // (Dividend - (Quotient * Divisor).
4886     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4887                                               { N0, N1 })) {
4888       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4889       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4890       AddToWorklist(Mul.getNode());
4891       AddToWorklist(Sub.getNode());
4892       CombineTo(RemNode, Sub);
4893     }
4894     return V;
4895   }
4896 
4897   // sdiv, srem -> sdivrem
4898   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4899   // true.  Otherwise, we break the simplification logic in visitREM().
4900   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4901   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4902     if (SDValue DivRem = useDivRem(N))
4903         return DivRem;
4904 
4905   return SDValue();
4906 }
4907 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4908 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4909   SDLoc DL(N);
4910   EVT VT = N->getValueType(0);
4911 
4912   // fold (udiv x, (1 << c)) -> x >>u c
4913   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true)) {
4914     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4915       AddToWorklist(LogBase2.getNode());
4916 
4917       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4918       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4919       AddToWorklist(Trunc.getNode());
4920       return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4921     }
4922   }
4923 
4924   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4925   if (N1.getOpcode() == ISD::SHL) {
4926     SDValue N10 = N1.getOperand(0);
4927     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true)) {
4928       if (SDValue LogBase2 = BuildLogBase2(N10, DL)) {
4929         AddToWorklist(LogBase2.getNode());
4930 
4931         EVT ADDVT = N1.getOperand(1).getValueType();
4932         SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4933         AddToWorklist(Trunc.getNode());
4934         SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4935         AddToWorklist(Add.getNode());
4936         return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4937       }
4938     }
4939   }
4940 
4941   // fold (udiv x, c) -> alternate
4942   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4943   if (isConstantOrConstantVector(N1) &&
4944       !TLI.isIntDivCheap(N->getValueType(0), Attr))
4945     if (SDValue Op = BuildUDIV(N))
4946       return Op;
4947 
4948   return SDValue();
4949 }
4950 
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)4951 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4952   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
4953       !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
4954     // Target-specific implementation of srem x, pow2.
4955     if (SDValue Res = BuildSREMPow2(N))
4956       return Res;
4957   }
4958   return SDValue();
4959 }
4960 
4961 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4962 SDValue DAGCombiner::visitREM(SDNode *N) {
4963   unsigned Opcode = N->getOpcode();
4964   SDValue N0 = N->getOperand(0);
4965   SDValue N1 = N->getOperand(1);
4966   EVT VT = N->getValueType(0);
4967   EVT CCVT = getSetCCResultType(VT);
4968 
4969   bool isSigned = (Opcode == ISD::SREM);
4970   SDLoc DL(N);
4971 
4972   // fold (rem c1, c2) -> c1%c2
4973   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4974     return C;
4975 
4976   // fold (urem X, -1) -> select(FX == -1, 0, FX)
4977   // Freeze the numerator to avoid a miscompile with an undefined value.
4978   if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false) &&
4979       CCVT.isVector() == VT.isVector()) {
4980     SDValue F0 = DAG.getFreeze(N0);
4981     SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
4982     return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
4983   }
4984 
4985   if (SDValue V = simplifyDivRem(N, DAG))
4986     return V;
4987 
4988   if (SDValue NewSel = foldBinOpIntoSelect(N))
4989     return NewSel;
4990 
4991   if (isSigned) {
4992     // If we know the sign bits of both operands are zero, strength reduce to a
4993     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4994     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4995       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4996   } else {
4997     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4998       // fold (urem x, pow2) -> (and x, pow2-1)
4999       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5000       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5001       AddToWorklist(Add.getNode());
5002       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5003     }
5004     // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5005     // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5006     // TODO: We should sink the following into isKnownToBePowerOfTwo
5007     // using a OrZero parameter analogous to our handling in ValueTracking.
5008     if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5009         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
5010       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5011       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5012       AddToWorklist(Add.getNode());
5013       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5014     }
5015   }
5016 
5017   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5018 
5019   // If X/C can be simplified by the division-by-constant logic, lower
5020   // X%C to the equivalent of X-X/C*C.
5021   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5022   // speculative DIV must not cause a DIVREM conversion.  We guard against this
5023   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
5024   // combine will not return a DIVREM.  Regardless, checking cheapness here
5025   // makes sense since the simplification results in fatter code.
5026   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
5027     if (isSigned) {
5028       // check if we can build faster implementation for srem
5029       if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5030         return OptimizedRem;
5031     }
5032 
5033     SDValue OptimizedDiv =
5034         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5035     if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5036       // If the equivalent Div node also exists, update its users.
5037       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5038       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
5039                                                 { N0, N1 }))
5040         CombineTo(DivNode, OptimizedDiv);
5041       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
5042       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5043       AddToWorklist(OptimizedDiv.getNode());
5044       AddToWorklist(Mul.getNode());
5045       return Sub;
5046     }
5047   }
5048 
5049   // sdiv, srem -> sdivrem
5050   if (SDValue DivRem = useDivRem(N))
5051     return DivRem.getValue(1);
5052 
5053   return SDValue();
5054 }
5055 
visitMULHS(SDNode * N)5056 SDValue DAGCombiner::visitMULHS(SDNode *N) {
5057   SDValue N0 = N->getOperand(0);
5058   SDValue N1 = N->getOperand(1);
5059   EVT VT = N->getValueType(0);
5060   SDLoc DL(N);
5061 
5062   // fold (mulhs c1, c2)
5063   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
5064     return C;
5065 
5066   // canonicalize constant to RHS.
5067   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5068       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5069     return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
5070 
5071   if (VT.isVector()) {
5072     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5073       return FoldedVOp;
5074 
5075     // fold (mulhs x, 0) -> 0
5076     // do not return N1, because undef node may exist.
5077     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5078       return DAG.getConstant(0, DL, VT);
5079   }
5080 
5081   // fold (mulhs x, 0) -> 0
5082   if (isNullConstant(N1))
5083     return N1;
5084 
5085   // fold (mulhs x, 1) -> (sra x, size(x)-1)
5086   if (isOneConstant(N1))
5087     return DAG.getNode(
5088         ISD::SRA, DL, VT, N0,
5089         DAG.getShiftAmountConstant(N0.getScalarValueSizeInBits() - 1, VT, DL));
5090 
5091   // fold (mulhs x, undef) -> 0
5092   if (N0.isUndef() || N1.isUndef())
5093     return DAG.getConstant(0, DL, VT);
5094 
5095   // If the type twice as wide is legal, transform the mulhs to a wider multiply
5096   // plus a shift.
5097   if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
5098       !VT.isVector()) {
5099     MVT Simple = VT.getSimpleVT();
5100     unsigned SimpleSize = Simple.getSizeInBits();
5101     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5102     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5103       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5104       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5105       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5106       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5107                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5108       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5109     }
5110   }
5111 
5112   return SDValue();
5113 }
5114 
visitMULHU(SDNode * N)5115 SDValue DAGCombiner::visitMULHU(SDNode *N) {
5116   SDValue N0 = N->getOperand(0);
5117   SDValue N1 = N->getOperand(1);
5118   EVT VT = N->getValueType(0);
5119   SDLoc DL(N);
5120 
5121   // fold (mulhu c1, c2)
5122   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
5123     return C;
5124 
5125   // canonicalize constant to RHS.
5126   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5127       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5128     return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
5129 
5130   if (VT.isVector()) {
5131     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5132       return FoldedVOp;
5133 
5134     // fold (mulhu x, 0) -> 0
5135     // do not return N1, because undef node may exist.
5136     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5137       return DAG.getConstant(0, DL, VT);
5138   }
5139 
5140   // fold (mulhu x, 0) -> 0
5141   if (isNullConstant(N1))
5142     return N1;
5143 
5144   // fold (mulhu x, 1) -> 0
5145   if (isOneConstant(N1))
5146     return DAG.getConstant(0, DL, VT);
5147 
5148   // fold (mulhu x, undef) -> 0
5149   if (N0.isUndef() || N1.isUndef())
5150     return DAG.getConstant(0, DL, VT);
5151 
5152   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5153   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
5154       hasOperation(ISD::SRL, VT)) {
5155     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5156       unsigned NumEltBits = VT.getScalarSizeInBits();
5157       SDValue SRLAmt = DAG.getNode(
5158           ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
5159       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5160       SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
5161       return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5162     }
5163   }
5164 
5165   // If the type twice as wide is legal, transform the mulhu to a wider multiply
5166   // plus a shift.
5167   if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
5168       !VT.isVector()) {
5169     MVT Simple = VT.getSimpleVT();
5170     unsigned SimpleSize = Simple.getSizeInBits();
5171     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5172     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5173       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5174       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5175       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5176       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5177                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5178       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5179     }
5180   }
5181 
5182   // Simplify the operands using demanded-bits information.
5183   // We don't have demanded bits support for MULHU so this just enables constant
5184   // folding based on known bits.
5185   if (SimplifyDemandedBits(SDValue(N, 0)))
5186     return SDValue(N, 0);
5187 
5188   return SDValue();
5189 }
5190 
visitAVG(SDNode * N)5191 SDValue DAGCombiner::visitAVG(SDNode *N) {
5192   unsigned Opcode = N->getOpcode();
5193   SDValue N0 = N->getOperand(0);
5194   SDValue N1 = N->getOperand(1);
5195   EVT VT = N->getValueType(0);
5196   SDLoc DL(N);
5197   bool IsSigned = Opcode == ISD::AVGCEILS || Opcode == ISD::AVGFLOORS;
5198 
5199   // fold (avg c1, c2)
5200   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5201     return C;
5202 
5203   // canonicalize constant to RHS.
5204   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5205       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5206     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5207 
5208   if (VT.isVector())
5209     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5210       return FoldedVOp;
5211 
5212   // fold (avg x, undef) -> x
5213   if (N0.isUndef())
5214     return N1;
5215   if (N1.isUndef())
5216     return N0;
5217 
5218   // fold (avg x, x) --> x
5219   if (N0 == N1 && Level >= AfterLegalizeTypes)
5220     return N0;
5221 
5222   // fold (avgfloor x, 0) -> x >> 1
5223   SDValue X, Y;
5224   if (sd_match(N, m_c_BinOp(ISD::AVGFLOORS, m_Value(X), m_Zero())))
5225     return DAG.getNode(ISD::SRA, DL, VT, X,
5226                        DAG.getShiftAmountConstant(1, VT, DL));
5227   if (sd_match(N, m_c_BinOp(ISD::AVGFLOORU, m_Value(X), m_Zero())))
5228     return DAG.getNode(ISD::SRL, DL, VT, X,
5229                        DAG.getShiftAmountConstant(1, VT, DL));
5230 
5231   // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5232   // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
5233   if (!IsSigned &&
5234       sd_match(N, m_BinOp(Opcode, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
5235       X.getValueType() == Y.getValueType() &&
5236       hasOperation(Opcode, X.getValueType())) {
5237     SDValue AvgU = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5238     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgU);
5239   }
5240   if (IsSigned &&
5241       sd_match(N, m_BinOp(Opcode, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
5242       X.getValueType() == Y.getValueType() &&
5243       hasOperation(Opcode, X.getValueType())) {
5244     SDValue AvgS = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5245     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgS);
5246   }
5247 
5248   // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5249   // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5250   // Check if avgflooru isn't legal/custom but avgceilu is.
5251   if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
5252       (!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
5253     if (DAG.isKnownNeverZero(N1))
5254       return DAG.getNode(
5255           ISD::AVGCEILU, DL, VT, N0,
5256           DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
5257     if (DAG.isKnownNeverZero(N0))
5258       return DAG.getNode(
5259           ISD::AVGCEILU, DL, VT, N1,
5260           DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
5261   }
5262 
5263   return SDValue();
5264 }
5265 
visitABD(SDNode * N)5266 SDValue DAGCombiner::visitABD(SDNode *N) {
5267   unsigned Opcode = N->getOpcode();
5268   SDValue N0 = N->getOperand(0);
5269   SDValue N1 = N->getOperand(1);
5270   EVT VT = N->getValueType(0);
5271   SDLoc DL(N);
5272 
5273   // fold (abd c1, c2)
5274   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5275     return C;
5276 
5277   // canonicalize constant to RHS.
5278   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5279       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5280     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5281 
5282   if (VT.isVector())
5283     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5284       return FoldedVOp;
5285 
5286   // fold (abd x, undef) -> 0
5287   if (N0.isUndef() || N1.isUndef())
5288     return DAG.getConstant(0, DL, VT);
5289 
5290   SDValue X;
5291 
5292   // fold (abds x, 0) -> abs x
5293   if (sd_match(N, m_c_BinOp(ISD::ABDS, m_Value(X), m_Zero())) &&
5294       (!LegalOperations || hasOperation(ISD::ABS, VT)))
5295     return DAG.getNode(ISD::ABS, DL, VT, X);
5296 
5297   // fold (abdu x, 0) -> x
5298   if (sd_match(N, m_c_BinOp(ISD::ABDU, m_Value(X), m_Zero())))
5299     return X;
5300 
5301   // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5302   if (Opcode == ISD::ABDS && hasOperation(ISD::ABDU, VT) &&
5303       DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5304     return DAG.getNode(ISD::ABDU, DL, VT, N1, N0);
5305 
5306   return SDValue();
5307 }
5308 
5309 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5310 /// give the opcodes for the two computations that are being performed. Return
5311 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)5312 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5313                                                 unsigned HiOp) {
5314   // If the high half is not needed, just compute the low half.
5315   bool HiExists = N->hasAnyUseOfValue(1);
5316   if (!HiExists && (!LegalOperations ||
5317                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
5318     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5319     return CombineTo(N, Res, Res);
5320   }
5321 
5322   // If the low half is not needed, just compute the high half.
5323   bool LoExists = N->hasAnyUseOfValue(0);
5324   if (!LoExists && (!LegalOperations ||
5325                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
5326     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5327     return CombineTo(N, Res, Res);
5328   }
5329 
5330   // If both halves are used, return as it is.
5331   if (LoExists && HiExists)
5332     return SDValue();
5333 
5334   // If the two computed results can be simplified separately, separate them.
5335   if (LoExists) {
5336     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5337     AddToWorklist(Lo.getNode());
5338     SDValue LoOpt = combine(Lo.getNode());
5339     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5340         (!LegalOperations ||
5341          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
5342       return CombineTo(N, LoOpt, LoOpt);
5343   }
5344 
5345   if (HiExists) {
5346     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5347     AddToWorklist(Hi.getNode());
5348     SDValue HiOpt = combine(Hi.getNode());
5349     if (HiOpt.getNode() && HiOpt != Hi &&
5350         (!LegalOperations ||
5351          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
5352       return CombineTo(N, HiOpt, HiOpt);
5353   }
5354 
5355   return SDValue();
5356 }
5357 
visitSMUL_LOHI(SDNode * N)5358 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5359   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
5360     return Res;
5361 
5362   SDValue N0 = N->getOperand(0);
5363   SDValue N1 = N->getOperand(1);
5364   EVT VT = N->getValueType(0);
5365   SDLoc DL(N);
5366 
5367   // Constant fold.
5368   if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5369     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N0, N1);
5370 
5371   // canonicalize constant to RHS (vector doesn't have to splat)
5372   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5373       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5374     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
5375 
5376   // If the type is twice as wide is legal, transform the mulhu to a wider
5377   // multiply plus a shift.
5378   if (VT.isSimple() && !VT.isVector()) {
5379     MVT Simple = VT.getSimpleVT();
5380     unsigned SimpleSize = Simple.getSizeInBits();
5381     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5382     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5383       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5384       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5385       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5386       // Compute the high part as N1.
5387       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5388                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5389       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5390       // Compute the low part as N0.
5391       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5392       return CombineTo(N, Lo, Hi);
5393     }
5394   }
5395 
5396   return SDValue();
5397 }
5398 
visitUMUL_LOHI(SDNode * N)5399 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5400   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
5401     return Res;
5402 
5403   SDValue N0 = N->getOperand(0);
5404   SDValue N1 = N->getOperand(1);
5405   EVT VT = N->getValueType(0);
5406   SDLoc DL(N);
5407 
5408   // Constant fold.
5409   if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5410     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N0, N1);
5411 
5412   // canonicalize constant to RHS (vector doesn't have to splat)
5413   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5414       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5415     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
5416 
5417   // (umul_lohi N0, 0) -> (0, 0)
5418   if (isNullConstant(N1)) {
5419     SDValue Zero = DAG.getConstant(0, DL, VT);
5420     return CombineTo(N, Zero, Zero);
5421   }
5422 
5423   // (umul_lohi N0, 1) -> (N0, 0)
5424   if (isOneConstant(N1)) {
5425     SDValue Zero = DAG.getConstant(0, DL, VT);
5426     return CombineTo(N, N0, Zero);
5427   }
5428 
5429   // If the type is twice as wide is legal, transform the mulhu to a wider
5430   // multiply plus a shift.
5431   if (VT.isSimple() && !VT.isVector()) {
5432     MVT Simple = VT.getSimpleVT();
5433     unsigned SimpleSize = Simple.getSizeInBits();
5434     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5435     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5436       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5437       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5438       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5439       // Compute the high part as N1.
5440       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5441                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5442       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5443       // Compute the low part as N0.
5444       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5445       return CombineTo(N, Lo, Hi);
5446     }
5447   }
5448 
5449   return SDValue();
5450 }
5451 
visitMULO(SDNode * N)5452 SDValue DAGCombiner::visitMULO(SDNode *N) {
5453   SDValue N0 = N->getOperand(0);
5454   SDValue N1 = N->getOperand(1);
5455   EVT VT = N0.getValueType();
5456   bool IsSigned = (ISD::SMULO == N->getOpcode());
5457 
5458   EVT CarryVT = N->getValueType(1);
5459   SDLoc DL(N);
5460 
5461   ConstantSDNode *N0C = isConstOrConstSplat(N0);
5462   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5463 
5464   // fold operation with constant operands.
5465   // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5466   // multiple results.
5467   if (N0C && N1C) {
5468     bool Overflow;
5469     APInt Result =
5470         IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
5471                  : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
5472     return CombineTo(N, DAG.getConstant(Result, DL, VT),
5473                      DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
5474   }
5475 
5476   // canonicalize constant to RHS.
5477   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5478       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5479     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
5480 
5481   // fold (mulo x, 0) -> 0 + no carry out
5482   if (isNullOrNullSplat(N1))
5483     return CombineTo(N, DAG.getConstant(0, DL, VT),
5484                      DAG.getConstant(0, DL, CarryVT));
5485 
5486   // (mulo x, 2) -> (addo x, x)
5487   // FIXME: This needs a freeze.
5488   if (N1C && N1C->getAPIntValue() == 2 &&
5489       (!IsSigned || VT.getScalarSizeInBits() > 2))
5490     return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5491                        N->getVTList(), N0, N0);
5492 
5493   // A 1 bit SMULO overflows if both inputs are 1.
5494   if (IsSigned && VT.getScalarSizeInBits() == 1) {
5495     SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5496     SDValue Cmp = DAG.getSetCC(DL, CarryVT, And,
5497                                DAG.getConstant(0, DL, VT), ISD::SETNE);
5498     return CombineTo(N, And, Cmp);
5499   }
5500 
5501   // If it cannot overflow, transform into a mul.
5502   if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5503     return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5504                      DAG.getConstant(0, DL, CarryVT));
5505   return SDValue();
5506 }
5507 
5508 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5509 // swapped around) make a signed saturate pattern, clamping to between a signed
5510 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5511 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5512 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5513 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned,SelectionDAG & DAG)5514 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5515                                   SDValue N3, ISD::CondCode CC, unsigned &BW,
5516                                   bool &Unsigned, SelectionDAG &DAG) {
5517   auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5518                             ISD::CondCode CC) {
5519     // The compare and select operand should be the same or the select operands
5520     // should be truncated versions of the comparison.
5521     if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5522       return 0;
5523     // The constants need to be the same or a truncated version of each other.
5524     ConstantSDNode *N1C = isConstOrConstSplat(peekThroughTruncates(N1));
5525     ConstantSDNode *N3C = isConstOrConstSplat(peekThroughTruncates(N3));
5526     if (!N1C || !N3C)
5527       return 0;
5528     const APInt &C1 = N1C->getAPIntValue().trunc(N1.getScalarValueSizeInBits());
5529     const APInt &C2 = N3C->getAPIntValue().trunc(N3.getScalarValueSizeInBits());
5530     if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5531       return 0;
5532     return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5533   };
5534 
5535   // Check the initial value is a SMIN/SMAX equivalent.
5536   unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5537   if (!Opcode0)
5538     return SDValue();
5539 
5540   // We could only need one range check, if the fptosi could never produce
5541   // the upper value.
5542   if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5543     if (isNullOrNullSplat(N3)) {
5544       EVT IntVT = N0.getValueType().getScalarType();
5545       EVT FPVT = N0.getOperand(0).getValueType().getScalarType();
5546       if (FPVT.isSimple()) {
5547         Type *InputTy = FPVT.getTypeForEVT(*DAG.getContext());
5548         const fltSemantics &Semantics = InputTy->getFltSemantics();
5549         uint32_t MinBitWidth =
5550           APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5551         if (IntVT.getSizeInBits() >= MinBitWidth) {
5552           Unsigned = true;
5553           BW = PowerOf2Ceil(MinBitWidth);
5554           return N0;
5555         }
5556       }
5557     }
5558   }
5559 
5560   SDValue N00, N01, N02, N03;
5561   ISD::CondCode N0CC;
5562   switch (N0.getOpcode()) {
5563   case ISD::SMIN:
5564   case ISD::SMAX:
5565     N00 = N02 = N0.getOperand(0);
5566     N01 = N03 = N0.getOperand(1);
5567     N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5568     break;
5569   case ISD::SELECT_CC:
5570     N00 = N0.getOperand(0);
5571     N01 = N0.getOperand(1);
5572     N02 = N0.getOperand(2);
5573     N03 = N0.getOperand(3);
5574     N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5575     break;
5576   case ISD::SELECT:
5577   case ISD::VSELECT:
5578     if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5579       return SDValue();
5580     N00 = N0.getOperand(0).getOperand(0);
5581     N01 = N0.getOperand(0).getOperand(1);
5582     N02 = N0.getOperand(1);
5583     N03 = N0.getOperand(2);
5584     N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5585     break;
5586   default:
5587     return SDValue();
5588   }
5589 
5590   unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5591   if (!Opcode1 || Opcode0 == Opcode1)
5592     return SDValue();
5593 
5594   ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5595   ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5596   if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5597     return SDValue();
5598 
5599   const APInt &MinC = MinCOp->getAPIntValue();
5600   const APInt &MaxC = MaxCOp->getAPIntValue();
5601   APInt MinCPlus1 = MinC + 1;
5602   if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5603     BW = MinCPlus1.exactLogBase2() + 1;
5604     Unsigned = false;
5605     return N02;
5606   }
5607 
5608   if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5609     BW = MinCPlus1.exactLogBase2();
5610     Unsigned = true;
5611     return N02;
5612   }
5613 
5614   return SDValue();
5615 }
5616 
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5617 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5618                                            SDValue N3, ISD::CondCode CC,
5619                                            SelectionDAG &DAG) {
5620   unsigned BW;
5621   bool Unsigned;
5622   SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
5623   if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5624     return SDValue();
5625   EVT FPVT = Fp.getOperand(0).getValueType();
5626   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5627   if (FPVT.isVector())
5628     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5629                              FPVT.getVectorElementCount());
5630   unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5631   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
5632     return SDValue();
5633   SDLoc DL(Fp);
5634   SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
5635                             DAG.getValueType(NewVT.getScalarType()));
5636   return DAG.getExtOrTrunc(!Unsigned, Sat, DL, N2->getValueType(0));
5637 }
5638 
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5639 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5640                                          SDValue N3, ISD::CondCode CC,
5641                                          SelectionDAG &DAG) {
5642   // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5643   // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5644   // be truncated versions of the setcc (N0/N1).
5645   if ((N0 != N2 &&
5646        (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
5647       N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5648     return SDValue();
5649   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5650   ConstantSDNode *N3C = isConstOrConstSplat(N3);
5651   if (!N1C || !N3C)
5652     return SDValue();
5653   const APInt &C1 = N1C->getAPIntValue();
5654   const APInt &C3 = N3C->getAPIntValue();
5655   if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5656       C1 != C3.zext(C1.getBitWidth()))
5657     return SDValue();
5658 
5659   unsigned BW = (C1 + 1).exactLogBase2();
5660   EVT FPVT = N0.getOperand(0).getValueType();
5661   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5662   if (FPVT.isVector())
5663     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5664                              FPVT.getVectorElementCount());
5665   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
5666                                                         FPVT, NewVT))
5667     return SDValue();
5668 
5669   SDValue Sat =
5670       DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
5671                   DAG.getValueType(NewVT.getScalarType()));
5672   return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
5673 }
5674 
visitIMINMAX(SDNode * N)5675 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5676   SDValue N0 = N->getOperand(0);
5677   SDValue N1 = N->getOperand(1);
5678   EVT VT = N0.getValueType();
5679   unsigned Opcode = N->getOpcode();
5680   SDLoc DL(N);
5681 
5682   // fold operation with constant operands.
5683   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5684     return C;
5685 
5686   // If the operands are the same, this is a no-op.
5687   if (N0 == N1)
5688     return N0;
5689 
5690   // canonicalize constant to RHS
5691   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5692       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5693     return DAG.getNode(Opcode, DL, VT, N1, N0);
5694 
5695   // fold vector ops
5696   if (VT.isVector())
5697     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5698       return FoldedVOp;
5699 
5700   // reassociate minmax
5701   if (SDValue RMINMAX = reassociateOps(Opcode, DL, N0, N1, N->getFlags()))
5702     return RMINMAX;
5703 
5704   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5705   // Only do this if:
5706   // 1. The current op isn't legal and the flipped is.
5707   // 2. The saturation pattern is broken by canonicalization in InstCombine.
5708   bool IsOpIllegal = !TLI.isOperationLegal(Opcode, VT);
5709   bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
5710   if ((IsSatBroken || IsOpIllegal) && (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
5711       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
5712     unsigned AltOpcode;
5713     switch (Opcode) {
5714     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5715     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5716     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5717     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5718     default: llvm_unreachable("Unknown MINMAX opcode");
5719     }
5720     if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(AltOpcode, VT))
5721       return DAG.getNode(AltOpcode, DL, VT, N0, N1);
5722   }
5723 
5724   if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5725     if (SDValue S = PerformMinMaxFpToSatCombine(
5726             N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5727       return S;
5728   if (Opcode == ISD::UMIN)
5729     if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
5730       return S;
5731 
5732   // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
5733   auto ReductionOpcode = [](unsigned Opcode) {
5734     switch (Opcode) {
5735     case ISD::SMIN:
5736       return ISD::VECREDUCE_SMIN;
5737     case ISD::SMAX:
5738       return ISD::VECREDUCE_SMAX;
5739     case ISD::UMIN:
5740       return ISD::VECREDUCE_UMIN;
5741     case ISD::UMAX:
5742       return ISD::VECREDUCE_UMAX;
5743     default:
5744       llvm_unreachable("Unexpected opcode");
5745     }
5746   };
5747   if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
5748                                         SDLoc(N), VT, N0, N1))
5749     return SD;
5750 
5751   // Simplify the operands using demanded-bits information.
5752   if (SimplifyDemandedBits(SDValue(N, 0)))
5753     return SDValue(N, 0);
5754 
5755   return SDValue();
5756 }
5757 
5758 /// If this is a bitwise logic instruction and both operands have the same
5759 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)5760 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5761   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
5762   EVT VT = N0.getValueType();
5763   unsigned LogicOpcode = N->getOpcode();
5764   unsigned HandOpcode = N0.getOpcode();
5765   assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
5766   assert(HandOpcode == N1.getOpcode() && "Bad input!");
5767 
5768   // Bail early if none of these transforms apply.
5769   if (N0.getNumOperands() == 0)
5770     return SDValue();
5771 
5772   // FIXME: We should check number of uses of the operands to not increase
5773   //        the instruction count for all transforms.
5774 
5775   // Handle size-changing casts (or sign_extend_inreg).
5776   SDValue X = N0.getOperand(0);
5777   SDValue Y = N1.getOperand(0);
5778   EVT XVT = X.getValueType();
5779   SDLoc DL(N);
5780   if (ISD::isExtOpcode(HandOpcode) || ISD::isExtVecInRegOpcode(HandOpcode) ||
5781       (HandOpcode == ISD::SIGN_EXTEND_INREG &&
5782        N0.getOperand(1) == N1.getOperand(1))) {
5783     // If both operands have other uses, this transform would create extra
5784     // instructions without eliminating anything.
5785     if (!N0.hasOneUse() && !N1.hasOneUse())
5786       return SDValue();
5787     // We need matching integer source types.
5788     if (XVT != Y.getValueType())
5789       return SDValue();
5790     // Don't create an illegal op during or after legalization. Don't ever
5791     // create an unsupported vector op.
5792     if ((VT.isVector() || LegalOperations) &&
5793         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
5794       return SDValue();
5795     // Avoid infinite looping with PromoteIntBinOp.
5796     // TODO: Should we apply desirable/legal constraints to all opcodes?
5797     if ((HandOpcode == ISD::ANY_EXTEND ||
5798          HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
5799         LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
5800       return SDValue();
5801     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5802     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5803     if (HandOpcode == ISD::SIGN_EXTEND_INREG)
5804       return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5805     return DAG.getNode(HandOpcode, DL, VT, Logic);
5806   }
5807 
5808   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5809   if (HandOpcode == ISD::TRUNCATE) {
5810     // If both operands have other uses, this transform would create extra
5811     // instructions without eliminating anything.
5812     if (!N0.hasOneUse() && !N1.hasOneUse())
5813       return SDValue();
5814     // We need matching source types.
5815     if (XVT != Y.getValueType())
5816       return SDValue();
5817     // Don't create an illegal op during or after legalization.
5818     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
5819       return SDValue();
5820     // Be extra careful sinking truncate. If it's free, there's no benefit in
5821     // widening a binop. Also, don't create a logic op on an illegal type.
5822     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
5823       return SDValue();
5824     if (!TLI.isTypeLegal(XVT))
5825       return SDValue();
5826     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5827     return DAG.getNode(HandOpcode, DL, VT, Logic);
5828   }
5829 
5830   // For binops SHL/SRL/SRA/AND:
5831   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5832   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5833        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5834       N0.getOperand(1) == N1.getOperand(1)) {
5835     // If either operand has other uses, this transform is not an improvement.
5836     if (!N0.hasOneUse() || !N1.hasOneUse())
5837       return SDValue();
5838     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5839     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5840   }
5841 
5842   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5843   if (HandOpcode == ISD::BSWAP) {
5844     // If either operand has other uses, this transform is not an improvement.
5845     if (!N0.hasOneUse() || !N1.hasOneUse())
5846       return SDValue();
5847     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5848     return DAG.getNode(HandOpcode, DL, VT, Logic);
5849   }
5850 
5851   // For funnel shifts FSHL/FSHR:
5852   // logic_op (OP x, x1, s), (OP y, y1, s) -->
5853   // --> OP (logic_op x, y), (logic_op, x1, y1), s
5854   if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
5855       N0.getOperand(2) == N1.getOperand(2)) {
5856     if (!N0.hasOneUse() || !N1.hasOneUse())
5857       return SDValue();
5858     SDValue X1 = N0.getOperand(1);
5859     SDValue Y1 = N1.getOperand(1);
5860     SDValue S = N0.getOperand(2);
5861     SDValue Logic0 = DAG.getNode(LogicOpcode, DL, VT, X, Y);
5862     SDValue Logic1 = DAG.getNode(LogicOpcode, DL, VT, X1, Y1);
5863     return DAG.getNode(HandOpcode, DL, VT, Logic0, Logic1, S);
5864   }
5865 
5866   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5867   // Only perform this optimization up until type legalization, before
5868   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5869   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5870   // we don't want to undo this promotion.
5871   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5872   // on scalars.
5873   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5874        Level <= AfterLegalizeTypes) {
5875     // Input types must be integer and the same.
5876     if (XVT.isInteger() && XVT == Y.getValueType() &&
5877         !(VT.isVector() && TLI.isTypeLegal(VT) &&
5878           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
5879       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5880       return DAG.getNode(HandOpcode, DL, VT, Logic);
5881     }
5882   }
5883 
5884   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5885   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5886   // If both shuffles use the same mask, and both shuffle within a single
5887   // vector, then it is worthwhile to move the swizzle after the operation.
5888   // The type-legalizer generates this pattern when loading illegal
5889   // vector types from memory. In many cases this allows additional shuffle
5890   // optimizations.
5891   // There are other cases where moving the shuffle after the xor/and/or
5892   // is profitable even if shuffles don't perform a swizzle.
5893   // If both shuffles use the same mask, and both shuffles have the same first
5894   // or second operand, then it might still be profitable to move the shuffle
5895   // after the xor/and/or operation.
5896   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5897     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
5898     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
5899     assert(X.getValueType() == Y.getValueType() &&
5900            "Inputs to shuffles are not the same type");
5901 
5902     // Check that both shuffles use the same mask. The masks are known to be of
5903     // the same length because the result vector type is the same.
5904     // Check also that shuffles have only one use to avoid introducing extra
5905     // instructions.
5906     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5907         !SVN0->getMask().equals(SVN1->getMask()))
5908       return SDValue();
5909 
5910     // Don't try to fold this node if it requires introducing a
5911     // build vector of all zeros that might be illegal at this stage.
5912     SDValue ShOp = N0.getOperand(1);
5913     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5914       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5915 
5916     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5917     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
5918       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
5919                                   N0.getOperand(0), N1.getOperand(0));
5920       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
5921     }
5922 
5923     // Don't try to fold this node if it requires introducing a
5924     // build vector of all zeros that might be illegal at this stage.
5925     ShOp = N0.getOperand(0);
5926     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5927       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5928 
5929     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5930     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
5931       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
5932                                   N1.getOperand(1));
5933       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
5934     }
5935   }
5936 
5937   return SDValue();
5938 }
5939 
5940 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)5941 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5942                                        const SDLoc &DL) {
5943   SDValue LL, LR, RL, RR, N0CC, N1CC;
5944   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
5945       !isSetCCEquivalent(N1, RL, RR, N1CC))
5946     return SDValue();
5947 
5948   assert(N0.getValueType() == N1.getValueType() &&
5949          "Unexpected operand types for bitwise logic op");
5950   assert(LL.getValueType() == LR.getValueType() &&
5951          RL.getValueType() == RR.getValueType() &&
5952          "Unexpected operand types for setcc");
5953 
5954   // If we're here post-legalization or the logic op type is not i1, the logic
5955   // op type must match a setcc result type. Also, all folds require new
5956   // operations on the left and right operands, so those types must match.
5957   EVT VT = N0.getValueType();
5958   EVT OpVT = LL.getValueType();
5959   if (LegalOperations || VT.getScalarType() != MVT::i1)
5960     if (VT != getSetCCResultType(OpVT))
5961       return SDValue();
5962   if (OpVT != RL.getValueType())
5963     return SDValue();
5964 
5965   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
5966   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
5967   bool IsInteger = OpVT.isInteger();
5968   if (LR == RR && CC0 == CC1 && IsInteger) {
5969     bool IsZero = isNullOrNullSplat(LR);
5970     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
5971 
5972     // All bits clear?
5973     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5974     // All sign bits clear?
5975     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5976     // Any bits set?
5977     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5978     // Any sign bits set?
5979     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5980 
5981     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
5982     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5983     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
5984     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
5985     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5986       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
5987       AddToWorklist(Or.getNode());
5988       return DAG.getSetCC(DL, VT, Or, LR, CC1);
5989     }
5990 
5991     // All bits set?
5992     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5993     // All sign bits set?
5994     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5995     // Any bits clear?
5996     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5997     // Any sign bits clear?
5998     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5999 
6000     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
6001     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
6002     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
6003     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
6004     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6005       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
6006       AddToWorklist(And.getNode());
6007       return DAG.getSetCC(DL, VT, And, LR, CC1);
6008     }
6009   }
6010 
6011   // TODO: What is the 'or' equivalent of this fold?
6012   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6013   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6014       IsInteger && CC0 == ISD::SETNE &&
6015       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
6016        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
6017     SDValue One = DAG.getConstant(1, DL, OpVT);
6018     SDValue Two = DAG.getConstant(2, DL, OpVT);
6019     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
6020     AddToWorklist(Add.getNode());
6021     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
6022   }
6023 
6024   // Try more general transforms if the predicates match and the only user of
6025   // the compares is the 'and' or 'or'.
6026   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
6027       N0.hasOneUse() && N1.hasOneUse()) {
6028     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6029     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6030     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6031       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
6032       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
6033       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
6034       SDValue Zero = DAG.getConstant(0, DL, OpVT);
6035       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
6036     }
6037 
6038     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6039     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6040       // Match a shared variable operand and 2 non-opaque constant operands.
6041       auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6042         // The difference of the constants must be a single bit.
6043         const APInt &CMax =
6044             APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
6045         const APInt &CMin =
6046             APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
6047         return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6048       };
6049       if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
6050         // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6051         // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6052         SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
6053         SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
6054         SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
6055         SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
6056         SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
6057         SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
6058         SDValue Zero = DAG.getConstant(0, DL, OpVT);
6059         return DAG.getSetCC(DL, VT, And, Zero, CC0);
6060       }
6061     }
6062   }
6063 
6064   // Canonicalize equivalent operands to LL == RL.
6065   if (LL == RR && LR == RL) {
6066     CC1 = ISD::getSetCCSwappedOperands(CC1);
6067     std::swap(RL, RR);
6068   }
6069 
6070   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6071   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6072   if (LL == RL && LR == RR) {
6073     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
6074                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
6075     if (NewCC != ISD::SETCC_INVALID &&
6076         (!LegalOperations ||
6077          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
6078           TLI.isOperationLegal(ISD::SETCC, OpVT))))
6079       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
6080   }
6081 
6082   return SDValue();
6083 }
6084 
arebothOperandsNotSNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6085 static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6086                                    SelectionDAG &DAG) {
6087   return DAG.isKnownNeverSNaN(Operand2) && DAG.isKnownNeverSNaN(Operand1);
6088 }
6089 
arebothOperandsNotNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6090 static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6091                                   SelectionDAG &DAG) {
6092   return DAG.isKnownNeverNaN(Operand2) && DAG.isKnownNeverNaN(Operand1);
6093 }
6094 
getMinMaxOpcodeForFP(SDValue Operand1,SDValue Operand2,ISD::CondCode CC,unsigned OrAndOpcode,SelectionDAG & DAG,bool isFMAXNUMFMINNUM_IEEE,bool isFMAXNUMFMINNUM)6095 static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
6096                                      ISD::CondCode CC, unsigned OrAndOpcode,
6097                                      SelectionDAG &DAG,
6098                                      bool isFMAXNUMFMINNUM_IEEE,
6099                                      bool isFMAXNUMFMINNUM) {
6100   // The optimization cannot be applied for all the predicates because
6101   // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6102   // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6103   // applied at all if one of the operands is a signaling NaN.
6104 
6105   // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6106   // are non NaN values.
6107   if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6108       ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
6109     return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6110                    isFMAXNUMFMINNUM_IEEE
6111                ? ISD::FMINNUM_IEEE
6112                : ISD::DELETED_NODE;
6113   else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
6114             (OrAndOpcode == ISD::OR)) ||
6115            ((CC == ISD::SETLT || CC == ISD::SETLE) &&
6116             (OrAndOpcode == ISD::AND)))
6117     return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6118                    isFMAXNUMFMINNUM_IEEE
6119                ? ISD::FMAXNUM_IEEE
6120                : ISD::DELETED_NODE;
6121   // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6122   // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6123   // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6124   // that there are not any sNaNs, then the optimization is not valid
6125   // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6126   // the optimization using FMINNUM/FMAXNUM for the following cases. If
6127   // we can prove that we do not have any sNaNs, then we can do the
6128   // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6129   // cases.
6130   else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6131             (OrAndOpcode == ISD::OR)) ||
6132            ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6133             (OrAndOpcode == ISD::AND)))
6134     return isFMAXNUMFMINNUM ? ISD::FMINNUM
6135                             : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6136                                       isFMAXNUMFMINNUM_IEEE
6137                                   ? ISD::FMINNUM_IEEE
6138                                   : ISD::DELETED_NODE;
6139   else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6140             (OrAndOpcode == ISD::OR)) ||
6141            ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6142             (OrAndOpcode == ISD::AND)))
6143     return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6144                             : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6145                                       isFMAXNUMFMINNUM_IEEE
6146                                   ? ISD::FMAXNUM_IEEE
6147                                   : ISD::DELETED_NODE;
6148   return ISD::DELETED_NODE;
6149 }
6150 
foldAndOrOfSETCC(SDNode * LogicOp,SelectionDAG & DAG)6151 static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6152   using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6153   assert(
6154       (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6155       "Invalid Op to combine SETCC with");
6156 
6157   // TODO: Search past casts/truncates.
6158   SDValue LHS = LogicOp->getOperand(0);
6159   SDValue RHS = LogicOp->getOperand(1);
6160   if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6161       !LHS->hasOneUse() || !RHS->hasOneUse())
6162     return SDValue();
6163 
6164   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6165   AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6166       LogicOp, LHS.getNode(), RHS.getNode());
6167 
6168   SDValue LHS0 = LHS->getOperand(0);
6169   SDValue RHS0 = RHS->getOperand(0);
6170   SDValue LHS1 = LHS->getOperand(1);
6171   SDValue RHS1 = RHS->getOperand(1);
6172   // TODO: We don't actually need a splat here, for vectors we just need the
6173   // invariants to hold for each element.
6174   auto *LHS1C = isConstOrConstSplat(LHS1);
6175   auto *RHS1C = isConstOrConstSplat(RHS1);
6176   ISD::CondCode CCL = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
6177   ISD::CondCode CCR = cast<CondCodeSDNode>(RHS.getOperand(2))->get();
6178   EVT VT = LogicOp->getValueType(0);
6179   EVT OpVT = LHS0.getValueType();
6180   SDLoc DL(LogicOp);
6181 
6182   // Check if the operands of an and/or operation are comparisons and if they
6183   // compare against the same value. Replace the and/or-cmp-cmp sequence with
6184   // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6185   // sequence will be replaced with min-cmp sequence:
6186   // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6187   // and and-cmp-cmp will be replaced with max-cmp sequence:
6188   // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6189   // The optimization does not work for `==` or `!=` .
6190   // The two comparisons should have either the same predicate or the
6191   // predicate of one of the comparisons is the opposite of the other one.
6192   bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(ISD::FMAXNUM_IEEE, OpVT) &&
6193                                TLI.isOperationLegal(ISD::FMINNUM_IEEE, OpVT);
6194   bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(ISD::FMAXNUM, OpVT) &&
6195                           TLI.isOperationLegalOrCustom(ISD::FMINNUM, OpVT);
6196   if (((OpVT.isInteger() && TLI.isOperationLegal(ISD::UMAX, OpVT) &&
6197         TLI.isOperationLegal(ISD::SMAX, OpVT) &&
6198         TLI.isOperationLegal(ISD::UMIN, OpVT) &&
6199         TLI.isOperationLegal(ISD::SMIN, OpVT)) ||
6200        (OpVT.isFloatingPoint() &&
6201         (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6202       !ISD::isIntEqualitySetCC(CCL) && !ISD::isFPEqualitySetCC(CCL) &&
6203       CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6204       CCL != ISD::SETTRUE &&
6205       (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(CCR))) {
6206 
6207     SDValue CommonValue, Operand1, Operand2;
6208     ISD::CondCode CC = ISD::SETCC_INVALID;
6209     if (CCL == CCR) {
6210       if (LHS0 == RHS0) {
6211         CommonValue = LHS0;
6212         Operand1 = LHS1;
6213         Operand2 = RHS1;
6214         CC = ISD::getSetCCSwappedOperands(CCL);
6215       } else if (LHS1 == RHS1) {
6216         CommonValue = LHS1;
6217         Operand1 = LHS0;
6218         Operand2 = RHS0;
6219         CC = CCL;
6220       }
6221     } else {
6222       assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6223       if (LHS0 == RHS1) {
6224         CommonValue = LHS0;
6225         Operand1 = LHS1;
6226         Operand2 = RHS0;
6227         CC = CCR;
6228       } else if (RHS0 == LHS1) {
6229         CommonValue = LHS1;
6230         Operand1 = LHS0;
6231         Operand2 = RHS1;
6232         CC = CCL;
6233       }
6234     }
6235 
6236     // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6237     // handle it using OR/AND.
6238     if (CC == ISD::SETLT && isNullOrNullSplat(CommonValue))
6239       CC = ISD::SETCC_INVALID;
6240     else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CommonValue))
6241       CC = ISD::SETCC_INVALID;
6242 
6243     if (CC != ISD::SETCC_INVALID) {
6244       unsigned NewOpcode = ISD::DELETED_NODE;
6245       bool IsSigned = isSignedIntSetCC(CC);
6246       if (OpVT.isInteger()) {
6247         bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6248                        CC == ISD::SETLT || CC == ISD::SETULT);
6249         bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6250         if (IsLess == IsOr)
6251           NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6252         else
6253           NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6254       } else if (OpVT.isFloatingPoint())
6255         NewOpcode =
6256             getMinMaxOpcodeForFP(Operand1, Operand2, CC, LogicOp->getOpcode(),
6257                                  DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6258 
6259       if (NewOpcode != ISD::DELETED_NODE) {
6260         SDValue MinMaxValue =
6261             DAG.getNode(NewOpcode, DL, OpVT, Operand1, Operand2);
6262         return DAG.getSetCC(DL, VT, MinMaxValue, CommonValue, CC);
6263       }
6264     }
6265   }
6266 
6267   if (TargetPreference == AndOrSETCCFoldKind::None)
6268     return SDValue();
6269 
6270   if (CCL == CCR &&
6271       CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6272       LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6273     const APInt &APLhs = LHS1C->getAPIntValue();
6274     const APInt &APRhs = RHS1C->getAPIntValue();
6275 
6276     // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6277     // case this is just a compare).
6278     if (APLhs == (-APRhs) &&
6279         ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6280          DAG.doesNodeExist(ISD::ABS, DAG.getVTList(OpVT), {LHS0}))) {
6281       const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6282       // (icmp eq A, C) | (icmp eq A, -C)
6283       //    -> (icmp eq Abs(A), C)
6284       // (icmp ne A, C) & (icmp ne A, -C)
6285       //    -> (icmp ne Abs(A), C)
6286       SDValue AbsOp = DAG.getNode(ISD::ABS, DL, OpVT, LHS0);
6287       return DAG.getNode(ISD::SETCC, DL, VT, AbsOp,
6288                          DAG.getConstant(C, DL, OpVT), LHS.getOperand(2));
6289     } else if (TargetPreference &
6290                (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6291 
6292       // AndOrSETCCFoldKind::AddAnd:
6293       // A == C0 | A == C1
6294       //  IF IsPow2(smax(C0, C1)-smin(C0, C1))
6295       //    -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6296       // A != C0 & A != C1
6297       //  IF IsPow2(smax(C0, C1)-smin(C0, C1))
6298       //    -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6299 
6300       // AndOrSETCCFoldKind::NotAnd:
6301       // A == C0 | A == C1
6302       //  IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6303       //    -> ~A & smin(C0, C1) == 0
6304       // A != C0 & A != C1
6305       //  IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6306       //    -> ~A & smin(C0, C1) != 0
6307 
6308       const APInt &MaxC = APIntOps::smax(APRhs, APLhs);
6309       const APInt &MinC = APIntOps::smin(APRhs, APLhs);
6310       APInt Dif = MaxC - MinC;
6311       if (!Dif.isZero() && Dif.isPowerOf2()) {
6312         if (MaxC.isAllOnes() &&
6313             (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6314           SDValue NotOp = DAG.getNOT(DL, LHS0, OpVT);
6315           SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, NotOp,
6316                                       DAG.getConstant(MinC, DL, OpVT));
6317           return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6318                              DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6319         } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6320 
6321           SDValue AddOp = DAG.getNode(ISD::ADD, DL, OpVT, LHS0,
6322                                       DAG.getConstant(-MinC, DL, OpVT));
6323           SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, AddOp,
6324                                       DAG.getConstant(~Dif, DL, OpVT));
6325           return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6326                              DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6327         }
6328       }
6329     }
6330   }
6331 
6332   return SDValue();
6333 }
6334 
6335 // Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6336 // We canonicalize to the `select` form in the middle end, but the `and` form
6337 // gets better codegen and all tested targets (arm, x86, riscv)
combineSelectAsExtAnd(SDValue Cond,SDValue T,SDValue F,const SDLoc & DL,SelectionDAG & DAG)6338 static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6339                                      const SDLoc &DL, SelectionDAG &DAG) {
6340   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6341   if (!isNullConstant(F))
6342     return SDValue();
6343 
6344   EVT CondVT = Cond.getValueType();
6345   if (TLI.getBooleanContents(CondVT) !=
6346       TargetLoweringBase::ZeroOrOneBooleanContent)
6347     return SDValue();
6348 
6349   if (T.getOpcode() != ISD::AND)
6350     return SDValue();
6351 
6352   if (!isOneConstant(T.getOperand(1)))
6353     return SDValue();
6354 
6355   EVT OpVT = T.getValueType();
6356 
6357   SDValue CondMask =
6358       OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Cond, DL, OpVT, CondVT);
6359   return DAG.getNode(ISD::AND, DL, OpVT, CondMask, T.getOperand(0));
6360 }
6361 
6362 /// This contains all DAGCombine rules which reduce two values combined by
6363 /// an And operation to a single value. This makes them reusable in the context
6364 /// of visitSELECT(). Rules involving constants are not included as
6365 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)6366 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6367   EVT VT = N1.getValueType();
6368   SDLoc DL(N);
6369 
6370   // fold (and x, undef) -> 0
6371   if (N0.isUndef() || N1.isUndef())
6372     return DAG.getConstant(0, DL, VT);
6373 
6374   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
6375     return V;
6376 
6377   // Canonicalize:
6378   //   and(x, add) -> and(add, x)
6379   if (N1.getOpcode() == ISD::ADD)
6380     std::swap(N0, N1);
6381 
6382   // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6383   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6384       VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6385     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
6386       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
6387         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6388         // immediate for an add, but it is legal if its top c2 bits are set,
6389         // transform the ADD so the immediate doesn't need to be materialized
6390         // in a register.
6391         APInt ADDC = ADDI->getAPIntValue();
6392         APInt SRLC = SRLI->getAPIntValue();
6393         if (ADDC.getSignificantBits() <= 64 && SRLC.ult(VT.getSizeInBits()) &&
6394             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6395           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
6396                                              SRLC.getZExtValue());
6397           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
6398             ADDC |= Mask;
6399             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6400               SDLoc DL0(N0);
6401               SDValue NewAdd =
6402                 DAG.getNode(ISD::ADD, DL0, VT,
6403                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
6404               CombineTo(N0.getNode(), NewAdd);
6405               // Return N so it doesn't get rechecked!
6406               return SDValue(N, 0);
6407             }
6408           }
6409         }
6410       }
6411     }
6412   }
6413 
6414   return SDValue();
6415 }
6416 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)6417 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6418                                    EVT LoadResultTy, EVT &ExtVT) {
6419   if (!AndC->getAPIntValue().isMask())
6420     return false;
6421 
6422   unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6423 
6424   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6425   EVT LoadedVT = LoadN->getMemoryVT();
6426 
6427   if (ExtVT == LoadedVT &&
6428       (!LegalOperations ||
6429        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
6430     // ZEXTLOAD will match without needing to change the size of the value being
6431     // loaded.
6432     return true;
6433   }
6434 
6435   // Do not change the width of a volatile or atomic loads.
6436   if (!LoadN->isSimple())
6437     return false;
6438 
6439   // Do not generate loads of non-round integer types since these can
6440   // be expensive (and would be wrong if the type is not byte sized).
6441   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
6442     return false;
6443 
6444   if (LegalOperations &&
6445       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
6446     return false;
6447 
6448   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
6449     return false;
6450 
6451   return true;
6452 }
6453 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)6454 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6455                                     ISD::LoadExtType ExtType, EVT &MemVT,
6456                                     unsigned ShAmt) {
6457   if (!LDST)
6458     return false;
6459   // Only allow byte offsets.
6460   if (ShAmt % 8)
6461     return false;
6462 
6463   // Do not generate loads of non-round integer types since these can
6464   // be expensive (and would be wrong if the type is not byte sized).
6465   if (!MemVT.isRound())
6466     return false;
6467 
6468   // Don't change the width of a volatile or atomic loads.
6469   if (!LDST->isSimple())
6470     return false;
6471 
6472   EVT LdStMemVT = LDST->getMemoryVT();
6473 
6474   // Bail out when changing the scalable property, since we can't be sure that
6475   // we're actually narrowing here.
6476   if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6477     return false;
6478 
6479   // Verify that we are actually reducing a load width here.
6480   if (LdStMemVT.bitsLT(MemVT))
6481     return false;
6482 
6483   // Ensure that this isn't going to produce an unsupported memory access.
6484   if (ShAmt) {
6485     assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
6486     const unsigned ByteShAmt = ShAmt / 8;
6487     const Align LDSTAlign = LDST->getAlign();
6488     const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
6489     if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
6490                                 LDST->getAddressSpace(), NarrowAlign,
6491                                 LDST->getMemOperand()->getFlags()))
6492       return false;
6493   }
6494 
6495   // It's not possible to generate a constant of extended or untyped type.
6496   EVT PtrType = LDST->getBasePtr().getValueType();
6497   if (PtrType == MVT::Untyped || PtrType.isExtended())
6498     return false;
6499 
6500   if (isa<LoadSDNode>(LDST)) {
6501     LoadSDNode *Load = cast<LoadSDNode>(LDST);
6502     // Don't transform one with multiple uses, this would require adding a new
6503     // load.
6504     if (!SDValue(Load, 0).hasOneUse())
6505       return false;
6506 
6507     if (LegalOperations &&
6508         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
6509       return false;
6510 
6511     // For the transform to be legal, the load must produce only two values
6512     // (the value loaded and the chain).  Don't transform a pre-increment
6513     // load, for example, which produces an extra value.  Otherwise the
6514     // transformation is not equivalent, and the downstream logic to replace
6515     // uses gets things wrong.
6516     if (Load->getNumValues() > 2)
6517       return false;
6518 
6519     // If the load that we're shrinking is an extload and we're not just
6520     // discarding the extension we can't simply shrink the load. Bail.
6521     // TODO: It would be possible to merge the extensions in some cases.
6522     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6523         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6524       return false;
6525 
6526     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
6527       return false;
6528   } else {
6529     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6530     StoreSDNode *Store = cast<StoreSDNode>(LDST);
6531     // Can't write outside the original store
6532     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6533       return false;
6534 
6535     if (LegalOperations &&
6536         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
6537       return false;
6538   }
6539   return true;
6540 }
6541 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)6542 bool DAGCombiner::SearchForAndLoads(SDNode *N,
6543                                     SmallVectorImpl<LoadSDNode*> &Loads,
6544                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6545                                     ConstantSDNode *Mask,
6546                                     SDNode *&NodeToMask) {
6547   // Recursively search for the operands, looking for loads which can be
6548   // narrowed.
6549   for (SDValue Op : N->op_values()) {
6550     if (Op.getValueType().isVector())
6551       return false;
6552 
6553     // Some constants may need fixing up later if they are too large.
6554     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
6555       if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
6556           (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
6557         NodesWithConsts.insert(N);
6558       continue;
6559     }
6560 
6561     if (!Op.hasOneUse())
6562       return false;
6563 
6564     switch(Op.getOpcode()) {
6565     case ISD::LOAD: {
6566       auto *Load = cast<LoadSDNode>(Op);
6567       EVT ExtVT;
6568       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
6569           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
6570 
6571         // ZEXTLOAD is already small enough.
6572         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6573             ExtVT.bitsGE(Load->getMemoryVT()))
6574           continue;
6575 
6576         // Use LE to convert equal sized loads to zext.
6577         if (ExtVT.bitsLE(Load->getMemoryVT()))
6578           Loads.push_back(Load);
6579 
6580         continue;
6581       }
6582       return false;
6583     }
6584     case ISD::ZERO_EXTEND:
6585     case ISD::AssertZext: {
6586       unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6587       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6588       EVT VT = Op.getOpcode() == ISD::AssertZext ?
6589         cast<VTSDNode>(Op.getOperand(1))->getVT() :
6590         Op.getOperand(0).getValueType();
6591 
6592       // We can accept extending nodes if the mask is wider or an equal
6593       // width to the original type.
6594       if (ExtVT.bitsGE(VT))
6595         continue;
6596       break;
6597     }
6598     case ISD::OR:
6599     case ISD::XOR:
6600     case ISD::AND:
6601       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
6602                              NodeToMask))
6603         return false;
6604       continue;
6605     }
6606 
6607     // Allow one node which will masked along with any loads found.
6608     if (NodeToMask)
6609       return false;
6610 
6611     // Also ensure that the node to be masked only produces one data result.
6612     NodeToMask = Op.getNode();
6613     if (NodeToMask->getNumValues() > 1) {
6614       bool HasValue = false;
6615       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
6616         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
6617         if (VT != MVT::Glue && VT != MVT::Other) {
6618           if (HasValue) {
6619             NodeToMask = nullptr;
6620             return false;
6621           }
6622           HasValue = true;
6623         }
6624       }
6625       assert(HasValue && "Node to be masked has no data result?");
6626     }
6627   }
6628   return true;
6629 }
6630 
BackwardsPropagateMask(SDNode * N)6631 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
6632   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
6633   if (!Mask)
6634     return false;
6635 
6636   if (!Mask->getAPIntValue().isMask())
6637     return false;
6638 
6639   // No need to do anything if the and directly uses a load.
6640   if (isa<LoadSDNode>(N->getOperand(0)))
6641     return false;
6642 
6643   SmallVector<LoadSDNode*, 8> Loads;
6644   SmallPtrSet<SDNode*, 2> NodesWithConsts;
6645   SDNode *FixupNode = nullptr;
6646   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
6647     if (Loads.empty())
6648       return false;
6649 
6650     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
6651     SDValue MaskOp = N->getOperand(1);
6652 
6653     // If it exists, fixup the single node we allow in the tree that needs
6654     // masking.
6655     if (FixupNode) {
6656       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
6657       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
6658                                 FixupNode->getValueType(0),
6659                                 SDValue(FixupNode, 0), MaskOp);
6660       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
6661       if (And.getOpcode() == ISD ::AND)
6662         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
6663     }
6664 
6665     // Narrow any constants that need it.
6666     for (auto *LogicN : NodesWithConsts) {
6667       SDValue Op0 = LogicN->getOperand(0);
6668       SDValue Op1 = LogicN->getOperand(1);
6669 
6670       if (isa<ConstantSDNode>(Op0))
6671         Op0 =
6672             DAG.getNode(ISD::AND, SDLoc(Op0), Op0.getValueType(), Op0, MaskOp);
6673 
6674       if (isa<ConstantSDNode>(Op1))
6675         Op1 =
6676             DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(), Op1, MaskOp);
6677 
6678       if (isa<ConstantSDNode>(Op0) && !isa<ConstantSDNode>(Op1))
6679         std::swap(Op0, Op1);
6680 
6681       DAG.UpdateNodeOperands(LogicN, Op0, Op1);
6682     }
6683 
6684     // Create narrow loads.
6685     for (auto *Load : Loads) {
6686       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
6687       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
6688                                 SDValue(Load, 0), MaskOp);
6689       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
6690       if (And.getOpcode() == ISD ::AND)
6691         And = SDValue(
6692             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
6693       SDValue NewLoad = reduceLoadWidth(And.getNode());
6694       assert(NewLoad &&
6695              "Shouldn't be masking the load if it can't be narrowed");
6696       CombineTo(Load, NewLoad, NewLoad.getValue(1));
6697     }
6698     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
6699     return true;
6700   }
6701   return false;
6702 }
6703 
6704 // Unfold
6705 //    x &  (-1 'logical shift' y)
6706 // To
6707 //    (x 'opposite logical shift' y) 'logical shift' y
6708 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)6709 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
6710   assert(N->getOpcode() == ISD::AND);
6711 
6712   SDValue N0 = N->getOperand(0);
6713   SDValue N1 = N->getOperand(1);
6714 
6715   // Do we actually prefer shifts over mask?
6716   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
6717     return SDValue();
6718 
6719   // Try to match  (-1 '[outer] logical shift' y)
6720   unsigned OuterShift;
6721   unsigned InnerShift; // The opposite direction to the OuterShift.
6722   SDValue Y;           // Shift amount.
6723   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
6724     if (!M.hasOneUse())
6725       return false;
6726     OuterShift = M->getOpcode();
6727     if (OuterShift == ISD::SHL)
6728       InnerShift = ISD::SRL;
6729     else if (OuterShift == ISD::SRL)
6730       InnerShift = ISD::SHL;
6731     else
6732       return false;
6733     if (!isAllOnesConstant(M->getOperand(0)))
6734       return false;
6735     Y = M->getOperand(1);
6736     return true;
6737   };
6738 
6739   SDValue X;
6740   if (matchMask(N1))
6741     X = N0;
6742   else if (matchMask(N0))
6743     X = N1;
6744   else
6745     return SDValue();
6746 
6747   SDLoc DL(N);
6748   EVT VT = N->getValueType(0);
6749 
6750   //     tmp = x   'opposite logical shift' y
6751   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
6752   //     ret = tmp 'logical shift' y
6753   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
6754 
6755   return T1;
6756 }
6757 
6758 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
6759 /// For a target with a bit test, this is expected to become test + set and save
6760 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)6761 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
6762   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
6763 
6764   // Look through an optional extension.
6765   SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
6766   if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6767     And0 = And0.getOperand(0);
6768   if (!isOneConstant(And1) || !And0.hasOneUse())
6769     return SDValue();
6770 
6771   SDValue Src = And0;
6772 
6773   // Attempt to find a 'not' op.
6774   // TODO: Should we favor test+set even without the 'not' op?
6775   bool FoundNot = false;
6776   if (isBitwiseNot(Src)) {
6777     FoundNot = true;
6778     Src = Src.getOperand(0);
6779 
6780     // Look though an optional truncation. The source operand may not be the
6781     // same type as the original 'and', but that is ok because we are masking
6782     // off everything but the low bit.
6783     if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6784       Src = Src.getOperand(0);
6785   }
6786 
6787   // Match a shift-right by constant.
6788   if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6789     return SDValue();
6790 
6791   // This is probably not worthwhile without a supported type.
6792   EVT SrcVT = Src.getValueType();
6793   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6794   if (!TLI.isTypeLegal(SrcVT))
6795     return SDValue();
6796 
6797   // We might have looked through casts that make this transform invalid.
6798   unsigned BitWidth = SrcVT.getScalarSizeInBits();
6799   SDValue ShiftAmt = Src.getOperand(1);
6800   auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
6801   if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(BitWidth))
6802     return SDValue();
6803 
6804   // Set source to shift source.
6805   Src = Src.getOperand(0);
6806 
6807   // Try again to find a 'not' op.
6808   // TODO: Should we favor test+set even with two 'not' ops?
6809   if (!FoundNot) {
6810     if (!isBitwiseNot(Src))
6811       return SDValue();
6812     Src = Src.getOperand(0);
6813   }
6814 
6815   if (!TLI.hasBitTest(Src, ShiftAmt))
6816     return SDValue();
6817 
6818   // Turn this into a bit-test pattern using mask op + setcc:
6819   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6820   // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6821   SDLoc DL(And);
6822   SDValue X = DAG.getZExtOrTrunc(Src, DL, SrcVT);
6823   EVT CCVT =
6824       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
6825   SDValue Mask = DAG.getConstant(
6826       APInt::getOneBitSet(BitWidth, ShiftAmtC->getZExtValue()), DL, SrcVT);
6827   SDValue NewAnd = DAG.getNode(ISD::AND, DL, SrcVT, X, Mask);
6828   SDValue Zero = DAG.getConstant(0, DL, SrcVT);
6829   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
6830   return DAG.getZExtOrTrunc(Setcc, DL, And->getValueType(0));
6831 }
6832 
6833 /// For targets that support usubsat, match a bit-hack form of that operation
6834 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)6835 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG, const SDLoc &DL) {
6836   EVT VT = N->getValueType(0);
6837   unsigned BitWidth = VT.getScalarSizeInBits();
6838   APInt SignMask = APInt::getSignMask(BitWidth);
6839 
6840   // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6841   // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6842   // xor/add with SMIN (signmask) are logically equivalent.
6843   SDValue X;
6844   if (!sd_match(N, m_And(m_OneUse(m_Xor(m_Value(X), m_SpecificInt(SignMask))),
6845                          m_OneUse(m_Sra(m_Deferred(X),
6846                                         m_SpecificInt(BitWidth - 1))))) &&
6847       !sd_match(N, m_And(m_OneUse(m_Add(m_Value(X), m_SpecificInt(SignMask))),
6848                          m_OneUse(m_Sra(m_Deferred(X),
6849                                         m_SpecificInt(BitWidth - 1))))))
6850     return SDValue();
6851 
6852   return DAG.getNode(ISD::USUBSAT, DL, VT, X,
6853                      DAG.getConstant(SignMask, DL, VT));
6854 }
6855 
6856 /// Given a bitwise logic operation N with a matching bitwise logic operand,
6857 /// fold a pattern where 2 of the source operands are identically shifted
6858 /// values. For example:
6859 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)6860 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6861                                  SelectionDAG &DAG) {
6862   unsigned LogicOpcode = N->getOpcode();
6863   assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6864          "Expected bitwise logic operation");
6865 
6866   if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6867     return SDValue();
6868 
6869   // Match another bitwise logic op and a shift.
6870   unsigned ShiftOpcode = ShiftOp.getOpcode();
6871   if (LogicOp.getOpcode() != LogicOpcode ||
6872       !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6873         ShiftOpcode == ISD::SRA))
6874     return SDValue();
6875 
6876   // Match another shift op inside the first logic operand. Handle both commuted
6877   // possibilities.
6878   // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6879   // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6880   SDValue X1 = ShiftOp.getOperand(0);
6881   SDValue Y = ShiftOp.getOperand(1);
6882   SDValue X0, Z;
6883   if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
6884       LogicOp.getOperand(0).getOperand(1) == Y) {
6885     X0 = LogicOp.getOperand(0).getOperand(0);
6886     Z = LogicOp.getOperand(1);
6887   } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
6888              LogicOp.getOperand(1).getOperand(1) == Y) {
6889     X0 = LogicOp.getOperand(1).getOperand(0);
6890     Z = LogicOp.getOperand(0);
6891   } else {
6892     return SDValue();
6893   }
6894 
6895   EVT VT = N->getValueType(0);
6896   SDLoc DL(N);
6897   SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
6898   SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
6899   return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
6900 }
6901 
6902 /// Given a tree of logic operations with shape like
6903 /// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
6904 /// try to match and fold shift operations with the same shift amount.
6905 /// For example:
6906 /// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
6907 /// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
foldLogicTreeOfShifts(SDNode * N,SDValue LeftHand,SDValue RightHand,SelectionDAG & DAG)6908 static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
6909                                      SDValue RightHand, SelectionDAG &DAG) {
6910   unsigned LogicOpcode = N->getOpcode();
6911   assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6912          "Expected bitwise logic operation");
6913   if (LeftHand.getOpcode() != LogicOpcode ||
6914       RightHand.getOpcode() != LogicOpcode)
6915     return SDValue();
6916   if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
6917     return SDValue();
6918 
6919   // Try to match one of following patterns:
6920   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
6921   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
6922   // Note that foldLogicOfShifts will handle commuted versions of the left hand
6923   // itself.
6924   SDValue CombinedShifts, W;
6925   SDValue R0 = RightHand.getOperand(0);
6926   SDValue R1 = RightHand.getOperand(1);
6927   if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R0, DAG)))
6928     W = R1;
6929   else if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R1, DAG)))
6930     W = R0;
6931   else
6932     return SDValue();
6933 
6934   EVT VT = N->getValueType(0);
6935   SDLoc DL(N);
6936   return DAG.getNode(LogicOpcode, DL, VT, CombinedShifts, W);
6937 }
6938 
visitAND(SDNode * N)6939 SDValue DAGCombiner::visitAND(SDNode *N) {
6940   SDValue N0 = N->getOperand(0);
6941   SDValue N1 = N->getOperand(1);
6942   EVT VT = N1.getValueType();
6943   SDLoc DL(N);
6944 
6945   // x & x --> x
6946   if (N0 == N1)
6947     return N0;
6948 
6949   // fold (and c1, c2) -> c1&c2
6950   if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, DL, VT, {N0, N1}))
6951     return C;
6952 
6953   // canonicalize constant to RHS
6954   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6955       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6956     return DAG.getNode(ISD::AND, DL, VT, N1, N0);
6957 
6958   if (areBitwiseNotOfEachother(N0, N1))
6959     return DAG.getConstant(APInt::getZero(VT.getScalarSizeInBits()), DL, VT);
6960 
6961   // fold vector ops
6962   if (VT.isVector()) {
6963     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6964       return FoldedVOp;
6965 
6966     // fold (and x, 0) -> 0, vector edition
6967     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6968       // do not return N1, because undef node may exist in N1
6969       return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()), DL,
6970                              N1.getValueType());
6971 
6972     // fold (and x, -1) -> x, vector edition
6973     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6974       return N0;
6975 
6976     // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6977     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
6978     ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
6979     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
6980         N1.hasOneUse()) {
6981       EVT LoadVT = MLoad->getMemoryVT();
6982       EVT ExtVT = VT;
6983       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
6984         // For this AND to be a zero extension of the masked load the elements
6985         // of the BuildVec must mask the bottom bits of the extended element
6986         // type
6987         uint64_t ElementSize =
6988             LoadVT.getVectorElementType().getScalarSizeInBits();
6989         if (Splat->getAPIntValue().isMask(ElementSize)) {
6990           SDValue NewLoad = DAG.getMaskedLoad(
6991               ExtVT, DL, MLoad->getChain(), MLoad->getBasePtr(),
6992               MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
6993               LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
6994               ISD::ZEXTLOAD, MLoad->isExpandingLoad());
6995           bool LoadHasOtherUsers = !N0.hasOneUse();
6996           CombineTo(N, NewLoad);
6997           if (LoadHasOtherUsers)
6998             CombineTo(MLoad, NewLoad.getValue(0), NewLoad.getValue(1));
6999           return SDValue(N, 0);
7000         }
7001       }
7002     }
7003   }
7004 
7005   // fold (and x, -1) -> x
7006   if (isAllOnesConstant(N1))
7007     return N0;
7008 
7009   // if (and x, c) is known to be zero, return 0
7010   unsigned BitWidth = VT.getScalarSizeInBits();
7011   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7012   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
7013     return DAG.getConstant(0, DL, VT);
7014 
7015   if (SDValue R = foldAndOrOfSETCC(N, DAG))
7016     return R;
7017 
7018   if (SDValue NewSel = foldBinOpIntoSelect(N))
7019     return NewSel;
7020 
7021   // reassociate and
7022   if (SDValue RAND = reassociateOps(ISD::AND, DL, N0, N1, N->getFlags()))
7023     return RAND;
7024 
7025   // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7026   if (SDValue SD =
7027           reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, DL, VT, N0, N1))
7028     return SD;
7029 
7030   // fold (and (or x, C), D) -> D if (C & D) == D
7031   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7032     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
7033   };
7034   if (N0.getOpcode() == ISD::OR &&
7035       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
7036     return N1;
7037 
7038   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7039     SDValue N0Op0 = N0.getOperand(0);
7040     EVT SrcVT = N0Op0.getValueType();
7041     unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7042     APInt Mask = ~N1C->getAPIntValue();
7043     Mask = Mask.trunc(SrcBitWidth);
7044 
7045     // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7046     if (DAG.MaskedValueIsZero(N0Op0, Mask))
7047       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0Op0);
7048 
7049     // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7050     if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7051         TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) &&
7052         TLI.isTypeDesirableForOp(ISD::AND, SrcVT) &&
7053         TLI.isNarrowingProfitable(VT, SrcVT))
7054       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
7055                          DAG.getNode(ISD::AND, DL, SrcVT, N0Op0,
7056                                      DAG.getZExtOrTrunc(N1, DL, SrcVT)));
7057   }
7058 
7059   // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7060   if (ISD::isExtOpcode(N0.getOpcode())) {
7061     unsigned ExtOpc = N0.getOpcode();
7062     SDValue N0Op0 = N0.getOperand(0);
7063     if (N0Op0.getOpcode() == ISD::AND &&
7064         (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0Op0, VT)) &&
7065         DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
7066         DAG.isConstantIntBuildVectorOrConstantInt(N0Op0.getOperand(1)) &&
7067         N0->hasOneUse() && N0Op0->hasOneUse()) {
7068       SDValue NewMask =
7069           DAG.getNode(ISD::AND, DL, VT, N1,
7070                       DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(1)));
7071       return DAG.getNode(ISD::AND, DL, VT,
7072                          DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(0)),
7073                          NewMask);
7074     }
7075   }
7076 
7077   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7078   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7079   // already be zero by virtue of the width of the base type of the load.
7080   //
7081   // the 'X' node here can either be nothing or an extract_vector_elt to catch
7082   // more cases.
7083   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7084        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
7085        N0.getOperand(0).getOpcode() == ISD::LOAD &&
7086        N0.getOperand(0).getResNo() == 0) ||
7087       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7088     auto *Load =
7089         cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0));
7090 
7091     // Get the constant (if applicable) the zero'th operand is being ANDed with.
7092     // This can be a pure constant or a vector splat, in which case we treat the
7093     // vector as a scalar and use the splat value.
7094     APInt Constant = APInt::getZero(1);
7095     if (const ConstantSDNode *C = isConstOrConstSplat(
7096             N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) {
7097       Constant = C->getAPIntValue();
7098     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
7099       unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
7100       APInt SplatValue, SplatUndef;
7101       unsigned SplatBitSize;
7102       bool HasAnyUndefs;
7103       // Endianness should not matter here. Code below makes sure that we only
7104       // use the result if the SplatBitSize is a multiple of the vector element
7105       // size. And after that we AND all element sized parts of the splat
7106       // together. So the end result should be the same regardless of in which
7107       // order we do those operations.
7108       const bool IsBigEndian = false;
7109       bool IsSplat =
7110           Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7111                                   HasAnyUndefs, EltBitWidth, IsBigEndian);
7112 
7113       // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7114       // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7115       if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7116         // Undef bits can contribute to a possible optimisation if set, so
7117         // set them.
7118         SplatValue |= SplatUndef;
7119 
7120         // The splat value may be something like "0x00FFFFFF", which means 0 for
7121         // the first vector value and FF for the rest, repeating. We need a mask
7122         // that will apply equally to all members of the vector, so AND all the
7123         // lanes of the constant together.
7124         Constant = APInt::getAllOnes(EltBitWidth);
7125         for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7126           Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
7127       }
7128     }
7129 
7130     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7131     // actually legal and isn't going to get expanded, else this is a false
7132     // optimisation.
7133     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
7134                                                     Load->getValueType(0),
7135                                                     Load->getMemoryVT());
7136 
7137     // Resize the constant to the same size as the original memory access before
7138     // extension. If it is still the AllOnesValue then this AND is completely
7139     // unneeded.
7140     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
7141 
7142     bool B;
7143     switch (Load->getExtensionType()) {
7144     default: B = false; break;
7145     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7146     case ISD::ZEXTLOAD:
7147     case ISD::NON_EXTLOAD: B = true; break;
7148     }
7149 
7150     if (B && Constant.isAllOnes()) {
7151       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7152       // preserve semantics once we get rid of the AND.
7153       SDValue NewLoad(Load, 0);
7154 
7155       // Fold the AND away. NewLoad may get replaced immediately.
7156       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
7157 
7158       if (Load->getExtensionType() == ISD::EXTLOAD) {
7159         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
7160                               Load->getValueType(0), SDLoc(Load),
7161                               Load->getChain(), Load->getBasePtr(),
7162                               Load->getOffset(), Load->getMemoryVT(),
7163                               Load->getMemOperand());
7164         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7165         if (Load->getNumValues() == 3) {
7166           // PRE/POST_INC loads have 3 values.
7167           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
7168                            NewLoad.getValue(2) };
7169           CombineTo(Load, To, 3, true);
7170         } else {
7171           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
7172         }
7173       }
7174 
7175       return SDValue(N, 0); // Return N so it doesn't get rechecked!
7176     }
7177   }
7178 
7179   // Try to convert a constant mask AND into a shuffle clear mask.
7180   if (VT.isVector())
7181     if (SDValue Shuffle = XformToShuffleWithZero(N))
7182       return Shuffle;
7183 
7184   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7185     return Combined;
7186 
7187   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7188       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
7189     SDValue Ext = N0.getOperand(0);
7190     EVT ExtVT = Ext->getValueType(0);
7191     SDValue Extendee = Ext->getOperand(0);
7192 
7193     unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7194     if (N1C->getAPIntValue().isMask(ScalarWidth) &&
7195         (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
7196       //    (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7197       // => (extract_subvector (iN_zeroext v))
7198       SDValue ZeroExtExtendee =
7199           DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVT, Extendee);
7200 
7201       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ZeroExtExtendee,
7202                          N0.getOperand(1));
7203     }
7204   }
7205 
7206   // fold (and (masked_gather x)) -> (zext_masked_gather x)
7207   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
7208     EVT MemVT = GN0->getMemoryVT();
7209     EVT ScalarVT = MemVT.getScalarType();
7210 
7211     if (SDValue(GN0, 0).hasOneUse() &&
7212         isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
7213         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
7214       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
7215                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
7216 
7217       SDValue ZExtLoad = DAG.getMaskedGather(
7218           DAG.getVTList(VT, MVT::Other), MemVT, DL, Ops, GN0->getMemOperand(),
7219           GN0->getIndexType(), ISD::ZEXTLOAD);
7220 
7221       CombineTo(N, ZExtLoad);
7222       AddToWorklist(ZExtLoad.getNode());
7223       // Avoid recheck of N.
7224       return SDValue(N, 0);
7225     }
7226   }
7227 
7228   // fold (and (load x), 255) -> (zextload x, i8)
7229   // fold (and (extload x, i16), 255) -> (zextload x, i8)
7230   if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7231     if (SDValue Res = reduceLoadWidth(N))
7232       return Res;
7233 
7234   if (LegalTypes) {
7235     // Attempt to propagate the AND back up to the leaves which, if they're
7236     // loads, can be combined to narrow loads and the AND node can be removed.
7237     // Perform after legalization so that extend nodes will already be
7238     // combined into the loads.
7239     if (BackwardsPropagateMask(N))
7240       return SDValue(N, 0);
7241   }
7242 
7243   if (SDValue Combined = visitANDLike(N0, N1, N))
7244     return Combined;
7245 
7246   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
7247   if (N0.getOpcode() == N1.getOpcode())
7248     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7249       return V;
7250 
7251   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7252     return R;
7253   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
7254     return R;
7255 
7256   // Masking the negated extension of a boolean is just the zero-extended
7257   // boolean:
7258   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7259   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7260   //
7261   // Note: the SimplifyDemandedBits fold below can make an information-losing
7262   // transform, and then we have no way to find this better fold.
7263   if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
7264     if (isNullOrNullSplat(N0.getOperand(0))) {
7265       SDValue SubRHS = N0.getOperand(1);
7266       if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
7267           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
7268         return SubRHS;
7269       if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
7270           SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
7271         return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SubRHS.getOperand(0));
7272     }
7273   }
7274 
7275   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7276   // fold (and (sra)) -> (and (srl)) when possible.
7277   if (SimplifyDemandedBits(SDValue(N, 0)))
7278     return SDValue(N, 0);
7279 
7280   // fold (zext_inreg (extload x)) -> (zextload x)
7281   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7282   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
7283       (ISD::isEXTLoad(N0.getNode()) ||
7284        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
7285     auto *LN0 = cast<LoadSDNode>(N0);
7286     EVT MemVT = LN0->getMemoryVT();
7287     // If we zero all the possible extended bits, then we can turn this into
7288     // a zextload if we are running before legalize or the operation is legal.
7289     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7290     unsigned MemBitSize = MemVT.getScalarSizeInBits();
7291     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
7292     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
7293         ((!LegalOperations && LN0->isSimple()) ||
7294          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
7295       SDValue ExtLoad =
7296           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
7297                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
7298       AddToWorklist(N);
7299       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
7300       return SDValue(N, 0); // Return N so it doesn't get rechecked!
7301     }
7302   }
7303 
7304   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7305   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7306     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
7307                                            N0.getOperand(1), false))
7308       return BSwap;
7309   }
7310 
7311   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7312     return Shifts;
7313 
7314   if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
7315     return V;
7316 
7317   // Recognize the following pattern:
7318   //
7319   // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7320   //
7321   // where bitmask is a mask that clears the upper bits of AndVT. The
7322   // number of bits in bitmask must be a power of two.
7323   auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7324     if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7325       return false;
7326 
7327     auto *C = dyn_cast<ConstantSDNode>(RHS);
7328     if (!C)
7329       return false;
7330 
7331     if (!C->getAPIntValue().isMask(
7332             LHS.getOperand(0).getValueType().getFixedSizeInBits()))
7333       return false;
7334 
7335     return true;
7336   };
7337 
7338   // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7339   if (IsAndZeroExtMask(N0, N1))
7340     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
7341 
7342   if (hasOperation(ISD::USUBSAT, VT))
7343     if (SDValue V = foldAndToUsubsat(N, DAG, DL))
7344       return V;
7345 
7346   // Postpone until legalization completed to avoid interference with bswap
7347   // folding
7348   if (LegalOperations || VT.isVector())
7349     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
7350       return R;
7351 
7352   return SDValue();
7353 }
7354 
7355 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)7356 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7357                                         bool DemandHighBits) {
7358   if (!LegalOperations)
7359     return SDValue();
7360 
7361   EVT VT = N->getValueType(0);
7362   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7363     return SDValue();
7364   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
7365     return SDValue();
7366 
7367   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7368   bool LookPassAnd0 = false;
7369   bool LookPassAnd1 = false;
7370   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
7371     std::swap(N0, N1);
7372   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
7373     std::swap(N0, N1);
7374   if (N0.getOpcode() == ISD::AND) {
7375     if (!N0->hasOneUse())
7376       return SDValue();
7377     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7378     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7379     // This is needed for X86.
7380     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7381                   N01C->getZExtValue() != 0xFFFF))
7382       return SDValue();
7383     N0 = N0.getOperand(0);
7384     LookPassAnd0 = true;
7385   }
7386 
7387   if (N1.getOpcode() == ISD::AND) {
7388     if (!N1->hasOneUse())
7389       return SDValue();
7390     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7391     if (!N11C || N11C->getZExtValue() != 0xFF)
7392       return SDValue();
7393     N1 = N1.getOperand(0);
7394     LookPassAnd1 = true;
7395   }
7396 
7397   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7398     std::swap(N0, N1);
7399   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7400     return SDValue();
7401   if (!N0->hasOneUse() || !N1->hasOneUse())
7402     return SDValue();
7403 
7404   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7405   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7406   if (!N01C || !N11C)
7407     return SDValue();
7408   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7409     return SDValue();
7410 
7411   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7412   SDValue N00 = N0->getOperand(0);
7413   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7414     if (!N00->hasOneUse())
7415       return SDValue();
7416     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
7417     if (!N001C || N001C->getZExtValue() != 0xFF)
7418       return SDValue();
7419     N00 = N00.getOperand(0);
7420     LookPassAnd0 = true;
7421   }
7422 
7423   SDValue N10 = N1->getOperand(0);
7424   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7425     if (!N10->hasOneUse())
7426       return SDValue();
7427     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
7428     // Also allow 0xFFFF since the bits will be shifted out. This is needed
7429     // for X86.
7430     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7431                    N101C->getZExtValue() != 0xFFFF))
7432       return SDValue();
7433     N10 = N10.getOperand(0);
7434     LookPassAnd1 = true;
7435   }
7436 
7437   if (N00 != N10)
7438     return SDValue();
7439 
7440   // Make sure everything beyond the low halfword gets set to zero since the SRL
7441   // 16 will clear the top bits.
7442   unsigned OpSizeInBits = VT.getSizeInBits();
7443   if (OpSizeInBits > 16) {
7444     // If the left-shift isn't masked out then the only way this is a bswap is
7445     // if all bits beyond the low 8 are 0. In that case the entire pattern
7446     // reduces to a left shift anyway: leave it for other parts of the combiner.
7447     if (DemandHighBits && !LookPassAnd0)
7448       return SDValue();
7449 
7450     // However, if the right shift isn't masked out then it might be because
7451     // it's not needed. See if we can spot that too. If the high bits aren't
7452     // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7453     // upper bits to be zero.
7454     if (!LookPassAnd1) {
7455       unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7456       if (!DAG.MaskedValueIsZero(N10,
7457                                  APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
7458         return SDValue();
7459     }
7460   }
7461 
7462   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
7463   if (OpSizeInBits > 16) {
7464     SDLoc DL(N);
7465     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
7466                       DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
7467   }
7468   return Res;
7469 }
7470 
7471 /// Return true if the specified node is an element that makes up a 32-bit
7472 /// packed halfword byteswap.
7473 /// ((x & 0x000000ff) << 8) |
7474 /// ((x & 0x0000ff00) >> 8) |
7475 /// ((x & 0x00ff0000) << 8) |
7476 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)7477 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7478   if (!N->hasOneUse())
7479     return false;
7480 
7481   unsigned Opc = N.getOpcode();
7482   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7483     return false;
7484 
7485   SDValue N0 = N.getOperand(0);
7486   unsigned Opc0 = N0.getOpcode();
7487   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7488     return false;
7489 
7490   ConstantSDNode *N1C = nullptr;
7491   // SHL or SRL: look upstream for AND mask operand
7492   if (Opc == ISD::AND)
7493     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7494   else if (Opc0 == ISD::AND)
7495     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7496   if (!N1C)
7497     return false;
7498 
7499   unsigned MaskByteOffset;
7500   switch (N1C->getZExtValue()) {
7501   default:
7502     return false;
7503   case 0xFF:       MaskByteOffset = 0; break;
7504   case 0xFF00:     MaskByteOffset = 1; break;
7505   case 0xFFFF:
7506     // In case demanded bits didn't clear the bits that will be shifted out.
7507     // This is needed for X86.
7508     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7509       MaskByteOffset = 1;
7510       break;
7511     }
7512     return false;
7513   case 0xFF0000:   MaskByteOffset = 2; break;
7514   case 0xFF000000: MaskByteOffset = 3; break;
7515   }
7516 
7517   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7518   if (Opc == ISD::AND) {
7519     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7520       // (x >> 8) & 0xff
7521       // (x >> 8) & 0xff0000
7522       if (Opc0 != ISD::SRL)
7523         return false;
7524       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7525       if (!C || C->getZExtValue() != 8)
7526         return false;
7527     } else {
7528       // (x << 8) & 0xff00
7529       // (x << 8) & 0xff000000
7530       if (Opc0 != ISD::SHL)
7531         return false;
7532       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7533       if (!C || C->getZExtValue() != 8)
7534         return false;
7535     }
7536   } else if (Opc == ISD::SHL) {
7537     // (x & 0xff) << 8
7538     // (x & 0xff0000) << 8
7539     if (MaskByteOffset != 0 && MaskByteOffset != 2)
7540       return false;
7541     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7542     if (!C || C->getZExtValue() != 8)
7543       return false;
7544   } else { // Opc == ISD::SRL
7545     // (x & 0xff00) >> 8
7546     // (x & 0xff000000) >> 8
7547     if (MaskByteOffset != 1 && MaskByteOffset != 3)
7548       return false;
7549     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7550     if (!C || C->getZExtValue() != 8)
7551       return false;
7552   }
7553 
7554   if (Parts[MaskByteOffset])
7555     return false;
7556 
7557   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
7558   return true;
7559 }
7560 
7561 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)7562 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
7563   if (N.getOpcode() == ISD::OR)
7564     return isBSwapHWordElement(N.getOperand(0), Parts) &&
7565            isBSwapHWordElement(N.getOperand(1), Parts);
7566 
7567   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
7568     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
7569     if (!C || C->getAPIntValue() != 16)
7570       return false;
7571     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
7572     return true;
7573   }
7574 
7575   return false;
7576 }
7577 
7578 // Match this pattern:
7579 //   (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
7580 // And rewrite this to:
7581 //   (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT)7582 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
7583                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
7584                                        SDValue N1, EVT VT) {
7585   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
7586          "MatchBSwapHWordOrAndAnd: expecting i32");
7587   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
7588     return SDValue();
7589   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
7590     return SDValue();
7591   // TODO: this is too restrictive; lifting this restriction requires more tests
7592   if (!N0->hasOneUse() || !N1->hasOneUse())
7593     return SDValue();
7594   ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
7595   ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
7596   if (!Mask0 || !Mask1)
7597     return SDValue();
7598   if (Mask0->getAPIntValue() != 0xff00ff00 ||
7599       Mask1->getAPIntValue() != 0x00ff00ff)
7600     return SDValue();
7601   SDValue Shift0 = N0.getOperand(0);
7602   SDValue Shift1 = N1.getOperand(0);
7603   if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
7604     return SDValue();
7605   ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
7606   ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
7607   if (!ShiftAmt0 || !ShiftAmt1)
7608     return SDValue();
7609   if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
7610     return SDValue();
7611   if (Shift0.getOperand(0) != Shift1.getOperand(0))
7612     return SDValue();
7613 
7614   SDLoc DL(N);
7615   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
7616   SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
7617   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
7618 }
7619 
7620 /// Match a 32-bit packed halfword bswap. That is
7621 /// ((x & 0x000000ff) << 8) |
7622 /// ((x & 0x0000ff00) >> 8) |
7623 /// ((x & 0x00ff0000) << 8) |
7624 /// ((x & 0xff000000) >> 8)
7625 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)7626 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
7627   if (!LegalOperations)
7628     return SDValue();
7629 
7630   EVT VT = N->getValueType(0);
7631   if (VT != MVT::i32)
7632     return SDValue();
7633   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
7634     return SDValue();
7635 
7636   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
7637     return BSwap;
7638 
7639   // Try again with commuted operands.
7640   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
7641     return BSwap;
7642 
7643 
7644   // Look for either
7645   // (or (bswaphpair), (bswaphpair))
7646   // (or (or (bswaphpair), (and)), (and))
7647   // (or (or (and), (bswaphpair)), (and))
7648   SDNode *Parts[4] = {};
7649 
7650   if (isBSwapHWordPair(N0, Parts)) {
7651     // (or (or (and), (and)), (or (and), (and)))
7652     if (!isBSwapHWordPair(N1, Parts))
7653       return SDValue();
7654   } else if (N0.getOpcode() == ISD::OR) {
7655     // (or (or (or (and), (and)), (and)), (and))
7656     if (!isBSwapHWordElement(N1, Parts))
7657       return SDValue();
7658     SDValue N00 = N0.getOperand(0);
7659     SDValue N01 = N0.getOperand(1);
7660     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
7661         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
7662       return SDValue();
7663   } else {
7664     return SDValue();
7665   }
7666 
7667   // Make sure the parts are all coming from the same node.
7668   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
7669     return SDValue();
7670 
7671   SDLoc DL(N);
7672   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
7673                               SDValue(Parts[0], 0));
7674 
7675   // Result of the bswap should be rotated by 16. If it's not legal, then
7676   // do  (x << 16) | (x >> 16).
7677   SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
7678   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
7679     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
7680   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
7681     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
7682   return DAG.getNode(ISD::OR, DL, VT,
7683                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
7684                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
7685 }
7686 
7687 /// This contains all DAGCombine rules which reduce two values combined by
7688 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,const SDLoc & DL)7689 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
7690   EVT VT = N1.getValueType();
7691 
7692   // fold (or x, undef) -> -1
7693   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
7694     return DAG.getAllOnesConstant(DL, VT);
7695 
7696   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
7697     return V;
7698 
7699   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
7700   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
7701       // Don't increase # computations.
7702       (N0->hasOneUse() || N1->hasOneUse())) {
7703     // We can only do this xform if we know that bits from X that are set in C2
7704     // but not in C1 are already zero.  Likewise for Y.
7705     if (const ConstantSDNode *N0O1C =
7706         getAsNonOpaqueConstant(N0.getOperand(1))) {
7707       if (const ConstantSDNode *N1O1C =
7708           getAsNonOpaqueConstant(N1.getOperand(1))) {
7709         // We can only do this xform if we know that bits from X that are set in
7710         // C2 but not in C1 are already zero.  Likewise for Y.
7711         const APInt &LHSMask = N0O1C->getAPIntValue();
7712         const APInt &RHSMask = N1O1C->getAPIntValue();
7713 
7714         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
7715             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
7716           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7717                                   N0.getOperand(0), N1.getOperand(0));
7718           return DAG.getNode(ISD::AND, DL, VT, X,
7719                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
7720         }
7721       }
7722     }
7723   }
7724 
7725   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
7726   if (N0.getOpcode() == ISD::AND &&
7727       N1.getOpcode() == ISD::AND &&
7728       N0.getOperand(0) == N1.getOperand(0) &&
7729       // Don't increase # computations.
7730       (N0->hasOneUse() || N1->hasOneUse())) {
7731     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7732                             N0.getOperand(1), N1.getOperand(1));
7733     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
7734   }
7735 
7736   return SDValue();
7737 }
7738 
7739 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)7740 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
7741                                   SDNode *N) {
7742   EVT VT = N0.getValueType();
7743   unsigned BW = VT.getScalarSizeInBits();
7744   SDLoc DL(N);
7745 
7746   auto peekThroughResize = [](SDValue V) {
7747     if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
7748       return V->getOperand(0);
7749     return V;
7750   };
7751 
7752   SDValue N0Resized = peekThroughResize(N0);
7753   if (N0Resized.getOpcode() == ISD::AND) {
7754     SDValue N1Resized = peekThroughResize(N1);
7755     SDValue N00 = N0Resized.getOperand(0);
7756     SDValue N01 = N0Resized.getOperand(1);
7757 
7758     // fold or (and x, y), x --> x
7759     if (N00 == N1Resized || N01 == N1Resized)
7760       return N1;
7761 
7762     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
7763     // TODO: Set AllowUndefs = true.
7764     if (SDValue NotOperand = getBitwiseNotOperand(N01, N00,
7765                                                   /* AllowUndefs */ false)) {
7766       if (peekThroughResize(NotOperand) == N1Resized)
7767         return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N00, DL, VT),
7768                            N1);
7769     }
7770 
7771     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
7772     if (SDValue NotOperand = getBitwiseNotOperand(N00, N01,
7773                                                   /* AllowUndefs */ false)) {
7774       if (peekThroughResize(NotOperand) == N1Resized)
7775         return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N01, DL, VT),
7776                            N1);
7777     }
7778   }
7779 
7780   SDValue X, Y;
7781 
7782   // fold or (xor X, N1), N1 --> or X, N1
7783   if (sd_match(N0, m_Xor(m_Value(X), m_Specific(N1))))
7784     return DAG.getNode(ISD::OR, DL, VT, X, N1);
7785 
7786   // fold or (xor x, y), (x and/or y) --> or x, y
7787   if (sd_match(N0, m_Xor(m_Value(X), m_Value(Y))) &&
7788       (sd_match(N1, m_And(m_Specific(X), m_Specific(Y))) ||
7789        sd_match(N1, m_Or(m_Specific(X), m_Specific(Y)))))
7790     return DAG.getNode(ISD::OR, DL, VT, X, Y);
7791 
7792   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7793     return R;
7794 
7795   auto peekThroughZext = [](SDValue V) {
7796     if (V->getOpcode() == ISD::ZERO_EXTEND)
7797       return V->getOperand(0);
7798     return V;
7799   };
7800 
7801   // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
7802   if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
7803       N0.getOperand(0) == N1.getOperand(0) &&
7804       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7805     return N0;
7806 
7807   // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
7808   if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
7809       N0.getOperand(1) == N1.getOperand(0) &&
7810       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7811     return N0;
7812 
7813   // Attempt to match a legalized build_pair-esque pattern:
7814   // or(shl(aext(Hi),BW/2),zext(Lo))
7815   SDValue Lo, Hi;
7816   if (sd_match(N0,
7817                m_OneUse(m_Shl(m_AnyExt(m_Value(Hi)), m_SpecificInt(BW / 2)))) &&
7818       sd_match(N1, m_ZExt(m_Value(Lo))) &&
7819       Lo.getScalarValueSizeInBits() == (BW / 2) &&
7820       Lo.getValueType() == Hi.getValueType()) {
7821     // Fold build_pair(not(Lo),not(Hi)) -> not(build_pair(Lo,Hi)).
7822     SDValue NotLo, NotHi;
7823     if (sd_match(Lo, m_OneUse(m_Not(m_Value(NotLo)))) &&
7824         sd_match(Hi, m_OneUse(m_Not(m_Value(NotHi))))) {
7825       Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotLo);
7826       Hi = DAG.getNode(ISD::ANY_EXTEND, DL, VT, NotHi);
7827       Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
7828                        DAG.getShiftAmountConstant(BW / 2, VT, DL));
7829       return DAG.getNOT(DL, DAG.getNode(ISD::OR, DL, VT, Lo, Hi), VT);
7830     }
7831   }
7832 
7833   return SDValue();
7834 }
7835 
visitOR(SDNode * N)7836 SDValue DAGCombiner::visitOR(SDNode *N) {
7837   SDValue N0 = N->getOperand(0);
7838   SDValue N1 = N->getOperand(1);
7839   EVT VT = N1.getValueType();
7840   SDLoc DL(N);
7841 
7842   // x | x --> x
7843   if (N0 == N1)
7844     return N0;
7845 
7846   // fold (or c1, c2) -> c1|c2
7847   if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, DL, VT, {N0, N1}))
7848     return C;
7849 
7850   // canonicalize constant to RHS
7851   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7852       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7853     return DAG.getNode(ISD::OR, DL, VT, N1, N0);
7854 
7855   // fold vector ops
7856   if (VT.isVector()) {
7857     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7858       return FoldedVOp;
7859 
7860     // fold (or x, 0) -> x, vector edition
7861     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
7862       return N0;
7863 
7864     // fold (or x, -1) -> -1, vector edition
7865     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
7866       // do not return N1, because undef node may exist in N1
7867       return DAG.getAllOnesConstant(DL, N1.getValueType());
7868 
7869     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
7870     // Do this only if the resulting type / shuffle is legal.
7871     auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
7872     auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
7873     if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
7874       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
7875       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
7876       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
7877       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
7878       // Ensure both shuffles have a zero input.
7879       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
7880         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
7881         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
7882         bool CanFold = true;
7883         int NumElts = VT.getVectorNumElements();
7884         SmallVector<int, 4> Mask(NumElts, -1);
7885 
7886         for (int i = 0; i != NumElts; ++i) {
7887           int M0 = SV0->getMaskElt(i);
7888           int M1 = SV1->getMaskElt(i);
7889 
7890           // Determine if either index is pointing to a zero vector.
7891           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7892           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7893 
7894           // If one element is zero and the otherside is undef, keep undef.
7895           // This also handles the case that both are undef.
7896           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7897             continue;
7898 
7899           // Make sure only one of the elements is zero.
7900           if (M0Zero == M1Zero) {
7901             CanFold = false;
7902             break;
7903           }
7904 
7905           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7906 
7907           // We have a zero and non-zero element. If the non-zero came from
7908           // SV0 make the index a LHS index. If it came from SV1, make it
7909           // a RHS index. We need to mod by NumElts because we don't care
7910           // which operand it came from in the original shuffles.
7911           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7912         }
7913 
7914         if (CanFold) {
7915           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
7916           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
7917           SDValue LegalShuffle =
7918               TLI.buildLegalVectorShuffle(VT, DL, NewLHS, NewRHS, Mask, DAG);
7919           if (LegalShuffle)
7920             return LegalShuffle;
7921         }
7922       }
7923     }
7924   }
7925 
7926   // fold (or x, 0) -> x
7927   if (isNullConstant(N1))
7928     return N0;
7929 
7930   // fold (or x, -1) -> -1
7931   if (isAllOnesConstant(N1))
7932     return N1;
7933 
7934   if (SDValue NewSel = foldBinOpIntoSelect(N))
7935     return NewSel;
7936 
7937   // fold (or x, c) -> c iff (x & ~c) == 0
7938   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
7939   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
7940     return N1;
7941 
7942   if (SDValue R = foldAndOrOfSETCC(N, DAG))
7943     return R;
7944 
7945   if (SDValue Combined = visitORLike(N0, N1, DL))
7946     return Combined;
7947 
7948   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7949     return Combined;
7950 
7951   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7952   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7953     return BSwap;
7954   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7955     return BSwap;
7956 
7957   // reassociate or
7958   if (SDValue ROR = reassociateOps(ISD::OR, DL, N0, N1, N->getFlags()))
7959     return ROR;
7960 
7961   // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
7962   if (SDValue SD =
7963           reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, DL, VT, N0, N1))
7964     return SD;
7965 
7966   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7967   // iff (c1 & c2) != 0 or c1/c2 are undef.
7968   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7969     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
7970   };
7971   if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7972       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
7973     if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
7974                                                  {N1, N0.getOperand(1)})) {
7975       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
7976       AddToWorklist(IOR.getNode());
7977       return DAG.getNode(ISD::AND, DL, VT, COR, IOR);
7978     }
7979   }
7980 
7981   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7982     return Combined;
7983   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
7984     return Combined;
7985 
7986   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
7987   if (N0.getOpcode() == N1.getOpcode())
7988     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7989       return V;
7990 
7991   // See if this is some rotate idiom.
7992   if (SDValue Rot = MatchRotate(N0, N1, DL))
7993     return Rot;
7994 
7995   if (SDValue Load = MatchLoadCombine(N))
7996     return Load;
7997 
7998   // Simplify the operands using demanded-bits information.
7999   if (SimplifyDemandedBits(SDValue(N, 0)))
8000     return SDValue(N, 0);
8001 
8002   // If OR can be rewritten into ADD, try combines based on ADD.
8003   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8004       DAG.isADDLike(SDValue(N, 0)))
8005     if (SDValue Combined = visitADDLike(N))
8006       return Combined;
8007 
8008   // Postpone until legalization completed to avoid interference with bswap
8009   // folding
8010   if (LegalOperations || VT.isVector())
8011     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8012       return R;
8013 
8014   return SDValue();
8015 }
8016 
stripConstantMask(const SelectionDAG & DAG,SDValue Op,SDValue & Mask)8017 static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8018                                  SDValue &Mask) {
8019   if (Op.getOpcode() == ISD::AND &&
8020       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
8021     Mask = Op.getOperand(1);
8022     return Op.getOperand(0);
8023   }
8024   return Op;
8025 }
8026 
8027 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(const SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)8028 static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8029                             SDValue &Mask) {
8030   Op = stripConstantMask(DAG, Op, Mask);
8031   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8032     Shift = Op;
8033     return true;
8034   }
8035   return false;
8036 }
8037 
8038 /// Helper function for visitOR to extract the needed side of a rotate idiom
8039 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
8040 /// InstCombine merged some outside op with one of the shifts from
8041 /// the rotate pattern.
8042 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8043 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
8044 /// patterns:
8045 ///
8046 ///   (or (add v v) (shrl v bitwidth-1)):
8047 ///     expands (add v v) -> (shl v 1)
8048 ///
8049 ///   (or (mul v c0) (shrl (mul v c1) c2)):
8050 ///     expands (mul v c0) -> (shl (mul v c1) c3)
8051 ///
8052 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
8053 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
8054 ///
8055 ///   (or (shl v c0) (shrl (shl v c1) c2)):
8056 ///     expands (shl v c0) -> (shl (shl v c1) c3)
8057 ///
8058 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
8059 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
8060 ///
8061 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)8062 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8063                                      SDValue ExtractFrom, SDValue &Mask,
8064                                      const SDLoc &DL) {
8065   assert(OppShift && ExtractFrom && "Empty SDValue");
8066   if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8067     return SDValue();
8068 
8069   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
8070 
8071   // Value and Type of the shift.
8072   SDValue OppShiftLHS = OppShift.getOperand(0);
8073   EVT ShiftedVT = OppShiftLHS.getValueType();
8074 
8075   // Amount of the existing shift.
8076   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
8077 
8078   // (add v v) -> (shl v 1)
8079   // TODO: Should this be a general DAG canonicalization?
8080   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8081       ExtractFrom.getOpcode() == ISD::ADD &&
8082       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
8083       ExtractFrom.getOperand(0) == OppShiftLHS &&
8084       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8085     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
8086                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
8087 
8088   // Preconditions:
8089   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8090   //
8091   // Find opcode of the needed shift to be extracted from (op0 v c0).
8092   unsigned Opcode = ISD::DELETED_NODE;
8093   bool IsMulOrDiv = false;
8094   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8095   // opcode or its arithmetic (mul or udiv) variant.
8096   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8097     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8098     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8099       return false;
8100     Opcode = NeededShift;
8101     return true;
8102   };
8103   // op0 must be either the needed shift opcode or the mul/udiv equivalent
8104   // that the needed shift can be extracted from.
8105   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8106       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8107     return SDValue();
8108 
8109   // op0 must be the same opcode on both sides, have the same LHS argument,
8110   // and produce the same value type.
8111   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8112       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
8113       ShiftedVT != ExtractFrom.getValueType())
8114     return SDValue();
8115 
8116   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8117   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
8118   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8119   ConstantSDNode *ExtractFromCst =
8120       isConstOrConstSplat(ExtractFrom.getOperand(1));
8121   // TODO: We should be able to handle non-uniform constant vectors for these values
8122   // Check that we have constant values.
8123   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8124       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8125       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8126     return SDValue();
8127 
8128   // Compute the shift amount we need to extract to complete the rotate.
8129   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8130   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
8131     return SDValue();
8132   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8133   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8134   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8135   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8136   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
8137 
8138   // Now try extract the needed shift from the ExtractFrom op and see if the
8139   // result matches up with the existing shift's LHS op.
8140   if (IsMulOrDiv) {
8141     // Op to extract from is a mul or udiv by a constant.
8142     // Check:
8143     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8144     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8145     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
8146                                                  NeededShiftAmt.getZExtValue());
8147     APInt ResultAmt;
8148     APInt Rem;
8149     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
8150     if (Rem != 0 || ResultAmt != OppLHSAmt)
8151       return SDValue();
8152   } else {
8153     // Op to extract from is a shift by a constant.
8154     // Check:
8155     //      c2 - (bitwidth(op0 v c0) - c1) == c0
8156     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8157                                           ExtractFromAmt.getBitWidth()))
8158       return SDValue();
8159   }
8160 
8161   // Return the expanded shift op that should allow a rotate to be formed.
8162   EVT ShiftVT = OppShift.getOperand(1).getValueType();
8163   EVT ResVT = ExtractFrom.getValueType();
8164   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
8165   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
8166 }
8167 
8168 // Return true if we can prove that, whenever Neg and Pos are both in the
8169 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
8170 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8171 //
8172 //     (or (shift1 X, Neg), (shift2 X, Pos))
8173 //
8174 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8175 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
8176 // to consider shift amounts with defined behavior.
8177 //
8178 // The IsRotate flag should be set when the LHS of both shifts is the same.
8179 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)8180 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8181                            SelectionDAG &DAG, bool IsRotate) {
8182   const auto &TLI = DAG.getTargetLoweringInfo();
8183   // If EltSize is a power of 2 then:
8184   //
8185   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8186   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8187   //
8188   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8189   // for the stronger condition:
8190   //
8191   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
8192   //
8193   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8194   // we can just replace Neg with Neg' for the rest of the function.
8195   //
8196   // In other cases we check for the even stronger condition:
8197   //
8198   //     Neg == EltSize - Pos                                    [B]
8199   //
8200   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
8201   // behavior if Pos == 0 (and consequently Neg == EltSize).
8202   //
8203   // We could actually use [A] whenever EltSize is a power of 2, but the
8204   // only extra cases that it would match are those uninteresting ones
8205   // where Neg and Pos are never in range at the same time.  E.g. for
8206   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8207   // as well as (sub 32, Pos), but:
8208   //
8209   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8210   //
8211   // always invokes undefined behavior for 32-bit X.
8212   //
8213   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8214   // This allows us to peek through any operations that only affect Mask's
8215   // un-demanded bits.
8216   //
8217   // NOTE: We can only do this when matching operations which won't modify the
8218   // least Log2(EltSize) significant bits and not a general funnel shift.
8219   unsigned MaskLoBits = 0;
8220   if (IsRotate && isPowerOf2_64(EltSize)) {
8221     unsigned Bits = Log2_64(EltSize);
8222     unsigned NegBits = Neg.getScalarValueSizeInBits();
8223     if (NegBits >= Bits) {
8224       APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
8225       if (SDValue Inner =
8226               TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
8227         Neg = Inner;
8228         MaskLoBits = Bits;
8229       }
8230     }
8231   }
8232 
8233   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8234   if (Neg.getOpcode() != ISD::SUB)
8235     return false;
8236   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
8237   if (!NegC)
8238     return false;
8239   SDValue NegOp1 = Neg.getOperand(1);
8240 
8241   // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8242   // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8243   // are redundant for the purpose of the equality.
8244   if (MaskLoBits) {
8245     unsigned PosBits = Pos.getScalarValueSizeInBits();
8246     if (PosBits >= MaskLoBits) {
8247       APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
8248       if (SDValue Inner =
8249               TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
8250         Pos = Inner;
8251       }
8252     }
8253   }
8254 
8255   // The condition we need is now:
8256   //
8257   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8258   //
8259   // If NegOp1 == Pos then we need:
8260   //
8261   //              EltSize & Mask == NegC & Mask
8262   //
8263   // (because "x & Mask" is a truncation and distributes through subtraction).
8264   //
8265   // We also need to account for a potential truncation of NegOp1 if the amount
8266   // has already been legalized to a shift amount type.
8267   APInt Width;
8268   if ((Pos == NegOp1) ||
8269       (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
8270     Width = NegC->getAPIntValue();
8271 
8272   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8273   // Then the condition we want to prove becomes:
8274   //
8275   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8276   //
8277   // which, again because "x & Mask" is a truncation, becomes:
8278   //
8279   //                NegC & Mask == (EltSize - PosC) & Mask
8280   //             EltSize & Mask == (NegC + PosC) & Mask
8281   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
8282     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
8283       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8284     else
8285       return false;
8286   } else
8287     return false;
8288 
8289   // Now we just need to check that EltSize & Mask == Width & Mask.
8290   if (MaskLoBits)
8291     // EltSize & Mask is 0 since Mask is EltSize - 1.
8292     return Width.getLoBits(MaskLoBits) == 0;
8293   return Width == EltSize;
8294 }
8295 
8296 // A subroutine of MatchRotate used once we have found an OR of two opposite
8297 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
8298 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8299 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
8300 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8301 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8302                                        SDValue Neg, SDValue InnerPos,
8303                                        SDValue InnerNeg, bool HasPos,
8304                                        unsigned PosOpcode, unsigned NegOpcode,
8305                                        const SDLoc &DL) {
8306   // fold (or (shl x, (*ext y)),
8307   //          (srl x, (*ext (sub 32, y)))) ->
8308   //   (rotl x, y) or (rotr x, (sub 32, y))
8309   //
8310   // fold (or (shl x, (*ext (sub 32, y))),
8311   //          (srl x, (*ext y))) ->
8312   //   (rotr x, y) or (rotl x, (sub 32, y))
8313   EVT VT = Shifted.getValueType();
8314   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
8315                      /*IsRotate*/ true)) {
8316     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
8317                        HasPos ? Pos : Neg);
8318   }
8319 
8320   return SDValue();
8321 }
8322 
8323 // A subroutine of MatchRotate used once we have found an OR of two opposite
8324 // shifts of N0 + N1.  If Neg == <operand size> - Pos then the OR reduces
8325 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8326 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
8327 // Neg with outer conversions stripped away.
8328 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8329 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8330                                        SDValue Neg, SDValue InnerPos,
8331                                        SDValue InnerNeg, bool HasPos,
8332                                        unsigned PosOpcode, unsigned NegOpcode,
8333                                        const SDLoc &DL) {
8334   EVT VT = N0.getValueType();
8335   unsigned EltBits = VT.getScalarSizeInBits();
8336 
8337   // fold (or (shl x0, (*ext y)),
8338   //          (srl x1, (*ext (sub 32, y)))) ->
8339   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8340   //
8341   // fold (or (shl x0, (*ext (sub 32, y))),
8342   //          (srl x1, (*ext y))) ->
8343   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8344   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
8345     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
8346                        HasPos ? Pos : Neg);
8347   }
8348 
8349   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8350   // so for now just use the PosOpcode case if its legal.
8351   // TODO: When can we use the NegOpcode case?
8352   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
8353     auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
8354       if (Op.getOpcode() != BinOpc)
8355         return false;
8356       ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
8357       return Cst && (Cst->getAPIntValue() == Imm);
8358     };
8359 
8360     // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8361     //   -> (fshl x0, x1, y)
8362     if (IsBinOpImm(N1, ISD::SRL, 1) &&
8363         IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
8364         InnerPos == InnerNeg.getOperand(0) &&
8365         TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
8366       return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
8367     }
8368 
8369     // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8370     //   -> (fshr x0, x1, y)
8371     if (IsBinOpImm(N0, ISD::SHL, 1) &&
8372         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8373         InnerNeg == InnerPos.getOperand(0) &&
8374         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8375       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
8376     }
8377 
8378     // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8379     //   -> (fshr x0, x1, y)
8380     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8381     if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
8382         IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8383         InnerNeg == InnerPos.getOperand(0) &&
8384         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8385       return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
8386     }
8387   }
8388 
8389   return SDValue();
8390 }
8391 
8392 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
8393 // idioms for rotate, and if the target supports rotation instructions, generate
8394 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
8395 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)8396 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
8397   EVT VT = LHS.getValueType();
8398 
8399   // The target must have at least one rotate/funnel flavor.
8400   // We still try to match rotate by constant pre-legalization.
8401   // TODO: Support pre-legalization funnel-shift by constant.
8402   bool HasROTL = hasOperation(ISD::ROTL, VT);
8403   bool HasROTR = hasOperation(ISD::ROTR, VT);
8404   bool HasFSHL = hasOperation(ISD::FSHL, VT);
8405   bool HasFSHR = hasOperation(ISD::FSHR, VT);
8406 
8407   // If the type is going to be promoted and the target has enabled custom
8408   // lowering for rotate, allow matching rotate by non-constants. Only allow
8409   // this for scalar types.
8410   if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
8411                                   TargetLowering::TypePromoteInteger) {
8412     HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
8413     HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
8414   }
8415 
8416   if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8417     return SDValue();
8418 
8419   // Check for truncated rotate.
8420   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8421       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
8422     assert(LHS.getValueType() == RHS.getValueType());
8423     if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
8424       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
8425     }
8426   }
8427 
8428   // Match "(X shl/srl V1) & V2" where V2 may not be present.
8429   SDValue LHSShift;   // The shift.
8430   SDValue LHSMask;    // AND value if any.
8431   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
8432 
8433   SDValue RHSShift;   // The shift.
8434   SDValue RHSMask;    // AND value if any.
8435   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
8436 
8437   // If neither side matched a rotate half, bail
8438   if (!LHSShift && !RHSShift)
8439     return SDValue();
8440 
8441   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8442   // side of the rotate, so try to handle that here. In all cases we need to
8443   // pass the matched shift from the opposite side to compute the opcode and
8444   // needed shift amount to extract.  We still want to do this if both sides
8445   // matched a rotate half because one half may be a potential overshift that
8446   // can be broken down (ie if InstCombine merged two shl or srl ops into a
8447   // single one).
8448 
8449   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8450   if (LHSShift)
8451     if (SDValue NewRHSShift =
8452             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
8453       RHSShift = NewRHSShift;
8454   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8455   if (RHSShift)
8456     if (SDValue NewLHSShift =
8457             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
8458       LHSShift = NewLHSShift;
8459 
8460   // If a side is still missing, nothing else we can do.
8461   if (!RHSShift || !LHSShift)
8462     return SDValue();
8463 
8464   // At this point we've matched or extracted a shift op on each side.
8465 
8466   if (LHSShift.getOpcode() == RHSShift.getOpcode())
8467     return SDValue(); // Shifts must disagree.
8468 
8469   // Canonicalize shl to left side in a shl/srl pair.
8470   if (RHSShift.getOpcode() == ISD::SHL) {
8471     std::swap(LHS, RHS);
8472     std::swap(LHSShift, RHSShift);
8473     std::swap(LHSMask, RHSMask);
8474   }
8475 
8476   // Something has gone wrong - we've lost the shl/srl pair - bail.
8477   if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8478     return SDValue();
8479 
8480   unsigned EltSizeInBits = VT.getScalarSizeInBits();
8481   SDValue LHSShiftArg = LHSShift.getOperand(0);
8482   SDValue LHSShiftAmt = LHSShift.getOperand(1);
8483   SDValue RHSShiftArg = RHSShift.getOperand(0);
8484   SDValue RHSShiftAmt = RHSShift.getOperand(1);
8485 
8486   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8487                                         ConstantSDNode *RHS) {
8488     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8489   };
8490 
8491   auto ApplyMasks = [&](SDValue Res) {
8492     // If there is an AND of either shifted operand, apply it to the result.
8493     if (LHSMask.getNode() || RHSMask.getNode()) {
8494       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8495       SDValue Mask = AllOnes;
8496 
8497       if (LHSMask.getNode()) {
8498         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
8499         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8500                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
8501       }
8502       if (RHSMask.getNode()) {
8503         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
8504         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8505                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
8506       }
8507 
8508       Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
8509     }
8510 
8511     return Res;
8512   };
8513 
8514   // TODO: Support pre-legalization funnel-shift by constant.
8515   bool IsRotate = LHSShiftArg == RHSShiftArg;
8516   if (!IsRotate && !(HasFSHL || HasFSHR)) {
8517     if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8518         ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
8519       // Look for a disguised rotate by constant.
8520       // The common shifted operand X may be hidden inside another 'or'.
8521       SDValue X, Y;
8522       auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8523         if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8524           return false;
8525         if (CommonOp == Or.getOperand(0)) {
8526           X = CommonOp;
8527           Y = Or.getOperand(1);
8528           return true;
8529         }
8530         if (CommonOp == Or.getOperand(1)) {
8531           X = CommonOp;
8532           Y = Or.getOperand(0);
8533           return true;
8534         }
8535         return false;
8536       };
8537 
8538       SDValue Res;
8539       if (matchOr(LHSShiftArg, RHSShiftArg)) {
8540         // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
8541         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
8542         SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
8543         Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
8544       } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
8545         // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
8546         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
8547         SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
8548         Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
8549       } else {
8550         return SDValue();
8551       }
8552 
8553       return ApplyMasks(Res);
8554     }
8555 
8556     return SDValue(); // Requires funnel shift support.
8557   }
8558 
8559   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
8560   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
8561   // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
8562   // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
8563   // iff C1+C2 == EltSizeInBits
8564   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
8565     SDValue Res;
8566     if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
8567       bool UseROTL = !LegalOperations || HasROTL;
8568       Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
8569                         UseROTL ? LHSShiftAmt : RHSShiftAmt);
8570     } else {
8571       bool UseFSHL = !LegalOperations || HasFSHL;
8572       Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
8573                         RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
8574     }
8575 
8576     return ApplyMasks(Res);
8577   }
8578 
8579   // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
8580   // shift.
8581   if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8582     return SDValue();
8583 
8584   // If there is a mask here, and we have a variable shift, we can't be sure
8585   // that we're masking out the right stuff.
8586   if (LHSMask.getNode() || RHSMask.getNode())
8587     return SDValue();
8588 
8589   // If the shift amount is sign/zext/any-extended just peel it off.
8590   SDValue LExtOp0 = LHSShiftAmt;
8591   SDValue RExtOp0 = RHSShiftAmt;
8592   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8593        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8594        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8595        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
8596       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8597        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8598        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8599        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
8600     LExtOp0 = LHSShiftAmt.getOperand(0);
8601     RExtOp0 = RHSShiftAmt.getOperand(0);
8602   }
8603 
8604   if (IsRotate && (HasROTL || HasROTR)) {
8605     SDValue TryL =
8606         MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
8607                           RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
8608     if (TryL)
8609       return TryL;
8610 
8611     SDValue TryR =
8612         MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
8613                           LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
8614     if (TryR)
8615       return TryR;
8616   }
8617 
8618   SDValue TryL =
8619       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
8620                         LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
8621   if (TryL)
8622     return TryL;
8623 
8624   SDValue TryR =
8625       MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
8626                         RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
8627   if (TryR)
8628     return TryR;
8629 
8630   return SDValue();
8631 }
8632 
8633 /// Recursively traverses the expression calculating the origin of the requested
8634 /// byte of the given value. Returns std::nullopt if the provider can't be
8635 /// calculated.
8636 ///
8637 /// For all the values except the root of the expression, we verify that the
8638 /// value has exactly one use and if not then return std::nullopt. This way if
8639 /// the origin of the byte is returned it's guaranteed that the values which
8640 /// contribute to the byte are not used outside of this expression.
8641 
8642 /// However, there is a special case when dealing with vector loads -- we allow
8643 /// more than one use if the load is a vector type.  Since the values that
8644 /// contribute to the byte ultimately come from the ExtractVectorElements of the
8645 /// Load, we don't care if the Load has uses other than ExtractVectorElements,
8646 /// because those operations are independent from the pattern to be combined.
8647 /// For vector loads, we simply care that the ByteProviders are adjacent
8648 /// positions of the same vector, and their index matches the byte that is being
8649 /// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
8650 /// is the index used in an ExtractVectorElement, and \p StartingIndex is the
8651 /// byte position we are trying to provide for the LoadCombine. If these do
8652 /// not match, then we can not combine the vector loads. \p Index uses the
8653 /// byte position we are trying to provide for and is matched against the
8654 /// shl and load size. The \p Index algorithm ensures the requested byte is
8655 /// provided for by the pattern, and the pattern does not over provide bytes.
8656 ///
8657 ///
8658 /// The supported LoadCombine pattern for vector loads is as follows
8659 ///                              or
8660 ///                          /        \
8661 ///                         or        shl
8662 ///                       /     \      |
8663 ///                     or      shl   zext
8664 ///                   /    \     |     |
8665 ///                 shl   zext  zext  EVE*
8666 ///                  |     |     |     |
8667 ///                 zext  EVE*  EVE*  LOAD
8668 ///                  |     |     |
8669 ///                 EVE*  LOAD  LOAD
8670 ///                  |
8671 ///                 LOAD
8672 ///
8673 /// *ExtractVectorElement
8674 using SDByteProvider = ByteProvider<SDNode *>;
8675 
8676 static std::optional<SDByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,std::optional<uint64_t> VectorIndex,unsigned StartingIndex=0)8677 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
8678                       std::optional<uint64_t> VectorIndex,
8679                       unsigned StartingIndex = 0) {
8680 
8681   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
8682   if (Depth == 10)
8683     return std::nullopt;
8684 
8685   // Only allow multiple uses if the instruction is a vector load (in which
8686   // case we will use the load for every ExtractVectorElement)
8687   if (Depth && !Op.hasOneUse() &&
8688       (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
8689     return std::nullopt;
8690 
8691   // Fail to combine if we have encountered anything but a LOAD after handling
8692   // an ExtractVectorElement.
8693   if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
8694     return std::nullopt;
8695 
8696   unsigned BitWidth = Op.getValueSizeInBits();
8697   if (BitWidth % 8 != 0)
8698     return std::nullopt;
8699   unsigned ByteWidth = BitWidth / 8;
8700   assert(Index < ByteWidth && "invalid index requested");
8701   (void) ByteWidth;
8702 
8703   switch (Op.getOpcode()) {
8704   case ISD::OR: {
8705     auto LHS =
8706         calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
8707     if (!LHS)
8708       return std::nullopt;
8709     auto RHS =
8710         calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
8711     if (!RHS)
8712       return std::nullopt;
8713 
8714     if (LHS->isConstantZero())
8715       return RHS;
8716     if (RHS->isConstantZero())
8717       return LHS;
8718     return std::nullopt;
8719   }
8720   case ISD::SHL: {
8721     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8722     if (!ShiftOp)
8723       return std::nullopt;
8724 
8725     uint64_t BitShift = ShiftOp->getZExtValue();
8726 
8727     if (BitShift % 8 != 0)
8728       return std::nullopt;
8729     uint64_t ByteShift = BitShift / 8;
8730 
8731     // If we are shifting by an amount greater than the index we are trying to
8732     // provide, then do not provide anything. Otherwise, subtract the index by
8733     // the amount we shifted by.
8734     return Index < ByteShift
8735                ? SDByteProvider::getConstantZero()
8736                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
8737                                        Depth + 1, VectorIndex, Index);
8738   }
8739   case ISD::ANY_EXTEND:
8740   case ISD::SIGN_EXTEND:
8741   case ISD::ZERO_EXTEND: {
8742     SDValue NarrowOp = Op->getOperand(0);
8743     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8744     if (NarrowBitWidth % 8 != 0)
8745       return std::nullopt;
8746     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8747 
8748     if (Index >= NarrowByteWidth)
8749       return Op.getOpcode() == ISD::ZERO_EXTEND
8750                  ? std::optional<SDByteProvider>(
8751                        SDByteProvider::getConstantZero())
8752                  : std::nullopt;
8753     return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
8754                                  StartingIndex);
8755   }
8756   case ISD::BSWAP:
8757     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
8758                                  Depth + 1, VectorIndex, StartingIndex);
8759   case ISD::EXTRACT_VECTOR_ELT: {
8760     auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8761     if (!OffsetOp)
8762       return std::nullopt;
8763 
8764     VectorIndex = OffsetOp->getZExtValue();
8765 
8766     SDValue NarrowOp = Op->getOperand(0);
8767     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8768     if (NarrowBitWidth % 8 != 0)
8769       return std::nullopt;
8770     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8771     // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
8772     // type, leaving the high bits undefined.
8773     if (Index >= NarrowByteWidth)
8774       return std::nullopt;
8775 
8776     // Check to see if the position of the element in the vector corresponds
8777     // with the byte we are trying to provide for. In the case of a vector of
8778     // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
8779     // the element will provide a range of bytes. For example, if we have a
8780     // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
8781     // 3).
8782     if (*VectorIndex * NarrowByteWidth > StartingIndex)
8783       return std::nullopt;
8784     if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
8785       return std::nullopt;
8786 
8787     return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
8788                                  VectorIndex, StartingIndex);
8789   }
8790   case ISD::LOAD: {
8791     auto L = cast<LoadSDNode>(Op.getNode());
8792     if (!L->isSimple() || L->isIndexed())
8793       return std::nullopt;
8794 
8795     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
8796     if (NarrowBitWidth % 8 != 0)
8797       return std::nullopt;
8798     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8799 
8800     // If the width of the load does not reach byte we are trying to provide for
8801     // and it is not a ZEXTLOAD, then the load does not provide for the byte in
8802     // question
8803     if (Index >= NarrowByteWidth)
8804       return L->getExtensionType() == ISD::ZEXTLOAD
8805                  ? std::optional<SDByteProvider>(
8806                        SDByteProvider::getConstantZero())
8807                  : std::nullopt;
8808 
8809     unsigned BPVectorIndex = VectorIndex.value_or(0U);
8810     return SDByteProvider::getSrc(L, Index, BPVectorIndex);
8811   }
8812   }
8813 
8814   return std::nullopt;
8815 }
8816 
littleEndianByteAt(unsigned BW,unsigned i)8817 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
8818   return i;
8819 }
8820 
bigEndianByteAt(unsigned BW,unsigned i)8821 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
8822   return BW - i - 1;
8823 }
8824 
8825 // Check if the bytes offsets we are looking at match with either big or
8826 // little endian value loaded. Return true for big endian, false for little
8827 // endian, and std::nullopt if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)8828 static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
8829                                        int64_t FirstOffset) {
8830   // The endian can be decided only when it is 2 bytes at least.
8831   unsigned Width = ByteOffsets.size();
8832   if (Width < 2)
8833     return std::nullopt;
8834 
8835   bool BigEndian = true, LittleEndian = true;
8836   for (unsigned i = 0; i < Width; i++) {
8837     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
8838     LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
8839     BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
8840     if (!BigEndian && !LittleEndian)
8841       return std::nullopt;
8842   }
8843 
8844   assert((BigEndian != LittleEndian) && "It should be either big endian or"
8845                                         "little endian");
8846   return BigEndian;
8847 }
8848 
8849 // Look through one layer of truncate or extend.
stripTruncAndExt(SDValue Value)8850 static SDValue stripTruncAndExt(SDValue Value) {
8851   switch (Value.getOpcode()) {
8852   case ISD::TRUNCATE:
8853   case ISD::ZERO_EXTEND:
8854   case ISD::SIGN_EXTEND:
8855   case ISD::ANY_EXTEND:
8856     return Value.getOperand(0);
8857   }
8858   return SDValue();
8859 }
8860 
8861 /// Match a pattern where a wide type scalar value is stored by several narrow
8862 /// stores. Fold it into a single store or a BSWAP and a store if the targets
8863 /// supports it.
8864 ///
8865 /// Assuming little endian target:
8866 ///  i8 *p = ...
8867 ///  i32 val = ...
8868 ///  p[0] = (val >> 0) & 0xFF;
8869 ///  p[1] = (val >> 8) & 0xFF;
8870 ///  p[2] = (val >> 16) & 0xFF;
8871 ///  p[3] = (val >> 24) & 0xFF;
8872 /// =>
8873 ///  *((i32)p) = val;
8874 ///
8875 ///  i8 *p = ...
8876 ///  i32 val = ...
8877 ///  p[0] = (val >> 24) & 0xFF;
8878 ///  p[1] = (val >> 16) & 0xFF;
8879 ///  p[2] = (val >> 8) & 0xFF;
8880 ///  p[3] = (val >> 0) & 0xFF;
8881 /// =>
8882 ///  *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)8883 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
8884   // The matching looks for "store (trunc x)" patterns that appear early but are
8885   // likely to be replaced by truncating store nodes during combining.
8886   // TODO: If there is evidence that running this later would help, this
8887   //       limitation could be removed. Legality checks may need to be added
8888   //       for the created store and optional bswap/rotate.
8889   if (LegalOperations || OptLevel == CodeGenOptLevel::None)
8890     return SDValue();
8891 
8892   // We only handle merging simple stores of 1-4 bytes.
8893   // TODO: Allow unordered atomics when wider type is legal (see D66309)
8894   EVT MemVT = N->getMemoryVT();
8895   if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
8896       !N->isSimple() || N->isIndexed())
8897     return SDValue();
8898 
8899   // Collect all of the stores in the chain, upto the maximum store width (i64).
8900   SDValue Chain = N->getChain();
8901   SmallVector<StoreSDNode *, 8> Stores = {N};
8902   unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
8903   unsigned MaxWideNumBits = 64;
8904   unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
8905   while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
8906     // All stores must be the same size to ensure that we are writing all of the
8907     // bytes in the wide value.
8908     // This store should have exactly one use as a chain operand for another
8909     // store in the merging set. If there are other chain uses, then the
8910     // transform may not be safe because order of loads/stores outside of this
8911     // set may not be preserved.
8912     // TODO: We could allow multiple sizes by tracking each stored byte.
8913     if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
8914         Store->isIndexed() || !Store->hasOneUse())
8915       return SDValue();
8916     Stores.push_back(Store);
8917     Chain = Store->getChain();
8918     if (MaxStores < Stores.size())
8919       return SDValue();
8920   }
8921   // There is no reason to continue if we do not have at least a pair of stores.
8922   if (Stores.size() < 2)
8923     return SDValue();
8924 
8925   // Handle simple types only.
8926   LLVMContext &Context = *DAG.getContext();
8927   unsigned NumStores = Stores.size();
8928   unsigned WideNumBits = NumStores * NarrowNumBits;
8929   EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
8930   if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
8931     return SDValue();
8932 
8933   // Check if all bytes of the source value that we are looking at are stored
8934   // to the same base address. Collect offsets from Base address into OffsetMap.
8935   SDValue SourceValue;
8936   SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
8937   int64_t FirstOffset = INT64_MAX;
8938   StoreSDNode *FirstStore = nullptr;
8939   std::optional<BaseIndexOffset> Base;
8940   for (auto *Store : Stores) {
8941     // All the stores store different parts of the CombinedValue. A truncate is
8942     // required to get the partial value.
8943     SDValue Trunc = Store->getValue();
8944     if (Trunc.getOpcode() != ISD::TRUNCATE)
8945       return SDValue();
8946     // Other than the first/last part, a shift operation is required to get the
8947     // offset.
8948     int64_t Offset = 0;
8949     SDValue WideVal = Trunc.getOperand(0);
8950     if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
8951         isa<ConstantSDNode>(WideVal.getOperand(1))) {
8952       // The shift amount must be a constant multiple of the narrow type.
8953       // It is translated to the offset address in the wide source value "y".
8954       //
8955       // x = srl y, ShiftAmtC
8956       // i8 z = trunc x
8957       // store z, ...
8958       uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
8959       if (ShiftAmtC % NarrowNumBits != 0)
8960         return SDValue();
8961 
8962       // Make sure we aren't reading bits that are shifted in.
8963       if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
8964         return SDValue();
8965 
8966       Offset = ShiftAmtC / NarrowNumBits;
8967       WideVal = WideVal.getOperand(0);
8968     }
8969 
8970     // Stores must share the same source value with different offsets.
8971     if (!SourceValue)
8972       SourceValue = WideVal;
8973     else if (SourceValue != WideVal) {
8974       // Truncate and extends can be stripped to see if the values are related.
8975       if (stripTruncAndExt(SourceValue) != WideVal &&
8976           stripTruncAndExt(WideVal) != SourceValue)
8977         return SDValue();
8978 
8979       if (WideVal.getScalarValueSizeInBits() >
8980           SourceValue.getScalarValueSizeInBits())
8981         SourceValue = WideVal;
8982 
8983       // Give up if the source value type is smaller than the store size.
8984       if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8985         return SDValue();
8986     }
8987 
8988     // Stores must share the same base address.
8989     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
8990     int64_t ByteOffsetFromBase = 0;
8991     if (!Base)
8992       Base = Ptr;
8993     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8994       return SDValue();
8995 
8996     // Remember the first store.
8997     if (ByteOffsetFromBase < FirstOffset) {
8998       FirstStore = Store;
8999       FirstOffset = ByteOffsetFromBase;
9000     }
9001     // Map the offset in the store and the offset in the combined value, and
9002     // early return if it has been set before.
9003     if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
9004       return SDValue();
9005     OffsetMap[Offset] = ByteOffsetFromBase;
9006   }
9007 
9008   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9009   assert(FirstStore && "First store must be set");
9010 
9011   // Check that a store of the wide type is both allowed and fast on the target
9012   const DataLayout &Layout = DAG.getDataLayout();
9013   unsigned Fast = 0;
9014   bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
9015                                         *FirstStore->getMemOperand(), &Fast);
9016   if (!Allowed || !Fast)
9017     return SDValue();
9018 
9019   // Check if the pieces of the value are going to the expected places in memory
9020   // to merge the stores.
9021   auto checkOffsets = [&](bool MatchLittleEndian) {
9022     if (MatchLittleEndian) {
9023       for (unsigned i = 0; i != NumStores; ++i)
9024         if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9025           return false;
9026     } else { // MatchBigEndian by reversing loop counter.
9027       for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9028         if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9029           return false;
9030     }
9031     return true;
9032   };
9033 
9034   // Check if the offsets line up for the native data layout of this target.
9035   bool NeedBswap = false;
9036   bool NeedRotate = false;
9037   if (!checkOffsets(Layout.isLittleEndian())) {
9038     // Special-case: check if byte offsets line up for the opposite endian.
9039     if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9040       NeedBswap = true;
9041     else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9042       NeedRotate = true;
9043     else
9044       return SDValue();
9045   }
9046 
9047   SDLoc DL(N);
9048   if (WideVT != SourceValue.getValueType()) {
9049     assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9050            "Unexpected store value to merge");
9051     SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
9052   }
9053 
9054   // Before legalize we can introduce illegal bswaps/rotates which will be later
9055   // converted to an explicit bswap sequence. This way we end up with a single
9056   // store and byte shuffling instead of several stores and byte shuffling.
9057   if (NeedBswap) {
9058     SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
9059   } else if (NeedRotate) {
9060     assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9061     SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
9062     SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
9063   }
9064 
9065   SDValue NewStore =
9066       DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
9067                    FirstStore->getPointerInfo(), FirstStore->getAlign());
9068 
9069   // Rely on other DAG combine rules to remove the other individual stores.
9070   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
9071   return NewStore;
9072 }
9073 
9074 /// Match a pattern where a wide type scalar value is loaded by several narrow
9075 /// loads and combined by shifts and ors. Fold it into a single load or a load
9076 /// and a BSWAP if the targets supports it.
9077 ///
9078 /// Assuming little endian target:
9079 ///  i8 *a = ...
9080 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9081 /// =>
9082 ///  i32 val = *((i32)a)
9083 ///
9084 ///  i8 *a = ...
9085 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9086 /// =>
9087 ///  i32 val = BSWAP(*((i32)a))
9088 ///
9089 /// TODO: This rule matches complex patterns with OR node roots and doesn't
9090 /// interact well with the worklist mechanism. When a part of the pattern is
9091 /// updated (e.g. one of the loads) its direct users are put into the worklist,
9092 /// but the root node of the pattern which triggers the load combine is not
9093 /// necessarily a direct user of the changed node. For example, once the address
9094 /// of t28 load is reassociated load combine won't be triggered:
9095 ///             t25: i32 = add t4, Constant:i32<2>
9096 ///           t26: i64 = sign_extend t25
9097 ///        t27: i64 = add t2, t26
9098 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9099 ///     t29: i32 = zero_extend t28
9100 ///   t32: i32 = shl t29, Constant:i8<8>
9101 /// t33: i32 = or t23, t32
9102 /// As a possible fix visitLoad can check if the load can be a part of a load
9103 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)9104 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9105   assert(N->getOpcode() == ISD::OR &&
9106          "Can only match load combining against OR nodes");
9107 
9108   // Handles simple types only
9109   EVT VT = N->getValueType(0);
9110   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9111     return SDValue();
9112   unsigned ByteWidth = VT.getSizeInBits() / 8;
9113 
9114   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9115   auto MemoryByteOffset = [&](SDByteProvider P) {
9116     assert(P.hasSrc() && "Must be a memory byte provider");
9117     auto *Load = cast<LoadSDNode>(P.Src.value());
9118 
9119     unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9120 
9121     assert(LoadBitWidth % 8 == 0 &&
9122            "can only analyze providers for individual bytes not bit");
9123     unsigned LoadByteWidth = LoadBitWidth / 8;
9124     return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset)
9125                              : littleEndianByteAt(LoadByteWidth, P.DestOffset);
9126   };
9127 
9128   std::optional<BaseIndexOffset> Base;
9129   SDValue Chain;
9130 
9131   SmallPtrSet<LoadSDNode *, 8> Loads;
9132   std::optional<SDByteProvider> FirstByteProvider;
9133   int64_t FirstOffset = INT64_MAX;
9134 
9135   // Check if all the bytes of the OR we are looking at are loaded from the same
9136   // base address. Collect bytes offsets from Base address in ByteOffsets.
9137   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9138   unsigned ZeroExtendedBytes = 0;
9139   for (int i = ByteWidth - 1; i >= 0; --i) {
9140     auto P =
9141         calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
9142                               /*StartingIndex*/ i);
9143     if (!P)
9144       return SDValue();
9145 
9146     if (P->isConstantZero()) {
9147       // It's OK for the N most significant bytes to be 0, we can just
9148       // zero-extend the load.
9149       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9150         return SDValue();
9151       continue;
9152     }
9153     assert(P->hasSrc() && "provenance should either be memory or zero");
9154     auto *L = cast<LoadSDNode>(P->Src.value());
9155 
9156     // All loads must share the same chain
9157     SDValue LChain = L->getChain();
9158     if (!Chain)
9159       Chain = LChain;
9160     else if (Chain != LChain)
9161       return SDValue();
9162 
9163     // Loads must share the same base address
9164     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
9165     int64_t ByteOffsetFromBase = 0;
9166 
9167     // For vector loads, the expected load combine pattern will have an
9168     // ExtractElement for each index in the vector. While each of these
9169     // ExtractElements will be accessing the same base address as determined
9170     // by the load instruction, the actual bytes they interact with will differ
9171     // due to different ExtractElement indices. To accurately determine the
9172     // byte position of an ExtractElement, we offset the base load ptr with
9173     // the index multiplied by the byte size of each element in the vector.
9174     if (L->getMemoryVT().isVector()) {
9175       unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9176       if (LoadWidthInBit % 8 != 0)
9177         return SDValue();
9178       unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9179       Ptr.addToOffset(ByteOffsetFromVector);
9180     }
9181 
9182     if (!Base)
9183       Base = Ptr;
9184 
9185     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9186       return SDValue();
9187 
9188     // Calculate the offset of the current byte from the base address
9189     ByteOffsetFromBase += MemoryByteOffset(*P);
9190     ByteOffsets[i] = ByteOffsetFromBase;
9191 
9192     // Remember the first byte load
9193     if (ByteOffsetFromBase < FirstOffset) {
9194       FirstByteProvider = P;
9195       FirstOffset = ByteOffsetFromBase;
9196     }
9197 
9198     Loads.insert(L);
9199   }
9200 
9201   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9202          "memory, so there must be at least one load which produces the value");
9203   assert(Base && "Base address of the accessed memory location must be set");
9204   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9205 
9206   bool NeedsZext = ZeroExtendedBytes > 0;
9207 
9208   EVT MemVT =
9209       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
9210 
9211   if (!MemVT.isSimple())
9212     return SDValue();
9213 
9214   // Before legalize we can introduce too wide illegal loads which will be later
9215   // split into legal sized loads. This enables us to combine i64 load by i8
9216   // patterns to a couple of i32 loads on 32 bit targets.
9217   if (LegalOperations &&
9218       !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
9219                             MemVT))
9220     return SDValue();
9221 
9222   // Check if the bytes of the OR we are looking at match with either big or
9223   // little endian value load
9224   std::optional<bool> IsBigEndian = isBigEndian(
9225       ArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
9226   if (!IsBigEndian)
9227     return SDValue();
9228 
9229   assert(FirstByteProvider && "must be set");
9230 
9231   // Ensure that the first byte is loaded from zero offset of the first load.
9232   // So the combined value can be loaded from the first load address.
9233   if (MemoryByteOffset(*FirstByteProvider) != 0)
9234     return SDValue();
9235   auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value());
9236 
9237   // The node we are looking at matches with the pattern, check if we can
9238   // replace it with a single (possibly zero-extended) load and bswap + shift if
9239   // needed.
9240 
9241   // If the load needs byte swap check if the target supports it
9242   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9243 
9244   // Before legalize we can introduce illegal bswaps which will be later
9245   // converted to an explicit bswap sequence. This way we end up with a single
9246   // load and byte shuffling instead of several loads and byte shuffling.
9247   // We do not introduce illegal bswaps when zero-extending as this tends to
9248   // introduce too many arithmetic instructions.
9249   if (NeedsBswap && (LegalOperations || NeedsZext) &&
9250       !TLI.isOperationLegal(ISD::BSWAP, VT))
9251     return SDValue();
9252 
9253   // If we need to bswap and zero extend, we have to insert a shift. Check that
9254   // it is legal.
9255   if (NeedsBswap && NeedsZext && LegalOperations &&
9256       !TLI.isOperationLegal(ISD::SHL, VT))
9257     return SDValue();
9258 
9259   // Check that a load of the wide type is both allowed and fast on the target
9260   unsigned Fast = 0;
9261   bool Allowed =
9262       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
9263                              *FirstLoad->getMemOperand(), &Fast);
9264   if (!Allowed || !Fast)
9265     return SDValue();
9266 
9267   SDValue NewLoad =
9268       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
9269                      Chain, FirstLoad->getBasePtr(),
9270                      FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
9271 
9272   // Transfer chain users from old loads to the new load.
9273   for (LoadSDNode *L : Loads)
9274     DAG.makeEquivalentMemoryOrdering(L, NewLoad);
9275 
9276   if (!NeedsBswap)
9277     return NewLoad;
9278 
9279   SDValue ShiftedLoad =
9280       NeedsZext ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
9281                               DAG.getShiftAmountConstant(ZeroExtendedBytes * 8,
9282                                                          VT, SDLoc(N)))
9283                 : NewLoad;
9284   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
9285 }
9286 
9287 // If the target has andn, bsl, or a similar bit-select instruction,
9288 // we want to unfold masked merge, with canonical pattern of:
9289 //   |        A  |  |B|
9290 //   ((x ^ y) & m) ^ y
9291 //    |  D  |
9292 // Into:
9293 //   (x & m) | (y & ~m)
9294 // If y is a constant, m is not a 'not', and the 'andn' does not work with
9295 // immediates, we unfold into a different pattern:
9296 //   ~(~x & m) & (m | y)
9297 // If x is a constant, m is a 'not', and the 'andn' does not work with
9298 // immediates, we unfold into a different pattern:
9299 //   (x | ~m) & ~(~m & ~y)
9300 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9301 //       the very least that breaks andnpd / andnps patterns, and because those
9302 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)9303 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9304   assert(N->getOpcode() == ISD::XOR);
9305 
9306   // Don't touch 'not' (i.e. where y = -1).
9307   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
9308     return SDValue();
9309 
9310   EVT VT = N->getValueType(0);
9311 
9312   // There are 3 commutable operators in the pattern,
9313   // so we have to deal with 8 possible variants of the basic pattern.
9314   SDValue X, Y, M;
9315   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9316     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9317       return false;
9318     SDValue Xor = And.getOperand(XorIdx);
9319     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9320       return false;
9321     SDValue Xor0 = Xor.getOperand(0);
9322     SDValue Xor1 = Xor.getOperand(1);
9323     // Don't touch 'not' (i.e. where y = -1).
9324     if (isAllOnesOrAllOnesSplat(Xor1))
9325       return false;
9326     if (Other == Xor0)
9327       std::swap(Xor0, Xor1);
9328     if (Other != Xor1)
9329       return false;
9330     X = Xor0;
9331     Y = Xor1;
9332     M = And.getOperand(XorIdx ? 0 : 1);
9333     return true;
9334   };
9335 
9336   SDValue N0 = N->getOperand(0);
9337   SDValue N1 = N->getOperand(1);
9338   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9339       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9340     return SDValue();
9341 
9342   // Don't do anything if the mask is constant. This should not be reachable.
9343   // InstCombine should have already unfolded this pattern, and DAGCombiner
9344   // probably shouldn't produce it, too.
9345   if (isa<ConstantSDNode>(M.getNode()))
9346     return SDValue();
9347 
9348   // We can transform if the target has AndNot
9349   if (!TLI.hasAndNot(M))
9350     return SDValue();
9351 
9352   SDLoc DL(N);
9353 
9354   // If Y is a constant, check that 'andn' works with immediates. Unless M is
9355   // a bitwise not that would already allow ANDN to be used.
9356   if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
9357     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9358     // If not, we need to do a bit more work to make sure andn is still used.
9359     SDValue NotX = DAG.getNOT(DL, X, VT);
9360     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
9361     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
9362     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
9363     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
9364   }
9365 
9366   // If X is a constant and M is a bitwise not, check that 'andn' works with
9367   // immediates.
9368   if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
9369     assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9370     // If not, we need to do a bit more work to make sure andn is still used.
9371     SDValue NotM = M.getOperand(0);
9372     SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
9373     SDValue NotY = DAG.getNOT(DL, Y, VT);
9374     SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
9375     SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
9376     return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
9377   }
9378 
9379   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
9380   SDValue NotM = DAG.getNOT(DL, M, VT);
9381   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
9382 
9383   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
9384 }
9385 
visitXOR(SDNode * N)9386 SDValue DAGCombiner::visitXOR(SDNode *N) {
9387   SDValue N0 = N->getOperand(0);
9388   SDValue N1 = N->getOperand(1);
9389   EVT VT = N0.getValueType();
9390   SDLoc DL(N);
9391 
9392   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9393   if (N0.isUndef() && N1.isUndef())
9394     return DAG.getConstant(0, DL, VT);
9395 
9396   // fold (xor x, undef) -> undef
9397   if (N0.isUndef())
9398     return N0;
9399   if (N1.isUndef())
9400     return N1;
9401 
9402   // fold (xor c1, c2) -> c1^c2
9403   if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
9404     return C;
9405 
9406   // canonicalize constant to RHS
9407   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
9408       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
9409     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
9410 
9411   // fold vector ops
9412   if (VT.isVector()) {
9413     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9414       return FoldedVOp;
9415 
9416     // fold (xor x, 0) -> x, vector edition
9417     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
9418       return N0;
9419   }
9420 
9421   // fold (xor x, 0) -> x
9422   if (isNullConstant(N1))
9423     return N0;
9424 
9425   if (SDValue NewSel = foldBinOpIntoSelect(N))
9426     return NewSel;
9427 
9428   // reassociate xor
9429   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
9430     return RXOR;
9431 
9432   // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9433   if (SDValue SD =
9434           reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
9435     return SD;
9436 
9437   // fold (a^b) -> (a|b) iff a and b share no bits.
9438   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
9439       DAG.haveNoCommonBitsSet(N0, N1)) {
9440     SDNodeFlags Flags;
9441     Flags.setDisjoint(true);
9442     return DAG.getNode(ISD::OR, DL, VT, N0, N1, Flags);
9443   }
9444 
9445   // look for 'add-like' folds:
9446   // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9447   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
9448       isMinSignedConstant(N1))
9449     if (SDValue Combined = visitADDLike(N))
9450       return Combined;
9451 
9452   // fold !(x cc y) -> (x !cc y)
9453   unsigned N0Opcode = N0.getOpcode();
9454   SDValue LHS, RHS, CC;
9455   if (TLI.isConstTrueVal(N1) &&
9456       isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
9457     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
9458                                                LHS.getValueType());
9459     if (!LegalOperations ||
9460         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
9461       switch (N0Opcode) {
9462       default:
9463         llvm_unreachable("Unhandled SetCC Equivalent!");
9464       case ISD::SETCC:
9465         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
9466       case ISD::SELECT_CC:
9467         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
9468                                N0.getOperand(3), NotCC);
9469       case ISD::STRICT_FSETCC:
9470       case ISD::STRICT_FSETCCS: {
9471         if (N0.hasOneUse()) {
9472           // FIXME Can we handle multiple uses? Could we token factor the chain
9473           // results from the new/old setcc?
9474           SDValue SetCC =
9475               DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
9476                            N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
9477           CombineTo(N, SetCC);
9478           DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
9479           recursivelyDeleteUnusedNodes(N0.getNode());
9480           return SDValue(N, 0); // Return N so it doesn't get rechecked!
9481         }
9482         break;
9483       }
9484       }
9485     }
9486   }
9487 
9488   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9489   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9490       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
9491     SDValue V = N0.getOperand(0);
9492     SDLoc DL0(N0);
9493     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
9494                     DAG.getConstant(1, DL0, V.getValueType()));
9495     AddToWorklist(V.getNode());
9496     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
9497   }
9498 
9499   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9500   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
9501       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9502     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9503     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
9504       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9505       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9506       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9507       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9508       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9509     }
9510   }
9511   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9512   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
9513       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9514     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9515     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
9516       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9517       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9518       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9519       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9520       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9521     }
9522   }
9523 
9524   // fold (not (neg x)) -> (add X, -1)
9525   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9526   // Y is a constant or the subtract has a single use.
9527   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
9528       isNullConstant(N0.getOperand(0))) {
9529     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
9530                        DAG.getAllOnesConstant(DL, VT));
9531   }
9532 
9533   // fold (not (add X, -1)) -> (neg X)
9534   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
9535       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
9536     return DAG.getNegative(N0.getOperand(0), DL, VT);
9537   }
9538 
9539   // fold (xor (and x, y), y) -> (and (not x), y)
9540   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
9541     SDValue X = N0.getOperand(0);
9542     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
9543     AddToWorklist(NotX.getNode());
9544     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
9545   }
9546 
9547   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
9548   if (!LegalOperations || hasOperation(ISD::ABS, VT)) {
9549     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
9550     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
9551     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
9552       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
9553       SDValue S0 = S.getOperand(0);
9554       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
9555         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
9556           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
9557             return DAG.getNode(ISD::ABS, DL, VT, S0);
9558     }
9559   }
9560 
9561   // fold (xor x, x) -> 0
9562   if (N0 == N1)
9563     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
9564 
9565   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
9566   // Here is a concrete example of this equivalence:
9567   // i16   x ==  14
9568   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
9569   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
9570   //
9571   // =>
9572   //
9573   // i16     ~1      == 0b1111111111111110
9574   // i16 rol(~1, 14) == 0b1011111111111111
9575   //
9576   // Some additional tips to help conceptualize this transform:
9577   // - Try to see the operation as placing a single zero in a value of all ones.
9578   // - There exists no value for x which would allow the result to contain zero.
9579   // - Values of x larger than the bitwidth are undefined and do not require a
9580   //   consistent result.
9581   // - Pushing the zero left requires shifting one bits in from the right.
9582   // A rotate left of ~1 is a nice way of achieving the desired result.
9583   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
9584       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
9585     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
9586                        N0.getOperand(1));
9587   }
9588 
9589   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
9590   if (N0Opcode == N1.getOpcode())
9591     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
9592       return V;
9593 
9594   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
9595     return R;
9596   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
9597     return R;
9598   if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
9599     return R;
9600 
9601   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
9602   if (SDValue MM = unfoldMaskedMerge(N))
9603     return MM;
9604 
9605   // Simplify the expression using non-local knowledge.
9606   if (SimplifyDemandedBits(SDValue(N, 0)))
9607     return SDValue(N, 0);
9608 
9609   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
9610     return Combined;
9611 
9612   return SDValue();
9613 }
9614 
9615 /// If we have a shift-by-constant of a bitwise logic op that itself has a
9616 /// shift-by-constant operand with identical opcode, we may be able to convert
9617 /// that into 2 independent shifts followed by the logic op. This is a
9618 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)9619 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
9620   // Match a one-use bitwise logic op.
9621   SDValue LogicOp = Shift->getOperand(0);
9622   if (!LogicOp.hasOneUse())
9623     return SDValue();
9624 
9625   unsigned LogicOpcode = LogicOp.getOpcode();
9626   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
9627       LogicOpcode != ISD::XOR)
9628     return SDValue();
9629 
9630   // Find a matching one-use shift by constant.
9631   unsigned ShiftOpcode = Shift->getOpcode();
9632   SDValue C1 = Shift->getOperand(1);
9633   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
9634   assert(C1Node && "Expected a shift with constant operand");
9635   const APInt &C1Val = C1Node->getAPIntValue();
9636   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
9637                              const APInt *&ShiftAmtVal) {
9638     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
9639       return false;
9640 
9641     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
9642     if (!ShiftCNode)
9643       return false;
9644 
9645     // Capture the shifted operand and shift amount value.
9646     ShiftOp = V.getOperand(0);
9647     ShiftAmtVal = &ShiftCNode->getAPIntValue();
9648 
9649     // Shift amount types do not have to match their operand type, so check that
9650     // the constants are the same width.
9651     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
9652       return false;
9653 
9654     // The fold is not valid if the sum of the shift values doesn't fit in the
9655     // given shift amount type.
9656     bool Overflow = false;
9657     APInt NewShiftAmt = C1Val.uadd_ov(*ShiftAmtVal, Overflow);
9658     if (Overflow)
9659       return false;
9660 
9661     // The fold is not valid if the sum of the shift values exceeds bitwidth.
9662     if (NewShiftAmt.uge(V.getScalarValueSizeInBits()))
9663       return false;
9664 
9665     return true;
9666   };
9667 
9668   // Logic ops are commutative, so check each operand for a match.
9669   SDValue X, Y;
9670   const APInt *C0Val;
9671   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
9672     Y = LogicOp.getOperand(1);
9673   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
9674     Y = LogicOp.getOperand(0);
9675   else
9676     return SDValue();
9677 
9678   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
9679   SDLoc DL(Shift);
9680   EVT VT = Shift->getValueType(0);
9681   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
9682   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
9683   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
9684   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
9685   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2,
9686                      LogicOp->getFlags());
9687 }
9688 
9689 /// Handle transforms common to the three shifts, when the shift amount is a
9690 /// constant.
9691 /// We are looking for: (shift being one of shl/sra/srl)
9692 ///   shift (binop X, C0), C1
9693 /// And want to transform into:
9694 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)9695 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
9696   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
9697 
9698   // Do not turn a 'not' into a regular xor.
9699   if (isBitwiseNot(N->getOperand(0)))
9700     return SDValue();
9701 
9702   // The inner binop must be one-use, since we want to replace it.
9703   SDValue LHS = N->getOperand(0);
9704   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
9705     return SDValue();
9706 
9707   // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
9708   if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
9709     return R;
9710 
9711   // We want to pull some binops through shifts, so that we have (and (shift))
9712   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
9713   // thing happens with address calculations, so it's important to canonicalize
9714   // it.
9715   switch (LHS.getOpcode()) {
9716   default:
9717     return SDValue();
9718   case ISD::OR:
9719   case ISD::XOR:
9720   case ISD::AND:
9721     break;
9722   case ISD::ADD:
9723     if (N->getOpcode() != ISD::SHL)
9724       return SDValue(); // only shl(add) not sr[al](add).
9725     break;
9726   }
9727 
9728   // FIXME: disable this unless the input to the binop is a shift by a constant
9729   // or is copy/select. Enable this in other cases when figure out it's exactly
9730   // profitable.
9731   SDValue BinOpLHSVal = LHS.getOperand(0);
9732   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
9733                             BinOpLHSVal.getOpcode() == ISD::SRA ||
9734                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
9735                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
9736   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
9737                         BinOpLHSVal.getOpcode() == ISD::SELECT;
9738 
9739   if (!IsShiftByConstant && !IsCopyOrSelect)
9740     return SDValue();
9741 
9742   if (IsCopyOrSelect && N->hasOneUse())
9743     return SDValue();
9744 
9745   // Attempt to fold the constants, shifting the binop RHS by the shift amount.
9746   SDLoc DL(N);
9747   EVT VT = N->getValueType(0);
9748   if (SDValue NewRHS = DAG.FoldConstantArithmetic(
9749           N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
9750     SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
9751                                    N->getOperand(1));
9752     return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
9753   }
9754 
9755   return SDValue();
9756 }
9757 
distributeTruncateThroughAnd(SDNode * N)9758 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
9759   assert(N->getOpcode() == ISD::TRUNCATE);
9760   assert(N->getOperand(0).getOpcode() == ISD::AND);
9761 
9762   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
9763   EVT TruncVT = N->getValueType(0);
9764   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
9765       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
9766     SDValue N01 = N->getOperand(0).getOperand(1);
9767     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
9768       SDLoc DL(N);
9769       SDValue N00 = N->getOperand(0).getOperand(0);
9770       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
9771       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
9772       AddToWorklist(Trunc00.getNode());
9773       AddToWorklist(Trunc01.getNode());
9774       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
9775     }
9776   }
9777 
9778   return SDValue();
9779 }
9780 
visitRotate(SDNode * N)9781 SDValue DAGCombiner::visitRotate(SDNode *N) {
9782   SDLoc dl(N);
9783   SDValue N0 = N->getOperand(0);
9784   SDValue N1 = N->getOperand(1);
9785   EVT VT = N->getValueType(0);
9786   unsigned Bitsize = VT.getScalarSizeInBits();
9787 
9788   // fold (rot x, 0) -> x
9789   if (isNullOrNullSplat(N1))
9790     return N0;
9791 
9792   // fold (rot x, c) -> x iff (c % BitSize) == 0
9793   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
9794     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
9795     if (DAG.MaskedValueIsZero(N1, ModuloMask))
9796       return N0;
9797   }
9798 
9799   // fold (rot x, c) -> (rot x, c % BitSize)
9800   bool OutOfRange = false;
9801   auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
9802     OutOfRange |= C->getAPIntValue().uge(Bitsize);
9803     return true;
9804   };
9805   if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
9806     EVT AmtVT = N1.getValueType();
9807     SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
9808     if (SDValue Amt =
9809             DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
9810       return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
9811   }
9812 
9813   // rot i16 X, 8 --> bswap X
9814   auto *RotAmtC = isConstOrConstSplat(N1);
9815   if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
9816       VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
9817     return DAG.getNode(ISD::BSWAP, dl, VT, N0);
9818 
9819   // Simplify the operands using demanded-bits information.
9820   if (SimplifyDemandedBits(SDValue(N, 0)))
9821     return SDValue(N, 0);
9822 
9823   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
9824   if (N1.getOpcode() == ISD::TRUNCATE &&
9825       N1.getOperand(0).getOpcode() == ISD::AND) {
9826     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9827       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
9828   }
9829 
9830   unsigned NextOp = N0.getOpcode();
9831 
9832   // fold (rot* (rot* x, c2), c1)
9833   //   -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
9834   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
9835     SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
9836     SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
9837     if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
9838       EVT ShiftVT = C1->getValueType(0);
9839       bool SameSide = (N->getOpcode() == NextOp);
9840       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
9841       SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
9842       SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9843                                                  {N1, BitsizeC});
9844       SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9845                                                  {N0.getOperand(1), BitsizeC});
9846       if (Norm1 && Norm2)
9847         if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
9848                 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
9849           CombinedShift = DAG.FoldConstantArithmetic(ISD::ADD, dl, ShiftVT,
9850                                                      {CombinedShift, BitsizeC});
9851           SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
9852               ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
9853           return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
9854                              CombinedShiftNorm);
9855         }
9856     }
9857   }
9858   return SDValue();
9859 }
9860 
visitSHL(SDNode * N)9861 SDValue DAGCombiner::visitSHL(SDNode *N) {
9862   SDValue N0 = N->getOperand(0);
9863   SDValue N1 = N->getOperand(1);
9864   if (SDValue V = DAG.simplifyShift(N0, N1))
9865     return V;
9866 
9867   SDLoc DL(N);
9868   EVT VT = N0.getValueType();
9869   EVT ShiftVT = N1.getValueType();
9870   unsigned OpSizeInBits = VT.getScalarSizeInBits();
9871 
9872   // fold (shl c1, c2) -> c1<<c2
9873   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N0, N1}))
9874     return C;
9875 
9876   // fold vector ops
9877   if (VT.isVector()) {
9878     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9879       return FoldedVOp;
9880 
9881     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
9882     // If setcc produces all-one true value then:
9883     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
9884     if (N1CV && N1CV->isConstant()) {
9885       if (N0.getOpcode() == ISD::AND) {
9886         SDValue N00 = N0->getOperand(0);
9887         SDValue N01 = N0->getOperand(1);
9888         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
9889 
9890         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
9891             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
9892                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
9893           if (SDValue C =
9894                   DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N01, N1}))
9895             return DAG.getNode(ISD::AND, DL, VT, N00, C);
9896         }
9897       }
9898     }
9899   }
9900 
9901   if (SDValue NewSel = foldBinOpIntoSelect(N))
9902     return NewSel;
9903 
9904   // if (shl x, c) is known to be zero, return 0
9905   if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9906     return DAG.getConstant(0, DL, VT);
9907 
9908   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
9909   if (N1.getOpcode() == ISD::TRUNCATE &&
9910       N1.getOperand(0).getOpcode() == ISD::AND) {
9911     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9912       return DAG.getNode(ISD::SHL, DL, VT, N0, NewOp1);
9913   }
9914 
9915   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
9916   if (N0.getOpcode() == ISD::SHL) {
9917     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9918                                           ConstantSDNode *RHS) {
9919       APInt c1 = LHS->getAPIntValue();
9920       APInt c2 = RHS->getAPIntValue();
9921       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9922       return (c1 + c2).uge(OpSizeInBits);
9923     };
9924     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9925       return DAG.getConstant(0, DL, VT);
9926 
9927     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9928                                        ConstantSDNode *RHS) {
9929       APInt c1 = LHS->getAPIntValue();
9930       APInt c2 = RHS->getAPIntValue();
9931       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9932       return (c1 + c2).ult(OpSizeInBits);
9933     };
9934     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9935       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9936       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
9937     }
9938   }
9939 
9940   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
9941   // For this to be valid, the second form must not preserve any of the bits
9942   // that are shifted out by the inner shift in the first form.  This means
9943   // the outer shift size must be >= the number of bits added by the ext.
9944   // As a corollary, we don't care what kind of ext it is.
9945   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
9946        N0.getOpcode() == ISD::ANY_EXTEND ||
9947        N0.getOpcode() == ISD::SIGN_EXTEND) &&
9948       N0.getOperand(0).getOpcode() == ISD::SHL) {
9949     SDValue N0Op0 = N0.getOperand(0);
9950     SDValue InnerShiftAmt = N0Op0.getOperand(1);
9951     EVT InnerVT = N0Op0.getValueType();
9952     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
9953 
9954     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9955                                                          ConstantSDNode *RHS) {
9956       APInt c1 = LHS->getAPIntValue();
9957       APInt c2 = RHS->getAPIntValue();
9958       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9959       return c2.uge(OpSizeInBits - InnerBitwidth) &&
9960              (c1 + c2).uge(OpSizeInBits);
9961     };
9962     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
9963                                   /*AllowUndefs*/ false,
9964                                   /*AllowTypeMismatch*/ true))
9965       return DAG.getConstant(0, DL, VT);
9966 
9967     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9968                                                       ConstantSDNode *RHS) {
9969       APInt c1 = LHS->getAPIntValue();
9970       APInt c2 = RHS->getAPIntValue();
9971       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9972       return c2.uge(OpSizeInBits - InnerBitwidth) &&
9973              (c1 + c2).ult(OpSizeInBits);
9974     };
9975     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
9976                                   /*AllowUndefs*/ false,
9977                                   /*AllowTypeMismatch*/ true)) {
9978       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
9979       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
9980       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
9981       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
9982     }
9983   }
9984 
9985   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
9986   // Only fold this if the inner zext has no other uses to avoid increasing
9987   // the total number of instructions.
9988   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9989       N0.getOperand(0).getOpcode() == ISD::SRL) {
9990     SDValue N0Op0 = N0.getOperand(0);
9991     SDValue InnerShiftAmt = N0Op0.getOperand(1);
9992 
9993     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9994       APInt c1 = LHS->getAPIntValue();
9995       APInt c2 = RHS->getAPIntValue();
9996       zeroExtendToMatch(c1, c2);
9997       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
9998     };
9999     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
10000                                   /*AllowUndefs*/ false,
10001                                   /*AllowTypeMismatch*/ true)) {
10002       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
10003       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
10004       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
10005       AddToWorklist(NewSHL.getNode());
10006       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
10007     }
10008   }
10009 
10010   if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
10011     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10012                                            ConstantSDNode *RHS) {
10013       const APInt &LHSC = LHS->getAPIntValue();
10014       const APInt &RHSC = RHS->getAPIntValue();
10015       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
10016              LHSC.getZExtValue() <= RHSC.getZExtValue();
10017     };
10018 
10019     // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
10020     // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10021     if (N0->getFlags().hasExact()) {
10022       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10023                                     /*AllowUndefs*/ false,
10024                                     /*AllowTypeMismatch*/ true)) {
10025         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10026         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10027         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10028       }
10029       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10030                                     /*AllowUndefs*/ false,
10031                                     /*AllowTypeMismatch*/ true)) {
10032         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10033         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10034         return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
10035       }
10036     }
10037 
10038     // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10039     //                               (and (srl x, (sub c1, c2), MASK)
10040     // Only fold this if the inner shift has no other uses -- if it does,
10041     // folding this will increase the total number of instructions.
10042     if (N0.getOpcode() == ISD::SRL &&
10043         (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
10044         TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10045       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10046                                     /*AllowUndefs*/ false,
10047                                     /*AllowTypeMismatch*/ true)) {
10048         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10049         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10050         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10051         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
10052         Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
10053         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
10054         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10055       }
10056       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10057                                     /*AllowUndefs*/ false,
10058                                     /*AllowTypeMismatch*/ true)) {
10059         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10060         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10061         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10062         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
10063         SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10064         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10065       }
10066     }
10067   }
10068 
10069   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10070   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
10071       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
10072     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10073     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
10074     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
10075   }
10076 
10077   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10078   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10079   // Variant of version done on multiply, except mul by a power of 2 is turned
10080   // into a shift.
10081   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10082       N0->hasOneUse() && TLI.isDesirableToCommuteWithShift(N, Level)) {
10083     SDValue N01 = N0.getOperand(1);
10084     if (SDValue Shl1 =
10085             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
10086       SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
10087       AddToWorklist(Shl0.getNode());
10088       SDNodeFlags Flags;
10089       // Preserve the disjoint flag for Or.
10090       if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10091         Flags.setDisjoint(true);
10092       return DAG.getNode(N0.getOpcode(), DL, VT, Shl0, Shl1, Flags);
10093     }
10094   }
10095 
10096   // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10097   // TODO: Add zext/add_nuw variant with suitable test coverage
10098   // TODO: Should we limit this with isLegalAddImmediate?
10099   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10100       N0.getOperand(0).getOpcode() == ISD::ADD &&
10101       N0.getOperand(0)->getFlags().hasNoSignedWrap() && N0->hasOneUse() &&
10102       N0.getOperand(0)->hasOneUse() &&
10103       TLI.isDesirableToCommuteWithShift(N, Level)) {
10104     SDValue Add = N0.getOperand(0);
10105     SDLoc DL(N0);
10106     if (SDValue ExtC = DAG.FoldConstantArithmetic(N0.getOpcode(), DL, VT,
10107                                                   {Add.getOperand(1)})) {
10108       if (SDValue ShlC =
10109               DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {ExtC, N1})) {
10110         SDValue ExtX = DAG.getNode(N0.getOpcode(), DL, VT, Add.getOperand(0));
10111         SDValue ShlX = DAG.getNode(ISD::SHL, DL, VT, ExtX, N1);
10112         return DAG.getNode(ISD::ADD, DL, VT, ShlX, ShlC);
10113       }
10114     }
10115   }
10116 
10117   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10118   if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10119     SDValue N01 = N0.getOperand(1);
10120     if (SDValue Shl =
10121             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
10122       return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), Shl);
10123   }
10124 
10125   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10126   if (N1C && !N1C->isOpaque())
10127     if (SDValue NewSHL = visitShiftByConstant(N))
10128       return NewSHL;
10129 
10130   // fold (shl X, cttz(Y)) -> (mul (Y & -Y), X) if cttz is unsupported on the
10131   // target.
10132   if (((N1.getOpcode() == ISD::CTTZ &&
10133         VT.getScalarSizeInBits() <= ShiftVT.getScalarSizeInBits()) ||
10134        N1.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
10135       N1.hasOneUse() && !TLI.isOperationLegalOrCustom(ISD::CTTZ, ShiftVT) &&
10136       TLI.isOperationLegalOrCustom(ISD::MUL, VT)) {
10137     SDValue Y = N1.getOperand(0);
10138     SDLoc DL(N);
10139     SDValue NegY = DAG.getNegative(Y, DL, ShiftVT);
10140     SDValue And =
10141         DAG.getZExtOrTrunc(DAG.getNode(ISD::AND, DL, ShiftVT, Y, NegY), DL, VT);
10142     return DAG.getNode(ISD::MUL, DL, VT, And, N0);
10143   }
10144 
10145   if (SimplifyDemandedBits(SDValue(N, 0)))
10146     return SDValue(N, 0);
10147 
10148   // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10149   if (N0.getOpcode() == ISD::VSCALE && N1C) {
10150     const APInt &C0 = N0.getConstantOperandAPInt(0);
10151     const APInt &C1 = N1C->getAPIntValue();
10152     return DAG.getVScale(DL, VT, C0 << C1);
10153   }
10154 
10155   // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10156   APInt ShlVal;
10157   if (N0.getOpcode() == ISD::STEP_VECTOR &&
10158       ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
10159     const APInt &C0 = N0.getConstantOperandAPInt(0);
10160     if (ShlVal.ult(C0.getBitWidth())) {
10161       APInt NewStep = C0 << ShlVal;
10162       return DAG.getStepVector(DL, VT, NewStep);
10163     }
10164   }
10165 
10166   return SDValue();
10167 }
10168 
10169 // Transform a right shift of a multiply into a multiply-high.
10170 // Examples:
10171 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10172 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const TargetLowering & TLI)10173 static SDValue combineShiftToMULH(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
10174                                   const TargetLowering &TLI) {
10175   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10176          "SRL or SRA node is required here!");
10177 
10178   // Check the shift amount. Proceed with the transformation if the shift
10179   // amount is constant.
10180   ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
10181   if (!ShiftAmtSrc)
10182     return SDValue();
10183 
10184   // The operation feeding into the shift must be a multiply.
10185   SDValue ShiftOperand = N->getOperand(0);
10186   if (ShiftOperand.getOpcode() != ISD::MUL)
10187     return SDValue();
10188 
10189   // Both operands must be equivalent extend nodes.
10190   SDValue LeftOp = ShiftOperand.getOperand(0);
10191   SDValue RightOp = ShiftOperand.getOperand(1);
10192 
10193   bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10194   bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10195 
10196   if (!IsSignExt && !IsZeroExt)
10197     return SDValue();
10198 
10199   EVT NarrowVT = LeftOp.getOperand(0).getValueType();
10200   unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10201 
10202   // return true if U may use the lower bits of its operands
10203   auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10204     if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10205       return true;
10206     }
10207     ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(U->getOperand(1));
10208     if (!UShiftAmtSrc) {
10209       return true;
10210     }
10211     unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10212     return UShiftAmt < NarrowVTSize;
10213   };
10214 
10215   // If the lower part of the MUL is also used and MUL_LOHI is supported
10216   // do not introduce the MULH in favor of MUL_LOHI
10217   unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10218   if (!ShiftOperand.hasOneUse() &&
10219       TLI.isOperationLegalOrCustom(MulLoHiOp, NarrowVT) &&
10220       llvm::any_of(ShiftOperand->uses(), UserOfLowerBits)) {
10221     return SDValue();
10222   }
10223 
10224   SDValue MulhRightOp;
10225   if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
10226     unsigned ActiveBits = IsSignExt
10227                               ? Constant->getAPIntValue().getSignificantBits()
10228                               : Constant->getAPIntValue().getActiveBits();
10229     if (ActiveBits > NarrowVTSize)
10230       return SDValue();
10231     MulhRightOp = DAG.getConstant(
10232         Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
10233         NarrowVT);
10234   } else {
10235     if (LeftOp.getOpcode() != RightOp.getOpcode())
10236       return SDValue();
10237     // Check that the two extend nodes are the same type.
10238     if (NarrowVT != RightOp.getOperand(0).getValueType())
10239       return SDValue();
10240     MulhRightOp = RightOp.getOperand(0);
10241   }
10242 
10243   EVT WideVT = LeftOp.getValueType();
10244   // Proceed with the transformation if the wide types match.
10245   assert((WideVT == RightOp.getValueType()) &&
10246          "Cannot have a multiply node with two different operand types.");
10247 
10248   // Proceed with the transformation if the wide type is twice as large
10249   // as the narrow type.
10250   if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10251     return SDValue();
10252 
10253   // Check the shift amount with the narrow type size.
10254   // Proceed with the transformation if the shift amount is the width
10255   // of the narrow type.
10256   unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10257   if (ShiftAmt != NarrowVTSize)
10258     return SDValue();
10259 
10260   // If the operation feeding into the MUL is a sign extend (sext),
10261   // we use mulhs. Othewise, zero extends (zext) use mulhu.
10262   unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10263 
10264   // Combine to mulh if mulh is legal/custom for the narrow type on the target
10265   // or if it is a vector type then we could transform to an acceptable type and
10266   // rely on legalization to split/combine the result.
10267   if (NarrowVT.isVector()) {
10268     EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), NarrowVT);
10269     if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10270         !TLI.isOperationLegalOrCustom(MulhOpcode, TransformVT))
10271       return SDValue();
10272   } else {
10273     if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
10274       return SDValue();
10275   }
10276 
10277   SDValue Result =
10278       DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
10279   bool IsSigned = N->getOpcode() == ISD::SRA;
10280   return DAG.getExtOrTrunc(IsSigned, Result, DL, WideVT);
10281 }
10282 
10283 // fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10284 // This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
foldBitOrderCrossLogicOp(SDNode * N,SelectionDAG & DAG)10285 static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10286   unsigned Opcode = N->getOpcode();
10287   if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10288     return SDValue();
10289 
10290   SDValue N0 = N->getOperand(0);
10291   EVT VT = N->getValueType(0);
10292   SDLoc DL(N);
10293   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && N0.hasOneUse()) {
10294     SDValue OldLHS = N0.getOperand(0);
10295     SDValue OldRHS = N0.getOperand(1);
10296 
10297     // If both operands are bswap/bitreverse, ignore the multiuse
10298     // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10299     if (OldLHS.getOpcode() == Opcode && OldRHS.getOpcode() == Opcode) {
10300       return DAG.getNode(N0.getOpcode(), DL, VT, OldLHS.getOperand(0),
10301                          OldRHS.getOperand(0));
10302     }
10303 
10304     if (OldLHS.getOpcode() == Opcode && OldLHS.hasOneUse()) {
10305       SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, OldRHS);
10306       return DAG.getNode(N0.getOpcode(), DL, VT, OldLHS.getOperand(0),
10307                          NewBitReorder);
10308     }
10309 
10310     if (OldRHS.getOpcode() == Opcode && OldRHS.hasOneUse()) {
10311       SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, OldLHS);
10312       return DAG.getNode(N0.getOpcode(), DL, VT, NewBitReorder,
10313                          OldRHS.getOperand(0));
10314     }
10315   }
10316   return SDValue();
10317 }
10318 
visitSRA(SDNode * N)10319 SDValue DAGCombiner::visitSRA(SDNode *N) {
10320   SDValue N0 = N->getOperand(0);
10321   SDValue N1 = N->getOperand(1);
10322   if (SDValue V = DAG.simplifyShift(N0, N1))
10323     return V;
10324 
10325   SDLoc DL(N);
10326   EVT VT = N0.getValueType();
10327   unsigned OpSizeInBits = VT.getScalarSizeInBits();
10328 
10329   // fold (sra c1, c2) -> (sra c1, c2)
10330   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, DL, VT, {N0, N1}))
10331     return C;
10332 
10333   // Arithmetic shifting an all-sign-bit value is a no-op.
10334   // fold (sra 0, x) -> 0
10335   // fold (sra -1, x) -> -1
10336   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
10337     return N0;
10338 
10339   // fold vector ops
10340   if (VT.isVector())
10341     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10342       return FoldedVOp;
10343 
10344   if (SDValue NewSel = foldBinOpIntoSelect(N))
10345     return NewSel;
10346 
10347   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10348 
10349   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10350   // clamp (add c1, c2) to max shift.
10351   if (N0.getOpcode() == ISD::SRA) {
10352     EVT ShiftVT = N1.getValueType();
10353     EVT ShiftSVT = ShiftVT.getScalarType();
10354     SmallVector<SDValue, 16> ShiftValues;
10355 
10356     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10357       APInt c1 = LHS->getAPIntValue();
10358       APInt c2 = RHS->getAPIntValue();
10359       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10360       APInt Sum = c1 + c2;
10361       unsigned ShiftSum =
10362           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10363       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
10364       return true;
10365     };
10366     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
10367       SDValue ShiftValue;
10368       if (N1.getOpcode() == ISD::BUILD_VECTOR)
10369         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
10370       else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10371         assert(ShiftValues.size() == 1 &&
10372                "Expected matchBinaryPredicate to return one element for "
10373                "SPLAT_VECTORs");
10374         ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
10375       } else
10376         ShiftValue = ShiftValues[0];
10377       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
10378     }
10379   }
10380 
10381   // fold (sra (shl X, m), (sub result_size, n))
10382   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10383   // result_size - n != m.
10384   // If truncate is free for the target sext(shl) is likely to result in better
10385   // code.
10386   if (N0.getOpcode() == ISD::SHL && N1C) {
10387     // Get the two constants of the shifts, CN0 = m, CN = n.
10388     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
10389     if (N01C) {
10390       LLVMContext &Ctx = *DAG.getContext();
10391       // Determine what the truncate's result bitsize and type would be.
10392       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
10393 
10394       if (VT.isVector())
10395         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10396 
10397       // Determine the residual right-shift amount.
10398       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10399 
10400       // If the shift is not a no-op (in which case this should be just a sign
10401       // extend already), the truncated to type is legal, sign_extend is legal
10402       // on that type, and the truncate to that type is both legal and free,
10403       // perform the transform.
10404       if ((ShiftAmt > 0) &&
10405           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
10406           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
10407           TLI.isTruncateFree(VT, TruncVT)) {
10408         SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
10409         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
10410                                     N0.getOperand(0), Amt);
10411         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
10412                                     Shift);
10413         return DAG.getNode(ISD::SIGN_EXTEND, DL,
10414                            N->getValueType(0), Trunc);
10415       }
10416     }
10417   }
10418 
10419   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10420   //   sra (add (shl X, N1C), AddC), N1C -->
10421   //   sext (add (trunc X to (width - N1C)), AddC')
10422   //   sra (sub AddC, (shl X, N1C)), N1C -->
10423   //   sext (sub AddC1',(trunc X to (width - N1C)))
10424   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10425       N0.hasOneUse()) {
10426     bool IsAdd = N0.getOpcode() == ISD::ADD;
10427     SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
10428     if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
10429         Shl.hasOneUse()) {
10430       // TODO: AddC does not need to be a splat.
10431       if (ConstantSDNode *AddC =
10432               isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
10433         // Determine what the truncate's type would be and ask the target if
10434         // that is a free operation.
10435         LLVMContext &Ctx = *DAG.getContext();
10436         unsigned ShiftAmt = N1C->getZExtValue();
10437         EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
10438         if (VT.isVector())
10439           TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10440 
10441         // TODO: The simple type check probably belongs in the default hook
10442         //       implementation and/or target-specific overrides (because
10443         //       non-simple types likely require masking when legalized), but
10444         //       that restriction may conflict with other transforms.
10445         if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
10446             TLI.isTruncateFree(VT, TruncVT)) {
10447           SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
10448           SDValue ShiftC =
10449               DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
10450                                   TruncVT.getScalarSizeInBits()),
10451                               DL, TruncVT);
10452           SDValue Add;
10453           if (IsAdd)
10454             Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
10455           else
10456             Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
10457           return DAG.getSExtOrTrunc(Add, DL, VT);
10458         }
10459       }
10460     }
10461   }
10462 
10463   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10464   if (N1.getOpcode() == ISD::TRUNCATE &&
10465       N1.getOperand(0).getOpcode() == ISD::AND) {
10466     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10467       return DAG.getNode(ISD::SRA, DL, VT, N0, NewOp1);
10468   }
10469 
10470   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10471   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10472   //      if c1 is equal to the number of bits the trunc removes
10473   // TODO - support non-uniform vector shift amounts.
10474   if (N0.getOpcode() == ISD::TRUNCATE &&
10475       (N0.getOperand(0).getOpcode() == ISD::SRL ||
10476        N0.getOperand(0).getOpcode() == ISD::SRA) &&
10477       N0.getOperand(0).hasOneUse() &&
10478       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
10479     SDValue N0Op0 = N0.getOperand(0);
10480     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
10481       EVT LargeVT = N0Op0.getValueType();
10482       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10483       if (LargeShift->getAPIntValue() == TruncBits) {
10484         EVT LargeShiftVT = getShiftAmountTy(LargeVT);
10485         SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
10486         Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
10487                           DAG.getConstant(TruncBits, DL, LargeShiftVT));
10488         SDValue SRA =
10489             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
10490         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
10491       }
10492     }
10493   }
10494 
10495   // Simplify, based on bits shifted out of the LHS.
10496   if (SimplifyDemandedBits(SDValue(N, 0)))
10497     return SDValue(N, 0);
10498 
10499   // If the sign bit is known to be zero, switch this to a SRL.
10500   if (DAG.SignBitIsZero(N0))
10501     return DAG.getNode(ISD::SRL, DL, VT, N0, N1);
10502 
10503   if (N1C && !N1C->isOpaque())
10504     if (SDValue NewSRA = visitShiftByConstant(N))
10505       return NewSRA;
10506 
10507   // Try to transform this shift into a multiply-high if
10508   // it matches the appropriate pattern detected in combineShiftToMULH.
10509   if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
10510     return MULH;
10511 
10512   // Attempt to convert a sra of a load into a narrower sign-extending load.
10513   if (SDValue NarrowLoad = reduceLoadWidth(N))
10514     return NarrowLoad;
10515 
10516   return SDValue();
10517 }
10518 
visitSRL(SDNode * N)10519 SDValue DAGCombiner::visitSRL(SDNode *N) {
10520   SDValue N0 = N->getOperand(0);
10521   SDValue N1 = N->getOperand(1);
10522   if (SDValue V = DAG.simplifyShift(N0, N1))
10523     return V;
10524 
10525   SDLoc DL(N);
10526   EVT VT = N0.getValueType();
10527   EVT ShiftVT = N1.getValueType();
10528   unsigned OpSizeInBits = VT.getScalarSizeInBits();
10529 
10530   // fold (srl c1, c2) -> c1 >>u c2
10531   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, DL, VT, {N0, N1}))
10532     return C;
10533 
10534   // fold vector ops
10535   if (VT.isVector())
10536     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10537       return FoldedVOp;
10538 
10539   if (SDValue NewSel = foldBinOpIntoSelect(N))
10540     return NewSel;
10541 
10542   // if (srl x, c) is known to be zero, return 0
10543   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10544   if (N1C &&
10545       DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
10546     return DAG.getConstant(0, DL, VT);
10547 
10548   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10549   if (N0.getOpcode() == ISD::SRL) {
10550     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10551                                           ConstantSDNode *RHS) {
10552       APInt c1 = LHS->getAPIntValue();
10553       APInt c2 = RHS->getAPIntValue();
10554       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10555       return (c1 + c2).uge(OpSizeInBits);
10556     };
10557     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
10558       return DAG.getConstant(0, DL, VT);
10559 
10560     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10561                                        ConstantSDNode *RHS) {
10562       APInt c1 = LHS->getAPIntValue();
10563       APInt c2 = RHS->getAPIntValue();
10564       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10565       return (c1 + c2).ult(OpSizeInBits);
10566     };
10567     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
10568       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
10569       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
10570     }
10571   }
10572 
10573   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
10574       N0.getOperand(0).getOpcode() == ISD::SRL) {
10575     SDValue InnerShift = N0.getOperand(0);
10576     // TODO - support non-uniform vector shift amounts.
10577     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
10578       uint64_t c1 = N001C->getZExtValue();
10579       uint64_t c2 = N1C->getZExtValue();
10580       EVT InnerShiftVT = InnerShift.getValueType();
10581       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
10582       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
10583       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
10584       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
10585       if (c1 + OpSizeInBits == InnerShiftSize) {
10586         if (c1 + c2 >= InnerShiftSize)
10587           return DAG.getConstant(0, DL, VT);
10588         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
10589         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
10590                                        InnerShift.getOperand(0), NewShiftAmt);
10591         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
10592       }
10593       // In the more general case, we can clear the high bits after the shift:
10594       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
10595       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
10596           c1 + c2 < InnerShiftSize) {
10597         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
10598         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
10599                                        InnerShift.getOperand(0), NewShiftAmt);
10600         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
10601                                                             OpSizeInBits - c2),
10602                                        DL, InnerShiftVT);
10603         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
10604         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
10605       }
10606     }
10607   }
10608 
10609   // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
10610   //                               (and (srl x, (sub c2, c1), MASK)
10611   if (N0.getOpcode() == ISD::SHL &&
10612       (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
10613       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10614     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10615                                            ConstantSDNode *RHS) {
10616       const APInt &LHSC = LHS->getAPIntValue();
10617       const APInt &RHSC = RHS->getAPIntValue();
10618       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
10619              LHSC.getZExtValue() <= RHSC.getZExtValue();
10620     };
10621     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10622                                   /*AllowUndefs*/ false,
10623                                   /*AllowTypeMismatch*/ true)) {
10624       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10625       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10626       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10627       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
10628       Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
10629       SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10630       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10631     }
10632     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10633                                   /*AllowUndefs*/ false,
10634                                   /*AllowTypeMismatch*/ true)) {
10635       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10636       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10637       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10638       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
10639       SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
10640       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10641     }
10642   }
10643 
10644   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
10645   // TODO - support non-uniform vector shift amounts.
10646   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
10647     // Shifting in all undef bits?
10648     EVT SmallVT = N0.getOperand(0).getValueType();
10649     unsigned BitSize = SmallVT.getScalarSizeInBits();
10650     if (N1C->getAPIntValue().uge(BitSize))
10651       return DAG.getUNDEF(VT);
10652 
10653     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
10654       uint64_t ShiftAmt = N1C->getZExtValue();
10655       SDLoc DL0(N0);
10656       SDValue SmallShift =
10657           DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
10658                       DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
10659       AddToWorklist(SmallShift.getNode());
10660       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
10661       return DAG.getNode(ISD::AND, DL, VT,
10662                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
10663                          DAG.getConstant(Mask, DL, VT));
10664     }
10665   }
10666 
10667   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
10668   // bit, which is unmodified by sra.
10669   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
10670     if (N0.getOpcode() == ISD::SRA)
10671       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), N1);
10672   }
10673 
10674   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit), and x has a power
10675   // of two bitwidth. The "5" represents (log2 (bitwidth x)).
10676   if (N1C && N0.getOpcode() == ISD::CTLZ &&
10677       isPowerOf2_32(OpSizeInBits) &&
10678       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
10679     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
10680 
10681     // If any of the input bits are KnownOne, then the input couldn't be all
10682     // zeros, thus the result of the srl will always be zero.
10683     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
10684 
10685     // If all of the bits input the to ctlz node are known to be zero, then
10686     // the result of the ctlz is "32" and the result of the shift is one.
10687     APInt UnknownBits = ~Known.Zero;
10688     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
10689 
10690     // Otherwise, check to see if there is exactly one bit input to the ctlz.
10691     if (UnknownBits.isPowerOf2()) {
10692       // Okay, we know that only that the single bit specified by UnknownBits
10693       // could be set on input to the CTLZ node. If this bit is set, the SRL
10694       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
10695       // to an SRL/XOR pair, which is likely to simplify more.
10696       unsigned ShAmt = UnknownBits.countr_zero();
10697       SDValue Op = N0.getOperand(0);
10698 
10699       if (ShAmt) {
10700         SDLoc DL(N0);
10701         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
10702                          DAG.getShiftAmountConstant(ShAmt, VT, DL));
10703         AddToWorklist(Op.getNode());
10704       }
10705       return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
10706     }
10707   }
10708 
10709   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
10710   if (N1.getOpcode() == ISD::TRUNCATE &&
10711       N1.getOperand(0).getOpcode() == ISD::AND) {
10712     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10713       return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
10714   }
10715 
10716   // fold operands of srl based on knowledge that the low bits are not
10717   // demanded.
10718   if (SimplifyDemandedBits(SDValue(N, 0)))
10719     return SDValue(N, 0);
10720 
10721   if (N1C && !N1C->isOpaque())
10722     if (SDValue NewSRL = visitShiftByConstant(N))
10723       return NewSRL;
10724 
10725   // Attempt to convert a srl of a load into a narrower zero-extending load.
10726   if (SDValue NarrowLoad = reduceLoadWidth(N))
10727     return NarrowLoad;
10728 
10729   // Here is a common situation. We want to optimize:
10730   //
10731   //   %a = ...
10732   //   %b = and i32 %a, 2
10733   //   %c = srl i32 %b, 1
10734   //   brcond i32 %c ...
10735   //
10736   // into
10737   //
10738   //   %a = ...
10739   //   %b = and %a, 2
10740   //   %c = setcc eq %b, 0
10741   //   brcond %c ...
10742   //
10743   // However when after the source operand of SRL is optimized into AND, the SRL
10744   // itself may not be optimized further. Look for it and add the BRCOND into
10745   // the worklist.
10746   //
10747   // The also tends to happen for binary operations when SimplifyDemandedBits
10748   // is involved.
10749   //
10750   // FIXME: This is unecessary if we process the DAG in topological order,
10751   // which we plan to do. This workaround can be removed once the DAG is
10752   // processed in topological order.
10753   if (N->hasOneUse()) {
10754     SDNode *Use = *N->use_begin();
10755 
10756     // Look pass the truncate.
10757     if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse())
10758       Use = *Use->use_begin();
10759 
10760     if (Use->getOpcode() == ISD::BRCOND || Use->getOpcode() == ISD::AND ||
10761         Use->getOpcode() == ISD::OR || Use->getOpcode() == ISD::XOR)
10762       AddToWorklist(Use);
10763   }
10764 
10765   // Try to transform this shift into a multiply-high if
10766   // it matches the appropriate pattern detected in combineShiftToMULH.
10767   if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
10768     return MULH;
10769 
10770   return SDValue();
10771 }
10772 
visitFunnelShift(SDNode * N)10773 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
10774   EVT VT = N->getValueType(0);
10775   SDValue N0 = N->getOperand(0);
10776   SDValue N1 = N->getOperand(1);
10777   SDValue N2 = N->getOperand(2);
10778   bool IsFSHL = N->getOpcode() == ISD::FSHL;
10779   unsigned BitWidth = VT.getScalarSizeInBits();
10780   SDLoc DL(N);
10781 
10782   // fold (fshl N0, N1, 0) -> N0
10783   // fold (fshr N0, N1, 0) -> N1
10784   if (isPowerOf2_32(BitWidth))
10785     if (DAG.MaskedValueIsZero(
10786             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
10787       return IsFSHL ? N0 : N1;
10788 
10789   auto IsUndefOrZero = [](SDValue V) {
10790     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
10791   };
10792 
10793   // TODO - support non-uniform vector shift amounts.
10794   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
10795     EVT ShAmtTy = N2.getValueType();
10796 
10797     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
10798     if (Cst->getAPIntValue().uge(BitWidth)) {
10799       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
10800       return DAG.getNode(N->getOpcode(), DL, VT, N0, N1,
10801                          DAG.getConstant(RotAmt, DL, ShAmtTy));
10802     }
10803 
10804     unsigned ShAmt = Cst->getZExtValue();
10805     if (ShAmt == 0)
10806       return IsFSHL ? N0 : N1;
10807 
10808     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
10809     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
10810     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
10811     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
10812     if (IsUndefOrZero(N0))
10813       return DAG.getNode(
10814           ISD::SRL, DL, VT, N1,
10815           DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt, DL, ShAmtTy));
10816     if (IsUndefOrZero(N1))
10817       return DAG.getNode(
10818           ISD::SHL, DL, VT, N0,
10819           DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt, DL, ShAmtTy));
10820 
10821     // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10822     // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10823     // TODO - bigendian support once we have test coverage.
10824     // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
10825     // TODO - permit LHS EXTLOAD if extensions are shifted out.
10826     if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
10827         !DAG.getDataLayout().isBigEndian()) {
10828       auto *LHS = dyn_cast<LoadSDNode>(N0);
10829       auto *RHS = dyn_cast<LoadSDNode>(N1);
10830       if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
10831           LHS->getAddressSpace() == RHS->getAddressSpace() &&
10832           (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
10833           ISD::isNON_EXTLoad(LHS)) {
10834         if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
10835           SDLoc DL(RHS);
10836           uint64_t PtrOff =
10837               IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
10838           Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
10839           unsigned Fast = 0;
10840           if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
10841                                      RHS->getAddressSpace(), NewAlign,
10842                                      RHS->getMemOperand()->getFlags(), &Fast) &&
10843               Fast) {
10844             SDValue NewPtr = DAG.getMemBasePlusOffset(
10845                 RHS->getBasePtr(), TypeSize::getFixed(PtrOff), DL);
10846             AddToWorklist(NewPtr.getNode());
10847             SDValue Load = DAG.getLoad(
10848                 VT, DL, RHS->getChain(), NewPtr,
10849                 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
10850                 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
10851             // Replace the old load's chain with the new load's chain.
10852             WorklistRemover DeadNodes(*this);
10853             DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
10854             return Load;
10855           }
10856         }
10857       }
10858     }
10859   }
10860 
10861   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
10862   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
10863   // iff We know the shift amount is in range.
10864   // TODO: when is it worth doing SUB(BW, N2) as well?
10865   if (isPowerOf2_32(BitWidth)) {
10866     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
10867     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10868       return DAG.getNode(ISD::SRL, DL, VT, N1, N2);
10869     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10870       return DAG.getNode(ISD::SHL, DL, VT, N0, N2);
10871   }
10872 
10873   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
10874   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
10875   // TODO: Investigate flipping this rotate if only one is legal.
10876   // If funnel shift is legal as well we might be better off avoiding
10877   // non-constant (BW - N2).
10878   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
10879   if (N0 == N1 && hasOperation(RotOpc, VT))
10880     return DAG.getNode(RotOpc, DL, VT, N0, N2);
10881 
10882   // Simplify, based on bits shifted out of N0/N1.
10883   if (SimplifyDemandedBits(SDValue(N, 0)))
10884     return SDValue(N, 0);
10885 
10886   return SDValue();
10887 }
10888 
visitSHLSAT(SDNode * N)10889 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
10890   SDValue N0 = N->getOperand(0);
10891   SDValue N1 = N->getOperand(1);
10892   if (SDValue V = DAG.simplifyShift(N0, N1))
10893     return V;
10894 
10895   SDLoc DL(N);
10896   EVT VT = N0.getValueType();
10897 
10898   // fold (*shlsat c1, c2) -> c1<<c2
10899   if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
10900     return C;
10901 
10902   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10903 
10904   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
10905     // fold (sshlsat x, c) -> (shl x, c)
10906     if (N->getOpcode() == ISD::SSHLSAT && N1C &&
10907         N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
10908       return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
10909 
10910     // fold (ushlsat x, c) -> (shl x, c)
10911     if (N->getOpcode() == ISD::USHLSAT && N1C &&
10912         N1C->getAPIntValue().ule(
10913             DAG.computeKnownBits(N0).countMinLeadingZeros()))
10914       return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
10915   }
10916 
10917   return SDValue();
10918 }
10919 
10920 // Given a ABS node, detect the following patterns:
10921 // (ABS (SUB (EXTEND a), (EXTEND b))).
10922 // (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
10923 // Generates UABD/SABD instruction.
foldABSToABD(SDNode * N,const SDLoc & DL)10924 SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
10925   EVT SrcVT = N->getValueType(0);
10926 
10927   if (N->getOpcode() == ISD::TRUNCATE)
10928     N = N->getOperand(0).getNode();
10929 
10930   if (N->getOpcode() != ISD::ABS)
10931     return SDValue();
10932 
10933   EVT VT = N->getValueType(0);
10934   SDValue AbsOp1 = N->getOperand(0);
10935   SDValue Op0, Op1;
10936 
10937   if (AbsOp1.getOpcode() != ISD::SUB)
10938     return SDValue();
10939 
10940   Op0 = AbsOp1.getOperand(0);
10941   Op1 = AbsOp1.getOperand(1);
10942 
10943   unsigned Opc0 = Op0.getOpcode();
10944 
10945   // Check if the operands of the sub are (zero|sign)-extended.
10946   // TODO: Should we use ValueTracking instead?
10947   if (Opc0 != Op1.getOpcode() ||
10948       (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
10949        Opc0 != ISD::SIGN_EXTEND_INREG)) {
10950     // fold (abs (sub nsw x, y)) -> abds(x, y)
10951     if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) &&
10952         TLI.preferABDSToABSWithNSW(VT)) {
10953       SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
10954       return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
10955     }
10956     return SDValue();
10957   }
10958 
10959   EVT VT0, VT1;
10960   if (Opc0 == ISD::SIGN_EXTEND_INREG) {
10961     VT0 = cast<VTSDNode>(Op0.getOperand(1))->getVT();
10962     VT1 = cast<VTSDNode>(Op1.getOperand(1))->getVT();
10963   } else {
10964     VT0 = Op0.getOperand(0).getValueType();
10965     VT1 = Op1.getOperand(0).getValueType();
10966   }
10967   unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
10968 
10969   // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
10970   // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
10971   EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
10972   if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10973       (VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(ABDOpcode, MaxVT)) {
10974     SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
10975                               DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
10976                               DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
10977     ABD = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ABD);
10978     return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
10979   }
10980 
10981   // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
10982   // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10983   if (hasOperation(ABDOpcode, VT)) {
10984     SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
10985     return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
10986   }
10987 
10988   return SDValue();
10989 }
10990 
visitABS(SDNode * N)10991 SDValue DAGCombiner::visitABS(SDNode *N) {
10992   SDValue N0 = N->getOperand(0);
10993   EVT VT = N->getValueType(0);
10994   SDLoc DL(N);
10995 
10996   // fold (abs c1) -> c2
10997   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ABS, DL, VT, {N0}))
10998     return C;
10999   // fold (abs (abs x)) -> (abs x)
11000   if (N0.getOpcode() == ISD::ABS)
11001     return N0;
11002   // fold (abs x) -> x iff not-negative
11003   if (DAG.SignBitIsZero(N0))
11004     return N0;
11005 
11006   if (SDValue ABD = foldABSToABD(N, DL))
11007     return ABD;
11008 
11009   // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
11010   // iff zero_extend/truncate are free.
11011   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11012     EVT ExtVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
11013     if (TLI.isTruncateFree(VT, ExtVT) && TLI.isZExtFree(ExtVT, VT) &&
11014         TLI.isTypeDesirableForOp(ISD::ABS, ExtVT) &&
11015         hasOperation(ISD::ABS, ExtVT)) {
11016       return DAG.getNode(
11017           ISD::ZERO_EXTEND, DL, VT,
11018           DAG.getNode(ISD::ABS, DL, ExtVT,
11019                       DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N0.getOperand(0))));
11020     }
11021   }
11022 
11023   return SDValue();
11024 }
11025 
visitBSWAP(SDNode * N)11026 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11027   SDValue N0 = N->getOperand(0);
11028   EVT VT = N->getValueType(0);
11029   SDLoc DL(N);
11030 
11031   // fold (bswap c1) -> c2
11032   if (SDValue C = DAG.FoldConstantArithmetic(ISD::BSWAP, DL, VT, {N0}))
11033     return C;
11034   // fold (bswap (bswap x)) -> x
11035   if (N0.getOpcode() == ISD::BSWAP)
11036     return N0.getOperand(0);
11037 
11038   // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11039   // isn't supported, it will be expanded to bswap followed by a manual reversal
11040   // of bits in each byte. By placing bswaps before bitreverse, we can remove
11041   // the two bswaps if the bitreverse gets expanded.
11042   if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11043     SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11044     return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
11045   }
11046 
11047   // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11048   // iff x >= bw/2 (i.e. lower half is known zero)
11049   unsigned BW = VT.getScalarSizeInBits();
11050   if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11051     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11052     EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
11053     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11054         ShAmt->getZExtValue() >= (BW / 2) &&
11055         (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
11056         TLI.isTruncateFree(VT, HalfVT) &&
11057         (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
11058       SDValue Res = N0.getOperand(0);
11059       if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11060         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
11061                           DAG.getShiftAmountConstant(NewShAmt, VT, DL));
11062       Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
11063       Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
11064       return DAG.getZExtOrTrunc(Res, DL, VT);
11065     }
11066   }
11067 
11068   // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11069   // inverse-shift-of-bswap:
11070   // bswap (X u<< C) --> (bswap X) u>> C
11071   // bswap (X u>> C) --> (bswap X) u<< C
11072   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11073       N0.hasOneUse()) {
11074     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11075     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11076         ShAmt->getZExtValue() % 8 == 0) {
11077       SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11078       unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11079       return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
11080     }
11081   }
11082 
11083   if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11084     return V;
11085 
11086   return SDValue();
11087 }
11088 
visitBITREVERSE(SDNode * N)11089 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11090   SDValue N0 = N->getOperand(0);
11091   EVT VT = N->getValueType(0);
11092   SDLoc DL(N);
11093 
11094   // fold (bitreverse c1) -> c2
11095   if (SDValue C = DAG.FoldConstantArithmetic(ISD::BITREVERSE, DL, VT, {N0}))
11096     return C;
11097 
11098   // fold (bitreverse (bitreverse x)) -> x
11099   if (N0.getOpcode() == ISD::BITREVERSE)
11100     return N0.getOperand(0);
11101 
11102   SDValue X, Y;
11103 
11104   // fold (bitreverse (lshr (bitreverse x), y)) -> (shl x, y)
11105   if ((!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
11106       sd_match(N, m_BitReverse(m_Srl(m_BitReverse(m_Value(X)), m_Value(Y)))))
11107     return DAG.getNode(ISD::SHL, DL, VT, X, Y);
11108 
11109   // fold (bitreverse (shl (bitreverse x), y)) -> (lshr x, y)
11110   if ((!LegalOperations || TLI.isOperationLegal(ISD::SRL, VT)) &&
11111       sd_match(N, m_BitReverse(m_Shl(m_BitReverse(m_Value(X)), m_Value(Y)))))
11112     return DAG.getNode(ISD::SRL, DL, VT, X, Y);
11113 
11114   return SDValue();
11115 }
11116 
visitCTLZ(SDNode * N)11117 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
11118   SDValue N0 = N->getOperand(0);
11119   EVT VT = N->getValueType(0);
11120   SDLoc DL(N);
11121 
11122   // fold (ctlz c1) -> c2
11123   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTLZ, DL, VT, {N0}))
11124     return C;
11125 
11126   // If the value is known never to be zero, switch to the undef version.
11127   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT))
11128     if (DAG.isKnownNeverZero(N0))
11129       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, N0);
11130 
11131   return SDValue();
11132 }
11133 
visitCTLZ_ZERO_UNDEF(SDNode * N)11134 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
11135   SDValue N0 = N->getOperand(0);
11136   EVT VT = N->getValueType(0);
11137   SDLoc DL(N);
11138 
11139   // fold (ctlz_zero_undef c1) -> c2
11140   if (SDValue C =
11141           DAG.FoldConstantArithmetic(ISD::CTLZ_ZERO_UNDEF, DL, VT, {N0}))
11142     return C;
11143   return SDValue();
11144 }
11145 
visitCTTZ(SDNode * N)11146 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
11147   SDValue N0 = N->getOperand(0);
11148   EVT VT = N->getValueType(0);
11149   SDLoc DL(N);
11150 
11151   // fold (cttz c1) -> c2
11152   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTTZ, DL, VT, {N0}))
11153     return C;
11154 
11155   // If the value is known never to be zero, switch to the undef version.
11156   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT))
11157     if (DAG.isKnownNeverZero(N0))
11158       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, DL, VT, N0);
11159 
11160   return SDValue();
11161 }
11162 
visitCTTZ_ZERO_UNDEF(SDNode * N)11163 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11164   SDValue N0 = N->getOperand(0);
11165   EVT VT = N->getValueType(0);
11166   SDLoc DL(N);
11167 
11168   // fold (cttz_zero_undef c1) -> c2
11169   if (SDValue C =
11170           DAG.FoldConstantArithmetic(ISD::CTTZ_ZERO_UNDEF, DL, VT, {N0}))
11171     return C;
11172   return SDValue();
11173 }
11174 
visitCTPOP(SDNode * N)11175 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11176   SDValue N0 = N->getOperand(0);
11177   EVT VT = N->getValueType(0);
11178   unsigned NumBits = VT.getScalarSizeInBits();
11179   SDLoc DL(N);
11180 
11181   // fold (ctpop c1) -> c2
11182   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTPOP, DL, VT, {N0}))
11183     return C;
11184 
11185   // If the source is being shifted, but doesn't affect any active bits,
11186   // then we can call CTPOP on the shift source directly.
11187   if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
11188     if (ConstantSDNode *AmtC = isConstOrConstSplat(N0.getOperand(1))) {
11189       const APInt &Amt = AmtC->getAPIntValue();
11190       if (Amt.ult(NumBits)) {
11191         KnownBits KnownSrc = DAG.computeKnownBits(N0.getOperand(0));
11192         if ((N0.getOpcode() == ISD::SRL &&
11193              Amt.ule(KnownSrc.countMinTrailingZeros())) ||
11194             (N0.getOpcode() == ISD::SHL &&
11195              Amt.ule(KnownSrc.countMinLeadingZeros()))) {
11196           return DAG.getNode(ISD::CTPOP, DL, VT, N0.getOperand(0));
11197         }
11198       }
11199     }
11200   }
11201 
11202   // If the upper bits are known to be zero, then see if its profitable to
11203   // only count the lower bits.
11204   if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
11205     EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), NumBits / 2);
11206     if (hasOperation(ISD::CTPOP, HalfVT) &&
11207         TLI.isTypeDesirableForOp(ISD::CTPOP, HalfVT) &&
11208         TLI.isTruncateFree(N0, HalfVT) && TLI.isZExtFree(HalfVT, VT)) {
11209       APInt UpperBits = APInt::getHighBitsSet(NumBits, NumBits / 2);
11210       if (DAG.MaskedValueIsZero(N0, UpperBits)) {
11211         SDValue PopCnt = DAG.getNode(ISD::CTPOP, DL, HalfVT,
11212                                      DAG.getZExtOrTrunc(N0, DL, HalfVT));
11213         return DAG.getZExtOrTrunc(PopCnt, DL, VT);
11214       }
11215     }
11216   }
11217 
11218   return SDValue();
11219 }
11220 
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const SDNodeFlags Flags,const TargetLowering & TLI)11221 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11222                                          SDValue RHS, const SDNodeFlags Flags,
11223                                          const TargetLowering &TLI) {
11224   EVT VT = LHS.getValueType();
11225   if (!VT.isFloatingPoint())
11226     return false;
11227 
11228   const TargetOptions &Options = DAG.getTarget().Options;
11229 
11230   return (Flags.hasNoSignedZeros() || Options.NoSignedZerosFPMath) &&
11231          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11232          (Flags.hasNoNaNs() ||
11233           (DAG.isKnownNeverNaN(RHS) && DAG.isKnownNeverNaN(LHS)));
11234 }
11235 
combineMinNumMaxNumImpl(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)11236 static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11237                                        SDValue RHS, SDValue True, SDValue False,
11238                                        ISD::CondCode CC,
11239                                        const TargetLowering &TLI,
11240                                        SelectionDAG &DAG) {
11241   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
11242   switch (CC) {
11243   case ISD::SETOLT:
11244   case ISD::SETOLE:
11245   case ISD::SETLT:
11246   case ISD::SETLE:
11247   case ISD::SETULT:
11248   case ISD::SETULE: {
11249     // Since it's known never nan to get here already, either fminnum or
11250     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11251     // expanded in terms of it.
11252     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11253     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11254       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11255 
11256     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11257     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11258       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11259     return SDValue();
11260   }
11261   case ISD::SETOGT:
11262   case ISD::SETOGE:
11263   case ISD::SETGT:
11264   case ISD::SETGE:
11265   case ISD::SETUGT:
11266   case ISD::SETUGE: {
11267     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11268     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11269       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11270 
11271     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11272     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11273       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11274     return SDValue();
11275   }
11276   default:
11277     return SDValue();
11278   }
11279 }
11280 
11281 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC)11282 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11283                                          SDValue RHS, SDValue True,
11284                                          SDValue False, ISD::CondCode CC) {
11285   if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11286     return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11287 
11288   // If we can't directly match this, try to see if we can pull an fneg out of
11289   // the select.
11290   SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11291       True, DAG, LegalOperations, ForCodeSize);
11292   if (!NegTrue)
11293     return SDValue();
11294 
11295   HandleSDNode NegTrueHandle(NegTrue);
11296 
11297   // Try to unfold an fneg from the select if we are comparing the negated
11298   // constant.
11299   //
11300   // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11301   //
11302   // TODO: Handle fabs
11303   if (LHS == NegTrue) {
11304     // If we can't directly match this, try to see if we can pull an fneg out of
11305     // the select.
11306     SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11307         RHS, DAG, LegalOperations, ForCodeSize);
11308     if (NegRHS) {
11309       HandleSDNode NegRHSHandle(NegRHS);
11310       if (NegRHS == False) {
11311         SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
11312                                                    False, CC, TLI, DAG);
11313         if (Combined)
11314           return DAG.getNode(ISD::FNEG, DL, VT, Combined);
11315       }
11316     }
11317   }
11318 
11319   return SDValue();
11320 }
11321 
11322 /// If a (v)select has a condition value that is a sign-bit test, try to smear
11323 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)11324 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, const SDLoc &DL,
11325                                              SelectionDAG &DAG) {
11326   SDValue Cond = N->getOperand(0);
11327   SDValue C1 = N->getOperand(1);
11328   SDValue C2 = N->getOperand(2);
11329   if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
11330     return SDValue();
11331 
11332   EVT VT = N->getValueType(0);
11333   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11334       VT != Cond.getOperand(0).getValueType())
11335     return SDValue();
11336 
11337   // The inverted-condition + commuted-select variants of these patterns are
11338   // canonicalized to these forms in IR.
11339   SDValue X = Cond.getOperand(0);
11340   SDValue CondC = Cond.getOperand(1);
11341   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11342   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
11343       isAllOnesOrAllOnesSplat(C2)) {
11344     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11345     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11346     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11347     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
11348   }
11349   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
11350     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11351     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11352     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11353     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
11354   }
11355   return SDValue();
11356 }
11357 
shouldConvertSelectOfConstantsToMath(const SDValue & Cond,EVT VT,const TargetLowering & TLI)11358 static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11359                                                  const TargetLowering &TLI) {
11360   if (!TLI.convertSelectOfConstantsToMath(VT))
11361     return false;
11362 
11363   if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11364     return true;
11365   if (!TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))
11366     return true;
11367 
11368   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11369   if (CC == ISD::SETLT && isNullOrNullSplat(Cond.getOperand(1)))
11370     return true;
11371   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond.getOperand(1)))
11372     return true;
11373 
11374   return false;
11375 }
11376 
foldSelectOfConstants(SDNode * N)11377 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11378   SDValue Cond = N->getOperand(0);
11379   SDValue N1 = N->getOperand(1);
11380   SDValue N2 = N->getOperand(2);
11381   EVT VT = N->getValueType(0);
11382   EVT CondVT = Cond.getValueType();
11383   SDLoc DL(N);
11384 
11385   if (!VT.isInteger())
11386     return SDValue();
11387 
11388   auto *C1 = dyn_cast<ConstantSDNode>(N1);
11389   auto *C2 = dyn_cast<ConstantSDNode>(N2);
11390   if (!C1 || !C2)
11391     return SDValue();
11392 
11393   if (CondVT != MVT::i1 || LegalOperations) {
11394     // fold (select Cond, 0, 1) -> (xor Cond, 1)
11395     // We can't do this reliably if integer based booleans have different contents
11396     // to floating point based booleans. This is because we can't tell whether we
11397     // have an integer-based boolean or a floating-point-based boolean unless we
11398     // can find the SETCC that produced it and inspect its operands. This is
11399     // fairly easy if C is the SETCC node, but it can potentially be
11400     // undiscoverable (or not reasonably discoverable). For example, it could be
11401     // in another basic block or it could require searching a complicated
11402     // expression.
11403     if (CondVT.isInteger() &&
11404         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11405             TargetLowering::ZeroOrOneBooleanContent &&
11406         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11407             TargetLowering::ZeroOrOneBooleanContent &&
11408         C1->isZero() && C2->isOne()) {
11409       SDValue NotCond =
11410           DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
11411       if (VT.bitsEq(CondVT))
11412         return NotCond;
11413       return DAG.getZExtOrTrunc(NotCond, DL, VT);
11414     }
11415 
11416     return SDValue();
11417   }
11418 
11419   // Only do this before legalization to avoid conflicting with target-specific
11420   // transforms in the other direction (create a select from a zext/sext). There
11421   // is also a target-independent combine here in DAGCombiner in the other
11422   // direction for (select Cond, -1, 0) when the condition is not i1.
11423   assert(CondVT == MVT::i1 && !LegalOperations);
11424 
11425   // select Cond, 1, 0 --> zext (Cond)
11426   if (C1->isOne() && C2->isZero())
11427     return DAG.getZExtOrTrunc(Cond, DL, VT);
11428 
11429   // select Cond, -1, 0 --> sext (Cond)
11430   if (C1->isAllOnes() && C2->isZero())
11431     return DAG.getSExtOrTrunc(Cond, DL, VT);
11432 
11433   // select Cond, 0, 1 --> zext (!Cond)
11434   if (C1->isZero() && C2->isOne()) {
11435     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11436     NotCond = DAG.getZExtOrTrunc(NotCond, DL, VT);
11437     return NotCond;
11438   }
11439 
11440   // select Cond, 0, -1 --> sext (!Cond)
11441   if (C1->isZero() && C2->isAllOnes()) {
11442     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11443     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
11444     return NotCond;
11445   }
11446 
11447   // Use a target hook because some targets may prefer to transform in the
11448   // other direction.
11449   if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11450     return SDValue();
11451 
11452   // For any constants that differ by 1, we can transform the select into
11453   // an extend and add.
11454   const APInt &C1Val = C1->getAPIntValue();
11455   const APInt &C2Val = C2->getAPIntValue();
11456 
11457   // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11458   if (C1Val - 1 == C2Val) {
11459     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
11460     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
11461   }
11462 
11463   // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11464   if (C1Val + 1 == C2Val) {
11465     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
11466     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
11467   }
11468 
11469   // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
11470   if (C1Val.isPowerOf2() && C2Val.isZero()) {
11471     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
11472     SDValue ShAmtC =
11473         DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
11474     return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
11475   }
11476 
11477   // select Cond, -1, C --> or (sext Cond), C
11478   if (C1->isAllOnes()) {
11479     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
11480     return DAG.getNode(ISD::OR, DL, VT, Cond, N2);
11481   }
11482 
11483   // select Cond, C, -1 --> or (sext (not Cond)), C
11484   if (C2->isAllOnes()) {
11485     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11486     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
11487     return DAG.getNode(ISD::OR, DL, VT, NotCond, N1);
11488   }
11489 
11490   if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
11491     return V;
11492 
11493   return SDValue();
11494 }
11495 
11496 template <class MatchContextClass>
foldBoolSelectToLogic(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)11497 static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
11498                                      SelectionDAG &DAG) {
11499   assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
11500           N->getOpcode() == ISD::VP_SELECT) &&
11501          "Expected a (v)(vp.)select");
11502   SDValue Cond = N->getOperand(0);
11503   SDValue T = N->getOperand(1), F = N->getOperand(2);
11504   EVT VT = N->getValueType(0);
11505   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11506   MatchContextClass matcher(DAG, TLI, N);
11507 
11508   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
11509     return SDValue();
11510 
11511   // select Cond, Cond, F --> or Cond, freeze(F)
11512   // select Cond, 1, F    --> or Cond, freeze(F)
11513   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
11514     return matcher.getNode(ISD::OR, DL, VT, Cond, DAG.getFreeze(F));
11515 
11516   // select Cond, T, Cond --> and Cond, freeze(T)
11517   // select Cond, T, 0    --> and Cond, freeze(T)
11518   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
11519     return matcher.getNode(ISD::AND, DL, VT, Cond, DAG.getFreeze(T));
11520 
11521   // select Cond, T, 1 --> or (not Cond), freeze(T)
11522   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
11523     SDValue NotCond =
11524         matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
11525     return matcher.getNode(ISD::OR, DL, VT, NotCond, DAG.getFreeze(T));
11526   }
11527 
11528   // select Cond, 0, F --> and (not Cond), freeze(F)
11529   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
11530     SDValue NotCond =
11531         matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
11532     return matcher.getNode(ISD::AND, DL, VT, NotCond, DAG.getFreeze(F));
11533   }
11534 
11535   return SDValue();
11536 }
11537 
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)11538 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
11539   SDValue N0 = N->getOperand(0);
11540   SDValue N1 = N->getOperand(1);
11541   SDValue N2 = N->getOperand(2);
11542   EVT VT = N->getValueType(0);
11543 
11544   SDValue Cond0, Cond1;
11545   ISD::CondCode CC;
11546   if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
11547                                      m_CondCode(CC)))) ||
11548       VT != Cond0.getValueType())
11549     return SDValue();
11550 
11551   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
11552   // compare is inverted from that pattern ("Cond0 s> -1").
11553   if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
11554     ; // This is the pattern we are looking for.
11555   else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
11556     std::swap(N1, N2);
11557   else
11558     return SDValue();
11559 
11560   // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & freeze(N1)
11561   if (isNullOrNullSplat(N2)) {
11562     SDLoc DL(N);
11563     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
11564     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
11565     return DAG.getNode(ISD::AND, DL, VT, Sra, DAG.getFreeze(N1));
11566   }
11567 
11568   // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | freeze(N2)
11569   if (isAllOnesOrAllOnesSplat(N1)) {
11570     SDLoc DL(N);
11571     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
11572     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
11573     return DAG.getNode(ISD::OR, DL, VT, Sra, DAG.getFreeze(N2));
11574   }
11575 
11576   // If we have to invert the sign bit mask, only do that transform if the
11577   // target has a bitwise 'and not' instruction (the invert is free).
11578   // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & freeze(N2)
11579   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11580   if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
11581     SDLoc DL(N);
11582     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
11583     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
11584     SDValue Not = DAG.getNOT(DL, Sra, VT);
11585     return DAG.getNode(ISD::AND, DL, VT, Not, DAG.getFreeze(N2));
11586   }
11587 
11588   // TODO: There's another pattern in this family, but it may require
11589   //       implementing hasOrNot() to check for profitability:
11590   //       (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | freeze(N2)
11591 
11592   return SDValue();
11593 }
11594 
visitSELECT(SDNode * N)11595 SDValue DAGCombiner::visitSELECT(SDNode *N) {
11596   SDValue N0 = N->getOperand(0);
11597   SDValue N1 = N->getOperand(1);
11598   SDValue N2 = N->getOperand(2);
11599   EVT VT = N->getValueType(0);
11600   EVT VT0 = N0.getValueType();
11601   SDLoc DL(N);
11602   SDNodeFlags Flags = N->getFlags();
11603 
11604   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
11605     return V;
11606 
11607   if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
11608     return V;
11609 
11610   // select (not Cond), N1, N2 -> select Cond, N2, N1
11611   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
11612     SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
11613     SelectOp->setFlags(Flags);
11614     return SelectOp;
11615   }
11616 
11617   if (SDValue V = foldSelectOfConstants(N))
11618     return V;
11619 
11620   // If we can fold this based on the true/false value, do so.
11621   if (SimplifySelectOps(N, N1, N2))
11622     return SDValue(N, 0); // Don't revisit N.
11623 
11624   if (VT0 == MVT::i1) {
11625     // The code in this block deals with the following 2 equivalences:
11626     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
11627     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
11628     // The target can specify its preferred form with the
11629     // shouldNormalizeToSelectSequence() callback. However we always transform
11630     // to the right anyway if we find the inner select exists in the DAG anyway
11631     // and we always transform to the left side if we know that we can further
11632     // optimize the combination of the conditions.
11633     bool normalizeToSequence =
11634         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
11635     // select (and Cond0, Cond1), X, Y
11636     //   -> select Cond0, (select Cond1, X, Y), Y
11637     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
11638       SDValue Cond0 = N0->getOperand(0);
11639       SDValue Cond1 = N0->getOperand(1);
11640       SDValue InnerSelect =
11641           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
11642       if (normalizeToSequence || !InnerSelect.use_empty())
11643         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
11644                            InnerSelect, N2, Flags);
11645       // Cleanup on failure.
11646       if (InnerSelect.use_empty())
11647         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
11648     }
11649     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
11650     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
11651       SDValue Cond0 = N0->getOperand(0);
11652       SDValue Cond1 = N0->getOperand(1);
11653       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
11654                                         Cond1, N1, N2, Flags);
11655       if (normalizeToSequence || !InnerSelect.use_empty())
11656         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
11657                            InnerSelect, Flags);
11658       // Cleanup on failure.
11659       if (InnerSelect.use_empty())
11660         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
11661     }
11662 
11663     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
11664     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
11665       SDValue N1_0 = N1->getOperand(0);
11666       SDValue N1_1 = N1->getOperand(1);
11667       SDValue N1_2 = N1->getOperand(2);
11668       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
11669         // Create the actual and node if we can generate good code for it.
11670         if (!normalizeToSequence) {
11671           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
11672           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
11673                              N2, Flags);
11674         }
11675         // Otherwise see if we can optimize the "and" to a better pattern.
11676         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
11677           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
11678                              N2, Flags);
11679         }
11680       }
11681     }
11682     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
11683     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
11684       SDValue N2_0 = N2->getOperand(0);
11685       SDValue N2_1 = N2->getOperand(1);
11686       SDValue N2_2 = N2->getOperand(2);
11687       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
11688         // Create the actual or node if we can generate good code for it.
11689         if (!normalizeToSequence) {
11690           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
11691           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
11692                              N2_2, Flags);
11693         }
11694         // Otherwise see if we can optimize to a better pattern.
11695         if (SDValue Combined = visitORLike(N0, N2_0, DL))
11696           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
11697                              N2_2, Flags);
11698       }
11699     }
11700   }
11701 
11702   // Fold selects based on a setcc into other things, such as min/max/abs.
11703   if (N0.getOpcode() == ISD::SETCC) {
11704     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
11705     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11706 
11707     // select (fcmp lt x, y), x, y -> fminnum x, y
11708     // select (fcmp gt x, y), x, y -> fmaxnum x, y
11709     //
11710     // This is OK if we don't care what happens if either operand is a NaN.
11711     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, Flags, TLI))
11712       if (SDValue FMinMax =
11713               combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
11714         return FMinMax;
11715 
11716     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
11717     // This is conservatively limited to pre-legal-operations to give targets
11718     // a chance to reverse the transform if they want to do that. Also, it is
11719     // unlikely that the pattern would be formed late, so it's probably not
11720     // worth going through the other checks.
11721     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
11722         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
11723         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
11724       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
11725       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
11726       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
11727         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
11728         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
11729         //
11730         // The IR equivalent of this transform would have this form:
11731         //   %a = add %x, C
11732         //   %c = icmp ugt %x, ~C
11733         //   %r = select %c, -1, %a
11734         //   =>
11735         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
11736         //   %u0 = extractvalue %u, 0
11737         //   %u1 = extractvalue %u, 1
11738         //   %r = select %u1, -1, %u0
11739         SDVTList VTs = DAG.getVTList(VT, VT0);
11740         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
11741         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
11742       }
11743     }
11744 
11745     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
11746         (!LegalOperations &&
11747          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
11748       // Any flags available in a select/setcc fold will be on the setcc as they
11749       // migrated from fcmp
11750       Flags = N0->getFlags();
11751       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
11752                                        N2, N0.getOperand(2));
11753       SelectNode->setFlags(Flags);
11754       return SelectNode;
11755     }
11756 
11757     if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
11758       return NewSel;
11759   }
11760 
11761   if (!VT.isVector())
11762     if (SDValue BinOp = foldSelectOfBinops(N))
11763       return BinOp;
11764 
11765   if (SDValue R = combineSelectAsExtAnd(N0, N1, N2, DL, DAG))
11766     return R;
11767 
11768   return SDValue();
11769 }
11770 
11771 // This function assumes all the vselect's arguments are CONCAT_VECTOR
11772 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)11773 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
11774   SDLoc DL(N);
11775   SDValue Cond = N->getOperand(0);
11776   SDValue LHS = N->getOperand(1);
11777   SDValue RHS = N->getOperand(2);
11778   EVT VT = N->getValueType(0);
11779   int NumElems = VT.getVectorNumElements();
11780   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
11781          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
11782          Cond.getOpcode() == ISD::BUILD_VECTOR);
11783 
11784   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
11785   // binary ones here.
11786   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
11787     return SDValue();
11788 
11789   // We're sure we have an even number of elements due to the
11790   // concat_vectors we have as arguments to vselect.
11791   // Skip BV elements until we find one that's not an UNDEF
11792   // After we find an UNDEF element, keep looping until we get to half the
11793   // length of the BV and see if all the non-undef nodes are the same.
11794   ConstantSDNode *BottomHalf = nullptr;
11795   for (int i = 0; i < NumElems / 2; ++i) {
11796     if (Cond->getOperand(i)->isUndef())
11797       continue;
11798 
11799     if (BottomHalf == nullptr)
11800       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
11801     else if (Cond->getOperand(i).getNode() != BottomHalf)
11802       return SDValue();
11803   }
11804 
11805   // Do the same for the second half of the BuildVector
11806   ConstantSDNode *TopHalf = nullptr;
11807   for (int i = NumElems / 2; i < NumElems; ++i) {
11808     if (Cond->getOperand(i)->isUndef())
11809       continue;
11810 
11811     if (TopHalf == nullptr)
11812       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
11813     else if (Cond->getOperand(i).getNode() != TopHalf)
11814       return SDValue();
11815   }
11816 
11817   assert(TopHalf && BottomHalf &&
11818          "One half of the selector was all UNDEFs and the other was all the "
11819          "same value. This should have been addressed before this function.");
11820   return DAG.getNode(
11821       ISD::CONCAT_VECTORS, DL, VT,
11822       BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
11823       TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
11824 }
11825 
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG,const SDLoc & DL)11826 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
11827                        SelectionDAG &DAG, const SDLoc &DL) {
11828 
11829   // Only perform the transformation when existing operands can be reused.
11830   if (IndexIsScaled)
11831     return false;
11832 
11833   if (!isNullConstant(BasePtr) && !Index.hasOneUse())
11834     return false;
11835 
11836   EVT VT = BasePtr.getValueType();
11837 
11838   if (SDValue SplatVal = DAG.getSplatValue(Index);
11839       SplatVal && !isNullConstant(SplatVal) &&
11840       SplatVal.getValueType() == VT) {
11841     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
11842     Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
11843     return true;
11844   }
11845 
11846   if (Index.getOpcode() != ISD::ADD)
11847     return false;
11848 
11849   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
11850       SplatVal && SplatVal.getValueType() == VT) {
11851     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
11852     Index = Index.getOperand(1);
11853     return true;
11854   }
11855   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
11856       SplatVal && SplatVal.getValueType() == VT) {
11857     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
11858     Index = Index.getOperand(0);
11859     return true;
11860   }
11861   return false;
11862 }
11863 
11864 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)11865 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
11866                      SelectionDAG &DAG) {
11867   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11868 
11869   // It's always safe to look through zero extends.
11870   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
11871     if (TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
11872       IndexType = ISD::UNSIGNED_SCALED;
11873       Index = Index.getOperand(0);
11874       return true;
11875     }
11876     if (ISD::isIndexTypeSigned(IndexType)) {
11877       IndexType = ISD::UNSIGNED_SCALED;
11878       return true;
11879     }
11880   }
11881 
11882   // It's only safe to look through sign extends when Index is signed.
11883   if (Index.getOpcode() == ISD::SIGN_EXTEND &&
11884       ISD::isIndexTypeSigned(IndexType) &&
11885       TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
11886     Index = Index.getOperand(0);
11887     return true;
11888   }
11889 
11890   return false;
11891 }
11892 
visitVPSCATTER(SDNode * N)11893 SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
11894   VPScatterSDNode *MSC = cast<VPScatterSDNode>(N);
11895   SDValue Mask = MSC->getMask();
11896   SDValue Chain = MSC->getChain();
11897   SDValue Index = MSC->getIndex();
11898   SDValue Scale = MSC->getScale();
11899   SDValue StoreVal = MSC->getValue();
11900   SDValue BasePtr = MSC->getBasePtr();
11901   SDValue VL = MSC->getVectorLength();
11902   ISD::MemIndexType IndexType = MSC->getIndexType();
11903   SDLoc DL(N);
11904 
11905   // Zap scatters with a zero mask.
11906   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11907     return Chain;
11908 
11909   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11910     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11911     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11912                             DL, Ops, MSC->getMemOperand(), IndexType);
11913   }
11914 
11915   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11916     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11917     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11918                             DL, Ops, MSC->getMemOperand(), IndexType);
11919   }
11920 
11921   return SDValue();
11922 }
11923 
visitMSCATTER(SDNode * N)11924 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
11925   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
11926   SDValue Mask = MSC->getMask();
11927   SDValue Chain = MSC->getChain();
11928   SDValue Index = MSC->getIndex();
11929   SDValue Scale = MSC->getScale();
11930   SDValue StoreVal = MSC->getValue();
11931   SDValue BasePtr = MSC->getBasePtr();
11932   ISD::MemIndexType IndexType = MSC->getIndexType();
11933   SDLoc DL(N);
11934 
11935   // Zap scatters with a zero mask.
11936   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11937     return Chain;
11938 
11939   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11940     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11941     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11942                                 DL, Ops, MSC->getMemOperand(), IndexType,
11943                                 MSC->isTruncatingStore());
11944   }
11945 
11946   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11947     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11948     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11949                                 DL, Ops, MSC->getMemOperand(), IndexType,
11950                                 MSC->isTruncatingStore());
11951   }
11952 
11953   return SDValue();
11954 }
11955 
visitMSTORE(SDNode * N)11956 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
11957   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
11958   SDValue Mask = MST->getMask();
11959   SDValue Chain = MST->getChain();
11960   SDValue Value = MST->getValue();
11961   SDValue Ptr = MST->getBasePtr();
11962   SDLoc DL(N);
11963 
11964   // Zap masked stores with a zero mask.
11965   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11966     return Chain;
11967 
11968   // Remove a masked store if base pointers and masks are equal.
11969   if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Chain)) {
11970     if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
11971         MST1->isSimple() && MST1->getBasePtr() == Ptr &&
11972         !MST->getBasePtr().isUndef() &&
11973         ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
11974                                          MST1->getMemoryVT().getStoreSize()) ||
11975          ISD::isConstantSplatVectorAllOnes(Mask.getNode())) &&
11976         TypeSize::isKnownLE(MST1->getMemoryVT().getStoreSize(),
11977                             MST->getMemoryVT().getStoreSize())) {
11978       CombineTo(MST1, MST1->getChain());
11979       if (N->getOpcode() != ISD::DELETED_NODE)
11980         AddToWorklist(N);
11981       return SDValue(N, 0);
11982     }
11983   }
11984 
11985   // If this is a masked load with an all ones mask, we can use a unmasked load.
11986   // FIXME: Can we do this for indexed, compressing, or truncating stores?
11987   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
11988       !MST->isCompressingStore() && !MST->isTruncatingStore())
11989     return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
11990                         MST->getBasePtr(), MST->getPointerInfo(),
11991                         MST->getOriginalAlign(),
11992                         MST->getMemOperand()->getFlags(), MST->getAAInfo());
11993 
11994   // Try transforming N to an indexed store.
11995   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11996     return SDValue(N, 0);
11997 
11998   if (MST->isTruncatingStore() && MST->isUnindexed() &&
11999       Value.getValueType().isInteger() &&
12000       (!isa<ConstantSDNode>(Value) ||
12001        !cast<ConstantSDNode>(Value)->isOpaque())) {
12002     APInt TruncDemandedBits =
12003         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
12004                              MST->getMemoryVT().getScalarSizeInBits());
12005 
12006     // See if we can simplify the operation with
12007     // SimplifyDemandedBits, which only works if the value has a single use.
12008     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
12009       // Re-visit the store if anything changed and the store hasn't been merged
12010       // with another node (N is deleted) SimplifyDemandedBits will add Value's
12011       // node back to the worklist if necessary, but we also need to re-visit
12012       // the Store node itself.
12013       if (N->getOpcode() != ISD::DELETED_NODE)
12014         AddToWorklist(N);
12015       return SDValue(N, 0);
12016     }
12017   }
12018 
12019   // If this is a TRUNC followed by a masked store, fold this into a masked
12020   // truncating store.  We can do this even if this is already a masked
12021   // truncstore.
12022   // TODO: Try combine to masked compress store if possiable.
12023   if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
12024       MST->isUnindexed() && !MST->isCompressingStore() &&
12025       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
12026                                MST->getMemoryVT(), LegalOperations)) {
12027     auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
12028                                          Value.getOperand(0).getValueType());
12029     return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
12030                               MST->getOffset(), Mask, MST->getMemoryVT(),
12031                               MST->getMemOperand(), MST->getAddressingMode(),
12032                               /*IsTruncating=*/true);
12033   }
12034 
12035   return SDValue();
12036 }
12037 
visitVP_STRIDED_STORE(SDNode * N)12038 SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
12039   auto *SST = cast<VPStridedStoreSDNode>(N);
12040   EVT EltVT = SST->getValue().getValueType().getVectorElementType();
12041   // Combine strided stores with unit-stride to a regular VP store.
12042   if (auto *CStride = dyn_cast<ConstantSDNode>(SST->getStride());
12043       CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12044     return DAG.getStoreVP(SST->getChain(), SDLoc(N), SST->getValue(),
12045                           SST->getBasePtr(), SST->getOffset(), SST->getMask(),
12046                           SST->getVectorLength(), SST->getMemoryVT(),
12047                           SST->getMemOperand(), SST->getAddressingMode(),
12048                           SST->isTruncatingStore(), SST->isCompressingStore());
12049   }
12050   return SDValue();
12051 }
12052 
visitVECTOR_COMPRESS(SDNode * N)12053 SDValue DAGCombiner::visitVECTOR_COMPRESS(SDNode *N) {
12054   SDLoc DL(N);
12055   SDValue Vec = N->getOperand(0);
12056   SDValue Mask = N->getOperand(1);
12057   SDValue Passthru = N->getOperand(2);
12058   EVT VecVT = Vec.getValueType();
12059 
12060   bool HasPassthru = !Passthru.isUndef();
12061 
12062   APInt SplatVal;
12063   if (ISD::isConstantSplatVector(Mask.getNode(), SplatVal))
12064     return TLI.isConstTrueVal(Mask) ? Vec : Passthru;
12065 
12066   if (Vec.isUndef() || Mask.isUndef())
12067     return Passthru;
12068 
12069   // No need for potentially expensive compress if the mask is constant.
12070   if (ISD::isBuildVectorOfConstantSDNodes(Mask.getNode())) {
12071     SmallVector<SDValue, 16> Ops;
12072     EVT ScalarVT = VecVT.getVectorElementType();
12073     unsigned NumSelected = 0;
12074     unsigned NumElmts = VecVT.getVectorNumElements();
12075     for (unsigned I = 0; I < NumElmts; ++I) {
12076       SDValue MaskI = Mask.getOperand(I);
12077       // We treat undef mask entries as "false".
12078       if (MaskI.isUndef())
12079         continue;
12080 
12081       if (TLI.isConstTrueVal(MaskI)) {
12082         SDValue VecI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec,
12083                                    DAG.getVectorIdxConstant(I, DL));
12084         Ops.push_back(VecI);
12085         NumSelected++;
12086       }
12087     }
12088     for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
12089       SDValue Val =
12090           HasPassthru
12091               ? DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Passthru,
12092                             DAG.getVectorIdxConstant(Rest, DL))
12093               : DAG.getUNDEF(ScalarVT);
12094       Ops.push_back(Val);
12095     }
12096     return DAG.getBuildVector(VecVT, DL, Ops);
12097   }
12098 
12099   return SDValue();
12100 }
12101 
visitVPGATHER(SDNode * N)12102 SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
12103   VPGatherSDNode *MGT = cast<VPGatherSDNode>(N);
12104   SDValue Mask = MGT->getMask();
12105   SDValue Chain = MGT->getChain();
12106   SDValue Index = MGT->getIndex();
12107   SDValue Scale = MGT->getScale();
12108   SDValue BasePtr = MGT->getBasePtr();
12109   SDValue VL = MGT->getVectorLength();
12110   ISD::MemIndexType IndexType = MGT->getIndexType();
12111   SDLoc DL(N);
12112 
12113   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
12114     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12115     return DAG.getGatherVP(
12116         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12117         Ops, MGT->getMemOperand(), IndexType);
12118   }
12119 
12120   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12121     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12122     return DAG.getGatherVP(
12123         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12124         Ops, MGT->getMemOperand(), IndexType);
12125   }
12126 
12127   return SDValue();
12128 }
12129 
visitMGATHER(SDNode * N)12130 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
12131   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
12132   SDValue Mask = MGT->getMask();
12133   SDValue Chain = MGT->getChain();
12134   SDValue Index = MGT->getIndex();
12135   SDValue Scale = MGT->getScale();
12136   SDValue PassThru = MGT->getPassThru();
12137   SDValue BasePtr = MGT->getBasePtr();
12138   ISD::MemIndexType IndexType = MGT->getIndexType();
12139   SDLoc DL(N);
12140 
12141   // Zap gathers with a zero mask.
12142   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12143     return CombineTo(N, PassThru, MGT->getChain());
12144 
12145   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
12146     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12147     return DAG.getMaskedGather(
12148         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12149         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12150   }
12151 
12152   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12153     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12154     return DAG.getMaskedGather(
12155         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12156         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12157   }
12158 
12159   return SDValue();
12160 }
12161 
visitMLOAD(SDNode * N)12162 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
12163   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
12164   SDValue Mask = MLD->getMask();
12165   SDLoc DL(N);
12166 
12167   // Zap masked loads with a zero mask.
12168   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12169     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
12170 
12171   // If this is a masked load with an all ones mask, we can use a unmasked load.
12172   // FIXME: Can we do this for indexed, expanding, or extending loads?
12173   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
12174       !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
12175     SDValue NewLd = DAG.getLoad(
12176         N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
12177         MLD->getPointerInfo(), MLD->getOriginalAlign(),
12178         MLD->getMemOperand()->getFlags(), MLD->getAAInfo(), MLD->getRanges());
12179     return CombineTo(N, NewLd, NewLd.getValue(1));
12180   }
12181 
12182   // Try transforming N to an indexed load.
12183   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12184     return SDValue(N, 0);
12185 
12186   return SDValue();
12187 }
12188 
visitVP_STRIDED_LOAD(SDNode * N)12189 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12190   auto *SLD = cast<VPStridedLoadSDNode>(N);
12191   EVT EltVT = SLD->getValueType(0).getVectorElementType();
12192   // Combine strided loads with unit-stride to a regular VP load.
12193   if (auto *CStride = dyn_cast<ConstantSDNode>(SLD->getStride());
12194       CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12195     SDValue NewLd = DAG.getLoadVP(
12196         SLD->getAddressingMode(), SLD->getExtensionType(), SLD->getValueType(0),
12197         SDLoc(N), SLD->getChain(), SLD->getBasePtr(), SLD->getOffset(),
12198         SLD->getMask(), SLD->getVectorLength(), SLD->getMemoryVT(),
12199         SLD->getMemOperand(), SLD->isExpandingLoad());
12200     return CombineTo(N, NewLd, NewLd.getValue(1));
12201   }
12202   return SDValue();
12203 }
12204 
12205 /// A vector select of 2 constant vectors can be simplified to math/logic to
12206 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)12207 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12208   SDValue Cond = N->getOperand(0);
12209   SDValue N1 = N->getOperand(1);
12210   SDValue N2 = N->getOperand(2);
12211   EVT VT = N->getValueType(0);
12212   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
12213       !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
12214       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
12215       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
12216     return SDValue();
12217 
12218   // Check if we can use the condition value to increment/decrement a single
12219   // constant value. This simplifies a select to an add and removes a constant
12220   // load/materialization from the general case.
12221   bool AllAddOne = true;
12222   bool AllSubOne = true;
12223   unsigned Elts = VT.getVectorNumElements();
12224   for (unsigned i = 0; i != Elts; ++i) {
12225     SDValue N1Elt = N1.getOperand(i);
12226     SDValue N2Elt = N2.getOperand(i);
12227     if (N1Elt.isUndef() || N2Elt.isUndef())
12228       continue;
12229     if (N1Elt.getValueType() != N2Elt.getValueType()) {
12230       AllAddOne = false;
12231       AllSubOne = false;
12232       break;
12233     }
12234 
12235     const APInt &C1 = N1Elt->getAsAPIntVal();
12236     const APInt &C2 = N2Elt->getAsAPIntVal();
12237     if (C1 != C2 + 1)
12238       AllAddOne = false;
12239     if (C1 != C2 - 1)
12240       AllSubOne = false;
12241   }
12242 
12243   // Further simplifications for the extra-special cases where the constants are
12244   // all 0 or all -1 should be implemented as folds of these patterns.
12245   SDLoc DL(N);
12246   if (AllAddOne || AllSubOne) {
12247     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
12248     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
12249     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
12250     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
12251     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
12252   }
12253 
12254   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
12255   APInt Pow2C;
12256   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
12257       isNullOrNullSplat(N2)) {
12258     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
12259     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
12260     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
12261   }
12262 
12263   if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
12264     return V;
12265 
12266   // The general case for select-of-constants:
12267   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
12268   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
12269   // leave that to a machine-specific pass.
12270   return SDValue();
12271 }
12272 
visitVP_SELECT(SDNode * N)12273 SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
12274   SDValue N0 = N->getOperand(0);
12275   SDValue N1 = N->getOperand(1);
12276   SDValue N2 = N->getOperand(2);
12277   SDLoc DL(N);
12278 
12279   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
12280     return V;
12281 
12282   if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DL, DAG))
12283     return V;
12284 
12285   return SDValue();
12286 }
12287 
visitVSELECT(SDNode * N)12288 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
12289   SDValue N0 = N->getOperand(0);
12290   SDValue N1 = N->getOperand(1);
12291   SDValue N2 = N->getOperand(2);
12292   EVT VT = N->getValueType(0);
12293   SDLoc DL(N);
12294 
12295   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
12296     return V;
12297 
12298   if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
12299     return V;
12300 
12301   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
12302   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
12303     return DAG.getSelect(DL, VT, F, N2, N1);
12304 
12305   // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
12306   if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&
12307       DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1)) &&
12308       N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits() &&
12309       TLI.getBooleanContents(N0.getValueType()) ==
12310           TargetLowering::ZeroOrNegativeOneBooleanContent) {
12311     return DAG.getNode(
12312         ISD::ADD, DL, N1.getValueType(), N2,
12313         DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1), N0));
12314   }
12315 
12316   // Canonicalize integer abs.
12317   // vselect (setg[te] X,  0),  X, -X ->
12318   // vselect (setgt    X, -1),  X, -X ->
12319   // vselect (setl[te] X,  0), -X,  X ->
12320   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
12321   if (N0.getOpcode() == ISD::SETCC) {
12322     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
12323     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
12324     bool isAbs = false;
12325     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
12326 
12327     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
12328          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
12329         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
12330       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
12331     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
12332              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
12333       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
12334 
12335     if (isAbs) {
12336       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
12337         return DAG.getNode(ISD::ABS, DL, VT, LHS);
12338 
12339       SDValue Shift = DAG.getNode(
12340           ISD::SRA, DL, VT, LHS,
12341           DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
12342       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
12343       AddToWorklist(Shift.getNode());
12344       AddToWorklist(Add.getNode());
12345       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
12346     }
12347 
12348     // vselect x, y (fcmp lt x, y) -> fminnum x, y
12349     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
12350     //
12351     // This is OK if we don't care about what happens if either operand is a
12352     // NaN.
12353     //
12354     if (N0.hasOneUse() &&
12355         isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, N->getFlags(), TLI)) {
12356       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
12357         return FMinMax;
12358     }
12359 
12360     if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
12361       return S;
12362     if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
12363       return S;
12364 
12365     // If this select has a condition (setcc) with narrower operands than the
12366     // select, try to widen the compare to match the select width.
12367     // TODO: This should be extended to handle any constant.
12368     // TODO: This could be extended to handle non-loading patterns, but that
12369     //       requires thorough testing to avoid regressions.
12370     if (isNullOrNullSplat(RHS)) {
12371       EVT NarrowVT = LHS.getValueType();
12372       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
12373       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
12374       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
12375       unsigned WideWidth = WideVT.getScalarSizeInBits();
12376       bool IsSigned = isSignedIntSetCC(CC);
12377       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12378       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
12379           SetCCWidth != 1 && SetCCWidth < WideWidth &&
12380           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
12381           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
12382         // Both compare operands can be widened for free. The LHS can use an
12383         // extended load, and the RHS is a constant:
12384         //   vselect (ext (setcc load(X), C)), N1, N2 -->
12385         //   vselect (setcc extload(X), C'), N1, N2
12386         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12387         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
12388         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
12389         EVT WideSetCCVT = getSetCCResultType(WideVT);
12390         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
12391         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
12392       }
12393     }
12394 
12395     // Match VSELECTs with absolute difference patterns.
12396     // (vselect (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12397     // (vselect (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12398     // (vselect (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12399     // (vselect (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12400     if (N1.getOpcode() == ISD::SUB && N2.getOpcode() == ISD::SUB &&
12401         N1.getOperand(0) == N2.getOperand(1) &&
12402         N1.getOperand(1) == N2.getOperand(0)) {
12403       bool IsSigned = isSignedIntSetCC(CC);
12404       unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12405       if (hasOperation(ABDOpc, VT)) {
12406         switch (CC) {
12407         case ISD::SETGT:
12408         case ISD::SETGE:
12409         case ISD::SETUGT:
12410         case ISD::SETUGE:
12411           if (LHS == N1.getOperand(0) && RHS == N1.getOperand(1))
12412             return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12413           break;
12414         case ISD::SETLT:
12415         case ISD::SETLE:
12416         case ISD::SETULT:
12417         case ISD::SETULE:
12418           if (RHS == N1.getOperand(0) && LHS == N1.getOperand(1) )
12419             return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12420           break;
12421         default:
12422           break;
12423         }
12424       }
12425     }
12426 
12427     // Match VSELECTs into add with unsigned saturation.
12428     if (hasOperation(ISD::UADDSAT, VT)) {
12429       // Check if one of the arms of the VSELECT is vector with all bits set.
12430       // If it's on the left side invert the predicate to simplify logic below.
12431       SDValue Other;
12432       ISD::CondCode SatCC = CC;
12433       if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
12434         Other = N2;
12435         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
12436       } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
12437         Other = N1;
12438       }
12439 
12440       if (Other && Other.getOpcode() == ISD::ADD) {
12441         SDValue CondLHS = LHS, CondRHS = RHS;
12442         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
12443 
12444         // Canonicalize condition operands.
12445         if (SatCC == ISD::SETUGE) {
12446           std::swap(CondLHS, CondRHS);
12447           SatCC = ISD::SETULE;
12448         }
12449 
12450         // We can test against either of the addition operands.
12451         // x <= x+y ? x+y : ~0 --> uaddsat x, y
12452         // x+y >= x ? x+y : ~0 --> uaddsat x, y
12453         if (SatCC == ISD::SETULE && Other == CondRHS &&
12454             (OpLHS == CondLHS || OpRHS == CondLHS))
12455           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
12456 
12457         if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
12458             (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12459              OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
12460             CondLHS == OpLHS) {
12461           // If the RHS is a constant we have to reverse the const
12462           // canonicalization.
12463           // x >= ~C ? x+C : ~0 --> uaddsat x, C
12464           auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12465             return Cond->getAPIntValue() == ~Op->getAPIntValue();
12466           };
12467           if (SatCC == ISD::SETULE &&
12468               ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
12469             return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
12470         }
12471       }
12472     }
12473 
12474     // Match VSELECTs into sub with unsigned saturation.
12475     if (hasOperation(ISD::USUBSAT, VT)) {
12476       // Check if one of the arms of the VSELECT is a zero vector. If it's on
12477       // the left side invert the predicate to simplify logic below.
12478       SDValue Other;
12479       ISD::CondCode SatCC = CC;
12480       if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
12481         Other = N2;
12482         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
12483       } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
12484         Other = N1;
12485       }
12486 
12487       // zext(x) >= y ? trunc(zext(x) - y) : 0
12488       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12489       // zext(x) >  y ? trunc(zext(x) - y) : 0
12490       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12491       if (Other && Other.getOpcode() == ISD::TRUNCATE &&
12492           Other.getOperand(0).getOpcode() == ISD::SUB &&
12493           (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
12494         SDValue OpLHS = Other.getOperand(0).getOperand(0);
12495         SDValue OpRHS = Other.getOperand(0).getOperand(1);
12496         if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
12497           if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
12498                                               DAG, DL))
12499             return R;
12500       }
12501 
12502       if (Other && Other.getNumOperands() == 2) {
12503         SDValue CondRHS = RHS;
12504         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
12505 
12506         if (OpLHS == LHS) {
12507           // Look for a general sub with unsigned saturation first.
12508           // x >= y ? x-y : 0 --> usubsat x, y
12509           // x >  y ? x-y : 0 --> usubsat x, y
12510           if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
12511               Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
12512             return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
12513 
12514           if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12515               OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12516             if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
12517                 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12518               // If the RHS is a constant we have to reverse the const
12519               // canonicalization.
12520               // x > C-1 ? x+-C : 0 --> usubsat x, C
12521               auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12522                 return (!Op && !Cond) ||
12523                        (Op && Cond &&
12524                         Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
12525               };
12526               if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
12527                   ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
12528                                             /*AllowUndefs*/ true)) {
12529                 OpRHS = DAG.getNegative(OpRHS, DL, VT);
12530                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
12531               }
12532 
12533               // Another special case: If C was a sign bit, the sub has been
12534               // canonicalized into a xor.
12535               // FIXME: Would it be better to use computeKnownBits to
12536               // determine whether it's safe to decanonicalize the xor?
12537               // x s< 0 ? x^C : 0 --> usubsat x, C
12538               APInt SplatValue;
12539               if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
12540                   ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
12541                   ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
12542                   SplatValue.isSignMask()) {
12543                 // Note that we have to rebuild the RHS constant here to
12544                 // ensure we don't rely on particular values of undef lanes.
12545                 OpRHS = DAG.getConstant(SplatValue, DL, VT);
12546                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
12547               }
12548             }
12549           }
12550         }
12551       }
12552     }
12553   }
12554 
12555   if (SimplifySelectOps(N, N1, N2))
12556     return SDValue(N, 0);  // Don't revisit N.
12557 
12558   // Fold (vselect all_ones, N1, N2) -> N1
12559   if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
12560     return N1;
12561   // Fold (vselect all_zeros, N1, N2) -> N2
12562   if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
12563     return N2;
12564 
12565   // The ConvertSelectToConcatVector function is assuming both the above
12566   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
12567   // and addressed.
12568   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
12569       N2.getOpcode() == ISD::CONCAT_VECTORS &&
12570       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
12571     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
12572       return CV;
12573   }
12574 
12575   if (SDValue V = foldVSelectOfConstants(N))
12576     return V;
12577 
12578   if (hasOperation(ISD::SRA, VT))
12579     if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
12580       return V;
12581 
12582   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
12583     return SDValue(N, 0);
12584 
12585   return SDValue();
12586 }
12587 
visitSELECT_CC(SDNode * N)12588 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
12589   SDValue N0 = N->getOperand(0);
12590   SDValue N1 = N->getOperand(1);
12591   SDValue N2 = N->getOperand(2);
12592   SDValue N3 = N->getOperand(3);
12593   SDValue N4 = N->getOperand(4);
12594   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
12595   SDLoc DL(N);
12596 
12597   // fold select_cc lhs, rhs, x, x, cc -> x
12598   if (N2 == N3)
12599     return N2;
12600 
12601   // select_cc bool, 0, x, y, seteq -> select bool, y, x
12602   if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
12603       isNullConstant(N1))
12604     return DAG.getSelect(DL, N2.getValueType(), N0, N3, N2);
12605 
12606   // Determine if the condition we're dealing with is constant
12607   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
12608                                   CC, DL, false)) {
12609     AddToWorklist(SCC.getNode());
12610 
12611     // cond always true -> true val
12612     // cond always false -> false val
12613     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
12614       return SCCC->isZero() ? N3 : N2;
12615 
12616     // When the condition is UNDEF, just return the first operand. This is
12617     // coherent the DAG creation, no setcc node is created in this case
12618     if (SCC->isUndef())
12619       return N2;
12620 
12621     // Fold to a simpler select_cc
12622     if (SCC.getOpcode() == ISD::SETCC) {
12623       SDValue SelectOp =
12624           DAG.getNode(ISD::SELECT_CC, DL, N2.getValueType(), SCC.getOperand(0),
12625                       SCC.getOperand(1), N2, N3, SCC.getOperand(2));
12626       SelectOp->setFlags(SCC->getFlags());
12627       return SelectOp;
12628     }
12629   }
12630 
12631   // If we can fold this based on the true/false value, do so.
12632   if (SimplifySelectOps(N, N2, N3))
12633     return SDValue(N, 0); // Don't revisit N.
12634 
12635   // fold select_cc into other things, such as min/max/abs
12636   return SimplifySelectCC(DL, N0, N1, N2, N3, CC);
12637 }
12638 
visitSETCC(SDNode * N)12639 SDValue DAGCombiner::visitSETCC(SDNode *N) {
12640   // setcc is very commonly used as an argument to brcond. This pattern
12641   // also lend itself to numerous combines and, as a result, it is desired
12642   // we keep the argument to a brcond as a setcc as much as possible.
12643   bool PreferSetCC =
12644       N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
12645 
12646   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
12647   EVT VT = N->getValueType(0);
12648   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
12649   SDLoc DL(N);
12650 
12651   if (SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL, !PreferSetCC)) {
12652     // If we prefer to have a setcc, and we don't, we'll try our best to
12653     // recreate one using rebuildSetCC.
12654     if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12655       SDValue NewSetCC = rebuildSetCC(Combined);
12656 
12657       // We don't have anything interesting to combine to.
12658       if (NewSetCC.getNode() == N)
12659         return SDValue();
12660 
12661       if (NewSetCC)
12662         return NewSetCC;
12663     }
12664     return Combined;
12665   }
12666 
12667   // Optimize
12668   //    1) (icmp eq/ne (and X, C0), (shift X, C1))
12669   // or
12670   //    2) (icmp eq/ne X, (rotate X, C1))
12671   // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
12672   // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
12673   // Then:
12674   // If C1 is a power of 2, then the rotate and shift+and versions are
12675   // equivilent, so we can interchange them depending on target preference.
12676   // Otherwise, if we have the shift+and version we can interchange srl/shl
12677   // which inturn affects the constant C0. We can use this to get better
12678   // constants again determined by target preference.
12679   if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
12680     auto IsAndWithShift = [](SDValue A, SDValue B) {
12681       return A.getOpcode() == ISD::AND &&
12682              (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
12683              A.getOperand(0) == B.getOperand(0);
12684     };
12685     auto IsRotateWithOp = [](SDValue A, SDValue B) {
12686       return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
12687              B.getOperand(0) == A;
12688     };
12689     SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
12690     bool IsRotate = false;
12691 
12692     // Find either shift+and or rotate pattern.
12693     if (IsAndWithShift(N0, N1)) {
12694       AndOrOp = N0;
12695       ShiftOrRotate = N1;
12696     } else if (IsAndWithShift(N1, N0)) {
12697       AndOrOp = N1;
12698       ShiftOrRotate = N0;
12699     } else if (IsRotateWithOp(N0, N1)) {
12700       IsRotate = true;
12701       AndOrOp = N0;
12702       ShiftOrRotate = N1;
12703     } else if (IsRotateWithOp(N1, N0)) {
12704       IsRotate = true;
12705       AndOrOp = N1;
12706       ShiftOrRotate = N0;
12707     }
12708 
12709     if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
12710         (IsRotate || AndOrOp.hasOneUse())) {
12711       EVT OpVT = N0.getValueType();
12712       // Get constant shift/rotate amount and possibly mask (if its shift+and
12713       // variant).
12714       auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
12715         ConstantSDNode *CNode = isConstOrConstSplat(Op, /*AllowUndefs*/ false,
12716                                                     /*AllowTrunc*/ false);
12717         if (CNode == nullptr)
12718           return std::nullopt;
12719         return CNode->getAPIntValue();
12720       };
12721       std::optional<APInt> AndCMask =
12722           IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(1));
12723       std::optional<APInt> ShiftCAmt =
12724           GetAPIntValue(ShiftOrRotate.getOperand(1));
12725       unsigned NumBits = OpVT.getScalarSizeInBits();
12726 
12727       // We found constants.
12728       if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(NumBits)) {
12729         unsigned ShiftOpc = ShiftOrRotate.getOpcode();
12730         // Check that the constants meet the constraints.
12731         bool CanTransform = IsRotate;
12732         if (!CanTransform) {
12733           // Check that mask and shift compliment eachother
12734           CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
12735           // Check that we are comparing all bits
12736           CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
12737           // Check that the and mask is correct for the shift
12738           CanTransform &=
12739               ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
12740         }
12741 
12742         // See if target prefers another shift/rotate opcode.
12743         unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
12744             OpVT, ShiftOpc, ShiftCAmt->isPowerOf2(), *ShiftCAmt, AndCMask);
12745         // Transform is valid and we have a new preference.
12746         if (CanTransform && NewShiftOpc != ShiftOpc) {
12747           SDValue NewShiftOrRotate =
12748               DAG.getNode(NewShiftOpc, DL, OpVT, ShiftOrRotate.getOperand(0),
12749                           ShiftOrRotate.getOperand(1));
12750           SDValue NewAndOrOp = SDValue();
12751 
12752           if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
12753             APInt NewMask =
12754                 NewShiftOpc == ISD::SHL
12755                     ? APInt::getHighBitsSet(NumBits,
12756                                             NumBits - ShiftCAmt->getZExtValue())
12757                     : APInt::getLowBitsSet(NumBits,
12758                                            NumBits - ShiftCAmt->getZExtValue());
12759             NewAndOrOp =
12760                 DAG.getNode(ISD::AND, DL, OpVT, ShiftOrRotate.getOperand(0),
12761                             DAG.getConstant(NewMask, DL, OpVT));
12762           } else {
12763             NewAndOrOp = ShiftOrRotate.getOperand(0);
12764           }
12765 
12766           return DAG.getSetCC(DL, VT, NewAndOrOp, NewShiftOrRotate, Cond);
12767         }
12768       }
12769     }
12770   }
12771   return SDValue();
12772 }
12773 
visitSETCCCARRY(SDNode * N)12774 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
12775   SDValue LHS = N->getOperand(0);
12776   SDValue RHS = N->getOperand(1);
12777   SDValue Carry = N->getOperand(2);
12778   SDValue Cond = N->getOperand(3);
12779 
12780   // If Carry is false, fold to a regular SETCC.
12781   if (isNullConstant(Carry))
12782     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
12783 
12784   return SDValue();
12785 }
12786 
12787 /// Check if N satisfies:
12788 ///   N is used once.
12789 ///   N is a Load.
12790 ///   The load is compatible with ExtOpcode. It means
12791 ///     If load has explicit zero/sign extension, ExpOpcode must have the same
12792 ///     extension.
12793 ///     Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)12794 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
12795   if (!N.hasOneUse())
12796     return false;
12797 
12798   if (!isa<LoadSDNode>(N))
12799     return false;
12800 
12801   LoadSDNode *Load = cast<LoadSDNode>(N);
12802   ISD::LoadExtType LoadExt = Load->getExtensionType();
12803   if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
12804     return true;
12805 
12806   // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
12807   // extension.
12808   if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
12809       (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
12810     return false;
12811 
12812   return true;
12813 }
12814 
12815 /// Fold
12816 ///   (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
12817 ///   (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
12818 ///   (aext (select c, load x, load y)) -> (select c, extload x, extload y)
12819 /// This function is called by the DAGCombiner when visiting sext/zext/aext
12820 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,const SDLoc & DL,CombineLevel Level)12821 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
12822                                          SelectionDAG &DAG, const SDLoc &DL,
12823                                          CombineLevel Level) {
12824   unsigned Opcode = N->getOpcode();
12825   SDValue N0 = N->getOperand(0);
12826   EVT VT = N->getValueType(0);
12827   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
12828           Opcode == ISD::ANY_EXTEND) &&
12829          "Expected EXTEND dag node in input!");
12830 
12831   if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
12832       !N0.hasOneUse())
12833     return SDValue();
12834 
12835   SDValue Op1 = N0->getOperand(1);
12836   SDValue Op2 = N0->getOperand(2);
12837   if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
12838     return SDValue();
12839 
12840   auto ExtLoadOpcode = ISD::EXTLOAD;
12841   if (Opcode == ISD::SIGN_EXTEND)
12842     ExtLoadOpcode = ISD::SEXTLOAD;
12843   else if (Opcode == ISD::ZERO_EXTEND)
12844     ExtLoadOpcode = ISD::ZEXTLOAD;
12845 
12846   // Illegal VSELECT may ISel fail if happen after legalization (DAG
12847   // Combine2), so we should conservatively check the OperationAction.
12848   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
12849   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
12850   if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
12851       !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) ||
12852       (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
12853        TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal))
12854     return SDValue();
12855 
12856   SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
12857   SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
12858   return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
12859 }
12860 
12861 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
12862 /// a build_vector of constants.
12863 /// This function is called by the DAGCombiner when visiting sext/zext/aext
12864 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12865 /// Vector extends are not folded if operations are legal; this is to
12866 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)12867 static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
12868                                          const TargetLowering &TLI,
12869                                          SelectionDAG &DAG, bool LegalTypes) {
12870   unsigned Opcode = N->getOpcode();
12871   SDValue N0 = N->getOperand(0);
12872   EVT VT = N->getValueType(0);
12873 
12874   assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
12875          "Expected EXTEND dag node in input!");
12876 
12877   // fold (sext c1) -> c1
12878   // fold (zext c1) -> c1
12879   // fold (aext c1) -> c1
12880   if (isa<ConstantSDNode>(N0))
12881     return DAG.getNode(Opcode, DL, VT, N0);
12882 
12883   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12884   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
12885   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12886   if (N0->getOpcode() == ISD::SELECT) {
12887     SDValue Op1 = N0->getOperand(1);
12888     SDValue Op2 = N0->getOperand(2);
12889     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
12890         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
12891       // For any_extend, choose sign extension of the constants to allow a
12892       // possible further transform to sign_extend_inreg.i.e.
12893       //
12894       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
12895       // t2: i64 = any_extend t1
12896       // -->
12897       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
12898       // -->
12899       // t4: i64 = sign_extend_inreg t3
12900       unsigned FoldOpc = Opcode;
12901       if (FoldOpc == ISD::ANY_EXTEND)
12902         FoldOpc = ISD::SIGN_EXTEND;
12903       return DAG.getSelect(DL, VT, N0->getOperand(0),
12904                            DAG.getNode(FoldOpc, DL, VT, Op1),
12905                            DAG.getNode(FoldOpc, DL, VT, Op2));
12906     }
12907   }
12908 
12909   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
12910   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
12911   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
12912   EVT SVT = VT.getScalarType();
12913   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
12914       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
12915     return SDValue();
12916 
12917   // We can fold this node into a build_vector.
12918   unsigned VTBits = SVT.getSizeInBits();
12919   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
12920   SmallVector<SDValue, 8> Elts;
12921   unsigned NumElts = VT.getVectorNumElements();
12922 
12923   for (unsigned i = 0; i != NumElts; ++i) {
12924     SDValue Op = N0.getOperand(i);
12925     if (Op.isUndef()) {
12926       if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
12927         Elts.push_back(DAG.getUNDEF(SVT));
12928       else
12929         Elts.push_back(DAG.getConstant(0, DL, SVT));
12930       continue;
12931     }
12932 
12933     SDLoc DL(Op);
12934     // Get the constant value and if needed trunc it to the size of the type.
12935     // Nodes like build_vector might have constants wider than the scalar type.
12936     APInt C = Op->getAsAPIntVal().zextOrTrunc(EVTBits);
12937     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
12938       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
12939     else
12940       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
12941   }
12942 
12943   return DAG.getBuildVector(VT, DL, Elts);
12944 }
12945 
12946 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
12947 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
12948 // transformation. Returns true if extension are possible and the above
12949 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)12950 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
12951                                     unsigned ExtOpc,
12952                                     SmallVectorImpl<SDNode *> &ExtendNodes,
12953                                     const TargetLowering &TLI) {
12954   bool HasCopyToRegUses = false;
12955   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
12956   for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
12957        ++UI) {
12958     SDNode *User = *UI;
12959     if (User == N)
12960       continue;
12961     if (UI.getUse().getResNo() != N0.getResNo())
12962       continue;
12963     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
12964     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
12965       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
12966       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
12967         // Sign bits will be lost after a zext.
12968         return false;
12969       bool Add = false;
12970       for (unsigned i = 0; i != 2; ++i) {
12971         SDValue UseOp = User->getOperand(i);
12972         if (UseOp == N0)
12973           continue;
12974         if (!isa<ConstantSDNode>(UseOp))
12975           return false;
12976         Add = true;
12977       }
12978       if (Add)
12979         ExtendNodes.push_back(User);
12980       continue;
12981     }
12982     // If truncates aren't free and there are users we can't
12983     // extend, it isn't worthwhile.
12984     if (!isTruncFree)
12985       return false;
12986     // Remember if this value is live-out.
12987     if (User->getOpcode() == ISD::CopyToReg)
12988       HasCopyToRegUses = true;
12989   }
12990 
12991   if (HasCopyToRegUses) {
12992     bool BothLiveOut = false;
12993     for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
12994          UI != UE; ++UI) {
12995       SDUse &Use = UI.getUse();
12996       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
12997         BothLiveOut = true;
12998         break;
12999       }
13000     }
13001     if (BothLiveOut)
13002       // Both unextended and extended values are live out. There had better be
13003       // a good reason for the transformation.
13004       return !ExtendNodes.empty();
13005   }
13006   return true;
13007 }
13008 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)13009 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
13010                                   SDValue OrigLoad, SDValue ExtLoad,
13011                                   ISD::NodeType ExtType) {
13012   // Extend SetCC uses if necessary.
13013   SDLoc DL(ExtLoad);
13014   for (SDNode *SetCC : SetCCs) {
13015     SmallVector<SDValue, 4> Ops;
13016 
13017     for (unsigned j = 0; j != 2; ++j) {
13018       SDValue SOp = SetCC->getOperand(j);
13019       if (SOp == OrigLoad)
13020         Ops.push_back(ExtLoad);
13021       else
13022         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
13023     }
13024 
13025     Ops.push_back(SetCC->getOperand(2));
13026     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
13027   }
13028 }
13029 
13030 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)13031 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
13032   SDValue N0 = N->getOperand(0);
13033   EVT DstVT = N->getValueType(0);
13034   EVT SrcVT = N0.getValueType();
13035 
13036   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13037           N->getOpcode() == ISD::ZERO_EXTEND) &&
13038          "Unexpected node type (not an extend)!");
13039 
13040   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
13041   // For example, on a target with legal v4i32, but illegal v8i32, turn:
13042   //   (v8i32 (sext (v8i16 (load x))))
13043   // into:
13044   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
13045   //                          (v4i32 (sextload (x + 16)))))
13046   // Where uses of the original load, i.e.:
13047   //   (v8i16 (load x))
13048   // are replaced with:
13049   //   (v8i16 (truncate
13050   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
13051   //                            (v4i32 (sextload (x + 16)))))))
13052   //
13053   // This combine is only applicable to illegal, but splittable, vectors.
13054   // All legal types, and illegal non-vector types, are handled elsewhere.
13055   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
13056   //
13057   if (N0->getOpcode() != ISD::LOAD)
13058     return SDValue();
13059 
13060   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13061 
13062   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
13063       !N0.hasOneUse() || !LN0->isSimple() ||
13064       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
13065       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
13066     return SDValue();
13067 
13068   SmallVector<SDNode *, 4> SetCCs;
13069   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
13070     return SDValue();
13071 
13072   ISD::LoadExtType ExtType =
13073       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13074 
13075   // Try to split the vector types to get down to legal types.
13076   EVT SplitSrcVT = SrcVT;
13077   EVT SplitDstVT = DstVT;
13078   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
13079          SplitSrcVT.getVectorNumElements() > 1) {
13080     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
13081     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
13082   }
13083 
13084   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
13085     return SDValue();
13086 
13087   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
13088 
13089   SDLoc DL(N);
13090   const unsigned NumSplits =
13091       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
13092   const unsigned Stride = SplitSrcVT.getStoreSize();
13093   SmallVector<SDValue, 4> Loads;
13094   SmallVector<SDValue, 4> Chains;
13095 
13096   SDValue BasePtr = LN0->getBasePtr();
13097   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
13098     const unsigned Offset = Idx * Stride;
13099 
13100     SDValue SplitLoad =
13101         DAG.getExtLoad(ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(),
13102                        BasePtr, LN0->getPointerInfo().getWithOffset(Offset),
13103                        SplitSrcVT, LN0->getOriginalAlign(),
13104                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
13105 
13106     BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::getFixed(Stride), DL);
13107 
13108     Loads.push_back(SplitLoad.getValue(0));
13109     Chains.push_back(SplitLoad.getValue(1));
13110   }
13111 
13112   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
13113   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
13114 
13115   // Simplify TF.
13116   AddToWorklist(NewChain.getNode());
13117 
13118   CombineTo(N, NewValue);
13119 
13120   // Replace uses of the original load (before extension)
13121   // with a truncate of the concatenated sextloaded vectors.
13122   SDValue Trunc =
13123       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
13124   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
13125   CombineTo(N0.getNode(), Trunc, NewChain);
13126   return SDValue(N, 0); // Return N so it doesn't get rechecked!
13127 }
13128 
13129 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13130 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)13131 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
13132   assert(N->getOpcode() == ISD::ZERO_EXTEND);
13133   EVT VT = N->getValueType(0);
13134   EVT OrigVT = N->getOperand(0).getValueType();
13135   if (TLI.isZExtFree(OrigVT, VT))
13136     return SDValue();
13137 
13138   // and/or/xor
13139   SDValue N0 = N->getOperand(0);
13140   if (!ISD::isBitwiseLogicOp(N0.getOpcode()) ||
13141       N0.getOperand(1).getOpcode() != ISD::Constant ||
13142       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
13143     return SDValue();
13144 
13145   // shl/shr
13146   SDValue N1 = N0->getOperand(0);
13147   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
13148       N1.getOperand(1).getOpcode() != ISD::Constant ||
13149       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
13150     return SDValue();
13151 
13152   // load
13153   if (!isa<LoadSDNode>(N1.getOperand(0)))
13154     return SDValue();
13155   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
13156   EVT MemVT = Load->getMemoryVT();
13157   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
13158       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
13159     return SDValue();
13160 
13161 
13162   // If the shift op is SHL, the logic op must be AND, otherwise the result
13163   // will be wrong.
13164   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
13165     return SDValue();
13166 
13167   if (!N0.hasOneUse() || !N1.hasOneUse())
13168     return SDValue();
13169 
13170   SmallVector<SDNode*, 4> SetCCs;
13171   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
13172                                ISD::ZERO_EXTEND, SetCCs, TLI))
13173     return SDValue();
13174 
13175   // Actually do the transformation.
13176   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
13177                                    Load->getChain(), Load->getBasePtr(),
13178                                    Load->getMemoryVT(), Load->getMemOperand());
13179 
13180   SDLoc DL1(N1);
13181   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
13182                               N1.getOperand(1));
13183 
13184   APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
13185   SDLoc DL0(N0);
13186   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
13187                             DAG.getConstant(Mask, DL0, VT));
13188 
13189   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
13190   CombineTo(N, And);
13191   if (SDValue(Load, 0).hasOneUse()) {
13192     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
13193   } else {
13194     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
13195                                 Load->getValueType(0), ExtLoad);
13196     CombineTo(Load, Trunc, ExtLoad.getValue(1));
13197   }
13198 
13199   // N0 is dead at this point.
13200   recursivelyDeleteUnusedNodes(N0.getNode());
13201 
13202   return SDValue(N,0); // Return N so it doesn't get rechecked!
13203 }
13204 
13205 /// If we're narrowing or widening the result of a vector select and the final
13206 /// size is the same size as a setcc (compare) feeding the select, then try to
13207 /// apply the cast operation to the select's operands because matching vector
13208 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)13209 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
13210   unsigned CastOpcode = Cast->getOpcode();
13211   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
13212           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13213           CastOpcode == ISD::FP_ROUND) &&
13214          "Unexpected opcode for vector select narrowing/widening");
13215 
13216   // We only do this transform before legal ops because the pattern may be
13217   // obfuscated by target-specific operations after legalization. Do not create
13218   // an illegal select op, however, because that may be difficult to lower.
13219   EVT VT = Cast->getValueType(0);
13220   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
13221     return SDValue();
13222 
13223   SDValue VSel = Cast->getOperand(0);
13224   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
13225       VSel.getOperand(0).getOpcode() != ISD::SETCC)
13226     return SDValue();
13227 
13228   // Does the setcc have the same vector size as the casted select?
13229   SDValue SetCC = VSel.getOperand(0);
13230   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
13231   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
13232     return SDValue();
13233 
13234   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
13235   SDValue A = VSel.getOperand(1);
13236   SDValue B = VSel.getOperand(2);
13237   SDValue CastA, CastB;
13238   SDLoc DL(Cast);
13239   if (CastOpcode == ISD::FP_ROUND) {
13240     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
13241     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
13242     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
13243   } else {
13244     CastA = DAG.getNode(CastOpcode, DL, VT, A);
13245     CastB = DAG.getNode(CastOpcode, DL, VT, B);
13246   }
13247   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
13248 }
13249 
13250 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13251 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)13252 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
13253                                      const TargetLowering &TLI, EVT VT,
13254                                      bool LegalOperations, SDNode *N,
13255                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
13256   SDNode *N0Node = N0.getNode();
13257   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
13258                                                    : ISD::isZEXTLoad(N0Node);
13259   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
13260       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
13261     return SDValue();
13262 
13263   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13264   EVT MemVT = LN0->getMemoryVT();
13265   if ((LegalOperations || !LN0->isSimple() ||
13266        VT.isVector()) &&
13267       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
13268     return SDValue();
13269 
13270   SDValue ExtLoad =
13271       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
13272                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
13273   Combiner.CombineTo(N, ExtLoad);
13274   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13275   if (LN0->use_empty())
13276     Combiner.recursivelyDeleteUnusedNodes(LN0);
13277   return SDValue(N, 0); // Return N so it doesn't get rechecked!
13278 }
13279 
13280 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13281 // Only generate vector extloads when 1) they're legal, and 2) they are
13282 // deemed desirable by the target. NonNegZExt can be set to true if a zero
13283 // extend has the nonneg flag to allow use of sextload if profitable.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc,bool NonNegZExt=false)13284 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
13285                                   const TargetLowering &TLI, EVT VT,
13286                                   bool LegalOperations, SDNode *N, SDValue N0,
13287                                   ISD::LoadExtType ExtLoadType,
13288                                   ISD::NodeType ExtOpc,
13289                                   bool NonNegZExt = false) {
13290   if (!ISD::isNON_EXTLoad(N0.getNode()) || !ISD::isUNINDEXEDLoad(N0.getNode()))
13291     return {};
13292 
13293   // If this is zext nneg, see if it would make sense to treat it as a sext.
13294   if (NonNegZExt) {
13295     assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
13296            "Unexpected load type or opcode");
13297     for (SDNode *User : N0->uses()) {
13298       if (User->getOpcode() == ISD::SETCC) {
13299         ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
13300         if (ISD::isSignedIntSetCC(CC)) {
13301           ExtLoadType = ISD::SEXTLOAD;
13302           ExtOpc = ISD::SIGN_EXTEND;
13303           break;
13304         }
13305       }
13306     }
13307   }
13308 
13309   // TODO: isFixedLengthVector() should be removed and any negative effects on
13310   // code generation being the result of that target's implementation of
13311   // isVectorLoadExtDesirable().
13312   if ((LegalOperations || VT.isFixedLengthVector() ||
13313        !cast<LoadSDNode>(N0)->isSimple()) &&
13314       !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))
13315     return {};
13316 
13317   bool DoXform = true;
13318   SmallVector<SDNode *, 4> SetCCs;
13319   if (!N0.hasOneUse())
13320     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
13321   if (VT.isVector())
13322     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
13323   if (!DoXform)
13324     return {};
13325 
13326   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13327   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
13328                                    LN0->getBasePtr(), N0.getValueType(),
13329                                    LN0->getMemOperand());
13330   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
13331   // If the load value is used only by N, replace it via CombineTo N.
13332   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
13333   Combiner.CombineTo(N, ExtLoad);
13334   if (NoReplaceTrunc) {
13335     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13336     Combiner.recursivelyDeleteUnusedNodes(LN0);
13337   } else {
13338     SDValue Trunc =
13339         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
13340     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
13341   }
13342   return SDValue(N, 0); // Return N so it doesn't get rechecked!
13343 }
13344 
13345 static SDValue
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)13346 tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
13347                          bool LegalOperations, SDNode *N, SDValue N0,
13348                          ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
13349   if (!N0.hasOneUse())
13350     return SDValue();
13351 
13352   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
13353   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
13354     return SDValue();
13355 
13356   if ((LegalOperations || !cast<MaskedLoadSDNode>(N0)->isSimple()) &&
13357       !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
13358     return SDValue();
13359 
13360   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
13361     return SDValue();
13362 
13363   SDLoc dl(Ld);
13364   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
13365   SDValue NewLoad = DAG.getMaskedLoad(
13366       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
13367       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
13368       ExtLoadType, Ld->isExpandingLoad());
13369   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
13370   return NewLoad;
13371 }
13372 
13373 // fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
tryToFoldExtOfAtomicLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDValue N0,ISD::LoadExtType ExtLoadType)13374 static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
13375                                         const TargetLowering &TLI, EVT VT,
13376                                         SDValue N0,
13377                                         ISD::LoadExtType ExtLoadType) {
13378   auto *ALoad = dyn_cast<AtomicSDNode>(N0);
13379   if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
13380     return {};
13381   EVT MemoryVT = ALoad->getMemoryVT();
13382   if (!TLI.isAtomicLoadExtLegal(ExtLoadType, VT, MemoryVT))
13383     return {};
13384   // Can't fold into ALoad if it is already extending differently.
13385   ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
13386   if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
13387       (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
13388     return {};
13389 
13390   EVT OrigVT = ALoad->getValueType(0);
13391   assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
13392   auto *NewALoad = cast<AtomicSDNode>(DAG.getAtomic(
13393       ISD::ATOMIC_LOAD, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
13394       ALoad->getBasePtr(), ALoad->getMemOperand()));
13395   NewALoad->setExtensionType(ExtLoadType);
13396   DAG.ReplaceAllUsesOfValueWith(
13397       SDValue(ALoad, 0),
13398       DAG.getNode(ISD::TRUNCATE, SDLoc(ALoad), OrigVT, SDValue(NewALoad, 0)));
13399   // Update the chain uses.
13400   DAG.ReplaceAllUsesOfValueWith(SDValue(ALoad, 1), SDValue(NewALoad, 1));
13401   return SDValue(NewALoad, 0);
13402 }
13403 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)13404 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
13405                                        bool LegalOperations) {
13406   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13407           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
13408 
13409   SDValue SetCC = N->getOperand(0);
13410   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
13411       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
13412     return SDValue();
13413 
13414   SDValue X = SetCC.getOperand(0);
13415   SDValue Ones = SetCC.getOperand(1);
13416   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
13417   EVT VT = N->getValueType(0);
13418   EVT XVT = X.getValueType();
13419   // setge X, C is canonicalized to setgt, so we do not need to match that
13420   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
13421   // not require the 'not' op.
13422   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
13423     // Invert and smear/shift the sign bit:
13424     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
13425     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
13426     SDLoc DL(N);
13427     unsigned ShCt = VT.getSizeInBits() - 1;
13428     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13429     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
13430       SDValue NotX = DAG.getNOT(DL, X, VT);
13431       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
13432       auto ShiftOpcode =
13433         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
13434       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
13435     }
13436   }
13437   return SDValue();
13438 }
13439 
foldSextSetcc(SDNode * N)13440 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
13441   SDValue N0 = N->getOperand(0);
13442   if (N0.getOpcode() != ISD::SETCC)
13443     return SDValue();
13444 
13445   SDValue N00 = N0.getOperand(0);
13446   SDValue N01 = N0.getOperand(1);
13447   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
13448   EVT VT = N->getValueType(0);
13449   EVT N00VT = N00.getValueType();
13450   SDLoc DL(N);
13451 
13452   // Propagate fast-math-flags.
13453   SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13454 
13455   // On some architectures (such as SSE/NEON/etc) the SETCC result type is
13456   // the same size as the compared operands. Try to optimize sext(setcc())
13457   // if this is the case.
13458   if (VT.isVector() && !LegalOperations &&
13459       TLI.getBooleanContents(N00VT) ==
13460           TargetLowering::ZeroOrNegativeOneBooleanContent) {
13461     EVT SVT = getSetCCResultType(N00VT);
13462 
13463     // If we already have the desired type, don't change it.
13464     if (SVT != N0.getValueType()) {
13465       // We know that the # elements of the results is the same as the
13466       // # elements of the compare (and the # elements of the compare result
13467       // for that matter).  Check to see that they are the same size.  If so,
13468       // we know that the element size of the sext'd result matches the
13469       // element size of the compare operands.
13470       if (VT.getSizeInBits() == SVT.getSizeInBits())
13471         return DAG.getSetCC(DL, VT, N00, N01, CC);
13472 
13473       // If the desired elements are smaller or larger than the source
13474       // elements, we can use a matching integer vector type and then
13475       // truncate/sign extend.
13476       EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
13477       if (SVT == MatchingVecType) {
13478         SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
13479         return DAG.getSExtOrTrunc(VsetCC, DL, VT);
13480       }
13481     }
13482 
13483     // Try to eliminate the sext of a setcc by zexting the compare operands.
13484     if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
13485         !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
13486       bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
13487       unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13488       unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13489 
13490       // We have an unsupported narrow vector compare op that would be legal
13491       // if extended to the destination type. See if the compare operands
13492       // can be freely extended to the destination type.
13493       auto IsFreeToExtend = [&](SDValue V) {
13494         if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
13495           return true;
13496         // Match a simple, non-extended load that can be converted to a
13497         // legal {z/s}ext-load.
13498         // TODO: Allow widening of an existing {z/s}ext-load?
13499         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
13500               ISD::isUNINDEXEDLoad(V.getNode()) &&
13501               cast<LoadSDNode>(V)->isSimple() &&
13502               TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
13503           return false;
13504 
13505         // Non-chain users of this value must either be the setcc in this
13506         // sequence or extends that can be folded into the new {z/s}ext-load.
13507         for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
13508              UI != UE; ++UI) {
13509           // Skip uses of the chain and the setcc.
13510           SDNode *User = *UI;
13511           if (UI.getUse().getResNo() != 0 || User == N0.getNode())
13512             continue;
13513           // Extra users must have exactly the same cast we are about to create.
13514           // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
13515           //       is enhanced similarly.
13516           if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
13517             return false;
13518         }
13519         return true;
13520       };
13521 
13522       if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
13523         SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
13524         SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
13525         return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
13526       }
13527     }
13528   }
13529 
13530   // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
13531   // Here, T can be 1 or -1, depending on the type of the setcc and
13532   // getBooleanContents().
13533   unsigned SetCCWidth = N0.getScalarValueSizeInBits();
13534 
13535   // To determine the "true" side of the select, we need to know the high bit
13536   // of the value returned by the setcc if it evaluates to true.
13537   // If the type of the setcc is i1, then the true case of the select is just
13538   // sext(i1 1), that is, -1.
13539   // If the type of the setcc is larger (say, i8) then the value of the high
13540   // bit depends on getBooleanContents(), so ask TLI for a real "true" value
13541   // of the appropriate width.
13542   SDValue ExtTrueVal = (SetCCWidth == 1)
13543                            ? DAG.getAllOnesConstant(DL, VT)
13544                            : DAG.getBoolConstant(true, DL, VT, N00VT);
13545   SDValue Zero = DAG.getConstant(0, DL, VT);
13546   if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
13547     return SCC;
13548 
13549   if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(N0, VT, TLI)) {
13550     EVT SetCCVT = getSetCCResultType(N00VT);
13551     // Don't do this transform for i1 because there's a select transform
13552     // that would reverse it.
13553     // TODO: We should not do this transform at all without a target hook
13554     // because a sext is likely cheaper than a select?
13555     if (SetCCVT.getScalarSizeInBits() != 1 &&
13556         (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
13557       SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
13558       return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
13559     }
13560   }
13561 
13562   return SDValue();
13563 }
13564 
visitSIGN_EXTEND(SDNode * N)13565 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
13566   SDValue N0 = N->getOperand(0);
13567   EVT VT = N->getValueType(0);
13568   SDLoc DL(N);
13569 
13570   if (VT.isVector())
13571     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13572       return FoldedVOp;
13573 
13574   // sext(undef) = 0 because the top bit will all be the same.
13575   if (N0.isUndef())
13576     return DAG.getConstant(0, DL, VT);
13577 
13578   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13579     return Res;
13580 
13581   // fold (sext (sext x)) -> (sext x)
13582   // fold (sext (aext x)) -> (sext x)
13583   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
13584     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
13585 
13586   // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13587   // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13588   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13589       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13590     return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
13591                        N0.getOperand(0));
13592 
13593   // fold (sext (sext_inreg x)) -> (sext (trunc x))
13594   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
13595     SDValue N00 = N0.getOperand(0);
13596     EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT();
13597     if ((N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(N00, ExtVT)) &&
13598         (!LegalTypes || TLI.isTypeLegal(ExtVT))) {
13599       SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00);
13600       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T);
13601     }
13602   }
13603 
13604   if (N0.getOpcode() == ISD::TRUNCATE) {
13605     // fold (sext (truncate (load x))) -> (sext (smaller load x))
13606     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
13607     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
13608       SDNode *oye = N0.getOperand(0).getNode();
13609       if (NarrowLoad.getNode() != N0.getNode()) {
13610         CombineTo(N0.getNode(), NarrowLoad);
13611         // CombineTo deleted the truncate, if needed, but not what's under it.
13612         AddToWorklist(oye);
13613       }
13614       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
13615     }
13616 
13617     // See if the value being truncated is already sign extended.  If so, just
13618     // eliminate the trunc/sext pair.
13619     SDValue Op = N0.getOperand(0);
13620     unsigned OpBits   = Op.getScalarValueSizeInBits();
13621     unsigned MidBits  = N0.getScalarValueSizeInBits();
13622     unsigned DestBits = VT.getScalarSizeInBits();
13623     unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13624 
13625     if (OpBits == DestBits) {
13626       // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
13627       // bits, it is already ready.
13628       if (NumSignBits > DestBits-MidBits)
13629         return Op;
13630     } else if (OpBits < DestBits) {
13631       // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
13632       // bits, just sext from i32.
13633       if (NumSignBits > OpBits-MidBits)
13634         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
13635     } else {
13636       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
13637       // bits, just truncate to i32.
13638       if (NumSignBits > OpBits-MidBits)
13639         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
13640     }
13641 
13642     // fold (sext (truncate x)) -> (sextinreg x).
13643     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
13644                                                  N0.getValueType())) {
13645       if (OpBits < DestBits)
13646         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
13647       else if (OpBits > DestBits)
13648         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
13649       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
13650                          DAG.getValueType(N0.getValueType()));
13651     }
13652   }
13653 
13654   // Try to simplify (sext (load x)).
13655   if (SDValue foldedExt =
13656           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
13657                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
13658     return foldedExt;
13659 
13660   if (SDValue foldedExt =
13661           tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13662                                    ISD::SEXTLOAD, ISD::SIGN_EXTEND))
13663     return foldedExt;
13664 
13665   // fold (sext (load x)) to multiple smaller sextloads.
13666   // Only on illegal but splittable vectors.
13667   if (SDValue ExtLoad = CombineExtLoad(N))
13668     return ExtLoad;
13669 
13670   // Try to simplify (sext (sextload x)).
13671   if (SDValue foldedExt = tryToFoldExtOfExtload(
13672           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
13673     return foldedExt;
13674 
13675   // Try to simplify (sext (atomic_load x)).
13676   if (SDValue foldedExt =
13677           tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::SEXTLOAD))
13678     return foldedExt;
13679 
13680   // fold (sext (and/or/xor (load x), cst)) ->
13681   //      (and/or/xor (sextload x), (sext cst))
13682   if (ISD::isBitwiseLogicOp(N0.getOpcode()) &&
13683       isa<LoadSDNode>(N0.getOperand(0)) &&
13684       N0.getOperand(1).getOpcode() == ISD::Constant &&
13685       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
13686     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
13687     EVT MemVT = LN00->getMemoryVT();
13688     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
13689       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
13690       SmallVector<SDNode*, 4> SetCCs;
13691       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
13692                                              ISD::SIGN_EXTEND, SetCCs, TLI);
13693       if (DoXform) {
13694         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
13695                                          LN00->getChain(), LN00->getBasePtr(),
13696                                          LN00->getMemoryVT(),
13697                                          LN00->getMemOperand());
13698         APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
13699         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
13700                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
13701         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
13702         bool NoReplaceTruncAnd = !N0.hasOneUse();
13703         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13704         CombineTo(N, And);
13705         // If N0 has multiple uses, change other uses as well.
13706         if (NoReplaceTruncAnd) {
13707           SDValue TruncAnd =
13708               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
13709           CombineTo(N0.getNode(), TruncAnd);
13710         }
13711         if (NoReplaceTrunc) {
13712           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
13713         } else {
13714           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
13715                                       LN00->getValueType(0), ExtLoad);
13716           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
13717         }
13718         return SDValue(N,0); // Return N so it doesn't get rechecked!
13719       }
13720     }
13721   }
13722 
13723   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13724     return V;
13725 
13726   if (SDValue V = foldSextSetcc(N))
13727     return V;
13728 
13729   // fold (sext x) -> (zext x) if the sign bit is known zero.
13730   if (!TLI.isSExtCheaperThanZExt(N0.getValueType(), VT) &&
13731       (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
13732       DAG.SignBitIsZero(N0)) {
13733     SDNodeFlags Flags;
13734     Flags.setNonNeg(true);
13735     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, Flags);
13736   }
13737 
13738   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
13739     return NewVSel;
13740 
13741   // Eliminate this sign extend by doing a negation in the destination type:
13742   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
13743   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
13744       isNullOrNullSplat(N0.getOperand(0)) &&
13745       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
13746       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
13747     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
13748     return DAG.getNegative(Zext, DL, VT);
13749   }
13750   // Eliminate this sign extend by doing a decrement in the destination type:
13751   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
13752   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
13753       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
13754       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
13755       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
13756     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
13757     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
13758   }
13759 
13760   // fold sext (not i1 X) -> add (zext i1 X), -1
13761   // TODO: This could be extended to handle bool vectors.
13762   if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
13763       (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
13764                             TLI.isOperationLegal(ISD::ADD, VT)))) {
13765     // If we can eliminate the 'not', the sext form should be better
13766     if (SDValue NewXor = visitXOR(N0.getNode())) {
13767       // Returning N0 is a form of in-visit replacement that may have
13768       // invalidated N0.
13769       if (NewXor.getNode() == N0.getNode()) {
13770         // Return SDValue here as the xor should have already been replaced in
13771         // this sext.
13772         return SDValue();
13773       }
13774 
13775       // Return a new sext with the new xor.
13776       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
13777     }
13778 
13779     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
13780     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
13781   }
13782 
13783   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
13784     return Res;
13785 
13786   return SDValue();
13787 }
13788 
13789 /// Given an extending node with a pop-count operand, if the target does not
13790 /// support a pop-count in the narrow source type but does support it in the
13791 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG,const SDLoc & DL)13792 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG, const SDLoc &DL) {
13793   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
13794           Extend->getOpcode() == ISD::ANY_EXTEND) &&
13795          "Expected extend op");
13796 
13797   SDValue CtPop = Extend->getOperand(0);
13798   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
13799     return SDValue();
13800 
13801   EVT VT = Extend->getValueType(0);
13802   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13803   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
13804       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
13805     return SDValue();
13806 
13807   // zext (ctpop X) --> ctpop (zext X)
13808   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
13809   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
13810 }
13811 
13812 // If we have (zext (abs X)) where X is a type that will be promoted by type
13813 // legalization, convert to (abs (sext X)). But don't extend past a legal type.
widenAbs(SDNode * Extend,SelectionDAG & DAG)13814 static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
13815   assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
13816 
13817   EVT VT = Extend->getValueType(0);
13818   if (VT.isVector())
13819     return SDValue();
13820 
13821   SDValue Abs = Extend->getOperand(0);
13822   if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
13823     return SDValue();
13824 
13825   EVT AbsVT = Abs.getValueType();
13826   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13827   if (TLI.getTypeAction(*DAG.getContext(), AbsVT) !=
13828       TargetLowering::TypePromoteInteger)
13829     return SDValue();
13830 
13831   EVT LegalVT = TLI.getTypeToTransformTo(*DAG.getContext(), AbsVT);
13832 
13833   SDValue SExt =
13834       DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Abs), LegalVT, Abs.getOperand(0));
13835   SDValue NewAbs = DAG.getNode(ISD::ABS, SDLoc(Abs), LegalVT, SExt);
13836   return DAG.getZExtOrTrunc(NewAbs, SDLoc(Extend), VT);
13837 }
13838 
visitZERO_EXTEND(SDNode * N)13839 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
13840   SDValue N0 = N->getOperand(0);
13841   EVT VT = N->getValueType(0);
13842   SDLoc DL(N);
13843 
13844   if (VT.isVector())
13845     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13846       return FoldedVOp;
13847 
13848   // zext(undef) = 0
13849   if (N0.isUndef())
13850     return DAG.getConstant(0, DL, VT);
13851 
13852   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13853     return Res;
13854 
13855   // fold (zext (zext x)) -> (zext x)
13856   // fold (zext (aext x)) -> (zext x)
13857   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
13858     SDNodeFlags Flags;
13859     if (N0.getOpcode() == ISD::ZERO_EXTEND)
13860       Flags.setNonNeg(N0->getFlags().hasNonNeg());
13861     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0), Flags);
13862   }
13863 
13864   // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13865   // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13866   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13867       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
13868     return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, N0.getOperand(0));
13869 
13870   // fold (zext (truncate x)) -> (zext x) or
13871   //      (zext (truncate x)) -> (truncate x)
13872   // This is valid when the truncated bits of x are already zero.
13873   SDValue Op;
13874   KnownBits Known;
13875   if (isTruncateOf(DAG, N0, Op, Known)) {
13876     APInt TruncatedBits =
13877       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
13878       APInt(Op.getScalarValueSizeInBits(), 0) :
13879       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
13880                         N0.getScalarValueSizeInBits(),
13881                         std::min(Op.getScalarValueSizeInBits(),
13882                                  VT.getScalarSizeInBits()));
13883     if (TruncatedBits.isSubsetOf(Known.Zero)) {
13884       SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13885       DAG.salvageDebugInfo(*N0.getNode());
13886 
13887       return ZExtOrTrunc;
13888     }
13889   }
13890 
13891   // fold (zext (truncate x)) -> (and x, mask)
13892   if (N0.getOpcode() == ISD::TRUNCATE) {
13893     // fold (zext (truncate (load x))) -> (zext (smaller load x))
13894     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
13895     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
13896       SDNode *oye = N0.getOperand(0).getNode();
13897       if (NarrowLoad.getNode() != N0.getNode()) {
13898         CombineTo(N0.getNode(), NarrowLoad);
13899         // CombineTo deleted the truncate, if needed, but not what's under it.
13900         AddToWorklist(oye);
13901       }
13902       return SDValue(N, 0); // Return N so it doesn't get rechecked!
13903     }
13904 
13905     EVT SrcVT = N0.getOperand(0).getValueType();
13906     EVT MinVT = N0.getValueType();
13907 
13908     if (N->getFlags().hasNonNeg()) {
13909       SDValue Op = N0.getOperand(0);
13910       unsigned OpBits = SrcVT.getScalarSizeInBits();
13911       unsigned MidBits = MinVT.getScalarSizeInBits();
13912       unsigned DestBits = VT.getScalarSizeInBits();
13913       unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13914 
13915       if (OpBits == DestBits) {
13916         // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
13917         // bits, it is already ready.
13918         if (NumSignBits > DestBits - MidBits)
13919           return Op;
13920       } else if (OpBits < DestBits) {
13921         // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
13922         // bits, just sext from i32.
13923         // FIXME: This can probably be ZERO_EXTEND nneg?
13924         if (NumSignBits > OpBits - MidBits)
13925           return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
13926       } else {
13927         // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
13928         // bits, just truncate to i32.
13929         if (NumSignBits > OpBits - MidBits)
13930           return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
13931       }
13932     }
13933 
13934     // Try to mask before the extension to avoid having to generate a larger mask,
13935     // possibly over several sub-vectors.
13936     if (SrcVT.bitsLT(VT) && VT.isVector()) {
13937       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
13938                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
13939         SDValue Op = N0.getOperand(0);
13940         Op = DAG.getZeroExtendInReg(Op, DL, MinVT);
13941         AddToWorklist(Op.getNode());
13942         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13943         // Transfer the debug info; the new node is equivalent to N0.
13944         DAG.transferDbgValues(N0, ZExtOrTrunc);
13945         return ZExtOrTrunc;
13946       }
13947     }
13948 
13949     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
13950       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
13951       AddToWorklist(Op.getNode());
13952       SDValue And = DAG.getZeroExtendInReg(Op, DL, MinVT);
13953       // We may safely transfer the debug info describing the truncate node over
13954       // to the equivalent and operation.
13955       DAG.transferDbgValues(N0, And);
13956       return And;
13957     }
13958   }
13959 
13960   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
13961   // if either of the casts is not free.
13962   if (N0.getOpcode() == ISD::AND &&
13963       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
13964       N0.getOperand(1).getOpcode() == ISD::Constant &&
13965       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType()) ||
13966        !TLI.isZExtFree(N0.getValueType(), VT))) {
13967     SDValue X = N0.getOperand(0).getOperand(0);
13968     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
13969     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
13970     return DAG.getNode(ISD::AND, DL, VT,
13971                        X, DAG.getConstant(Mask, DL, VT));
13972   }
13973 
13974   // Try to simplify (zext (load x)).
13975   if (SDValue foldedExt = tryToFoldExtOfLoad(
13976           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD,
13977           ISD::ZERO_EXTEND, N->getFlags().hasNonNeg()))
13978     return foldedExt;
13979 
13980   if (SDValue foldedExt =
13981           tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13982                                    ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
13983     return foldedExt;
13984 
13985   // fold (zext (load x)) to multiple smaller zextloads.
13986   // Only on illegal but splittable vectors.
13987   if (SDValue ExtLoad = CombineExtLoad(N))
13988     return ExtLoad;
13989 
13990   // Try to simplify (zext (atomic_load x)).
13991   if (SDValue foldedExt =
13992           tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::ZEXTLOAD))
13993     return foldedExt;
13994 
13995   // fold (zext (and/or/xor (load x), cst)) ->
13996   //      (and/or/xor (zextload x), (zext cst))
13997   // Unless (and (load x) cst) will match as a zextload already and has
13998   // additional users, or the zext is already free.
13999   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && !TLI.isZExtFree(N0, VT) &&
14000       isa<LoadSDNode>(N0.getOperand(0)) &&
14001       N0.getOperand(1).getOpcode() == ISD::Constant &&
14002       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
14003     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
14004     EVT MemVT = LN00->getMemoryVT();
14005     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
14006         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
14007       bool DoXform = true;
14008       SmallVector<SDNode*, 4> SetCCs;
14009       if (!N0.hasOneUse()) {
14010         if (N0.getOpcode() == ISD::AND) {
14011           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
14012           EVT LoadResultTy = AndC->getValueType(0);
14013           EVT ExtVT;
14014           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
14015             DoXform = false;
14016         }
14017       }
14018       if (DoXform)
14019         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
14020                                           ISD::ZERO_EXTEND, SetCCs, TLI);
14021       if (DoXform) {
14022         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
14023                                          LN00->getChain(), LN00->getBasePtr(),
14024                                          LN00->getMemoryVT(),
14025                                          LN00->getMemOperand());
14026         APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14027         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
14028                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
14029         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
14030         bool NoReplaceTruncAnd = !N0.hasOneUse();
14031         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14032         CombineTo(N, And);
14033         // If N0 has multiple uses, change other uses as well.
14034         if (NoReplaceTruncAnd) {
14035           SDValue TruncAnd =
14036               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
14037           CombineTo(N0.getNode(), TruncAnd);
14038         }
14039         if (NoReplaceTrunc) {
14040           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
14041         } else {
14042           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
14043                                       LN00->getValueType(0), ExtLoad);
14044           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
14045         }
14046         return SDValue(N,0); // Return N so it doesn't get rechecked!
14047       }
14048     }
14049   }
14050 
14051   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
14052   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
14053   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
14054     return ZExtLoad;
14055 
14056   // Try to simplify (zext (zextload x)).
14057   if (SDValue foldedExt = tryToFoldExtOfExtload(
14058           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
14059     return foldedExt;
14060 
14061   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14062     return V;
14063 
14064   if (N0.getOpcode() == ISD::SETCC) {
14065     // Propagate fast-math-flags.
14066     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14067 
14068     // Only do this before legalize for now.
14069     if (!LegalOperations && VT.isVector() &&
14070         N0.getValueType().getVectorElementType() == MVT::i1) {
14071       EVT N00VT = N0.getOperand(0).getValueType();
14072       if (getSetCCResultType(N00VT) == N0.getValueType())
14073         return SDValue();
14074 
14075       // We know that the # elements of the results is the same as the #
14076       // elements of the compare (and the # elements of the compare result for
14077       // that matter). Check to see that they are the same size. If so, we know
14078       // that the element size of the sext'd result matches the element size of
14079       // the compare operands.
14080       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
14081         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
14082         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
14083                                      N0.getOperand(1), N0.getOperand(2));
14084         return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
14085       }
14086 
14087       // If the desired elements are smaller or larger than the source
14088       // elements we can use a matching integer vector type and then
14089       // truncate/any extend followed by zext_in_reg.
14090       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14091       SDValue VsetCC =
14092           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
14093                       N0.getOperand(1), N0.getOperand(2));
14094       return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
14095                                     N0.getValueType());
14096     }
14097 
14098     // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
14099     EVT N0VT = N0.getValueType();
14100     EVT N00VT = N0.getOperand(0).getValueType();
14101     if (SDValue SCC = SimplifySelectCC(
14102             DL, N0.getOperand(0), N0.getOperand(1),
14103             DAG.getBoolConstant(true, DL, N0VT, N00VT),
14104             DAG.getBoolConstant(false, DL, N0VT, N00VT),
14105             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
14106       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
14107   }
14108 
14109   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
14110   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
14111       !TLI.isZExtFree(N0, VT)) {
14112     SDValue ShVal = N0.getOperand(0);
14113     SDValue ShAmt = N0.getOperand(1);
14114     if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) {
14115       if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
14116         if (N0.getOpcode() == ISD::SHL) {
14117           // If the original shl may be shifting out bits, do not perform this
14118           // transformation.
14119           unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
14120                                    ShVal.getOperand(0).getValueSizeInBits();
14121           if (ShAmtC->getAPIntValue().ugt(KnownZeroBits)) {
14122             // If the shift is too large, then see if we can deduce that the
14123             // shift is safe anyway.
14124             // Create a mask that has ones for the bits being shifted out.
14125             APInt ShiftOutMask =
14126                 APInt::getHighBitsSet(ShVal.getValueSizeInBits(),
14127                                       ShAmtC->getAPIntValue().getZExtValue());
14128 
14129             // Check if the bits being shifted out are known to be zero.
14130             if (!DAG.MaskedValueIsZero(ShVal, ShiftOutMask))
14131               return SDValue();
14132           }
14133         }
14134 
14135         // Ensure that the shift amount is wide enough for the shifted value.
14136         if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
14137           ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
14138 
14139         return DAG.getNode(N0.getOpcode(), DL, VT,
14140                            DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ShVal), ShAmt);
14141       }
14142     }
14143   }
14144 
14145   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14146     return NewVSel;
14147 
14148   if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
14149     return NewCtPop;
14150 
14151   if (SDValue V = widenAbs(N, DAG))
14152     return V;
14153 
14154   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
14155     return Res;
14156 
14157   // CSE zext nneg with sext if the zext is not free.
14158   if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(N0.getValueType(), VT)) {
14159     SDNode *CSENode = DAG.getNodeIfExists(ISD::SIGN_EXTEND, N->getVTList(), N0);
14160     if (CSENode)
14161       return SDValue(CSENode, 0);
14162   }
14163 
14164   return SDValue();
14165 }
14166 
visitANY_EXTEND(SDNode * N)14167 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
14168   SDValue N0 = N->getOperand(0);
14169   EVT VT = N->getValueType(0);
14170   SDLoc DL(N);
14171 
14172   // aext(undef) = undef
14173   if (N0.isUndef())
14174     return DAG.getUNDEF(VT);
14175 
14176   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14177     return Res;
14178 
14179   // fold (aext (aext x)) -> (aext x)
14180   // fold (aext (zext x)) -> (zext x)
14181   // fold (aext (sext x)) -> (sext x)
14182   if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
14183       N0.getOpcode() == ISD::SIGN_EXTEND) {
14184     SDNodeFlags Flags;
14185     if (N0.getOpcode() == ISD::ZERO_EXTEND)
14186       Flags.setNonNeg(N0->getFlags().hasNonNeg());
14187     return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Flags);
14188   }
14189 
14190   // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
14191   // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14192   // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14193   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14194       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
14195       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
14196     return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0));
14197 
14198   // fold (aext (truncate (load x))) -> (aext (smaller load x))
14199   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
14200   if (N0.getOpcode() == ISD::TRUNCATE) {
14201     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
14202       SDNode *oye = N0.getOperand(0).getNode();
14203       if (NarrowLoad.getNode() != N0.getNode()) {
14204         CombineTo(N0.getNode(), NarrowLoad);
14205         // CombineTo deleted the truncate, if needed, but not what's under it.
14206         AddToWorklist(oye);
14207       }
14208       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14209     }
14210   }
14211 
14212   // fold (aext (truncate x))
14213   if (N0.getOpcode() == ISD::TRUNCATE)
14214     return DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
14215 
14216   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
14217   // if the trunc is not free.
14218   if (N0.getOpcode() == ISD::AND &&
14219       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
14220       N0.getOperand(1).getOpcode() == ISD::Constant &&
14221       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType())) {
14222     SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
14223     SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
14224     assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
14225     return DAG.getNode(ISD::AND, DL, VT, X, Y);
14226   }
14227 
14228   // fold (aext (load x)) -> (aext (truncate (extload x)))
14229   // None of the supported targets knows how to perform load and any_ext
14230   // on vectors in one instruction, so attempt to fold to zext instead.
14231   if (VT.isVector()) {
14232     // Try to simplify (zext (load x)).
14233     if (SDValue foldedExt =
14234             tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
14235                                ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
14236       return foldedExt;
14237   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
14238              ISD::isUNINDEXEDLoad(N0.getNode()) &&
14239              TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
14240     bool DoXform = true;
14241     SmallVector<SDNode *, 4> SetCCs;
14242     if (!N0.hasOneUse())
14243       DoXform =
14244           ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
14245     if (DoXform) {
14246       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14247       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT, LN0->getChain(),
14248                                        LN0->getBasePtr(), N0.getValueType(),
14249                                        LN0->getMemOperand());
14250       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
14251       // If the load value is used only by N, replace it via CombineTo N.
14252       bool NoReplaceTrunc = N0.hasOneUse();
14253       CombineTo(N, ExtLoad);
14254       if (NoReplaceTrunc) {
14255         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14256         recursivelyDeleteUnusedNodes(LN0);
14257       } else {
14258         SDValue Trunc =
14259             DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
14260         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
14261       }
14262       return SDValue(N, 0); // Return N so it doesn't get rechecked!
14263     }
14264   }
14265 
14266   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
14267   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
14268   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
14269   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
14270       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
14271     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14272     ISD::LoadExtType ExtType = LN0->getExtensionType();
14273     EVT MemVT = LN0->getMemoryVT();
14274     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
14275       SDValue ExtLoad =
14276           DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(),
14277                          MemVT, LN0->getMemOperand());
14278       CombineTo(N, ExtLoad);
14279       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14280       recursivelyDeleteUnusedNodes(LN0);
14281       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14282     }
14283   }
14284 
14285   if (N0.getOpcode() == ISD::SETCC) {
14286     // Propagate fast-math-flags.
14287     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14288 
14289     // For vectors:
14290     // aext(setcc) -> vsetcc
14291     // aext(setcc) -> truncate(vsetcc)
14292     // aext(setcc) -> aext(vsetcc)
14293     // Only do this before legalize for now.
14294     if (VT.isVector() && !LegalOperations) {
14295       EVT N00VT = N0.getOperand(0).getValueType();
14296       if (getSetCCResultType(N00VT) == N0.getValueType())
14297         return SDValue();
14298 
14299       // We know that the # elements of the results is the same as the
14300       // # elements of the compare (and the # elements of the compare result
14301       // for that matter).  Check to see that they are the same size.  If so,
14302       // we know that the element size of the sext'd result matches the
14303       // element size of the compare operands.
14304       if (VT.getSizeInBits() == N00VT.getSizeInBits())
14305         return DAG.getSetCC(DL, VT, N0.getOperand(0), N0.getOperand(1),
14306                             cast<CondCodeSDNode>(N0.getOperand(2))->get());
14307 
14308       // If the desired elements are smaller or larger than the source
14309       // elements we can use a matching integer vector type and then
14310       // truncate/any extend
14311       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14312       SDValue VsetCC = DAG.getSetCC(
14313           DL, MatchingVectorType, N0.getOperand(0), N0.getOperand(1),
14314           cast<CondCodeSDNode>(N0.getOperand(2))->get());
14315       return DAG.getAnyExtOrTrunc(VsetCC, DL, VT);
14316     }
14317 
14318     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
14319     if (SDValue SCC = SimplifySelectCC(
14320             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
14321             DAG.getConstant(0, DL, VT),
14322             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
14323       return SCC;
14324   }
14325 
14326   if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
14327     return NewCtPop;
14328 
14329   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
14330     return Res;
14331 
14332   return SDValue();
14333 }
14334 
visitAssertExt(SDNode * N)14335 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
14336   unsigned Opcode = N->getOpcode();
14337   SDValue N0 = N->getOperand(0);
14338   SDValue N1 = N->getOperand(1);
14339   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
14340 
14341   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
14342   if (N0.getOpcode() == Opcode &&
14343       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
14344     return N0;
14345 
14346   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14347       N0.getOperand(0).getOpcode() == Opcode) {
14348     // We have an assert, truncate, assert sandwich. Make one stronger assert
14349     // by asserting on the smallest asserted type to the larger source type.
14350     // This eliminates the later assert:
14351     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
14352     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
14353     SDLoc DL(N);
14354     SDValue BigA = N0.getOperand(0);
14355     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
14356     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
14357     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
14358     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
14359                                     BigA.getOperand(0), MinAssertVTVal);
14360     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
14361   }
14362 
14363   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
14364   // than X. Just move the AssertZext in front of the truncate and drop the
14365   // AssertSExt.
14366   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14367       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
14368       Opcode == ISD::AssertZext) {
14369     SDValue BigA = N0.getOperand(0);
14370     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
14371     if (AssertVT.bitsLT(BigA_AssertVT)) {
14372       SDLoc DL(N);
14373       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
14374                                       BigA.getOperand(0), N1);
14375       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
14376     }
14377   }
14378 
14379   return SDValue();
14380 }
14381 
visitAssertAlign(SDNode * N)14382 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
14383   SDLoc DL(N);
14384 
14385   Align AL = cast<AssertAlignSDNode>(N)->getAlign();
14386   SDValue N0 = N->getOperand(0);
14387 
14388   // Fold (assertalign (assertalign x, AL0), AL1) ->
14389   // (assertalign x, max(AL0, AL1))
14390   if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
14391     return DAG.getAssertAlign(DL, N0.getOperand(0),
14392                               std::max(AL, AAN->getAlign()));
14393 
14394   // In rare cases, there are trivial arithmetic ops in source operands. Sink
14395   // this assert down to source operands so that those arithmetic ops could be
14396   // exposed to the DAG combining.
14397   switch (N0.getOpcode()) {
14398   default:
14399     break;
14400   case ISD::ADD:
14401   case ISD::SUB: {
14402     unsigned AlignShift = Log2(AL);
14403     SDValue LHS = N0.getOperand(0);
14404     SDValue RHS = N0.getOperand(1);
14405     unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
14406     unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
14407     if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
14408       if (LHSAlignShift < AlignShift)
14409         LHS = DAG.getAssertAlign(DL, LHS, AL);
14410       if (RHSAlignShift < AlignShift)
14411         RHS = DAG.getAssertAlign(DL, RHS, AL);
14412       return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
14413     }
14414     break;
14415   }
14416   }
14417 
14418   return SDValue();
14419 }
14420 
14421 /// If the result of a load is shifted/masked/truncated to an effectively
14422 /// narrower type, try to transform the load to a narrower type and/or
14423 /// use an extending load.
reduceLoadWidth(SDNode * N)14424 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
14425   unsigned Opc = N->getOpcode();
14426 
14427   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
14428   SDValue N0 = N->getOperand(0);
14429   EVT VT = N->getValueType(0);
14430   EVT ExtVT = VT;
14431 
14432   // This transformation isn't valid for vector loads.
14433   if (VT.isVector())
14434     return SDValue();
14435 
14436   // The ShAmt variable is used to indicate that we've consumed a right
14437   // shift. I.e. we want to narrow the width of the load by skipping to load the
14438   // ShAmt least significant bits.
14439   unsigned ShAmt = 0;
14440   // A special case is when the least significant bits from the load are masked
14441   // away, but using an AND rather than a right shift. HasShiftedOffset is used
14442   // to indicate that the narrowed load should be left-shifted ShAmt bits to get
14443   // the result.
14444   unsigned ShiftedOffset = 0;
14445   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
14446   // extended to VT.
14447   if (Opc == ISD::SIGN_EXTEND_INREG) {
14448     ExtType = ISD::SEXTLOAD;
14449     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
14450   } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
14451     // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
14452     // value, or it may be shifting a higher subword, half or byte into the
14453     // lowest bits.
14454 
14455     // Only handle shift with constant shift amount, and the shiftee must be a
14456     // load.
14457     auto *LN = dyn_cast<LoadSDNode>(N0);
14458     auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
14459     if (!N1C || !LN)
14460       return SDValue();
14461     // If the shift amount is larger than the memory type then we're not
14462     // accessing any of the loaded bytes.
14463     ShAmt = N1C->getZExtValue();
14464     uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
14465     if (MemoryWidth <= ShAmt)
14466       return SDValue();
14467     // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
14468     ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
14469     ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
14470     // If original load is a SEXTLOAD then we can't simply replace it by a
14471     // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
14472     // followed by a ZEXT, but that is not handled at the moment). Similarly if
14473     // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
14474     if ((LN->getExtensionType() == ISD::SEXTLOAD ||
14475          LN->getExtensionType() == ISD::ZEXTLOAD) &&
14476         LN->getExtensionType() != ExtType)
14477       return SDValue();
14478   } else if (Opc == ISD::AND) {
14479     // An AND with a constant mask is the same as a truncate + zero-extend.
14480     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
14481     if (!AndC)
14482       return SDValue();
14483 
14484     const APInt &Mask = AndC->getAPIntValue();
14485     unsigned ActiveBits = 0;
14486     if (Mask.isMask()) {
14487       ActiveBits = Mask.countr_one();
14488     } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
14489       ShiftedOffset = ShAmt;
14490     } else {
14491       return SDValue();
14492     }
14493 
14494     ExtType = ISD::ZEXTLOAD;
14495     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
14496   }
14497 
14498   // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
14499   // a right shift. Here we redo some of those checks, to possibly adjust the
14500   // ExtVT even further based on "a masking AND". We could also end up here for
14501   // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
14502   // need to be done here as well.
14503   if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
14504     SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
14505     // Bail out when the SRL has more than one use. This is done for historical
14506     // (undocumented) reasons. Maybe intent was to guard the AND-masking below
14507     // check below? And maybe it could be non-profitable to do the transform in
14508     // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
14509     // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
14510     if (!SRL.hasOneUse())
14511       return SDValue();
14512 
14513     // Only handle shift with constant shift amount, and the shiftee must be a
14514     // load.
14515     auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
14516     auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
14517     if (!SRL1C || !LN)
14518       return SDValue();
14519 
14520     // If the shift amount is larger than the input type then we're not
14521     // accessing any of the loaded bytes.  If the load was a zextload/extload
14522     // then the result of the shift+trunc is zero/undef (handled elsewhere).
14523     ShAmt = SRL1C->getZExtValue();
14524     uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
14525     if (ShAmt >= MemoryWidth)
14526       return SDValue();
14527 
14528     // Because a SRL must be assumed to *need* to zero-extend the high bits
14529     // (as opposed to anyext the high bits), we can't combine the zextload
14530     // lowering of SRL and an sextload.
14531     if (LN->getExtensionType() == ISD::SEXTLOAD)
14532       return SDValue();
14533 
14534     // Avoid reading outside the memory accessed by the original load (could
14535     // happened if we only adjust the load base pointer by ShAmt). Instead we
14536     // try to narrow the load even further. The typical scenario here is:
14537     //   (i64 (truncate (i96 (srl (load x), 64)))) ->
14538     //     (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
14539     if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
14540       // Don't replace sextload by zextload.
14541       if (ExtType == ISD::SEXTLOAD)
14542         return SDValue();
14543       // Narrow the load.
14544       ExtType = ISD::ZEXTLOAD;
14545       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
14546     }
14547 
14548     // If the SRL is only used by a masking AND, we may be able to adjust
14549     // the ExtVT to make the AND redundant.
14550     SDNode *Mask = *(SRL->use_begin());
14551     if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
14552         isa<ConstantSDNode>(Mask->getOperand(1))) {
14553       unsigned Offset, ActiveBits;
14554       const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
14555       if (ShiftMask.isMask()) {
14556         EVT MaskedVT =
14557             EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one());
14558         // If the mask is smaller, recompute the type.
14559         if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
14560             TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
14561           ExtVT = MaskedVT;
14562       } else if (ExtType == ISD::ZEXTLOAD &&
14563                  ShiftMask.isShiftedMask(Offset, ActiveBits) &&
14564                  (Offset + ShAmt) < VT.getScalarSizeInBits()) {
14565         EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
14566         // If the mask is shifted we can use a narrower load and a shl to insert
14567         // the trailing zeros.
14568         if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
14569             TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) {
14570           ExtVT = MaskedVT;
14571           ShAmt = Offset + ShAmt;
14572           ShiftedOffset = Offset;
14573         }
14574       }
14575     }
14576 
14577     N0 = SRL.getOperand(0);
14578   }
14579 
14580   // If the load is shifted left (and the result isn't shifted back right), we
14581   // can fold a truncate through the shift. The typical scenario is that N
14582   // points at a TRUNCATE here so the attempted fold is:
14583   //   (truncate (shl (load x), c))) -> (shl (narrow load x), c)
14584   // ShLeftAmt will indicate how much a narrowed load should be shifted left.
14585   unsigned ShLeftAmt = 0;
14586   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14587       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
14588     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
14589       ShLeftAmt = N01->getZExtValue();
14590       N0 = N0.getOperand(0);
14591     }
14592   }
14593 
14594   // If we haven't found a load, we can't narrow it.
14595   if (!isa<LoadSDNode>(N0))
14596     return SDValue();
14597 
14598   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14599   // Reducing the width of a volatile load is illegal.  For atomics, we may be
14600   // able to reduce the width provided we never widen again. (see D66309)
14601   if (!LN0->isSimple() ||
14602       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
14603     return SDValue();
14604 
14605   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
14606     unsigned LVTStoreBits =
14607         LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
14608     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
14609     return LVTStoreBits - EVTStoreBits - ShAmt;
14610   };
14611 
14612   // We need to adjust the pointer to the load by ShAmt bits in order to load
14613   // the correct bytes.
14614   unsigned PtrAdjustmentInBits =
14615       DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
14616 
14617   uint64_t PtrOff = PtrAdjustmentInBits / 8;
14618   SDLoc DL(LN0);
14619   // The original load itself didn't wrap, so an offset within it doesn't.
14620   SDNodeFlags Flags;
14621   Flags.setNoUnsignedWrap(true);
14622   SDValue NewPtr = DAG.getMemBasePlusOffset(
14623       LN0->getBasePtr(), TypeSize::getFixed(PtrOff), DL, Flags);
14624   AddToWorklist(NewPtr.getNode());
14625 
14626   SDValue Load;
14627   if (ExtType == ISD::NON_EXTLOAD)
14628     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
14629                        LN0->getPointerInfo().getWithOffset(PtrOff),
14630                        LN0->getOriginalAlign(),
14631                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
14632   else
14633     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
14634                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
14635                           LN0->getOriginalAlign(),
14636                           LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
14637 
14638   // Replace the old load's chain with the new load's chain.
14639   WorklistRemover DeadNodes(*this);
14640   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
14641 
14642   // Shift the result left, if we've swallowed a left shift.
14643   SDValue Result = Load;
14644   if (ShLeftAmt != 0) {
14645     // If the shift amount is as large as the result size (but, presumably,
14646     // no larger than the source) then the useful bits of the result are
14647     // zero; we can't simply return the shortened shift, because the result
14648     // of that operation is undefined.
14649     if (ShLeftAmt >= VT.getScalarSizeInBits())
14650       Result = DAG.getConstant(0, DL, VT);
14651     else
14652       Result = DAG.getNode(ISD::SHL, DL, VT, Result,
14653                            DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
14654   }
14655 
14656   if (ShiftedOffset != 0) {
14657     // We're using a shifted mask, so the load now has an offset. This means
14658     // that data has been loaded into the lower bytes than it would have been
14659     // before, so we need to shl the loaded data into the correct position in the
14660     // register.
14661     SDValue ShiftC = DAG.getConstant(ShiftedOffset, DL, VT);
14662     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
14663     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
14664   }
14665 
14666   // Return the new loaded value.
14667   return Result;
14668 }
14669 
visitSIGN_EXTEND_INREG(SDNode * N)14670 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
14671   SDValue N0 = N->getOperand(0);
14672   SDValue N1 = N->getOperand(1);
14673   EVT VT = N->getValueType(0);
14674   EVT ExtVT = cast<VTSDNode>(N1)->getVT();
14675   unsigned VTBits = VT.getScalarSizeInBits();
14676   unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
14677 
14678   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
14679   if (N0.isUndef())
14680     return DAG.getConstant(0, SDLoc(N), VT);
14681 
14682   // fold (sext_in_reg c1) -> c1
14683   if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
14684     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
14685 
14686   // If the input is already sign extended, just drop the extension.
14687   if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
14688     return N0;
14689 
14690   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
14691   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14692       ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
14693     return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
14694                        N1);
14695 
14696   // fold (sext_in_reg (sext x)) -> (sext x)
14697   // fold (sext_in_reg (aext x)) -> (sext x)
14698   // if x is small enough or if we know that x has more than 1 sign bit and the
14699   // sign_extend_inreg is extending from one of them.
14700   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14701     SDValue N00 = N0.getOperand(0);
14702     unsigned N00Bits = N00.getScalarValueSizeInBits();
14703     if ((N00Bits <= ExtVTBits ||
14704          DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
14705         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
14706       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
14707   }
14708 
14709   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
14710   // if x is small enough or if we know that x has more than 1 sign bit and the
14711   // sign_extend_inreg is extending from one of them.
14712   if (ISD::isExtVecInRegOpcode(N0.getOpcode())) {
14713     SDValue N00 = N0.getOperand(0);
14714     unsigned N00Bits = N00.getScalarValueSizeInBits();
14715     unsigned DstElts = N0.getValueType().getVectorMinNumElements();
14716     unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
14717     bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
14718     APInt DemandedSrcElts = APInt::getLowBitsSet(SrcElts, DstElts);
14719     if ((N00Bits == ExtVTBits ||
14720          (!IsZext && (N00Bits < ExtVTBits ||
14721                       DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
14722         (!LegalOperations ||
14723          TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
14724       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, N00);
14725   }
14726 
14727   // fold (sext_in_reg (zext x)) -> (sext x)
14728   // iff we are extending the source sign bit.
14729   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
14730     SDValue N00 = N0.getOperand(0);
14731     if (N00.getScalarValueSizeInBits() == ExtVTBits &&
14732         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
14733       return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
14734   }
14735 
14736   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
14737   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
14738     return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
14739 
14740   // fold operands of sext_in_reg based on knowledge that the top bits are not
14741   // demanded.
14742   if (SimplifyDemandedBits(SDValue(N, 0)))
14743     return SDValue(N, 0);
14744 
14745   // fold (sext_in_reg (load x)) -> (smaller sextload x)
14746   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
14747   if (SDValue NarrowLoad = reduceLoadWidth(N))
14748     return NarrowLoad;
14749 
14750   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
14751   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
14752   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
14753   if (N0.getOpcode() == ISD::SRL) {
14754     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
14755       if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
14756         // We can turn this into an SRA iff the input to the SRL is already sign
14757         // extended enough.
14758         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
14759         if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
14760           return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
14761                              N0.getOperand(1));
14762       }
14763   }
14764 
14765   // fold (sext_inreg (extload x)) -> (sextload x)
14766   // If sextload is not supported by target, we can only do the combine when
14767   // load has one use. Doing otherwise can block folding the extload with other
14768   // extends that the target does support.
14769   if (ISD::isEXTLoad(N0.getNode()) &&
14770       ISD::isUNINDEXEDLoad(N0.getNode()) &&
14771       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
14772       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
14773         N0.hasOneUse()) ||
14774        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
14775     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14776     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
14777                                      LN0->getChain(),
14778                                      LN0->getBasePtr(), ExtVT,
14779                                      LN0->getMemOperand());
14780     CombineTo(N, ExtLoad);
14781     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
14782     AddToWorklist(ExtLoad.getNode());
14783     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14784   }
14785 
14786   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
14787   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
14788       N0.hasOneUse() &&
14789       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
14790       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
14791        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
14792     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14793     SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
14794                                      LN0->getChain(),
14795                                      LN0->getBasePtr(), ExtVT,
14796                                      LN0->getMemOperand());
14797     CombineTo(N, ExtLoad);
14798     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
14799     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14800   }
14801 
14802   // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
14803   // ignore it if the masked load is already sign extended
14804   if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
14805     if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
14806         Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
14807         TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
14808       SDValue ExtMaskedLoad = DAG.getMaskedLoad(
14809           VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
14810           Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
14811           Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
14812       CombineTo(N, ExtMaskedLoad);
14813       CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
14814       return SDValue(N, 0); // Return N so it doesn't get rechecked!
14815     }
14816   }
14817 
14818   // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
14819   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
14820     if (SDValue(GN0, 0).hasOneUse() &&
14821         ExtVT == GN0->getMemoryVT() &&
14822         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
14823       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
14824                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
14825 
14826       SDValue ExtLoad = DAG.getMaskedGather(
14827           DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
14828           GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
14829 
14830       CombineTo(N, ExtLoad);
14831       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
14832       AddToWorklist(ExtLoad.getNode());
14833       return SDValue(N, 0); // Return N so it doesn't get rechecked!
14834     }
14835   }
14836 
14837   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
14838   if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
14839     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
14840                                            N0.getOperand(1), false))
14841       return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
14842   }
14843 
14844   // Fold (iM_signext_inreg
14845   //        (extract_subvector (zext|anyext|sext iN_v to _) _)
14846   //        from iN)
14847   //      -> (extract_subvector (signext iN_v to iM))
14848   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
14849       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
14850     SDValue InnerExt = N0.getOperand(0);
14851     EVT InnerExtVT = InnerExt->getValueType(0);
14852     SDValue Extendee = InnerExt->getOperand(0);
14853 
14854     if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
14855         (!LegalOperations ||
14856          TLI.isOperationLegal(ISD::SIGN_EXTEND, InnerExtVT))) {
14857       SDValue SignExtExtendee =
14858           DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), InnerExtVT, Extendee);
14859       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, SignExtExtendee,
14860                          N0.getOperand(1));
14861     }
14862   }
14863 
14864   return SDValue();
14865 }
14866 
foldExtendVectorInregToExtendOfSubvector(SDNode * N,const SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalOperations)14867 static SDValue foldExtendVectorInregToExtendOfSubvector(
14868     SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
14869     bool LegalOperations) {
14870   unsigned InregOpcode = N->getOpcode();
14871   unsigned Opcode = DAG.getOpcode_EXTEND(InregOpcode);
14872 
14873   SDValue Src = N->getOperand(0);
14874   EVT VT = N->getValueType(0);
14875   EVT SrcVT = EVT::getVectorVT(*DAG.getContext(),
14876                                Src.getValueType().getVectorElementType(),
14877                                VT.getVectorElementCount());
14878 
14879   assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
14880          "Expected EXTEND_VECTOR_INREG dag node in input!");
14881 
14882   // Profitability check: our operand must be an one-use CONCAT_VECTORS.
14883   // FIXME: one-use check may be overly restrictive
14884   if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
14885     return SDValue();
14886 
14887   // Profitability check: we must be extending exactly one of it's operands.
14888   // FIXME: this is probably overly restrictive.
14889   Src = Src.getOperand(0);
14890   if (Src.getValueType() != SrcVT)
14891     return SDValue();
14892 
14893   if (LegalOperations && !TLI.isOperationLegal(Opcode, VT))
14894     return SDValue();
14895 
14896   return DAG.getNode(Opcode, DL, VT, Src);
14897 }
14898 
visitEXTEND_VECTOR_INREG(SDNode * N)14899 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
14900   SDValue N0 = N->getOperand(0);
14901   EVT VT = N->getValueType(0);
14902   SDLoc DL(N);
14903 
14904   if (N0.isUndef()) {
14905     // aext_vector_inreg(undef) = undef because the top bits are undefined.
14906     // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
14907     return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
14908                ? DAG.getUNDEF(VT)
14909                : DAG.getConstant(0, DL, VT);
14910   }
14911 
14912   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14913     return Res;
14914 
14915   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
14916     return SDValue(N, 0);
14917 
14918   if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
14919                                                            LegalOperations))
14920     return R;
14921 
14922   return SDValue();
14923 }
14924 
visitTRUNCATE(SDNode * N)14925 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14926   SDValue N0 = N->getOperand(0);
14927   EVT VT = N->getValueType(0);
14928   EVT SrcVT = N0.getValueType();
14929   bool isLE = DAG.getDataLayout().isLittleEndian();
14930   SDLoc DL(N);
14931 
14932   // trunc(undef) = undef
14933   if (N0.isUndef())
14934     return DAG.getUNDEF(VT);
14935 
14936   // fold (truncate (truncate x)) -> (truncate x)
14937   if (N0.getOpcode() == ISD::TRUNCATE)
14938     return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14939 
14940   // fold (truncate c1) -> c1
14941   if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
14942     return C;
14943 
14944   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
14945   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
14946       N0.getOpcode() == ISD::SIGN_EXTEND ||
14947       N0.getOpcode() == ISD::ANY_EXTEND) {
14948     // if the source is smaller than the dest, we still need an extend.
14949     if (N0.getOperand(0).getValueType().bitsLT(VT))
14950       return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0));
14951     // if the source is larger than the dest, than we just need the truncate.
14952     if (N0.getOperand(0).getValueType().bitsGT(VT))
14953       return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14954     // if the source and dest are the same type, we can drop both the extend
14955     // and the truncate.
14956     return N0.getOperand(0);
14957   }
14958 
14959   // Try to narrow a truncate-of-sext_in_reg to the destination type:
14960   // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
14961   if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14962       N0.hasOneUse()) {
14963     SDValue X = N0.getOperand(0);
14964     SDValue ExtVal = N0.getOperand(1);
14965     EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
14966     if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(VT, SrcVT, ExtVT)) {
14967       SDValue TrX = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
14968       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, TrX, ExtVal);
14969     }
14970   }
14971 
14972   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
14973   if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
14974     return SDValue();
14975 
14976   // Fold extract-and-trunc into a narrow extract. For example:
14977   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
14978   //   i32 y = TRUNCATE(i64 x)
14979   //        -- becomes --
14980   //   v16i8 b = BITCAST (v2i64 val)
14981   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
14982   //
14983   // Note: We only run this optimization after type legalization (which often
14984   // creates this pattern) and before operation legalization after which
14985   // we need to be more careful about the vector instructions that we generate.
14986   if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
14987       LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
14988     EVT VecTy = N0.getOperand(0).getValueType();
14989     EVT ExTy = N0.getValueType();
14990     EVT TrTy = N->getValueType(0);
14991 
14992     auto EltCnt = VecTy.getVectorElementCount();
14993     unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
14994     auto NewEltCnt = EltCnt * SizeRatio;
14995 
14996     EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
14997     assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
14998 
14999     SDValue EltNo = N0->getOperand(1);
15000     if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
15001       int Elt = EltNo->getAsZExtVal();
15002       int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
15003       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
15004                          DAG.getBitcast(NVT, N0.getOperand(0)),
15005                          DAG.getVectorIdxConstant(Index, DL));
15006     }
15007   }
15008 
15009   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
15010   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
15011     if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
15012         TLI.isTruncateFree(SrcVT, VT)) {
15013       SDLoc SL(N0);
15014       SDValue Cond = N0.getOperand(0);
15015       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
15016       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
15017       return DAG.getNode(ISD::SELECT, DL, VT, Cond, TruncOp0, TruncOp1);
15018     }
15019   }
15020 
15021   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
15022   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
15023       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
15024       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
15025     SDValue Amt = N0.getOperand(1);
15026     KnownBits Known = DAG.computeKnownBits(Amt);
15027     unsigned Size = VT.getScalarSizeInBits();
15028     if (Known.countMaxActiveBits() <= Log2_32(Size)) {
15029       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
15030       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
15031       if (AmtVT != Amt.getValueType()) {
15032         Amt = DAG.getZExtOrTrunc(Amt, DL, AmtVT);
15033         AddToWorklist(Amt.getNode());
15034       }
15035       return DAG.getNode(ISD::SHL, DL, VT, Trunc, Amt);
15036     }
15037   }
15038 
15039   if (SDValue V = foldSubToUSubSat(VT, N0.getNode(), DL))
15040     return V;
15041 
15042   if (SDValue ABD = foldABSToABD(N, DL))
15043     return ABD;
15044 
15045   // Attempt to pre-truncate BUILD_VECTOR sources.
15046   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
15047       N0.hasOneUse() &&
15048       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
15049       // Avoid creating illegal types if running after type legalizer.
15050       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
15051     EVT SVT = VT.getScalarType();
15052     SmallVector<SDValue, 8> TruncOps;
15053     for (const SDValue &Op : N0->op_values()) {
15054       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
15055       TruncOps.push_back(TruncOp);
15056     }
15057     return DAG.getBuildVector(VT, DL, TruncOps);
15058   }
15059 
15060   // trunc (splat_vector x) -> splat_vector (trunc x)
15061   if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
15062       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType())) &&
15063       (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT))) {
15064     EVT SVT = VT.getScalarType();
15065     return DAG.getSplatVector(
15066         VT, DL, DAG.getNode(ISD::TRUNCATE, DL, SVT, N0->getOperand(0)));
15067   }
15068 
15069   // Fold a series of buildvector, bitcast, and truncate if possible.
15070   // For example fold
15071   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
15072   //   (2xi32 (buildvector x, y)).
15073   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
15074       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
15075       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
15076       N0.getOperand(0).hasOneUse()) {
15077     SDValue BuildVect = N0.getOperand(0);
15078     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
15079     EVT TruncVecEltTy = VT.getVectorElementType();
15080 
15081     // Check that the element types match.
15082     if (BuildVectEltTy == TruncVecEltTy) {
15083       // Now we only need to compute the offset of the truncated elements.
15084       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
15085       unsigned TruncVecNumElts = VT.getVectorNumElements();
15086       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
15087 
15088       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
15089              "Invalid number of elements");
15090 
15091       SmallVector<SDValue, 8> Opnds;
15092       for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
15093         Opnds.push_back(BuildVect.getOperand(i));
15094 
15095       return DAG.getBuildVector(VT, DL, Opnds);
15096     }
15097   }
15098 
15099   // fold (truncate (load x)) -> (smaller load x)
15100   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
15101   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
15102     if (SDValue Reduced = reduceLoadWidth(N))
15103       return Reduced;
15104 
15105     // Handle the case where the truncated result is at least as wide as the
15106     // loaded type.
15107     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
15108       auto *LN0 = cast<LoadSDNode>(N0);
15109       if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
15110         SDValue NewLoad = DAG.getExtLoad(
15111             LN0->getExtensionType(), SDLoc(LN0), VT, LN0->getChain(),
15112             LN0->getBasePtr(), LN0->getMemoryVT(), LN0->getMemOperand());
15113         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
15114         return NewLoad;
15115       }
15116     }
15117   }
15118 
15119   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
15120   // where ... are all 'undef'.
15121   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
15122     SmallVector<EVT, 8> VTs;
15123     SDValue V;
15124     unsigned Idx = 0;
15125     unsigned NumDefs = 0;
15126 
15127     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
15128       SDValue X = N0.getOperand(i);
15129       if (!X.isUndef()) {
15130         V = X;
15131         Idx = i;
15132         NumDefs++;
15133       }
15134       // Stop if more than one members are non-undef.
15135       if (NumDefs > 1)
15136         break;
15137 
15138       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
15139                                      VT.getVectorElementType(),
15140                                      X.getValueType().getVectorElementCount()));
15141     }
15142 
15143     if (NumDefs == 0)
15144       return DAG.getUNDEF(VT);
15145 
15146     if (NumDefs == 1) {
15147       assert(V.getNode() && "The single defined operand is empty!");
15148       SmallVector<SDValue, 8> Opnds;
15149       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
15150         if (i != Idx) {
15151           Opnds.push_back(DAG.getUNDEF(VTs[i]));
15152           continue;
15153         }
15154         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
15155         AddToWorklist(NV.getNode());
15156         Opnds.push_back(NV);
15157       }
15158       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Opnds);
15159     }
15160   }
15161 
15162   // Fold truncate of a bitcast of a vector to an extract of the low vector
15163   // element.
15164   //
15165   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
15166   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
15167     SDValue VecSrc = N0.getOperand(0);
15168     EVT VecSrcVT = VecSrc.getValueType();
15169     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
15170         (!LegalOperations ||
15171          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
15172       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
15173       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VecSrc,
15174                          DAG.getVectorIdxConstant(Idx, DL));
15175     }
15176   }
15177 
15178   // Simplify the operands using demanded-bits information.
15179   if (SimplifyDemandedBits(SDValue(N, 0)))
15180     return SDValue(N, 0);
15181 
15182   // fold (truncate (extract_subvector(ext x))) ->
15183   //      (extract_subvector x)
15184   // TODO: This can be generalized to cover cases where the truncate and extract
15185   // do not fully cancel each other out.
15186   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
15187     SDValue N00 = N0.getOperand(0);
15188     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
15189         N00.getOpcode() == ISD::ZERO_EXTEND ||
15190         N00.getOpcode() == ISD::ANY_EXTEND) {
15191       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
15192           VT.getVectorElementType())
15193         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
15194                            N00.getOperand(0), N0.getOperand(1));
15195     }
15196   }
15197 
15198   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15199     return NewVSel;
15200 
15201   // Narrow a suitable binary operation with a non-opaque constant operand by
15202   // moving it ahead of the truncate. This is limited to pre-legalization
15203   // because targets may prefer a wider type during later combines and invert
15204   // this transform.
15205   switch (N0.getOpcode()) {
15206   case ISD::ADD:
15207   case ISD::SUB:
15208   case ISD::MUL:
15209   case ISD::AND:
15210   case ISD::OR:
15211   case ISD::XOR:
15212     if (!LegalOperations && N0.hasOneUse() &&
15213         (isConstantOrConstantVector(N0.getOperand(0), true) ||
15214          isConstantOrConstantVector(N0.getOperand(1), true))) {
15215       // TODO: We already restricted this to pre-legalization, but for vectors
15216       // we are extra cautious to not create an unsupported operation.
15217       // Target-specific changes are likely needed to avoid regressions here.
15218       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
15219         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
15220         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
15221         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
15222       }
15223     }
15224     break;
15225   case ISD::ADDE:
15226   case ISD::UADDO_CARRY:
15227     // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
15228     // (trunc uaddo_carry(X, Y, Carry)) ->
15229     //     (uaddo_carry trunc(X), trunc(Y), Carry)
15230     // When the adde's carry is not used.
15231     // We only do for uaddo_carry before legalize operation
15232     if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
15233          TLI.isOperationLegal(N0.getOpcode(), VT)) &&
15234         N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
15235       SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
15236       SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
15237       SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
15238       return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
15239     }
15240     break;
15241   case ISD::USUBSAT:
15242     // Truncate the USUBSAT only if LHS is a known zero-extension, its not
15243     // enough to know that the upper bits are zero we must ensure that we don't
15244     // introduce an extra truncate.
15245     if (!LegalOperations && N0.hasOneUse() &&
15246         N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
15247         N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
15248             VT.getScalarSizeInBits() &&
15249         hasOperation(N0.getOpcode(), VT)) {
15250       return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
15251                                  DAG, DL);
15252     }
15253     break;
15254   }
15255 
15256   return SDValue();
15257 }
15258 
getBuildPairElt(SDNode * N,unsigned i)15259 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
15260   SDValue Elt = N->getOperand(i);
15261   if (Elt.getOpcode() != ISD::MERGE_VALUES)
15262     return Elt.getNode();
15263   return Elt.getOperand(Elt.getResNo()).getNode();
15264 }
15265 
15266 /// build_pair (load, load) -> load
15267 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)15268 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
15269   assert(N->getOpcode() == ISD::BUILD_PAIR);
15270 
15271   auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
15272   auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
15273 
15274   // A BUILD_PAIR is always having the least significant part in elt 0 and the
15275   // most significant part in elt 1. So when combining into one large load, we
15276   // need to consider the endianness.
15277   if (DAG.getDataLayout().isBigEndian())
15278     std::swap(LD1, LD2);
15279 
15280   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
15281       !LD1->hasOneUse() || !LD2->hasOneUse() ||
15282       LD1->getAddressSpace() != LD2->getAddressSpace())
15283     return SDValue();
15284 
15285   unsigned LD1Fast = 0;
15286   EVT LD1VT = LD1->getValueType(0);
15287   unsigned LD1Bytes = LD1VT.getStoreSize();
15288   if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
15289       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
15290       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
15291                              *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
15292     return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
15293                        LD1->getPointerInfo(), LD1->getAlign());
15294 
15295   return SDValue();
15296 }
15297 
getPPCf128HiElementSelector(const SelectionDAG & DAG)15298 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
15299   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
15300   // and Lo parts; on big-endian machines it doesn't.
15301   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
15302 }
15303 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)15304 SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
15305                                           const TargetLowering &TLI) {
15306   // If this is not a bitcast to an FP type or if the target doesn't have
15307   // IEEE754-compliant FP logic, we're done.
15308   EVT VT = N->getValueType(0);
15309   SDValue N0 = N->getOperand(0);
15310   EVT SourceVT = N0.getValueType();
15311 
15312   if (!VT.isFloatingPoint())
15313     return SDValue();
15314 
15315   // TODO: Handle cases where the integer constant is a different scalar
15316   // bitwidth to the FP.
15317   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
15318     return SDValue();
15319 
15320   unsigned FPOpcode;
15321   APInt SignMask;
15322   switch (N0.getOpcode()) {
15323   case ISD::AND:
15324     FPOpcode = ISD::FABS;
15325     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
15326     break;
15327   case ISD::XOR:
15328     FPOpcode = ISD::FNEG;
15329     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
15330     break;
15331   case ISD::OR:
15332     FPOpcode = ISD::FABS;
15333     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
15334     break;
15335   default:
15336     return SDValue();
15337   }
15338 
15339   if (LegalOperations && !TLI.isOperationLegal(FPOpcode, VT))
15340     return SDValue();
15341 
15342   // This needs to be the inverse of logic in foldSignChangeInBitcast.
15343   // FIXME: I don't think looking for bitcast intrinsically makes sense, but
15344   // removing this would require more changes.
15345   auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
15346     if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(0).getValueType() == VT)
15347       return true;
15348 
15349     return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
15350   };
15351 
15352   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
15353   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
15354   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
15355   //   fneg (fabs X)
15356   SDValue LogicOp0 = N0.getOperand(0);
15357   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
15358   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
15359       IsBitCastOrFree(LogicOp0, VT)) {
15360     SDValue CastOp0 = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, LogicOp0);
15361     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, CastOp0);
15362     NumFPLogicOpsConv++;
15363     if (N0.getOpcode() == ISD::OR)
15364       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
15365     return FPOp;
15366   }
15367 
15368   return SDValue();
15369 }
15370 
visitBITCAST(SDNode * N)15371 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
15372   SDValue N0 = N->getOperand(0);
15373   EVT VT = N->getValueType(0);
15374 
15375   if (N0.isUndef())
15376     return DAG.getUNDEF(VT);
15377 
15378   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
15379   // Only do this before legalize types, unless both types are integer and the
15380   // scalar type is legal. Only do this before legalize ops, since the target
15381   // maybe depending on the bitcast.
15382   // First check to see if this is all constant.
15383   // TODO: Support FP bitcasts after legalize types.
15384   if (VT.isVector() &&
15385       (!LegalTypes ||
15386        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
15387         TLI.isTypeLegal(VT.getVectorElementType()))) &&
15388       N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
15389       cast<BuildVectorSDNode>(N0)->isConstant())
15390     return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
15391                                              VT.getVectorElementType());
15392 
15393   // If the input is a constant, let getNode fold it.
15394   if (isIntOrFPConstant(N0)) {
15395     // If we can't allow illegal operations, we need to check that this is just
15396     // a fp -> int or int -> conversion and that the resulting operation will
15397     // be legal.
15398     if (!LegalOperations ||
15399         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
15400          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
15401         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
15402          TLI.isOperationLegal(ISD::Constant, VT))) {
15403       SDValue C = DAG.getBitcast(VT, N0);
15404       if (C.getNode() != N)
15405         return C;
15406     }
15407   }
15408 
15409   // (conv (conv x, t1), t2) -> (conv x, t2)
15410   if (N0.getOpcode() == ISD::BITCAST)
15411     return DAG.getBitcast(VT, N0.getOperand(0));
15412 
15413   // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
15414   // iff the current bitwise logicop type isn't legal
15415   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && VT.isInteger() &&
15416       !TLI.isTypeLegal(N0.getOperand(0).getValueType())) {
15417     auto IsFreeBitcast = [VT](SDValue V) {
15418       return (V.getOpcode() == ISD::BITCAST &&
15419               V.getOperand(0).getValueType() == VT) ||
15420              (ISD::isBuildVectorOfConstantSDNodes(V.getNode()) &&
15421               V->hasOneUse());
15422     };
15423     if (IsFreeBitcast(N0.getOperand(0)) && IsFreeBitcast(N0.getOperand(1)))
15424       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT,
15425                          DAG.getBitcast(VT, N0.getOperand(0)),
15426                          DAG.getBitcast(VT, N0.getOperand(1)));
15427   }
15428 
15429   // fold (conv (load x)) -> (load (conv*)x)
15430   // If the resultant load doesn't need a higher alignment than the original!
15431   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
15432       // Do not remove the cast if the types differ in endian layout.
15433       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
15434           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
15435       // If the load is volatile, we only want to change the load type if the
15436       // resulting load is legal. Otherwise we might increase the number of
15437       // memory accesses. We don't care if the original type was legal or not
15438       // as we assume software couldn't rely on the number of accesses of an
15439       // illegal type.
15440       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
15441        TLI.isOperationLegal(ISD::LOAD, VT))) {
15442     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15443 
15444     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
15445                                     *LN0->getMemOperand())) {
15446       SDValue Load =
15447           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
15448                       LN0->getMemOperand());
15449       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
15450       return Load;
15451     }
15452   }
15453 
15454   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
15455     return V;
15456 
15457   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
15458   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
15459   //
15460   // For ppc_fp128:
15461   // fold (bitcast (fneg x)) ->
15462   //     flipbit = signbit
15463   //     (xor (bitcast x) (build_pair flipbit, flipbit))
15464   //
15465   // fold (bitcast (fabs x)) ->
15466   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
15467   //     (xor (bitcast x) (build_pair flipbit, flipbit))
15468   // This often reduces constant pool loads.
15469   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
15470        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
15471       N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
15472       !N0.getValueType().isVector()) {
15473     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
15474     AddToWorklist(NewConv.getNode());
15475 
15476     SDLoc DL(N);
15477     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15478       assert(VT.getSizeInBits() == 128);
15479       SDValue SignBit = DAG.getConstant(
15480           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
15481       SDValue FlipBit;
15482       if (N0.getOpcode() == ISD::FNEG) {
15483         FlipBit = SignBit;
15484         AddToWorklist(FlipBit.getNode());
15485       } else {
15486         assert(N0.getOpcode() == ISD::FABS);
15487         SDValue Hi =
15488             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
15489                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15490                                               SDLoc(NewConv)));
15491         AddToWorklist(Hi.getNode());
15492         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
15493         AddToWorklist(FlipBit.getNode());
15494       }
15495       SDValue FlipBits =
15496           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
15497       AddToWorklist(FlipBits.getNode());
15498       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
15499     }
15500     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
15501     if (N0.getOpcode() == ISD::FNEG)
15502       return DAG.getNode(ISD::XOR, DL, VT,
15503                          NewConv, DAG.getConstant(SignBit, DL, VT));
15504     assert(N0.getOpcode() == ISD::FABS);
15505     return DAG.getNode(ISD::AND, DL, VT,
15506                        NewConv, DAG.getConstant(~SignBit, DL, VT));
15507   }
15508 
15509   // fold (bitconvert (fcopysign cst, x)) ->
15510   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
15511   // Note that we don't handle (copysign x, cst) because this can always be
15512   // folded to an fneg or fabs.
15513   //
15514   // For ppc_fp128:
15515   // fold (bitcast (fcopysign cst, x)) ->
15516   //     flipbit = (and (extract_element
15517   //                     (xor (bitcast cst), (bitcast x)), 0),
15518   //                    signbit)
15519   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
15520   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
15521       isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
15522       !VT.isVector()) {
15523     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
15524     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
15525     if (isTypeLegal(IntXVT)) {
15526       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
15527       AddToWorklist(X.getNode());
15528 
15529       // If X has a different width than the result/lhs, sext it or truncate it.
15530       unsigned VTWidth = VT.getSizeInBits();
15531       if (OrigXWidth < VTWidth) {
15532         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
15533         AddToWorklist(X.getNode());
15534       } else if (OrigXWidth > VTWidth) {
15535         // To get the sign bit in the right place, we have to shift it right
15536         // before truncating.
15537         SDLoc DL(X);
15538         X = DAG.getNode(ISD::SRL, DL,
15539                         X.getValueType(), X,
15540                         DAG.getConstant(OrigXWidth-VTWidth, DL,
15541                                         X.getValueType()));
15542         AddToWorklist(X.getNode());
15543         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
15544         AddToWorklist(X.getNode());
15545       }
15546 
15547       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15548         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
15549         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
15550         AddToWorklist(Cst.getNode());
15551         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
15552         AddToWorklist(X.getNode());
15553         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
15554         AddToWorklist(XorResult.getNode());
15555         SDValue XorResult64 = DAG.getNode(
15556             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
15557             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15558                                   SDLoc(XorResult)));
15559         AddToWorklist(XorResult64.getNode());
15560         SDValue FlipBit =
15561             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
15562                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
15563         AddToWorklist(FlipBit.getNode());
15564         SDValue FlipBits =
15565             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
15566         AddToWorklist(FlipBits.getNode());
15567         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
15568       }
15569       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
15570       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
15571                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
15572       AddToWorklist(X.getNode());
15573 
15574       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
15575       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
15576                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
15577       AddToWorklist(Cst.getNode());
15578 
15579       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
15580     }
15581   }
15582 
15583   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
15584   if (N0.getOpcode() == ISD::BUILD_PAIR)
15585     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
15586       return CombineLD;
15587 
15588   // Remove double bitcasts from shuffles - this is often a legacy of
15589   // XformToShuffleWithZero being used to combine bitmaskings (of
15590   // float vectors bitcast to integer vectors) into shuffles.
15591   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
15592   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
15593       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
15594       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
15595       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
15596     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
15597 
15598     // If operands are a bitcast, peek through if it casts the original VT.
15599     // If operands are a constant, just bitcast back to original VT.
15600     auto PeekThroughBitcast = [&](SDValue Op) {
15601       if (Op.getOpcode() == ISD::BITCAST &&
15602           Op.getOperand(0).getValueType() == VT)
15603         return SDValue(Op.getOperand(0));
15604       if (Op.isUndef() || isAnyConstantBuildVector(Op))
15605         return DAG.getBitcast(VT, Op);
15606       return SDValue();
15607     };
15608 
15609     // FIXME: If either input vector is bitcast, try to convert the shuffle to
15610     // the result type of this bitcast. This would eliminate at least one
15611     // bitcast. See the transform in InstCombine.
15612     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
15613     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
15614     if (!(SV0 && SV1))
15615       return SDValue();
15616 
15617     int MaskScale =
15618         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
15619     SmallVector<int, 8> NewMask;
15620     for (int M : SVN->getMask())
15621       for (int i = 0; i != MaskScale; ++i)
15622         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
15623 
15624     SDValue LegalShuffle =
15625         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
15626     if (LegalShuffle)
15627       return LegalShuffle;
15628   }
15629 
15630   return SDValue();
15631 }
15632 
visitBUILD_PAIR(SDNode * N)15633 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
15634   EVT VT = N->getValueType(0);
15635   return CombineConsecutiveLoads(N, VT);
15636 }
15637 
visitFREEZE(SDNode * N)15638 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
15639   SDValue N0 = N->getOperand(0);
15640 
15641   if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
15642     return N0;
15643 
15644   // We currently avoid folding freeze over SRA/SRL, due to the problems seen
15645   // with (freeze (assert ext)) blocking simplifications of SRA/SRL. See for
15646   // example https://reviews.llvm.org/D136529#4120959.
15647   if (N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::SRL)
15648     return SDValue();
15649 
15650   // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
15651   // Try to push freeze through instructions that propagate but don't produce
15652   // poison as far as possible. If an operand of freeze follows three
15653   // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
15654   // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
15655   // the freeze through to the operands that are not guaranteed non-poison.
15656   // NOTE: we will strip poison-generating flags, so ignore them here.
15657   if (DAG.canCreateUndefOrPoison(N0, /*PoisonOnly*/ false,
15658                                  /*ConsiderFlags*/ false) ||
15659       N0->getNumValues() != 1 || !N0->hasOneUse())
15660     return SDValue();
15661 
15662   bool AllowMultipleMaybePoisonOperands =
15663       N0.getOpcode() == ISD::SELECT_CC ||
15664       N0.getOpcode() == ISD::SETCC ||
15665       N0.getOpcode() == ISD::BUILD_VECTOR ||
15666       N0.getOpcode() == ISD::BUILD_PAIR ||
15667       N0.getOpcode() == ISD::VECTOR_SHUFFLE ||
15668       N0.getOpcode() == ISD::CONCAT_VECTORS;
15669 
15670   // Avoid turning a BUILD_VECTOR that can be recognized as "all zeros", "all
15671   // ones" or "constant" into something that depends on FrozenUndef. We can
15672   // instead pick undef values to keep those properties, while at the same time
15673   // folding away the freeze.
15674   // If we implement a more general solution for folding away freeze(undef) in
15675   // the future, then this special handling can be removed.
15676   if (N0.getOpcode() == ISD::BUILD_VECTOR) {
15677     SDLoc DL(N0);
15678     EVT VT = N0.getValueType();
15679     if (llvm::ISD::isBuildVectorAllOnes(N0.getNode()))
15680       return DAG.getAllOnesConstant(DL, VT);
15681     if (llvm::ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
15682       SmallVector<SDValue, 8> NewVecC;
15683       for (const SDValue &Op : N0->op_values())
15684         NewVecC.push_back(
15685             Op.isUndef() ? DAG.getConstant(0, DL, Op.getValueType()) : Op);
15686       return DAG.getBuildVector(VT, DL, NewVecC);
15687     }
15688   }
15689 
15690   SmallSet<SDValue, 8> MaybePoisonOperands;
15691   SmallVector<unsigned, 8> MaybePoisonOperandNumbers;
15692   for (auto [OpNo, Op] : enumerate(N0->ops())) {
15693     if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
15694                                              /*Depth*/ 1))
15695       continue;
15696     bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
15697     bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op).second;
15698     if (IsNewMaybePoisonOperand)
15699       MaybePoisonOperandNumbers.push_back(OpNo);
15700     if (!HadMaybePoisonOperands)
15701       continue;
15702     if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
15703       // Multiple maybe-poison ops when not allowed - bail out.
15704       return SDValue();
15705     }
15706   }
15707   // NOTE: the whole op may be not guaranteed to not be undef or poison because
15708   // it could create undef or poison due to it's poison-generating flags.
15709   // So not finding any maybe-poison operands is fine.
15710 
15711   for (unsigned OpNo : MaybePoisonOperandNumbers) {
15712     // N0 can mutate during iteration, so make sure to refetch the maybe poison
15713     // operands via the operand numbers. The typical scenario is that we have
15714     // something like this
15715     //   t262: i32 = freeze t181
15716     //   t150: i32 = ctlz_zero_undef t262
15717     //   t184: i32 = ctlz_zero_undef t181
15718     //   t268: i32 = select_cc t181, Constant:i32<0>, t184, t186, setne:ch
15719     // When freezing the t181 operand we get t262 back, and then the
15720     // ReplaceAllUsesOfValueWith call will not only replace t181 by t262, but
15721     // also recursively replace t184 by t150.
15722     SDValue MaybePoisonOperand = N->getOperand(0).getOperand(OpNo);
15723     // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
15724     if (MaybePoisonOperand.getOpcode() == ISD::UNDEF)
15725       continue;
15726     // First, freeze each offending operand.
15727     SDValue FrozenMaybePoisonOperand = DAG.getFreeze(MaybePoisonOperand);
15728     // Then, change all other uses of unfrozen operand to use frozen operand.
15729     DAG.ReplaceAllUsesOfValueWith(MaybePoisonOperand, FrozenMaybePoisonOperand);
15730     if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
15731         FrozenMaybePoisonOperand.getOperand(0) == FrozenMaybePoisonOperand) {
15732       // But, that also updated the use in the freeze we just created, thus
15733       // creating a cycle in a DAG. Let's undo that by mutating the freeze.
15734       DAG.UpdateNodeOperands(FrozenMaybePoisonOperand.getNode(),
15735                              MaybePoisonOperand);
15736     }
15737   }
15738 
15739   // This node has been merged with another.
15740   if (N->getOpcode() == ISD::DELETED_NODE)
15741     return SDValue(N, 0);
15742 
15743   // The whole node may have been updated, so the value we were holding
15744   // may no longer be valid. Re-fetch the operand we're `freeze`ing.
15745   N0 = N->getOperand(0);
15746 
15747   // Finally, recreate the node, it's operands were updated to use
15748   // frozen operands, so we just need to use it's "original" operands.
15749   SmallVector<SDValue> Ops(N0->op_begin(), N0->op_end());
15750   // Special-handle ISD::UNDEF, each single one of them can be it's own thing.
15751   for (SDValue &Op : Ops) {
15752     if (Op.getOpcode() == ISD::UNDEF)
15753       Op = DAG.getFreeze(Op);
15754   }
15755 
15756   SDValue R;
15757   if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N0)) {
15758     // Special case handling for ShuffleVectorSDNode nodes.
15759     R = DAG.getVectorShuffle(N0.getValueType(), SDLoc(N0), Ops[0], Ops[1],
15760                              SVN->getMask());
15761   } else {
15762     // NOTE: this strips poison generating flags.
15763     R = DAG.getNode(N0.getOpcode(), SDLoc(N0), N0->getVTList(), Ops);
15764   }
15765   assert(DAG.isGuaranteedNotToBeUndefOrPoison(R, /*PoisonOnly*/ false) &&
15766          "Can't create node that may be undef/poison!");
15767   return R;
15768 }
15769 
15770 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
15771 /// operands. DstEltVT indicates the destination element value type.
15772 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)15773 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
15774   EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
15775 
15776   // If this is already the right type, we're done.
15777   if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
15778 
15779   unsigned SrcBitSize = SrcEltVT.getSizeInBits();
15780   unsigned DstBitSize = DstEltVT.getSizeInBits();
15781 
15782   // If this is a conversion of N elements of one type to N elements of another
15783   // type, convert each element.  This handles FP<->INT cases.
15784   if (SrcBitSize == DstBitSize) {
15785     SmallVector<SDValue, 8> Ops;
15786     for (SDValue Op : BV->op_values()) {
15787       // If the vector element type is not legal, the BUILD_VECTOR operands
15788       // are promoted and implicitly truncated.  Make that explicit here.
15789       if (Op.getValueType() != SrcEltVT)
15790         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
15791       Ops.push_back(DAG.getBitcast(DstEltVT, Op));
15792       AddToWorklist(Ops.back().getNode());
15793     }
15794     EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
15795                               BV->getValueType(0).getVectorNumElements());
15796     return DAG.getBuildVector(VT, SDLoc(BV), Ops);
15797   }
15798 
15799   // Otherwise, we're growing or shrinking the elements.  To avoid having to
15800   // handle annoying details of growing/shrinking FP values, we convert them to
15801   // int first.
15802   if (SrcEltVT.isFloatingPoint()) {
15803     // Convert the input float vector to a int vector where the elements are the
15804     // same sizes.
15805     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
15806     BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
15807     SrcEltVT = IntVT;
15808   }
15809 
15810   // Now we know the input is an integer vector.  If the output is a FP type,
15811   // convert to integer first, then to FP of the right size.
15812   if (DstEltVT.isFloatingPoint()) {
15813     EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
15814     SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
15815 
15816     // Next, convert to FP elements of the same size.
15817     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
15818   }
15819 
15820   // Okay, we know the src/dst types are both integers of differing types.
15821   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
15822 
15823   // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
15824   // BuildVectorSDNode?
15825   auto *BVN = cast<BuildVectorSDNode>(BV);
15826 
15827   // Extract the constant raw bit data.
15828   BitVector UndefElements;
15829   SmallVector<APInt> RawBits;
15830   bool IsLE = DAG.getDataLayout().isLittleEndian();
15831   if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
15832     return SDValue();
15833 
15834   SDLoc DL(BV);
15835   SmallVector<SDValue, 8> Ops;
15836   for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
15837     if (UndefElements[I])
15838       Ops.push_back(DAG.getUNDEF(DstEltVT));
15839     else
15840       Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
15841   }
15842 
15843   EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
15844   return DAG.getBuildVector(VT, DL, Ops);
15845 }
15846 
15847 // Returns true if floating point contraction is allowed on the FMUL-SDValue
15848 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)15849 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
15850   assert(N.getOpcode() == ISD::FMUL);
15851 
15852   return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
15853          N->getFlags().hasAllowContract();
15854 }
15855 
15856 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)15857 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
15858   return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
15859 }
15860 
15861 /// Try to perform FMA combining on a given FADD node.
15862 template <class MatchContextClass>
visitFADDForFMACombine(SDNode * N)15863 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
15864   SDValue N0 = N->getOperand(0);
15865   SDValue N1 = N->getOperand(1);
15866   EVT VT = N->getValueType(0);
15867   SDLoc SL(N);
15868   MatchContextClass matcher(DAG, TLI, N);
15869   const TargetOptions &Options = DAG.getTarget().Options;
15870 
15871   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15872 
15873   // Floating-point multiply-add with intermediate rounding.
15874   // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15875   // FIXME: Add VP_FMAD opcode.
15876   bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15877 
15878   // Floating-point multiply-add without intermediate rounding.
15879   bool HasFMA =
15880       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
15881       (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15882 
15883   // No valid opcode, do not combine.
15884   if (!HasFMAD && !HasFMA)
15885     return SDValue();
15886 
15887   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15888                               Options.UnsafeFPMath || HasFMAD);
15889   // If the addition is not contractable, do not combine.
15890   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15891     return SDValue();
15892 
15893   // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
15894   // beneficial. It does not reduce latency. It increases register pressure. It
15895   // replaces an fadd with an fma which is a more complex instruction, so is
15896   // likely to have a larger encoding, use more functional units, etc.
15897   if (N0 == N1)
15898     return SDValue();
15899 
15900   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15901     return SDValue();
15902 
15903   // Always prefer FMAD to FMA for precision.
15904   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15905   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15906 
15907   auto isFusedOp = [&](SDValue N) {
15908     return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
15909   };
15910 
15911   // Is the node an FMUL and contractable either due to global flags or
15912   // SDNodeFlags.
15913   auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15914     if (!matcher.match(N, ISD::FMUL))
15915       return false;
15916     return AllowFusionGlobally || N->getFlags().hasAllowContract();
15917   };
15918   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
15919   // prefer to fold the multiply with fewer uses.
15920   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
15921     if (N0->use_size() > N1->use_size())
15922       std::swap(N0, N1);
15923   }
15924 
15925   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
15926   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
15927     return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
15928                            N0.getOperand(1), N1);
15929   }
15930 
15931   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
15932   // Note: Commutes FADD operands.
15933   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
15934     return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
15935                            N1.getOperand(1), N0);
15936   }
15937 
15938   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
15939   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
15940   // This also works with nested fma instructions:
15941   // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
15942   // fma A, B, (fma C, D, fma (E, F, G))
15943   // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
15944   // fma A, B, (fma C, D, fma (E, F, G)).
15945   // This requires reassociation because it changes the order of operations.
15946   bool CanReassociate =
15947       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15948   if (CanReassociate) {
15949     SDValue FMA, E;
15950     if (isFusedOp(N0) && N0.hasOneUse()) {
15951       FMA = N0;
15952       E = N1;
15953     } else if (isFusedOp(N1) && N1.hasOneUse()) {
15954       FMA = N1;
15955       E = N0;
15956     }
15957 
15958     SDValue TmpFMA = FMA;
15959     while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
15960       SDValue FMul = TmpFMA->getOperand(2);
15961       if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
15962         SDValue C = FMul.getOperand(0);
15963         SDValue D = FMul.getOperand(1);
15964         SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
15965         DAG.ReplaceAllUsesOfValueWith(FMul, CDE);
15966         // Replacing the inner FMul could cause the outer FMA to be simplified
15967         // away.
15968         return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
15969       }
15970 
15971       TmpFMA = TmpFMA->getOperand(2);
15972     }
15973   }
15974 
15975   // Look through FP_EXTEND nodes to do more combining.
15976 
15977   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
15978   if (matcher.match(N0, ISD::FP_EXTEND)) {
15979     SDValue N00 = N0.getOperand(0);
15980     if (isContractableFMUL(N00) &&
15981         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15982                             N00.getValueType())) {
15983       return matcher.getNode(
15984           PreferredFusedOpcode, SL, VT,
15985           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
15986           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1);
15987     }
15988   }
15989 
15990   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
15991   // Note: Commutes FADD operands.
15992   if (matcher.match(N1, ISD::FP_EXTEND)) {
15993     SDValue N10 = N1.getOperand(0);
15994     if (isContractableFMUL(N10) &&
15995         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15996                             N10.getValueType())) {
15997       return matcher.getNode(
15998           PreferredFusedOpcode, SL, VT,
15999           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
16000           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
16001     }
16002   }
16003 
16004   // More folding opportunities when target permits.
16005   if (Aggressive) {
16006     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
16007     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
16008     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
16009                                     SDValue Z) {
16010       return matcher.getNode(
16011           PreferredFusedOpcode, SL, VT, X, Y,
16012           matcher.getNode(PreferredFusedOpcode, SL, VT,
16013                           matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
16014                           matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
16015     };
16016     if (isFusedOp(N0)) {
16017       SDValue N02 = N0.getOperand(2);
16018       if (matcher.match(N02, ISD::FP_EXTEND)) {
16019         SDValue N020 = N02.getOperand(0);
16020         if (isContractableFMUL(N020) &&
16021             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16022                                 N020.getValueType())) {
16023           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
16024                                       N020.getOperand(0), N020.getOperand(1),
16025                                       N1);
16026         }
16027       }
16028     }
16029 
16030     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
16031     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
16032     // FIXME: This turns two single-precision and one double-precision
16033     // operation into two double-precision operations, which might not be
16034     // interesting for all targets, especially GPUs.
16035     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
16036                                     SDValue Z) {
16037       return matcher.getNode(
16038           PreferredFusedOpcode, SL, VT,
16039           matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
16040           matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
16041           matcher.getNode(PreferredFusedOpcode, SL, VT,
16042                           matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
16043                           matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
16044     };
16045     if (N0.getOpcode() == ISD::FP_EXTEND) {
16046       SDValue N00 = N0.getOperand(0);
16047       if (isFusedOp(N00)) {
16048         SDValue N002 = N00.getOperand(2);
16049         if (isContractableFMUL(N002) &&
16050             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16051                                 N00.getValueType())) {
16052           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
16053                                       N002.getOperand(0), N002.getOperand(1),
16054                                       N1);
16055         }
16056       }
16057     }
16058 
16059     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
16060     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
16061     if (isFusedOp(N1)) {
16062       SDValue N12 = N1.getOperand(2);
16063       if (N12.getOpcode() == ISD::FP_EXTEND) {
16064         SDValue N120 = N12.getOperand(0);
16065         if (isContractableFMUL(N120) &&
16066             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16067                                 N120.getValueType())) {
16068           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
16069                                       N120.getOperand(0), N120.getOperand(1),
16070                                       N0);
16071         }
16072       }
16073     }
16074 
16075     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
16076     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
16077     // FIXME: This turns two single-precision and one double-precision
16078     // operation into two double-precision operations, which might not be
16079     // interesting for all targets, especially GPUs.
16080     if (N1.getOpcode() == ISD::FP_EXTEND) {
16081       SDValue N10 = N1.getOperand(0);
16082       if (isFusedOp(N10)) {
16083         SDValue N102 = N10.getOperand(2);
16084         if (isContractableFMUL(N102) &&
16085             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16086                                 N10.getValueType())) {
16087           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
16088                                       N102.getOperand(0), N102.getOperand(1),
16089                                       N0);
16090         }
16091       }
16092     }
16093   }
16094 
16095   return SDValue();
16096 }
16097 
16098 /// Try to perform FMA combining on a given FSUB node.
16099 template <class MatchContextClass>
visitFSUBForFMACombine(SDNode * N)16100 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
16101   SDValue N0 = N->getOperand(0);
16102   SDValue N1 = N->getOperand(1);
16103   EVT VT = N->getValueType(0);
16104   SDLoc SL(N);
16105   MatchContextClass matcher(DAG, TLI, N);
16106   const TargetOptions &Options = DAG.getTarget().Options;
16107 
16108   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
16109 
16110   // Floating-point multiply-add with intermediate rounding.
16111   // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
16112   // FIXME: Add VP_FMAD opcode.
16113   bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
16114 
16115   // Floating-point multiply-add without intermediate rounding.
16116   bool HasFMA =
16117       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
16118       (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
16119 
16120   // No valid opcode, do not combine.
16121   if (!HasFMAD && !HasFMA)
16122     return SDValue();
16123 
16124   const SDNodeFlags Flags = N->getFlags();
16125   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
16126                               Options.UnsafeFPMath || HasFMAD);
16127 
16128   // If the subtraction is not contractable, do not combine.
16129   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
16130     return SDValue();
16131 
16132   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
16133     return SDValue();
16134 
16135   // Always prefer FMAD to FMA for precision.
16136   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16137   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16138   bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
16139 
16140   // Is the node an FMUL and contractable either due to global flags or
16141   // SDNodeFlags.
16142   auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
16143     if (!matcher.match(N, ISD::FMUL))
16144       return false;
16145     return AllowFusionGlobally || N->getFlags().hasAllowContract();
16146   };
16147 
16148   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
16149   auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
16150     if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
16151       return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
16152                              XY.getOperand(1),
16153                              matcher.getNode(ISD::FNEG, SL, VT, Z));
16154     }
16155     return SDValue();
16156   };
16157 
16158   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
16159   // Note: Commutes FSUB operands.
16160   auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
16161     if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
16162       return matcher.getNode(
16163           PreferredFusedOpcode, SL, VT,
16164           matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
16165           YZ.getOperand(1), X);
16166     }
16167     return SDValue();
16168   };
16169 
16170   // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
16171   // prefer to fold the multiply with fewer uses.
16172   if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
16173       (N0->use_size() > N1->use_size())) {
16174     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
16175     if (SDValue V = tryToFoldXSubYZ(N0, N1))
16176       return V;
16177     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
16178     if (SDValue V = tryToFoldXYSubZ(N0, N1))
16179       return V;
16180   } else {
16181     // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
16182     if (SDValue V = tryToFoldXYSubZ(N0, N1))
16183       return V;
16184     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
16185     if (SDValue V = tryToFoldXSubYZ(N0, N1))
16186       return V;
16187   }
16188 
16189   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
16190   if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) &&
16191       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
16192     SDValue N00 = N0.getOperand(0).getOperand(0);
16193     SDValue N01 = N0.getOperand(0).getOperand(1);
16194     return matcher.getNode(PreferredFusedOpcode, SL, VT,
16195                            matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
16196                            matcher.getNode(ISD::FNEG, SL, VT, N1));
16197   }
16198 
16199   // Look through FP_EXTEND nodes to do more combining.
16200 
16201   // fold (fsub (fpext (fmul x, y)), z)
16202   //   -> (fma (fpext x), (fpext y), (fneg z))
16203   if (matcher.match(N0, ISD::FP_EXTEND)) {
16204     SDValue N00 = N0.getOperand(0);
16205     if (isContractableFMUL(N00) &&
16206         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16207                             N00.getValueType())) {
16208       return matcher.getNode(
16209           PreferredFusedOpcode, SL, VT,
16210           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
16211           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
16212           matcher.getNode(ISD::FNEG, SL, VT, N1));
16213     }
16214   }
16215 
16216   // fold (fsub x, (fpext (fmul y, z)))
16217   //   -> (fma (fneg (fpext y)), (fpext z), x)
16218   // Note: Commutes FSUB operands.
16219   if (matcher.match(N1, ISD::FP_EXTEND)) {
16220     SDValue N10 = N1.getOperand(0);
16221     if (isContractableFMUL(N10) &&
16222         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16223                             N10.getValueType())) {
16224       return matcher.getNode(
16225           PreferredFusedOpcode, SL, VT,
16226           matcher.getNode(
16227               ISD::FNEG, SL, VT,
16228               matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
16229           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
16230     }
16231   }
16232 
16233   // fold (fsub (fpext (fneg (fmul, x, y))), z)
16234   //   -> (fneg (fma (fpext x), (fpext y), z))
16235   // Note: This could be removed with appropriate canonicalization of the
16236   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
16237   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
16238   // from implementing the canonicalization in visitFSUB.
16239   if (matcher.match(N0, ISD::FP_EXTEND)) {
16240     SDValue N00 = N0.getOperand(0);
16241     if (matcher.match(N00, ISD::FNEG)) {
16242       SDValue N000 = N00.getOperand(0);
16243       if (isContractableFMUL(N000) &&
16244           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16245                               N00.getValueType())) {
16246         return matcher.getNode(
16247             ISD::FNEG, SL, VT,
16248             matcher.getNode(
16249                 PreferredFusedOpcode, SL, VT,
16250                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
16251                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
16252                 N1));
16253       }
16254     }
16255   }
16256 
16257   // fold (fsub (fneg (fpext (fmul, x, y))), z)
16258   //   -> (fneg (fma (fpext x)), (fpext y), z)
16259   // Note: This could be removed with appropriate canonicalization of the
16260   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
16261   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
16262   // from implementing the canonicalization in visitFSUB.
16263   if (matcher.match(N0, ISD::FNEG)) {
16264     SDValue N00 = N0.getOperand(0);
16265     if (matcher.match(N00, ISD::FP_EXTEND)) {
16266       SDValue N000 = N00.getOperand(0);
16267       if (isContractableFMUL(N000) &&
16268           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16269                               N000.getValueType())) {
16270         return matcher.getNode(
16271             ISD::FNEG, SL, VT,
16272             matcher.getNode(
16273                 PreferredFusedOpcode, SL, VT,
16274                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
16275                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
16276                 N1));
16277       }
16278     }
16279   }
16280 
16281   auto isReassociable = [&Options](SDNode *N) {
16282     return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16283   };
16284 
16285   auto isContractableAndReassociableFMUL = [&isContractableFMUL,
16286                                             &isReassociable](SDValue N) {
16287     return isContractableFMUL(N) && isReassociable(N.getNode());
16288   };
16289 
16290   auto isFusedOp = [&](SDValue N) {
16291     return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
16292   };
16293 
16294   // More folding opportunities when target permits.
16295   if (Aggressive && isReassociable(N)) {
16296     bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
16297     // fold (fsub (fma x, y, (fmul u, v)), z)
16298     //   -> (fma x, y (fma u, v, (fneg z)))
16299     if (CanFuse && isFusedOp(N0) &&
16300         isContractableAndReassociableFMUL(N0.getOperand(2)) &&
16301         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
16302       return matcher.getNode(
16303           PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
16304           matcher.getNode(PreferredFusedOpcode, SL, VT,
16305                           N0.getOperand(2).getOperand(0),
16306                           N0.getOperand(2).getOperand(1),
16307                           matcher.getNode(ISD::FNEG, SL, VT, N1)));
16308     }
16309 
16310     // fold (fsub x, (fma y, z, (fmul u, v)))
16311     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
16312     if (CanFuse && isFusedOp(N1) &&
16313         isContractableAndReassociableFMUL(N1.getOperand(2)) &&
16314         N1->hasOneUse() && NoSignedZero) {
16315       SDValue N20 = N1.getOperand(2).getOperand(0);
16316       SDValue N21 = N1.getOperand(2).getOperand(1);
16317       return matcher.getNode(
16318           PreferredFusedOpcode, SL, VT,
16319           matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
16320           N1.getOperand(1),
16321           matcher.getNode(PreferredFusedOpcode, SL, VT,
16322                           matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
16323     }
16324 
16325     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
16326     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
16327     if (isFusedOp(N0) && N0->hasOneUse()) {
16328       SDValue N02 = N0.getOperand(2);
16329       if (matcher.match(N02, ISD::FP_EXTEND)) {
16330         SDValue N020 = N02.getOperand(0);
16331         if (isContractableAndReassociableFMUL(N020) &&
16332             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16333                                 N020.getValueType())) {
16334           return matcher.getNode(
16335               PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
16336               matcher.getNode(
16337                   PreferredFusedOpcode, SL, VT,
16338                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
16339                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
16340                   matcher.getNode(ISD::FNEG, SL, VT, N1)));
16341         }
16342       }
16343     }
16344 
16345     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
16346     //   -> (fma (fpext x), (fpext y),
16347     //           (fma (fpext u), (fpext v), (fneg z)))
16348     // FIXME: This turns two single-precision and one double-precision
16349     // operation into two double-precision operations, which might not be
16350     // interesting for all targets, especially GPUs.
16351     if (matcher.match(N0, ISD::FP_EXTEND)) {
16352       SDValue N00 = N0.getOperand(0);
16353       if (isFusedOp(N00)) {
16354         SDValue N002 = N00.getOperand(2);
16355         if (isContractableAndReassociableFMUL(N002) &&
16356             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16357                                 N00.getValueType())) {
16358           return matcher.getNode(
16359               PreferredFusedOpcode, SL, VT,
16360               matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
16361               matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
16362               matcher.getNode(
16363                   PreferredFusedOpcode, SL, VT,
16364                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
16365                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
16366                   matcher.getNode(ISD::FNEG, SL, VT, N1)));
16367         }
16368       }
16369     }
16370 
16371     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
16372     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
16373     if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) &&
16374         N1->hasOneUse()) {
16375       SDValue N120 = N1.getOperand(2).getOperand(0);
16376       if (isContractableAndReassociableFMUL(N120) &&
16377           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16378                               N120.getValueType())) {
16379         SDValue N1200 = N120.getOperand(0);
16380         SDValue N1201 = N120.getOperand(1);
16381         return matcher.getNode(
16382             PreferredFusedOpcode, SL, VT,
16383             matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
16384             N1.getOperand(1),
16385             matcher.getNode(
16386                 PreferredFusedOpcode, SL, VT,
16387                 matcher.getNode(ISD::FNEG, SL, VT,
16388                                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
16389                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
16390       }
16391     }
16392 
16393     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
16394     //   -> (fma (fneg (fpext y)), (fpext z),
16395     //           (fma (fneg (fpext u)), (fpext v), x))
16396     // FIXME: This turns two single-precision and one double-precision
16397     // operation into two double-precision operations, which might not be
16398     // interesting for all targets, especially GPUs.
16399     if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) {
16400       SDValue CvtSrc = N1.getOperand(0);
16401       SDValue N100 = CvtSrc.getOperand(0);
16402       SDValue N101 = CvtSrc.getOperand(1);
16403       SDValue N102 = CvtSrc.getOperand(2);
16404       if (isContractableAndReassociableFMUL(N102) &&
16405           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16406                               CvtSrc.getValueType())) {
16407         SDValue N1020 = N102.getOperand(0);
16408         SDValue N1021 = N102.getOperand(1);
16409         return matcher.getNode(
16410             PreferredFusedOpcode, SL, VT,
16411             matcher.getNode(ISD::FNEG, SL, VT,
16412                             matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
16413             matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
16414             matcher.getNode(
16415                 PreferredFusedOpcode, SL, VT,
16416                 matcher.getNode(ISD::FNEG, SL, VT,
16417                                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
16418                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
16419       }
16420     }
16421   }
16422 
16423   return SDValue();
16424 }
16425 
16426 /// Try to perform FMA combining on a given FMUL node based on the distributive
16427 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
16428 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)16429 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
16430   SDValue N0 = N->getOperand(0);
16431   SDValue N1 = N->getOperand(1);
16432   EVT VT = N->getValueType(0);
16433   SDLoc SL(N);
16434 
16435   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
16436 
16437   const TargetOptions &Options = DAG.getTarget().Options;
16438 
16439   // The transforms below are incorrect when x == 0 and y == inf, because the
16440   // intermediate multiplication produces a nan.
16441   SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
16442   if (!hasNoInfs(Options, FAdd))
16443     return SDValue();
16444 
16445   // Floating-point multiply-add without intermediate rounding.
16446   bool HasFMA =
16447       isContractableFMUL(Options, SDValue(N, 0)) &&
16448       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
16449       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
16450 
16451   // Floating-point multiply-add with intermediate rounding. This can result
16452   // in a less precise result due to the changed rounding order.
16453   bool HasFMAD = Options.UnsafeFPMath &&
16454                  (LegalOperations && TLI.isFMADLegal(DAG, N));
16455 
16456   // No valid opcode, do not combine.
16457   if (!HasFMAD && !HasFMA)
16458     return SDValue();
16459 
16460   // Always prefer FMAD to FMA for precision.
16461   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16462   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16463 
16464   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
16465   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
16466   auto FuseFADD = [&](SDValue X, SDValue Y) {
16467     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
16468       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
16469         if (C->isExactlyValue(+1.0))
16470           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16471                              Y);
16472         if (C->isExactlyValue(-1.0))
16473           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16474                              DAG.getNode(ISD::FNEG, SL, VT, Y));
16475       }
16476     }
16477     return SDValue();
16478   };
16479 
16480   if (SDValue FMA = FuseFADD(N0, N1))
16481     return FMA;
16482   if (SDValue FMA = FuseFADD(N1, N0))
16483     return FMA;
16484 
16485   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
16486   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
16487   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
16488   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
16489   auto FuseFSUB = [&](SDValue X, SDValue Y) {
16490     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
16491       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
16492         if (C0->isExactlyValue(+1.0))
16493           return DAG.getNode(PreferredFusedOpcode, SL, VT,
16494                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
16495                              Y);
16496         if (C0->isExactlyValue(-1.0))
16497           return DAG.getNode(PreferredFusedOpcode, SL, VT,
16498                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
16499                              DAG.getNode(ISD::FNEG, SL, VT, Y));
16500       }
16501       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
16502         if (C1->isExactlyValue(+1.0))
16503           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16504                              DAG.getNode(ISD::FNEG, SL, VT, Y));
16505         if (C1->isExactlyValue(-1.0))
16506           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
16507                              Y);
16508       }
16509     }
16510     return SDValue();
16511   };
16512 
16513   if (SDValue FMA = FuseFSUB(N0, N1))
16514     return FMA;
16515   if (SDValue FMA = FuseFSUB(N1, N0))
16516     return FMA;
16517 
16518   return SDValue();
16519 }
16520 
visitVP_FADD(SDNode * N)16521 SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
16522   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16523 
16524   // FADD -> FMA combines:
16525   if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
16526     if (Fused.getOpcode() != ISD::DELETED_NODE)
16527       AddToWorklist(Fused.getNode());
16528     return Fused;
16529   }
16530   return SDValue();
16531 }
16532 
visitFADD(SDNode * N)16533 SDValue DAGCombiner::visitFADD(SDNode *N) {
16534   SDValue N0 = N->getOperand(0);
16535   SDValue N1 = N->getOperand(1);
16536   SDNode *N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
16537   SDNode *N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
16538   EVT VT = N->getValueType(0);
16539   SDLoc DL(N);
16540   const TargetOptions &Options = DAG.getTarget().Options;
16541   SDNodeFlags Flags = N->getFlags();
16542   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16543 
16544   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16545     return R;
16546 
16547   // fold (fadd c1, c2) -> c1 + c2
16548   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
16549     return C;
16550 
16551   // canonicalize constant to RHS
16552   if (N0CFP && !N1CFP)
16553     return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
16554 
16555   // fold vector ops
16556   if (VT.isVector())
16557     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16558       return FoldedVOp;
16559 
16560   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
16561   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
16562   if (N1C && N1C->isZero())
16563     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
16564       return N0;
16565 
16566   if (SDValue NewSel = foldBinOpIntoSelect(N))
16567     return NewSel;
16568 
16569   // fold (fadd A, (fneg B)) -> (fsub A, B)
16570   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
16571     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16572             N1, DAG, LegalOperations, ForCodeSize))
16573       return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
16574 
16575   // fold (fadd (fneg A), B) -> (fsub B, A)
16576   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
16577     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16578             N0, DAG, LegalOperations, ForCodeSize))
16579       return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
16580 
16581   auto isFMulNegTwo = [](SDValue FMul) {
16582     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
16583       return false;
16584     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
16585     return C && C->isExactlyValue(-2.0);
16586   };
16587 
16588   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
16589   if (isFMulNegTwo(N0)) {
16590     SDValue B = N0.getOperand(0);
16591     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
16592     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
16593   }
16594   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
16595   if (isFMulNegTwo(N1)) {
16596     SDValue B = N1.getOperand(0);
16597     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
16598     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
16599   }
16600 
16601   // No FP constant should be created after legalization as Instruction
16602   // Selection pass has a hard time dealing with FP constants.
16603   bool AllowNewConst = (Level < AfterLegalizeDAG);
16604 
16605   // If nnan is enabled, fold lots of things.
16606   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
16607     // If allowed, fold (fadd (fneg x), x) -> 0.0
16608     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
16609       return DAG.getConstantFP(0.0, DL, VT);
16610 
16611     // If allowed, fold (fadd x, (fneg x)) -> 0.0
16612     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
16613       return DAG.getConstantFP(0.0, DL, VT);
16614   }
16615 
16616   // If 'unsafe math' or reassoc and nsz, fold lots of things.
16617   // TODO: break out portions of the transformations below for which Unsafe is
16618   //       considered and which do not require both nsz and reassoc
16619   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16620        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16621       AllowNewConst) {
16622     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
16623     if (N1CFP && N0.getOpcode() == ISD::FADD &&
16624         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
16625       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
16626       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
16627     }
16628 
16629     // We can fold chains of FADD's of the same value into multiplications.
16630     // This transform is not safe in general because we are reducing the number
16631     // of rounding steps.
16632     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
16633       if (N0.getOpcode() == ISD::FMUL) {
16634         SDNode *CFP00 =
16635             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
16636         SDNode *CFP01 =
16637             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
16638 
16639         // (fadd (fmul x, c), x) -> (fmul x, c+1)
16640         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
16641           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
16642                                        DAG.getConstantFP(1.0, DL, VT));
16643           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
16644         }
16645 
16646         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
16647         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
16648             N1.getOperand(0) == N1.getOperand(1) &&
16649             N0.getOperand(0) == N1.getOperand(0)) {
16650           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
16651                                        DAG.getConstantFP(2.0, DL, VT));
16652           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
16653         }
16654       }
16655 
16656       if (N1.getOpcode() == ISD::FMUL) {
16657         SDNode *CFP10 =
16658             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
16659         SDNode *CFP11 =
16660             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
16661 
16662         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
16663         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
16664           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
16665                                        DAG.getConstantFP(1.0, DL, VT));
16666           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
16667         }
16668 
16669         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
16670         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
16671             N0.getOperand(0) == N0.getOperand(1) &&
16672             N1.getOperand(0) == N0.getOperand(0)) {
16673           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
16674                                        DAG.getConstantFP(2.0, DL, VT));
16675           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
16676         }
16677       }
16678 
16679       if (N0.getOpcode() == ISD::FADD) {
16680         SDNode *CFP00 =
16681             DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
16682         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
16683         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
16684             (N0.getOperand(0) == N1)) {
16685           return DAG.getNode(ISD::FMUL, DL, VT, N1,
16686                              DAG.getConstantFP(3.0, DL, VT));
16687         }
16688       }
16689 
16690       if (N1.getOpcode() == ISD::FADD) {
16691         SDNode *CFP10 =
16692             DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
16693         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
16694         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
16695             N1.getOperand(0) == N0) {
16696           return DAG.getNode(ISD::FMUL, DL, VT, N0,
16697                              DAG.getConstantFP(3.0, DL, VT));
16698         }
16699       }
16700 
16701       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
16702       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
16703           N0.getOperand(0) == N0.getOperand(1) &&
16704           N1.getOperand(0) == N1.getOperand(1) &&
16705           N0.getOperand(0) == N1.getOperand(0)) {
16706         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
16707                            DAG.getConstantFP(4.0, DL, VT));
16708       }
16709     }
16710 
16711     // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
16712     if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
16713                                           VT, N0, N1, Flags))
16714       return SD;
16715   } // enable-unsafe-fp-math
16716 
16717   // FADD -> FMA combines:
16718   if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
16719     if (Fused.getOpcode() != ISD::DELETED_NODE)
16720       AddToWorklist(Fused.getNode());
16721     return Fused;
16722   }
16723   return SDValue();
16724 }
16725 
visitSTRICT_FADD(SDNode * N)16726 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
16727   SDValue Chain = N->getOperand(0);
16728   SDValue N0 = N->getOperand(1);
16729   SDValue N1 = N->getOperand(2);
16730   EVT VT = N->getValueType(0);
16731   EVT ChainVT = N->getValueType(1);
16732   SDLoc DL(N);
16733   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16734 
16735   // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
16736   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
16737     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16738             N1, DAG, LegalOperations, ForCodeSize)) {
16739       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
16740                          {Chain, N0, NegN1});
16741     }
16742 
16743   // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
16744   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
16745     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16746             N0, DAG, LegalOperations, ForCodeSize)) {
16747       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
16748                          {Chain, N1, NegN0});
16749     }
16750   return SDValue();
16751 }
16752 
visitFSUB(SDNode * N)16753 SDValue DAGCombiner::visitFSUB(SDNode *N) {
16754   SDValue N0 = N->getOperand(0);
16755   SDValue N1 = N->getOperand(1);
16756   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
16757   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
16758   EVT VT = N->getValueType(0);
16759   SDLoc DL(N);
16760   const TargetOptions &Options = DAG.getTarget().Options;
16761   const SDNodeFlags Flags = N->getFlags();
16762   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16763 
16764   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16765     return R;
16766 
16767   // fold (fsub c1, c2) -> c1-c2
16768   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
16769     return C;
16770 
16771   // fold vector ops
16772   if (VT.isVector())
16773     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16774       return FoldedVOp;
16775 
16776   if (SDValue NewSel = foldBinOpIntoSelect(N))
16777     return NewSel;
16778 
16779   // (fsub A, 0) -> A
16780   if (N1CFP && N1CFP->isZero()) {
16781     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
16782         Flags.hasNoSignedZeros()) {
16783       return N0;
16784     }
16785   }
16786 
16787   if (N0 == N1) {
16788     // (fsub x, x) -> 0.0
16789     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
16790       return DAG.getConstantFP(0.0f, DL, VT);
16791   }
16792 
16793   // (fsub -0.0, N1) -> -N1
16794   if (N0CFP && N0CFP->isZero()) {
16795     if (N0CFP->isNegative() ||
16796         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
16797       // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
16798       // flushed to zero, unless all users treat denorms as zero (DAZ).
16799       // FIXME: This transform will change the sign of a NaN and the behavior
16800       // of a signaling NaN. It is only valid when a NoNaN flag is present.
16801       DenormalMode DenormMode = DAG.getDenormalMode(VT);
16802       if (DenormMode == DenormalMode::getIEEE()) {
16803         if (SDValue NegN1 =
16804                 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
16805           return NegN1;
16806         if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
16807           return DAG.getNode(ISD::FNEG, DL, VT, N1);
16808       }
16809     }
16810   }
16811 
16812   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16813        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16814       N1.getOpcode() == ISD::FADD) {
16815     // X - (X + Y) -> -Y
16816     if (N0 == N1->getOperand(0))
16817       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
16818     // X - (Y + X) -> -Y
16819     if (N0 == N1->getOperand(1))
16820       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
16821   }
16822 
16823   // fold (fsub A, (fneg B)) -> (fadd A, B)
16824   if (SDValue NegN1 =
16825           TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
16826     return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
16827 
16828   // FSUB -> FMA combines:
16829   if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
16830     AddToWorklist(Fused.getNode());
16831     return Fused;
16832   }
16833 
16834   return SDValue();
16835 }
16836 
16837 // Transform IEEE Floats:
16838 //      (fmul C, (uitofp Pow2))
16839 //          -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
16840 //      (fdiv C, (uitofp Pow2))
16841 //          -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
16842 //
16843 // The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
16844 // there is no need for more than an add/sub.
16845 //
16846 // This is valid under the following circumstances:
16847 // 1) We are dealing with IEEE floats
16848 // 2) C is normal
16849 // 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
16850 // TODO: Much of this could also be used for generating `ldexp` on targets the
16851 // prefer it.
combineFMulOrFDivWithIntPow2(SDNode * N)16852 SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
16853   EVT VT = N->getValueType(0);
16854   SDValue ConstOp, Pow2Op;
16855 
16856   std::optional<int> Mantissa;
16857   auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
16858     if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
16859       return false;
16860 
16861     ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx));
16862     Pow2Op = N->getOperand(1 - ConstOpIdx);
16863     if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
16864         (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
16865          !DAG.computeKnownBits(Pow2Op).isNonNegative()))
16866       return false;
16867 
16868     Pow2Op = Pow2Op.getOperand(0);
16869 
16870     // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
16871     // TODO: We could use knownbits to make this bound more precise.
16872     int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
16873 
16874     auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
16875       if (CFP == nullptr)
16876         return false;
16877 
16878       const APFloat &APF = CFP->getValueAPF();
16879 
16880       // Make sure we have normal/ieee constant.
16881       if (!APF.isNormal() || !APF.isIEEE())
16882         return false;
16883 
16884       // Make sure the floats exponent is within the bounds that this transform
16885       // produces bitwise equals value.
16886       int CurExp = ilogb(APF);
16887       // FMul by pow2 will only increase exponent.
16888       int MinExp =
16889           N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
16890       // FDiv by pow2 will only decrease exponent.
16891       int MaxExp =
16892           N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
16893       if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
16894           MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
16895         return false;
16896 
16897       // Finally make sure we actually know the mantissa for the float type.
16898       int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
16899       if (!Mantissa)
16900         Mantissa = ThisMantissa;
16901 
16902       return *Mantissa == ThisMantissa && ThisMantissa > 0;
16903     };
16904 
16905     // TODO: We may be able to include undefs.
16906     return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
16907   };
16908 
16909   if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
16910     return SDValue();
16911 
16912   if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op))
16913     return SDValue();
16914 
16915   // Get log2 after all other checks have taken place. This is because
16916   // BuildLogBase2 may create a new node.
16917   SDLoc DL(N);
16918   // Get Log2 type with same bitwidth as the float type (VT).
16919   EVT NewIntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits());
16920   if (VT.isVector())
16921     NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewIntVT,
16922                                 VT.getVectorElementCount());
16923 
16924   SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op),
16925                                /*InexpensiveOnly*/ true, NewIntVT);
16926   if (!Log2)
16927     return SDValue();
16928 
16929   // Perform actual transform.
16930   SDValue MantissaShiftCnt =
16931       DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
16932   // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
16933   // `(X << C1) + (C << C1)`, but that isn't always the case because of the
16934   // cast. We could implement that by handle here to handle the casts.
16935   SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt);
16936   SDValue ResAsInt =
16937       DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
16938                   NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift);
16939   SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt);
16940   return ResAsFP;
16941 }
16942 
visitFMUL(SDNode * N)16943 SDValue DAGCombiner::visitFMUL(SDNode *N) {
16944   SDValue N0 = N->getOperand(0);
16945   SDValue N1 = N->getOperand(1);
16946   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
16947   EVT VT = N->getValueType(0);
16948   SDLoc DL(N);
16949   const TargetOptions &Options = DAG.getTarget().Options;
16950   const SDNodeFlags Flags = N->getFlags();
16951   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16952 
16953   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16954     return R;
16955 
16956   // fold (fmul c1, c2) -> c1*c2
16957   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
16958     return C;
16959 
16960   // canonicalize constant to RHS
16961   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
16962      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
16963     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
16964 
16965   // fold vector ops
16966   if (VT.isVector())
16967     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16968       return FoldedVOp;
16969 
16970   if (SDValue NewSel = foldBinOpIntoSelect(N))
16971     return NewSel;
16972 
16973   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
16974     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
16975     if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
16976         N0.getOpcode() == ISD::FMUL) {
16977       SDValue N00 = N0.getOperand(0);
16978       SDValue N01 = N0.getOperand(1);
16979       // Avoid an infinite loop by making sure that N00 is not a constant
16980       // (the inner multiply has not been constant folded yet).
16981       if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
16982           !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
16983         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
16984         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
16985       }
16986     }
16987 
16988     // Match a special-case: we convert X * 2.0 into fadd.
16989     // fmul (fadd X, X), C -> fmul X, 2.0 * C
16990     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
16991         N0.getOperand(0) == N0.getOperand(1)) {
16992       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
16993       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
16994       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
16995     }
16996 
16997     // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
16998     if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
16999                                           VT, N0, N1, Flags))
17000       return SD;
17001   }
17002 
17003   // fold (fmul X, 2.0) -> (fadd X, X)
17004   if (N1CFP && N1CFP->isExactlyValue(+2.0))
17005     return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
17006 
17007   // fold (fmul X, -1.0) -> (fsub -0.0, X)
17008   if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
17009     if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
17010       return DAG.getNode(ISD::FSUB, DL, VT,
17011                          DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
17012     }
17013   }
17014 
17015   // -N0 * -N1 --> N0 * N1
17016   TargetLowering::NegatibleCost CostN0 =
17017       TargetLowering::NegatibleCost::Expensive;
17018   TargetLowering::NegatibleCost CostN1 =
17019       TargetLowering::NegatibleCost::Expensive;
17020   SDValue NegN0 =
17021       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
17022   if (NegN0) {
17023     HandleSDNode NegN0Handle(NegN0);
17024     SDValue NegN1 =
17025         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
17026     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17027                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
17028       return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
17029   }
17030 
17031   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
17032   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
17033   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
17034       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
17035       TLI.isOperationLegal(ISD::FABS, VT)) {
17036     SDValue Select = N0, X = N1;
17037     if (Select.getOpcode() != ISD::SELECT)
17038       std::swap(Select, X);
17039 
17040     SDValue Cond = Select.getOperand(0);
17041     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
17042     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
17043 
17044     if (TrueOpnd && FalseOpnd &&
17045         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
17046         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
17047         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
17048       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
17049       switch (CC) {
17050       default: break;
17051       case ISD::SETOLT:
17052       case ISD::SETULT:
17053       case ISD::SETOLE:
17054       case ISD::SETULE:
17055       case ISD::SETLT:
17056       case ISD::SETLE:
17057         std::swap(TrueOpnd, FalseOpnd);
17058         [[fallthrough]];
17059       case ISD::SETOGT:
17060       case ISD::SETUGT:
17061       case ISD::SETOGE:
17062       case ISD::SETUGE:
17063       case ISD::SETGT:
17064       case ISD::SETGE:
17065         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
17066             TLI.isOperationLegal(ISD::FNEG, VT))
17067           return DAG.getNode(ISD::FNEG, DL, VT,
17068                    DAG.getNode(ISD::FABS, DL, VT, X));
17069         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
17070           return DAG.getNode(ISD::FABS, DL, VT, X);
17071 
17072         break;
17073       }
17074     }
17075   }
17076 
17077   // FMUL -> FMA combines:
17078   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
17079     AddToWorklist(Fused.getNode());
17080     return Fused;
17081   }
17082 
17083   // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
17084   // able to run.
17085   if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17086     return R;
17087 
17088   return SDValue();
17089 }
17090 
visitFMA(SDNode * N)17091 template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
17092   SDValue N0 = N->getOperand(0);
17093   SDValue N1 = N->getOperand(1);
17094   SDValue N2 = N->getOperand(2);
17095   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
17096   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
17097   EVT VT = N->getValueType(0);
17098   SDLoc DL(N);
17099   const TargetOptions &Options = DAG.getTarget().Options;
17100   // FMA nodes have flags that propagate to the created nodes.
17101   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17102   MatchContextClass matcher(DAG, TLI, N);
17103 
17104   // Constant fold FMA.
17105   if (isa<ConstantFPSDNode>(N0) &&
17106       isa<ConstantFPSDNode>(N1) &&
17107       isa<ConstantFPSDNode>(N2)) {
17108     return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2);
17109   }
17110 
17111   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
17112   TargetLowering::NegatibleCost CostN0 =
17113       TargetLowering::NegatibleCost::Expensive;
17114   TargetLowering::NegatibleCost CostN1 =
17115       TargetLowering::NegatibleCost::Expensive;
17116   SDValue NegN0 =
17117       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
17118   if (NegN0) {
17119     HandleSDNode NegN0Handle(NegN0);
17120     SDValue NegN1 =
17121         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
17122     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17123                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
17124       return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
17125   }
17126 
17127   // FIXME: use fast math flags instead of Options.UnsafeFPMath
17128   if (Options.UnsafeFPMath) {
17129     if (N0CFP && N0CFP->isZero())
17130       return N2;
17131     if (N1CFP && N1CFP->isZero())
17132       return N2;
17133   }
17134 
17135   // FIXME: Support splat of constant.
17136   if (N0CFP && N0CFP->isExactlyValue(1.0))
17137     return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
17138   if (N1CFP && N1CFP->isExactlyValue(1.0))
17139     return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
17140 
17141   // Canonicalize (fma c, x, y) -> (fma x, c, y)
17142   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
17143      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
17144     return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
17145 
17146   bool CanReassociate =
17147       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
17148   if (CanReassociate) {
17149     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
17150     if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) &&
17151         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
17152         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
17153       return matcher.getNode(
17154           ISD::FMUL, DL, VT, N0,
17155           matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
17156     }
17157 
17158     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
17159     if (matcher.match(N0, ISD::FMUL) &&
17160         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
17161         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
17162       return matcher.getNode(
17163           ISD::FMA, DL, VT, N0.getOperand(0),
17164           matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2);
17165     }
17166   }
17167 
17168   // (fma x, -1, y) -> (fadd (fneg x), y)
17169   // FIXME: Support splat of constant.
17170   if (N1CFP) {
17171     if (N1CFP->isExactlyValue(1.0))
17172       return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
17173 
17174     if (N1CFP->isExactlyValue(-1.0) &&
17175         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
17176       SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
17177       AddToWorklist(RHSNeg.getNode());
17178       return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
17179     }
17180 
17181     // fma (fneg x), K, y -> fma x -K, y
17182     if (matcher.match(N0, ISD::FNEG) &&
17183         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17184          (N1.hasOneUse() &&
17185           !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
17186       return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
17187                              matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
17188     }
17189   }
17190 
17191   // FIXME: Support splat of constant.
17192   if (CanReassociate) {
17193     // (fma x, c, x) -> (fmul x, (c+1))
17194     if (N1CFP && N0 == N2) {
17195       return matcher.getNode(ISD::FMUL, DL, VT, N0,
17196                              matcher.getNode(ISD::FADD, DL, VT, N1,
17197                                              DAG.getConstantFP(1.0, DL, VT)));
17198     }
17199 
17200     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
17201     if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) {
17202       return matcher.getNode(ISD::FMUL, DL, VT, N0,
17203                              matcher.getNode(ISD::FADD, DL, VT, N1,
17204                                              DAG.getConstantFP(-1.0, DL, VT)));
17205     }
17206   }
17207 
17208   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
17209   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
17210   if (!TLI.isFNegFree(VT))
17211     if (SDValue Neg = TLI.getCheaperNegatedExpression(
17212             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
17213       return matcher.getNode(ISD::FNEG, DL, VT, Neg);
17214   return SDValue();
17215 }
17216 
visitFMAD(SDNode * N)17217 SDValue DAGCombiner::visitFMAD(SDNode *N) {
17218   SDValue N0 = N->getOperand(0);
17219   SDValue N1 = N->getOperand(1);
17220   SDValue N2 = N->getOperand(2);
17221   EVT VT = N->getValueType(0);
17222   SDLoc DL(N);
17223 
17224   // Constant fold FMAD.
17225   if (isa<ConstantFPSDNode>(N0) && isa<ConstantFPSDNode>(N1) &&
17226       isa<ConstantFPSDNode>(N2))
17227     return DAG.getNode(ISD::FMAD, DL, VT, N0, N1, N2);
17228 
17229   return SDValue();
17230 }
17231 
17232 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
17233 // reciprocal.
17234 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
17235 // Notice that this is not always beneficial. One reason is different targets
17236 // may have different costs for FDIV and FMUL, so sometimes the cost of two
17237 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
17238 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)17239 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
17240   // TODO: Limit this transform based on optsize/minsize - it always creates at
17241   //       least 1 extra instruction. But the perf win may be substantial enough
17242   //       that only minsize should restrict this.
17243   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
17244   const SDNodeFlags Flags = N->getFlags();
17245   if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
17246     return SDValue();
17247 
17248   // Skip if current node is a reciprocal/fneg-reciprocal.
17249   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
17250   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
17251   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
17252     return SDValue();
17253 
17254   // Exit early if the target does not want this transform or if there can't
17255   // possibly be enough uses of the divisor to make the transform worthwhile.
17256   unsigned MinUses = TLI.combineRepeatedFPDivisors();
17257 
17258   // For splat vectors, scale the number of uses by the splat factor. If we can
17259   // convert the division into a scalar op, that will likely be much faster.
17260   unsigned NumElts = 1;
17261   EVT VT = N->getValueType(0);
17262   if (VT.isVector() && DAG.isSplatValue(N1))
17263     NumElts = VT.getVectorMinNumElements();
17264 
17265   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
17266     return SDValue();
17267 
17268   // Find all FDIV users of the same divisor.
17269   // Use a set because duplicates may be present in the user list.
17270   SetVector<SDNode *> Users;
17271   for (auto *U : N1->uses()) {
17272     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
17273       // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
17274       if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
17275           U->getOperand(0) == U->getOperand(1).getOperand(0) &&
17276           U->getFlags().hasAllowReassociation() &&
17277           U->getFlags().hasNoSignedZeros())
17278         continue;
17279 
17280       // This division is eligible for optimization only if global unsafe math
17281       // is enabled or if this division allows reciprocal formation.
17282       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
17283         Users.insert(U);
17284     }
17285   }
17286 
17287   // Now that we have the actual number of divisor uses, make sure it meets
17288   // the minimum threshold specified by the target.
17289   if ((Users.size() * NumElts) < MinUses)
17290     return SDValue();
17291 
17292   SDLoc DL(N);
17293   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
17294   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
17295 
17296   // Dividend / Divisor -> Dividend * Reciprocal
17297   for (auto *U : Users) {
17298     SDValue Dividend = U->getOperand(0);
17299     if (Dividend != FPOne) {
17300       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
17301                                     Reciprocal, Flags);
17302       CombineTo(U, NewNode);
17303     } else if (U != Reciprocal.getNode()) {
17304       // In the absence of fast-math-flags, this user node is always the
17305       // same node as Reciprocal, but with FMF they may be different nodes.
17306       CombineTo(U, Reciprocal);
17307     }
17308   }
17309   return SDValue(N, 0);  // N was replaced.
17310 }
17311 
visitFDIV(SDNode * N)17312 SDValue DAGCombiner::visitFDIV(SDNode *N) {
17313   SDValue N0 = N->getOperand(0);
17314   SDValue N1 = N->getOperand(1);
17315   EVT VT = N->getValueType(0);
17316   SDLoc DL(N);
17317   const TargetOptions &Options = DAG.getTarget().Options;
17318   SDNodeFlags Flags = N->getFlags();
17319   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17320 
17321   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17322     return R;
17323 
17324   // fold (fdiv c1, c2) -> c1/c2
17325   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
17326     return C;
17327 
17328   // fold vector ops
17329   if (VT.isVector())
17330     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17331       return FoldedVOp;
17332 
17333   if (SDValue NewSel = foldBinOpIntoSelect(N))
17334     return NewSel;
17335 
17336   if (SDValue V = combineRepeatedFPDivisors(N))
17337     return V;
17338 
17339   // fold (fdiv X, c2) -> (fmul X, 1/c2) if there is no loss in precision, or
17340   // the loss is acceptable with AllowReciprocal.
17341   if (auto *N1CFP = isConstOrConstSplatFP(N1, true)) {
17342     // Compute the reciprocal 1.0 / c2.
17343     const APFloat &N1APF = N1CFP->getValueAPF();
17344     APFloat Recip = APFloat::getOne(N1APF.getSemantics());
17345     APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
17346     // Only do the transform if the reciprocal is a legal fp immediate that
17347     // isn't too nasty (eg NaN, denormal, ...).
17348     if (((st == APFloat::opOK && !Recip.isDenormal()) ||
17349          (st == APFloat::opInexact &&
17350           (Options.UnsafeFPMath || Flags.hasAllowReciprocal()))) &&
17351         (!LegalOperations ||
17352          // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
17353          // backend)... we should handle this gracefully after Legalize.
17354          // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
17355          TLI.isOperationLegal(ISD::ConstantFP, VT) ||
17356          TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
17357       return DAG.getNode(ISD::FMUL, DL, VT, N0,
17358                          DAG.getConstantFP(Recip, DL, VT));
17359   }
17360 
17361   if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
17362     // If this FDIV is part of a reciprocal square root, it may be folded
17363     // into a target-specific square root estimate instruction.
17364     if (N1.getOpcode() == ISD::FSQRT) {
17365       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
17366         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
17367     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
17368                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
17369       if (SDValue RV =
17370               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
17371         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
17372         AddToWorklist(RV.getNode());
17373         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
17374       }
17375     } else if (N1.getOpcode() == ISD::FP_ROUND &&
17376                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
17377       if (SDValue RV =
17378               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
17379         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
17380         AddToWorklist(RV.getNode());
17381         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
17382       }
17383     } else if (N1.getOpcode() == ISD::FMUL) {
17384       // Look through an FMUL. Even though this won't remove the FDIV directly,
17385       // it's still worthwhile to get rid of the FSQRT if possible.
17386       SDValue Sqrt, Y;
17387       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
17388         Sqrt = N1.getOperand(0);
17389         Y = N1.getOperand(1);
17390       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
17391         Sqrt = N1.getOperand(1);
17392         Y = N1.getOperand(0);
17393       }
17394       if (Sqrt.getNode()) {
17395         // If the other multiply operand is known positive, pull it into the
17396         // sqrt. That will eliminate the division if we convert to an estimate.
17397         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
17398             N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
17399           SDValue A;
17400           if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
17401             A = Y.getOperand(0);
17402           else if (Y == Sqrt.getOperand(0))
17403             A = Y;
17404           if (A) {
17405             // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
17406             // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
17407             SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
17408             SDValue AAZ =
17409                 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
17410             if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
17411               return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
17412 
17413             // Estimate creation failed. Clean up speculatively created nodes.
17414             recursivelyDeleteUnusedNodes(AAZ.getNode());
17415           }
17416         }
17417 
17418         // We found a FSQRT, so try to make this fold:
17419         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
17420         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
17421           SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
17422           AddToWorklist(Div.getNode());
17423           return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
17424         }
17425       }
17426     }
17427 
17428     // Fold into a reciprocal estimate and multiply instead of a real divide.
17429     if (Options.NoInfsFPMath || Flags.hasNoInfs())
17430       if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
17431         return RV;
17432   }
17433 
17434   // Fold X/Sqrt(X) -> Sqrt(X)
17435   if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
17436       (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
17437     if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
17438       return N1;
17439 
17440   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
17441   TargetLowering::NegatibleCost CostN0 =
17442       TargetLowering::NegatibleCost::Expensive;
17443   TargetLowering::NegatibleCost CostN1 =
17444       TargetLowering::NegatibleCost::Expensive;
17445   SDValue NegN0 =
17446       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
17447   if (NegN0) {
17448     HandleSDNode NegN0Handle(NegN0);
17449     SDValue NegN1 =
17450         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
17451     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17452                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
17453       return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
17454   }
17455 
17456   if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17457     return R;
17458 
17459   return SDValue();
17460 }
17461 
visitFREM(SDNode * N)17462 SDValue DAGCombiner::visitFREM(SDNode *N) {
17463   SDValue N0 = N->getOperand(0);
17464   SDValue N1 = N->getOperand(1);
17465   EVT VT = N->getValueType(0);
17466   SDNodeFlags Flags = N->getFlags();
17467   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17468   SDLoc DL(N);
17469 
17470   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17471     return R;
17472 
17473   // fold (frem c1, c2) -> fmod(c1,c2)
17474   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, DL, VT, {N0, N1}))
17475     return C;
17476 
17477   if (SDValue NewSel = foldBinOpIntoSelect(N))
17478     return NewSel;
17479 
17480   // Lower frem N0, N1 => x - trunc(N0 / N1) * N1, providing N1 is an integer
17481   // power of 2.
17482   if (!TLI.isOperationLegal(ISD::FREM, VT) &&
17483       TLI.isOperationLegalOrCustom(ISD::FMUL, VT) &&
17484       TLI.isOperationLegalOrCustom(ISD::FDIV, VT) &&
17485       TLI.isOperationLegalOrCustom(ISD::FTRUNC, VT) &&
17486       DAG.isKnownToBeAPowerOfTwoFP(N1)) {
17487     bool NeedsCopySign =
17488         !Flags.hasNoSignedZeros() && !DAG.cannotBeOrderedNegativeFP(N0);
17489     SDValue Div = DAG.getNode(ISD::FDIV, DL, VT, N0, N1);
17490     SDValue Rnd = DAG.getNode(ISD::FTRUNC, DL, VT, Div);
17491     SDValue MLA;
17492     if (TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT)) {
17493       MLA = DAG.getNode(ISD::FMA, DL, VT, DAG.getNode(ISD::FNEG, DL, VT, Rnd),
17494                         N1, N0);
17495     } else {
17496       SDValue Mul = DAG.getNode(ISD::FMUL, DL, VT, Rnd, N1);
17497       MLA = DAG.getNode(ISD::FSUB, DL, VT, N0, Mul);
17498     }
17499     return NeedsCopySign ? DAG.getNode(ISD::FCOPYSIGN, DL, VT, MLA, N0) : MLA;
17500   }
17501 
17502   return SDValue();
17503 }
17504 
visitFSQRT(SDNode * N)17505 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
17506   SDNodeFlags Flags = N->getFlags();
17507   const TargetOptions &Options = DAG.getTarget().Options;
17508 
17509   // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
17510   // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
17511   if (!Flags.hasApproximateFuncs() ||
17512       (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
17513     return SDValue();
17514 
17515   SDValue N0 = N->getOperand(0);
17516   if (TLI.isFsqrtCheap(N0, DAG))
17517     return SDValue();
17518 
17519   // FSQRT nodes have flags that propagate to the created nodes.
17520   // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
17521   //       transform the fdiv, we may produce a sub-optimal estimate sequence
17522   //       because the reciprocal calculation may not have to filter out a
17523   //       0.0 input.
17524   return buildSqrtEstimate(N0, Flags);
17525 }
17526 
17527 /// copysign(x, fp_extend(y)) -> copysign(x, y)
17528 /// copysign(x, fp_round(y)) -> copysign(x, y)
17529 /// Operands to the functions are the type of X and Y respectively.
CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy,EVT YTy)17530 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
17531   // Always fold no-op FP casts.
17532   if (XTy == YTy)
17533     return true;
17534 
17535   // Do not optimize out type conversion of f128 type yet.
17536   // For some targets like x86_64, configuration is changed to keep one f128
17537   // value in one SSE register, but instruction selection cannot handle
17538   // FCOPYSIGN on SSE registers yet.
17539   if (YTy == MVT::f128)
17540     return false;
17541 
17542   return !YTy.isVector() || EnableVectorFCopySignExtendRound;
17543 }
17544 
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)17545 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
17546   SDValue N1 = N->getOperand(1);
17547   if (N1.getOpcode() != ISD::FP_EXTEND &&
17548       N1.getOpcode() != ISD::FP_ROUND)
17549     return false;
17550   EVT N1VT = N1->getValueType(0);
17551   EVT N1Op0VT = N1->getOperand(0).getValueType();
17552   return CanCombineFCOPYSIGN_EXTEND_ROUND(N1VT, N1Op0VT);
17553 }
17554 
visitFCOPYSIGN(SDNode * N)17555 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
17556   SDValue N0 = N->getOperand(0);
17557   SDValue N1 = N->getOperand(1);
17558   EVT VT = N->getValueType(0);
17559   SDLoc DL(N);
17560 
17561   // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
17562   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, DL, VT, {N0, N1}))
17563     return C;
17564 
17565   if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
17566     const APFloat &V = N1C->getValueAPF();
17567     // copysign(x, c1) -> fabs(x)       iff ispos(c1)
17568     // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
17569     if (!V.isNegative()) {
17570       if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
17571         return DAG.getNode(ISD::FABS, DL, VT, N0);
17572     } else {
17573       if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
17574         return DAG.getNode(ISD::FNEG, DL, VT,
17575                            DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
17576     }
17577   }
17578 
17579   // copysign(fabs(x), y) -> copysign(x, y)
17580   // copysign(fneg(x), y) -> copysign(x, y)
17581   // copysign(copysign(x,z), y) -> copysign(x, y)
17582   if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
17583       N0.getOpcode() == ISD::FCOPYSIGN)
17584     return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N0.getOperand(0), N1);
17585 
17586   // copysign(x, abs(y)) -> abs(x)
17587   if (N1.getOpcode() == ISD::FABS)
17588     return DAG.getNode(ISD::FABS, DL, VT, N0);
17589 
17590   // copysign(x, copysign(y,z)) -> copysign(x, z)
17591   if (N1.getOpcode() == ISD::FCOPYSIGN)
17592     return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N0, N1.getOperand(1));
17593 
17594   // copysign(x, fp_extend(y)) -> copysign(x, y)
17595   // copysign(x, fp_round(y)) -> copysign(x, y)
17596   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
17597     return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N0, N1.getOperand(0));
17598 
17599   // We only take the sign bit from the sign operand.
17600   EVT SignVT = N1.getValueType();
17601   if (SimplifyDemandedBits(N1,
17602                            APInt::getSignMask(SignVT.getScalarSizeInBits())))
17603     return SDValue(N, 0);
17604 
17605   // We only take the non-sign bits from the value operand
17606   if (SimplifyDemandedBits(N0,
17607                            APInt::getSignedMaxValue(VT.getScalarSizeInBits())))
17608     return SDValue(N, 0);
17609 
17610   return SDValue();
17611 }
17612 
visitFPOW(SDNode * N)17613 SDValue DAGCombiner::visitFPOW(SDNode *N) {
17614   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
17615   if (!ExponentC)
17616     return SDValue();
17617   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17618 
17619   // Try to convert x ** (1/3) into cube root.
17620   // TODO: Handle the various flavors of long double.
17621   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
17622   //       Some range near 1/3 should be fine.
17623   EVT VT = N->getValueType(0);
17624   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
17625       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
17626     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
17627     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
17628     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
17629     // For regular numbers, rounding may cause the results to differ.
17630     // Therefore, we require { nsz ninf nnan afn } for this transform.
17631     // TODO: We could select out the special cases if we don't have nsz/ninf.
17632     SDNodeFlags Flags = N->getFlags();
17633     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
17634         !Flags.hasApproximateFuncs())
17635       return SDValue();
17636 
17637     // Do not create a cbrt() libcall if the target does not have it, and do not
17638     // turn a pow that has lowering support into a cbrt() libcall.
17639     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
17640         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
17641          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
17642       return SDValue();
17643 
17644     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
17645   }
17646 
17647   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
17648   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
17649   // TODO: This could be extended (using a target hook) to handle smaller
17650   // power-of-2 fractional exponents.
17651   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
17652   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
17653   if (ExponentIs025 || ExponentIs075) {
17654     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
17655     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
17656     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
17657     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
17658     // For regular numbers, rounding may cause the results to differ.
17659     // Therefore, we require { nsz ninf afn } for this transform.
17660     // TODO: We could select out the special cases if we don't have nsz/ninf.
17661     SDNodeFlags Flags = N->getFlags();
17662 
17663     // We only need no signed zeros for the 0.25 case.
17664     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
17665         !Flags.hasApproximateFuncs())
17666       return SDValue();
17667 
17668     // Don't double the number of libcalls. We are trying to inline fast code.
17669     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
17670       return SDValue();
17671 
17672     // Assume that libcalls are the smallest code.
17673     // TODO: This restriction should probably be lifted for vectors.
17674     if (ForCodeSize)
17675       return SDValue();
17676 
17677     // pow(X, 0.25) --> sqrt(sqrt(X))
17678     SDLoc DL(N);
17679     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
17680     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
17681     if (ExponentIs025)
17682       return SqrtSqrt;
17683     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
17684     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
17685   }
17686 
17687   return SDValue();
17688 }
17689 
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)17690 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
17691                                const TargetLowering &TLI) {
17692   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
17693   // replacing casts with a libcall. We also must be allowed to ignore -0.0
17694   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
17695   // conversions would return +0.0.
17696   // FIXME: We should be able to use node-level FMF here.
17697   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
17698   EVT VT = N->getValueType(0);
17699   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
17700       !DAG.getTarget().Options.NoSignedZerosFPMath)
17701     return SDValue();
17702 
17703   // fptosi/fptoui round towards zero, so converting from FP to integer and
17704   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
17705   SDValue N0 = N->getOperand(0);
17706   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
17707       N0.getOperand(0).getValueType() == VT)
17708     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
17709 
17710   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
17711       N0.getOperand(0).getValueType() == VT)
17712     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
17713 
17714   return SDValue();
17715 }
17716 
visitSINT_TO_FP(SDNode * N)17717 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
17718   SDValue N0 = N->getOperand(0);
17719   EVT VT = N->getValueType(0);
17720   EVT OpVT = N0.getValueType();
17721 
17722   // [us]itofp(undef) = 0, because the result value is bounded.
17723   if (N0.isUndef())
17724     return DAG.getConstantFP(0.0, SDLoc(N), VT);
17725 
17726   // fold (sint_to_fp c1) -> c1fp
17727   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
17728       // ...but only if the target supports immediate floating-point values
17729       (!LegalOperations ||
17730        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
17731     return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
17732 
17733   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
17734   // but UINT_TO_FP is legal on this target, try to convert.
17735   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
17736       hasOperation(ISD::UINT_TO_FP, OpVT)) {
17737     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
17738     if (DAG.SignBitIsZero(N0))
17739       return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
17740   }
17741 
17742   // The next optimizations are desirable only if SELECT_CC can be lowered.
17743   // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
17744   if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
17745       !VT.isVector() &&
17746       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17747     SDLoc DL(N);
17748     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
17749                          DAG.getConstantFP(0.0, DL, VT));
17750   }
17751 
17752   // fold (sint_to_fp (zext (setcc x, y, cc))) ->
17753   //      (select (setcc x, y, cc), 1.0, 0.0)
17754   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
17755       N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
17756       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17757     SDLoc DL(N);
17758     return DAG.getSelect(DL, VT, N0.getOperand(0),
17759                          DAG.getConstantFP(1.0, DL, VT),
17760                          DAG.getConstantFP(0.0, DL, VT));
17761   }
17762 
17763   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17764     return FTrunc;
17765 
17766   return SDValue();
17767 }
17768 
visitUINT_TO_FP(SDNode * N)17769 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
17770   SDValue N0 = N->getOperand(0);
17771   EVT VT = N->getValueType(0);
17772   EVT OpVT = N0.getValueType();
17773 
17774   // [us]itofp(undef) = 0, because the result value is bounded.
17775   if (N0.isUndef())
17776     return DAG.getConstantFP(0.0, SDLoc(N), VT);
17777 
17778   // fold (uint_to_fp c1) -> c1fp
17779   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
17780       // ...but only if the target supports immediate floating-point values
17781       (!LegalOperations ||
17782        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
17783     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
17784 
17785   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
17786   // but SINT_TO_FP is legal on this target, try to convert.
17787   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
17788       hasOperation(ISD::SINT_TO_FP, OpVT)) {
17789     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
17790     if (DAG.SignBitIsZero(N0))
17791       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
17792   }
17793 
17794   // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
17795   if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
17796       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17797     SDLoc DL(N);
17798     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
17799                          DAG.getConstantFP(0.0, DL, VT));
17800   }
17801 
17802   if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17803     return FTrunc;
17804 
17805   return SDValue();
17806 }
17807 
17808 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)17809 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
17810   SDValue N0 = N->getOperand(0);
17811   EVT VT = N->getValueType(0);
17812 
17813   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
17814     return SDValue();
17815 
17816   SDValue Src = N0.getOperand(0);
17817   EVT SrcVT = Src.getValueType();
17818   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
17819   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
17820 
17821   // We can safely assume the conversion won't overflow the output range,
17822   // because (for example) (uint8_t)18293.f is undefined behavior.
17823 
17824   // Since we can assume the conversion won't overflow, our decision as to
17825   // whether the input will fit in the float should depend on the minimum
17826   // of the input range and output range.
17827 
17828   // This means this is also safe for a signed input and unsigned output, since
17829   // a negative input would lead to undefined behavior.
17830   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
17831   unsigned OutputSize = (int)VT.getScalarSizeInBits();
17832   unsigned ActualSize = std::min(InputSize, OutputSize);
17833   const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
17834 
17835   // We can only fold away the float conversion if the input range can be
17836   // represented exactly in the float range.
17837   if (APFloat::semanticsPrecision(sem) >= ActualSize) {
17838     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
17839       unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
17840                                                        : ISD::ZERO_EXTEND;
17841       return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
17842     }
17843     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
17844       return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
17845     return DAG.getBitcast(VT, Src);
17846   }
17847   return SDValue();
17848 }
17849 
visitFP_TO_SINT(SDNode * N)17850 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
17851   SDValue N0 = N->getOperand(0);
17852   EVT VT = N->getValueType(0);
17853 
17854   // fold (fp_to_sint undef) -> undef
17855   if (N0.isUndef())
17856     return DAG.getUNDEF(VT);
17857 
17858   // fold (fp_to_sint c1fp) -> c1
17859   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17860     return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
17861 
17862   return FoldIntToFPToInt(N, DAG);
17863 }
17864 
visitFP_TO_UINT(SDNode * N)17865 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
17866   SDValue N0 = N->getOperand(0);
17867   EVT VT = N->getValueType(0);
17868 
17869   // fold (fp_to_uint undef) -> undef
17870   if (N0.isUndef())
17871     return DAG.getUNDEF(VT);
17872 
17873   // fold (fp_to_uint c1fp) -> c1
17874   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17875     return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
17876 
17877   return FoldIntToFPToInt(N, DAG);
17878 }
17879 
visitXRINT(SDNode * N)17880 SDValue DAGCombiner::visitXRINT(SDNode *N) {
17881   SDValue N0 = N->getOperand(0);
17882   EVT VT = N->getValueType(0);
17883 
17884   // fold (lrint|llrint undef) -> undef
17885   if (N0.isUndef())
17886     return DAG.getUNDEF(VT);
17887 
17888   // fold (lrint|llrint c1fp) -> c1
17889   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17890     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);
17891 
17892   return SDValue();
17893 }
17894 
visitFP_ROUND(SDNode * N)17895 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
17896   SDValue N0 = N->getOperand(0);
17897   SDValue N1 = N->getOperand(1);
17898   EVT VT = N->getValueType(0);
17899 
17900   // fold (fp_round c1fp) -> c1fp
17901   if (SDValue C =
17902           DAG.FoldConstantArithmetic(ISD::FP_ROUND, SDLoc(N), VT, {N0, N1}))
17903     return C;
17904 
17905   // fold (fp_round (fp_extend x)) -> x
17906   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
17907     return N0.getOperand(0);
17908 
17909   // fold (fp_round (fp_round x)) -> (fp_round x)
17910   if (N0.getOpcode() == ISD::FP_ROUND) {
17911     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
17912     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
17913 
17914     // Avoid folding legal fp_rounds into non-legal ones.
17915     if (!hasOperation(ISD::FP_ROUND, VT))
17916       return SDValue();
17917 
17918     // Skip this folding if it results in an fp_round from f80 to f16.
17919     //
17920     // f80 to f16 always generates an expensive (and as yet, unimplemented)
17921     // libcall to __truncxfhf2 instead of selecting native f16 conversion
17922     // instructions from f32 or f64.  Moreover, the first (value-preserving)
17923     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
17924     // x86.
17925     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
17926       return SDValue();
17927 
17928     // If the first fp_round isn't a value preserving truncation, it might
17929     // introduce a tie in the second fp_round, that wouldn't occur in the
17930     // single-step fp_round we want to fold to.
17931     // In other words, double rounding isn't the same as rounding.
17932     // Also, this is a value preserving truncation iff both fp_round's are.
17933     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
17934       SDLoc DL(N);
17935       return DAG.getNode(
17936           ISD::FP_ROUND, DL, VT, N0.getOperand(0),
17937           DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
17938     }
17939   }
17940 
17941   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
17942   // Note: From a legality perspective, this is a two step transform.  First,
17943   // we duplicate the fp_round to the arguments of the copysign, then we
17944   // eliminate the fp_round on Y.  The second step requires an additional
17945   // predicate to match the implementation above.
17946   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
17947       CanCombineFCOPYSIGN_EXTEND_ROUND(VT,
17948                                        N0.getValueType())) {
17949     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
17950                               N0.getOperand(0), N1);
17951     AddToWorklist(Tmp.getNode());
17952     return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
17953                        Tmp, N0.getOperand(1));
17954   }
17955 
17956   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
17957     return NewVSel;
17958 
17959   return SDValue();
17960 }
17961 
visitFP_EXTEND(SDNode * N)17962 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
17963   SDValue N0 = N->getOperand(0);
17964   EVT VT = N->getValueType(0);
17965 
17966   if (VT.isVector())
17967     if (SDValue FoldedVOp = SimplifyVCastOp(N, SDLoc(N)))
17968       return FoldedVOp;
17969 
17970   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
17971   if (N->hasOneUse() &&
17972       N->use_begin()->getOpcode() == ISD::FP_ROUND)
17973     return SDValue();
17974 
17975   // fold (fp_extend c1fp) -> c1fp
17976   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
17977     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
17978 
17979   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
17980   if (N0.getOpcode() == ISD::FP16_TO_FP &&
17981       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
17982     return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
17983 
17984   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
17985   // value of X.
17986   if (N0.getOpcode() == ISD::FP_ROUND
17987       && N0.getConstantOperandVal(1) == 1) {
17988     SDValue In = N0.getOperand(0);
17989     if (In.getValueType() == VT) return In;
17990     if (VT.bitsLT(In.getValueType()))
17991       return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
17992                          In, N0.getOperand(1));
17993     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
17994   }
17995 
17996   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
17997   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
17998       TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
17999     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
18000     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
18001                                      LN0->getChain(),
18002                                      LN0->getBasePtr(), N0.getValueType(),
18003                                      LN0->getMemOperand());
18004     CombineTo(N, ExtLoad);
18005     CombineTo(
18006         N0.getNode(),
18007         DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
18008                     DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
18009         ExtLoad.getValue(1));
18010     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
18011   }
18012 
18013   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
18014     return NewVSel;
18015 
18016   return SDValue();
18017 }
18018 
visitFCEIL(SDNode * N)18019 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
18020   SDValue N0 = N->getOperand(0);
18021   EVT VT = N->getValueType(0);
18022 
18023   // fold (fceil c1) -> fceil(c1)
18024   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
18025     return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
18026 
18027   return SDValue();
18028 }
18029 
visitFTRUNC(SDNode * N)18030 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
18031   SDValue N0 = N->getOperand(0);
18032   EVT VT = N->getValueType(0);
18033 
18034   // fold (ftrunc c1) -> ftrunc(c1)
18035   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
18036     return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
18037 
18038   // fold ftrunc (known rounded int x) -> x
18039   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
18040   // likely to be generated to extract integer from a rounded floating value.
18041   switch (N0.getOpcode()) {
18042   default: break;
18043   case ISD::FRINT:
18044   case ISD::FTRUNC:
18045   case ISD::FNEARBYINT:
18046   case ISD::FROUNDEVEN:
18047   case ISD::FFLOOR:
18048   case ISD::FCEIL:
18049     return N0;
18050   }
18051 
18052   return SDValue();
18053 }
18054 
visitFFREXP(SDNode * N)18055 SDValue DAGCombiner::visitFFREXP(SDNode *N) {
18056   SDValue N0 = N->getOperand(0);
18057 
18058   // fold (ffrexp c1) -> ffrexp(c1)
18059   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
18060     return DAG.getNode(ISD::FFREXP, SDLoc(N), N->getVTList(), N0);
18061   return SDValue();
18062 }
18063 
visitFFLOOR(SDNode * N)18064 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
18065   SDValue N0 = N->getOperand(0);
18066   EVT VT = N->getValueType(0);
18067 
18068   // fold (ffloor c1) -> ffloor(c1)
18069   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
18070     return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
18071 
18072   return SDValue();
18073 }
18074 
visitFNEG(SDNode * N)18075 SDValue DAGCombiner::visitFNEG(SDNode *N) {
18076   SDValue N0 = N->getOperand(0);
18077   EVT VT = N->getValueType(0);
18078   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18079 
18080   // Constant fold FNEG.
18081   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
18082     return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
18083 
18084   if (SDValue NegN0 =
18085           TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
18086     return NegN0;
18087 
18088   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
18089   // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
18090   // know it was called from a context with a nsz flag if the input fsub does
18091   // not.
18092   if (N0.getOpcode() == ISD::FSUB &&
18093       (DAG.getTarget().Options.NoSignedZerosFPMath ||
18094        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
18095     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
18096                        N0.getOperand(0));
18097   }
18098 
18099   if (SDValue Cast = foldSignChangeInBitcast(N))
18100     return Cast;
18101 
18102   return SDValue();
18103 }
18104 
visitFMinMax(SDNode * N)18105 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
18106   SDValue N0 = N->getOperand(0);
18107   SDValue N1 = N->getOperand(1);
18108   EVT VT = N->getValueType(0);
18109   const SDNodeFlags Flags = N->getFlags();
18110   unsigned Opc = N->getOpcode();
18111   bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
18112   bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
18113   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18114 
18115   // Constant fold.
18116   if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
18117     return C;
18118 
18119   // Canonicalize to constant on RHS.
18120   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
18121       !DAG.isConstantFPBuildVectorOrConstantFP(N1))
18122     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
18123 
18124   if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
18125     const APFloat &AF = N1CFP->getValueAPF();
18126 
18127     // minnum(X, nan) -> X
18128     // maxnum(X, nan) -> X
18129     // minimum(X, nan) -> nan
18130     // maximum(X, nan) -> nan
18131     if (AF.isNaN())
18132       return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
18133 
18134     // In the following folds, inf can be replaced with the largest finite
18135     // float, if the ninf flag is set.
18136     if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
18137       // minnum(X, -inf) -> -inf
18138       // maxnum(X, +inf) -> +inf
18139       // minimum(X, -inf) -> -inf if nnan
18140       // maximum(X, +inf) -> +inf if nnan
18141       if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
18142         return N->getOperand(1);
18143 
18144       // minnum(X, +inf) -> X if nnan
18145       // maxnum(X, -inf) -> X if nnan
18146       // minimum(X, +inf) -> X
18147       // maximum(X, -inf) -> X
18148       if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
18149         return N->getOperand(0);
18150     }
18151   }
18152 
18153   if (SDValue SD = reassociateReduction(
18154           PropagatesNaN
18155               ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
18156               : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
18157           Opc, SDLoc(N), VT, N0, N1, Flags))
18158     return SD;
18159 
18160   return SDValue();
18161 }
18162 
visitFABS(SDNode * N)18163 SDValue DAGCombiner::visitFABS(SDNode *N) {
18164   SDValue N0 = N->getOperand(0);
18165   EVT VT = N->getValueType(0);
18166 
18167   // fold (fabs c1) -> fabs(c1)
18168   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
18169     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
18170 
18171   // fold (fabs (fabs x)) -> (fabs x)
18172   if (N0.getOpcode() == ISD::FABS)
18173     return N->getOperand(0);
18174 
18175   // fold (fabs (fneg x)) -> (fabs x)
18176   // fold (fabs (fcopysign x, y)) -> (fabs x)
18177   if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
18178     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
18179 
18180   if (SDValue Cast = foldSignChangeInBitcast(N))
18181     return Cast;
18182 
18183   return SDValue();
18184 }
18185 
visitBRCOND(SDNode * N)18186 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
18187   SDValue Chain = N->getOperand(0);
18188   SDValue N1 = N->getOperand(1);
18189   SDValue N2 = N->getOperand(2);
18190 
18191   // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
18192   // nondeterministic jumps).
18193   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
18194     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
18195                        N1->getOperand(0), N2);
18196   }
18197 
18198   // Variant of the previous fold where there is a SETCC in between:
18199   //   BRCOND(SETCC(FREEZE(X), CONST, Cond))
18200   // =>
18201   //   BRCOND(FREEZE(SETCC(X, CONST, Cond)))
18202   // =>
18203   //   BRCOND(SETCC(X, CONST, Cond))
18204   // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
18205   // isn't equivalent to true or false.
18206   // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
18207   // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
18208   if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
18209     SDValue S0 = N1->getOperand(0), S1 = N1->getOperand(1);
18210     ISD::CondCode Cond = cast<CondCodeSDNode>(N1->getOperand(2))->get();
18211     ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(S0);
18212     ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(S1);
18213     bool Updated = false;
18214 
18215     // Is 'X Cond C' always true or false?
18216     auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
18217       bool False = (Cond == ISD::SETULT && C->isZero()) ||
18218                    (Cond == ISD::SETLT && C->isMinSignedValue()) ||
18219                    (Cond == ISD::SETUGT && C->isAllOnes()) ||
18220                    (Cond == ISD::SETGT && C->isMaxSignedValue());
18221       bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
18222                   (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
18223                   (Cond == ISD::SETUGE && C->isZero()) ||
18224                   (Cond == ISD::SETGE && C->isMinSignedValue());
18225       return True || False;
18226     };
18227 
18228     if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
18229       if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
18230         S0 = S0->getOperand(0);
18231         Updated = true;
18232       }
18233     }
18234     if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
18235       if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond), S0C)) {
18236         S1 = S1->getOperand(0);
18237         Updated = true;
18238       }
18239     }
18240 
18241     if (Updated)
18242       return DAG.getNode(
18243           ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
18244           DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2);
18245   }
18246 
18247   // If N is a constant we could fold this into a fallthrough or unconditional
18248   // branch. However that doesn't happen very often in normal code, because
18249   // Instcombine/SimplifyCFG should have handled the available opportunities.
18250   // If we did this folding here, it would be necessary to update the
18251   // MachineBasicBlock CFG, which is awkward.
18252 
18253   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
18254   // on the target.
18255   if (N1.getOpcode() == ISD::SETCC &&
18256       TLI.isOperationLegalOrCustom(ISD::BR_CC,
18257                                    N1.getOperand(0).getValueType())) {
18258     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
18259                        Chain, N1.getOperand(2),
18260                        N1.getOperand(0), N1.getOperand(1), N2);
18261   }
18262 
18263   if (N1.hasOneUse()) {
18264     // rebuildSetCC calls visitXor which may change the Chain when there is a
18265     // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
18266     HandleSDNode ChainHandle(Chain);
18267     if (SDValue NewN1 = rebuildSetCC(N1))
18268       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
18269                          ChainHandle.getValue(), NewN1, N2);
18270   }
18271 
18272   return SDValue();
18273 }
18274 
rebuildSetCC(SDValue N)18275 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
18276   if (N.getOpcode() == ISD::SRL ||
18277       (N.getOpcode() == ISD::TRUNCATE &&
18278        (N.getOperand(0).hasOneUse() &&
18279         N.getOperand(0).getOpcode() == ISD::SRL))) {
18280     // Look pass the truncate.
18281     if (N.getOpcode() == ISD::TRUNCATE)
18282       N = N.getOperand(0);
18283 
18284     // Match this pattern so that we can generate simpler code:
18285     //
18286     //   %a = ...
18287     //   %b = and i32 %a, 2
18288     //   %c = srl i32 %b, 1
18289     //   brcond i32 %c ...
18290     //
18291     // into
18292     //
18293     //   %a = ...
18294     //   %b = and i32 %a, 2
18295     //   %c = setcc eq %b, 0
18296     //   brcond %c ...
18297     //
18298     // This applies only when the AND constant value has one bit set and the
18299     // SRL constant is equal to the log2 of the AND constant. The back-end is
18300     // smart enough to convert the result into a TEST/JMP sequence.
18301     SDValue Op0 = N.getOperand(0);
18302     SDValue Op1 = N.getOperand(1);
18303 
18304     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
18305       SDValue AndOp1 = Op0.getOperand(1);
18306 
18307       if (AndOp1.getOpcode() == ISD::Constant) {
18308         const APInt &AndConst = AndOp1->getAsAPIntVal();
18309 
18310         if (AndConst.isPowerOf2() &&
18311             Op1->getAsAPIntVal() == AndConst.logBase2()) {
18312           SDLoc DL(N);
18313           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
18314                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
18315                               ISD::SETNE);
18316         }
18317       }
18318     }
18319   }
18320 
18321   // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
18322   // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
18323   if (N.getOpcode() == ISD::XOR) {
18324     // Because we may call this on a speculatively constructed
18325     // SimplifiedSetCC Node, we need to simplify this node first.
18326     // Ideally this should be folded into SimplifySetCC and not
18327     // here. For now, grab a handle to N so we don't lose it from
18328     // replacements interal to the visit.
18329     HandleSDNode XORHandle(N);
18330     while (N.getOpcode() == ISD::XOR) {
18331       SDValue Tmp = visitXOR(N.getNode());
18332       // No simplification done.
18333       if (!Tmp.getNode())
18334         break;
18335       // Returning N is form in-visit replacement that may invalidated
18336       // N. Grab value from Handle.
18337       if (Tmp.getNode() == N.getNode())
18338         N = XORHandle.getValue();
18339       else // Node simplified. Try simplifying again.
18340         N = Tmp;
18341     }
18342 
18343     if (N.getOpcode() != ISD::XOR)
18344       return N;
18345 
18346     SDValue Op0 = N->getOperand(0);
18347     SDValue Op1 = N->getOperand(1);
18348 
18349     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
18350       bool Equal = false;
18351       // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
18352       if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
18353           Op0.getValueType() == MVT::i1) {
18354         N = Op0;
18355         Op0 = N->getOperand(0);
18356         Op1 = N->getOperand(1);
18357         Equal = true;
18358       }
18359 
18360       EVT SetCCVT = N.getValueType();
18361       if (LegalTypes)
18362         SetCCVT = getSetCCResultType(SetCCVT);
18363       // Replace the uses of XOR with SETCC
18364       return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
18365                           Equal ? ISD::SETEQ : ISD::SETNE);
18366     }
18367   }
18368 
18369   return SDValue();
18370 }
18371 
18372 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
18373 //
visitBR_CC(SDNode * N)18374 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
18375   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
18376   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
18377 
18378   // If N is a constant we could fold this into a fallthrough or unconditional
18379   // branch. However that doesn't happen very often in normal code, because
18380   // Instcombine/SimplifyCFG should have handled the available opportunities.
18381   // If we did this folding here, it would be necessary to update the
18382   // MachineBasicBlock CFG, which is awkward.
18383 
18384   // Use SimplifySetCC to simplify SETCC's.
18385   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
18386                                CondLHS, CondRHS, CC->get(), SDLoc(N),
18387                                false);
18388   if (Simp.getNode()) AddToWorklist(Simp.getNode());
18389 
18390   // fold to a simpler setcc
18391   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
18392     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
18393                        N->getOperand(0), Simp.getOperand(2),
18394                        Simp.getOperand(0), Simp.getOperand(1),
18395                        N->getOperand(4));
18396 
18397   return SDValue();
18398 }
18399 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)18400 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
18401                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
18402                                      const TargetLowering &TLI) {
18403   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
18404     if (LD->isIndexed())
18405       return false;
18406     EVT VT = LD->getMemoryVT();
18407     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
18408       return false;
18409     Ptr = LD->getBasePtr();
18410   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
18411     if (ST->isIndexed())
18412       return false;
18413     EVT VT = ST->getMemoryVT();
18414     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
18415       return false;
18416     Ptr = ST->getBasePtr();
18417     IsLoad = false;
18418   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
18419     if (LD->isIndexed())
18420       return false;
18421     EVT VT = LD->getMemoryVT();
18422     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
18423         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
18424       return false;
18425     Ptr = LD->getBasePtr();
18426     IsMasked = true;
18427   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
18428     if (ST->isIndexed())
18429       return false;
18430     EVT VT = ST->getMemoryVT();
18431     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
18432         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
18433       return false;
18434     Ptr = ST->getBasePtr();
18435     IsLoad = false;
18436     IsMasked = true;
18437   } else {
18438     return false;
18439   }
18440   return true;
18441 }
18442 
18443 /// Try turning a load/store into a pre-indexed load/store when the base
18444 /// pointer is an add or subtract and it has other uses besides the load/store.
18445 /// After the transformation, the new indexed load/store has effectively folded
18446 /// the add/subtract in and all of its other uses are redirected to the
18447 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)18448 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
18449   if (Level < AfterLegalizeDAG)
18450     return false;
18451 
18452   bool IsLoad = true;
18453   bool IsMasked = false;
18454   SDValue Ptr;
18455   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
18456                                 Ptr, TLI))
18457     return false;
18458 
18459   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
18460   // out.  There is no reason to make this a preinc/predec.
18461   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
18462       Ptr->hasOneUse())
18463     return false;
18464 
18465   // Ask the target to do addressing mode selection.
18466   SDValue BasePtr;
18467   SDValue Offset;
18468   ISD::MemIndexedMode AM = ISD::UNINDEXED;
18469   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
18470     return false;
18471 
18472   // Backends without true r+i pre-indexed forms may need to pass a
18473   // constant base with a variable offset so that constant coercion
18474   // will work with the patterns in canonical form.
18475   bool Swapped = false;
18476   if (isa<ConstantSDNode>(BasePtr)) {
18477     std::swap(BasePtr, Offset);
18478     Swapped = true;
18479   }
18480 
18481   // Don't create a indexed load / store with zero offset.
18482   if (isNullConstant(Offset))
18483     return false;
18484 
18485   // Try turning it into a pre-indexed load / store except when:
18486   // 1) The new base ptr is a frame index.
18487   // 2) If N is a store and the new base ptr is either the same as or is a
18488   //    predecessor of the value being stored.
18489   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
18490   //    that would create a cycle.
18491   // 4) All uses are load / store ops that use it as old base ptr.
18492 
18493   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
18494   // (plus the implicit offset) to a register to preinc anyway.
18495   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
18496     return false;
18497 
18498   // Check #2.
18499   if (!IsLoad) {
18500     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
18501                            : cast<StoreSDNode>(N)->getValue();
18502 
18503     // Would require a copy.
18504     if (Val == BasePtr)
18505       return false;
18506 
18507     // Would create a cycle.
18508     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
18509       return false;
18510   }
18511 
18512   // Caches for hasPredecessorHelper.
18513   SmallPtrSet<const SDNode *, 32> Visited;
18514   SmallVector<const SDNode *, 16> Worklist;
18515   Worklist.push_back(N);
18516 
18517   // If the offset is a constant, there may be other adds of constants that
18518   // can be folded with this one. We should do this to avoid having to keep
18519   // a copy of the original base pointer.
18520   SmallVector<SDNode *, 16> OtherUses;
18521   constexpr unsigned int MaxSteps = 8192;
18522   if (isa<ConstantSDNode>(Offset))
18523     for (SDNode::use_iterator UI = BasePtr->use_begin(),
18524                               UE = BasePtr->use_end();
18525          UI != UE; ++UI) {
18526       SDUse &Use = UI.getUse();
18527       // Skip the use that is Ptr and uses of other results from BasePtr's
18528       // node (important for nodes that return multiple results).
18529       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
18530         continue;
18531 
18532       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist,
18533                                        MaxSteps))
18534         continue;
18535 
18536       if (Use.getUser()->getOpcode() != ISD::ADD &&
18537           Use.getUser()->getOpcode() != ISD::SUB) {
18538         OtherUses.clear();
18539         break;
18540       }
18541 
18542       SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
18543       if (!isa<ConstantSDNode>(Op1)) {
18544         OtherUses.clear();
18545         break;
18546       }
18547 
18548       // FIXME: In some cases, we can be smarter about this.
18549       if (Op1.getValueType() != Offset.getValueType()) {
18550         OtherUses.clear();
18551         break;
18552       }
18553 
18554       OtherUses.push_back(Use.getUser());
18555     }
18556 
18557   if (Swapped)
18558     std::swap(BasePtr, Offset);
18559 
18560   // Now check for #3 and #4.
18561   bool RealUse = false;
18562 
18563   for (SDNode *Use : Ptr->uses()) {
18564     if (Use == N)
18565       continue;
18566     if (SDNode::hasPredecessorHelper(Use, Visited, Worklist, MaxSteps))
18567       return false;
18568 
18569     // If Ptr may be folded in addressing mode of other use, then it's
18570     // not profitable to do this transformation.
18571     if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
18572       RealUse = true;
18573   }
18574 
18575   if (!RealUse)
18576     return false;
18577 
18578   SDValue Result;
18579   if (!IsMasked) {
18580     if (IsLoad)
18581       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
18582     else
18583       Result =
18584           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
18585   } else {
18586     if (IsLoad)
18587       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
18588                                         Offset, AM);
18589     else
18590       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
18591                                          Offset, AM);
18592   }
18593   ++PreIndexedNodes;
18594   ++NodesCombined;
18595   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
18596              Result.dump(&DAG); dbgs() << '\n');
18597   WorklistRemover DeadNodes(*this);
18598   if (IsLoad) {
18599     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
18600     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
18601   } else {
18602     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
18603   }
18604 
18605   // Finally, since the node is now dead, remove it from the graph.
18606   deleteAndRecombine(N);
18607 
18608   if (Swapped)
18609     std::swap(BasePtr, Offset);
18610 
18611   // Replace other uses of BasePtr that can be updated to use Ptr
18612   for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
18613     unsigned OffsetIdx = 1;
18614     if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
18615       OffsetIdx = 0;
18616     assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
18617            BasePtr.getNode() && "Expected BasePtr operand");
18618 
18619     // We need to replace ptr0 in the following expression:
18620     //   x0 * offset0 + y0 * ptr0 = t0
18621     // knowing that
18622     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
18623     //
18624     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
18625     // indexed load/store and the expression that needs to be re-written.
18626     //
18627     // Therefore, we have:
18628     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
18629 
18630     auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
18631     const APInt &Offset0 = CN->getAPIntValue();
18632     const APInt &Offset1 = Offset->getAsAPIntVal();
18633     int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
18634     int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
18635     int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
18636     int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
18637 
18638     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
18639 
18640     APInt CNV = Offset0;
18641     if (X0 < 0) CNV = -CNV;
18642     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
18643     else CNV = CNV - Offset1;
18644 
18645     SDLoc DL(OtherUses[i]);
18646 
18647     // We can now generate the new expression.
18648     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
18649     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
18650 
18651     SDValue NewUse = DAG.getNode(Opcode,
18652                                  DL,
18653                                  OtherUses[i]->getValueType(0), NewOp1, NewOp2);
18654     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
18655     deleteAndRecombine(OtherUses[i]);
18656   }
18657 
18658   // Replace the uses of Ptr with uses of the updated base value.
18659   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
18660   deleteAndRecombine(Ptr.getNode());
18661   AddToWorklist(Result.getNode());
18662 
18663   return true;
18664 }
18665 
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)18666 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
18667                                    SDValue &BasePtr, SDValue &Offset,
18668                                    ISD::MemIndexedMode &AM,
18669                                    SelectionDAG &DAG,
18670                                    const TargetLowering &TLI) {
18671   if (PtrUse == N ||
18672       (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
18673     return false;
18674 
18675   if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
18676     return false;
18677 
18678   // Don't create a indexed load / store with zero offset.
18679   if (isNullConstant(Offset))
18680     return false;
18681 
18682   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
18683     return false;
18684 
18685   SmallPtrSet<const SDNode *, 32> Visited;
18686   for (SDNode *Use : BasePtr->uses()) {
18687     if (Use == Ptr.getNode())
18688       continue;
18689 
18690     // No if there's a later user which could perform the index instead.
18691     if (isa<MemSDNode>(Use)) {
18692       bool IsLoad = true;
18693       bool IsMasked = false;
18694       SDValue OtherPtr;
18695       if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
18696                                    IsMasked, OtherPtr, TLI)) {
18697         SmallVector<const SDNode *, 2> Worklist;
18698         Worklist.push_back(Use);
18699         if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
18700           return false;
18701       }
18702     }
18703 
18704     // If all the uses are load / store addresses, then don't do the
18705     // transformation.
18706     if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
18707       for (SDNode *UseUse : Use->uses())
18708         if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
18709           return false;
18710     }
18711   }
18712   return true;
18713 }
18714 
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)18715 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
18716                                          bool &IsMasked, SDValue &Ptr,
18717                                          SDValue &BasePtr, SDValue &Offset,
18718                                          ISD::MemIndexedMode &AM,
18719                                          SelectionDAG &DAG,
18720                                          const TargetLowering &TLI) {
18721   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
18722                                 IsMasked, Ptr, TLI) ||
18723       Ptr->hasOneUse())
18724     return nullptr;
18725 
18726   // Try turning it into a post-indexed load / store except when
18727   // 1) All uses are load / store ops that use it as base ptr (and
18728   //    it may be folded as addressing mmode).
18729   // 2) Op must be independent of N, i.e. Op is neither a predecessor
18730   //    nor a successor of N. Otherwise, if Op is folded that would
18731   //    create a cycle.
18732   for (SDNode *Op : Ptr->uses()) {
18733     // Check for #1.
18734     if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
18735       continue;
18736 
18737     // Check for #2.
18738     SmallPtrSet<const SDNode *, 32> Visited;
18739     SmallVector<const SDNode *, 8> Worklist;
18740     constexpr unsigned int MaxSteps = 8192;
18741     // Ptr is predecessor to both N and Op.
18742     Visited.insert(Ptr.getNode());
18743     Worklist.push_back(N);
18744     Worklist.push_back(Op);
18745     if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
18746         !SDNode::hasPredecessorHelper(Op, Visited, Worklist, MaxSteps))
18747       return Op;
18748   }
18749   return nullptr;
18750 }
18751 
18752 /// Try to combine a load/store with a add/sub of the base pointer node into a
18753 /// post-indexed load/store. The transformation folded the add/subtract into the
18754 /// new indexed load/store effectively and all of its uses are redirected to the
18755 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)18756 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
18757   if (Level < AfterLegalizeDAG)
18758     return false;
18759 
18760   bool IsLoad = true;
18761   bool IsMasked = false;
18762   SDValue Ptr;
18763   SDValue BasePtr;
18764   SDValue Offset;
18765   ISD::MemIndexedMode AM = ISD::UNINDEXED;
18766   SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
18767                                          Offset, AM, DAG, TLI);
18768   if (!Op)
18769     return false;
18770 
18771   SDValue Result;
18772   if (!IsMasked)
18773     Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
18774                                          Offset, AM)
18775                     : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
18776                                           BasePtr, Offset, AM);
18777   else
18778     Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
18779                                                BasePtr, Offset, AM)
18780                     : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
18781                                                 BasePtr, Offset, AM);
18782   ++PostIndexedNodes;
18783   ++NodesCombined;
18784   LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
18785              Result.dump(&DAG); dbgs() << '\n');
18786   WorklistRemover DeadNodes(*this);
18787   if (IsLoad) {
18788     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
18789     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
18790   } else {
18791     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
18792   }
18793 
18794   // Finally, since the node is now dead, remove it from the graph.
18795   deleteAndRecombine(N);
18796 
18797   // Replace the uses of Use with uses of the updated base value.
18798   DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
18799                                 Result.getValue(IsLoad ? 1 : 0));
18800   deleteAndRecombine(Op);
18801   return true;
18802 }
18803 
18804 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)18805 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
18806   ISD::MemIndexedMode AM = LD->getAddressingMode();
18807   assert(AM != ISD::UNINDEXED);
18808   SDValue BP = LD->getOperand(1);
18809   SDValue Inc = LD->getOperand(2);
18810 
18811   // Some backends use TargetConstants for load offsets, but don't expect
18812   // TargetConstants in general ADD nodes. We can convert these constants into
18813   // regular Constants (if the constant is not opaque).
18814   assert((Inc.getOpcode() != ISD::TargetConstant ||
18815           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
18816          "Cannot split out indexing using opaque target constants");
18817   if (Inc.getOpcode() == ISD::TargetConstant) {
18818     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
18819     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
18820                           ConstInc->getValueType(0));
18821   }
18822 
18823   unsigned Opc =
18824       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
18825   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
18826 }
18827 
numVectorEltsOrZero(EVT T)18828 static inline ElementCount numVectorEltsOrZero(EVT T) {
18829   return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
18830 }
18831 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)18832 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
18833   EVT STType = Val.getValueType();
18834   EVT STMemType = ST->getMemoryVT();
18835   if (STType == STMemType)
18836     return true;
18837   if (isTypeLegal(STMemType))
18838     return false; // fail.
18839   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
18840       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
18841     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
18842     return true;
18843   }
18844   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
18845       STType.isInteger() && STMemType.isInteger()) {
18846     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
18847     return true;
18848   }
18849   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
18850     Val = DAG.getBitcast(STMemType, Val);
18851     return true;
18852   }
18853   return false; // fail.
18854 }
18855 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)18856 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
18857   EVT LDMemType = LD->getMemoryVT();
18858   EVT LDType = LD->getValueType(0);
18859   assert(Val.getValueType() == LDMemType &&
18860          "Attempting to extend value of non-matching type");
18861   if (LDType == LDMemType)
18862     return true;
18863   if (LDMemType.isInteger() && LDType.isInteger()) {
18864     switch (LD->getExtensionType()) {
18865     case ISD::NON_EXTLOAD:
18866       Val = DAG.getBitcast(LDType, Val);
18867       return true;
18868     case ISD::EXTLOAD:
18869       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
18870       return true;
18871     case ISD::SEXTLOAD:
18872       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
18873       return true;
18874     case ISD::ZEXTLOAD:
18875       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
18876       return true;
18877     }
18878   }
18879   return false;
18880 }
18881 
getUniqueStoreFeeding(LoadSDNode * LD,int64_t & Offset)18882 StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
18883                                                 int64_t &Offset) {
18884   SDValue Chain = LD->getOperand(0);
18885 
18886   // Look through CALLSEQ_START.
18887   if (Chain.getOpcode() == ISD::CALLSEQ_START)
18888     Chain = Chain->getOperand(0);
18889 
18890   StoreSDNode *ST = nullptr;
18891   SmallVector<SDValue, 8> Aliases;
18892   if (Chain.getOpcode() == ISD::TokenFactor) {
18893     // Look for unique store within the TokenFactor.
18894     for (SDValue Op : Chain->ops()) {
18895       StoreSDNode *Store = dyn_cast<StoreSDNode>(Op.getNode());
18896       if (!Store)
18897         continue;
18898       BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
18899       BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
18900       if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
18901         continue;
18902       // Make sure the store is not aliased with any nodes in TokenFactor.
18903       GatherAllAliases(Store, Chain, Aliases);
18904       if (Aliases.empty() ||
18905           (Aliases.size() == 1 && Aliases.front().getNode() == Store))
18906         ST = Store;
18907       break;
18908     }
18909   } else {
18910     StoreSDNode *Store = dyn_cast<StoreSDNode>(Chain.getNode());
18911     if (Store) {
18912       BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
18913       BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
18914       if (BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
18915         ST = Store;
18916     }
18917   }
18918 
18919   return ST;
18920 }
18921 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)18922 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
18923   if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
18924     return SDValue();
18925   SDValue Chain = LD->getOperand(0);
18926   int64_t Offset;
18927 
18928   StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
18929   // TODO: Relax this restriction for unordered atomics (see D66309)
18930   if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
18931     return SDValue();
18932 
18933   EVT LDType = LD->getValueType(0);
18934   EVT LDMemType = LD->getMemoryVT();
18935   EVT STMemType = ST->getMemoryVT();
18936   EVT STType = ST->getValue().getValueType();
18937 
18938   // There are two cases to consider here:
18939   //  1. The store is fixed width and the load is scalable. In this case we
18940   //     don't know at compile time if the store completely envelops the load
18941   //     so we abandon the optimisation.
18942   //  2. The store is scalable and the load is fixed width. We could
18943   //     potentially support a limited number of cases here, but there has been
18944   //     no cost-benefit analysis to prove it's worth it.
18945   bool LdStScalable = LDMemType.isScalableVT();
18946   if (LdStScalable != STMemType.isScalableVT())
18947     return SDValue();
18948 
18949   // If we are dealing with scalable vectors on a big endian platform the
18950   // calculation of offsets below becomes trickier, since we do not know at
18951   // compile time the absolute size of the vector. Until we've done more
18952   // analysis on big-endian platforms it seems better to bail out for now.
18953   if (LdStScalable && DAG.getDataLayout().isBigEndian())
18954     return SDValue();
18955 
18956   // Normalize for Endianness. After this Offset=0 will denote that the least
18957   // significant bit in the loaded value maps to the least significant bit in
18958   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
18959   // n:th least significant byte of the stored value.
18960   int64_t OrigOffset = Offset;
18961   if (DAG.getDataLayout().isBigEndian())
18962     Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
18963               (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
18964                  8 -
18965              Offset;
18966 
18967   // Check that the stored value cover all bits that are loaded.
18968   bool STCoversLD;
18969 
18970   TypeSize LdMemSize = LDMemType.getSizeInBits();
18971   TypeSize StMemSize = STMemType.getSizeInBits();
18972   if (LdStScalable)
18973     STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
18974   else
18975     STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
18976                                    StMemSize.getFixedValue());
18977 
18978   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
18979     if (LD->isIndexed()) {
18980       // Cannot handle opaque target constants and we must respect the user's
18981       // request not to split indexes from loads.
18982       if (!canSplitIdx(LD))
18983         return SDValue();
18984       SDValue Idx = SplitIndexingFromLoad(LD);
18985       SDValue Ops[] = {Val, Idx, Chain};
18986       return CombineTo(LD, Ops, 3);
18987     }
18988     return CombineTo(LD, Val, Chain);
18989   };
18990 
18991   if (!STCoversLD)
18992     return SDValue();
18993 
18994   // Memory as copy space (potentially masked).
18995   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
18996     // Simple case: Direct non-truncating forwarding
18997     if (LDType.getSizeInBits() == LdMemSize)
18998       return ReplaceLd(LD, ST->getValue(), Chain);
18999     // Can we model the truncate and extension with an and mask?
19000     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
19001         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
19002       // Mask to size of LDMemType
19003       auto Mask =
19004           DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
19005                                                StMemSize.getFixedValue()),
19006                           SDLoc(ST), STType);
19007       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
19008       return ReplaceLd(LD, Val, Chain);
19009     }
19010   }
19011 
19012   // Handle some cases for big-endian that would be Offset 0 and handled for
19013   // little-endian.
19014   SDValue Val = ST->getValue();
19015   if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
19016     if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
19017         !LDType.isVector() && isTypeLegal(STType) &&
19018         TLI.isOperationLegal(ISD::SRL, STType)) {
19019       Val = DAG.getNode(ISD::SRL, SDLoc(LD), STType, Val,
19020                         DAG.getConstant(Offset * 8, SDLoc(LD), STType));
19021       Offset = 0;
19022     }
19023   }
19024 
19025   // TODO: Deal with nonzero offset.
19026   if (LD->getBasePtr().isUndef() || Offset != 0)
19027     return SDValue();
19028   // Model necessary truncations / extenstions.
19029   // Truncate Value To Stored Memory Size.
19030   do {
19031     if (!getTruncatedStoreValue(ST, Val))
19032       break;
19033     if (!isTypeLegal(LDMemType))
19034       break;
19035     if (STMemType != LDMemType) {
19036       // TODO: Support vectors? This requires extract_subvector/bitcast.
19037       if (!STMemType.isVector() && !LDMemType.isVector() &&
19038           STMemType.isInteger() && LDMemType.isInteger())
19039         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
19040       else
19041         break;
19042     }
19043     if (!extendLoadedValueToExtension(LD, Val))
19044       break;
19045     return ReplaceLd(LD, Val, Chain);
19046   } while (false);
19047 
19048   // On failure, cleanup dead nodes we may have created.
19049   if (Val->use_empty())
19050     deleteAndRecombine(Val.getNode());
19051   return SDValue();
19052 }
19053 
visitLOAD(SDNode * N)19054 SDValue DAGCombiner::visitLOAD(SDNode *N) {
19055   LoadSDNode *LD  = cast<LoadSDNode>(N);
19056   SDValue Chain = LD->getChain();
19057   SDValue Ptr   = LD->getBasePtr();
19058 
19059   // If load is not volatile and there are no uses of the loaded value (and
19060   // the updated indexed value in case of indexed loads), change uses of the
19061   // chain value into uses of the chain input (i.e. delete the dead load).
19062   // TODO: Allow this for unordered atomics (see D66309)
19063   if (LD->isSimple()) {
19064     if (N->getValueType(1) == MVT::Other) {
19065       // Unindexed loads.
19066       if (!N->hasAnyUseOfValue(0)) {
19067         // It's not safe to use the two value CombineTo variant here. e.g.
19068         // v1, chain2 = load chain1, loc
19069         // v2, chain3 = load chain2, loc
19070         // v3         = add v2, c
19071         // Now we replace use of chain2 with chain1.  This makes the second load
19072         // isomorphic to the one we are deleting, and thus makes this load live.
19073         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
19074                    dbgs() << "\nWith chain: "; Chain.dump(&DAG);
19075                    dbgs() << "\n");
19076         WorklistRemover DeadNodes(*this);
19077         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
19078         AddUsersToWorklist(Chain.getNode());
19079         if (N->use_empty())
19080           deleteAndRecombine(N);
19081 
19082         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
19083       }
19084     } else {
19085       // Indexed loads.
19086       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
19087 
19088       // If this load has an opaque TargetConstant offset, then we cannot split
19089       // the indexing into an add/sub directly (that TargetConstant may not be
19090       // valid for a different type of node, and we cannot convert an opaque
19091       // target constant into a regular constant).
19092       bool CanSplitIdx = canSplitIdx(LD);
19093 
19094       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
19095         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
19096         SDValue Index;
19097         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
19098           Index = SplitIndexingFromLoad(LD);
19099           // Try to fold the base pointer arithmetic into subsequent loads and
19100           // stores.
19101           AddUsersToWorklist(N);
19102         } else
19103           Index = DAG.getUNDEF(N->getValueType(1));
19104         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
19105                    dbgs() << "\nWith: "; Undef.dump(&DAG);
19106                    dbgs() << " and 2 other values\n");
19107         WorklistRemover DeadNodes(*this);
19108         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
19109         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
19110         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
19111         deleteAndRecombine(N);
19112         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
19113       }
19114     }
19115   }
19116 
19117   // If this load is directly stored, replace the load value with the stored
19118   // value.
19119   if (auto V = ForwardStoreValueToDirectLoad(LD))
19120     return V;
19121 
19122   // Try to infer better alignment information than the load already has.
19123   if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
19124       !LD->isAtomic()) {
19125     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
19126       if (*Alignment > LD->getAlign() &&
19127           isAligned(*Alignment, LD->getSrcValueOffset())) {
19128         SDValue NewLoad = DAG.getExtLoad(
19129             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
19130             LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
19131             LD->getMemOperand()->getFlags(), LD->getAAInfo());
19132         // NewLoad will always be N as we are only refining the alignment
19133         assert(NewLoad.getNode() == N);
19134         (void)NewLoad;
19135       }
19136     }
19137   }
19138 
19139   if (LD->isUnindexed()) {
19140     // Walk up chain skipping non-aliasing memory nodes.
19141     SDValue BetterChain = FindBetterChain(LD, Chain);
19142 
19143     // If there is a better chain.
19144     if (Chain != BetterChain) {
19145       SDValue ReplLoad;
19146 
19147       // Replace the chain to void dependency.
19148       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
19149         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
19150                                BetterChain, Ptr, LD->getMemOperand());
19151       } else {
19152         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
19153                                   LD->getValueType(0),
19154                                   BetterChain, Ptr, LD->getMemoryVT(),
19155                                   LD->getMemOperand());
19156       }
19157 
19158       // Create token factor to keep old chain connected.
19159       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
19160                                   MVT::Other, Chain, ReplLoad.getValue(1));
19161 
19162       // Replace uses with load result and token factor
19163       return CombineTo(N, ReplLoad.getValue(0), Token);
19164     }
19165   }
19166 
19167   // Try transforming N to an indexed load.
19168   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
19169     return SDValue(N, 0);
19170 
19171   // Try to slice up N to more direct loads if the slices are mapped to
19172   // different register banks or pairing can take place.
19173   if (SliceUpLoad(N))
19174     return SDValue(N, 0);
19175 
19176   return SDValue();
19177 }
19178 
19179 namespace {
19180 
19181 /// Helper structure used to slice a load in smaller loads.
19182 /// Basically a slice is obtained from the following sequence:
19183 /// Origin = load Ty1, Base
19184 /// Shift = srl Ty1 Origin, CstTy Amount
19185 /// Inst = trunc Shift to Ty2
19186 ///
19187 /// Then, it will be rewritten into:
19188 /// Slice = load SliceTy, Base + SliceOffset
19189 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
19190 ///
19191 /// SliceTy is deduced from the number of bits that are actually used to
19192 /// build Inst.
19193 struct LoadedSlice {
19194   /// Helper structure used to compute the cost of a slice.
19195   struct Cost {
19196     /// Are we optimizing for code size.
19197     bool ForCodeSize = false;
19198 
19199     /// Various cost.
19200     unsigned Loads = 0;
19201     unsigned Truncates = 0;
19202     unsigned CrossRegisterBanksCopies = 0;
19203     unsigned ZExts = 0;
19204     unsigned Shift = 0;
19205 
Cost__anon666e37104411::LoadedSlice::Cost19206     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
19207 
19208     /// Get the cost of one isolated slice.
Cost__anon666e37104411::LoadedSlice::Cost19209     Cost(const LoadedSlice &LS, bool ForCodeSize)
19210         : ForCodeSize(ForCodeSize), Loads(1) {
19211       EVT TruncType = LS.Inst->getValueType(0);
19212       EVT LoadedType = LS.getLoadedType();
19213       if (TruncType != LoadedType &&
19214           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
19215         ZExts = 1;
19216     }
19217 
19218     /// Account for slicing gain in the current cost.
19219     /// Slicing provide a few gains like removing a shift or a
19220     /// truncate. This method allows to grow the cost of the original
19221     /// load with the gain from this slice.
addSliceGain__anon666e37104411::LoadedSlice::Cost19222     void addSliceGain(const LoadedSlice &LS) {
19223       // Each slice saves a truncate.
19224       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
19225       if (!TLI.isTruncateFree(LS.Inst->getOperand(0), LS.Inst->getValueType(0)))
19226         ++Truncates;
19227       // If there is a shift amount, this slice gets rid of it.
19228       if (LS.Shift)
19229         ++Shift;
19230       // If this slice can merge a cross register bank copy, account for it.
19231       if (LS.canMergeExpensiveCrossRegisterBankCopy())
19232         ++CrossRegisterBanksCopies;
19233     }
19234 
operator +=__anon666e37104411::LoadedSlice::Cost19235     Cost &operator+=(const Cost &RHS) {
19236       Loads += RHS.Loads;
19237       Truncates += RHS.Truncates;
19238       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
19239       ZExts += RHS.ZExts;
19240       Shift += RHS.Shift;
19241       return *this;
19242     }
19243 
operator ==__anon666e37104411::LoadedSlice::Cost19244     bool operator==(const Cost &RHS) const {
19245       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
19246              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
19247              ZExts == RHS.ZExts && Shift == RHS.Shift;
19248     }
19249 
operator !=__anon666e37104411::LoadedSlice::Cost19250     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
19251 
operator <__anon666e37104411::LoadedSlice::Cost19252     bool operator<(const Cost &RHS) const {
19253       // Assume cross register banks copies are as expensive as loads.
19254       // FIXME: Do we want some more target hooks?
19255       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
19256       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
19257       // Unless we are optimizing for code size, consider the
19258       // expensive operation first.
19259       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
19260         return ExpensiveOpsLHS < ExpensiveOpsRHS;
19261       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
19262              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
19263     }
19264 
operator >__anon666e37104411::LoadedSlice::Cost19265     bool operator>(const Cost &RHS) const { return RHS < *this; }
19266 
operator <=__anon666e37104411::LoadedSlice::Cost19267     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
19268 
operator >=__anon666e37104411::LoadedSlice::Cost19269     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
19270   };
19271 
19272   // The last instruction that represent the slice. This should be a
19273   // truncate instruction.
19274   SDNode *Inst;
19275 
19276   // The original load instruction.
19277   LoadSDNode *Origin;
19278 
19279   // The right shift amount in bits from the original load.
19280   unsigned Shift;
19281 
19282   // The DAG from which Origin came from.
19283   // This is used to get some contextual information about legal types, etc.
19284   SelectionDAG *DAG;
19285 
LoadedSlice__anon666e37104411::LoadedSlice19286   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
19287               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
19288       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
19289 
19290   /// Get the bits used in a chunk of bits \p BitWidth large.
19291   /// \return Result is \p BitWidth and has used bits set to 1 and
19292   ///         not used bits set to 0.
getUsedBits__anon666e37104411::LoadedSlice19293   APInt getUsedBits() const {
19294     // Reproduce the trunc(lshr) sequence:
19295     // - Start from the truncated value.
19296     // - Zero extend to the desired bit width.
19297     // - Shift left.
19298     assert(Origin && "No original load to compare against.");
19299     unsigned BitWidth = Origin->getValueSizeInBits(0);
19300     assert(Inst && "This slice is not bound to an instruction");
19301     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
19302            "Extracted slice is bigger than the whole type!");
19303     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
19304     UsedBits.setAllBits();
19305     UsedBits = UsedBits.zext(BitWidth);
19306     UsedBits <<= Shift;
19307     return UsedBits;
19308   }
19309 
19310   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon666e37104411::LoadedSlice19311   unsigned getLoadedSize() const {
19312     unsigned SliceSize = getUsedBits().popcount();
19313     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
19314     return SliceSize / 8;
19315   }
19316 
19317   /// Get the type that will be loaded for this slice.
19318   /// Note: This may not be the final type for the slice.
getLoadedType__anon666e37104411::LoadedSlice19319   EVT getLoadedType() const {
19320     assert(DAG && "Missing context");
19321     LLVMContext &Ctxt = *DAG->getContext();
19322     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
19323   }
19324 
19325   /// Get the alignment of the load used for this slice.
getAlign__anon666e37104411::LoadedSlice19326   Align getAlign() const {
19327     Align Alignment = Origin->getAlign();
19328     uint64_t Offset = getOffsetFromBase();
19329     if (Offset != 0)
19330       Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
19331     return Alignment;
19332   }
19333 
19334   /// Check if this slice can be rewritten with legal operations.
isLegal__anon666e37104411::LoadedSlice19335   bool isLegal() const {
19336     // An invalid slice is not legal.
19337     if (!Origin || !Inst || !DAG)
19338       return false;
19339 
19340     // Offsets are for indexed load only, we do not handle that.
19341     if (!Origin->getOffset().isUndef())
19342       return false;
19343 
19344     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19345 
19346     // Check that the type is legal.
19347     EVT SliceType = getLoadedType();
19348     if (!TLI.isTypeLegal(SliceType))
19349       return false;
19350 
19351     // Check that the load is legal for this type.
19352     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
19353       return false;
19354 
19355     // Check that the offset can be computed.
19356     // 1. Check its type.
19357     EVT PtrType = Origin->getBasePtr().getValueType();
19358     if (PtrType == MVT::Untyped || PtrType.isExtended())
19359       return false;
19360 
19361     // 2. Check that it fits in the immediate.
19362     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
19363       return false;
19364 
19365     // 3. Check that the computation is legal.
19366     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
19367       return false;
19368 
19369     // Check that the zext is legal if it needs one.
19370     EVT TruncateType = Inst->getValueType(0);
19371     if (TruncateType != SliceType &&
19372         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
19373       return false;
19374 
19375     return true;
19376   }
19377 
19378   /// Get the offset in bytes of this slice in the original chunk of
19379   /// bits.
19380   /// \pre DAG != nullptr.
getOffsetFromBase__anon666e37104411::LoadedSlice19381   uint64_t getOffsetFromBase() const {
19382     assert(DAG && "Missing context.");
19383     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
19384     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
19385     uint64_t Offset = Shift / 8;
19386     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
19387     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
19388            "The size of the original loaded type is not a multiple of a"
19389            " byte.");
19390     // If Offset is bigger than TySizeInBytes, it means we are loading all
19391     // zeros. This should have been optimized before in the process.
19392     assert(TySizeInBytes > Offset &&
19393            "Invalid shift amount for given loaded size");
19394     if (IsBigEndian)
19395       Offset = TySizeInBytes - Offset - getLoadedSize();
19396     return Offset;
19397   }
19398 
19399   /// Generate the sequence of instructions to load the slice
19400   /// represented by this object and redirect the uses of this slice to
19401   /// this new sequence of instructions.
19402   /// \pre this->Inst && this->Origin are valid Instructions and this
19403   /// object passed the legal check: LoadedSlice::isLegal returned true.
19404   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon666e37104411::LoadedSlice19405   SDValue loadSlice() const {
19406     assert(Inst && Origin && "Unable to replace a non-existing slice.");
19407     const SDValue &OldBaseAddr = Origin->getBasePtr();
19408     SDValue BaseAddr = OldBaseAddr;
19409     // Get the offset in that chunk of bytes w.r.t. the endianness.
19410     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
19411     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
19412     if (Offset) {
19413       // BaseAddr = BaseAddr + Offset.
19414       EVT ArithType = BaseAddr.getValueType();
19415       SDLoc DL(Origin);
19416       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
19417                               DAG->getConstant(Offset, DL, ArithType));
19418     }
19419 
19420     // Create the type of the loaded slice according to its size.
19421     EVT SliceType = getLoadedType();
19422 
19423     // Create the load for the slice.
19424     SDValue LastInst =
19425         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
19426                      Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
19427                      Origin->getMemOperand()->getFlags());
19428     // If the final type is not the same as the loaded type, this means that
19429     // we have to pad with zero. Create a zero extend for that.
19430     EVT FinalType = Inst->getValueType(0);
19431     if (SliceType != FinalType)
19432       LastInst =
19433           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
19434     return LastInst;
19435   }
19436 
19437   /// Check if this slice can be merged with an expensive cross register
19438   /// bank copy. E.g.,
19439   /// i = load i32
19440   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon666e37104411::LoadedSlice19441   bool canMergeExpensiveCrossRegisterBankCopy() const {
19442     if (!Inst || !Inst->hasOneUse())
19443       return false;
19444     SDNode *Use = *Inst->use_begin();
19445     if (Use->getOpcode() != ISD::BITCAST)
19446       return false;
19447     assert(DAG && "Missing context");
19448     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19449     EVT ResVT = Use->getValueType(0);
19450     const TargetRegisterClass *ResRC =
19451         TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
19452     const TargetRegisterClass *ArgRC =
19453         TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
19454                            Use->getOperand(0)->isDivergent());
19455     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
19456       return false;
19457 
19458     // At this point, we know that we perform a cross-register-bank copy.
19459     // Check if it is expensive.
19460     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
19461     // Assume bitcasts are cheap, unless both register classes do not
19462     // explicitly share a common sub class.
19463     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
19464       return false;
19465 
19466     // Check if it will be merged with the load.
19467     // 1. Check the alignment / fast memory access constraint.
19468     unsigned IsFast = 0;
19469     if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
19470                                 Origin->getAddressSpace(), getAlign(),
19471                                 Origin->getMemOperand()->getFlags(), &IsFast) ||
19472         !IsFast)
19473       return false;
19474 
19475     // 2. Check that the load is a legal operation for that type.
19476     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
19477       return false;
19478 
19479     // 3. Check that we do not have a zext in the way.
19480     if (Inst->getValueType(0) != getLoadedType())
19481       return false;
19482 
19483     return true;
19484   }
19485 };
19486 
19487 } // end anonymous namespace
19488 
19489 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
19490 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)19491 static bool areUsedBitsDense(const APInt &UsedBits) {
19492   // If all the bits are one, this is dense!
19493   if (UsedBits.isAllOnes())
19494     return true;
19495 
19496   // Get rid of the unused bits on the right.
19497   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countr_zero());
19498   // Get rid of the unused bits on the left.
19499   if (NarrowedUsedBits.countl_zero())
19500     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
19501   // Check that the chunk of bits is completely used.
19502   return NarrowedUsedBits.isAllOnes();
19503 }
19504 
19505 /// Check whether or not \p First and \p Second are next to each other
19506 /// in memory. This means that there is no hole between the bits loaded
19507 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)19508 static bool areSlicesNextToEachOther(const LoadedSlice &First,
19509                                      const LoadedSlice &Second) {
19510   assert(First.Origin == Second.Origin && First.Origin &&
19511          "Unable to match different memory origins.");
19512   APInt UsedBits = First.getUsedBits();
19513   assert((UsedBits & Second.getUsedBits()) == 0 &&
19514          "Slices are not supposed to overlap.");
19515   UsedBits |= Second.getUsedBits();
19516   return areUsedBitsDense(UsedBits);
19517 }
19518 
19519 /// Adjust the \p GlobalLSCost according to the target
19520 /// paring capabilities and the layout of the slices.
19521 /// \pre \p GlobalLSCost should account for at least as many loads as
19522 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)19523 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19524                                  LoadedSlice::Cost &GlobalLSCost) {
19525   unsigned NumberOfSlices = LoadedSlices.size();
19526   // If there is less than 2 elements, no pairing is possible.
19527   if (NumberOfSlices < 2)
19528     return;
19529 
19530   // Sort the slices so that elements that are likely to be next to each
19531   // other in memory are next to each other in the list.
19532   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
19533     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
19534     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
19535   });
19536   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
19537   // First (resp. Second) is the first (resp. Second) potentially candidate
19538   // to be placed in a paired load.
19539   const LoadedSlice *First = nullptr;
19540   const LoadedSlice *Second = nullptr;
19541   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
19542                 // Set the beginning of the pair.
19543                                                            First = Second) {
19544     Second = &LoadedSlices[CurrSlice];
19545 
19546     // If First is NULL, it means we start a new pair.
19547     // Get to the next slice.
19548     if (!First)
19549       continue;
19550 
19551     EVT LoadedType = First->getLoadedType();
19552 
19553     // If the types of the slices are different, we cannot pair them.
19554     if (LoadedType != Second->getLoadedType())
19555       continue;
19556 
19557     // Check if the target supplies paired loads for this type.
19558     Align RequiredAlignment;
19559     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
19560       // move to the next pair, this type is hopeless.
19561       Second = nullptr;
19562       continue;
19563     }
19564     // Check if we meet the alignment requirement.
19565     if (First->getAlign() < RequiredAlignment)
19566       continue;
19567 
19568     // Check that both loads are next to each other in memory.
19569     if (!areSlicesNextToEachOther(*First, *Second))
19570       continue;
19571 
19572     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
19573     --GlobalLSCost.Loads;
19574     // Move to the next pair.
19575     Second = nullptr;
19576   }
19577 }
19578 
19579 /// Check the profitability of all involved LoadedSlice.
19580 /// Currently, it is considered profitable if there is exactly two
19581 /// involved slices (1) which are (2) next to each other in memory, and
19582 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
19583 ///
19584 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
19585 /// the elements themselves.
19586 ///
19587 /// FIXME: When the cost model will be mature enough, we can relax
19588 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)19589 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19590                                 const APInt &UsedBits, bool ForCodeSize) {
19591   unsigned NumberOfSlices = LoadedSlices.size();
19592   if (StressLoadSlicing)
19593     return NumberOfSlices > 1;
19594 
19595   // Check (1).
19596   if (NumberOfSlices != 2)
19597     return false;
19598 
19599   // Check (2).
19600   if (!areUsedBitsDense(UsedBits))
19601     return false;
19602 
19603   // Check (3).
19604   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
19605   // The original code has one big load.
19606   OrigCost.Loads = 1;
19607   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
19608     const LoadedSlice &LS = LoadedSlices[CurrSlice];
19609     // Accumulate the cost of all the slices.
19610     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
19611     GlobalSlicingCost += SliceCost;
19612 
19613     // Account as cost in the original configuration the gain obtained
19614     // with the current slices.
19615     OrigCost.addSliceGain(LS);
19616   }
19617 
19618   // If the target supports paired load, adjust the cost accordingly.
19619   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
19620   return OrigCost > GlobalSlicingCost;
19621 }
19622 
19623 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
19624 /// operations, split it in the various pieces being extracted.
19625 ///
19626 /// This sort of thing is introduced by SROA.
19627 /// This slicing takes care not to insert overlapping loads.
19628 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)19629 bool DAGCombiner::SliceUpLoad(SDNode *N) {
19630   if (Level < AfterLegalizeDAG)
19631     return false;
19632 
19633   LoadSDNode *LD = cast<LoadSDNode>(N);
19634   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
19635       !LD->getValueType(0).isInteger())
19636     return false;
19637 
19638   // The algorithm to split up a load of a scalable vector into individual
19639   // elements currently requires knowing the length of the loaded type,
19640   // so will need adjusting to work on scalable vectors.
19641   if (LD->getValueType(0).isScalableVector())
19642     return false;
19643 
19644   // Keep track of already used bits to detect overlapping values.
19645   // In that case, we will just abort the transformation.
19646   APInt UsedBits(LD->getValueSizeInBits(0), 0);
19647 
19648   SmallVector<LoadedSlice, 4> LoadedSlices;
19649 
19650   // Check if this load is used as several smaller chunks of bits.
19651   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
19652   // of computation for each trunc.
19653   for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
19654        UI != UIEnd; ++UI) {
19655     // Skip the uses of the chain.
19656     if (UI.getUse().getResNo() != 0)
19657       continue;
19658 
19659     SDNode *User = *UI;
19660     unsigned Shift = 0;
19661 
19662     // Check if this is a trunc(lshr).
19663     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
19664         isa<ConstantSDNode>(User->getOperand(1))) {
19665       Shift = User->getConstantOperandVal(1);
19666       User = *User->use_begin();
19667     }
19668 
19669     // At this point, User is a Truncate, iff we encountered, trunc or
19670     // trunc(lshr).
19671     if (User->getOpcode() != ISD::TRUNCATE)
19672       return false;
19673 
19674     // The width of the type must be a power of 2 and greater than 8-bits.
19675     // Otherwise the load cannot be represented in LLVM IR.
19676     // Moreover, if we shifted with a non-8-bits multiple, the slice
19677     // will be across several bytes. We do not support that.
19678     unsigned Width = User->getValueSizeInBits(0);
19679     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
19680       return false;
19681 
19682     // Build the slice for this chain of computations.
19683     LoadedSlice LS(User, LD, Shift, &DAG);
19684     APInt CurrentUsedBits = LS.getUsedBits();
19685 
19686     // Check if this slice overlaps with another.
19687     if ((CurrentUsedBits & UsedBits) != 0)
19688       return false;
19689     // Update the bits used globally.
19690     UsedBits |= CurrentUsedBits;
19691 
19692     // Check if the new slice would be legal.
19693     if (!LS.isLegal())
19694       return false;
19695 
19696     // Record the slice.
19697     LoadedSlices.push_back(LS);
19698   }
19699 
19700   // Abort slicing if it does not seem to be profitable.
19701   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
19702     return false;
19703 
19704   ++SlicedLoads;
19705 
19706   // Rewrite each chain to use an independent load.
19707   // By construction, each chain can be represented by a unique load.
19708 
19709   // Prepare the argument for the new token factor for all the slices.
19710   SmallVector<SDValue, 8> ArgChains;
19711   for (const LoadedSlice &LS : LoadedSlices) {
19712     SDValue SliceInst = LS.loadSlice();
19713     CombineTo(LS.Inst, SliceInst, true);
19714     if (SliceInst.getOpcode() != ISD::LOAD)
19715       SliceInst = SliceInst.getOperand(0);
19716     assert(SliceInst->getOpcode() == ISD::LOAD &&
19717            "It takes more than a zext to get to the loaded slice!!");
19718     ArgChains.push_back(SliceInst.getValue(1));
19719   }
19720 
19721   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
19722                               ArgChains);
19723   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
19724   AddToWorklist(Chain.getNode());
19725   return true;
19726 }
19727 
19728 /// Check to see if V is (and load (ptr), imm), where the load is having
19729 /// specific bytes cleared out.  If so, return the byte size being masked out
19730 /// and the shift amount.
19731 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)19732 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
19733   std::pair<unsigned, unsigned> Result(0, 0);
19734 
19735   // Check for the structure we're looking for.
19736   if (V->getOpcode() != ISD::AND ||
19737       !isa<ConstantSDNode>(V->getOperand(1)) ||
19738       !ISD::isNormalLoad(V->getOperand(0).getNode()))
19739     return Result;
19740 
19741   // Check the chain and pointer.
19742   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
19743   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
19744 
19745   // This only handles simple types.
19746   if (V.getValueType() != MVT::i16 &&
19747       V.getValueType() != MVT::i32 &&
19748       V.getValueType() != MVT::i64)
19749     return Result;
19750 
19751   // Check the constant mask.  Invert it so that the bits being masked out are
19752   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
19753   // follow the sign bit for uniformity.
19754   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
19755   unsigned NotMaskLZ = llvm::countl_zero(NotMask);
19756   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
19757   unsigned NotMaskTZ = llvm::countr_zero(NotMask);
19758   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
19759   if (NotMaskLZ == 64) return Result;  // All zero mask.
19760 
19761   // See if we have a continuous run of bits.  If so, we have 0*1+0*
19762   if (llvm::countr_one(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
19763     return Result;
19764 
19765   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
19766   if (V.getValueType() != MVT::i64 && NotMaskLZ)
19767     NotMaskLZ -= 64-V.getValueSizeInBits();
19768 
19769   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
19770   switch (MaskedBytes) {
19771   case 1:
19772   case 2:
19773   case 4: break;
19774   default: return Result; // All one mask, or 5-byte mask.
19775   }
19776 
19777   // Verify that the first bit starts at a multiple of mask so that the access
19778   // is aligned the same as the access width.
19779   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
19780 
19781   // For narrowing to be valid, it must be the case that the load the
19782   // immediately preceding memory operation before the store.
19783   if (LD == Chain.getNode())
19784     ; // ok.
19785   else if (Chain->getOpcode() == ISD::TokenFactor &&
19786            SDValue(LD, 1).hasOneUse()) {
19787     // LD has only 1 chain use so they are no indirect dependencies.
19788     if (!LD->isOperandOf(Chain.getNode()))
19789       return Result;
19790   } else
19791     return Result; // Fail.
19792 
19793   Result.first = MaskedBytes;
19794   Result.second = NotMaskTZ/8;
19795   return Result;
19796 }
19797 
19798 /// Check to see if IVal is something that provides a value as specified by
19799 /// MaskInfo. If so, replace the specified store with a narrower store of
19800 /// truncated IVal.
19801 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)19802 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
19803                                 SDValue IVal, StoreSDNode *St,
19804                                 DAGCombiner *DC) {
19805   unsigned NumBytes = MaskInfo.first;
19806   unsigned ByteShift = MaskInfo.second;
19807   SelectionDAG &DAG = DC->getDAG();
19808 
19809   // Check to see if IVal is all zeros in the part being masked in by the 'or'
19810   // that uses this.  If not, this is not a replacement.
19811   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
19812                                   ByteShift*8, (ByteShift+NumBytes)*8);
19813   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
19814 
19815   // Check that it is legal on the target to do this.  It is legal if the new
19816   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
19817   // legalization. If the source type is legal, but the store type isn't, see
19818   // if we can use a truncating store.
19819   MVT VT = MVT::getIntegerVT(NumBytes * 8);
19820   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19821   bool UseTruncStore;
19822   if (DC->isTypeLegal(VT))
19823     UseTruncStore = false;
19824   else if (TLI.isTypeLegal(IVal.getValueType()) &&
19825            TLI.isTruncStoreLegal(IVal.getValueType(), VT))
19826     UseTruncStore = true;
19827   else
19828     return SDValue();
19829 
19830   // Can't do this for indexed stores.
19831   if (St->isIndexed())
19832     return SDValue();
19833 
19834   // Check that the target doesn't think this is a bad idea.
19835   if (St->getMemOperand() &&
19836       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
19837                               *St->getMemOperand()))
19838     return SDValue();
19839 
19840   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
19841   // shifted by ByteShift and truncated down to NumBytes.
19842   if (ByteShift) {
19843     SDLoc DL(IVal);
19844     IVal = DAG.getNode(
19845         ISD::SRL, DL, IVal.getValueType(), IVal,
19846         DAG.getShiftAmountConstant(ByteShift * 8, IVal.getValueType(), DL));
19847   }
19848 
19849   // Figure out the offset for the store and the alignment of the access.
19850   unsigned StOffset;
19851   if (DAG.getDataLayout().isLittleEndian())
19852     StOffset = ByteShift;
19853   else
19854     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
19855 
19856   SDValue Ptr = St->getBasePtr();
19857   if (StOffset) {
19858     SDLoc DL(IVal);
19859     Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(StOffset), DL);
19860   }
19861 
19862   ++OpsNarrowed;
19863   if (UseTruncStore)
19864     return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
19865                              St->getPointerInfo().getWithOffset(StOffset),
19866                              VT, St->getOriginalAlign());
19867 
19868   // Truncate down to the new size.
19869   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
19870 
19871   return DAG
19872       .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
19873                 St->getPointerInfo().getWithOffset(StOffset),
19874                 St->getOriginalAlign());
19875 }
19876 
19877 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
19878 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
19879 /// narrowing the load and store if it would end up being a win for performance
19880 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)19881 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
19882   StoreSDNode *ST  = cast<StoreSDNode>(N);
19883   if (!ST->isSimple())
19884     return SDValue();
19885 
19886   SDValue Chain = ST->getChain();
19887   SDValue Value = ST->getValue();
19888   SDValue Ptr   = ST->getBasePtr();
19889   EVT VT = Value.getValueType();
19890 
19891   if (ST->isTruncatingStore() || VT.isVector())
19892     return SDValue();
19893 
19894   unsigned Opc = Value.getOpcode();
19895 
19896   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
19897       !Value.hasOneUse())
19898     return SDValue();
19899 
19900   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
19901   // is a byte mask indicating a consecutive number of bytes, check to see if
19902   // Y is known to provide just those bytes.  If so, we try to replace the
19903   // load + replace + store sequence with a single (narrower) store, which makes
19904   // the load dead.
19905   if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
19906     std::pair<unsigned, unsigned> MaskedLoad;
19907     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
19908     if (MaskedLoad.first)
19909       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
19910                                                   Value.getOperand(1), ST,this))
19911         return NewST;
19912 
19913     // Or is commutative, so try swapping X and Y.
19914     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
19915     if (MaskedLoad.first)
19916       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
19917                                                   Value.getOperand(0), ST,this))
19918         return NewST;
19919   }
19920 
19921   if (!EnableReduceLoadOpStoreWidth)
19922     return SDValue();
19923 
19924   if (Value.getOperand(1).getOpcode() != ISD::Constant)
19925     return SDValue();
19926 
19927   SDValue N0 = Value.getOperand(0);
19928   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
19929       Chain == SDValue(N0.getNode(), 1)) {
19930     LoadSDNode *LD = cast<LoadSDNode>(N0);
19931     if (LD->getBasePtr() != Ptr ||
19932         LD->getPointerInfo().getAddrSpace() !=
19933         ST->getPointerInfo().getAddrSpace())
19934       return SDValue();
19935 
19936     // Find the type to narrow it the load / op / store to.
19937     SDValue N1 = Value.getOperand(1);
19938     unsigned BitWidth = N1.getValueSizeInBits();
19939     APInt Imm = N1->getAsAPIntVal();
19940     if (Opc == ISD::AND)
19941       Imm ^= APInt::getAllOnes(BitWidth);
19942     if (Imm == 0 || Imm.isAllOnes())
19943       return SDValue();
19944     unsigned ShAmt = Imm.countr_zero();
19945     unsigned MSB = BitWidth - Imm.countl_zero() - 1;
19946     unsigned NewBW = NextPowerOf2(MSB - ShAmt);
19947     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
19948     // The narrowing should be profitable, the load/store operation should be
19949     // legal (or custom) and the store size should be equal to the NewVT width.
19950     while (NewBW < BitWidth &&
19951            (NewVT.getStoreSizeInBits() != NewBW ||
19952             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
19953             !TLI.isNarrowingProfitable(VT, NewVT))) {
19954       NewBW = NextPowerOf2(NewBW);
19955       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
19956     }
19957     if (NewBW >= BitWidth)
19958       return SDValue();
19959 
19960     // If the lsb changed does not start at the type bitwidth boundary,
19961     // start at the previous one.
19962     if (ShAmt % NewBW)
19963       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
19964     APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
19965                                    std::min(BitWidth, ShAmt + NewBW));
19966     if ((Imm & Mask) == Imm) {
19967       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
19968       if (Opc == ISD::AND)
19969         NewImm ^= APInt::getAllOnes(NewBW);
19970       uint64_t PtrOff = ShAmt / 8;
19971       // For big endian targets, we need to adjust the offset to the pointer to
19972       // load the correct bytes.
19973       if (DAG.getDataLayout().isBigEndian())
19974         PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
19975 
19976       unsigned IsFast = 0;
19977       Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
19978       if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
19979                                   LD->getAddressSpace(), NewAlign,
19980                                   LD->getMemOperand()->getFlags(), &IsFast) ||
19981           !IsFast)
19982         return SDValue();
19983 
19984       SDValue NewPtr =
19985           DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(PtrOff), SDLoc(LD));
19986       SDValue NewLD =
19987           DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
19988                       LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
19989                       LD->getMemOperand()->getFlags(), LD->getAAInfo());
19990       SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
19991                                    DAG.getConstant(NewImm, SDLoc(Value),
19992                                                    NewVT));
19993       SDValue NewST =
19994           DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
19995                        ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
19996 
19997       AddToWorklist(NewPtr.getNode());
19998       AddToWorklist(NewLD.getNode());
19999       AddToWorklist(NewVal.getNode());
20000       WorklistRemover DeadNodes(*this);
20001       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
20002       ++OpsNarrowed;
20003       return NewST;
20004     }
20005   }
20006 
20007   return SDValue();
20008 }
20009 
20010 /// For a given floating point load / store pair, if the load value isn't used
20011 /// by any other operations, then consider transforming the pair to integer
20012 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)20013 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
20014   StoreSDNode *ST  = cast<StoreSDNode>(N);
20015   SDValue Value = ST->getValue();
20016   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
20017       Value.hasOneUse()) {
20018     LoadSDNode *LD = cast<LoadSDNode>(Value);
20019     EVT VT = LD->getMemoryVT();
20020     if (!VT.isFloatingPoint() ||
20021         VT != ST->getMemoryVT() ||
20022         LD->isNonTemporal() ||
20023         ST->isNonTemporal() ||
20024         LD->getPointerInfo().getAddrSpace() != 0 ||
20025         ST->getPointerInfo().getAddrSpace() != 0)
20026       return SDValue();
20027 
20028     TypeSize VTSize = VT.getSizeInBits();
20029 
20030     // We don't know the size of scalable types at compile time so we cannot
20031     // create an integer of the equivalent size.
20032     if (VTSize.isScalable())
20033       return SDValue();
20034 
20035     unsigned FastLD = 0, FastST = 0;
20036     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedValue());
20037     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
20038         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
20039         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
20040         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
20041         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
20042                                 *LD->getMemOperand(), &FastLD) ||
20043         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
20044                                 *ST->getMemOperand(), &FastST) ||
20045         !FastLD || !FastST)
20046       return SDValue();
20047 
20048     SDValue NewLD =
20049         DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
20050                     LD->getPointerInfo(), LD->getAlign());
20051 
20052     SDValue NewST =
20053         DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
20054                      ST->getPointerInfo(), ST->getAlign());
20055 
20056     AddToWorklist(NewLD.getNode());
20057     AddToWorklist(NewST.getNode());
20058     WorklistRemover DeadNodes(*this);
20059     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
20060     ++LdStFP2Int;
20061     return NewST;
20062   }
20063 
20064   return SDValue();
20065 }
20066 
20067 // This is a helper function for visitMUL to check the profitability
20068 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
20069 // MulNode is the original multiply, AddNode is (add x, c1),
20070 // and ConstNode is c2.
20071 //
20072 // If the (add x, c1) has multiple uses, we could increase
20073 // the number of adds if we make this transformation.
20074 // It would only be worth doing this if we can remove a
20075 // multiply in the process. Check for that here.
20076 // To illustrate:
20077 //     (A + c1) * c3
20078 //     (A + c2) * c3
20079 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)20080 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
20081                                               SDValue ConstNode) {
20082   APInt Val;
20083 
20084   // If the add only has one use, and the target thinks the folding is
20085   // profitable or does not lead to worse code, this would be OK to do.
20086   if (AddNode->hasOneUse() &&
20087       TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
20088     return true;
20089 
20090   // Walk all the users of the constant with which we're multiplying.
20091   for (SDNode *Use : ConstNode->uses()) {
20092     if (Use == MulNode) // This use is the one we're on right now. Skip it.
20093       continue;
20094 
20095     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
20096       SDNode *OtherOp;
20097       SDNode *MulVar = AddNode.getOperand(0).getNode();
20098 
20099       // OtherOp is what we're multiplying against the constant.
20100       if (Use->getOperand(0) == ConstNode)
20101         OtherOp = Use->getOperand(1).getNode();
20102       else
20103         OtherOp = Use->getOperand(0).getNode();
20104 
20105       // Check to see if multiply is with the same operand of our "add".
20106       //
20107       //     ConstNode  = CONST
20108       //     Use = ConstNode * A  <-- visiting Use. OtherOp is A.
20109       //     ...
20110       //     AddNode  = (A + c1)  <-- MulVar is A.
20111       //         = AddNode * ConstNode   <-- current visiting instruction.
20112       //
20113       // If we make this transformation, we will have a common
20114       // multiply (ConstNode * A) that we can save.
20115       if (OtherOp == MulVar)
20116         return true;
20117 
20118       // Now check to see if a future expansion will give us a common
20119       // multiply.
20120       //
20121       //     ConstNode  = CONST
20122       //     AddNode    = (A + c1)
20123       //     ...   = AddNode * ConstNode <-- current visiting instruction.
20124       //     ...
20125       //     OtherOp = (A + c2)
20126       //     Use     = OtherOp * ConstNode <-- visiting Use.
20127       //
20128       // If we make this transformation, we will have a common
20129       // multiply (CONST * A) after we also do the same transformation
20130       // to the "t2" instruction.
20131       if (OtherOp->getOpcode() == ISD::ADD &&
20132           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
20133           OtherOp->getOperand(0).getNode() == MulVar)
20134         return true;
20135     }
20136   }
20137 
20138   // Didn't find a case where this would be profitable.
20139   return false;
20140 }
20141 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)20142 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
20143                                          unsigned NumStores) {
20144   SmallVector<SDValue, 8> Chains;
20145   SmallPtrSet<const SDNode *, 8> Visited;
20146   SDLoc StoreDL(StoreNodes[0].MemNode);
20147 
20148   for (unsigned i = 0; i < NumStores; ++i) {
20149     Visited.insert(StoreNodes[i].MemNode);
20150   }
20151 
20152   // don't include nodes that are children or repeated nodes.
20153   for (unsigned i = 0; i < NumStores; ++i) {
20154     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
20155       Chains.push_back(StoreNodes[i].MemNode->getChain());
20156   }
20157 
20158   assert(!Chains.empty() && "Chain should have generated a chain");
20159   return DAG.getTokenFactor(StoreDL, Chains);
20160 }
20161 
hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes)20162 bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
20163   const Value *UnderlyingObj = nullptr;
20164   for (const auto &MemOp : StoreNodes) {
20165     const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
20166     // Pseudo value like stack frame has its own frame index and size, should
20167     // not use the first store's frame index for other frames.
20168     if (MMO->getPseudoValue())
20169       return false;
20170 
20171     if (!MMO->getValue())
20172       return false;
20173 
20174     const Value *Obj = getUnderlyingObject(MMO->getValue());
20175 
20176     if (UnderlyingObj && UnderlyingObj != Obj)
20177       return false;
20178 
20179     if (!UnderlyingObj)
20180       UnderlyingObj = Obj;
20181   }
20182 
20183   return true;
20184 }
20185 
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)20186 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
20187     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
20188     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
20189   // Make sure we have something to merge.
20190   if (NumStores < 2)
20191     return false;
20192 
20193   assert((!UseTrunc || !UseVector) &&
20194          "This optimization cannot emit a vector truncating store");
20195 
20196   // The latest Node in the DAG.
20197   SDLoc DL(StoreNodes[0].MemNode);
20198 
20199   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
20200   unsigned SizeInBits = NumStores * ElementSizeBits;
20201   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20202 
20203   std::optional<MachineMemOperand::Flags> Flags;
20204   AAMDNodes AAInfo;
20205   for (unsigned I = 0; I != NumStores; ++I) {
20206     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
20207     if (!Flags) {
20208       Flags = St->getMemOperand()->getFlags();
20209       AAInfo = St->getAAInfo();
20210       continue;
20211     }
20212     // Skip merging if there's an inconsistent flag.
20213     if (Flags != St->getMemOperand()->getFlags())
20214       return false;
20215     // Concatenate AA metadata.
20216     AAInfo = AAInfo.concat(St->getAAInfo());
20217   }
20218 
20219   EVT StoreTy;
20220   if (UseVector) {
20221     unsigned Elts = NumStores * NumMemElts;
20222     // Get the type for the merged vector store.
20223     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
20224   } else
20225     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
20226 
20227   SDValue StoredVal;
20228   if (UseVector) {
20229     if (IsConstantSrc) {
20230       SmallVector<SDValue, 8> BuildVector;
20231       for (unsigned I = 0; I != NumStores; ++I) {
20232         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
20233         SDValue Val = St->getValue();
20234         // If constant is of the wrong type, convert it now.  This comes up
20235         // when one of our stores was truncating.
20236         if (MemVT != Val.getValueType()) {
20237           Val = peekThroughBitcasts(Val);
20238           // Deal with constants of wrong size.
20239           if (ElementSizeBits != Val.getValueSizeInBits()) {
20240             auto *C = dyn_cast<ConstantSDNode>(Val);
20241             if (!C)
20242               // Not clear how to truncate FP values.
20243               // TODO: Handle truncation of build_vector constants
20244               return false;
20245 
20246             EVT IntMemVT =
20247                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
20248             Val = DAG.getConstant(C->getAPIntValue()
20249                                       .zextOrTrunc(Val.getValueSizeInBits())
20250                                       .zextOrTrunc(ElementSizeBits),
20251                                   SDLoc(C), IntMemVT);
20252           }
20253           // Make sure correctly size type is the correct type.
20254           Val = DAG.getBitcast(MemVT, Val);
20255         }
20256         BuildVector.push_back(Val);
20257       }
20258       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
20259                                                : ISD::BUILD_VECTOR,
20260                               DL, StoreTy, BuildVector);
20261     } else {
20262       SmallVector<SDValue, 8> Ops;
20263       for (unsigned i = 0; i < NumStores; ++i) {
20264         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
20265         SDValue Val = peekThroughBitcasts(St->getValue());
20266         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
20267         // type MemVT. If the underlying value is not the correct
20268         // type, but it is an extraction of an appropriate vector we
20269         // can recast Val to be of the correct type. This may require
20270         // converting between EXTRACT_VECTOR_ELT and
20271         // EXTRACT_SUBVECTOR.
20272         if ((MemVT != Val.getValueType()) &&
20273             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
20274              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
20275           EVT MemVTScalarTy = MemVT.getScalarType();
20276           // We may need to add a bitcast here to get types to line up.
20277           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
20278             Val = DAG.getBitcast(MemVT, Val);
20279           } else if (MemVT.isVector() &&
20280                      Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
20281             Val = DAG.getNode(ISD::BUILD_VECTOR, DL, MemVT, Val);
20282           } else {
20283             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
20284                                             : ISD::EXTRACT_VECTOR_ELT;
20285             SDValue Vec = Val.getOperand(0);
20286             SDValue Idx = Val.getOperand(1);
20287             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
20288           }
20289         }
20290         Ops.push_back(Val);
20291       }
20292 
20293       // Build the extracted vector elements back into a vector.
20294       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
20295                                                : ISD::BUILD_VECTOR,
20296                               DL, StoreTy, Ops);
20297     }
20298   } else {
20299     // We should always use a vector store when merging extracted vector
20300     // elements, so this path implies a store of constants.
20301     assert(IsConstantSrc && "Merged vector elements should use vector store");
20302 
20303     APInt StoreInt(SizeInBits, 0);
20304 
20305     // Construct a single integer constant which is made of the smaller
20306     // constant inputs.
20307     bool IsLE = DAG.getDataLayout().isLittleEndian();
20308     for (unsigned i = 0; i < NumStores; ++i) {
20309       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
20310       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
20311 
20312       SDValue Val = St->getValue();
20313       Val = peekThroughBitcasts(Val);
20314       StoreInt <<= ElementSizeBits;
20315       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
20316         StoreInt |= C->getAPIntValue()
20317                         .zextOrTrunc(ElementSizeBits)
20318                         .zextOrTrunc(SizeInBits);
20319       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
20320         StoreInt |= C->getValueAPF()
20321                         .bitcastToAPInt()
20322                         .zextOrTrunc(ElementSizeBits)
20323                         .zextOrTrunc(SizeInBits);
20324         // If fp truncation is necessary give up for now.
20325         if (MemVT.getSizeInBits() != ElementSizeBits)
20326           return false;
20327       } else if (ISD::isBuildVectorOfConstantSDNodes(Val.getNode()) ||
20328                  ISD::isBuildVectorOfConstantFPSDNodes(Val.getNode())) {
20329         // Not yet handled
20330         return false;
20331       } else {
20332         llvm_unreachable("Invalid constant element type");
20333       }
20334     }
20335 
20336     // Create the new Load and Store operations.
20337     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
20338   }
20339 
20340   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20341   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
20342   bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20343 
20344   // make sure we use trunc store if it's necessary to be legal.
20345   // When generate the new widen store, if the first store's pointer info can
20346   // not be reused, discard the pointer info except the address space because
20347   // now the widen store can not be represented by the original pointer info
20348   // which is for the narrow memory object.
20349   SDValue NewStore;
20350   if (!UseTrunc) {
20351     NewStore = DAG.getStore(
20352         NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
20353         CanReusePtrInfo
20354             ? FirstInChain->getPointerInfo()
20355             : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20356         FirstInChain->getAlign(), *Flags, AAInfo);
20357   } else { // Must be realized as a trunc store
20358     EVT LegalizedStoredValTy =
20359         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
20360     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
20361     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
20362     SDValue ExtendedStoreVal =
20363         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
20364                         LegalizedStoredValTy);
20365     NewStore = DAG.getTruncStore(
20366         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
20367         CanReusePtrInfo
20368             ? FirstInChain->getPointerInfo()
20369             : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20370         StoredVal.getValueType() /*TVT*/, FirstInChain->getAlign(), *Flags,
20371         AAInfo);
20372   }
20373 
20374   // Replace all merged stores with the new store.
20375   for (unsigned i = 0; i < NumStores; ++i)
20376     CombineTo(StoreNodes[i].MemNode, NewStore);
20377 
20378   AddToWorklist(NewChain.getNode());
20379   return true;
20380 }
20381 
20382 SDNode *
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes)20383 DAGCombiner::getStoreMergeCandidates(StoreSDNode *St,
20384                                      SmallVectorImpl<MemOpLink> &StoreNodes) {
20385   // This holds the base pointer, index, and the offset in bytes from the base
20386   // pointer. We must have a base and an offset. Do not handle stores to undef
20387   // base pointers.
20388   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
20389   if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
20390     return nullptr;
20391 
20392   SDValue Val = peekThroughBitcasts(St->getValue());
20393   StoreSource StoreSrc = getStoreSource(Val);
20394   assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
20395 
20396   // Match on loadbaseptr if relevant.
20397   EVT MemVT = St->getMemoryVT();
20398   BaseIndexOffset LBasePtr;
20399   EVT LoadVT;
20400   if (StoreSrc == StoreSource::Load) {
20401     auto *Ld = cast<LoadSDNode>(Val);
20402     LBasePtr = BaseIndexOffset::match(Ld, DAG);
20403     LoadVT = Ld->getMemoryVT();
20404     // Load and store should be the same type.
20405     if (MemVT != LoadVT)
20406       return nullptr;
20407     // Loads must only have one use.
20408     if (!Ld->hasNUsesOfValue(1, 0))
20409       return nullptr;
20410     // The memory operands must not be volatile/indexed/atomic.
20411     // TODO: May be able to relax for unordered atomics (see D66309)
20412     if (!Ld->isSimple() || Ld->isIndexed())
20413       return nullptr;
20414   }
20415   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
20416                             int64_t &Offset) -> bool {
20417     // The memory operands must not be volatile/indexed/atomic.
20418     // TODO: May be able to relax for unordered atomics (see D66309)
20419     if (!Other->isSimple() || Other->isIndexed())
20420       return false;
20421     // Don't mix temporal stores with non-temporal stores.
20422     if (St->isNonTemporal() != Other->isNonTemporal())
20423       return false;
20424     if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*St, *Other))
20425       return false;
20426     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
20427     // Allow merging constants of different types as integers.
20428     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
20429                                            : Other->getMemoryVT() != MemVT;
20430     switch (StoreSrc) {
20431     case StoreSource::Load: {
20432       if (NoTypeMatch)
20433         return false;
20434       // The Load's Base Ptr must also match.
20435       auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
20436       if (!OtherLd)
20437         return false;
20438       BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
20439       if (LoadVT != OtherLd->getMemoryVT())
20440         return false;
20441       // Loads must only have one use.
20442       if (!OtherLd->hasNUsesOfValue(1, 0))
20443         return false;
20444       // The memory operands must not be volatile/indexed/atomic.
20445       // TODO: May be able to relax for unordered atomics (see D66309)
20446       if (!OtherLd->isSimple() || OtherLd->isIndexed())
20447         return false;
20448       // Don't mix temporal loads with non-temporal loads.
20449       if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
20450         return false;
20451       if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*cast<LoadSDNode>(Val),
20452                                                    *OtherLd))
20453         return false;
20454       if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
20455         return false;
20456       break;
20457     }
20458     case StoreSource::Constant:
20459       if (NoTypeMatch)
20460         return false;
20461       if (getStoreSource(OtherBC) != StoreSource::Constant)
20462         return false;
20463       break;
20464     case StoreSource::Extract:
20465       // Do not merge truncated stores here.
20466       if (Other->isTruncatingStore())
20467         return false;
20468       if (!MemVT.bitsEq(OtherBC.getValueType()))
20469         return false;
20470       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20471           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20472         return false;
20473       break;
20474     default:
20475       llvm_unreachable("Unhandled store source for merging");
20476     }
20477     Ptr = BaseIndexOffset::match(Other, DAG);
20478     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
20479   };
20480 
20481   // We are looking for a root node which is an ancestor to all mergable
20482   // stores. We search up through a load, to our root and then down
20483   // through all children. For instance we will find Store{1,2,3} if
20484   // St is Store1, Store2. or Store3 where the root is not a load
20485   // which always true for nonvolatile ops. TODO: Expand
20486   // the search to find all valid candidates through multiple layers of loads.
20487   //
20488   // Root
20489   // |-------|-------|
20490   // Load    Load    Store3
20491   // |       |
20492   // Store1   Store2
20493   //
20494   // FIXME: We should be able to climb and
20495   // descend TokenFactors to find candidates as well.
20496 
20497   SDNode *RootNode = St->getChain().getNode();
20498   // Bail out if we already analyzed this root node and found nothing.
20499   if (ChainsWithoutMergeableStores.contains(RootNode))
20500     return nullptr;
20501 
20502   // Check if the pair of StoreNode and the RootNode already bail out many
20503   // times which is over the limit in dependence check.
20504   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
20505                                         SDNode *RootNode) -> bool {
20506     auto RootCount = StoreRootCountMap.find(StoreNode);
20507     return RootCount != StoreRootCountMap.end() &&
20508            RootCount->second.first == RootNode &&
20509            RootCount->second.second > StoreMergeDependenceLimit;
20510   };
20511 
20512   auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
20513     // This must be a chain use.
20514     if (UseIter.getOperandNo() != 0)
20515       return;
20516     if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
20517       BaseIndexOffset Ptr;
20518       int64_t PtrDiff;
20519       if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
20520           !OverLimitInDependenceCheck(OtherStore, RootNode))
20521         StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
20522     }
20523   };
20524 
20525   unsigned NumNodesExplored = 0;
20526   const unsigned MaxSearchNodes = 1024;
20527   if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
20528     RootNode = Ldn->getChain().getNode();
20529     // Bail out if we already analyzed this root node and found nothing.
20530     if (ChainsWithoutMergeableStores.contains(RootNode))
20531       return nullptr;
20532     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20533          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
20534       if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
20535         for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
20536           TryToAddCandidate(I2);
20537       }
20538       // Check stores that depend on the root (e.g. Store 3 in the chart above).
20539       if (I.getOperandNo() == 0 && isa<StoreSDNode>(*I)) {
20540         TryToAddCandidate(I);
20541       }
20542     }
20543   } else {
20544     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20545          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
20546       TryToAddCandidate(I);
20547   }
20548 
20549   return RootNode;
20550 }
20551 
20552 // We need to check that merging these stores does not cause a loop in the
20553 // DAG. Any store candidate may depend on another candidate indirectly through
20554 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)20555 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
20556     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
20557     SDNode *RootNode) {
20558   // FIXME: We should be able to truncate a full search of
20559   // predecessors by doing a BFS and keeping tabs the originating
20560   // stores from which worklist nodes come from in a similar way to
20561   // TokenFactor simplfication.
20562 
20563   SmallPtrSet<const SDNode *, 32> Visited;
20564   SmallVector<const SDNode *, 8> Worklist;
20565 
20566   // RootNode is a predecessor to all candidates so we need not search
20567   // past it. Add RootNode (peeking through TokenFactors). Do not count
20568   // these towards size check.
20569 
20570   Worklist.push_back(RootNode);
20571   while (!Worklist.empty()) {
20572     auto N = Worklist.pop_back_val();
20573     if (!Visited.insert(N).second)
20574       continue; // Already present in Visited.
20575     if (N->getOpcode() == ISD::TokenFactor) {
20576       for (SDValue Op : N->ops())
20577         Worklist.push_back(Op.getNode());
20578     }
20579   }
20580 
20581   // Don't count pruning nodes towards max.
20582   unsigned int Max = 1024 + Visited.size();
20583   // Search Ops of store candidates.
20584   for (unsigned i = 0; i < NumStores; ++i) {
20585     SDNode *N = StoreNodes[i].MemNode;
20586     // Of the 4 Store Operands:
20587     //   * Chain (Op 0) -> We have already considered these
20588     //                     in candidate selection, but only by following the
20589     //                     chain dependencies. We could still have a chain
20590     //                     dependency to a load, that has a non-chain dep to
20591     //                     another load, that depends on a store, etc. So it is
20592     //                     possible to have dependencies that consist of a mix
20593     //                     of chain and non-chain deps, and we need to include
20594     //                     chain operands in the analysis here..
20595     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
20596     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
20597     //                       but aren't necessarily fromt the same base node, so
20598     //                       cycles possible (e.g. via indexed store).
20599     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
20600     //               non-indexed stores). Not constant on all targets (e.g. ARM)
20601     //               and so can participate in a cycle.
20602     for (const SDValue &Op : N->op_values())
20603       Worklist.push_back(Op.getNode());
20604   }
20605   // Search through DAG. We can stop early if we find a store node.
20606   for (unsigned i = 0; i < NumStores; ++i)
20607     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
20608                                      Max)) {
20609       // If the searching bail out, record the StoreNode and RootNode in the
20610       // StoreRootCountMap. If we have seen the pair many times over a limit,
20611       // we won't add the StoreNode into StoreNodes set again.
20612       if (Visited.size() >= Max) {
20613         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
20614         if (RootCount.first == RootNode)
20615           RootCount.second++;
20616         else
20617           RootCount = {RootNode, 1};
20618       }
20619       return false;
20620     }
20621   return true;
20622 }
20623 
20624 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const20625 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
20626                                   int64_t ElementSizeBytes) const {
20627   while (true) {
20628     // Find a store past the width of the first store.
20629     size_t StartIdx = 0;
20630     while ((StartIdx + 1 < StoreNodes.size()) &&
20631            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
20632               StoreNodes[StartIdx + 1].OffsetFromBase)
20633       ++StartIdx;
20634 
20635     // Bail if we don't have enough candidates to merge.
20636     if (StartIdx + 1 >= StoreNodes.size())
20637       return 0;
20638 
20639     // Trim stores that overlapped with the first store.
20640     if (StartIdx)
20641       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
20642 
20643     // Scan the memory operations on the chain and find the first
20644     // non-consecutive store memory address.
20645     unsigned NumConsecutiveStores = 1;
20646     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
20647     // Check that the addresses are consecutive starting from the second
20648     // element in the list of stores.
20649     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
20650       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
20651       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20652         break;
20653       NumConsecutiveStores = i + 1;
20654     }
20655     if (NumConsecutiveStores > 1)
20656       return NumConsecutiveStores;
20657 
20658     // There are no consecutive stores at the start of the list.
20659     // Remove the first store and try again.
20660     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
20661   }
20662 }
20663 
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)20664 bool DAGCombiner::tryStoreMergeOfConstants(
20665     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20666     EVT MemVT, SDNode *RootNode, bool AllowVectors) {
20667   LLVMContext &Context = *DAG.getContext();
20668   const DataLayout &DL = DAG.getDataLayout();
20669   int64_t ElementSizeBytes = MemVT.getStoreSize();
20670   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20671   bool MadeChange = false;
20672 
20673   // Store the constants into memory as one consecutive store.
20674   while (NumConsecutiveStores >= 2) {
20675     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20676     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20677     Align FirstStoreAlign = FirstInChain->getAlign();
20678     unsigned LastLegalType = 1;
20679     unsigned LastLegalVectorType = 1;
20680     bool LastIntegerTrunc = false;
20681     bool NonZero = false;
20682     unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
20683     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20684       StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
20685       SDValue StoredVal = ST->getValue();
20686       bool IsElementZero = false;
20687       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
20688         IsElementZero = C->isZero();
20689       else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
20690         IsElementZero = C->getConstantFPValue()->isNullValue();
20691       else if (ISD::isBuildVectorAllZeros(StoredVal.getNode()))
20692         IsElementZero = true;
20693       if (IsElementZero) {
20694         if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
20695           FirstZeroAfterNonZero = i;
20696       }
20697       NonZero |= !IsElementZero;
20698 
20699       // Find a legal type for the constant store.
20700       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20701       EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
20702       unsigned IsFast = 0;
20703 
20704       // Break early when size is too large to be legal.
20705       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20706         break;
20707 
20708       if (TLI.isTypeLegal(StoreTy) &&
20709           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
20710                                DAG.getMachineFunction()) &&
20711           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20712                                  *FirstInChain->getMemOperand(), &IsFast) &&
20713           IsFast) {
20714         LastIntegerTrunc = false;
20715         LastLegalType = i + 1;
20716         // Or check whether a truncstore is legal.
20717       } else if (TLI.getTypeAction(Context, StoreTy) ==
20718                  TargetLowering::TypePromoteInteger) {
20719         EVT LegalizedStoredValTy =
20720             TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
20721         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
20722             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
20723                                  DAG.getMachineFunction()) &&
20724             TLI.allowsMemoryAccess(Context, DL, StoreTy,
20725                                    *FirstInChain->getMemOperand(), &IsFast) &&
20726             IsFast) {
20727           LastIntegerTrunc = true;
20728           LastLegalType = i + 1;
20729         }
20730       }
20731 
20732       // We only use vectors if the target allows it and the function is not
20733       // marked with the noimplicitfloat attribute.
20734       if (TLI.storeOfVectorConstantIsCheap(!NonZero, MemVT, i + 1, FirstStoreAS) &&
20735           AllowVectors) {
20736         // Find a legal type for the vector store.
20737         unsigned Elts = (i + 1) * NumMemElts;
20738         EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
20739         if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
20740             TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
20741             TLI.allowsMemoryAccess(Context, DL, Ty,
20742                                    *FirstInChain->getMemOperand(), &IsFast) &&
20743             IsFast)
20744           LastLegalVectorType = i + 1;
20745       }
20746     }
20747 
20748     bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
20749     unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
20750     bool UseTrunc = LastIntegerTrunc && !UseVector;
20751 
20752     // Check if we found a legal integer type that creates a meaningful
20753     // merge.
20754     if (NumElem < 2) {
20755       // We know that candidate stores are in order and of correct
20756       // shape. While there is no mergeable sequence from the
20757       // beginning one may start later in the sequence. The only
20758       // reason a merge of size N could have failed where another of
20759       // the same size would not have, is if the alignment has
20760       // improved or we've dropped a non-zero value. Drop as many
20761       // candidates as we can here.
20762       unsigned NumSkip = 1;
20763       while ((NumSkip < NumConsecutiveStores) &&
20764              (NumSkip < FirstZeroAfterNonZero) &&
20765              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20766         NumSkip++;
20767 
20768       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
20769       NumConsecutiveStores -= NumSkip;
20770       continue;
20771     }
20772 
20773     // Check that we can merge these candidates without causing a cycle.
20774     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
20775                                                   RootNode)) {
20776       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
20777       NumConsecutiveStores -= NumElem;
20778       continue;
20779     }
20780 
20781     MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
20782                                                   /*IsConstantSrc*/ true,
20783                                                   UseVector, UseTrunc);
20784 
20785     // Remove merged stores for next iteration.
20786     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
20787     NumConsecutiveStores -= NumElem;
20788   }
20789   return MadeChange;
20790 }
20791 
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)20792 bool DAGCombiner::tryStoreMergeOfExtracts(
20793     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20794     EVT MemVT, SDNode *RootNode) {
20795   LLVMContext &Context = *DAG.getContext();
20796   const DataLayout &DL = DAG.getDataLayout();
20797   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20798   bool MadeChange = false;
20799 
20800   // Loop on Consecutive Stores on success.
20801   while (NumConsecutiveStores >= 2) {
20802     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20803     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20804     Align FirstStoreAlign = FirstInChain->getAlign();
20805     unsigned NumStoresToMerge = 1;
20806     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20807       // Find a legal type for the vector store.
20808       unsigned Elts = (i + 1) * NumMemElts;
20809       EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
20810       unsigned IsFast = 0;
20811 
20812       // Break early when size is too large to be legal.
20813       if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
20814         break;
20815 
20816       if (TLI.isTypeLegal(Ty) &&
20817           TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
20818           TLI.allowsMemoryAccess(Context, DL, Ty,
20819                                  *FirstInChain->getMemOperand(), &IsFast) &&
20820           IsFast)
20821         NumStoresToMerge = i + 1;
20822     }
20823 
20824     // Check if we found a legal integer type creating a meaningful
20825     // merge.
20826     if (NumStoresToMerge < 2) {
20827       // We know that candidate stores are in order and of correct
20828       // shape. While there is no mergeable sequence from the
20829       // beginning one may start later in the sequence. The only
20830       // reason a merge of size N could have failed where another of
20831       // the same size would not have, is if the alignment has
20832       // improved. Drop as many candidates as we can here.
20833       unsigned NumSkip = 1;
20834       while ((NumSkip < NumConsecutiveStores) &&
20835              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20836         NumSkip++;
20837 
20838       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
20839       NumConsecutiveStores -= NumSkip;
20840       continue;
20841     }
20842 
20843     // Check that we can merge these candidates without causing a cycle.
20844     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
20845                                                   RootNode)) {
20846       StoreNodes.erase(StoreNodes.begin(),
20847                        StoreNodes.begin() + NumStoresToMerge);
20848       NumConsecutiveStores -= NumStoresToMerge;
20849       continue;
20850     }
20851 
20852     MadeChange |= mergeStoresOfConstantsOrVecElts(
20853         StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
20854         /*UseVector*/ true, /*UseTrunc*/ false);
20855 
20856     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
20857     NumConsecutiveStores -= NumStoresToMerge;
20858   }
20859   return MadeChange;
20860 }
20861 
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)20862 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
20863                                        unsigned NumConsecutiveStores, EVT MemVT,
20864                                        SDNode *RootNode, bool AllowVectors,
20865                                        bool IsNonTemporalStore,
20866                                        bool IsNonTemporalLoad) {
20867   LLVMContext &Context = *DAG.getContext();
20868   const DataLayout &DL = DAG.getDataLayout();
20869   int64_t ElementSizeBytes = MemVT.getStoreSize();
20870   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20871   bool MadeChange = false;
20872 
20873   // Look for load nodes which are used by the stored values.
20874   SmallVector<MemOpLink, 8> LoadNodes;
20875 
20876   // Find acceptable loads. Loads need to have the same chain (token factor),
20877   // must not be zext, volatile, indexed, and they must be consecutive.
20878   BaseIndexOffset LdBasePtr;
20879 
20880   for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20881     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
20882     SDValue Val = peekThroughBitcasts(St->getValue());
20883     LoadSDNode *Ld = cast<LoadSDNode>(Val);
20884 
20885     BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
20886     // If this is not the first ptr that we check.
20887     int64_t LdOffset = 0;
20888     if (LdBasePtr.getBase().getNode()) {
20889       // The base ptr must be the same.
20890       if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
20891         break;
20892     } else {
20893       // Check that all other base pointers are the same as this one.
20894       LdBasePtr = LdPtr;
20895     }
20896 
20897     // We found a potential memory operand to merge.
20898     LoadNodes.push_back(MemOpLink(Ld, LdOffset));
20899   }
20900 
20901   while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
20902     Align RequiredAlignment;
20903     bool NeedRotate = false;
20904     if (LoadNodes.size() == 2) {
20905       // If we have load/store pair instructions and we only have two values,
20906       // don't bother merging.
20907       if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
20908           StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
20909         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
20910         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
20911         break;
20912       }
20913       // If the loads are reversed, see if we can rotate the halves into place.
20914       int64_t Offset0 = LoadNodes[0].OffsetFromBase;
20915       int64_t Offset1 = LoadNodes[1].OffsetFromBase;
20916       EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
20917       if (Offset0 - Offset1 == ElementSizeBytes &&
20918           (hasOperation(ISD::ROTL, PairVT) ||
20919            hasOperation(ISD::ROTR, PairVT))) {
20920         std::swap(LoadNodes[0], LoadNodes[1]);
20921         NeedRotate = true;
20922       }
20923     }
20924     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20925     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20926     Align FirstStoreAlign = FirstInChain->getAlign();
20927     LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
20928 
20929     // Scan the memory operations on the chain and find the first
20930     // non-consecutive load memory address. These variables hold the index in
20931     // the store node array.
20932 
20933     unsigned LastConsecutiveLoad = 1;
20934 
20935     // This variable refers to the size and not index in the array.
20936     unsigned LastLegalVectorType = 1;
20937     unsigned LastLegalIntegerType = 1;
20938     bool isDereferenceable = true;
20939     bool DoIntegerTruncate = false;
20940     int64_t StartAddress = LoadNodes[0].OffsetFromBase;
20941     SDValue LoadChain = FirstLoad->getChain();
20942     for (unsigned i = 1; i < LoadNodes.size(); ++i) {
20943       // All loads must share the same chain.
20944       if (LoadNodes[i].MemNode->getChain() != LoadChain)
20945         break;
20946 
20947       int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
20948       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20949         break;
20950       LastConsecutiveLoad = i;
20951 
20952       if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
20953         isDereferenceable = false;
20954 
20955       // Find a legal type for the vector store.
20956       unsigned Elts = (i + 1) * NumMemElts;
20957       EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
20958 
20959       // Break early when size is too large to be legal.
20960       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20961         break;
20962 
20963       unsigned IsFastSt = 0;
20964       unsigned IsFastLd = 0;
20965       // Don't try vector types if we need a rotate. We may still fail the
20966       // legality checks for the integer type, but we can't handle the rotate
20967       // case with vectors.
20968       // FIXME: We could use a shuffle in place of the rotate.
20969       if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
20970           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
20971                                DAG.getMachineFunction()) &&
20972           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20973                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
20974           IsFastSt &&
20975           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20976                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
20977           IsFastLd) {
20978         LastLegalVectorType = i + 1;
20979       }
20980 
20981       // Find a legal type for the integer store.
20982       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20983       StoreTy = EVT::getIntegerVT(Context, SizeInBits);
20984       if (TLI.isTypeLegal(StoreTy) &&
20985           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
20986                                DAG.getMachineFunction()) &&
20987           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20988                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
20989           IsFastSt &&
20990           TLI.allowsMemoryAccess(Context, DL, StoreTy,
20991                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
20992           IsFastLd) {
20993         LastLegalIntegerType = i + 1;
20994         DoIntegerTruncate = false;
20995         // Or check whether a truncstore and extload is legal.
20996       } else if (TLI.getTypeAction(Context, StoreTy) ==
20997                  TargetLowering::TypePromoteInteger) {
20998         EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
20999         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
21000             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
21001                                  DAG.getMachineFunction()) &&
21002             TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
21003             TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
21004             TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
21005             TLI.allowsMemoryAccess(Context, DL, StoreTy,
21006                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
21007             IsFastSt &&
21008             TLI.allowsMemoryAccess(Context, DL, StoreTy,
21009                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
21010             IsFastLd) {
21011           LastLegalIntegerType = i + 1;
21012           DoIntegerTruncate = true;
21013         }
21014       }
21015     }
21016 
21017     // Only use vector types if the vector type is larger than the integer
21018     // type. If they are the same, use integers.
21019     bool UseVectorTy =
21020         LastLegalVectorType > LastLegalIntegerType && AllowVectors;
21021     unsigned LastLegalType =
21022         std::max(LastLegalVectorType, LastLegalIntegerType);
21023 
21024     // We add +1 here because the LastXXX variables refer to location while
21025     // the NumElem refers to array/index size.
21026     unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
21027     NumElem = std::min(LastLegalType, NumElem);
21028     Align FirstLoadAlign = FirstLoad->getAlign();
21029 
21030     if (NumElem < 2) {
21031       // We know that candidate stores are in order and of correct
21032       // shape. While there is no mergeable sequence from the
21033       // beginning one may start later in the sequence. The only
21034       // reason a merge of size N could have failed where another of
21035       // the same size would not have is if the alignment or either
21036       // the load or store has improved. Drop as many candidates as we
21037       // can here.
21038       unsigned NumSkip = 1;
21039       while ((NumSkip < LoadNodes.size()) &&
21040              (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
21041              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21042         NumSkip++;
21043       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
21044       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
21045       NumConsecutiveStores -= NumSkip;
21046       continue;
21047     }
21048 
21049     // Check that we can merge these candidates without causing a cycle.
21050     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
21051                                                   RootNode)) {
21052       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21053       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
21054       NumConsecutiveStores -= NumElem;
21055       continue;
21056     }
21057 
21058     // Find if it is better to use vectors or integers to load and store
21059     // to memory.
21060     EVT JointMemOpVT;
21061     if (UseVectorTy) {
21062       // Find a legal type for the vector store.
21063       unsigned Elts = NumElem * NumMemElts;
21064       JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
21065     } else {
21066       unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
21067       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
21068     }
21069 
21070     SDLoc LoadDL(LoadNodes[0].MemNode);
21071     SDLoc StoreDL(StoreNodes[0].MemNode);
21072 
21073     // The merged loads are required to have the same incoming chain, so
21074     // using the first's chain is acceptable.
21075 
21076     SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
21077     bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
21078     AddToWorklist(NewStoreChain.getNode());
21079 
21080     MachineMemOperand::Flags LdMMOFlags =
21081         isDereferenceable ? MachineMemOperand::MODereferenceable
21082                           : MachineMemOperand::MONone;
21083     if (IsNonTemporalLoad)
21084       LdMMOFlags |= MachineMemOperand::MONonTemporal;
21085 
21086     LdMMOFlags |= TLI.getTargetMMOFlags(*FirstLoad);
21087 
21088     MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
21089                                               ? MachineMemOperand::MONonTemporal
21090                                               : MachineMemOperand::MONone;
21091 
21092     StMMOFlags |= TLI.getTargetMMOFlags(*StoreNodes[0].MemNode);
21093 
21094     SDValue NewLoad, NewStore;
21095     if (UseVectorTy || !DoIntegerTruncate) {
21096       NewLoad = DAG.getLoad(
21097           JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
21098           FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
21099       SDValue StoreOp = NewLoad;
21100       if (NeedRotate) {
21101         unsigned LoadWidth = ElementSizeBytes * 8 * 2;
21102         assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
21103                "Unexpected type for rotate-able load pair");
21104         SDValue RotAmt =
21105             DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
21106         // Target can convert to the identical ROTR if it does not have ROTL.
21107         StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
21108       }
21109       NewStore = DAG.getStore(
21110           NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
21111           CanReusePtrInfo ? FirstInChain->getPointerInfo()
21112                           : MachinePointerInfo(FirstStoreAS),
21113           FirstStoreAlign, StMMOFlags);
21114     } else { // This must be the truncstore/extload case
21115       EVT ExtendedTy =
21116           TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
21117       NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
21118                                FirstLoad->getChain(), FirstLoad->getBasePtr(),
21119                                FirstLoad->getPointerInfo(), JointMemOpVT,
21120                                FirstLoadAlign, LdMMOFlags);
21121       NewStore = DAG.getTruncStore(
21122           NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
21123           CanReusePtrInfo ? FirstInChain->getPointerInfo()
21124                           : MachinePointerInfo(FirstStoreAS),
21125           JointMemOpVT, FirstInChain->getAlign(),
21126           FirstInChain->getMemOperand()->getFlags());
21127     }
21128 
21129     // Transfer chain users from old loads to the new load.
21130     for (unsigned i = 0; i < NumElem; ++i) {
21131       LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
21132       DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
21133                                     SDValue(NewLoad.getNode(), 1));
21134     }
21135 
21136     // Replace all stores with the new store. Recursively remove corresponding
21137     // values if they are no longer used.
21138     for (unsigned i = 0; i < NumElem; ++i) {
21139       SDValue Val = StoreNodes[i].MemNode->getOperand(1);
21140       CombineTo(StoreNodes[i].MemNode, NewStore);
21141       if (Val->use_empty())
21142         recursivelyDeleteUnusedNodes(Val.getNode());
21143     }
21144 
21145     MadeChange = true;
21146     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21147     LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
21148     NumConsecutiveStores -= NumElem;
21149   }
21150   return MadeChange;
21151 }
21152 
mergeConsecutiveStores(StoreSDNode * St)21153 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
21154   if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
21155     return false;
21156 
21157   // TODO: Extend this function to merge stores of scalable vectors.
21158   // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
21159   // store since we know <vscale x 16 x i8> is exactly twice as large as
21160   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
21161   EVT MemVT = St->getMemoryVT();
21162   if (MemVT.isScalableVT())
21163     return false;
21164   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
21165     return false;
21166 
21167   // This function cannot currently deal with non-byte-sized memory sizes.
21168   int64_t ElementSizeBytes = MemVT.getStoreSize();
21169   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
21170     return false;
21171 
21172   // Do not bother looking at stored values that are not constants, loads, or
21173   // extracted vector elements.
21174   SDValue StoredVal = peekThroughBitcasts(St->getValue());
21175   const StoreSource StoreSrc = getStoreSource(StoredVal);
21176   if (StoreSrc == StoreSource::Unknown)
21177     return false;
21178 
21179   SmallVector<MemOpLink, 8> StoreNodes;
21180   // Find potential store merge candidates by searching through chain sub-DAG
21181   SDNode *RootNode = getStoreMergeCandidates(St, StoreNodes);
21182 
21183   // Check if there is anything to merge.
21184   if (StoreNodes.size() < 2)
21185     return false;
21186 
21187   // Sort the memory operands according to their distance from the
21188   // base pointer.
21189   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
21190     return LHS.OffsetFromBase < RHS.OffsetFromBase;
21191   });
21192 
21193   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
21194       Attribute::NoImplicitFloat);
21195   bool IsNonTemporalStore = St->isNonTemporal();
21196   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
21197                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
21198 
21199   // Store Merge attempts to merge the lowest stores. This generally
21200   // works out as if successful, as the remaining stores are checked
21201   // after the first collection of stores is merged. However, in the
21202   // case that a non-mergeable store is found first, e.g., {p[-2],
21203   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
21204   // mergeable cases. To prevent this, we prune such stores from the
21205   // front of StoreNodes here.
21206   bool MadeChange = false;
21207   while (StoreNodes.size() > 1) {
21208     unsigned NumConsecutiveStores =
21209         getConsecutiveStores(StoreNodes, ElementSizeBytes);
21210     // There are no more stores in the list to examine.
21211     if (NumConsecutiveStores == 0)
21212       return MadeChange;
21213 
21214     // We have at least 2 consecutive stores. Try to merge them.
21215     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
21216     switch (StoreSrc) {
21217     case StoreSource::Constant:
21218       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
21219                                              MemVT, RootNode, AllowVectors);
21220       break;
21221 
21222     case StoreSource::Extract:
21223       MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
21224                                             MemVT, RootNode);
21225       break;
21226 
21227     case StoreSource::Load:
21228       MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
21229                                          MemVT, RootNode, AllowVectors,
21230                                          IsNonTemporalStore, IsNonTemporalLoad);
21231       break;
21232 
21233     default:
21234       llvm_unreachable("Unhandled store source type");
21235     }
21236   }
21237 
21238   // Remember if we failed to optimize, to save compile time.
21239   if (!MadeChange)
21240     ChainsWithoutMergeableStores.insert(RootNode);
21241 
21242   return MadeChange;
21243 }
21244 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)21245 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
21246   SDLoc SL(ST);
21247   SDValue ReplStore;
21248 
21249   // Replace the chain to avoid dependency.
21250   if (ST->isTruncatingStore()) {
21251     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
21252                                   ST->getBasePtr(), ST->getMemoryVT(),
21253                                   ST->getMemOperand());
21254   } else {
21255     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
21256                              ST->getMemOperand());
21257   }
21258 
21259   // Create token to keep both nodes around.
21260   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
21261                               MVT::Other, ST->getChain(), ReplStore);
21262 
21263   // Make sure the new and old chains are cleaned up.
21264   AddToWorklist(Token.getNode());
21265 
21266   // Don't add users to work list.
21267   return CombineTo(ST, Token, false);
21268 }
21269 
replaceStoreOfFPConstant(StoreSDNode * ST)21270 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
21271   SDValue Value = ST->getValue();
21272   if (Value.getOpcode() == ISD::TargetConstantFP)
21273     return SDValue();
21274 
21275   if (!ISD::isNormalStore(ST))
21276     return SDValue();
21277 
21278   SDLoc DL(ST);
21279 
21280   SDValue Chain = ST->getChain();
21281   SDValue Ptr = ST->getBasePtr();
21282 
21283   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
21284 
21285   // NOTE: If the original store is volatile, this transform must not increase
21286   // the number of stores.  For example, on x86-32 an f64 can be stored in one
21287   // processor operation but an i64 (which is not legal) requires two.  So the
21288   // transform should not be done in this case.
21289 
21290   SDValue Tmp;
21291   switch (CFP->getSimpleValueType(0).SimpleTy) {
21292   default:
21293     llvm_unreachable("Unknown FP type");
21294   case MVT::f16:    // We don't do this for these yet.
21295   case MVT::bf16:
21296   case MVT::f80:
21297   case MVT::f128:
21298   case MVT::ppcf128:
21299     return SDValue();
21300   case MVT::f32:
21301     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
21302         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
21303       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
21304                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
21305                             MVT::i32);
21306       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
21307     }
21308 
21309     return SDValue();
21310   case MVT::f64:
21311     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
21312          ST->isSimple()) ||
21313         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
21314       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
21315                             getZExtValue(), SDLoc(CFP), MVT::i64);
21316       return DAG.getStore(Chain, DL, Tmp,
21317                           Ptr, ST->getMemOperand());
21318     }
21319 
21320     if (ST->isSimple() && TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32) &&
21321         !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
21322       // Many FP stores are not made apparent until after legalize, e.g. for
21323       // argument passing.  Since this is so common, custom legalize the
21324       // 64-bit integer store into two 32-bit stores.
21325       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
21326       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
21327       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
21328       if (DAG.getDataLayout().isBigEndian())
21329         std::swap(Lo, Hi);
21330 
21331       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21332       AAMDNodes AAInfo = ST->getAAInfo();
21333 
21334       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
21335                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
21336       Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(4), DL);
21337       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
21338                                  ST->getPointerInfo().getWithOffset(4),
21339                                  ST->getOriginalAlign(), MMOFlags, AAInfo);
21340       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
21341                          St0, St1);
21342     }
21343 
21344     return SDValue();
21345   }
21346 }
21347 
21348 // (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
21349 //
21350 // If a store of a load with an element inserted into it has no other
21351 // uses in between the chain, then we can consider the vector store
21352 // dead and replace it with just the single scalar element store.
replaceStoreOfInsertLoad(StoreSDNode * ST)21353 SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
21354   SDLoc DL(ST);
21355   SDValue Value = ST->getValue();
21356   SDValue Ptr = ST->getBasePtr();
21357   SDValue Chain = ST->getChain();
21358   if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
21359     return SDValue();
21360 
21361   SDValue Elt = Value.getOperand(1);
21362   SDValue Idx = Value.getOperand(2);
21363 
21364   // If the element isn't byte sized or is implicitly truncated then we can't
21365   // compute an offset.
21366   EVT EltVT = Elt.getValueType();
21367   if (!EltVT.isByteSized() ||
21368       EltVT != Value.getOperand(0).getValueType().getVectorElementType())
21369     return SDValue();
21370 
21371   auto *Ld = dyn_cast<LoadSDNode>(Value.getOperand(0));
21372   if (!Ld || Ld->getBasePtr() != Ptr ||
21373       ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
21374       !ISD::isNormalStore(ST) ||
21375       Ld->getAddressSpace() != ST->getAddressSpace() ||
21376       !Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1)))
21377     return SDValue();
21378 
21379   unsigned IsFast;
21380   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
21381                               Elt.getValueType(), ST->getAddressSpace(),
21382                               ST->getAlign(), ST->getMemOperand()->getFlags(),
21383                               &IsFast) ||
21384       !IsFast)
21385     return SDValue();
21386 
21387   MachinePointerInfo PointerInfo(ST->getAddressSpace());
21388 
21389   // If the offset is a known constant then try to recover the pointer
21390   // info
21391   SDValue NewPtr;
21392   if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) {
21393     unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
21394     NewPtr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(COffset), DL);
21395     PointerInfo = ST->getPointerInfo().getWithOffset(COffset);
21396   } else {
21397     NewPtr = TLI.getVectorElementPointer(DAG, Ptr, Value.getValueType(), Idx);
21398   }
21399 
21400   return DAG.getStore(Chain, DL, Elt, NewPtr, PointerInfo, ST->getAlign(),
21401                       ST->getMemOperand()->getFlags());
21402 }
21403 
visitATOMIC_STORE(SDNode * N)21404 SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
21405   AtomicSDNode *ST = cast<AtomicSDNode>(N);
21406   SDValue Val = ST->getVal();
21407   EVT VT = Val.getValueType();
21408   EVT MemVT = ST->getMemoryVT();
21409 
21410   if (MemVT.bitsLT(VT)) { // Is truncating store
21411     APInt TruncDemandedBits = APInt::getLowBitsSet(VT.getScalarSizeInBits(),
21412                                                    MemVT.getScalarSizeInBits());
21413     // See if we can simplify the operation with SimplifyDemandedBits, which
21414     // only works if the value has a single use.
21415     if (SimplifyDemandedBits(Val, TruncDemandedBits))
21416       return SDValue(N, 0);
21417   }
21418 
21419   return SDValue();
21420 }
21421 
visitSTORE(SDNode * N)21422 SDValue DAGCombiner::visitSTORE(SDNode *N) {
21423   StoreSDNode *ST  = cast<StoreSDNode>(N);
21424   SDValue Chain = ST->getChain();
21425   SDValue Value = ST->getValue();
21426   SDValue Ptr   = ST->getBasePtr();
21427 
21428   // If this is a store of a bit convert, store the input value if the
21429   // resultant store does not need a higher alignment than the original.
21430   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
21431       ST->isUnindexed()) {
21432     EVT SVT = Value.getOperand(0).getValueType();
21433     // If the store is volatile, we only want to change the store type if the
21434     // resulting store is legal. Otherwise we might increase the number of
21435     // memory accesses. We don't care if the original type was legal or not
21436     // as we assume software couldn't rely on the number of accesses of an
21437     // illegal type.
21438     // TODO: May be able to relax for unordered atomics (see D66309)
21439     if (((!LegalOperations && ST->isSimple()) ||
21440          TLI.isOperationLegal(ISD::STORE, SVT)) &&
21441         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
21442                                      DAG, *ST->getMemOperand())) {
21443       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
21444                           ST->getMemOperand());
21445     }
21446   }
21447 
21448   // Turn 'store undef, Ptr' -> nothing.
21449   if (Value.isUndef() && ST->isUnindexed())
21450     return Chain;
21451 
21452   // Try to infer better alignment information than the store already has.
21453   if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
21454       !ST->isAtomic()) {
21455     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
21456       if (*Alignment > ST->getAlign() &&
21457           isAligned(*Alignment, ST->getSrcValueOffset())) {
21458         SDValue NewStore =
21459             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
21460                               ST->getMemoryVT(), *Alignment,
21461                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
21462         // NewStore will always be N as we are only refining the alignment
21463         assert(NewStore.getNode() == N);
21464         (void)NewStore;
21465       }
21466     }
21467   }
21468 
21469   // Try transforming a pair floating point load / store ops to integer
21470   // load / store ops.
21471   if (SDValue NewST = TransformFPLoadStorePair(N))
21472     return NewST;
21473 
21474   // Try transforming several stores into STORE (BSWAP).
21475   if (SDValue Store = mergeTruncStores(ST))
21476     return Store;
21477 
21478   if (ST->isUnindexed()) {
21479     // Walk up chain skipping non-aliasing memory nodes, on this store and any
21480     // adjacent stores.
21481     if (findBetterNeighborChains(ST)) {
21482       // replaceStoreChain uses CombineTo, which handled all of the worklist
21483       // manipulation. Return the original node to not do anything else.
21484       return SDValue(ST, 0);
21485     }
21486     Chain = ST->getChain();
21487   }
21488 
21489   // FIXME: is there such a thing as a truncating indexed store?
21490   if (ST->isTruncatingStore() && ST->isUnindexed() &&
21491       Value.getValueType().isInteger() &&
21492       (!isa<ConstantSDNode>(Value) ||
21493        !cast<ConstantSDNode>(Value)->isOpaque())) {
21494     // Convert a truncating store of a extension into a standard store.
21495     if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
21496          Value.getOpcode() == ISD::SIGN_EXTEND ||
21497          Value.getOpcode() == ISD::ANY_EXTEND) &&
21498         Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
21499         TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
21500       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
21501                           ST->getMemOperand());
21502 
21503     APInt TruncDemandedBits =
21504         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
21505                              ST->getMemoryVT().getScalarSizeInBits());
21506 
21507     // See if we can simplify the operation with SimplifyDemandedBits, which
21508     // only works if the value has a single use.
21509     AddToWorklist(Value.getNode());
21510     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
21511       // Re-visit the store if anything changed and the store hasn't been merged
21512       // with another node (N is deleted) SimplifyDemandedBits will add Value's
21513       // node back to the worklist if necessary, but we also need to re-visit
21514       // the Store node itself.
21515       if (N->getOpcode() != ISD::DELETED_NODE)
21516         AddToWorklist(N);
21517       return SDValue(N, 0);
21518     }
21519 
21520     // Otherwise, see if we can simplify the input to this truncstore with
21521     // knowledge that only the low bits are being used.  For example:
21522     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
21523     if (SDValue Shorter =
21524             TLI.SimplifyMultipleUseDemandedBits(Value, TruncDemandedBits, DAG))
21525       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
21526                                ST->getMemOperand());
21527 
21528     // If we're storing a truncated constant, see if we can simplify it.
21529     // TODO: Move this to targetShrinkDemandedConstant?
21530     if (auto *Cst = dyn_cast<ConstantSDNode>(Value))
21531       if (!Cst->isOpaque()) {
21532         const APInt &CValue = Cst->getAPIntValue();
21533         APInt NewVal = CValue & TruncDemandedBits;
21534         if (NewVal != CValue) {
21535           SDValue Shorter =
21536               DAG.getConstant(NewVal, SDLoc(N), Value.getValueType());
21537           return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr,
21538                                    ST->getMemoryVT(), ST->getMemOperand());
21539         }
21540       }
21541   }
21542 
21543   // If this is a load followed by a store to the same location, then the store
21544   // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
21545   // TODO: Add big-endian truncate support with test coverage.
21546   // TODO: Can relax for unordered atomics (see D66309)
21547   SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
21548                          ? peekThroughTruncates(Value)
21549                          : Value;
21550   if (auto *Ld = dyn_cast<LoadSDNode>(TruncVal)) {
21551     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
21552         ST->isUnindexed() && ST->isSimple() &&
21553         Ld->getAddressSpace() == ST->getAddressSpace() &&
21554         // There can't be any side effects between the load and store, such as
21555         // a call or store.
21556         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
21557       // The store is dead, remove it.
21558       return Chain;
21559     }
21560   }
21561 
21562   // Try scalarizing vector stores of loads where we only change one element
21563   if (SDValue NewST = replaceStoreOfInsertLoad(ST))
21564     return NewST;
21565 
21566   // TODO: Can relax for unordered atomics (see D66309)
21567   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
21568     if (ST->isUnindexed() && ST->isSimple() &&
21569         ST1->isUnindexed() && ST1->isSimple()) {
21570       if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
21571           ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
21572           ST->getAddressSpace() == ST1->getAddressSpace()) {
21573         // If this is a store followed by a store with the same value to the
21574         // same location, then the store is dead/noop.
21575         return Chain;
21576       }
21577 
21578       if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
21579           !ST1->getBasePtr().isUndef() &&
21580           ST->getAddressSpace() == ST1->getAddressSpace()) {
21581         // If we consider two stores and one smaller in size is a scalable
21582         // vector type and another one a bigger size store with a fixed type,
21583         // then we could not allow the scalable store removal because we don't
21584         // know its final size in the end.
21585         if (ST->getMemoryVT().isScalableVector() ||
21586             ST1->getMemoryVT().isScalableVector()) {
21587           if (ST1->getBasePtr() == Ptr &&
21588               TypeSize::isKnownLE(ST1->getMemoryVT().getStoreSize(),
21589                                   ST->getMemoryVT().getStoreSize())) {
21590             CombineTo(ST1, ST1->getChain());
21591             return SDValue(N, 0);
21592           }
21593         } else {
21594           const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
21595           const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
21596           // If this is a store who's preceding store to a subset of the current
21597           // location and no one other node is chained to that store we can
21598           // effectively drop the store. Do not remove stores to undef as they
21599           // may be used as data sinks.
21600           if (STBase.contains(DAG, ST->getMemoryVT().getFixedSizeInBits(),
21601                               ChainBase,
21602                               ST1->getMemoryVT().getFixedSizeInBits())) {
21603             CombineTo(ST1, ST1->getChain());
21604             return SDValue(N, 0);
21605           }
21606         }
21607       }
21608     }
21609   }
21610 
21611   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
21612   // truncating store.  We can do this even if this is already a truncstore.
21613   if ((Value.getOpcode() == ISD::FP_ROUND ||
21614        Value.getOpcode() == ISD::TRUNCATE) &&
21615       Value->hasOneUse() && ST->isUnindexed() &&
21616       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
21617                                ST->getMemoryVT(), LegalOperations)) {
21618     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
21619                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
21620   }
21621 
21622   // Always perform this optimization before types are legal. If the target
21623   // prefers, also try this after legalization to catch stores that were created
21624   // by intrinsics or other nodes.
21625   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
21626     while (true) {
21627       // There can be multiple store sequences on the same chain.
21628       // Keep trying to merge store sequences until we are unable to do so
21629       // or until we merge the last store on the chain.
21630       bool Changed = mergeConsecutiveStores(ST);
21631       if (!Changed) break;
21632       // Return N as merge only uses CombineTo and no worklist clean
21633       // up is necessary.
21634       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
21635         return SDValue(N, 0);
21636     }
21637   }
21638 
21639   // Try transforming N to an indexed store.
21640   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
21641     return SDValue(N, 0);
21642 
21643   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
21644   //
21645   // Make sure to do this only after attempting to merge stores in order to
21646   //  avoid changing the types of some subset of stores due to visit order,
21647   //  preventing their merging.
21648   if (isa<ConstantFPSDNode>(ST->getValue())) {
21649     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
21650       return NewSt;
21651   }
21652 
21653   if (SDValue NewSt = splitMergedValStore(ST))
21654     return NewSt;
21655 
21656   return ReduceLoadOpStoreWidth(N);
21657 }
21658 
visitLIFETIME_END(SDNode * N)21659 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
21660   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
21661   if (!LifetimeEnd->hasOffset())
21662     return SDValue();
21663 
21664   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
21665                                         LifetimeEnd->getOffset(), false);
21666 
21667   // We walk up the chains to find stores.
21668   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
21669   while (!Chains.empty()) {
21670     SDValue Chain = Chains.pop_back_val();
21671     if (!Chain.hasOneUse())
21672       continue;
21673     switch (Chain.getOpcode()) {
21674     case ISD::TokenFactor:
21675       for (unsigned Nops = Chain.getNumOperands(); Nops;)
21676         Chains.push_back(Chain.getOperand(--Nops));
21677       break;
21678     case ISD::LIFETIME_START:
21679     case ISD::LIFETIME_END:
21680       // We can forward past any lifetime start/end that can be proven not to
21681       // alias the node.
21682       if (!mayAlias(Chain.getNode(), N))
21683         Chains.push_back(Chain.getOperand(0));
21684       break;
21685     case ISD::STORE: {
21686       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
21687       // TODO: Can relax for unordered atomics (see D66309)
21688       if (!ST->isSimple() || ST->isIndexed())
21689         continue;
21690       const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
21691       // The bounds of a scalable store are not known until runtime, so this
21692       // store cannot be elided.
21693       if (StoreSize.isScalable())
21694         continue;
21695       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
21696       // If we store purely within object bounds just before its lifetime ends,
21697       // we can remove the store.
21698       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
21699                                    StoreSize.getFixedValue() * 8)) {
21700         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
21701                    dbgs() << "\nwithin LIFETIME_END of : ";
21702                    LifetimeEndBase.dump(); dbgs() << "\n");
21703         CombineTo(ST, ST->getChain());
21704         return SDValue(N, 0);
21705       }
21706     }
21707     }
21708   }
21709   return SDValue();
21710 }
21711 
21712 /// For the instruction sequence of store below, F and I values
21713 /// are bundled together as an i64 value before being stored into memory.
21714 /// Sometimes it is more efficent to generate separate stores for F and I,
21715 /// which can remove the bitwise instructions or sink them to colder places.
21716 ///
21717 ///   (store (or (zext (bitcast F to i32) to i64),
21718 ///              (shl (zext I to i64), 32)), addr)  -->
21719 ///   (store F, addr) and (store I, addr+4)
21720 ///
21721 /// Similarly, splitting for other merged store can also be beneficial, like:
21722 /// For pair of {i32, i32}, i64 store --> two i32 stores.
21723 /// For pair of {i32, i16}, i64 store --> two i32 stores.
21724 /// For pair of {i16, i16}, i32 store --> two i16 stores.
21725 /// For pair of {i16, i8},  i32 store --> two i16 stores.
21726 /// For pair of {i8, i8},   i16 store --> two i8 stores.
21727 ///
21728 /// We allow each target to determine specifically which kind of splitting is
21729 /// supported.
21730 ///
21731 /// The store patterns are commonly seen from the simple code snippet below
21732 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
21733 ///   void goo(const std::pair<int, float> &);
21734 ///   hoo() {
21735 ///     ...
21736 ///     goo(std::make_pair(tmp, ftmp));
21737 ///     ...
21738 ///   }
21739 ///
splitMergedValStore(StoreSDNode * ST)21740 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
21741   if (OptLevel == CodeGenOptLevel::None)
21742     return SDValue();
21743 
21744   // Can't change the number of memory accesses for a volatile store or break
21745   // atomicity for an atomic one.
21746   if (!ST->isSimple())
21747     return SDValue();
21748 
21749   SDValue Val = ST->getValue();
21750   SDLoc DL(ST);
21751 
21752   // Match OR operand.
21753   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
21754     return SDValue();
21755 
21756   // Match SHL operand and get Lower and Higher parts of Val.
21757   SDValue Op1 = Val.getOperand(0);
21758   SDValue Op2 = Val.getOperand(1);
21759   SDValue Lo, Hi;
21760   if (Op1.getOpcode() != ISD::SHL) {
21761     std::swap(Op1, Op2);
21762     if (Op1.getOpcode() != ISD::SHL)
21763       return SDValue();
21764   }
21765   Lo = Op2;
21766   Hi = Op1.getOperand(0);
21767   if (!Op1.hasOneUse())
21768     return SDValue();
21769 
21770   // Match shift amount to HalfValBitSize.
21771   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
21772   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
21773   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
21774     return SDValue();
21775 
21776   // Lo and Hi are zero-extended from int with size less equal than 32
21777   // to i64.
21778   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
21779       !Lo.getOperand(0).getValueType().isScalarInteger() ||
21780       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
21781       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
21782       !Hi.getOperand(0).getValueType().isScalarInteger() ||
21783       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
21784     return SDValue();
21785 
21786   // Use the EVT of low and high parts before bitcast as the input
21787   // of target query.
21788   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
21789                   ? Lo.getOperand(0).getValueType()
21790                   : Lo.getValueType();
21791   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
21792                    ? Hi.getOperand(0).getValueType()
21793                    : Hi.getValueType();
21794   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
21795     return SDValue();
21796 
21797   // Start to split store.
21798   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21799   AAMDNodes AAInfo = ST->getAAInfo();
21800 
21801   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
21802   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
21803   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
21804   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
21805 
21806   SDValue Chain = ST->getChain();
21807   SDValue Ptr = ST->getBasePtr();
21808   // Lower value store.
21809   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
21810                              ST->getOriginalAlign(), MMOFlags, AAInfo);
21811   Ptr =
21812       DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(HalfValBitSize / 8), DL);
21813   // Higher value store.
21814   SDValue St1 = DAG.getStore(
21815       St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
21816       ST->getOriginalAlign(), MMOFlags, AAInfo);
21817   return St1;
21818 }
21819 
21820 // Merge an insertion into an existing shuffle:
21821 // (insert_vector_elt (vector_shuffle X, Y, Mask),
21822 //                   .(extract_vector_elt X, N), InsIndex)
21823 //   --> (vector_shuffle X, Y, NewMask)
21824 //  and variations where shuffle operands may be CONCAT_VECTORS.
mergeEltWithShuffle(SDValue & X,SDValue & Y,ArrayRef<int> Mask,SmallVectorImpl<int> & NewMask,SDValue Elt,unsigned InsIndex)21825 static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
21826                                 SmallVectorImpl<int> &NewMask, SDValue Elt,
21827                                 unsigned InsIndex) {
21828   if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21829       !isa<ConstantSDNode>(Elt.getOperand(1)))
21830     return false;
21831 
21832   // Vec's operand 0 is using indices from 0 to N-1 and
21833   // operand 1 from N to 2N - 1, where N is the number of
21834   // elements in the vectors.
21835   SDValue InsertVal0 = Elt.getOperand(0);
21836   int ElementOffset = -1;
21837 
21838   // We explore the inputs of the shuffle in order to see if we find the
21839   // source of the extract_vector_elt. If so, we can use it to modify the
21840   // shuffle rather than perform an insert_vector_elt.
21841   SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
21842   ArgWorkList.emplace_back(Mask.size(), Y);
21843   ArgWorkList.emplace_back(0, X);
21844 
21845   while (!ArgWorkList.empty()) {
21846     int ArgOffset;
21847     SDValue ArgVal;
21848     std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
21849 
21850     if (ArgVal == InsertVal0) {
21851       ElementOffset = ArgOffset;
21852       break;
21853     }
21854 
21855     // Peek through concat_vector.
21856     if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
21857       int CurrentArgOffset =
21858           ArgOffset + ArgVal.getValueType().getVectorNumElements();
21859       int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
21860       for (SDValue Op : reverse(ArgVal->ops())) {
21861         CurrentArgOffset -= Step;
21862         ArgWorkList.emplace_back(CurrentArgOffset, Op);
21863       }
21864 
21865       // Make sure we went through all the elements and did not screw up index
21866       // computation.
21867       assert(CurrentArgOffset == ArgOffset);
21868     }
21869   }
21870 
21871   // If we failed to find a match, see if we can replace an UNDEF shuffle
21872   // operand.
21873   if (ElementOffset == -1) {
21874     if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
21875       return false;
21876     ElementOffset = Mask.size();
21877     Y = InsertVal0;
21878   }
21879 
21880   NewMask.assign(Mask.begin(), Mask.end());
21881   NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(1);
21882   assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
21883          "NewMask[InsIndex] is out of bound");
21884   return true;
21885 }
21886 
21887 // Merge an insertion into an existing shuffle:
21888 // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
21889 // InsIndex)
21890 //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
21891 //   CONCAT_VECTORS.
mergeInsertEltWithShuffle(SDNode * N,unsigned InsIndex)21892 SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
21893   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21894          "Expected extract_vector_elt");
21895   SDValue InsertVal = N->getOperand(1);
21896   SDValue Vec = N->getOperand(0);
21897 
21898   auto *SVN = dyn_cast<ShuffleVectorSDNode>(Vec);
21899   if (!SVN || !Vec.hasOneUse())
21900     return SDValue();
21901 
21902   ArrayRef<int> Mask = SVN->getMask();
21903   SDValue X = Vec.getOperand(0);
21904   SDValue Y = Vec.getOperand(1);
21905 
21906   SmallVector<int, 16> NewMask(Mask);
21907   if (mergeEltWithShuffle(X, Y, Mask, NewMask, InsertVal, InsIndex)) {
21908     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
21909         Vec.getValueType(), SDLoc(N), X, Y, NewMask, DAG);
21910     if (LegalShuffle)
21911       return LegalShuffle;
21912   }
21913 
21914   return SDValue();
21915 }
21916 
21917 // Convert a disguised subvector insertion into a shuffle:
21918 // insert_vector_elt V, (bitcast X from vector type), IdxC -->
21919 // bitcast(shuffle (bitcast V), (extended X), Mask)
21920 // Note: We do not use an insert_subvector node because that requires a
21921 // legal subvector type.
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)21922 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
21923   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21924          "Expected extract_vector_elt");
21925   SDValue InsertVal = N->getOperand(1);
21926 
21927   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
21928       !InsertVal.getOperand(0).getValueType().isVector())
21929     return SDValue();
21930 
21931   SDValue SubVec = InsertVal.getOperand(0);
21932   SDValue DestVec = N->getOperand(0);
21933   EVT SubVecVT = SubVec.getValueType();
21934   EVT VT = DestVec.getValueType();
21935   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
21936   // If the source only has a single vector element, the cost of creating adding
21937   // it to a vector is likely to exceed the cost of a insert_vector_elt.
21938   if (NumSrcElts == 1)
21939     return SDValue();
21940   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
21941   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
21942 
21943   // Step 1: Create a shuffle mask that implements this insert operation. The
21944   // vector that we are inserting into will be operand 0 of the shuffle, so
21945   // those elements are just 'i'. The inserted subvector is in the first
21946   // positions of operand 1 of the shuffle. Example:
21947   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
21948   SmallVector<int, 16> Mask(NumMaskVals);
21949   for (unsigned i = 0; i != NumMaskVals; ++i) {
21950     if (i / NumSrcElts == InsIndex)
21951       Mask[i] = (i % NumSrcElts) + NumMaskVals;
21952     else
21953       Mask[i] = i;
21954   }
21955 
21956   // Bail out if the target can not handle the shuffle we want to create.
21957   EVT SubVecEltVT = SubVecVT.getVectorElementType();
21958   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
21959   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
21960     return SDValue();
21961 
21962   // Step 2: Create a wide vector from the inserted source vector by appending
21963   // undefined elements. This is the same size as our destination vector.
21964   SDLoc DL(N);
21965   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
21966   ConcatOps[0] = SubVec;
21967   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
21968 
21969   // Step 3: Shuffle in the padded subvector.
21970   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
21971   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
21972   AddToWorklist(PaddedSubV.getNode());
21973   AddToWorklist(DestVecBC.getNode());
21974   AddToWorklist(Shuf.getNode());
21975   return DAG.getBitcast(VT, Shuf);
21976 }
21977 
21978 // Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
21979 // possible and the new load will be quick. We use more loads but less shuffles
21980 // and inserts.
combineInsertEltToLoad(SDNode * N,unsigned InsIndex)21981 SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
21982   EVT VT = N->getValueType(0);
21983 
21984   // InsIndex is expected to be the first of last lane.
21985   if (!VT.isFixedLengthVector() ||
21986       (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
21987     return SDValue();
21988 
21989   // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
21990   // depending on the InsIndex.
21991   auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0));
21992   SDValue Scalar = N->getOperand(1);
21993   if (!Shuffle || !all_of(enumerate(Shuffle->getMask()), [&](auto P) {
21994         return InsIndex == P.index() || P.value() < 0 ||
21995                (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
21996                (InsIndex == VT.getVectorNumElements() - 1 &&
21997                 P.value() == (int)P.index() + 1);
21998       }))
21999     return SDValue();
22000 
22001   // We optionally skip over an extend so long as both loads are extended in the
22002   // same way from the same type.
22003   unsigned Extend = 0;
22004   if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
22005       Scalar.getOpcode() == ISD::SIGN_EXTEND ||
22006       Scalar.getOpcode() == ISD::ANY_EXTEND) {
22007     Extend = Scalar.getOpcode();
22008     Scalar = Scalar.getOperand(0);
22009   }
22010 
22011   auto *ScalarLoad = dyn_cast<LoadSDNode>(Scalar);
22012   if (!ScalarLoad)
22013     return SDValue();
22014 
22015   SDValue Vec = Shuffle->getOperand(0);
22016   if (Extend) {
22017     if (Vec.getOpcode() != Extend)
22018       return SDValue();
22019     Vec = Vec.getOperand(0);
22020   }
22021   auto *VecLoad = dyn_cast<LoadSDNode>(Vec);
22022   if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
22023     return SDValue();
22024 
22025   int EltSize = ScalarLoad->getValueType(0).getScalarSizeInBits();
22026   if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
22027       !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
22028       ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
22029       ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
22030     return SDValue();
22031 
22032   // Check that the offset between the pointers to produce a single continuous
22033   // load.
22034   if (InsIndex == 0) {
22035     if (!DAG.areNonVolatileConsecutiveLoads(ScalarLoad, VecLoad, EltSize / 8,
22036                                             -1))
22037       return SDValue();
22038   } else {
22039     if (!DAG.areNonVolatileConsecutiveLoads(
22040             VecLoad, ScalarLoad, VT.getVectorNumElements() * EltSize / 8, -1))
22041       return SDValue();
22042   }
22043 
22044   // And that the new unaligned load will be fast.
22045   unsigned IsFast = 0;
22046   Align NewAlign = commonAlignment(VecLoad->getAlign(), EltSize / 8);
22047   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
22048                               Vec.getValueType(), VecLoad->getAddressSpace(),
22049                               NewAlign, VecLoad->getMemOperand()->getFlags(),
22050                               &IsFast) ||
22051       !IsFast)
22052     return SDValue();
22053 
22054   // Calculate the new Ptr and create the new load.
22055   SDLoc DL(N);
22056   SDValue Ptr = ScalarLoad->getBasePtr();
22057   if (InsIndex != 0)
22058     Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), VecLoad->getBasePtr(),
22059                       DAG.getConstant(EltSize / 8, DL, Ptr.getValueType()));
22060   MachinePointerInfo PtrInfo =
22061       InsIndex == 0 ? ScalarLoad->getPointerInfo()
22062                     : VecLoad->getPointerInfo().getWithOffset(EltSize / 8);
22063 
22064   SDValue Load = DAG.getLoad(VecLoad->getValueType(0), DL,
22065                              ScalarLoad->getChain(), Ptr, PtrInfo, NewAlign);
22066   DAG.makeEquivalentMemoryOrdering(ScalarLoad, Load.getValue(1));
22067   DAG.makeEquivalentMemoryOrdering(VecLoad, Load.getValue(1));
22068   return Extend ? DAG.getNode(Extend, DL, VT, Load) : Load;
22069 }
22070 
visitINSERT_VECTOR_ELT(SDNode * N)22071 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
22072   SDValue InVec = N->getOperand(0);
22073   SDValue InVal = N->getOperand(1);
22074   SDValue EltNo = N->getOperand(2);
22075   SDLoc DL(N);
22076 
22077   EVT VT = InVec.getValueType();
22078   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
22079 
22080   // Insert into out-of-bounds element is undefined.
22081   if (IndexC && VT.isFixedLengthVector() &&
22082       IndexC->getZExtValue() >= VT.getVectorNumElements())
22083     return DAG.getUNDEF(VT);
22084 
22085   // Remove redundant insertions:
22086   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
22087   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22088       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
22089     return InVec;
22090 
22091   if (!IndexC) {
22092     // If this is variable insert to undef vector, it might be better to splat:
22093     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
22094     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
22095       return DAG.getSplat(VT, DL, InVal);
22096     return SDValue();
22097   }
22098 
22099   if (VT.isScalableVector())
22100     return SDValue();
22101 
22102   unsigned NumElts = VT.getVectorNumElements();
22103 
22104   // We must know which element is being inserted for folds below here.
22105   unsigned Elt = IndexC->getZExtValue();
22106 
22107   // Handle <1 x ???> vector insertion special cases.
22108   if (NumElts == 1) {
22109     // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
22110     if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22111         InVal.getOperand(0).getValueType() == VT &&
22112         isNullConstant(InVal.getOperand(1)))
22113       return InVal.getOperand(0);
22114   }
22115 
22116   // Canonicalize insert_vector_elt dag nodes.
22117   // Example:
22118   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
22119   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
22120   //
22121   // Do this only if the child insert_vector node has one use; also
22122   // do this only if indices are both constants and Idx1 < Idx0.
22123   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
22124       && isa<ConstantSDNode>(InVec.getOperand(2))) {
22125     unsigned OtherElt = InVec.getConstantOperandVal(2);
22126     if (Elt < OtherElt) {
22127       // Swap nodes.
22128       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
22129                                   InVec.getOperand(0), InVal, EltNo);
22130       AddToWorklist(NewOp.getNode());
22131       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
22132                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
22133     }
22134   }
22135 
22136   if (SDValue Shuf = mergeInsertEltWithShuffle(N, Elt))
22137     return Shuf;
22138 
22139   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
22140     return Shuf;
22141 
22142   if (SDValue Shuf = combineInsertEltToLoad(N, Elt))
22143     return Shuf;
22144 
22145   // Attempt to convert an insert_vector_elt chain into a legal build_vector.
22146   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
22147     // vXi1 vector - we don't need to recurse.
22148     if (NumElts == 1)
22149       return DAG.getBuildVector(VT, DL, {InVal});
22150 
22151     // If we haven't already collected the element, insert into the op list.
22152     EVT MaxEltVT = InVal.getValueType();
22153     auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
22154                                 unsigned Idx) {
22155       if (!Ops[Idx]) {
22156         Ops[Idx] = Elt;
22157         if (VT.isInteger()) {
22158           EVT EltVT = Elt.getValueType();
22159           MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
22160         }
22161       }
22162     };
22163 
22164     // Ensure all the operands are the same value type, fill any missing
22165     // operands with UNDEF and create the BUILD_VECTOR.
22166     auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
22167       assert(Ops.size() == NumElts && "Unexpected vector size");
22168       for (SDValue &Op : Ops) {
22169         if (Op)
22170           Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
22171         else
22172           Op = DAG.getUNDEF(MaxEltVT);
22173       }
22174       return DAG.getBuildVector(VT, DL, Ops);
22175     };
22176 
22177     SmallVector<SDValue, 8> Ops(NumElts, SDValue());
22178     Ops[Elt] = InVal;
22179 
22180     // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
22181     for (SDValue CurVec = InVec; CurVec;) {
22182       // UNDEF - build new BUILD_VECTOR from already inserted operands.
22183       if (CurVec.isUndef())
22184         return CanonicalizeBuildVector(Ops);
22185 
22186       // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
22187       if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
22188         for (unsigned I = 0; I != NumElts; ++I)
22189           AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
22190         return CanonicalizeBuildVector(Ops);
22191       }
22192 
22193       // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
22194       if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
22195         AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
22196         return CanonicalizeBuildVector(Ops);
22197       }
22198 
22199       // INSERT_VECTOR_ELT - insert operand and continue up the chain.
22200       if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
22201         if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
22202           if (CurIdx->getAPIntValue().ult(NumElts)) {
22203             unsigned Idx = CurIdx->getZExtValue();
22204             AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
22205 
22206             // Found entire BUILD_VECTOR.
22207             if (all_of(Ops, [](SDValue Op) { return !!Op; }))
22208               return CanonicalizeBuildVector(Ops);
22209 
22210             CurVec = CurVec->getOperand(0);
22211             continue;
22212           }
22213 
22214       // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
22215       // update the shuffle mask (and second operand if we started with unary
22216       // shuffle) and create a new legal shuffle.
22217       if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
22218         auto *SVN = cast<ShuffleVectorSDNode>(CurVec);
22219         SDValue LHS = SVN->getOperand(0);
22220         SDValue RHS = SVN->getOperand(1);
22221         SmallVector<int, 16> Mask(SVN->getMask());
22222         bool Merged = true;
22223         for (auto I : enumerate(Ops)) {
22224           SDValue &Op = I.value();
22225           if (Op) {
22226             SmallVector<int, 16> NewMask;
22227             if (!mergeEltWithShuffle(LHS, RHS, Mask, NewMask, Op, I.index())) {
22228               Merged = false;
22229               break;
22230             }
22231             Mask = std::move(NewMask);
22232           }
22233         }
22234         if (Merged)
22235           if (SDValue NewShuffle =
22236                   TLI.buildLegalVectorShuffle(VT, DL, LHS, RHS, Mask, DAG))
22237             return NewShuffle;
22238       }
22239 
22240       // If all insertions are zero value, try to convert to AND mask.
22241       // TODO: Do this for -1 with OR mask?
22242       if (!LegalOperations && llvm::isNullConstant(InVal) &&
22243           all_of(Ops, [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
22244           count_if(Ops, [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
22245         SDValue Zero = DAG.getConstant(0, DL, MaxEltVT);
22246         SDValue AllOnes = DAG.getAllOnesConstant(DL, MaxEltVT);
22247         SmallVector<SDValue, 8> Mask(NumElts);
22248         for (unsigned I = 0; I != NumElts; ++I)
22249           Mask[I] = Ops[I] ? Zero : AllOnes;
22250         return DAG.getNode(ISD::AND, DL, VT, CurVec,
22251                            DAG.getBuildVector(VT, DL, Mask));
22252       }
22253 
22254       // Failed to find a match in the chain - bail.
22255       break;
22256     }
22257 
22258     // See if we can fill in the missing constant elements as zeros.
22259     // TODO: Should we do this for any constant?
22260     APInt DemandedZeroElts = APInt::getZero(NumElts);
22261     for (unsigned I = 0; I != NumElts; ++I)
22262       if (!Ops[I])
22263         DemandedZeroElts.setBit(I);
22264 
22265     if (DAG.MaskedVectorIsZero(InVec, DemandedZeroElts)) {
22266       SDValue Zero = VT.isInteger() ? DAG.getConstant(0, DL, MaxEltVT)
22267                                     : DAG.getConstantFP(0, DL, MaxEltVT);
22268       for (unsigned I = 0; I != NumElts; ++I)
22269         if (!Ops[I])
22270           Ops[I] = Zero;
22271 
22272       return CanonicalizeBuildVector(Ops);
22273     }
22274   }
22275 
22276   return SDValue();
22277 }
22278 
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)22279 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
22280                                                   SDValue EltNo,
22281                                                   LoadSDNode *OriginalLoad) {
22282   assert(OriginalLoad->isSimple());
22283 
22284   EVT ResultVT = EVE->getValueType(0);
22285   EVT VecEltVT = InVecVT.getVectorElementType();
22286 
22287   // If the vector element type is not a multiple of a byte then we are unable
22288   // to correctly compute an address to load only the extracted element as a
22289   // scalar.
22290   if (!VecEltVT.isByteSized())
22291     return SDValue();
22292 
22293   ISD::LoadExtType ExtTy =
22294       ResultVT.bitsGT(VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
22295   if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
22296       !TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
22297     return SDValue();
22298 
22299   Align Alignment = OriginalLoad->getAlign();
22300   MachinePointerInfo MPI;
22301   SDLoc DL(EVE);
22302   if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
22303     int Elt = ConstEltNo->getZExtValue();
22304     unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
22305     MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
22306     Alignment = commonAlignment(Alignment, PtrOff);
22307   } else {
22308     // Discard the pointer info except the address space because the memory
22309     // operand can't represent this new access since the offset is variable.
22310     MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
22311     Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
22312   }
22313 
22314   unsigned IsFast = 0;
22315   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
22316                               OriginalLoad->getAddressSpace(), Alignment,
22317                               OriginalLoad->getMemOperand()->getFlags(),
22318                               &IsFast) ||
22319       !IsFast)
22320     return SDValue();
22321 
22322   SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
22323                                                InVecVT, EltNo);
22324 
22325   // We are replacing a vector load with a scalar load. The new load must have
22326   // identical memory op ordering to the original.
22327   SDValue Load;
22328   if (ResultVT.bitsGT(VecEltVT)) {
22329     // If the result type of vextract is wider than the load, then issue an
22330     // extending load instead.
22331     ISD::LoadExtType ExtType =
22332         TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
22333                                                               : ISD::EXTLOAD;
22334     Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
22335                           NewPtr, MPI, VecEltVT, Alignment,
22336                           OriginalLoad->getMemOperand()->getFlags(),
22337                           OriginalLoad->getAAInfo());
22338     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
22339   } else {
22340     // The result type is narrower or the same width as the vector element
22341     Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
22342                        Alignment, OriginalLoad->getMemOperand()->getFlags(),
22343                        OriginalLoad->getAAInfo());
22344     DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
22345     if (ResultVT.bitsLT(VecEltVT))
22346       Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
22347     else
22348       Load = DAG.getBitcast(ResultVT, Load);
22349   }
22350   ++OpsNarrowed;
22351   return Load;
22352 }
22353 
22354 /// Transform a vector binary operation into a scalar binary operation by moving
22355 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,const SDLoc & DL,bool LegalOperations)22356 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
22357                                        const SDLoc &DL, bool LegalOperations) {
22358   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22359   SDValue Vec = ExtElt->getOperand(0);
22360   SDValue Index = ExtElt->getOperand(1);
22361   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22362   if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
22363       Vec->getNumValues() != 1)
22364     return SDValue();
22365 
22366   // Targets may want to avoid this to prevent an expensive register transfer.
22367   if (!TLI.shouldScalarizeBinop(Vec))
22368     return SDValue();
22369 
22370   // Extracting an element of a vector constant is constant-folded, so this
22371   // transform is just replacing a vector op with a scalar op while moving the
22372   // extract.
22373   SDValue Op0 = Vec.getOperand(0);
22374   SDValue Op1 = Vec.getOperand(1);
22375   APInt SplatVal;
22376   if (isAnyConstantBuildVector(Op0, true) ||
22377       ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
22378       isAnyConstantBuildVector(Op1, true) ||
22379       ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
22380     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22381     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22382     EVT VT = ExtElt->getValueType(0);
22383     SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
22384     SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
22385     return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
22386   }
22387 
22388   return SDValue();
22389 }
22390 
22391 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
22392 // recursively analyse all of it's users. and try to model themselves as
22393 // bit sequence extractions. If all of them agree on the new, narrower element
22394 // type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
22395 // new element type, do so now.
22396 // This is mainly useful to recover from legalization that scalarized
22397 // the vector as wide elements, but tries to rebuild it with narrower elements.
22398 //
22399 // Some more nodes could be modelled if that helps cover interesting patterns.
refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode * N)22400 bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
22401     SDNode *N) {
22402   // We perform this optimization post type-legalization because
22403   // the type-legalizer often scalarizes integer-promoted vectors.
22404   // Performing this optimization before may cause legalizaton cycles.
22405   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22406     return false;
22407 
22408   // TODO: Add support for big-endian.
22409   if (DAG.getDataLayout().isBigEndian())
22410     return false;
22411 
22412   SDValue VecOp = N->getOperand(0);
22413   EVT VecVT = VecOp.getValueType();
22414   assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
22415 
22416   // We must start with a constant extraction index.
22417   auto *IndexC = dyn_cast<ConstantSDNode>(N->getOperand(1));
22418   if (!IndexC)
22419     return false;
22420 
22421   assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
22422          "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
22423 
22424   // TODO: deal with the case of implicit anyext of the extraction.
22425   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22426   EVT ScalarVT = N->getValueType(0);
22427   if (VecVT.getScalarType() != ScalarVT)
22428     return false;
22429 
22430   // TODO: deal with the cases other than everything being integer-typed.
22431   if (!ScalarVT.isScalarInteger())
22432     return false;
22433 
22434   struct Entry {
22435     SDNode *Producer;
22436 
22437     // Which bits of VecOp does it contain?
22438     unsigned BitPos;
22439     int NumBits;
22440     // NOTE: the actual width of \p Producer may be wider than NumBits!
22441 
22442     Entry(Entry &&) = default;
22443     Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
22444         : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
22445 
22446     Entry() = delete;
22447     Entry(const Entry &) = delete;
22448     Entry &operator=(const Entry &) = delete;
22449     Entry &operator=(Entry &&) = delete;
22450   };
22451   SmallVector<Entry, 32> Worklist;
22452   SmallVector<Entry, 32> Leafs;
22453 
22454   // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
22455   Worklist.emplace_back(N, /*BitPos=*/VecEltBitWidth * IndexC->getZExtValue(),
22456                         /*NumBits=*/VecEltBitWidth);
22457 
22458   while (!Worklist.empty()) {
22459     Entry E = Worklist.pop_back_val();
22460     // Does the node not even use any of the VecOp bits?
22461     if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
22462           E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
22463       return false; // Let's allow the other combines clean this up first.
22464     // Did we fail to model any of the users of the Producer?
22465     bool ProducerIsLeaf = false;
22466     // Look at each user of this Producer.
22467     for (SDNode *User : E.Producer->uses()) {
22468       switch (User->getOpcode()) {
22469       // TODO: support ISD::BITCAST
22470       // TODO: support ISD::ANY_EXTEND
22471       // TODO: support ISD::ZERO_EXTEND
22472       // TODO: support ISD::SIGN_EXTEND
22473       case ISD::TRUNCATE:
22474         // Truncation simply means we keep position, but extract less bits.
22475         Worklist.emplace_back(User, E.BitPos,
22476                               /*NumBits=*/User->getValueSizeInBits(0));
22477         break;
22478       // TODO: support ISD::SRA
22479       // TODO: support ISD::SHL
22480       case ISD::SRL:
22481         // We should be shifting the Producer by a constant amount.
22482         if (auto *ShAmtC = dyn_cast<ConstantSDNode>(User->getOperand(1));
22483             User->getOperand(0).getNode() == E.Producer && ShAmtC) {
22484           // Logical right-shift means that we start extraction later,
22485           // but stop it at the same position we did previously.
22486           unsigned ShAmt = ShAmtC->getZExtValue();
22487           Worklist.emplace_back(User, E.BitPos + ShAmt, E.NumBits - ShAmt);
22488           break;
22489         }
22490         [[fallthrough]];
22491       default:
22492         // We can not model this user of the Producer.
22493         // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
22494         ProducerIsLeaf = true;
22495         // Profitability check: all users that we can not model
22496         //                      must be ISD::BUILD_VECTOR's.
22497         if (User->getOpcode() != ISD::BUILD_VECTOR)
22498           return false;
22499         break;
22500       }
22501     }
22502     if (ProducerIsLeaf)
22503       Leafs.emplace_back(std::move(E));
22504   }
22505 
22506   unsigned NewVecEltBitWidth = Leafs.front().NumBits;
22507 
22508   // If we are still at the same element granularity, give up,
22509   if (NewVecEltBitWidth == VecEltBitWidth)
22510     return false;
22511 
22512   // The vector width must be a multiple of the new element width.
22513   if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
22514     return false;
22515 
22516   // All leafs must agree on the new element width.
22517   // All leafs must not expect any "padding" bits ontop of that width.
22518   // All leafs must start extraction from multiple of that width.
22519   if (!all_of(Leafs, [NewVecEltBitWidth](const Entry &E) {
22520         return (unsigned)E.NumBits == NewVecEltBitWidth &&
22521                E.Producer->getValueSizeInBits(0) == NewVecEltBitWidth &&
22522                E.BitPos % NewVecEltBitWidth == 0;
22523       }))
22524     return false;
22525 
22526   EVT NewScalarVT = EVT::getIntegerVT(*DAG.getContext(), NewVecEltBitWidth);
22527   EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewScalarVT,
22528                                   VecVT.getSizeInBits() / NewVecEltBitWidth);
22529 
22530   if (LegalTypes &&
22531       !(TLI.isTypeLegal(NewScalarVT) && TLI.isTypeLegal(NewVecVT)))
22532     return false;
22533 
22534   if (LegalOperations &&
22535       !(TLI.isOperationLegalOrCustom(ISD::BITCAST, NewVecVT) &&
22536         TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, NewVecVT)))
22537     return false;
22538 
22539   SDValue NewVecOp = DAG.getBitcast(NewVecVT, VecOp);
22540   for (const Entry &E : Leafs) {
22541     SDLoc DL(E.Producer);
22542     unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
22543     assert(NewIndex < NewVecVT.getVectorNumElements() &&
22544            "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
22545     SDValue V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, NewScalarVT, NewVecOp,
22546                             DAG.getVectorIdxConstant(NewIndex, DL));
22547     CombineTo(E.Producer, V);
22548   }
22549 
22550   return true;
22551 }
22552 
visitEXTRACT_VECTOR_ELT(SDNode * N)22553 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
22554   SDValue VecOp = N->getOperand(0);
22555   SDValue Index = N->getOperand(1);
22556   EVT ScalarVT = N->getValueType(0);
22557   EVT VecVT = VecOp.getValueType();
22558   if (VecOp.isUndef())
22559     return DAG.getUNDEF(ScalarVT);
22560 
22561   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
22562   //
22563   // This only really matters if the index is non-constant since other combines
22564   // on the constant elements already work.
22565   SDLoc DL(N);
22566   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
22567       Index == VecOp.getOperand(2)) {
22568     SDValue Elt = VecOp.getOperand(1);
22569     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
22570   }
22571 
22572   // (vextract (scalar_to_vector val, 0) -> val
22573   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22574     // Only 0'th element of SCALAR_TO_VECTOR is defined.
22575     if (DAG.isKnownNeverZero(Index))
22576       return DAG.getUNDEF(ScalarVT);
22577 
22578     // Check if the result type doesn't match the inserted element type.
22579     // The inserted element and extracted element may have mismatched bitwidth.
22580     // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
22581     SDValue InOp = VecOp.getOperand(0);
22582     if (InOp.getValueType() != ScalarVT) {
22583       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22584       if (InOp.getValueType().bitsGT(ScalarVT))
22585         return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
22586       return DAG.getNode(ISD::ANY_EXTEND, DL, ScalarVT, InOp);
22587     }
22588     return InOp;
22589   }
22590 
22591   // extract_vector_elt of out-of-bounds element -> UNDEF
22592   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22593   if (IndexC && VecVT.isFixedLengthVector() &&
22594       IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
22595     return DAG.getUNDEF(ScalarVT);
22596 
22597   // extract_vector_elt (build_vector x, y), 1 -> y
22598   if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
22599        VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
22600       TLI.isTypeLegal(VecVT)) {
22601     assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
22602             VecVT.isFixedLengthVector()) &&
22603            "BUILD_VECTOR used for scalable vectors");
22604     unsigned IndexVal =
22605         VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
22606     SDValue Elt = VecOp.getOperand(IndexVal);
22607     EVT InEltVT = Elt.getValueType();
22608 
22609     if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
22610         isNullConstant(Elt)) {
22611       // Sometimes build_vector's scalar input types do not match result type.
22612       if (ScalarVT == InEltVT)
22613         return Elt;
22614 
22615       // TODO: It may be useful to truncate if free if the build_vector
22616       // implicitly converts.
22617     }
22618   }
22619 
22620   if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
22621     return BO;
22622 
22623   if (VecVT.isScalableVector())
22624     return SDValue();
22625 
22626   // All the code from this point onwards assumes fixed width vectors, but it's
22627   // possible that some of the combinations could be made to work for scalable
22628   // vectors too.
22629   unsigned NumElts = VecVT.getVectorNumElements();
22630   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22631 
22632   // See if the extracted element is constant, in which case fold it if its
22633   // a legal fp immediate.
22634   if (IndexC && ScalarVT.isFloatingPoint()) {
22635     APInt EltMask = APInt::getOneBitSet(NumElts, IndexC->getZExtValue());
22636     KnownBits KnownElt = DAG.computeKnownBits(VecOp, EltMask);
22637     if (KnownElt.isConstant()) {
22638       APFloat CstFP =
22639           APFloat(DAG.EVTToAPFloatSemantics(ScalarVT), KnownElt.getConstant());
22640       if (TLI.isFPImmLegal(CstFP, ScalarVT))
22641         return DAG.getConstantFP(CstFP, DL, ScalarVT);
22642     }
22643   }
22644 
22645   // TODO: These transforms should not require the 'hasOneUse' restriction, but
22646   // there are regressions on multiple targets without it. We can end up with a
22647   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
22648   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
22649       VecOp.hasOneUse()) {
22650     // The vector index of the LSBs of the source depend on the endian-ness.
22651     bool IsLE = DAG.getDataLayout().isLittleEndian();
22652     unsigned ExtractIndex = IndexC->getZExtValue();
22653     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
22654     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
22655     SDValue BCSrc = VecOp.getOperand(0);
22656     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
22657       return DAG.getAnyExtOrTrunc(BCSrc, DL, ScalarVT);
22658 
22659     if (LegalTypes && BCSrc.getValueType().isInteger() &&
22660         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22661       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
22662       // trunc i64 X to i32
22663       SDValue X = BCSrc.getOperand(0);
22664       assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
22665              "Extract element and scalar to vector can't change element type "
22666              "from FP to integer.");
22667       unsigned XBitWidth = X.getValueSizeInBits();
22668       BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
22669 
22670       // An extract element return value type can be wider than its vector
22671       // operand element type. In that case, the high bits are undefined, so
22672       // it's possible that we may need to extend rather than truncate.
22673       if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
22674         assert(XBitWidth % VecEltBitWidth == 0 &&
22675                "Scalar bitwidth must be a multiple of vector element bitwidth");
22676         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
22677       }
22678     }
22679   }
22680 
22681   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
22682   // We only perform this optimization before the op legalization phase because
22683   // we may introduce new vector instructions which are not backed by TD
22684   // patterns. For example on AVX, extracting elements from a wide vector
22685   // without using extract_subvector. However, if we can find an underlying
22686   // scalar value, then we can always use that.
22687   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
22688     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
22689     // Find the new index to extract from.
22690     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
22691 
22692     // Extracting an undef index is undef.
22693     if (OrigElt == -1)
22694       return DAG.getUNDEF(ScalarVT);
22695 
22696     // Select the right vector half to extract from.
22697     SDValue SVInVec;
22698     if (OrigElt < (int)NumElts) {
22699       SVInVec = VecOp.getOperand(0);
22700     } else {
22701       SVInVec = VecOp.getOperand(1);
22702       OrigElt -= NumElts;
22703     }
22704 
22705     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
22706       SDValue InOp = SVInVec.getOperand(OrigElt);
22707       if (InOp.getValueType() != ScalarVT) {
22708         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22709         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
22710       }
22711 
22712       return InOp;
22713     }
22714 
22715     // FIXME: We should handle recursing on other vector shuffles and
22716     // scalar_to_vector here as well.
22717 
22718     if (!LegalOperations ||
22719         // FIXME: Should really be just isOperationLegalOrCustom.
22720         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
22721         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
22722       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
22723                          DAG.getVectorIdxConstant(OrigElt, DL));
22724     }
22725   }
22726 
22727   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
22728   // simplify it based on the (valid) extraction indices.
22729   if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
22730         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22731                Use->getOperand(0) == VecOp &&
22732                isa<ConstantSDNode>(Use->getOperand(1));
22733       })) {
22734     APInt DemandedElts = APInt::getZero(NumElts);
22735     for (SDNode *Use : VecOp->uses()) {
22736       auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
22737       if (CstElt->getAPIntValue().ult(NumElts))
22738         DemandedElts.setBit(CstElt->getZExtValue());
22739     }
22740     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
22741       // We simplified the vector operand of this extract element. If this
22742       // extract is not dead, visit it again so it is folded properly.
22743       if (N->getOpcode() != ISD::DELETED_NODE)
22744         AddToWorklist(N);
22745       return SDValue(N, 0);
22746     }
22747     APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
22748     if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
22749       // We simplified the vector operand of this extract element. If this
22750       // extract is not dead, visit it again so it is folded properly.
22751       if (N->getOpcode() != ISD::DELETED_NODE)
22752         AddToWorklist(N);
22753       return SDValue(N, 0);
22754     }
22755   }
22756 
22757   if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
22758     return SDValue(N, 0);
22759 
22760   // Everything under here is trying to match an extract of a loaded value.
22761   // If the result of load has to be truncated, then it's not necessarily
22762   // profitable.
22763   bool BCNumEltsChanged = false;
22764   EVT ExtVT = VecVT.getVectorElementType();
22765   EVT LVT = ExtVT;
22766   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
22767     return SDValue();
22768 
22769   if (VecOp.getOpcode() == ISD::BITCAST) {
22770     // Don't duplicate a load with other uses.
22771     if (!VecOp.hasOneUse())
22772       return SDValue();
22773 
22774     EVT BCVT = VecOp.getOperand(0).getValueType();
22775     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
22776       return SDValue();
22777     if (NumElts != BCVT.getVectorNumElements())
22778       BCNumEltsChanged = true;
22779     VecOp = VecOp.getOperand(0);
22780     ExtVT = BCVT.getVectorElementType();
22781   }
22782 
22783   // extract (vector load $addr), i --> load $addr + i * size
22784   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
22785       ISD::isNormalLoad(VecOp.getNode()) &&
22786       !Index->hasPredecessor(VecOp.getNode())) {
22787     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
22788     if (VecLoad && VecLoad->isSimple())
22789       return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
22790   }
22791 
22792   // Perform only after legalization to ensure build_vector / vector_shuffle
22793   // optimizations have already been done.
22794   if (!LegalOperations || !IndexC)
22795     return SDValue();
22796 
22797   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
22798   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
22799   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
22800   int Elt = IndexC->getZExtValue();
22801   LoadSDNode *LN0 = nullptr;
22802   if (ISD::isNormalLoad(VecOp.getNode())) {
22803     LN0 = cast<LoadSDNode>(VecOp);
22804   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
22805              VecOp.getOperand(0).getValueType() == ExtVT &&
22806              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
22807     // Don't duplicate a load with other uses.
22808     if (!VecOp.hasOneUse())
22809       return SDValue();
22810 
22811     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
22812   }
22813   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
22814     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
22815     // =>
22816     // (load $addr+1*size)
22817 
22818     // Don't duplicate a load with other uses.
22819     if (!VecOp.hasOneUse())
22820       return SDValue();
22821 
22822     // If the bit convert changed the number of elements, it is unsafe
22823     // to examine the mask.
22824     if (BCNumEltsChanged)
22825       return SDValue();
22826 
22827     // Select the input vector, guarding against out of range extract vector.
22828     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
22829     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
22830 
22831     if (VecOp.getOpcode() == ISD::BITCAST) {
22832       // Don't duplicate a load with other uses.
22833       if (!VecOp.hasOneUse())
22834         return SDValue();
22835 
22836       VecOp = VecOp.getOperand(0);
22837     }
22838     if (ISD::isNormalLoad(VecOp.getNode())) {
22839       LN0 = cast<LoadSDNode>(VecOp);
22840       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
22841       Index = DAG.getConstant(Elt, DL, Index.getValueType());
22842     }
22843   } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
22844              VecVT.getVectorElementType() == ScalarVT &&
22845              (!LegalTypes ||
22846               TLI.isTypeLegal(
22847                   VecOp.getOperand(0).getValueType().getVectorElementType()))) {
22848     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
22849     //      -> extract_vector_elt a, 0
22850     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
22851     //      -> extract_vector_elt a, 1
22852     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
22853     //      -> extract_vector_elt b, 0
22854     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
22855     //      -> extract_vector_elt b, 1
22856     EVT ConcatVT = VecOp.getOperand(0).getValueType();
22857     unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
22858     SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, DL,
22859                                      Index.getValueType());
22860 
22861     SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
22862     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
22863                               ConcatVT.getVectorElementType(),
22864                               ConcatOp, NewIdx);
22865     return DAG.getNode(ISD::BITCAST, DL, ScalarVT, Elt);
22866   }
22867 
22868   // Make sure we found a non-volatile load and the extractelement is
22869   // the only use.
22870   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
22871     return SDValue();
22872 
22873   // If Idx was -1 above, Elt is going to be -1, so just return undef.
22874   if (Elt == -1)
22875     return DAG.getUNDEF(LVT);
22876 
22877   return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
22878 }
22879 
22880 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)22881 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
22882   // We perform this optimization post type-legalization because
22883   // the type-legalizer often scalarizes integer-promoted vectors.
22884   // Performing this optimization before may create bit-casts which
22885   // will be type-legalized to complex code sequences.
22886   // We perform this optimization only before the operation legalizer because we
22887   // may introduce illegal operations.
22888   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22889     return SDValue();
22890 
22891   unsigned NumInScalars = N->getNumOperands();
22892   SDLoc DL(N);
22893   EVT VT = N->getValueType(0);
22894 
22895   // Check to see if this is a BUILD_VECTOR of a bunch of values
22896   // which come from any_extend or zero_extend nodes. If so, we can create
22897   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
22898   // optimizations. We do not handle sign-extend because we can't fill the sign
22899   // using shuffles.
22900   EVT SourceType = MVT::Other;
22901   bool AllAnyExt = true;
22902 
22903   for (unsigned i = 0; i != NumInScalars; ++i) {
22904     SDValue In = N->getOperand(i);
22905     // Ignore undef inputs.
22906     if (In.isUndef()) continue;
22907 
22908     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
22909     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
22910 
22911     // Abort if the element is not an extension.
22912     if (!ZeroExt && !AnyExt) {
22913       SourceType = MVT::Other;
22914       break;
22915     }
22916 
22917     // The input is a ZeroExt or AnyExt. Check the original type.
22918     EVT InTy = In.getOperand(0).getValueType();
22919 
22920     // Check that all of the widened source types are the same.
22921     if (SourceType == MVT::Other)
22922       // First time.
22923       SourceType = InTy;
22924     else if (InTy != SourceType) {
22925       // Multiple income types. Abort.
22926       SourceType = MVT::Other;
22927       break;
22928     }
22929 
22930     // Check if all of the extends are ANY_EXTENDs.
22931     AllAnyExt &= AnyExt;
22932   }
22933 
22934   // In order to have valid types, all of the inputs must be extended from the
22935   // same source type and all of the inputs must be any or zero extend.
22936   // Scalar sizes must be a power of two.
22937   EVT OutScalarTy = VT.getScalarType();
22938   bool ValidTypes =
22939       SourceType != MVT::Other &&
22940       llvm::has_single_bit<uint32_t>(OutScalarTy.getSizeInBits()) &&
22941       llvm::has_single_bit<uint32_t>(SourceType.getSizeInBits());
22942 
22943   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
22944   // turn into a single shuffle instruction.
22945   if (!ValidTypes)
22946     return SDValue();
22947 
22948   // If we already have a splat buildvector, then don't fold it if it means
22949   // introducing zeros.
22950   if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
22951     return SDValue();
22952 
22953   bool isLE = DAG.getDataLayout().isLittleEndian();
22954   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
22955   assert(ElemRatio > 1 && "Invalid element size ratio");
22956   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
22957                                DAG.getConstant(0, DL, SourceType);
22958 
22959   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
22960   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
22961 
22962   // Populate the new build_vector
22963   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
22964     SDValue Cast = N->getOperand(i);
22965     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
22966             Cast.getOpcode() == ISD::ZERO_EXTEND ||
22967             Cast.isUndef()) && "Invalid cast opcode");
22968     SDValue In;
22969     if (Cast.isUndef())
22970       In = DAG.getUNDEF(SourceType);
22971     else
22972       In = Cast->getOperand(0);
22973     unsigned Index = isLE ? (i * ElemRatio) :
22974                             (i * ElemRatio + (ElemRatio - 1));
22975 
22976     assert(Index < Ops.size() && "Invalid index");
22977     Ops[Index] = In;
22978   }
22979 
22980   // The type of the new BUILD_VECTOR node.
22981   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
22982   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
22983          "Invalid vector size");
22984   // Check if the new vector type is legal.
22985   if (!isTypeLegal(VecVT) ||
22986       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
22987        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
22988     return SDValue();
22989 
22990   // Make the new BUILD_VECTOR.
22991   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
22992 
22993   // The new BUILD_VECTOR node has the potential to be further optimized.
22994   AddToWorklist(BV.getNode());
22995   // Bitcast to the desired type.
22996   return DAG.getBitcast(VT, BV);
22997 }
22998 
22999 // Simplify (build_vec (trunc $1)
23000 //                     (trunc (srl $1 half-width))
23001 //                     (trunc (srl $1 (2 * half-width))))
23002 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)23003 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
23004   assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
23005 
23006   EVT VT = N->getValueType(0);
23007 
23008   // Don't run this before LegalizeTypes if VT is legal.
23009   // Targets may have other preferences.
23010   if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
23011     return SDValue();
23012 
23013   // Only for little endian
23014   if (!DAG.getDataLayout().isLittleEndian())
23015     return SDValue();
23016 
23017   SDLoc DL(N);
23018   EVT OutScalarTy = VT.getScalarType();
23019   uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
23020 
23021   // Only for power of two types to be sure that bitcast works well
23022   if (!isPowerOf2_64(ScalarTypeBitsize))
23023     return SDValue();
23024 
23025   unsigned NumInScalars = N->getNumOperands();
23026 
23027   // Look through bitcasts
23028   auto PeekThroughBitcast = [](SDValue Op) {
23029     if (Op.getOpcode() == ISD::BITCAST)
23030       return Op.getOperand(0);
23031     return Op;
23032   };
23033 
23034   // The source value where all the parts are extracted.
23035   SDValue Src;
23036   for (unsigned i = 0; i != NumInScalars; ++i) {
23037     SDValue In = PeekThroughBitcast(N->getOperand(i));
23038     // Ignore undef inputs.
23039     if (In.isUndef()) continue;
23040 
23041     if (In.getOpcode() != ISD::TRUNCATE)
23042       return SDValue();
23043 
23044     In = PeekThroughBitcast(In.getOperand(0));
23045 
23046     if (In.getOpcode() != ISD::SRL) {
23047       // For now only build_vec without shuffling, handle shifts here in the
23048       // future.
23049       if (i != 0)
23050         return SDValue();
23051 
23052       Src = In;
23053     } else {
23054       // In is SRL
23055       SDValue part = PeekThroughBitcast(In.getOperand(0));
23056 
23057       if (!Src) {
23058         Src = part;
23059       } else if (Src != part) {
23060         // Vector parts do not stem from the same variable
23061         return SDValue();
23062       }
23063 
23064       SDValue ShiftAmtVal = In.getOperand(1);
23065       if (!isa<ConstantSDNode>(ShiftAmtVal))
23066         return SDValue();
23067 
23068       uint64_t ShiftAmt = In.getConstantOperandVal(1);
23069 
23070       // The extracted value is not extracted at the right position
23071       if (ShiftAmt != i * ScalarTypeBitsize)
23072         return SDValue();
23073     }
23074   }
23075 
23076   // Only cast if the size is the same
23077   if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
23078     return SDValue();
23079 
23080   return DAG.getBitcast(VT, Src);
23081 }
23082 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)23083 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
23084                                            ArrayRef<int> VectorMask,
23085                                            SDValue VecIn1, SDValue VecIn2,
23086                                            unsigned LeftIdx, bool DidSplitVec) {
23087   SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
23088 
23089   EVT VT = N->getValueType(0);
23090   EVT InVT1 = VecIn1.getValueType();
23091   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
23092 
23093   unsigned NumElems = VT.getVectorNumElements();
23094   unsigned ShuffleNumElems = NumElems;
23095 
23096   // If we artificially split a vector in two already, then the offsets in the
23097   // operands will all be based off of VecIn1, even those in VecIn2.
23098   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
23099 
23100   uint64_t VTSize = VT.getFixedSizeInBits();
23101   uint64_t InVT1Size = InVT1.getFixedSizeInBits();
23102   uint64_t InVT2Size = InVT2.getFixedSizeInBits();
23103 
23104   assert(InVT2Size <= InVT1Size &&
23105          "Inputs must be sorted to be in non-increasing vector size order.");
23106 
23107   // We can't generate a shuffle node with mismatched input and output types.
23108   // Try to make the types match the type of the output.
23109   if (InVT1 != VT || InVT2 != VT) {
23110     if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
23111       // If the output vector length is a multiple of both input lengths,
23112       // we can concatenate them and pad the rest with undefs.
23113       unsigned NumConcats = VTSize / InVT1Size;
23114       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
23115       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
23116       ConcatOps[0] = VecIn1;
23117       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
23118       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
23119       VecIn2 = SDValue();
23120     } else if (InVT1Size == VTSize * 2) {
23121       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
23122         return SDValue();
23123 
23124       if (!VecIn2.getNode()) {
23125         // If we only have one input vector, and it's twice the size of the
23126         // output, split it in two.
23127         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
23128                              DAG.getVectorIdxConstant(NumElems, DL));
23129         VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
23130         // Since we now have shorter input vectors, adjust the offset of the
23131         // second vector's start.
23132         Vec2Offset = NumElems;
23133       } else {
23134         assert(InVT2Size <= InVT1Size &&
23135                "Second input is not going to be larger than the first one.");
23136 
23137         // VecIn1 is wider than the output, and we have another, possibly
23138         // smaller input. Pad the smaller input with undefs, shuffle at the
23139         // input vector width, and extract the output.
23140         // The shuffle type is different than VT, so check legality again.
23141         if (LegalOperations &&
23142             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
23143           return SDValue();
23144 
23145         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
23146         // lower it back into a BUILD_VECTOR. So if the inserted type is
23147         // illegal, don't even try.
23148         if (InVT1 != InVT2) {
23149           if (!TLI.isTypeLegal(InVT2))
23150             return SDValue();
23151           VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
23152                                DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
23153         }
23154         ShuffleNumElems = NumElems * 2;
23155       }
23156     } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
23157       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
23158       ConcatOps[0] = VecIn2;
23159       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
23160     } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
23161       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems) ||
23162           !TLI.isTypeLegal(InVT1) || !TLI.isTypeLegal(InVT2))
23163         return SDValue();
23164       // If dest vector has less than two elements, then use shuffle and extract
23165       // from larger regs will cost even more.
23166       if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
23167         return SDValue();
23168       assert(InVT2Size <= InVT1Size &&
23169              "Second input is not going to be larger than the first one.");
23170 
23171       // VecIn1 is wider than the output, and we have another, possibly
23172       // smaller input. Pad the smaller input with undefs, shuffle at the
23173       // input vector width, and extract the output.
23174       // The shuffle type is different than VT, so check legality again.
23175       if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
23176         return SDValue();
23177 
23178       if (InVT1 != InVT2) {
23179         VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
23180                              DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
23181       }
23182       ShuffleNumElems = InVT1Size / VTSize * NumElems;
23183     } else {
23184       // TODO: Support cases where the length mismatch isn't exactly by a
23185       // factor of 2.
23186       // TODO: Move this check upwards, so that if we have bad type
23187       // mismatches, we don't create any DAG nodes.
23188       return SDValue();
23189     }
23190   }
23191 
23192   // Initialize mask to undef.
23193   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
23194 
23195   // Only need to run up to the number of elements actually used, not the
23196   // total number of elements in the shuffle - if we are shuffling a wider
23197   // vector, the high lanes should be set to undef.
23198   for (unsigned i = 0; i != NumElems; ++i) {
23199     if (VectorMask[i] <= 0)
23200       continue;
23201 
23202     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
23203     if (VectorMask[i] == (int)LeftIdx) {
23204       Mask[i] = ExtIndex;
23205     } else if (VectorMask[i] == (int)LeftIdx + 1) {
23206       Mask[i] = Vec2Offset + ExtIndex;
23207     }
23208   }
23209 
23210   // The type the input vectors may have changed above.
23211   InVT1 = VecIn1.getValueType();
23212 
23213   // If we already have a VecIn2, it should have the same type as VecIn1.
23214   // If we don't, get an undef/zero vector of the appropriate type.
23215   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
23216   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
23217 
23218   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
23219   if (ShuffleNumElems > NumElems)
23220     Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
23221 
23222   return Shuffle;
23223 }
23224 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)23225 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
23226   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
23227 
23228   // First, determine where the build vector is not undef.
23229   // TODO: We could extend this to handle zero elements as well as undefs.
23230   int NumBVOps = BV->getNumOperands();
23231   int ZextElt = -1;
23232   for (int i = 0; i != NumBVOps; ++i) {
23233     SDValue Op = BV->getOperand(i);
23234     if (Op.isUndef())
23235       continue;
23236     if (ZextElt == -1)
23237       ZextElt = i;
23238     else
23239       return SDValue();
23240   }
23241   // Bail out if there's no non-undef element.
23242   if (ZextElt == -1)
23243     return SDValue();
23244 
23245   // The build vector contains some number of undef elements and exactly
23246   // one other element. That other element must be a zero-extended scalar
23247   // extracted from a vector at a constant index to turn this into a shuffle.
23248   // Also, require that the build vector does not implicitly truncate/extend
23249   // its elements.
23250   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
23251   EVT VT = BV->getValueType(0);
23252   SDValue Zext = BV->getOperand(ZextElt);
23253   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
23254       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
23255       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
23256       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
23257     return SDValue();
23258 
23259   // The zero-extend must be a multiple of the source size, and we must be
23260   // building a vector of the same size as the source of the extract element.
23261   SDValue Extract = Zext.getOperand(0);
23262   unsigned DestSize = Zext.getValueSizeInBits();
23263   unsigned SrcSize = Extract.getValueSizeInBits();
23264   if (DestSize % SrcSize != 0 ||
23265       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
23266     return SDValue();
23267 
23268   // Create a shuffle mask that will combine the extracted element with zeros
23269   // and undefs.
23270   int ZextRatio = DestSize / SrcSize;
23271   int NumMaskElts = NumBVOps * ZextRatio;
23272   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
23273   for (int i = 0; i != NumMaskElts; ++i) {
23274     if (i / ZextRatio == ZextElt) {
23275       // The low bits of the (potentially translated) extracted element map to
23276       // the source vector. The high bits map to zero. We will use a zero vector
23277       // as the 2nd source operand of the shuffle, so use the 1st element of
23278       // that vector (mask value is number-of-elements) for the high bits.
23279       int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
23280       ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(1)
23281                                            : NumMaskElts;
23282     }
23283 
23284     // Undef elements of the build vector remain undef because we initialize
23285     // the shuffle mask with -1.
23286   }
23287 
23288   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
23289   // bitcast (shuffle V, ZeroVec, VectorMask)
23290   SDLoc DL(BV);
23291   EVT VecVT = Extract.getOperand(0).getValueType();
23292   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
23293   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23294   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
23295                                              ZeroVec, ShufMask, DAG);
23296   if (!Shuf)
23297     return SDValue();
23298   return DAG.getBitcast(VT, Shuf);
23299 }
23300 
23301 // FIXME: promote to STLExtras.
23302 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)23303 static auto getFirstIndexOf(R &&Range, const T &Val) {
23304   auto I = find(Range, Val);
23305   if (I == Range.end())
23306     return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
23307   return std::distance(Range.begin(), I);
23308 }
23309 
23310 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
23311 // operations. If the types of the vectors we're extracting from allow it,
23312 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)23313 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
23314   SDLoc DL(N);
23315   EVT VT = N->getValueType(0);
23316 
23317   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
23318   if (!isTypeLegal(VT))
23319     return SDValue();
23320 
23321   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
23322     return V;
23323 
23324   // May only combine to shuffle after legalize if shuffle is legal.
23325   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
23326     return SDValue();
23327 
23328   bool UsesZeroVector = false;
23329   unsigned NumElems = N->getNumOperands();
23330 
23331   // Record, for each element of the newly built vector, which input vector
23332   // that element comes from. -1 stands for undef, 0 for the zero vector,
23333   // and positive values for the input vectors.
23334   // VectorMask maps each element to its vector number, and VecIn maps vector
23335   // numbers to their initial SDValues.
23336 
23337   SmallVector<int, 8> VectorMask(NumElems, -1);
23338   SmallVector<SDValue, 8> VecIn;
23339   VecIn.push_back(SDValue());
23340 
23341   for (unsigned i = 0; i != NumElems; ++i) {
23342     SDValue Op = N->getOperand(i);
23343 
23344     if (Op.isUndef())
23345       continue;
23346 
23347     // See if we can use a blend with a zero vector.
23348     // TODO: Should we generalize this to a blend with an arbitrary constant
23349     // vector?
23350     if (isNullConstant(Op) || isNullFPConstant(Op)) {
23351       UsesZeroVector = true;
23352       VectorMask[i] = 0;
23353       continue;
23354     }
23355 
23356     // Not an undef or zero. If the input is something other than an
23357     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
23358     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
23359         !isa<ConstantSDNode>(Op.getOperand(1)))
23360       return SDValue();
23361     SDValue ExtractedFromVec = Op.getOperand(0);
23362 
23363     if (ExtractedFromVec.getValueType().isScalableVector())
23364       return SDValue();
23365 
23366     const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
23367     if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
23368       return SDValue();
23369 
23370     // All inputs must have the same element type as the output.
23371     if (VT.getVectorElementType() !=
23372         ExtractedFromVec.getValueType().getVectorElementType())
23373       return SDValue();
23374 
23375     // Have we seen this input vector before?
23376     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
23377     // a map back from SDValues to numbers isn't worth it.
23378     int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
23379     if (Idx == -1) { // A new source vector?
23380       Idx = VecIn.size();
23381       VecIn.push_back(ExtractedFromVec);
23382     }
23383 
23384     VectorMask[i] = Idx;
23385   }
23386 
23387   // If we didn't find at least one input vector, bail out.
23388   if (VecIn.size() < 2)
23389     return SDValue();
23390 
23391   // If all the Operands of BUILD_VECTOR extract from same
23392   // vector, then split the vector efficiently based on the maximum
23393   // vector access index and adjust the VectorMask and
23394   // VecIn accordingly.
23395   bool DidSplitVec = false;
23396   if (VecIn.size() == 2) {
23397     unsigned MaxIndex = 0;
23398     unsigned NearestPow2 = 0;
23399     SDValue Vec = VecIn.back();
23400     EVT InVT = Vec.getValueType();
23401     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
23402 
23403     for (unsigned i = 0; i < NumElems; i++) {
23404       if (VectorMask[i] <= 0)
23405         continue;
23406       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
23407       IndexVec[i] = Index;
23408       MaxIndex = std::max(MaxIndex, Index);
23409     }
23410 
23411     NearestPow2 = PowerOf2Ceil(MaxIndex);
23412     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
23413         NumElems * 2 < NearestPow2) {
23414       unsigned SplitSize = NearestPow2 / 2;
23415       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
23416                                      InVT.getVectorElementType(), SplitSize);
23417       if (TLI.isTypeLegal(SplitVT) &&
23418           SplitSize + SplitVT.getVectorNumElements() <=
23419               InVT.getVectorNumElements()) {
23420         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
23421                                      DAG.getVectorIdxConstant(SplitSize, DL));
23422         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
23423                                      DAG.getVectorIdxConstant(0, DL));
23424         VecIn.pop_back();
23425         VecIn.push_back(VecIn1);
23426         VecIn.push_back(VecIn2);
23427         DidSplitVec = true;
23428 
23429         for (unsigned i = 0; i < NumElems; i++) {
23430           if (VectorMask[i] <= 0)
23431             continue;
23432           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
23433         }
23434       }
23435     }
23436   }
23437 
23438   // Sort input vectors by decreasing vector element count,
23439   // while preserving the relative order of equally-sized vectors.
23440   // Note that we keep the first "implicit zero vector as-is.
23441   SmallVector<SDValue, 8> SortedVecIn(VecIn);
23442   llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
23443                     [](const SDValue &a, const SDValue &b) {
23444                       return a.getValueType().getVectorNumElements() >
23445                              b.getValueType().getVectorNumElements();
23446                     });
23447 
23448   // We now also need to rebuild the VectorMask, because it referenced element
23449   // order in VecIn, and we just sorted them.
23450   for (int &SourceVectorIndex : VectorMask) {
23451     if (SourceVectorIndex <= 0)
23452       continue;
23453     unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
23454     assert(Idx > 0 && Idx < SortedVecIn.size() &&
23455            VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
23456     SourceVectorIndex = Idx;
23457   }
23458 
23459   VecIn = std::move(SortedVecIn);
23460 
23461   // TODO: Should this fire if some of the input vectors has illegal type (like
23462   // it does now), or should we let legalization run its course first?
23463 
23464   // Shuffle phase:
23465   // Take pairs of vectors, and shuffle them so that the result has elements
23466   // from these vectors in the correct places.
23467   // For example, given:
23468   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
23469   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
23470   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
23471   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
23472   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
23473   // We will generate:
23474   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
23475   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
23476   SmallVector<SDValue, 4> Shuffles;
23477   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
23478     unsigned LeftIdx = 2 * In + 1;
23479     SDValue VecLeft = VecIn[LeftIdx];
23480     SDValue VecRight =
23481         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
23482 
23483     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
23484                                                 VecRight, LeftIdx, DidSplitVec))
23485       Shuffles.push_back(Shuffle);
23486     else
23487       return SDValue();
23488   }
23489 
23490   // If we need the zero vector as an "ingredient" in the blend tree, add it
23491   // to the list of shuffles.
23492   if (UsesZeroVector)
23493     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
23494                                       : DAG.getConstantFP(0.0, DL, VT));
23495 
23496   // If we only have one shuffle, we're done.
23497   if (Shuffles.size() == 1)
23498     return Shuffles[0];
23499 
23500   // Update the vector mask to point to the post-shuffle vectors.
23501   for (int &Vec : VectorMask)
23502     if (Vec == 0)
23503       Vec = Shuffles.size() - 1;
23504     else
23505       Vec = (Vec - 1) / 2;
23506 
23507   // More than one shuffle. Generate a binary tree of blends, e.g. if from
23508   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
23509   // generate:
23510   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
23511   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
23512   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
23513   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
23514   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
23515   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
23516   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
23517 
23518   // Make sure the initial size of the shuffle list is even.
23519   if (Shuffles.size() % 2)
23520     Shuffles.push_back(DAG.getUNDEF(VT));
23521 
23522   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
23523     if (CurSize % 2) {
23524       Shuffles[CurSize] = DAG.getUNDEF(VT);
23525       CurSize++;
23526     }
23527     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
23528       int Left = 2 * In;
23529       int Right = 2 * In + 1;
23530       SmallVector<int, 8> Mask(NumElems, -1);
23531       SDValue L = Shuffles[Left];
23532       ArrayRef<int> LMask;
23533       bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
23534                            L.use_empty() && L.getOperand(1).isUndef() &&
23535                            L.getOperand(0).getValueType() == L.getValueType();
23536       if (IsLeftShuffle) {
23537         LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
23538         L = L.getOperand(0);
23539       }
23540       SDValue R = Shuffles[Right];
23541       ArrayRef<int> RMask;
23542       bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
23543                             R.use_empty() && R.getOperand(1).isUndef() &&
23544                             R.getOperand(0).getValueType() == R.getValueType();
23545       if (IsRightShuffle) {
23546         RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
23547         R = R.getOperand(0);
23548       }
23549       for (unsigned I = 0; I != NumElems; ++I) {
23550         if (VectorMask[I] == Left) {
23551           Mask[I] = I;
23552           if (IsLeftShuffle)
23553             Mask[I] = LMask[I];
23554           VectorMask[I] = In;
23555         } else if (VectorMask[I] == Right) {
23556           Mask[I] = I + NumElems;
23557           if (IsRightShuffle)
23558             Mask[I] = RMask[I] + NumElems;
23559           VectorMask[I] = In;
23560         }
23561       }
23562 
23563       Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
23564     }
23565   }
23566   return Shuffles[0];
23567 }
23568 
23569 // Try to turn a build vector of zero extends of extract vector elts into a
23570 // a vector zero extend and possibly an extract subvector.
23571 // TODO: Support sign extend?
23572 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)23573 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
23574   if (LegalOperations)
23575     return SDValue();
23576 
23577   EVT VT = N->getValueType(0);
23578 
23579   bool FoundZeroExtend = false;
23580   SDValue Op0 = N->getOperand(0);
23581   auto checkElem = [&](SDValue Op) -> int64_t {
23582     unsigned Opc = Op.getOpcode();
23583     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
23584     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
23585         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23586         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
23587       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
23588         return C->getZExtValue();
23589     return -1;
23590   };
23591 
23592   // Make sure the first element matches
23593   // (zext (extract_vector_elt X, C))
23594   // Offset must be a constant multiple of the
23595   // known-minimum vector length of the result type.
23596   int64_t Offset = checkElem(Op0);
23597   if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
23598     return SDValue();
23599 
23600   unsigned NumElems = N->getNumOperands();
23601   SDValue In = Op0.getOperand(0).getOperand(0);
23602   EVT InSVT = In.getValueType().getScalarType();
23603   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
23604 
23605   // Don't create an illegal input type after type legalization.
23606   if (LegalTypes && !TLI.isTypeLegal(InVT))
23607     return SDValue();
23608 
23609   // Ensure all the elements come from the same vector and are adjacent.
23610   for (unsigned i = 1; i != NumElems; ++i) {
23611     if ((Offset + i) != checkElem(N->getOperand(i)))
23612       return SDValue();
23613   }
23614 
23615   SDLoc DL(N);
23616   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
23617                    Op0.getOperand(0).getOperand(1));
23618   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
23619                      VT, In);
23620 }
23621 
23622 // If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
23623 // and all other elements being constant zero's, granularize the BUILD_VECTOR's
23624 // element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
23625 // This patten can appear during legalization.
23626 //
23627 // NOTE: This can be generalized to allow more than a single
23628 //       non-constant-zero op, UNDEF's, and to be KnownBits-based,
convertBuildVecZextToBuildVecWithZeros(SDNode * N)23629 SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
23630   // Don't run this after legalization. Targets may have other preferences.
23631   if (Level >= AfterLegalizeDAG)
23632     return SDValue();
23633 
23634   // FIXME: support big-endian.
23635   if (DAG.getDataLayout().isBigEndian())
23636     return SDValue();
23637 
23638   EVT VT = N->getValueType(0);
23639   EVT OpVT = N->getOperand(0).getValueType();
23640   assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
23641 
23642   EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
23643 
23644   if (!TLI.isTypeLegal(OpIntVT) ||
23645       (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
23646     return SDValue();
23647 
23648   unsigned EltBitwidth = VT.getScalarSizeInBits();
23649   // NOTE: the actual width of operands may be wider than that!
23650 
23651   // Analyze all operands of this BUILD_VECTOR. What is the largest number of
23652   // active bits they all have? We'll want to truncate them all to that width.
23653   unsigned ActiveBits = 0;
23654   APInt KnownZeroOps(VT.getVectorNumElements(), 0);
23655   for (auto I : enumerate(N->ops())) {
23656     SDValue Op = I.value();
23657     // FIXME: support UNDEF elements?
23658     if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) {
23659       unsigned OpActiveBits =
23660           Cst->getAPIntValue().trunc(EltBitwidth).getActiveBits();
23661       if (OpActiveBits == 0) {
23662         KnownZeroOps.setBit(I.index());
23663         continue;
23664       }
23665       // Profitability check: don't allow non-zero constant operands.
23666       return SDValue();
23667     }
23668     // Profitability check: there must only be a single non-zero operand,
23669     // and it must be the first operand of the BUILD_VECTOR.
23670     if (I.index() != 0)
23671       return SDValue();
23672     // The operand must be a zero-extension itself.
23673     // FIXME: this could be generalized to known leading zeros check.
23674     if (Op.getOpcode() != ISD::ZERO_EXTEND)
23675       return SDValue();
23676     unsigned CurrActiveBits =
23677         Op.getOperand(0).getValueSizeInBits().getFixedValue();
23678     assert(!ActiveBits && "Already encountered non-constant-zero operand?");
23679     ActiveBits = CurrActiveBits;
23680     // We want to at least halve the element size.
23681     if (2 * ActiveBits > EltBitwidth)
23682       return SDValue();
23683   }
23684 
23685   // This BUILD_VECTOR must have at least one non-constant-zero operand.
23686   if (ActiveBits == 0)
23687     return SDValue();
23688 
23689   // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
23690   // into how many chunks can we split our element width?
23691   EVT NewScalarIntVT, NewIntVT;
23692   std::optional<unsigned> Factor;
23693   // We can split the element into at least two chunks, but not into more
23694   // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
23695   // for which the element width is a multiple of it,
23696   // and the resulting types/operations on that chunk width are legal.
23697   assert(2 * ActiveBits <= EltBitwidth &&
23698          "We know that half or less bits of the element are active.");
23699   for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
23700     if (EltBitwidth % Scale != 0)
23701       continue;
23702     unsigned ChunkBitwidth = EltBitwidth / Scale;
23703     assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
23704     NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
23705     NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
23706                                 Scale * N->getNumOperands());
23707     if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
23708         (LegalOperations &&
23709          !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
23710            TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
23711       continue;
23712     Factor = Scale;
23713     break;
23714   }
23715   if (!Factor)
23716     return SDValue();
23717 
23718   SDLoc DL(N);
23719   SDValue ZeroOp = DAG.getConstant(0, DL, NewScalarIntVT);
23720 
23721   // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
23722   SmallVector<SDValue, 16> NewOps;
23723   NewOps.reserve(NewIntVT.getVectorNumElements());
23724   for (auto I : enumerate(N->ops())) {
23725     SDValue Op = I.value();
23726     assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
23727     unsigned SrcOpIdx = I.index();
23728     if (KnownZeroOps[SrcOpIdx]) {
23729       NewOps.append(*Factor, ZeroOp);
23730       continue;
23731     }
23732     Op = DAG.getBitcast(OpIntVT, Op);
23733     Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
23734     NewOps.emplace_back(Op);
23735     NewOps.append(*Factor - 1, ZeroOp);
23736   }
23737   assert(NewOps.size() == NewIntVT.getVectorNumElements());
23738   SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
23739   NewBV = DAG.getBitcast(VT, NewBV);
23740   return NewBV;
23741 }
23742 
visitBUILD_VECTOR(SDNode * N)23743 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
23744   EVT VT = N->getValueType(0);
23745 
23746   // A vector built entirely of undefs is undef.
23747   if (ISD::allOperandsUndef(N))
23748     return DAG.getUNDEF(VT);
23749 
23750   // If this is a splat of a bitcast from another vector, change to a
23751   // concat_vector.
23752   // For example:
23753   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
23754   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
23755   //
23756   // If X is a build_vector itself, the concat can become a larger build_vector.
23757   // TODO: Maybe this is useful for non-splat too?
23758   if (!LegalOperations) {
23759     SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue();
23760     // Only change build_vector to a concat_vector if the splat value type is
23761     // same as the vector element type.
23762     if (Splat && Splat.getValueType() == VT.getVectorElementType()) {
23763       Splat = peekThroughBitcasts(Splat);
23764       EVT SrcVT = Splat.getValueType();
23765       if (SrcVT.isVector()) {
23766         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
23767         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
23768                                      SrcVT.getVectorElementType(), NumElts);
23769         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
23770           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
23771           SDValue Concat =
23772               DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), NewVT, Ops);
23773           return DAG.getBitcast(VT, Concat);
23774         }
23775       }
23776     }
23777   }
23778 
23779   // Check if we can express BUILD VECTOR via subvector extract.
23780   if (!LegalTypes && (N->getNumOperands() > 1)) {
23781     SDValue Op0 = N->getOperand(0);
23782     auto checkElem = [&](SDValue Op) -> uint64_t {
23783       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
23784           (Op0.getOperand(0) == Op.getOperand(0)))
23785         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
23786           return CNode->getZExtValue();
23787       return -1;
23788     };
23789 
23790     int Offset = checkElem(Op0);
23791     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
23792       if (Offset + i != checkElem(N->getOperand(i))) {
23793         Offset = -1;
23794         break;
23795       }
23796     }
23797 
23798     if ((Offset == 0) &&
23799         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
23800       return Op0.getOperand(0);
23801     if ((Offset != -1) &&
23802         ((Offset % N->getValueType(0).getVectorNumElements()) ==
23803          0)) // IDX must be multiple of output size.
23804       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
23805                          Op0.getOperand(0), Op0.getOperand(1));
23806   }
23807 
23808   if (SDValue V = convertBuildVecZextToZext(N))
23809     return V;
23810 
23811   if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
23812     return V;
23813 
23814   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
23815     return V;
23816 
23817   if (SDValue V = reduceBuildVecTruncToBitCast(N))
23818     return V;
23819 
23820   if (SDValue V = reduceBuildVecToShuffle(N))
23821     return V;
23822 
23823   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
23824   // Do this late as some of the above may replace the splat.
23825   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
23826     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
23827       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
23828       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
23829     }
23830 
23831   return SDValue();
23832 }
23833 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)23834 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
23835   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23836   EVT OpVT = N->getOperand(0).getValueType();
23837 
23838   // If the operands are legal vectors, leave them alone.
23839   if (TLI.isTypeLegal(OpVT) || OpVT.isScalableVector())
23840     return SDValue();
23841 
23842   SDLoc DL(N);
23843   EVT VT = N->getValueType(0);
23844   SmallVector<SDValue, 8> Ops;
23845   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
23846 
23847   // Keep track of what we encounter.
23848   bool AnyInteger = false;
23849   bool AnyFP = false;
23850   for (const SDValue &Op : N->ops()) {
23851     if (ISD::BITCAST == Op.getOpcode() &&
23852         !Op.getOperand(0).getValueType().isVector())
23853       Ops.push_back(Op.getOperand(0));
23854     else if (ISD::UNDEF == Op.getOpcode())
23855       Ops.push_back(DAG.getNode(ISD::UNDEF, DL, SVT));
23856     else
23857       return SDValue();
23858 
23859     // Note whether we encounter an integer or floating point scalar.
23860     // If it's neither, bail out, it could be something weird like x86mmx.
23861     EVT LastOpVT = Ops.back().getValueType();
23862     if (LastOpVT.isFloatingPoint())
23863       AnyFP = true;
23864     else if (LastOpVT.isInteger())
23865       AnyInteger = true;
23866     else
23867       return SDValue();
23868   }
23869 
23870   // If any of the operands is a floating point scalar bitcast to a vector,
23871   // use floating point types throughout, and bitcast everything.
23872   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
23873   if (AnyFP) {
23874     SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
23875     if (AnyInteger) {
23876       for (SDValue &Op : Ops) {
23877         if (Op.getValueType() == SVT)
23878           continue;
23879         if (Op.isUndef())
23880           Op = DAG.getNode(ISD::UNDEF, DL, SVT);
23881         else
23882           Op = DAG.getBitcast(SVT, Op);
23883       }
23884     }
23885   }
23886 
23887   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
23888                                VT.getSizeInBits() / SVT.getSizeInBits());
23889   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
23890 }
23891 
23892 // Attempt to merge nested concat_vectors/undefs.
23893 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
23894 //  --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)23895 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
23896                                                   SelectionDAG &DAG) {
23897   EVT VT = N->getValueType(0);
23898 
23899   // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
23900   EVT SubVT;
23901   SDValue FirstConcat;
23902   for (const SDValue &Op : N->ops()) {
23903     if (Op.isUndef())
23904       continue;
23905     if (Op.getOpcode() != ISD::CONCAT_VECTORS)
23906       return SDValue();
23907     if (!FirstConcat) {
23908       SubVT = Op.getOperand(0).getValueType();
23909       if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
23910         return SDValue();
23911       FirstConcat = Op;
23912       continue;
23913     }
23914     if (SubVT != Op.getOperand(0).getValueType())
23915       return SDValue();
23916   }
23917   assert(FirstConcat && "Concat of all-undefs found");
23918 
23919   SmallVector<SDValue> ConcatOps;
23920   for (const SDValue &Op : N->ops()) {
23921     if (Op.isUndef()) {
23922       ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
23923       continue;
23924     }
23925     ConcatOps.append(Op->op_begin(), Op->op_end());
23926   }
23927   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
23928 }
23929 
23930 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
23931 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
23932 // most two distinct vectors the same size as the result, attempt to turn this
23933 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)23934 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
23935   EVT VT = N->getValueType(0);
23936   EVT OpVT = N->getOperand(0).getValueType();
23937 
23938   // We currently can't generate an appropriate shuffle for a scalable vector.
23939   if (VT.isScalableVector())
23940     return SDValue();
23941 
23942   int NumElts = VT.getVectorNumElements();
23943   int NumOpElts = OpVT.getVectorNumElements();
23944 
23945   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
23946   SmallVector<int, 8> Mask;
23947 
23948   for (SDValue Op : N->ops()) {
23949     Op = peekThroughBitcasts(Op);
23950 
23951     // UNDEF nodes convert to UNDEF shuffle mask values.
23952     if (Op.isUndef()) {
23953       Mask.append((unsigned)NumOpElts, -1);
23954       continue;
23955     }
23956 
23957     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
23958       return SDValue();
23959 
23960     // What vector are we extracting the subvector from and at what index?
23961     SDValue ExtVec = Op.getOperand(0);
23962     int ExtIdx = Op.getConstantOperandVal(1);
23963 
23964     // We want the EVT of the original extraction to correctly scale the
23965     // extraction index.
23966     EVT ExtVT = ExtVec.getValueType();
23967     ExtVec = peekThroughBitcasts(ExtVec);
23968 
23969     // UNDEF nodes convert to UNDEF shuffle mask values.
23970     if (ExtVec.isUndef()) {
23971       Mask.append((unsigned)NumOpElts, -1);
23972       continue;
23973     }
23974 
23975     // Ensure that we are extracting a subvector from a vector the same
23976     // size as the result.
23977     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
23978       return SDValue();
23979 
23980     // Scale the subvector index to account for any bitcast.
23981     int NumExtElts = ExtVT.getVectorNumElements();
23982     if (0 == (NumExtElts % NumElts))
23983       ExtIdx /= (NumExtElts / NumElts);
23984     else if (0 == (NumElts % NumExtElts))
23985       ExtIdx *= (NumElts / NumExtElts);
23986     else
23987       return SDValue();
23988 
23989     // At most we can reference 2 inputs in the final shuffle.
23990     if (SV0.isUndef() || SV0 == ExtVec) {
23991       SV0 = ExtVec;
23992       for (int i = 0; i != NumOpElts; ++i)
23993         Mask.push_back(i + ExtIdx);
23994     } else if (SV1.isUndef() || SV1 == ExtVec) {
23995       SV1 = ExtVec;
23996       for (int i = 0; i != NumOpElts; ++i)
23997         Mask.push_back(i + ExtIdx + NumElts);
23998     } else {
23999       return SDValue();
24000     }
24001   }
24002 
24003   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24004   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
24005                                      DAG.getBitcast(VT, SV1), Mask, DAG);
24006 }
24007 
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)24008 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
24009   unsigned CastOpcode = N->getOperand(0).getOpcode();
24010   switch (CastOpcode) {
24011   case ISD::SINT_TO_FP:
24012   case ISD::UINT_TO_FP:
24013   case ISD::FP_TO_SINT:
24014   case ISD::FP_TO_UINT:
24015     // TODO: Allow more opcodes?
24016     //  case ISD::BITCAST:
24017     //  case ISD::TRUNCATE:
24018     //  case ISD::ZERO_EXTEND:
24019     //  case ISD::SIGN_EXTEND:
24020     //  case ISD::FP_EXTEND:
24021     break;
24022   default:
24023     return SDValue();
24024   }
24025 
24026   EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
24027   if (!SrcVT.isVector())
24028     return SDValue();
24029 
24030   // All operands of the concat must be the same kind of cast from the same
24031   // source type.
24032   SmallVector<SDValue, 4> SrcOps;
24033   for (SDValue Op : N->ops()) {
24034     if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
24035         Op.getOperand(0).getValueType() != SrcVT)
24036       return SDValue();
24037     SrcOps.push_back(Op.getOperand(0));
24038   }
24039 
24040   // The wider cast must be supported by the target. This is unusual because
24041   // the operation support type parameter depends on the opcode. In addition,
24042   // check the other type in the cast to make sure this is really legal.
24043   EVT VT = N->getValueType(0);
24044   EVT SrcEltVT = SrcVT.getVectorElementType();
24045   ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
24046   EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
24047   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24048   switch (CastOpcode) {
24049   case ISD::SINT_TO_FP:
24050   case ISD::UINT_TO_FP:
24051     if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
24052         !TLI.isTypeLegal(VT))
24053       return SDValue();
24054     break;
24055   case ISD::FP_TO_SINT:
24056   case ISD::FP_TO_UINT:
24057     if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
24058         !TLI.isTypeLegal(ConcatSrcVT))
24059       return SDValue();
24060     break;
24061   default:
24062     llvm_unreachable("Unexpected cast opcode");
24063   }
24064 
24065   // concat (cast X), (cast Y)... -> cast (concat X, Y...)
24066   SDLoc DL(N);
24067   SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
24068   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
24069 }
24070 
24071 // See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
24072 // the operands is a SHUFFLE_VECTOR, and all other operands are also operands
24073 // to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
combineConcatVectorOfShuffleAndItsOperands(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)24074 static SDValue combineConcatVectorOfShuffleAndItsOperands(
24075     SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
24076     bool LegalOperations) {
24077   EVT VT = N->getValueType(0);
24078   EVT OpVT = N->getOperand(0).getValueType();
24079   if (VT.isScalableVector())
24080     return SDValue();
24081 
24082   // For now, only allow simple 2-operand concatenations.
24083   if (N->getNumOperands() != 2)
24084     return SDValue();
24085 
24086   // Don't create illegal types/shuffles when not allowed to.
24087   if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
24088       (LegalOperations &&
24089        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
24090     return SDValue();
24091 
24092   // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
24093   // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
24094   // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
24095   // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
24096   // (4) and for now, the SHUFFLE_VECTOR must be unary.
24097   ShuffleVectorSDNode *SVN = nullptr;
24098   for (SDValue Op : N->ops()) {
24099     if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
24100         CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
24101         all_of(N->ops(), [CurSVN](SDValue Op) {
24102           // FIXME: can we allow UNDEF operands?
24103           return !Op.isUndef() &&
24104                  (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
24105         })) {
24106       SVN = CurSVN;
24107       break;
24108     }
24109   }
24110   if (!SVN)
24111     return SDValue();
24112 
24113   // We are going to pad the shuffle operands, so any indice, that was picking
24114   // from the second operand, must be adjusted.
24115   SmallVector<int, 16> AdjustedMask;
24116   AdjustedMask.reserve(SVN->getMask().size());
24117   assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
24118   append_range(AdjustedMask, SVN->getMask());
24119 
24120   // Identity masks for the operands of the (padded) shuffle.
24121   SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
24122   MutableArrayRef<int> FirstShufOpIdentityMask =
24123       MutableArrayRef<int>(IdentityMask)
24124           .take_front(OpVT.getVectorNumElements());
24125   MutableArrayRef<int> SecondShufOpIdentityMask =
24126       MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
24127   std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
24128   std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
24129             VT.getVectorNumElements());
24130 
24131   // New combined shuffle mask.
24132   SmallVector<int, 32> Mask;
24133   Mask.reserve(VT.getVectorNumElements());
24134   for (SDValue Op : N->ops()) {
24135     assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
24136     if (Op.getNode() == SVN) {
24137       append_range(Mask, AdjustedMask);
24138       continue;
24139     }
24140     if (Op == SVN->getOperand(0)) {
24141       append_range(Mask, FirstShufOpIdentityMask);
24142       continue;
24143     }
24144     if (Op == SVN->getOperand(1)) {
24145       append_range(Mask, SecondShufOpIdentityMask);
24146       continue;
24147     }
24148     llvm_unreachable("Unexpected operand!");
24149   }
24150 
24151   // Don't create illegal shuffle masks.
24152   if (!TLI.isShuffleMaskLegal(Mask, VT))
24153     return SDValue();
24154 
24155   // Pad the shuffle operands with UNDEF.
24156   SDLoc dl(N);
24157   std::array<SDValue, 2> ShufOps;
24158   for (auto I : zip(SVN->ops(), ShufOps)) {
24159     SDValue ShufOp = std::get<0>(I);
24160     SDValue &NewShufOp = std::get<1>(I);
24161     if (ShufOp.isUndef())
24162       NewShufOp = DAG.getUNDEF(VT);
24163     else {
24164       SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
24165                                           DAG.getUNDEF(OpVT));
24166       ShufOpParts[0] = ShufOp;
24167       NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
24168     }
24169   }
24170   // Finally, create the new wide shuffle.
24171   return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
24172 }
24173 
visitCONCAT_VECTORS(SDNode * N)24174 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
24175   // If we only have one input vector, we don't need to do any concatenation.
24176   if (N->getNumOperands() == 1)
24177     return N->getOperand(0);
24178 
24179   // Check if all of the operands are undefs.
24180   EVT VT = N->getValueType(0);
24181   if (ISD::allOperandsUndef(N))
24182     return DAG.getUNDEF(VT);
24183 
24184   // Optimize concat_vectors where all but the first of the vectors are undef.
24185   if (all_of(drop_begin(N->ops()),
24186              [](const SDValue &Op) { return Op.isUndef(); })) {
24187     SDValue In = N->getOperand(0);
24188     assert(In.getValueType().isVector() && "Must concat vectors");
24189 
24190     // If the input is a concat_vectors, just make a larger concat by padding
24191     // with smaller undefs.
24192     //
24193     // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
24194     // here could cause an infinite loop. That legalizing happens when LegalDAG
24195     // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
24196     // scalable.
24197     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
24198         !(LegalDAG && In.getValueType().isScalableVector())) {
24199       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
24200       SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
24201       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
24202       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
24203     }
24204 
24205     SDValue Scalar = peekThroughOneUseBitcasts(In);
24206 
24207     // concat_vectors(scalar_to_vector(scalar), undef) ->
24208     //     scalar_to_vector(scalar)
24209     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
24210          Scalar.hasOneUse()) {
24211       EVT SVT = Scalar.getValueType().getVectorElementType();
24212       if (SVT == Scalar.getOperand(0).getValueType())
24213         Scalar = Scalar.getOperand(0);
24214     }
24215 
24216     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
24217     if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
24218       // If the bitcast type isn't legal, it might be a trunc of a legal type;
24219       // look through the trunc so we can still do the transform:
24220       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
24221       if (Scalar->getOpcode() == ISD::TRUNCATE &&
24222           !TLI.isTypeLegal(Scalar.getValueType()) &&
24223           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
24224         Scalar = Scalar->getOperand(0);
24225 
24226       EVT SclTy = Scalar.getValueType();
24227 
24228       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
24229         return SDValue();
24230 
24231       // Bail out if the vector size is not a multiple of the scalar size.
24232       if (VT.getSizeInBits() % SclTy.getSizeInBits())
24233         return SDValue();
24234 
24235       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
24236       if (VNTNumElms < 2)
24237         return SDValue();
24238 
24239       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
24240       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
24241         return SDValue();
24242 
24243       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
24244       return DAG.getBitcast(VT, Res);
24245     }
24246   }
24247 
24248   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
24249   // We have already tested above for an UNDEF only concatenation.
24250   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
24251   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
24252   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
24253     return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
24254   };
24255   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
24256     SmallVector<SDValue, 8> Opnds;
24257     EVT SVT = VT.getScalarType();
24258 
24259     EVT MinVT = SVT;
24260     if (!SVT.isFloatingPoint()) {
24261       // If BUILD_VECTOR are from built from integer, they may have different
24262       // operand types. Get the smallest type and truncate all operands to it.
24263       bool FoundMinVT = false;
24264       for (const SDValue &Op : N->ops())
24265         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
24266           EVT OpSVT = Op.getOperand(0).getValueType();
24267           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
24268           FoundMinVT = true;
24269         }
24270       assert(FoundMinVT && "Concat vector type mismatch");
24271     }
24272 
24273     for (const SDValue &Op : N->ops()) {
24274       EVT OpVT = Op.getValueType();
24275       unsigned NumElts = OpVT.getVectorNumElements();
24276 
24277       if (ISD::UNDEF == Op.getOpcode())
24278         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
24279 
24280       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
24281         if (SVT.isFloatingPoint()) {
24282           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
24283           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
24284         } else {
24285           for (unsigned i = 0; i != NumElts; ++i)
24286             Opnds.push_back(
24287                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
24288         }
24289       }
24290     }
24291 
24292     assert(VT.getVectorNumElements() == Opnds.size() &&
24293            "Concat vector type mismatch");
24294     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
24295   }
24296 
24297   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
24298   // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
24299   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
24300     return V;
24301 
24302   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
24303     // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
24304     if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
24305       return V;
24306 
24307     // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
24308     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
24309       return V;
24310   }
24311 
24312   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
24313     return V;
24314 
24315   if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
24316           N, DAG, TLI, LegalTypes, LegalOperations))
24317     return V;
24318 
24319   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
24320   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
24321   // operands and look for a CONCAT operations that place the incoming vectors
24322   // at the exact same location.
24323   //
24324   // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
24325   SDValue SingleSource = SDValue();
24326   unsigned PartNumElem =
24327       N->getOperand(0).getValueType().getVectorMinNumElements();
24328 
24329   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
24330     SDValue Op = N->getOperand(i);
24331 
24332     if (Op.isUndef())
24333       continue;
24334 
24335     // Check if this is the identity extract:
24336     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
24337       return SDValue();
24338 
24339     // Find the single incoming vector for the extract_subvector.
24340     if (SingleSource.getNode()) {
24341       if (Op.getOperand(0) != SingleSource)
24342         return SDValue();
24343     } else {
24344       SingleSource = Op.getOperand(0);
24345 
24346       // Check the source type is the same as the type of the result.
24347       // If not, this concat may extend the vector, so we can not
24348       // optimize it away.
24349       if (SingleSource.getValueType() != N->getValueType(0))
24350         return SDValue();
24351     }
24352 
24353     // Check that we are reading from the identity index.
24354     unsigned IdentityIndex = i * PartNumElem;
24355     if (Op.getConstantOperandAPInt(1) != IdentityIndex)
24356       return SDValue();
24357   }
24358 
24359   if (SingleSource.getNode())
24360     return SingleSource;
24361 
24362   return SDValue();
24363 }
24364 
24365 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
24366 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)24367 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
24368   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
24369       V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
24370     return V.getOperand(1);
24371   }
24372   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
24373   if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
24374       V.getOperand(0).getValueType() == SubVT &&
24375       (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
24376     uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
24377     return V.getOperand(SubIdx);
24378   }
24379   return SDValue();
24380 }
24381 
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)24382 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
24383                                               SelectionDAG &DAG,
24384                                               bool LegalOperations) {
24385   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24386   SDValue BinOp = Extract->getOperand(0);
24387   unsigned BinOpcode = BinOp.getOpcode();
24388   if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
24389     return SDValue();
24390 
24391   EVT VecVT = BinOp.getValueType();
24392   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
24393   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
24394     return SDValue();
24395 
24396   SDValue Index = Extract->getOperand(1);
24397   EVT SubVT = Extract->getValueType(0);
24398   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
24399     return SDValue();
24400 
24401   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
24402   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
24403 
24404   // TODO: We could handle the case where only 1 operand is being inserted by
24405   //       creating an extract of the other operand, but that requires checking
24406   //       number of uses and/or costs.
24407   if (!Sub0 || !Sub1)
24408     return SDValue();
24409 
24410   // We are inserting both operands of the wide binop only to extract back
24411   // to the narrow vector size. Eliminate all of the insert/extract:
24412   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
24413   return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
24414                      BinOp->getFlags());
24415 }
24416 
24417 /// If we are extracting a subvector produced by a wide binary operator try
24418 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)24419 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
24420                                           bool LegalOperations) {
24421   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
24422   // some of these bailouts with other transforms.
24423 
24424   if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
24425     return V;
24426 
24427   // The extract index must be a constant, so we can map it to a concat operand.
24428   auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
24429   if (!ExtractIndexC)
24430     return SDValue();
24431 
24432   // We are looking for an optionally bitcasted wide vector binary operator
24433   // feeding an extract subvector.
24434   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24435   SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
24436   unsigned BOpcode = BinOp.getOpcode();
24437   if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
24438     return SDValue();
24439 
24440   // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
24441   // reduced to the unary fneg when it is visited, and we probably want to deal
24442   // with fneg in a target-specific way.
24443   if (BOpcode == ISD::FSUB) {
24444     auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
24445     if (C && C->getValueAPF().isNegZero())
24446       return SDValue();
24447   }
24448 
24449   // The binop must be a vector type, so we can extract some fraction of it.
24450   EVT WideBVT = BinOp.getValueType();
24451   // The optimisations below currently assume we are dealing with fixed length
24452   // vectors. It is possible to add support for scalable vectors, but at the
24453   // moment we've done no analysis to prove whether they are profitable or not.
24454   if (!WideBVT.isFixedLengthVector())
24455     return SDValue();
24456 
24457   EVT VT = Extract->getValueType(0);
24458   unsigned ExtractIndex = ExtractIndexC->getZExtValue();
24459   assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
24460          "Extract index is not a multiple of the vector length.");
24461 
24462   // Bail out if this is not a proper multiple width extraction.
24463   unsigned WideWidth = WideBVT.getSizeInBits();
24464   unsigned NarrowWidth = VT.getSizeInBits();
24465   if (WideWidth % NarrowWidth != 0)
24466     return SDValue();
24467 
24468   // Bail out if we are extracting a fraction of a single operation. This can
24469   // occur because we potentially looked through a bitcast of the binop.
24470   unsigned NarrowingRatio = WideWidth / NarrowWidth;
24471   unsigned WideNumElts = WideBVT.getVectorNumElements();
24472   if (WideNumElts % NarrowingRatio != 0)
24473     return SDValue();
24474 
24475   // Bail out if the target does not support a narrower version of the binop.
24476   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
24477                                    WideNumElts / NarrowingRatio);
24478   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT,
24479                                              LegalOperations))
24480     return SDValue();
24481 
24482   // If extraction is cheap, we don't need to look at the binop operands
24483   // for concat ops. The narrow binop alone makes this transform profitable.
24484   // We can't just reuse the original extract index operand because we may have
24485   // bitcasted.
24486   unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
24487   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
24488   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
24489       BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
24490     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
24491     SDLoc DL(Extract);
24492     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
24493     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24494                             BinOp.getOperand(0), NewExtIndex);
24495     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24496                             BinOp.getOperand(1), NewExtIndex);
24497     SDValue NarrowBinOp =
24498         DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
24499     return DAG.getBitcast(VT, NarrowBinOp);
24500   }
24501 
24502   // Only handle the case where we are doubling and then halving. A larger ratio
24503   // may require more than two narrow binops to replace the wide binop.
24504   if (NarrowingRatio != 2)
24505     return SDValue();
24506 
24507   // TODO: The motivating case for this transform is an x86 AVX1 target. That
24508   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
24509   // flavors, but no other 256-bit integer support. This could be extended to
24510   // handle any binop, but that may require fixing/adding other folds to avoid
24511   // codegen regressions.
24512   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
24513     return SDValue();
24514 
24515   // We need at least one concatenation operation of a binop operand to make
24516   // this transform worthwhile. The concat must double the input vector sizes.
24517   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
24518     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
24519       return V.getOperand(ConcatOpNum);
24520     return SDValue();
24521   };
24522   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
24523   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
24524 
24525   if (SubVecL || SubVecR) {
24526     // If a binop operand was not the result of a concat, we must extract a
24527     // half-sized operand for our new narrow binop:
24528     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
24529     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
24530     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
24531     SDLoc DL(Extract);
24532     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
24533     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
24534                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24535                                       BinOp.getOperand(0), IndexC);
24536 
24537     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
24538                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
24539                                       BinOp.getOperand(1), IndexC);
24540 
24541     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
24542     return DAG.getBitcast(VT, NarrowBinOp);
24543   }
24544 
24545   return SDValue();
24546 }
24547 
24548 /// If we are extracting a subvector from a wide vector load, convert to a
24549 /// narrow load to eliminate the extraction:
24550 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)24551 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
24552   // TODO: Add support for big-endian. The offset calculation must be adjusted.
24553   if (DAG.getDataLayout().isBigEndian())
24554     return SDValue();
24555 
24556   auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
24557   if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
24558     return SDValue();
24559 
24560   // Allow targets to opt-out.
24561   EVT VT = Extract->getValueType(0);
24562 
24563   // We can only create byte sized loads.
24564   if (!VT.isByteSized())
24565     return SDValue();
24566 
24567   unsigned Index = Extract->getConstantOperandVal(1);
24568   unsigned NumElts = VT.getVectorMinNumElements();
24569   // A fixed length vector being extracted from a scalable vector
24570   // may not be any *smaller* than the scalable one.
24571   if (Index == 0 && NumElts >= Ld->getValueType(0).getVectorMinNumElements())
24572     return SDValue();
24573 
24574   // The definition of EXTRACT_SUBVECTOR states that the index must be a
24575   // multiple of the minimum number of elements in the result type.
24576   assert(Index % NumElts == 0 && "The extract subvector index is not a "
24577                                  "multiple of the result's element count");
24578 
24579   // It's fine to use TypeSize here as we know the offset will not be negative.
24580   TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
24581 
24582   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24583   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
24584     return SDValue();
24585 
24586   // The narrow load will be offset from the base address of the old load if
24587   // we are extracting from something besides index 0 (little-endian).
24588   SDLoc DL(Extract);
24589 
24590   // TODO: Use "BaseIndexOffset" to make this more effective.
24591   SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
24592 
24593   LocationSize StoreSize = LocationSize::precise(VT.getStoreSize());
24594   MachineFunction &MF = DAG.getMachineFunction();
24595   MachineMemOperand *MMO;
24596   if (Offset.isScalable()) {
24597     MachinePointerInfo MPI =
24598         MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
24599     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
24600   } else
24601     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedValue(),
24602                                   StoreSize);
24603 
24604   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
24605   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
24606   return NewLd;
24607 }
24608 
24609 /// Given  EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
24610 /// try to produce  VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
24611 ///                                EXTRACT_SUBVECTOR(Op?, ?),
24612 ///                                Mask'))
24613 /// iff it is legal and profitable to do so. Notably, the trimmed mask
24614 /// (containing only the elements that are extracted)
24615 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)24616 static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
24617                                                      SelectionDAG &DAG,
24618                                                      const TargetLowering &TLI,
24619                                                      bool LegalOperations) {
24620   assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24621          "Must only be called on EXTRACT_SUBVECTOR's");
24622 
24623   SDValue N0 = N->getOperand(0);
24624 
24625   // Only deal with non-scalable vectors.
24626   EVT NarrowVT = N->getValueType(0);
24627   EVT WideVT = N0.getValueType();
24628   if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
24629     return SDValue();
24630 
24631   // The operand must be a shufflevector.
24632   auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
24633   if (!WideShuffleVector)
24634     return SDValue();
24635 
24636   // The old shuffleneeds to go away.
24637   if (!WideShuffleVector->hasOneUse())
24638     return SDValue();
24639 
24640   // And the narrow shufflevector that we'll form must be legal.
24641   if (LegalOperations &&
24642       !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
24643     return SDValue();
24644 
24645   uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
24646   int NumEltsExtracted = NarrowVT.getVectorNumElements();
24647   assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
24648          "Extract index is not a multiple of the output vector length.");
24649 
24650   int WideNumElts = WideVT.getVectorNumElements();
24651 
24652   SmallVector<int, 16> NewMask;
24653   NewMask.reserve(NumEltsExtracted);
24654   SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
24655       DemandedSubvectors;
24656 
24657   // Try to decode the wide mask into narrow mask from at most two subvectors.
24658   for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
24659                                                   NumEltsExtracted)) {
24660     assert((M >= -1) && (M < (2 * WideNumElts)) &&
24661            "Out-of-bounds shuffle mask?");
24662 
24663     if (M < 0) {
24664       // Does not depend on operands, does not require adjustment.
24665       NewMask.emplace_back(M);
24666       continue;
24667     }
24668 
24669     // From which operand of the shuffle does this shuffle mask element pick?
24670     int WideShufOpIdx = M / WideNumElts;
24671     // Which element of that operand is picked?
24672     int OpEltIdx = M % WideNumElts;
24673 
24674     assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
24675            "Shuffle mask vector decomposition failure.");
24676 
24677     // And which NumEltsExtracted-sized subvector of that operand is that?
24678     int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
24679     // And which element within that subvector of that operand is that?
24680     int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
24681 
24682     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
24683            "Shuffle mask subvector decomposition failure.");
24684 
24685     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
24686             WideShufOpIdx * WideNumElts) == M &&
24687            "Shuffle mask full decomposition failure.");
24688 
24689     SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
24690 
24691     if (Op.isUndef()) {
24692       // Picking from an undef operand. Let's adjust mask instead.
24693       NewMask.emplace_back(-1);
24694       continue;
24695     }
24696 
24697     const std::pair<SDValue, int> DemandedSubvector =
24698         std::make_pair(Op, OpSubvecIdx);
24699 
24700     if (DemandedSubvectors.insert(DemandedSubvector)) {
24701       if (DemandedSubvectors.size() > 2)
24702         return SDValue(); // We can't handle more than two subvectors.
24703       // How many elements into the WideVT does this subvector start?
24704       int Index = NumEltsExtracted * OpSubvecIdx;
24705       // Bail out if the extraction isn't going to be cheap.
24706       if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
24707         return SDValue();
24708     }
24709 
24710     // Ok, but from which operand of the new shuffle will this element pick?
24711     int NewOpIdx =
24712         getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
24713     assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
24714 
24715     int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
24716     NewMask.emplace_back(AdjM);
24717   }
24718   assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
24719   assert(DemandedSubvectors.size() <= 2 &&
24720          "Should have ended up demanding at most two subvectors.");
24721 
24722   // Did we discover that the shuffle does not actually depend on operands?
24723   if (DemandedSubvectors.empty())
24724     return DAG.getUNDEF(NarrowVT);
24725 
24726   // Profitability check: only deal with extractions from the first subvector
24727   // unless the mask becomes an identity mask.
24728   if (!ShuffleVectorInst::isIdentityMask(NewMask, NewMask.size()) ||
24729       any_of(NewMask, [](int M) { return M < 0; }))
24730     for (auto &DemandedSubvector : DemandedSubvectors)
24731       if (DemandedSubvector.second != 0)
24732         return SDValue();
24733 
24734   // We still perform the exact same EXTRACT_SUBVECTOR,  just on different
24735   // operand[s]/index[es], so there is no point in checking for it's legality.
24736 
24737   // Do not turn a legal shuffle into an illegal one.
24738   if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
24739       !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
24740     return SDValue();
24741 
24742   SDLoc DL(N);
24743 
24744   SmallVector<SDValue, 2> NewOps;
24745   for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
24746            &DemandedSubvector : DemandedSubvectors) {
24747     // How many elements into the WideVT does this subvector start?
24748     int Index = NumEltsExtracted * DemandedSubvector.second;
24749     SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
24750     NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
24751                                     DemandedSubvector.first, IndexC));
24752   }
24753   assert((NewOps.size() == 1 || NewOps.size() == 2) &&
24754          "Should end up with either one or two ops");
24755 
24756   // If we ended up with only one operand, pad with an undef.
24757   if (NewOps.size() == 1)
24758     NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
24759 
24760   return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
24761 }
24762 
visitEXTRACT_SUBVECTOR(SDNode * N)24763 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
24764   EVT NVT = N->getValueType(0);
24765   SDValue V = N->getOperand(0);
24766   uint64_t ExtIdx = N->getConstantOperandVal(1);
24767   SDLoc DL(N);
24768 
24769   // Extract from UNDEF is UNDEF.
24770   if (V.isUndef())
24771     return DAG.getUNDEF(NVT);
24772 
24773   if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
24774     if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
24775       return NarrowLoad;
24776 
24777   // Combine an extract of an extract into a single extract_subvector.
24778   // ext (ext X, C), 0 --> ext X, C
24779   if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
24780     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
24781                                     V.getConstantOperandVal(1)) &&
24782         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
24783       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, V.getOperand(0),
24784                          V.getOperand(1));
24785     }
24786   }
24787 
24788   // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
24789   if (V.getOpcode() == ISD::SPLAT_VECTOR)
24790     if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
24791       if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
24792         return DAG.getSplatVector(NVT, DL, V.getOperand(0));
24793 
24794   // extract_subvector(insert_subvector(x,y,c1),c2)
24795   //  --> extract_subvector(y,c2-c1)
24796   // iff we're just extracting from the inserted subvector.
24797   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24798     SDValue InsSub = V.getOperand(1);
24799     EVT InsSubVT = InsSub.getValueType();
24800     unsigned NumInsElts = InsSubVT.getVectorMinNumElements();
24801     unsigned InsIdx = V.getConstantOperandVal(2);
24802     unsigned NumSubElts = NVT.getVectorMinNumElements();
24803     if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) &&
24804         TLI.isExtractSubvectorCheap(NVT, InsSubVT, ExtIdx - InsIdx) &&
24805         InsSubVT.isFixedLengthVector() && NVT.isFixedLengthVector() &&
24806         V.getValueType().isFixedLengthVector())
24807       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, InsSub,
24808                          DAG.getVectorIdxConstant(ExtIdx - InsIdx, DL));
24809   }
24810 
24811   // Try to move vector bitcast after extract_subv by scaling extraction index:
24812   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
24813   if (V.getOpcode() == ISD::BITCAST &&
24814       V.getOperand(0).getValueType().isVector() &&
24815       (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
24816     SDValue SrcOp = V.getOperand(0);
24817     EVT SrcVT = SrcOp.getValueType();
24818     unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
24819     unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
24820     if ((SrcNumElts % DestNumElts) == 0) {
24821       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
24822       ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
24823       EVT NewExtVT =
24824           EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(), NewExtEC);
24825       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
24826         SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
24827         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
24828                                          V.getOperand(0), NewIndex);
24829         return DAG.getBitcast(NVT, NewExtract);
24830       }
24831     }
24832     if ((DestNumElts % SrcNumElts) == 0) {
24833       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
24834       if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
24835         ElementCount NewExtEC =
24836             NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
24837         EVT ScalarVT = SrcVT.getScalarType();
24838         if ((ExtIdx % DestSrcRatio) == 0) {
24839           unsigned IndexValScaled = ExtIdx / DestSrcRatio;
24840           EVT NewExtVT =
24841               EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
24842           if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
24843             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
24844             SDValue NewExtract =
24845                 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
24846                             V.getOperand(0), NewIndex);
24847             return DAG.getBitcast(NVT, NewExtract);
24848           }
24849           if (NewExtEC.isScalar() &&
24850               TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
24851             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
24852             SDValue NewExtract =
24853                 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
24854                             V.getOperand(0), NewIndex);
24855             return DAG.getBitcast(NVT, NewExtract);
24856           }
24857         }
24858       }
24859     }
24860   }
24861 
24862   if (V.getOpcode() == ISD::CONCAT_VECTORS) {
24863     unsigned ExtNumElts = NVT.getVectorMinNumElements();
24864     EVT ConcatSrcVT = V.getOperand(0).getValueType();
24865     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
24866            "Concat and extract subvector do not change element type");
24867     assert((ExtIdx % ExtNumElts) == 0 &&
24868            "Extract index is not a multiple of the input vector length.");
24869 
24870     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
24871     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
24872 
24873     // If the concatenated source types match this extract, it's a direct
24874     // simplification:
24875     // extract_subvec (concat V1, V2, ...), i --> Vi
24876     if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
24877       return V.getOperand(ConcatOpIdx);
24878 
24879     // If the concatenated source vectors are a multiple length of this extract,
24880     // then extract a fraction of one of those source vectors directly from a
24881     // concat operand. Example:
24882     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
24883     //   v2i8 extract_subvec v8i8 Y, 6
24884     if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
24885         ConcatSrcNumElts % ExtNumElts == 0) {
24886       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
24887       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
24888              "Trying to extract from >1 concat operand?");
24889       assert(NewExtIdx % ExtNumElts == 0 &&
24890              "Extract index is not a multiple of the input vector length.");
24891       SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
24892       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
24893                          V.getOperand(ConcatOpIdx), NewIndexC);
24894     }
24895   }
24896 
24897   if (SDValue V =
24898           foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
24899     return V;
24900 
24901   V = peekThroughBitcasts(V);
24902 
24903   // If the input is a build vector. Try to make a smaller build vector.
24904   if (V.getOpcode() == ISD::BUILD_VECTOR) {
24905     EVT InVT = V.getValueType();
24906     unsigned ExtractSize = NVT.getSizeInBits();
24907     unsigned EltSize = InVT.getScalarSizeInBits();
24908     // Only do this if we won't split any elements.
24909     if (ExtractSize % EltSize == 0) {
24910       unsigned NumElems = ExtractSize / EltSize;
24911       EVT EltVT = InVT.getVectorElementType();
24912       EVT ExtractVT =
24913           NumElems == 1 ? EltVT
24914                         : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
24915       if ((Level < AfterLegalizeDAG ||
24916            (NumElems == 1 ||
24917             TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
24918           (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
24919         unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
24920 
24921         if (NumElems == 1) {
24922           SDValue Src = V->getOperand(IdxVal);
24923           if (EltVT != Src.getValueType())
24924             Src = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Src);
24925           return DAG.getBitcast(NVT, Src);
24926         }
24927 
24928         // Extract the pieces from the original build_vector.
24929         SDValue BuildVec =
24930             DAG.getBuildVector(ExtractVT, DL, V->ops().slice(IdxVal, NumElems));
24931         return DAG.getBitcast(NVT, BuildVec);
24932       }
24933     }
24934   }
24935 
24936   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24937     // Handle only simple case where vector being inserted and vector
24938     // being extracted are of same size.
24939     EVT SmallVT = V.getOperand(1).getValueType();
24940     if (!NVT.bitsEq(SmallVT))
24941       return SDValue();
24942 
24943     // Combine:
24944     //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
24945     // Into:
24946     //    indices are equal or bit offsets are equal => V1
24947     //    otherwise => (extract_subvec V1, ExtIdx)
24948     uint64_t InsIdx = V.getConstantOperandVal(2);
24949     if (InsIdx * SmallVT.getScalarSizeInBits() ==
24950         ExtIdx * NVT.getScalarSizeInBits()) {
24951       if (LegalOperations && !TLI.isOperationLegal(ISD::BITCAST, NVT))
24952         return SDValue();
24953 
24954       return DAG.getBitcast(NVT, V.getOperand(1));
24955     }
24956     return DAG.getNode(
24957         ISD::EXTRACT_SUBVECTOR, DL, NVT,
24958         DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
24959         N->getOperand(1));
24960   }
24961 
24962   if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
24963     return NarrowBOp;
24964 
24965   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
24966     return SDValue(N, 0);
24967 
24968   return SDValue();
24969 }
24970 
24971 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
24972 /// followed by concatenation. Narrow vector ops may have better performance
24973 /// than wide ops, and this can unlock further narrowing of other vector ops.
24974 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)24975 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
24976                                          SelectionDAG &DAG) {
24977   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
24978   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
24979       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
24980       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
24981     return SDValue();
24982 
24983   // Split the wide shuffle mask into halves. Any mask element that is accessing
24984   // operand 1 is offset down to account for narrowing of the vectors.
24985   ArrayRef<int> Mask = Shuf->getMask();
24986   EVT VT = Shuf->getValueType(0);
24987   unsigned NumElts = VT.getVectorNumElements();
24988   unsigned HalfNumElts = NumElts / 2;
24989   SmallVector<int, 16> Mask0(HalfNumElts, -1);
24990   SmallVector<int, 16> Mask1(HalfNumElts, -1);
24991   for (unsigned i = 0; i != NumElts; ++i) {
24992     if (Mask[i] == -1)
24993       continue;
24994     // If we reference the upper (undef) subvector then the element is undef.
24995     if ((Mask[i] % NumElts) >= HalfNumElts)
24996       continue;
24997     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
24998     if (i < HalfNumElts)
24999       Mask0[i] = M;
25000     else
25001       Mask1[i - HalfNumElts] = M;
25002   }
25003 
25004   // Ask the target if this is a valid transform.
25005   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25006   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
25007                                 HalfNumElts);
25008   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
25009       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
25010     return SDValue();
25011 
25012   // shuffle (concat X, undef), (concat Y, undef), Mask -->
25013   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
25014   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
25015   SDLoc DL(Shuf);
25016   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
25017   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
25018   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
25019 }
25020 
25021 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
25022 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)25023 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
25024   EVT VT = N->getValueType(0);
25025   unsigned NumElts = VT.getVectorNumElements();
25026 
25027   SDValue N0 = N->getOperand(0);
25028   SDValue N1 = N->getOperand(1);
25029   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
25030   ArrayRef<int> Mask = SVN->getMask();
25031 
25032   SmallVector<SDValue, 4> Ops;
25033   EVT ConcatVT = N0.getOperand(0).getValueType();
25034   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
25035   unsigned NumConcats = NumElts / NumElemsPerConcat;
25036 
25037   auto IsUndefMaskElt = [](int i) { return i == -1; };
25038 
25039   // Special case: shuffle(concat(A,B)) can be more efficiently represented
25040   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
25041   // half vector elements.
25042   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
25043       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
25044                    IsUndefMaskElt)) {
25045     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
25046                               N0.getOperand(1),
25047                               Mask.slice(0, NumElemsPerConcat));
25048     N1 = DAG.getUNDEF(ConcatVT);
25049     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
25050   }
25051 
25052   // Look at every vector that's inserted. We're looking for exact
25053   // subvector-sized copies from a concatenated vector
25054   for (unsigned I = 0; I != NumConcats; ++I) {
25055     unsigned Begin = I * NumElemsPerConcat;
25056     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
25057 
25058     // Make sure we're dealing with a copy.
25059     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
25060       Ops.push_back(DAG.getUNDEF(ConcatVT));
25061       continue;
25062     }
25063 
25064     int OpIdx = -1;
25065     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
25066       if (IsUndefMaskElt(SubMask[i]))
25067         continue;
25068       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
25069         return SDValue();
25070       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
25071       if (0 <= OpIdx && EltOpIdx != OpIdx)
25072         return SDValue();
25073       OpIdx = EltOpIdx;
25074     }
25075     assert(0 <= OpIdx && "Unknown concat_vectors op");
25076 
25077     if (OpIdx < (int)N0.getNumOperands())
25078       Ops.push_back(N0.getOperand(OpIdx));
25079     else
25080       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
25081   }
25082 
25083   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
25084 }
25085 
25086 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
25087 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
25088 //
25089 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
25090 // a simplification in some sense, but it isn't appropriate in general: some
25091 // BUILD_VECTORs are substantially cheaper than others. The general case
25092 // of a BUILD_VECTOR requires inserting each element individually (or
25093 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
25094 // all constants is a single constant pool load.  A BUILD_VECTOR where each
25095 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
25096 // are undef lowers to a small number of element insertions.
25097 //
25098 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
25099 // We don't fold shuffles where one side is a non-zero constant, and we don't
25100 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
25101 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)25102 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
25103                                        SelectionDAG &DAG,
25104                                        const TargetLowering &TLI) {
25105   EVT VT = SVN->getValueType(0);
25106   unsigned NumElts = VT.getVectorNumElements();
25107   SDValue N0 = SVN->getOperand(0);
25108   SDValue N1 = SVN->getOperand(1);
25109 
25110   if (!N0->hasOneUse())
25111     return SDValue();
25112 
25113   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
25114   // discussed above.
25115   if (!N1.isUndef()) {
25116     if (!N1->hasOneUse())
25117       return SDValue();
25118 
25119     bool N0AnyConst = isAnyConstantBuildVector(N0);
25120     bool N1AnyConst = isAnyConstantBuildVector(N1);
25121     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
25122       return SDValue();
25123     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
25124       return SDValue();
25125   }
25126 
25127   // If both inputs are splats of the same value then we can safely merge this
25128   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
25129   bool IsSplat = false;
25130   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
25131   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
25132   if (BV0 && BV1)
25133     if (SDValue Splat0 = BV0->getSplatValue())
25134       IsSplat = (Splat0 == BV1->getSplatValue());
25135 
25136   SmallVector<SDValue, 8> Ops;
25137   SmallSet<SDValue, 16> DuplicateOps;
25138   for (int M : SVN->getMask()) {
25139     SDValue Op = DAG.getUNDEF(VT.getScalarType());
25140     if (M >= 0) {
25141       int Idx = M < (int)NumElts ? M : M - NumElts;
25142       SDValue &S = (M < (int)NumElts ? N0 : N1);
25143       if (S.getOpcode() == ISD::BUILD_VECTOR) {
25144         Op = S.getOperand(Idx);
25145       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
25146         SDValue Op0 = S.getOperand(0);
25147         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
25148       } else {
25149         // Operand can't be combined - bail out.
25150         return SDValue();
25151       }
25152     }
25153 
25154     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
25155     // generating a splat; semantically, this is fine, but it's likely to
25156     // generate low-quality code if the target can't reconstruct an appropriate
25157     // shuffle.
25158     if (!Op.isUndef() && !isIntOrFPConstant(Op))
25159       if (!IsSplat && !DuplicateOps.insert(Op).second)
25160         return SDValue();
25161 
25162     Ops.push_back(Op);
25163   }
25164 
25165   // BUILD_VECTOR requires all inputs to be of the same type, find the
25166   // maximum type and extend them all.
25167   EVT SVT = VT.getScalarType();
25168   if (SVT.isInteger())
25169     for (SDValue &Op : Ops)
25170       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
25171   if (SVT != VT.getScalarType())
25172     for (SDValue &Op : Ops)
25173       Op = Op.isUndef() ? DAG.getUNDEF(SVT)
25174                         : (TLI.isZExtFree(Op.getValueType(), SVT)
25175                                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
25176                                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
25177   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
25178 }
25179 
25180 // Match shuffles that can be converted to *_vector_extend_in_reg.
25181 // This is often generated during legalization.
25182 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
25183 // and returns the EVT to which the extension should be performed.
25184 // NOTE: this assumes that the src is the first operand of the shuffle.
canCombineShuffleToExtendVectorInreg(unsigned Opcode,EVT VT,std::function<bool (unsigned)> Match,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)25185 static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
25186     unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
25187     SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
25188     bool LegalOperations) {
25189   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25190 
25191   // TODO Add support for big-endian when we have a test case.
25192   if (!VT.isInteger() || IsBigEndian)
25193     return std::nullopt;
25194 
25195   unsigned NumElts = VT.getVectorNumElements();
25196   unsigned EltSizeInBits = VT.getScalarSizeInBits();
25197 
25198   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
25199   // power-of-2 extensions as they are the most likely.
25200   // FIXME: should try Scale == NumElts case too,
25201   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
25202     // The vector width must be a multiple of Scale.
25203     if (NumElts % Scale != 0)
25204       continue;
25205 
25206     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
25207     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
25208 
25209     if ((LegalTypes && !TLI.isTypeLegal(OutVT)) ||
25210         (LegalOperations && !TLI.isOperationLegalOrCustom(Opcode, OutVT)))
25211       continue;
25212 
25213     if (Match(Scale))
25214       return OutVT;
25215   }
25216 
25217   return std::nullopt;
25218 }
25219 
25220 // Match shuffles that can be converted to any_vector_extend_in_reg.
25221 // This is often generated during legalization.
25222 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)25223 static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
25224                                                     SelectionDAG &DAG,
25225                                                     const TargetLowering &TLI,
25226                                                     bool LegalOperations) {
25227   EVT VT = SVN->getValueType(0);
25228   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25229 
25230   // TODO Add support for big-endian when we have a test case.
25231   if (!VT.isInteger() || IsBigEndian)
25232     return SDValue();
25233 
25234   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
25235   auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
25236                       Mask = SVN->getMask()](unsigned Scale) {
25237     for (unsigned i = 0; i != NumElts; ++i) {
25238       if (Mask[i] < 0)
25239         continue;
25240       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
25241         continue;
25242       return false;
25243     }
25244     return true;
25245   };
25246 
25247   unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
25248   SDValue N0 = SVN->getOperand(0);
25249   // Never create an illegal type. Only create unsupported operations if we
25250   // are pre-legalization.
25251   std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
25252       Opcode, VT, isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
25253   if (!OutVT)
25254     return SDValue();
25255   return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT, N0));
25256 }
25257 
25258 // Match shuffles that can be converted to zero_extend_vector_inreg.
25259 // This is often generated during legalization.
25260 // e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)25261 static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
25262                                                      SelectionDAG &DAG,
25263                                                      const TargetLowering &TLI,
25264                                                      bool LegalOperations) {
25265   bool LegalTypes = true;
25266   EVT VT = SVN->getValueType(0);
25267   assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
25268   unsigned NumElts = VT.getVectorNumElements();
25269   unsigned EltSizeInBits = VT.getScalarSizeInBits();
25270 
25271   // TODO: add support for big-endian when we have a test case.
25272   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25273   if (!VT.isInteger() || IsBigEndian)
25274     return SDValue();
25275 
25276   SmallVector<int, 16> Mask(SVN->getMask().begin(), SVN->getMask().end());
25277   auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
25278     for (int &Indice : Mask) {
25279       if (Indice < 0)
25280         continue;
25281       int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
25282       int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
25283       Fn(Indice, OpIdx, OpEltIdx);
25284     }
25285   };
25286 
25287   // Which elements of which operand does this shuffle demand?
25288   std::array<APInt, 2> OpsDemandedElts;
25289   for (APInt &OpDemandedElts : OpsDemandedElts)
25290     OpDemandedElts = APInt::getZero(NumElts);
25291   ForEachDecomposedIndice(
25292       [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
25293         OpsDemandedElts[OpIdx].setBit(OpEltIdx);
25294       });
25295 
25296   // Element-wise(!), which of these demanded elements are know to be zero?
25297   std::array<APInt, 2> OpsKnownZeroElts;
25298   for (auto I : zip(SVN->ops(), OpsDemandedElts, OpsKnownZeroElts))
25299     std::get<2>(I) =
25300         DAG.computeVectorKnownZeroElements(std::get<0>(I), std::get<1>(I));
25301 
25302   // Manifest zeroable element knowledge in the shuffle mask.
25303   // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
25304   //       this is a local invention, but it won't leak into DAG.
25305   // FIXME: should we not manifest them, but just check when matching?
25306   bool HadZeroableElts = false;
25307   ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
25308                               int &Indice, int OpIdx, int OpEltIdx) {
25309     if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
25310       Indice = -2; // Zeroable element.
25311       HadZeroableElts = true;
25312     }
25313   });
25314 
25315   // Don't proceed unless we've refined at least one zeroable mask indice.
25316   // If we didn't, then we are still trying to match the same shuffle mask
25317   // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
25318   // and evidently failed. Proceeding will lead to endless combine loops.
25319   if (!HadZeroableElts)
25320     return SDValue();
25321 
25322   // The shuffle may be more fine-grained than we want. Widen elements first.
25323   // FIXME: should we do this before manifesting zeroable shuffle mask indices?
25324   SmallVector<int, 16> ScaledMask;
25325   getShuffleMaskWithWidestElts(Mask, ScaledMask);
25326   assert(Mask.size() >= ScaledMask.size() &&
25327          Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
25328   int Prescale = Mask.size() / ScaledMask.size();
25329 
25330   NumElts = ScaledMask.size();
25331   EltSizeInBits *= Prescale;
25332 
25333   EVT PrescaledVT = EVT::getVectorVT(
25334       *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits),
25335       NumElts);
25336 
25337   if (LegalTypes && !TLI.isTypeLegal(PrescaledVT) && TLI.isTypeLegal(VT))
25338     return SDValue();
25339 
25340   // For example,
25341   // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
25342   // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
25343   auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
25344     assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
25345            "Unexpected mask scaling factor.");
25346     ArrayRef<int> Mask = ScaledMask;
25347     for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
25348          SrcElt != NumSrcElts; ++SrcElt) {
25349       // Analyze the shuffle mask in Scale-sized chunks.
25350       ArrayRef<int> MaskChunk = Mask.take_front(Scale);
25351       assert(MaskChunk.size() == Scale && "Unexpected mask size.");
25352       Mask = Mask.drop_front(MaskChunk.size());
25353       // The first indice in this chunk must be SrcElt, but not zero!
25354       // FIXME: undef should be fine, but that results in more-defined result.
25355       if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
25356         return false;
25357       // The rest of the indices in this chunk must be zeros.
25358       // FIXME: undef should be fine, but that results in more-defined result.
25359       if (!all_of(MaskChunk.drop_front(1),
25360                   [](int Indice) { return Indice == -2; }))
25361         return false;
25362     }
25363     assert(Mask.empty() && "Did not process the whole mask?");
25364     return true;
25365   };
25366 
25367   unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
25368   for (bool Commuted : {false, true}) {
25369     SDValue Op = SVN->getOperand(!Commuted ? 0 : 1);
25370     if (Commuted)
25371       ShuffleVectorSDNode::commuteMask(ScaledMask);
25372     std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
25373         Opcode, PrescaledVT, isZeroExtend, DAG, TLI, LegalTypes,
25374         LegalOperations);
25375     if (OutVT)
25376       return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT,
25377                                             DAG.getBitcast(PrescaledVT, Op)));
25378   }
25379   return SDValue();
25380 }
25381 
25382 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
25383 // each source element of a large type into the lowest elements of a smaller
25384 // destination type. This is often generated during legalization.
25385 // If the source node itself was a '*_extend_vector_inreg' node then we should
25386 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)25387 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
25388                                         SelectionDAG &DAG) {
25389   EVT VT = SVN->getValueType(0);
25390   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25391 
25392   // TODO Add support for big-endian when we have a test case.
25393   if (!VT.isInteger() || IsBigEndian)
25394     return SDValue();
25395 
25396   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
25397 
25398   unsigned Opcode = N0.getOpcode();
25399   if (!ISD::isExtVecInRegOpcode(Opcode))
25400     return SDValue();
25401 
25402   SDValue N00 = N0.getOperand(0);
25403   ArrayRef<int> Mask = SVN->getMask();
25404   unsigned NumElts = VT.getVectorNumElements();
25405   unsigned EltSizeInBits = VT.getScalarSizeInBits();
25406   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
25407   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
25408 
25409   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
25410     return SDValue();
25411   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
25412 
25413   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
25414   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
25415   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
25416   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
25417     for (unsigned i = 0; i != NumElts; ++i) {
25418       if (Mask[i] < 0)
25419         continue;
25420       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
25421         continue;
25422       return false;
25423     }
25424     return true;
25425   };
25426 
25427   // At the moment we just handle the case where we've truncated back to the
25428   // same size as before the extension.
25429   // TODO: handle more extension/truncation cases as cases arise.
25430   if (EltSizeInBits != ExtSrcSizeInBits)
25431     return SDValue();
25432 
25433   // We can remove *extend_vector_inreg only if the truncation happens at
25434   // the same scale as the extension.
25435   if (isTruncate(ExtScale))
25436     return DAG.getBitcast(VT, N00);
25437 
25438   return SDValue();
25439 }
25440 
25441 // Combine shuffles of splat-shuffles of the form:
25442 // shuffle (shuffle V, undef, splat-mask), undef, M
25443 // If splat-mask contains undef elements, we need to be careful about
25444 // introducing undef's in the folded mask which are not the result of composing
25445 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)25446 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
25447                                         SelectionDAG &DAG) {
25448   EVT VT = Shuf->getValueType(0);
25449   unsigned NumElts = VT.getVectorNumElements();
25450 
25451   if (!Shuf->getOperand(1).isUndef())
25452     return SDValue();
25453 
25454   // See if this unary non-splat shuffle actually *is* a splat shuffle,
25455   // in disguise, with all demanded elements being identical.
25456   // FIXME: this can be done per-operand.
25457   if (!Shuf->isSplat()) {
25458     APInt DemandedElts(NumElts, 0);
25459     for (int Idx : Shuf->getMask()) {
25460       if (Idx < 0)
25461         continue; // Ignore sentinel indices.
25462       assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
25463       DemandedElts.setBit(Idx);
25464     }
25465     assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
25466     APInt UndefElts;
25467     if (DAG.isSplatValue(Shuf->getOperand(0), DemandedElts, UndefElts)) {
25468       // Even if all demanded elements are splat, some of them could be undef.
25469       // Which lowest demanded element is *not* known-undef?
25470       std::optional<unsigned> MinNonUndefIdx;
25471       for (int Idx : Shuf->getMask()) {
25472         if (Idx < 0 || UndefElts[Idx])
25473           continue; // Ignore sentinel indices, and undef elements.
25474         MinNonUndefIdx = std::min<unsigned>(Idx, MinNonUndefIdx.value_or(~0U));
25475       }
25476       if (!MinNonUndefIdx)
25477         return DAG.getUNDEF(VT); // All undef - result is undef.
25478       assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
25479       SmallVector<int, 8> SplatMask(Shuf->getMask().begin(),
25480                                     Shuf->getMask().end());
25481       for (int &Idx : SplatMask) {
25482         if (Idx < 0)
25483           continue; // Passthrough sentinel indices.
25484         // Otherwise, just pick the lowest demanded non-undef element.
25485         // Or sentinel undef, if we know we'd pick a known-undef element.
25486         Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
25487       }
25488       assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
25489       return DAG.getVectorShuffle(VT, SDLoc(Shuf), Shuf->getOperand(0),
25490                                   Shuf->getOperand(1), SplatMask);
25491     }
25492   }
25493 
25494   // If the inner operand is a known splat with no undefs, just return that directly.
25495   // TODO: Create DemandedElts mask from Shuf's mask.
25496   // TODO: Allow undef elements and merge with the shuffle code below.
25497   if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
25498     return Shuf->getOperand(0);
25499 
25500   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
25501   if (!Splat || !Splat->isSplat())
25502     return SDValue();
25503 
25504   ArrayRef<int> ShufMask = Shuf->getMask();
25505   ArrayRef<int> SplatMask = Splat->getMask();
25506   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
25507 
25508   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
25509   // every undef mask element in the splat-shuffle has a corresponding undef
25510   // element in the user-shuffle's mask or if the composition of mask elements
25511   // would result in undef.
25512   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
25513   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
25514   //   In this case it is not legal to simplify to the splat-shuffle because we
25515   //   may be exposing the users of the shuffle an undef element at index 1
25516   //   which was not there before the combine.
25517   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
25518   //   In this case the composition of masks yields SplatMask, so it's ok to
25519   //   simplify to the splat-shuffle.
25520   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
25521   //   In this case the composed mask includes all undef elements of SplatMask
25522   //   and in addition sets element zero to undef. It is safe to simplify to
25523   //   the splat-shuffle.
25524   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
25525                                        ArrayRef<int> SplatMask) {
25526     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
25527       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
25528           SplatMask[UserMask[i]] != -1)
25529         return false;
25530     return true;
25531   };
25532   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
25533     return Shuf->getOperand(0);
25534 
25535   // Create a new shuffle with a mask that is composed of the two shuffles'
25536   // masks.
25537   SmallVector<int, 32> NewMask;
25538   for (int Idx : ShufMask)
25539     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
25540 
25541   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
25542                               Splat->getOperand(0), Splat->getOperand(1),
25543                               NewMask);
25544 }
25545 
25546 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
25547 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)25548 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
25549                                        SelectionDAG &DAG,
25550                                        const TargetLowering &TLI,
25551                                        bool LegalOperations) {
25552   SDValue Op0 = SVN->getOperand(0);
25553   SDValue Op1 = SVN->getOperand(1);
25554   EVT VT = SVN->getValueType(0);
25555   if (Op0.getOpcode() != ISD::BITCAST)
25556     return SDValue();
25557   EVT InVT = Op0.getOperand(0).getValueType();
25558   if (!InVT.isVector() ||
25559       (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
25560                           Op1.getOperand(0).getValueType() != InVT)))
25561     return SDValue();
25562   if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
25563       (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
25564     return SDValue();
25565 
25566   int VTLanes = VT.getVectorNumElements();
25567   int InLanes = InVT.getVectorNumElements();
25568   if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
25569       (LegalOperations &&
25570        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
25571     return SDValue();
25572   int Factor = VTLanes / InLanes;
25573 
25574   // Check that each group of lanes in the mask are either undef or make a valid
25575   // mask for the wider lane type.
25576   ArrayRef<int> Mask = SVN->getMask();
25577   SmallVector<int> NewMask;
25578   if (!widenShuffleMaskElts(Factor, Mask, NewMask))
25579     return SDValue();
25580 
25581   if (!TLI.isShuffleMaskLegal(NewMask, InVT))
25582     return SDValue();
25583 
25584   // Create the new shuffle with the new mask and bitcast it back to the
25585   // original type.
25586   SDLoc DL(SVN);
25587   Op0 = Op0.getOperand(0);
25588   Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
25589   SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
25590   return DAG.getBitcast(VT, NewShuf);
25591 }
25592 
25593 /// Combine shuffle of shuffle of the form:
25594 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)25595 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
25596                                      SelectionDAG &DAG) {
25597   if (!OuterShuf->getOperand(1).isUndef())
25598     return SDValue();
25599   auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
25600   if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
25601     return SDValue();
25602 
25603   ArrayRef<int> OuterMask = OuterShuf->getMask();
25604   ArrayRef<int> InnerMask = InnerShuf->getMask();
25605   unsigned NumElts = OuterMask.size();
25606   assert(NumElts == InnerMask.size() && "Mask length mismatch");
25607   SmallVector<int, 32> CombinedMask(NumElts, -1);
25608   int SplatIndex = -1;
25609   for (unsigned i = 0; i != NumElts; ++i) {
25610     // Undef lanes remain undef.
25611     int OuterMaskElt = OuterMask[i];
25612     if (OuterMaskElt == -1)
25613       continue;
25614 
25615     // Peek through the shuffle masks to get the underlying source element.
25616     int InnerMaskElt = InnerMask[OuterMaskElt];
25617     if (InnerMaskElt == -1)
25618       continue;
25619 
25620     // Initialize the splatted element.
25621     if (SplatIndex == -1)
25622       SplatIndex = InnerMaskElt;
25623 
25624     // Non-matching index - this is not a splat.
25625     if (SplatIndex != InnerMaskElt)
25626       return SDValue();
25627 
25628     CombinedMask[i] = InnerMaskElt;
25629   }
25630   assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
25631           getSplatIndex(CombinedMask) != -1) &&
25632          "Expected a splat mask");
25633 
25634   // TODO: The transform may be a win even if the mask is not legal.
25635   EVT VT = OuterShuf->getValueType(0);
25636   assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
25637   if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
25638     return SDValue();
25639 
25640   return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
25641                               InnerShuf->getOperand(1), CombinedMask);
25642 }
25643 
25644 /// If the shuffle mask is taking exactly one element from the first vector
25645 /// operand and passing through all other elements from the second vector
25646 /// operand, return the index of the mask element that is choosing an element
25647 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)25648 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
25649   int MaskSize = Mask.size();
25650   int EltFromOp0 = -1;
25651   // TODO: This does not match if there are undef elements in the shuffle mask.
25652   // Should we ignore undefs in the shuffle mask instead? The trade-off is
25653   // removing an instruction (a shuffle), but losing the knowledge that some
25654   // vector lanes are not needed.
25655   for (int i = 0; i != MaskSize; ++i) {
25656     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
25657       // We're looking for a shuffle of exactly one element from operand 0.
25658       if (EltFromOp0 != -1)
25659         return -1;
25660       EltFromOp0 = i;
25661     } else if (Mask[i] != i + MaskSize) {
25662       // Nothing from operand 1 can change lanes.
25663       return -1;
25664     }
25665   }
25666   return EltFromOp0;
25667 }
25668 
25669 /// If a shuffle inserts exactly one element from a source vector operand into
25670 /// another vector operand and we can access the specified element as a scalar,
25671 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)25672 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
25673                                       SelectionDAG &DAG) {
25674   // First, check if we are taking one element of a vector and shuffling that
25675   // element into another vector.
25676   ArrayRef<int> Mask = Shuf->getMask();
25677   SmallVector<int, 16> CommutedMask(Mask);
25678   SDValue Op0 = Shuf->getOperand(0);
25679   SDValue Op1 = Shuf->getOperand(1);
25680   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
25681   if (ShufOp0Index == -1) {
25682     // Commute mask and check again.
25683     ShuffleVectorSDNode::commuteMask(CommutedMask);
25684     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
25685     if (ShufOp0Index == -1)
25686       return SDValue();
25687     // Commute operands to match the commuted shuffle mask.
25688     std::swap(Op0, Op1);
25689     Mask = CommutedMask;
25690   }
25691 
25692   // The shuffle inserts exactly one element from operand 0 into operand 1.
25693   // Now see if we can access that element as a scalar via a real insert element
25694   // instruction.
25695   // TODO: We can try harder to locate the element as a scalar. Examples: it
25696   // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
25697   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
25698          "Shuffle mask value must be from operand 0");
25699   if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
25700     return SDValue();
25701 
25702   auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
25703   if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
25704     return SDValue();
25705 
25706   // There's an existing insertelement with constant insertion index, so we
25707   // don't need to check the legality/profitability of a replacement operation
25708   // that differs at most in the constant value. The target should be able to
25709   // lower any of those in a similar way. If not, legalization will expand this
25710   // to a scalar-to-vector plus shuffle.
25711   //
25712   // Note that the shuffle may move the scalar from the position that the insert
25713   // element used. Therefore, our new insert element occurs at the shuffle's
25714   // mask index value, not the insert's index value.
25715   // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
25716   SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
25717   return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
25718                      Op1, Op0.getOperand(1), NewInsIndex);
25719 }
25720 
25721 /// If we have a unary shuffle of a shuffle, see if it can be folded away
25722 /// completely. This has the potential to lose undef knowledge because the first
25723 /// shuffle may not have an undef mask element where the second one does. So
25724 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)25725 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
25726   // shuf (shuf0 X, Y, Mask0), undef, Mask
25727   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
25728   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
25729     return SDValue();
25730 
25731   ArrayRef<int> Mask = Shuf->getMask();
25732   ArrayRef<int> Mask0 = Shuf0->getMask();
25733   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
25734     // Ignore undef elements.
25735     if (Mask[i] == -1)
25736       continue;
25737     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
25738 
25739     // Is the element of the shuffle operand chosen by this shuffle the same as
25740     // the element chosen by the shuffle operand itself?
25741     if (Mask0[Mask[i]] != Mask0[i])
25742       return SDValue();
25743   }
25744   // Every element of this shuffle is identical to the result of the previous
25745   // shuffle, so we can replace this value.
25746   return Shuf->getOperand(0);
25747 }
25748 
visitVECTOR_SHUFFLE(SDNode * N)25749 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
25750   EVT VT = N->getValueType(0);
25751   unsigned NumElts = VT.getVectorNumElements();
25752 
25753   SDValue N0 = N->getOperand(0);
25754   SDValue N1 = N->getOperand(1);
25755 
25756   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
25757 
25758   // Canonicalize shuffle undef, undef -> undef
25759   if (N0.isUndef() && N1.isUndef())
25760     return DAG.getUNDEF(VT);
25761 
25762   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
25763 
25764   // Canonicalize shuffle v, v -> v, undef
25765   if (N0 == N1)
25766     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
25767                                 createUnaryMask(SVN->getMask(), NumElts));
25768 
25769   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
25770   if (N0.isUndef())
25771     return DAG.getCommutedVectorShuffle(*SVN);
25772 
25773   // Remove references to rhs if it is undef
25774   if (N1.isUndef()) {
25775     bool Changed = false;
25776     SmallVector<int, 8> NewMask;
25777     for (unsigned i = 0; i != NumElts; ++i) {
25778       int Idx = SVN->getMaskElt(i);
25779       if (Idx >= (int)NumElts) {
25780         Idx = -1;
25781         Changed = true;
25782       }
25783       NewMask.push_back(Idx);
25784     }
25785     if (Changed)
25786       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
25787   }
25788 
25789   if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
25790     return InsElt;
25791 
25792   // A shuffle of a single vector that is a splatted value can always be folded.
25793   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
25794     return V;
25795 
25796   if (SDValue V = formSplatFromShuffles(SVN, DAG))
25797     return V;
25798 
25799   // If it is a splat, check if the argument vector is another splat or a
25800   // build_vector.
25801   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
25802     int SplatIndex = SVN->getSplatIndex();
25803     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
25804         TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
25805       // splat (vector_bo L, R), Index -->
25806       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
25807       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
25808       SDLoc DL(N);
25809       EVT EltVT = VT.getScalarType();
25810       SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
25811       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
25812       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
25813       SDValue NewBO =
25814           DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
25815       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
25816       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
25817       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
25818     }
25819 
25820     // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
25821     // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
25822     if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
25823         N0.hasOneUse()) {
25824       if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
25825         return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
25826 
25827       if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
25828         if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
25829           if (Idx->getAPIntValue() == SplatIndex)
25830             return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
25831 
25832       // Look through a bitcast if LE and splatting lane 0, through to a
25833       // scalar_to_vector or a build_vector.
25834       if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(0).hasOneUse() &&
25835           SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
25836           (N0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
25837            N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR)) {
25838         EVT N00VT = N0.getOperand(0).getValueType();
25839         if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
25840             VT.isInteger() && N00VT.isInteger()) {
25841           EVT InVT =
25842               TLI.getTypeToTransformTo(*DAG.getContext(), VT.getScalarType());
25843           SDValue Op = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0),
25844                                           SDLoc(N), InVT);
25845           return DAG.getSplatBuildVector(VT, SDLoc(N), Op);
25846         }
25847       }
25848     }
25849 
25850     // If this is a bit convert that changes the element type of the vector but
25851     // not the number of vector elements, look through it.  Be careful not to
25852     // look though conversions that change things like v4f32 to v2f64.
25853     SDNode *V = N0.getNode();
25854     if (V->getOpcode() == ISD::BITCAST) {
25855       SDValue ConvInput = V->getOperand(0);
25856       if (ConvInput.getValueType().isVector() &&
25857           ConvInput.getValueType().getVectorNumElements() == NumElts)
25858         V = ConvInput.getNode();
25859     }
25860 
25861     if (V->getOpcode() == ISD::BUILD_VECTOR) {
25862       assert(V->getNumOperands() == NumElts &&
25863              "BUILD_VECTOR has wrong number of operands");
25864       SDValue Base;
25865       bool AllSame = true;
25866       for (unsigned i = 0; i != NumElts; ++i) {
25867         if (!V->getOperand(i).isUndef()) {
25868           Base = V->getOperand(i);
25869           break;
25870         }
25871       }
25872       // Splat of <u, u, u, u>, return <u, u, u, u>
25873       if (!Base.getNode())
25874         return N0;
25875       for (unsigned i = 0; i != NumElts; ++i) {
25876         if (V->getOperand(i) != Base) {
25877           AllSame = false;
25878           break;
25879         }
25880       }
25881       // Splat of <x, x, x, x>, return <x, x, x, x>
25882       if (AllSame)
25883         return N0;
25884 
25885       // Canonicalize any other splat as a build_vector.
25886       SDValue Splatted = V->getOperand(SplatIndex);
25887       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
25888       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
25889 
25890       // We may have jumped through bitcasts, so the type of the
25891       // BUILD_VECTOR may not match the type of the shuffle.
25892       if (V->getValueType(0) != VT)
25893         NewBV = DAG.getBitcast(VT, NewBV);
25894       return NewBV;
25895     }
25896   }
25897 
25898   // Simplify source operands based on shuffle mask.
25899   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
25900     return SDValue(N, 0);
25901 
25902   // This is intentionally placed after demanded elements simplification because
25903   // it could eliminate knowledge of undef elements created by this shuffle.
25904   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
25905     return ShufOp;
25906 
25907   // Match shuffles that can be converted to any_vector_extend_in_reg.
25908   if (SDValue V =
25909           combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
25910     return V;
25911 
25912   // Combine "truncate_vector_in_reg" style shuffles.
25913   if (SDValue V = combineTruncationShuffle(SVN, DAG))
25914     return V;
25915 
25916   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
25917       Level < AfterLegalizeVectorOps &&
25918       (N1.isUndef() ||
25919       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
25920        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
25921     if (SDValue V = partitionShuffleOfConcats(N, DAG))
25922       return V;
25923   }
25924 
25925   // A shuffle of a concat of the same narrow vector can be reduced to use
25926   // only low-half elements of a concat with undef:
25927   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
25928   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
25929       N0.getNumOperands() == 2 &&
25930       N0.getOperand(0) == N0.getOperand(1)) {
25931     int HalfNumElts = (int)NumElts / 2;
25932     SmallVector<int, 8> NewMask;
25933     for (unsigned i = 0; i != NumElts; ++i) {
25934       int Idx = SVN->getMaskElt(i);
25935       if (Idx >= HalfNumElts) {
25936         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
25937         Idx -= HalfNumElts;
25938       }
25939       NewMask.push_back(Idx);
25940     }
25941     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
25942       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
25943       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
25944                                    N0.getOperand(0), UndefVec);
25945       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
25946     }
25947   }
25948 
25949   // See if we can replace a shuffle with an insert_subvector.
25950   // e.g. v2i32 into v8i32:
25951   // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
25952   // --> insert_subvector(lhs,rhs1,4).
25953   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
25954       TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
25955     auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
25956       // Ensure RHS subvectors are legal.
25957       assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
25958       EVT SubVT = RHS.getOperand(0).getValueType();
25959       int NumSubVecs = RHS.getNumOperands();
25960       int NumSubElts = SubVT.getVectorNumElements();
25961       assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
25962       if (!TLI.isTypeLegal(SubVT))
25963         return SDValue();
25964 
25965       // Don't bother if we have an unary shuffle (matches undef + LHS elts).
25966       if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
25967         return SDValue();
25968 
25969       // Search [NumSubElts] spans for RHS sequence.
25970       // TODO: Can we avoid nested loops to increase performance?
25971       SmallVector<int> InsertionMask(NumElts);
25972       for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
25973         for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
25974           // Reset mask to identity.
25975           std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
25976 
25977           // Add subvector insertion.
25978           std::iota(InsertionMask.begin() + SubIdx,
25979                     InsertionMask.begin() + SubIdx + NumSubElts,
25980                     NumElts + (SubVec * NumSubElts));
25981 
25982           // See if the shuffle mask matches the reference insertion mask.
25983           bool MatchingShuffle = true;
25984           for (int i = 0; i != (int)NumElts; ++i) {
25985             int ExpectIdx = InsertionMask[i];
25986             int ActualIdx = Mask[i];
25987             if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
25988               MatchingShuffle = false;
25989               break;
25990             }
25991           }
25992 
25993           if (MatchingShuffle)
25994             return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, LHS,
25995                                RHS.getOperand(SubVec),
25996                                DAG.getVectorIdxConstant(SubIdx, SDLoc(N)));
25997         }
25998       }
25999       return SDValue();
26000     };
26001     ArrayRef<int> Mask = SVN->getMask();
26002     if (N1.getOpcode() == ISD::CONCAT_VECTORS)
26003       if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
26004         return InsertN1;
26005     if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
26006       SmallVector<int> CommuteMask(Mask);
26007       ShuffleVectorSDNode::commuteMask(CommuteMask);
26008       if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
26009         return InsertN0;
26010     }
26011   }
26012 
26013   // If we're not performing a select/blend shuffle, see if we can convert the
26014   // shuffle into a AND node, with all the out-of-lane elements are known zero.
26015   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
26016     bool IsInLaneMask = true;
26017     ArrayRef<int> Mask = SVN->getMask();
26018     SmallVector<int, 16> ClearMask(NumElts, -1);
26019     APInt DemandedLHS = APInt::getZero(NumElts);
26020     APInt DemandedRHS = APInt::getZero(NumElts);
26021     for (int I = 0; I != (int)NumElts; ++I) {
26022       int M = Mask[I];
26023       if (M < 0)
26024         continue;
26025       ClearMask[I] = M == I ? I : (I + NumElts);
26026       IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
26027       if (M != I) {
26028         APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
26029         Demanded.setBit(M % NumElts);
26030       }
26031     }
26032     // TODO: Should we try to mask with N1 as well?
26033     if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
26034         (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
26035         (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
26036       SDLoc DL(N);
26037       EVT IntVT = VT.changeVectorElementTypeToInteger();
26038       EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
26039       // Transform the type to a legal type so that the buildvector constant
26040       // elements are not illegal. Make sure that the result is larger than the
26041       // original type, incase the value is split into two (eg i64->i32).
26042       if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
26043         IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
26044       if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
26045         SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
26046         SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
26047         SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
26048         for (int I = 0; I != (int)NumElts; ++I)
26049           if (0 <= Mask[I])
26050             AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
26051 
26052         // See if a clear mask is legal instead of going via
26053         // XformToShuffleWithZero which loses UNDEF mask elements.
26054         if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
26055           return DAG.getBitcast(
26056               VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
26057                                       DAG.getConstant(0, DL, IntVT), ClearMask));
26058 
26059         if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
26060           return DAG.getBitcast(
26061               VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
26062                               DAG.getBuildVector(IntVT, DL, AndMask)));
26063       }
26064     }
26065   }
26066 
26067   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
26068   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
26069   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
26070     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
26071       return Res;
26072 
26073   // If this shuffle only has a single input that is a bitcasted shuffle,
26074   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
26075   // back to their original types.
26076   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
26077       N1.isUndef() && Level < AfterLegalizeVectorOps &&
26078       TLI.isTypeLegal(VT)) {
26079 
26080     SDValue BC0 = peekThroughOneUseBitcasts(N0);
26081     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
26082       EVT SVT = VT.getScalarType();
26083       EVT InnerVT = BC0->getValueType(0);
26084       EVT InnerSVT = InnerVT.getScalarType();
26085 
26086       // Determine which shuffle works with the smaller scalar type.
26087       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
26088       EVT ScaleSVT = ScaleVT.getScalarType();
26089 
26090       if (TLI.isTypeLegal(ScaleVT) &&
26091           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
26092           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
26093         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
26094         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
26095 
26096         // Scale the shuffle masks to the smaller scalar type.
26097         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
26098         SmallVector<int, 8> InnerMask;
26099         SmallVector<int, 8> OuterMask;
26100         narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
26101         narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
26102 
26103         // Merge the shuffle masks.
26104         SmallVector<int, 8> NewMask;
26105         for (int M : OuterMask)
26106           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
26107 
26108         // Test for shuffle mask legality over both commutations.
26109         SDValue SV0 = BC0->getOperand(0);
26110         SDValue SV1 = BC0->getOperand(1);
26111         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
26112         if (!LegalMask) {
26113           std::swap(SV0, SV1);
26114           ShuffleVectorSDNode::commuteMask(NewMask);
26115           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
26116         }
26117 
26118         if (LegalMask) {
26119           SV0 = DAG.getBitcast(ScaleVT, SV0);
26120           SV1 = DAG.getBitcast(ScaleVT, SV1);
26121           return DAG.getBitcast(
26122               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
26123         }
26124       }
26125     }
26126   }
26127 
26128   // Match shuffles of bitcasts, so long as the mask can be treated as the
26129   // larger type.
26130   if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
26131     return V;
26132 
26133   // Compute the combined shuffle mask for a shuffle with SV0 as the first
26134   // operand, and SV1 as the second operand.
26135   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
26136   //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
26137   auto MergeInnerShuffle =
26138       [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
26139                      ShuffleVectorSDNode *OtherSVN, SDValue N1,
26140                      const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
26141                      SmallVectorImpl<int> &Mask) -> bool {
26142     // Don't try to fold splats; they're likely to simplify somehow, or they
26143     // might be free.
26144     if (OtherSVN->isSplat())
26145       return false;
26146 
26147     SV0 = SV1 = SDValue();
26148     Mask.clear();
26149 
26150     for (unsigned i = 0; i != NumElts; ++i) {
26151       int Idx = SVN->getMaskElt(i);
26152       if (Idx < 0) {
26153         // Propagate Undef.
26154         Mask.push_back(Idx);
26155         continue;
26156       }
26157 
26158       if (Commute)
26159         Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
26160 
26161       SDValue CurrentVec;
26162       if (Idx < (int)NumElts) {
26163         // This shuffle index refers to the inner shuffle N0. Lookup the inner
26164         // shuffle mask to identify which vector is actually referenced.
26165         Idx = OtherSVN->getMaskElt(Idx);
26166         if (Idx < 0) {
26167           // Propagate Undef.
26168           Mask.push_back(Idx);
26169           continue;
26170         }
26171         CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
26172                                           : OtherSVN->getOperand(1);
26173       } else {
26174         // This shuffle index references an element within N1.
26175         CurrentVec = N1;
26176       }
26177 
26178       // Simple case where 'CurrentVec' is UNDEF.
26179       if (CurrentVec.isUndef()) {
26180         Mask.push_back(-1);
26181         continue;
26182       }
26183 
26184       // Canonicalize the shuffle index. We don't know yet if CurrentVec
26185       // will be the first or second operand of the combined shuffle.
26186       Idx = Idx % NumElts;
26187       if (!SV0.getNode() || SV0 == CurrentVec) {
26188         // Ok. CurrentVec is the left hand side.
26189         // Update the mask accordingly.
26190         SV0 = CurrentVec;
26191         Mask.push_back(Idx);
26192         continue;
26193       }
26194       if (!SV1.getNode() || SV1 == CurrentVec) {
26195         // Ok. CurrentVec is the right hand side.
26196         // Update the mask accordingly.
26197         SV1 = CurrentVec;
26198         Mask.push_back(Idx + NumElts);
26199         continue;
26200       }
26201 
26202       // Last chance - see if the vector is another shuffle and if it
26203       // uses one of the existing candidate shuffle ops.
26204       if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
26205         int InnerIdx = CurrentSVN->getMaskElt(Idx);
26206         if (InnerIdx < 0) {
26207           Mask.push_back(-1);
26208           continue;
26209         }
26210         SDValue InnerVec = (InnerIdx < (int)NumElts)
26211                                ? CurrentSVN->getOperand(0)
26212                                : CurrentSVN->getOperand(1);
26213         if (InnerVec.isUndef()) {
26214           Mask.push_back(-1);
26215           continue;
26216         }
26217         InnerIdx %= NumElts;
26218         if (InnerVec == SV0) {
26219           Mask.push_back(InnerIdx);
26220           continue;
26221         }
26222         if (InnerVec == SV1) {
26223           Mask.push_back(InnerIdx + NumElts);
26224           continue;
26225         }
26226       }
26227 
26228       // Bail out if we cannot convert the shuffle pair into a single shuffle.
26229       return false;
26230     }
26231 
26232     if (llvm::all_of(Mask, [](int M) { return M < 0; }))
26233       return true;
26234 
26235     // Avoid introducing shuffles with illegal mask.
26236     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
26237     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
26238     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
26239     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
26240     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
26241     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
26242     if (TLI.isShuffleMaskLegal(Mask, VT))
26243       return true;
26244 
26245     std::swap(SV0, SV1);
26246     ShuffleVectorSDNode::commuteMask(Mask);
26247     return TLI.isShuffleMaskLegal(Mask, VT);
26248   };
26249 
26250   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
26251     // Canonicalize shuffles according to rules:
26252     //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
26253     //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
26254     //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
26255     if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
26256         N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
26257       // The incoming shuffle must be of the same type as the result of the
26258       // current shuffle.
26259       assert(N1->getOperand(0).getValueType() == VT &&
26260              "Shuffle types don't match");
26261 
26262       SDValue SV0 = N1->getOperand(0);
26263       SDValue SV1 = N1->getOperand(1);
26264       bool HasSameOp0 = N0 == SV0;
26265       bool IsSV1Undef = SV1.isUndef();
26266       if (HasSameOp0 || IsSV1Undef || N0 == SV1)
26267         // Commute the operands of this shuffle so merging below will trigger.
26268         return DAG.getCommutedVectorShuffle(*SVN);
26269     }
26270 
26271     // Canonicalize splat shuffles to the RHS to improve merging below.
26272     //  shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
26273     if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
26274         N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
26275         cast<ShuffleVectorSDNode>(N0)->isSplat() &&
26276         !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
26277       return DAG.getCommutedVectorShuffle(*SVN);
26278     }
26279 
26280     // Try to fold according to rules:
26281     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
26282     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
26283     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
26284     // Don't try to fold shuffles with illegal type.
26285     // Only fold if this shuffle is the only user of the other shuffle.
26286     // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
26287     for (int i = 0; i != 2; ++i) {
26288       if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
26289           N->isOnlyUserOf(N->getOperand(i).getNode())) {
26290         // The incoming shuffle must be of the same type as the result of the
26291         // current shuffle.
26292         auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
26293         assert(OtherSV->getOperand(0).getValueType() == VT &&
26294                "Shuffle types don't match");
26295 
26296         SDValue SV0, SV1;
26297         SmallVector<int, 4> Mask;
26298         if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
26299                               SV0, SV1, Mask)) {
26300           // Check if all indices in Mask are Undef. In case, propagate Undef.
26301           if (llvm::all_of(Mask, [](int M) { return M < 0; }))
26302             return DAG.getUNDEF(VT);
26303 
26304           return DAG.getVectorShuffle(VT, SDLoc(N),
26305                                       SV0 ? SV0 : DAG.getUNDEF(VT),
26306                                       SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
26307         }
26308       }
26309     }
26310 
26311     // Merge shuffles through binops if we are able to merge it with at least
26312     // one other shuffles.
26313     // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
26314     // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
26315     unsigned SrcOpcode = N0.getOpcode();
26316     if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
26317         (N1.isUndef() ||
26318          (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
26319       // Get binop source ops, or just pass on the undef.
26320       SDValue Op00 = N0.getOperand(0);
26321       SDValue Op01 = N0.getOperand(1);
26322       SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
26323       SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
26324       // TODO: We might be able to relax the VT check but we don't currently
26325       // have any isBinOp() that has different result/ops VTs so play safe until
26326       // we have test coverage.
26327       if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
26328           Op01.getValueType() == VT && Op11.getValueType() == VT &&
26329           (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
26330            Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
26331            Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
26332            Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
26333         auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
26334                                         SmallVectorImpl<int> &Mask, bool LeftOp,
26335                                         bool Commute) {
26336           SDValue InnerN = Commute ? N1 : N0;
26337           SDValue Op0 = LeftOp ? Op00 : Op01;
26338           SDValue Op1 = LeftOp ? Op10 : Op11;
26339           if (Commute)
26340             std::swap(Op0, Op1);
26341           // Only accept the merged shuffle if we don't introduce undef elements,
26342           // or the inner shuffle already contained undef elements.
26343           auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
26344           return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
26345                  MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
26346                                    Mask) &&
26347                  (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
26348                   llvm::none_of(Mask, [](int M) { return M < 0; }));
26349         };
26350 
26351         // Ensure we don't increase the number of shuffles - we must merge a
26352         // shuffle from at least one of the LHS and RHS ops.
26353         bool MergedLeft = false;
26354         SDValue LeftSV0, LeftSV1;
26355         SmallVector<int, 4> LeftMask;
26356         if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
26357             CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
26358           MergedLeft = true;
26359         } else {
26360           LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
26361           LeftSV0 = Op00, LeftSV1 = Op10;
26362         }
26363 
26364         bool MergedRight = false;
26365         SDValue RightSV0, RightSV1;
26366         SmallVector<int, 4> RightMask;
26367         if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
26368             CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
26369           MergedRight = true;
26370         } else {
26371           RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
26372           RightSV0 = Op01, RightSV1 = Op11;
26373         }
26374 
26375         if (MergedLeft || MergedRight) {
26376           SDLoc DL(N);
26377           SDValue LHS = DAG.getVectorShuffle(
26378               VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
26379               LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
26380           SDValue RHS = DAG.getVectorShuffle(
26381               VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
26382               RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
26383           return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
26384         }
26385       }
26386     }
26387   }
26388 
26389   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
26390     return V;
26391 
26392   // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
26393   // Perform this really late, because it could eliminate knowledge
26394   // of undef elements created by this shuffle.
26395   if (Level < AfterLegalizeTypes)
26396     if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
26397                                                           LegalOperations))
26398       return V;
26399 
26400   return SDValue();
26401 }
26402 
visitSCALAR_TO_VECTOR(SDNode * N)26403 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
26404   EVT VT = N->getValueType(0);
26405   if (!VT.isFixedLengthVector())
26406     return SDValue();
26407 
26408   // Try to convert a scalar binop with an extracted vector element to a vector
26409   // binop. This is intended to reduce potentially expensive register moves.
26410   // TODO: Check if both operands are extracted.
26411   // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
26412   // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
26413   SDValue Scalar = N->getOperand(0);
26414   unsigned Opcode = Scalar.getOpcode();
26415   EVT VecEltVT = VT.getScalarType();
26416   if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
26417       TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
26418       Scalar.getOperand(0).getValueType() == VecEltVT &&
26419       Scalar.getOperand(1).getValueType() == VecEltVT &&
26420       Scalar->isOnlyUserOf(Scalar.getOperand(0).getNode()) &&
26421       Scalar->isOnlyUserOf(Scalar.getOperand(1).getNode()) &&
26422       DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
26423     // Match an extract element and get a shuffle mask equivalent.
26424     SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
26425 
26426     for (int i : {0, 1}) {
26427       // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
26428       // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
26429       SDValue EE = Scalar.getOperand(i);
26430       auto *C = dyn_cast<ConstantSDNode>(Scalar.getOperand(i ? 0 : 1));
26431       if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
26432           EE.getOperand(0).getValueType() == VT &&
26433           isa<ConstantSDNode>(EE.getOperand(1))) {
26434         // Mask = {ExtractIndex, undef, undef....}
26435         ShufMask[0] = EE.getConstantOperandVal(1);
26436         // Make sure the shuffle is legal if we are crossing lanes.
26437         if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
26438           SDLoc DL(N);
26439           SDValue V[] = {EE.getOperand(0),
26440                          DAG.getConstant(C->getAPIntValue(), DL, VT)};
26441           SDValue VecBO = DAG.getNode(Opcode, DL, VT, V[i], V[1 - i]);
26442           return DAG.getVectorShuffle(VT, DL, VecBO, DAG.getUNDEF(VT),
26443                                       ShufMask);
26444         }
26445       }
26446     }
26447   }
26448 
26449   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
26450   // with a VECTOR_SHUFFLE and possible truncate.
26451   if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
26452       !Scalar.getOperand(0).getValueType().isFixedLengthVector())
26453     return SDValue();
26454 
26455   // If we have an implicit truncate, truncate here if it is legal.
26456   if (VecEltVT != Scalar.getValueType() &&
26457       Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
26458     SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
26459     return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
26460   }
26461 
26462   auto *ExtIndexC = dyn_cast<ConstantSDNode>(Scalar.getOperand(1));
26463   if (!ExtIndexC)
26464     return SDValue();
26465 
26466   SDValue SrcVec = Scalar.getOperand(0);
26467   EVT SrcVT = SrcVec.getValueType();
26468   unsigned SrcNumElts = SrcVT.getVectorNumElements();
26469   unsigned VTNumElts = VT.getVectorNumElements();
26470   if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
26471     // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
26472     SmallVector<int, 8> Mask(SrcNumElts, -1);
26473     Mask[0] = ExtIndexC->getZExtValue();
26474     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
26475         SrcVT, SDLoc(N), SrcVec, DAG.getUNDEF(SrcVT), Mask, DAG);
26476     if (!LegalShuffle)
26477       return SDValue();
26478 
26479     // If the initial vector is the same size, the shuffle is the result.
26480     if (VT == SrcVT)
26481       return LegalShuffle;
26482 
26483     // If not, shorten the shuffled vector.
26484     if (VTNumElts != SrcNumElts) {
26485       SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
26486       EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
26487                                    SrcVT.getVectorElementType(), VTNumElts);
26488       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, LegalShuffle,
26489                          ZeroIdx);
26490     }
26491   }
26492 
26493   return SDValue();
26494 }
26495 
visitINSERT_SUBVECTOR(SDNode * N)26496 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
26497   EVT VT = N->getValueType(0);
26498   SDValue N0 = N->getOperand(0);
26499   SDValue N1 = N->getOperand(1);
26500   SDValue N2 = N->getOperand(2);
26501   uint64_t InsIdx = N->getConstantOperandVal(2);
26502 
26503   // If inserting an UNDEF, just return the original vector.
26504   if (N1.isUndef())
26505     return N0;
26506 
26507   // If this is an insert of an extracted vector into an undef vector, we can
26508   // just use the input to the extract if the types match, and can simplify
26509   // in some cases even if they don't.
26510   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26511       N1.getOperand(1) == N2) {
26512     EVT SrcVT = N1.getOperand(0).getValueType();
26513     if (SrcVT == VT)
26514       return N1.getOperand(0);
26515     // TODO: To remove the zero check, need to adjust the offset to
26516     // a multiple of the new src type.
26517     if (isNullConstant(N2)) {
26518       if (VT.knownBitsGE(SrcVT) &&
26519           !(VT.isFixedLengthVector() && SrcVT.isScalableVector()))
26520         return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
26521                            VT, N0, N1.getOperand(0), N2);
26522       else if (VT.knownBitsLE(SrcVT) &&
26523                !(VT.isScalableVector() && SrcVT.isFixedLengthVector()))
26524         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N),
26525                            VT, N1.getOperand(0), N2);
26526     }
26527   }
26528 
26529   // Handle case where we've ended up inserting back into the source vector
26530   // we extracted the subvector from.
26531   // insert_subvector(N0, extract_subvector(N0, N2), N2) --> N0
26532   if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(0) == N0 &&
26533       N1.getOperand(1) == N2)
26534     return N0;
26535 
26536   // Simplify scalar inserts into an undef vector:
26537   // insert_subvector undef, (splat X), N2 -> splat X
26538   if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
26539     if (DAG.isConstantValueOfAnyType(N1.getOperand(0)) || N1.hasOneUse())
26540       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
26541 
26542   // If we are inserting a bitcast value into an undef, with the same
26543   // number of elements, just use the bitcast input of the extract.
26544   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
26545   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
26546   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
26547       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26548       N1.getOperand(0).getOperand(1) == N2 &&
26549       N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
26550           VT.getVectorElementCount() &&
26551       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
26552           VT.getSizeInBits()) {
26553     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
26554   }
26555 
26556   // If both N1 and N2 are bitcast values on which insert_subvector
26557   // would makes sense, pull the bitcast through.
26558   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
26559   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
26560   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
26561     SDValue CN0 = N0.getOperand(0);
26562     SDValue CN1 = N1.getOperand(0);
26563     EVT CN0VT = CN0.getValueType();
26564     EVT CN1VT = CN1.getValueType();
26565     if (CN0VT.isVector() && CN1VT.isVector() &&
26566         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
26567         CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
26568       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
26569                                       CN0.getValueType(), CN0, CN1, N2);
26570       return DAG.getBitcast(VT, NewINSERT);
26571     }
26572   }
26573 
26574   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
26575   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
26576   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
26577   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26578       N0.getOperand(1).getValueType() == N1.getValueType() &&
26579       N0.getOperand(2) == N2)
26580     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
26581                        N1, N2);
26582 
26583   // Eliminate an intermediate insert into an undef vector:
26584   // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
26585   // insert_subvector undef, X, 0
26586   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
26587       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)) &&
26588       isNullConstant(N2))
26589     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
26590                        N1.getOperand(1), N2);
26591 
26592   // Push subvector bitcasts to the output, adjusting the index as we go.
26593   // insert_subvector(bitcast(v), bitcast(s), c1)
26594   // -> bitcast(insert_subvector(v, s, c2))
26595   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
26596       N1.getOpcode() == ISD::BITCAST) {
26597     SDValue N0Src = peekThroughBitcasts(N0);
26598     SDValue N1Src = peekThroughBitcasts(N1);
26599     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
26600     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
26601     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
26602         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
26603       EVT NewVT;
26604       SDLoc DL(N);
26605       SDValue NewIdx;
26606       LLVMContext &Ctx = *DAG.getContext();
26607       ElementCount NumElts = VT.getVectorElementCount();
26608       unsigned EltSizeInBits = VT.getScalarSizeInBits();
26609       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
26610         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
26611         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
26612         NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
26613       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
26614         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
26615         if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
26616           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
26617                                    NumElts.divideCoefficientBy(Scale));
26618           NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
26619         }
26620       }
26621       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
26622         SDValue Res = DAG.getBitcast(NewVT, N0Src);
26623         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
26624         return DAG.getBitcast(VT, Res);
26625       }
26626     }
26627   }
26628 
26629   // Canonicalize insert_subvector dag nodes.
26630   // Example:
26631   // (insert_subvector (insert_subvector A, Idx0), Idx1)
26632   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
26633   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
26634       N1.getValueType() == N0.getOperand(1).getValueType()) {
26635     unsigned OtherIdx = N0.getConstantOperandVal(2);
26636     if (InsIdx < OtherIdx) {
26637       // Swap nodes.
26638       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
26639                                   N0.getOperand(0), N1, N2);
26640       AddToWorklist(NewOp.getNode());
26641       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
26642                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
26643     }
26644   }
26645 
26646   // If the input vector is a concatenation, and the insert replaces
26647   // one of the pieces, we can optimize into a single concat_vectors.
26648   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
26649       N0.getOperand(0).getValueType() == N1.getValueType() &&
26650       N0.getOperand(0).getValueType().isScalableVector() ==
26651           N1.getValueType().isScalableVector()) {
26652     unsigned Factor = N1.getValueType().getVectorMinNumElements();
26653     SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
26654     Ops[InsIdx / Factor] = N1;
26655     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
26656   }
26657 
26658   // Simplify source operands based on insertion.
26659   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
26660     return SDValue(N, 0);
26661 
26662   return SDValue();
26663 }
26664 
visitFP_TO_FP16(SDNode * N)26665 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
26666   SDValue N0 = N->getOperand(0);
26667 
26668   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
26669   if (N0->getOpcode() == ISD::FP16_TO_FP)
26670     return N0->getOperand(0);
26671 
26672   return SDValue();
26673 }
26674 
visitFP16_TO_FP(SDNode * N)26675 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
26676   auto Op = N->getOpcode();
26677   assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
26678          "opcode should be FP16_TO_FP or BF16_TO_FP.");
26679   SDValue N0 = N->getOperand(0);
26680 
26681   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
26682   // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26683   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
26684     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
26685     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
26686       return DAG.getNode(Op, SDLoc(N), N->getValueType(0), N0.getOperand(0));
26687     }
26688   }
26689 
26690   // Sometimes constants manage to survive very late in the pipeline, e.g.,
26691   // because they are wrapped inside the <1 x f16> type. Try one last time to
26692   // get rid of them.
26693   SDValue Folded = DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N),
26694                                               N->getValueType(0), {N0});
26695   return Folded;
26696 }
26697 
visitFP_TO_BF16(SDNode * N)26698 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
26699   SDValue N0 = N->getOperand(0);
26700 
26701   // fold (fp_to_bf16 (bf16_to_fp op)) -> op
26702   if (N0->getOpcode() == ISD::BF16_TO_FP)
26703     return N0->getOperand(0);
26704 
26705   return SDValue();
26706 }
26707 
visitBF16_TO_FP(SDNode * N)26708 SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
26709   // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26710   return visitFP16_TO_FP(N);
26711 }
26712 
visitVECREDUCE(SDNode * N)26713 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
26714   SDValue N0 = N->getOperand(0);
26715   EVT VT = N0.getValueType();
26716   unsigned Opcode = N->getOpcode();
26717 
26718   // VECREDUCE over 1-element vector is just an extract.
26719   if (VT.getVectorElementCount().isScalar()) {
26720     SDLoc dl(N);
26721     SDValue Res =
26722         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
26723                     DAG.getVectorIdxConstant(0, dl));
26724     if (Res.getValueType() != N->getValueType(0))
26725       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
26726     return Res;
26727   }
26728 
26729   // On an boolean vector an and/or reduction is the same as a umin/umax
26730   // reduction. Convert them if the latter is legal while the former isn't.
26731   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
26732     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
26733         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
26734     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
26735         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
26736         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
26737       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
26738   }
26739 
26740   // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
26741   // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
26742   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26743       TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
26744     SDValue Vec = N0.getOperand(0);
26745     SDValue Subvec = N0.getOperand(1);
26746     if ((Opcode == ISD::VECREDUCE_OR &&
26747          (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
26748         (Opcode == ISD::VECREDUCE_AND &&
26749          (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
26750       return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
26751   }
26752 
26753   return SDValue();
26754 }
26755 
visitVP_FSUB(SDNode * N)26756 SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
26757   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
26758 
26759   // FSUB -> FMA combines:
26760   if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
26761     AddToWorklist(Fused.getNode());
26762     return Fused;
26763   }
26764   return SDValue();
26765 }
26766 
visitVPOp(SDNode * N)26767 SDValue DAGCombiner::visitVPOp(SDNode *N) {
26768 
26769   if (N->getOpcode() == ISD::VP_GATHER)
26770     if (SDValue SD = visitVPGATHER(N))
26771       return SD;
26772 
26773   if (N->getOpcode() == ISD::VP_SCATTER)
26774     if (SDValue SD = visitVPSCATTER(N))
26775       return SD;
26776 
26777   if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
26778     if (SDValue SD = visitVP_STRIDED_LOAD(N))
26779       return SD;
26780 
26781   if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
26782     if (SDValue SD = visitVP_STRIDED_STORE(N))
26783       return SD;
26784 
26785   // VP operations in which all vector elements are disabled - either by
26786   // determining that the mask is all false or that the EVL is 0 - can be
26787   // eliminated.
26788   bool AreAllEltsDisabled = false;
26789   if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
26790     AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
26791   if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
26792     AreAllEltsDisabled |=
26793         ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
26794 
26795   // This is the only generic VP combine we support for now.
26796   if (!AreAllEltsDisabled) {
26797     switch (N->getOpcode()) {
26798     case ISD::VP_FADD:
26799       return visitVP_FADD(N);
26800     case ISD::VP_FSUB:
26801       return visitVP_FSUB(N);
26802     case ISD::VP_FMA:
26803       return visitFMA<VPMatchContext>(N);
26804     case ISD::VP_SELECT:
26805       return visitVP_SELECT(N);
26806     case ISD::VP_MUL:
26807       return visitMUL<VPMatchContext>(N);
26808     default:
26809       break;
26810     }
26811     return SDValue();
26812   }
26813 
26814   // Binary operations can be replaced by UNDEF.
26815   if (ISD::isVPBinaryOp(N->getOpcode()))
26816     return DAG.getUNDEF(N->getValueType(0));
26817 
26818   // VP Memory operations can be replaced by either the chain (stores) or the
26819   // chain + undef (loads).
26820   if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
26821     if (MemSD->writeMem())
26822       return MemSD->getChain();
26823     return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
26824   }
26825 
26826   // Reduction operations return the start operand when no elements are active.
26827   if (ISD::isVPReduction(N->getOpcode()))
26828     return N->getOperand(0);
26829 
26830   return SDValue();
26831 }
26832 
visitGET_FPENV_MEM(SDNode * N)26833 SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
26834   SDValue Chain = N->getOperand(0);
26835   SDValue Ptr = N->getOperand(1);
26836   EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
26837 
26838   // Check if the memory, where FP state is written to, is used only in a single
26839   // load operation.
26840   LoadSDNode *LdNode = nullptr;
26841   for (auto *U : Ptr->uses()) {
26842     if (U == N)
26843       continue;
26844     if (auto *Ld = dyn_cast<LoadSDNode>(U)) {
26845       if (LdNode && LdNode != Ld)
26846         return SDValue();
26847       LdNode = Ld;
26848       continue;
26849     }
26850     return SDValue();
26851   }
26852   if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26853       !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26854       !LdNode->getChain().reachesChainWithoutSideEffects(SDValue(N, 0)))
26855     return SDValue();
26856 
26857   // Check if the loaded value is used only in a store operation.
26858   StoreSDNode *StNode = nullptr;
26859   for (auto I = LdNode->use_begin(), E = LdNode->use_end(); I != E; ++I) {
26860     SDUse &U = I.getUse();
26861     if (U.getResNo() == 0) {
26862       if (auto *St = dyn_cast<StoreSDNode>(U.getUser())) {
26863         if (StNode)
26864           return SDValue();
26865         StNode = St;
26866       } else {
26867         return SDValue();
26868       }
26869     }
26870   }
26871   if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26872       !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26873       !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
26874     return SDValue();
26875 
26876   // Create new node GET_FPENV_MEM, which uses the store address to write FP
26877   // environment.
26878   SDValue Res = DAG.getGetFPEnv(Chain, SDLoc(N), StNode->getBasePtr(), MemVT,
26879                                 StNode->getMemOperand());
26880   CombineTo(StNode, Res, false);
26881   return Res;
26882 }
26883 
visitSET_FPENV_MEM(SDNode * N)26884 SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
26885   SDValue Chain = N->getOperand(0);
26886   SDValue Ptr = N->getOperand(1);
26887   EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
26888 
26889   // Check if the address of FP state is used also in a store operation only.
26890   StoreSDNode *StNode = nullptr;
26891   for (auto *U : Ptr->uses()) {
26892     if (U == N)
26893       continue;
26894     if (auto *St = dyn_cast<StoreSDNode>(U)) {
26895       if (StNode && StNode != St)
26896         return SDValue();
26897       StNode = St;
26898       continue;
26899     }
26900     return SDValue();
26901   }
26902   if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26903       !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26904       !Chain.reachesChainWithoutSideEffects(SDValue(StNode, 0)))
26905     return SDValue();
26906 
26907   // Check if the stored value is loaded from some location and the loaded
26908   // value is used only in the store operation.
26909   SDValue StValue = StNode->getValue();
26910   auto *LdNode = dyn_cast<LoadSDNode>(StValue);
26911   if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26912       !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26913       !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
26914     return SDValue();
26915 
26916   // Create new node SET_FPENV_MEM, which uses the load address to read FP
26917   // environment.
26918   SDValue Res =
26919       DAG.getSetFPEnv(LdNode->getChain(), SDLoc(N), LdNode->getBasePtr(), MemVT,
26920                       LdNode->getMemOperand());
26921   return Res;
26922 }
26923 
26924 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
26925 /// with the destination vector and a zero vector.
26926 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
26927 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)26928 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
26929   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
26930 
26931   EVT VT = N->getValueType(0);
26932   SDValue LHS = N->getOperand(0);
26933   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
26934   SDLoc DL(N);
26935 
26936   // Make sure we're not running after operation legalization where it
26937   // may have custom lowered the vector shuffles.
26938   if (LegalOperations)
26939     return SDValue();
26940 
26941   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
26942     return SDValue();
26943 
26944   EVT RVT = RHS.getValueType();
26945   unsigned NumElts = RHS.getNumOperands();
26946 
26947   // Attempt to create a valid clear mask, splitting the mask into
26948   // sub elements and checking to see if each is
26949   // all zeros or all ones - suitable for shuffle masking.
26950   auto BuildClearMask = [&](int Split) {
26951     int NumSubElts = NumElts * Split;
26952     int NumSubBits = RVT.getScalarSizeInBits() / Split;
26953 
26954     SmallVector<int, 8> Indices;
26955     for (int i = 0; i != NumSubElts; ++i) {
26956       int EltIdx = i / Split;
26957       int SubIdx = i % Split;
26958       SDValue Elt = RHS.getOperand(EltIdx);
26959       // X & undef --> 0 (not undef). So this lane must be converted to choose
26960       // from the zero constant vector (same as if the element had all 0-bits).
26961       if (Elt.isUndef()) {
26962         Indices.push_back(i + NumSubElts);
26963         continue;
26964       }
26965 
26966       APInt Bits;
26967       if (auto *Cst = dyn_cast<ConstantSDNode>(Elt))
26968         Bits = Cst->getAPIntValue();
26969       else if (auto *CstFP = dyn_cast<ConstantFPSDNode>(Elt))
26970         Bits = CstFP->getValueAPF().bitcastToAPInt();
26971       else
26972         return SDValue();
26973 
26974       // Extract the sub element from the constant bit mask.
26975       if (DAG.getDataLayout().isBigEndian())
26976         Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
26977       else
26978         Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
26979 
26980       if (Bits.isAllOnes())
26981         Indices.push_back(i);
26982       else if (Bits == 0)
26983         Indices.push_back(i + NumSubElts);
26984       else
26985         return SDValue();
26986     }
26987 
26988     // Let's see if the target supports this vector_shuffle.
26989     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
26990     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
26991     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
26992       return SDValue();
26993 
26994     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
26995     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
26996                                                    DAG.getBitcast(ClearVT, LHS),
26997                                                    Zero, Indices));
26998   };
26999 
27000   // Determine maximum split level (byte level masking).
27001   int MaxSplit = 1;
27002   if (RVT.getScalarSizeInBits() % 8 == 0)
27003     MaxSplit = RVT.getScalarSizeInBits() / 8;
27004 
27005   for (int Split = 1; Split <= MaxSplit; ++Split)
27006     if (RVT.getScalarSizeInBits() % Split == 0)
27007       if (SDValue S = BuildClearMask(Split))
27008         return S;
27009 
27010   return SDValue();
27011 }
27012 
27013 /// If a vector binop is performed on splat values, it may be profitable to
27014 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)27015 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
27016                                       const SDLoc &DL) {
27017   SDValue N0 = N->getOperand(0);
27018   SDValue N1 = N->getOperand(1);
27019   unsigned Opcode = N->getOpcode();
27020   EVT VT = N->getValueType(0);
27021   EVT EltVT = VT.getVectorElementType();
27022   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
27023 
27024   // TODO: Remove/replace the extract cost check? If the elements are available
27025   //       as scalars, then there may be no extract cost. Should we ask if
27026   //       inserting a scalar back into a vector is cheap instead?
27027   int Index0, Index1;
27028   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
27029   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
27030   // Extract element from splat_vector should be free.
27031   // TODO: use DAG.isSplatValue instead?
27032   bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
27033                            N1.getOpcode() == ISD::SPLAT_VECTOR;
27034   if (!Src0 || !Src1 || Index0 != Index1 ||
27035       Src0.getValueType().getVectorElementType() != EltVT ||
27036       Src1.getValueType().getVectorElementType() != EltVT ||
27037       !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
27038       !TLI.isOperationLegalOrCustom(Opcode, EltVT))
27039     return SDValue();
27040 
27041   SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
27042   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
27043   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
27044   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
27045 
27046   // If all lanes but 1 are undefined, no need to splat the scalar result.
27047   // TODO: Keep track of undefs and use that info in the general case.
27048   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
27049       count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
27050       count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
27051     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
27052     // build_vec ..undef, (bo X, Y), undef...
27053     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
27054     Ops[Index0] = ScalarBO;
27055     return DAG.getBuildVector(VT, DL, Ops);
27056   }
27057 
27058   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
27059   return DAG.getSplat(VT, DL, ScalarBO);
27060 }
27061 
27062 /// Visit a vector cast operation, like FP_EXTEND.
SimplifyVCastOp(SDNode * N,const SDLoc & DL)27063 SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
27064   EVT VT = N->getValueType(0);
27065   assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
27066   EVT EltVT = VT.getVectorElementType();
27067   unsigned Opcode = N->getOpcode();
27068 
27069   SDValue N0 = N->getOperand(0);
27070   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
27071 
27072   // TODO: promote operation might be also good here?
27073   int Index0;
27074   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
27075   if (Src0 &&
27076       (N0.getOpcode() == ISD::SPLAT_VECTOR ||
27077        TLI.isExtractVecEltCheap(VT, Index0)) &&
27078       TLI.isOperationLegalOrCustom(Opcode, EltVT) &&
27079       TLI.preferScalarizeSplat(N)) {
27080     EVT SrcVT = N0.getValueType();
27081     EVT SrcEltVT = SrcVT.getVectorElementType();
27082     SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
27083     SDValue Elt =
27084         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC);
27085     SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags());
27086     if (VT.isScalableVector())
27087       return DAG.getSplatVector(VT, DL, ScalarBO);
27088     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
27089     return DAG.getBuildVector(VT, DL, Ops);
27090   }
27091 
27092   return SDValue();
27093 }
27094 
27095 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)27096 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
27097   EVT VT = N->getValueType(0);
27098   assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
27099 
27100   SDValue LHS = N->getOperand(0);
27101   SDValue RHS = N->getOperand(1);
27102   unsigned Opcode = N->getOpcode();
27103   SDNodeFlags Flags = N->getFlags();
27104 
27105   // Move unary shuffles with identical masks after a vector binop:
27106   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
27107   //   --> shuffle (VBinOp A, B), Undef, Mask
27108   // This does not require type legality checks because we are creating the
27109   // same types of operations that are in the original sequence. We do have to
27110   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
27111   // though. This code is adapted from the identical transform in instcombine.
27112   if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
27113     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
27114     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
27115     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
27116         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
27117         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
27118       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
27119                                      RHS.getOperand(0), Flags);
27120       SDValue UndefV = LHS.getOperand(1);
27121       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
27122     }
27123 
27124     // Try to sink a splat shuffle after a binop with a uniform constant.
27125     // This is limited to cases where neither the shuffle nor the constant have
27126     // undefined elements because that could be poison-unsafe or inhibit
27127     // demanded elements analysis. It is further limited to not change a splat
27128     // of an inserted scalar because that may be optimized better by
27129     // load-folding or other target-specific behaviors.
27130     if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
27131         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
27132         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
27133       // binop (splat X), (splat C) --> splat (binop X, C)
27134       SDValue X = Shuf0->getOperand(0);
27135       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
27136       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
27137                                   Shuf0->getMask());
27138     }
27139     if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
27140         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
27141         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
27142       // binop (splat C), (splat X) --> splat (binop C, X)
27143       SDValue X = Shuf1->getOperand(0);
27144       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
27145       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
27146                                   Shuf1->getMask());
27147     }
27148   }
27149 
27150   // The following pattern is likely to emerge with vector reduction ops. Moving
27151   // the binary operation ahead of insertion may allow using a narrower vector
27152   // instruction that has better performance than the wide version of the op:
27153   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
27154   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
27155       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
27156       LHS.getOperand(2) == RHS.getOperand(2) &&
27157       (LHS.hasOneUse() || RHS.hasOneUse())) {
27158     SDValue X = LHS.getOperand(1);
27159     SDValue Y = RHS.getOperand(1);
27160     SDValue Z = LHS.getOperand(2);
27161     EVT NarrowVT = X.getValueType();
27162     if (NarrowVT == Y.getValueType() &&
27163         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
27164                                               LegalOperations)) {
27165       // (binop undef, undef) may not return undef, so compute that result.
27166       SDValue VecC =
27167           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
27168       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
27169       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
27170     }
27171   }
27172 
27173   // Make sure all but the first op are undef or constant.
27174   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
27175     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
27176            all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
27177              return Op.isUndef() ||
27178                     ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
27179            });
27180   };
27181 
27182   // The following pattern is likely to emerge with vector reduction ops. Moving
27183   // the binary operation ahead of the concat may allow using a narrower vector
27184   // instruction that has better performance than the wide version of the op:
27185   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
27186   //   concat (VBinOp X, Y), VecC
27187   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
27188       (LHS.hasOneUse() || RHS.hasOneUse())) {
27189     EVT NarrowVT = LHS.getOperand(0).getValueType();
27190     if (NarrowVT == RHS.getOperand(0).getValueType() &&
27191         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
27192       unsigned NumOperands = LHS.getNumOperands();
27193       SmallVector<SDValue, 4> ConcatOps;
27194       for (unsigned i = 0; i != NumOperands; ++i) {
27195         // This constant fold for operands 1 and up.
27196         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
27197                                         RHS.getOperand(i)));
27198       }
27199 
27200       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
27201     }
27202   }
27203 
27204   if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
27205     return V;
27206 
27207   return SDValue();
27208 }
27209 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)27210 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
27211                                     SDValue N2) {
27212   assert(N0.getOpcode() == ISD::SETCC &&
27213          "First argument must be a SetCC node!");
27214 
27215   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
27216                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
27217 
27218   // If we got a simplified select_cc node back from SimplifySelectCC, then
27219   // break it down into a new SETCC node, and a new SELECT node, and then return
27220   // the SELECT node, since we were called with a SELECT node.
27221   if (SCC.getNode()) {
27222     // Check to see if we got a select_cc back (to turn into setcc/select).
27223     // Otherwise, just return whatever node we got back, like fabs.
27224     if (SCC.getOpcode() == ISD::SELECT_CC) {
27225       const SDNodeFlags Flags = N0->getFlags();
27226       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
27227                                   N0.getValueType(),
27228                                   SCC.getOperand(0), SCC.getOperand(1),
27229                                   SCC.getOperand(4), Flags);
27230       AddToWorklist(SETCC.getNode());
27231       SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
27232                                          SCC.getOperand(2), SCC.getOperand(3));
27233       SelectNode->setFlags(Flags);
27234       return SelectNode;
27235     }
27236 
27237     return SCC;
27238   }
27239   return SDValue();
27240 }
27241 
27242 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
27243 /// being selected between, see if we can simplify the select.  Callers of this
27244 /// should assume that TheSelect is deleted if this returns true.  As such, they
27245 /// should return the appropriate thing (e.g. the node) back to the top-level of
27246 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)27247 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
27248                                     SDValue RHS) {
27249   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
27250   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
27251   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
27252     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
27253       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
27254       SDValue Sqrt = RHS;
27255       ISD::CondCode CC;
27256       SDValue CmpLHS;
27257       const ConstantFPSDNode *Zero = nullptr;
27258 
27259       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
27260         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
27261         CmpLHS = TheSelect->getOperand(0);
27262         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
27263       } else {
27264         // SELECT or VSELECT
27265         SDValue Cmp = TheSelect->getOperand(0);
27266         if (Cmp.getOpcode() == ISD::SETCC) {
27267           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
27268           CmpLHS = Cmp.getOperand(0);
27269           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
27270         }
27271       }
27272       if (Zero && Zero->isZero() &&
27273           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
27274           CC == ISD::SETULT || CC == ISD::SETLT)) {
27275         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
27276         CombineTo(TheSelect, Sqrt);
27277         return true;
27278       }
27279     }
27280   }
27281   // Cannot simplify select with vector condition
27282   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
27283 
27284   // If this is a select from two identical things, try to pull the operation
27285   // through the select.
27286   if (LHS.getOpcode() != RHS.getOpcode() ||
27287       !LHS.hasOneUse() || !RHS.hasOneUse())
27288     return false;
27289 
27290   // If this is a load and the token chain is identical, replace the select
27291   // of two loads with a load through a select of the address to load from.
27292   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
27293   // constants have been dropped into the constant pool.
27294   if (LHS.getOpcode() == ISD::LOAD) {
27295     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
27296     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
27297 
27298     // Token chains must be identical.
27299     if (LHS.getOperand(0) != RHS.getOperand(0) ||
27300         // Do not let this transformation reduce the number of volatile loads.
27301         // Be conservative for atomics for the moment
27302         // TODO: This does appear to be legal for unordered atomics (see D66309)
27303         !LLD->isSimple() || !RLD->isSimple() ||
27304         // FIXME: If either is a pre/post inc/dec load,
27305         // we'd need to split out the address adjustment.
27306         LLD->isIndexed() || RLD->isIndexed() ||
27307         // If this is an EXTLOAD, the VT's must match.
27308         LLD->getMemoryVT() != RLD->getMemoryVT() ||
27309         // If this is an EXTLOAD, the kind of extension must match.
27310         (LLD->getExtensionType() != RLD->getExtensionType() &&
27311          // The only exception is if one of the extensions is anyext.
27312          LLD->getExtensionType() != ISD::EXTLOAD &&
27313          RLD->getExtensionType() != ISD::EXTLOAD) ||
27314         // FIXME: this discards src value information.  This is
27315         // over-conservative. It would be beneficial to be able to remember
27316         // both potential memory locations.  Since we are discarding
27317         // src value info, don't do the transformation if the memory
27318         // locations are not in the default address space.
27319         LLD->getPointerInfo().getAddrSpace() != 0 ||
27320         RLD->getPointerInfo().getAddrSpace() != 0 ||
27321         // We can't produce a CMOV of a TargetFrameIndex since we won't
27322         // generate the address generation required.
27323         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
27324         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
27325         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
27326                                       LLD->getBasePtr().getValueType()))
27327       return false;
27328 
27329     // The loads must not depend on one another.
27330     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
27331       return false;
27332 
27333     // Check that the select condition doesn't reach either load.  If so,
27334     // folding this will induce a cycle into the DAG.  If not, this is safe to
27335     // xform, so create a select of the addresses.
27336 
27337     SmallPtrSet<const SDNode *, 32> Visited;
27338     SmallVector<const SDNode *, 16> Worklist;
27339 
27340     // Always fail if LLD and RLD are not independent. TheSelect is a
27341     // predecessor to all Nodes in question so we need not search past it.
27342 
27343     Visited.insert(TheSelect);
27344     Worklist.push_back(LLD);
27345     Worklist.push_back(RLD);
27346 
27347     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
27348         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
27349       return false;
27350 
27351     SDValue Addr;
27352     if (TheSelect->getOpcode() == ISD::SELECT) {
27353       // We cannot do this optimization if any pair of {RLD, LLD} is a
27354       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
27355       // Loads, we only need to check if CondNode is a successor to one of the
27356       // loads. We can further avoid this if there's no use of their chain
27357       // value.
27358       SDNode *CondNode = TheSelect->getOperand(0).getNode();
27359       Worklist.push_back(CondNode);
27360 
27361       if ((LLD->hasAnyUseOfValue(1) &&
27362            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
27363           (RLD->hasAnyUseOfValue(1) &&
27364            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
27365         return false;
27366 
27367       Addr = DAG.getSelect(SDLoc(TheSelect),
27368                            LLD->getBasePtr().getValueType(),
27369                            TheSelect->getOperand(0), LLD->getBasePtr(),
27370                            RLD->getBasePtr());
27371     } else {  // Otherwise SELECT_CC
27372       // We cannot do this optimization if any pair of {RLD, LLD} is a
27373       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
27374       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
27375       // one of the loads. We can further avoid this if there's no use of their
27376       // chain value.
27377 
27378       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
27379       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
27380       Worklist.push_back(CondLHS);
27381       Worklist.push_back(CondRHS);
27382 
27383       if ((LLD->hasAnyUseOfValue(1) &&
27384            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
27385           (RLD->hasAnyUseOfValue(1) &&
27386            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
27387         return false;
27388 
27389       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
27390                          LLD->getBasePtr().getValueType(),
27391                          TheSelect->getOperand(0),
27392                          TheSelect->getOperand(1),
27393                          LLD->getBasePtr(), RLD->getBasePtr(),
27394                          TheSelect->getOperand(4));
27395     }
27396 
27397     SDValue Load;
27398     // It is safe to replace the two loads if they have different alignments,
27399     // but the new load must be the minimum (most restrictive) alignment of the
27400     // inputs.
27401     Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
27402     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
27403     if (!RLD->isInvariant())
27404       MMOFlags &= ~MachineMemOperand::MOInvariant;
27405     if (!RLD->isDereferenceable())
27406       MMOFlags &= ~MachineMemOperand::MODereferenceable;
27407     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
27408       // FIXME: Discards pointer and AA info.
27409       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
27410                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
27411                          MMOFlags);
27412     } else {
27413       // FIXME: Discards pointer and AA info.
27414       Load = DAG.getExtLoad(
27415           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
27416                                                   : LLD->getExtensionType(),
27417           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
27418           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
27419     }
27420 
27421     // Users of the select now use the result of the load.
27422     CombineTo(TheSelect, Load);
27423 
27424     // Users of the old loads now use the new load's chain.  We know the
27425     // old-load value is dead now.
27426     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
27427     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
27428     return true;
27429   }
27430 
27431   return false;
27432 }
27433 
27434 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
27435 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)27436 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
27437                                             SDValue N1, SDValue N2, SDValue N3,
27438                                             ISD::CondCode CC) {
27439   // If this is a select where the false operand is zero and the compare is a
27440   // check of the sign bit, see if we can perform the "gzip trick":
27441   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
27442   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
27443   EVT XType = N0.getValueType();
27444   EVT AType = N2.getValueType();
27445   if (!isNullConstant(N3) || !XType.bitsGE(AType))
27446     return SDValue();
27447 
27448   // If the comparison is testing for a positive value, we have to invert
27449   // the sign bit mask, so only do that transform if the target has a bitwise
27450   // 'and not' instruction (the invert is free).
27451   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
27452     // (X > -1) ? A : 0
27453     // (X >  0) ? X : 0 <-- This is canonical signed max.
27454     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
27455       return SDValue();
27456   } else if (CC == ISD::SETLT) {
27457     // (X <  0) ? A : 0
27458     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
27459     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
27460       return SDValue();
27461   } else {
27462     return SDValue();
27463   }
27464 
27465   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
27466   // constant.
27467   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
27468   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
27469     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
27470     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
27471       SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
27472       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
27473       AddToWorklist(Shift.getNode());
27474 
27475       if (XType.bitsGT(AType)) {
27476         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
27477         AddToWorklist(Shift.getNode());
27478       }
27479 
27480       if (CC == ISD::SETGT)
27481         Shift = DAG.getNOT(DL, Shift, AType);
27482 
27483       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
27484     }
27485   }
27486 
27487   unsigned ShCt = XType.getSizeInBits() - 1;
27488   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
27489     return SDValue();
27490 
27491   SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
27492   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
27493   AddToWorklist(Shift.getNode());
27494 
27495   if (XType.bitsGT(AType)) {
27496     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
27497     AddToWorklist(Shift.getNode());
27498   }
27499 
27500   if (CC == ISD::SETGT)
27501     Shift = DAG.getNOT(DL, Shift, AType);
27502 
27503   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
27504 }
27505 
27506 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)27507 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
27508   SDValue N0 = N->getOperand(0);
27509   SDValue N1 = N->getOperand(1);
27510   SDValue N2 = N->getOperand(2);
27511   SDLoc DL(N);
27512 
27513   unsigned BinOpc = N1.getOpcode();
27514   if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc) ||
27515       (N1.getResNo() != N2.getResNo()))
27516     return SDValue();
27517 
27518   // The use checks are intentionally on SDNode because we may be dealing
27519   // with opcodes that produce more than one SDValue.
27520   // TODO: Do we really need to check N0 (the condition operand of the select)?
27521   //       But removing that clause could cause an infinite loop...
27522   if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
27523     return SDValue();
27524 
27525   // Binops may include opcodes that return multiple values, so all values
27526   // must be created/propagated from the newly created binops below.
27527   SDVTList OpVTs = N1->getVTList();
27528 
27529   // Fold select(cond, binop(x, y), binop(z, y))
27530   //  --> binop(select(cond, x, z), y)
27531   if (N1.getOperand(1) == N2.getOperand(1)) {
27532     SDValue N10 = N1.getOperand(0);
27533     SDValue N20 = N2.getOperand(0);
27534     SDValue NewSel = DAG.getSelect(DL, N10.getValueType(), N0, N10, N20);
27535     SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, NewSel, N1.getOperand(1));
27536     NewBinOp->setFlags(N1->getFlags());
27537     NewBinOp->intersectFlagsWith(N2->getFlags());
27538     return SDValue(NewBinOp.getNode(), N1.getResNo());
27539   }
27540 
27541   // Fold select(cond, binop(x, y), binop(x, z))
27542   //  --> binop(x, select(cond, y, z))
27543   if (N1.getOperand(0) == N2.getOperand(0)) {
27544     SDValue N11 = N1.getOperand(1);
27545     SDValue N21 = N2.getOperand(1);
27546     // Second op VT might be different (e.g. shift amount type)
27547     if (N11.getValueType() == N21.getValueType()) {
27548       SDValue NewSel = DAG.getSelect(DL, N11.getValueType(), N0, N11, N21);
27549       SDValue NewBinOp =
27550           DAG.getNode(BinOpc, DL, OpVTs, N1.getOperand(0), NewSel);
27551       NewBinOp->setFlags(N1->getFlags());
27552       NewBinOp->intersectFlagsWith(N2->getFlags());
27553       return SDValue(NewBinOp.getNode(), N1.getResNo());
27554     }
27555   }
27556 
27557   // TODO: Handle isCommutativeBinOp patterns as well?
27558   return SDValue();
27559 }
27560 
27561 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)27562 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
27563   SDValue N0 = N->getOperand(0);
27564   EVT VT = N->getValueType(0);
27565   bool IsFabs = N->getOpcode() == ISD::FABS;
27566   bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
27567 
27568   if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
27569     return SDValue();
27570 
27571   SDValue Int = N0.getOperand(0);
27572   EVT IntVT = Int.getValueType();
27573 
27574   // The operand to cast should be integer.
27575   if (!IntVT.isInteger() || IntVT.isVector())
27576     return SDValue();
27577 
27578   // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
27579   // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
27580   APInt SignMask;
27581   if (N0.getValueType().isVector()) {
27582     // For vector, create a sign mask (0x80...) or its inverse (for fabs,
27583     // 0x7f...) per element and splat it.
27584     SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
27585     if (IsFabs)
27586       SignMask = ~SignMask;
27587     SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
27588   } else {
27589     // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
27590     SignMask = APInt::getSignMask(IntVT.getSizeInBits());
27591     if (IsFabs)
27592       SignMask = ~SignMask;
27593   }
27594   SDLoc DL(N0);
27595   Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
27596                     DAG.getConstant(SignMask, DL, IntVT));
27597   AddToWorklist(Int.getNode());
27598   return DAG.getBitcast(VT, Int);
27599 }
27600 
27601 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
27602 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
27603 /// in it. This may be a win when the constant is not otherwise available
27604 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)27605 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
27606     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
27607     ISD::CondCode CC) {
27608   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
27609     return SDValue();
27610 
27611   // If we are before legalize types, we want the other legalization to happen
27612   // first (for example, to avoid messing with soft float).
27613   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
27614   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
27615   EVT VT = N2.getValueType();
27616   if (!TV || !FV || !TLI.isTypeLegal(VT))
27617     return SDValue();
27618 
27619   // If a constant can be materialized without loads, this does not make sense.
27620   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
27621       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
27622       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
27623     return SDValue();
27624 
27625   // If both constants have multiple uses, then we won't need to do an extra
27626   // load. The values are likely around in registers for other users.
27627   if (!TV->hasOneUse() && !FV->hasOneUse())
27628     return SDValue();
27629 
27630   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
27631                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
27632   Type *FPTy = Elts[0]->getType();
27633   const DataLayout &TD = DAG.getDataLayout();
27634 
27635   // Create a ConstantArray of the two constants.
27636   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
27637   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
27638                                       TD.getPrefTypeAlign(FPTy));
27639   Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
27640 
27641   // Get offsets to the 0 and 1 elements of the array, so we can select between
27642   // them.
27643   SDValue Zero = DAG.getIntPtrConstant(0, DL);
27644   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
27645   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
27646   SDValue Cond =
27647       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
27648   AddToWorklist(Cond.getNode());
27649   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
27650   AddToWorklist(CstOffset.getNode());
27651   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
27652   AddToWorklist(CPIdx.getNode());
27653   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
27654                      MachinePointerInfo::getConstantPool(
27655                          DAG.getMachineFunction()), Alignment);
27656 }
27657 
27658 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
27659 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)27660 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
27661                                       SDValue N2, SDValue N3, ISD::CondCode CC,
27662                                       bool NotExtCompare) {
27663   // (x ? y : y) -> y.
27664   if (N2 == N3) return N2;
27665 
27666   EVT CmpOpVT = N0.getValueType();
27667   EVT CmpResVT = getSetCCResultType(CmpOpVT);
27668   EVT VT = N2.getValueType();
27669   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
27670   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
27671   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
27672 
27673   // Determine if the condition we're dealing with is constant.
27674   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
27675     AddToWorklist(SCC.getNode());
27676     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
27677       // fold select_cc true, x, y -> x
27678       // fold select_cc false, x, y -> y
27679       return !(SCCC->isZero()) ? N2 : N3;
27680     }
27681   }
27682 
27683   if (SDValue V =
27684           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
27685     return V;
27686 
27687   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
27688     return V;
27689 
27690   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
27691   // where y is has a single bit set.
27692   // A plaintext description would be, we can turn the SELECT_CC into an AND
27693   // when the condition can be materialized as an all-ones register.  Any
27694   // single bit-test can be materialized as an all-ones register with
27695   // shift-left and shift-right-arith.
27696   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
27697       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
27698     SDValue AndLHS = N0->getOperand(0);
27699     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
27700     if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
27701       // Shift the tested bit over the sign bit.
27702       const APInt &AndMask = ConstAndRHS->getAPIntValue();
27703       if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
27704         unsigned ShCt = AndMask.getBitWidth() - 1;
27705         SDValue ShlAmt = DAG.getShiftAmountConstant(AndMask.countl_zero(), VT,
27706                                                     SDLoc(AndLHS));
27707         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
27708 
27709         // Now arithmetic right shift it all the way over, so the result is
27710         // either all-ones, or zero.
27711         SDValue ShrAmt = DAG.getShiftAmountConstant(ShCt, VT, SDLoc(Shl));
27712         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
27713 
27714         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
27715       }
27716     }
27717   }
27718 
27719   // fold select C, 16, 0 -> shl C, 4
27720   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
27721   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
27722 
27723   if ((Fold || Swap) &&
27724       TLI.getBooleanContents(CmpOpVT) ==
27725           TargetLowering::ZeroOrOneBooleanContent &&
27726       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
27727 
27728     if (Swap) {
27729       CC = ISD::getSetCCInverse(CC, CmpOpVT);
27730       std::swap(N2C, N3C);
27731     }
27732 
27733     // If the caller doesn't want us to simplify this into a zext of a compare,
27734     // don't do it.
27735     if (NotExtCompare && N2C->isOne())
27736       return SDValue();
27737 
27738     SDValue Temp, SCC;
27739     // zext (setcc n0, n1)
27740     if (LegalTypes) {
27741       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
27742       Temp = DAG.getZExtOrTrunc(SCC, SDLoc(N2), VT);
27743     } else {
27744       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
27745       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
27746     }
27747 
27748     AddToWorklist(SCC.getNode());
27749     AddToWorklist(Temp.getNode());
27750 
27751     if (N2C->isOne())
27752       return Temp;
27753 
27754     unsigned ShCt = N2C->getAPIntValue().logBase2();
27755     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
27756       return SDValue();
27757 
27758     // shl setcc result by log2 n2c
27759     return DAG.getNode(
27760         ISD::SHL, DL, N2.getValueType(), Temp,
27761         DAG.getShiftAmountConstant(ShCt, N2.getValueType(), SDLoc(Temp)));
27762   }
27763 
27764   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
27765   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
27766   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
27767   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
27768   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
27769   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
27770   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
27771   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
27772   if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
27773     SDValue ValueOnZero = N2;
27774     SDValue Count = N3;
27775     // If the condition is NE instead of E, swap the operands.
27776     if (CC == ISD::SETNE)
27777       std::swap(ValueOnZero, Count);
27778     // Check if the value on zero is a constant equal to the bits in the type.
27779     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
27780       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
27781         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
27782         // legal, combine to just cttz.
27783         if ((Count.getOpcode() == ISD::CTTZ ||
27784              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
27785             N0 == Count.getOperand(0) &&
27786             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
27787           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
27788         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
27789         // legal, combine to just ctlz.
27790         if ((Count.getOpcode() == ISD::CTLZ ||
27791              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
27792             N0 == Count.getOperand(0) &&
27793             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
27794           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
27795       }
27796     }
27797   }
27798 
27799   // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
27800   // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
27801   if (!NotExtCompare && N1C && N2C && N3C &&
27802       N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
27803       ((N1C->isAllOnes() && CC == ISD::SETGT) ||
27804        (N1C->isZero() && CC == ISD::SETLT)) &&
27805       !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
27806     SDValue ASR = DAG.getNode(
27807         ISD::SRA, DL, CmpOpVT, N0,
27808         DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
27809     return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
27810                        DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
27811   }
27812 
27813   if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27814     return S;
27815   if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27816     return S;
27817 
27818   return SDValue();
27819 }
27820 
27821 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)27822 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
27823                                    ISD::CondCode Cond, const SDLoc &DL,
27824                                    bool foldBooleans) {
27825   TargetLowering::DAGCombinerInfo
27826     DagCombineInfo(DAG, Level, false, this);
27827   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
27828 }
27829 
27830 /// Given an ISD::SDIV node expressing a divide by constant, return
27831 /// a DAG expression to select that will generate the same value by multiplying
27832 /// by a magic number.
27833 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)27834 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
27835   // when optimising for minimum size, we don't want to expand a div to a mul
27836   // and a shift.
27837   if (DAG.getMachineFunction().getFunction().hasMinSize())
27838     return SDValue();
27839 
27840   SmallVector<SDNode *, 8> Built;
27841   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
27842     for (SDNode *N : Built)
27843       AddToWorklist(N);
27844     return S;
27845   }
27846 
27847   return SDValue();
27848 }
27849 
27850 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
27851 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)27852 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
27853   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
27854   if (!C)
27855     return SDValue();
27856 
27857   // Avoid division by zero.
27858   if (C->isZero())
27859     return SDValue();
27860 
27861   SmallVector<SDNode *, 8> Built;
27862   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
27863     for (SDNode *N : Built)
27864       AddToWorklist(N);
27865     return S;
27866   }
27867 
27868   return SDValue();
27869 }
27870 
27871 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
27872 /// expression that will generate the same value by multiplying by a magic
27873 /// number.
27874 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)27875 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
27876   // when optimising for minimum size, we don't want to expand a div to a mul
27877   // and a shift.
27878   if (DAG.getMachineFunction().getFunction().hasMinSize())
27879     return SDValue();
27880 
27881   SmallVector<SDNode *, 8> Built;
27882   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
27883     for (SDNode *N : Built)
27884       AddToWorklist(N);
27885     return S;
27886   }
27887 
27888   return SDValue();
27889 }
27890 
27891 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
27892 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)27893 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
27894   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
27895   if (!C)
27896     return SDValue();
27897 
27898   // Avoid division by zero.
27899   if (C->isZero())
27900     return SDValue();
27901 
27902   SmallVector<SDNode *, 8> Built;
27903   if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
27904     for (SDNode *N : Built)
27905       AddToWorklist(N);
27906     return S;
27907   }
27908 
27909   return SDValue();
27910 }
27911 
27912 // This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
27913 //
27914 // Returns the node that represents `Log2(Op)`. This may create a new node. If
27915 // we are unable to compute `Log2(Op)` its return `SDValue()`.
27916 //
27917 // All nodes will be created at `DL` and the output will be of type `VT`.
27918 //
27919 // This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
27920 // `AssumeNonZero` if this function should simply assume (not require proving
27921 // `Op` is non-zero).
takeInexpensiveLog2(SelectionDAG & DAG,const SDLoc & DL,EVT VT,SDValue Op,unsigned Depth,bool AssumeNonZero)27922 static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
27923                                    SDValue Op, unsigned Depth,
27924                                    bool AssumeNonZero) {
27925   assert(VT.isInteger() && "Only integer types are supported!");
27926 
27927   auto PeekThroughCastsAndTrunc = [](SDValue V) {
27928     while (true) {
27929       switch (V.getOpcode()) {
27930       case ISD::TRUNCATE:
27931       case ISD::ZERO_EXTEND:
27932         V = V.getOperand(0);
27933         break;
27934       default:
27935         return V;
27936       }
27937     }
27938   };
27939 
27940   if (VT.isScalableVector())
27941     return SDValue();
27942 
27943   Op = PeekThroughCastsAndTrunc(Op);
27944 
27945   // Helper for determining whether a value is a power-2 constant scalar or a
27946   // vector of such elements.
27947   SmallVector<APInt> Pow2Constants;
27948   auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
27949     if (C->isZero() || C->isOpaque())
27950       return false;
27951     // TODO: We may also be able to support negative powers of 2 here.
27952     if (C->getAPIntValue().isPowerOf2()) {
27953       Pow2Constants.emplace_back(C->getAPIntValue());
27954       return true;
27955     }
27956     return false;
27957   };
27958 
27959   if (ISD::matchUnaryPredicate(Op, IsPowerOfTwo)) {
27960     if (!VT.isVector())
27961       return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
27962     // We need to create a build vector
27963     if (Op.getOpcode() == ISD::SPLAT_VECTOR)
27964       return DAG.getSplat(VT, DL,
27965                           DAG.getConstant(Pow2Constants.back().logBase2(), DL,
27966                                           VT.getScalarType()));
27967     SmallVector<SDValue> Log2Ops;
27968     for (const APInt &Pow2 : Pow2Constants)
27969       Log2Ops.emplace_back(
27970           DAG.getConstant(Pow2.logBase2(), DL, VT.getScalarType()));
27971     return DAG.getBuildVector(VT, DL, Log2Ops);
27972   }
27973 
27974   if (Depth >= DAG.MaxRecursionDepth)
27975     return SDValue();
27976 
27977   auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
27978     ToCast = PeekThroughCastsAndTrunc(ToCast);
27979     EVT CurVT = ToCast.getValueType();
27980     if (NewVT == CurVT)
27981       return ToCast;
27982 
27983     if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
27984       return DAG.getBitcast(NewVT, ToCast);
27985 
27986     return DAG.getZExtOrTrunc(ToCast, DL, NewVT);
27987   };
27988 
27989   // log2(X << Y) -> log2(X) + Y
27990   if (Op.getOpcode() == ISD::SHL) {
27991     // 1 << Y and X nuw/nsw << Y are all non-zero.
27992     if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
27993         Op->getFlags().hasNoSignedWrap() || isOneConstant(Op.getOperand(0)))
27994       if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0),
27995                                              Depth + 1, AssumeNonZero))
27996         return DAG.getNode(ISD::ADD, DL, VT, LogX,
27997                            CastToVT(VT, Op.getOperand(1)));
27998   }
27999 
28000   // c ? X : Y -> c ? Log2(X) : Log2(Y)
28001   if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
28002       Op.hasOneUse()) {
28003     if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
28004                                            Depth + 1, AssumeNonZero))
28005       if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
28006                                              Depth + 1, AssumeNonZero))
28007         return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
28008   }
28009 
28010   // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
28011   // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
28012   if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
28013       Op.hasOneUse()) {
28014     // Use AssumeNonZero as false here. Otherwise we can hit case where
28015     // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
28016     if (SDValue LogX =
28017             takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), Depth + 1,
28018                                 /*AssumeNonZero*/ false))
28019       if (SDValue LogY =
28020               takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), Depth + 1,
28021                                   /*AssumeNonZero*/ false))
28022         return DAG.getNode(Op.getOpcode(), DL, VT, LogX, LogY);
28023   }
28024 
28025   return SDValue();
28026 }
28027 
28028 /// Determines the LogBase2 value for a non-null input value using the
28029 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL,bool KnownNonZero,bool InexpensiveOnly,std::optional<EVT> OutVT)28030 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
28031                                    bool KnownNonZero, bool InexpensiveOnly,
28032                                    std::optional<EVT> OutVT) {
28033   EVT VT = OutVT ? *OutVT : V.getValueType();
28034   SDValue InexpensiveLogBase2 =
28035       takeInexpensiveLog2(DAG, DL, VT, V, /*Depth*/ 0, KnownNonZero);
28036   if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(V))
28037     return InexpensiveLogBase2;
28038 
28039   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
28040   SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
28041   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
28042   return LogBase2;
28043 }
28044 
28045 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
28046 /// For the reciprocal, we need to find the zero of the function:
28047 ///   F(X) = 1/X - A [which has a zero at X = 1/A]
28048 ///     =>
28049 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
28050 ///     does not require additional intermediate precision]
28051 /// For the last iteration, put numerator N into it to gain more precision:
28052 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)28053 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
28054                                       SDNodeFlags Flags) {
28055   if (LegalDAG)
28056     return SDValue();
28057 
28058   // TODO: Handle extended types?
28059   EVT VT = Op.getValueType();
28060   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
28061       VT.getScalarType() != MVT::f64)
28062     return SDValue();
28063 
28064   // If estimates are explicitly disabled for this function, we're done.
28065   MachineFunction &MF = DAG.getMachineFunction();
28066   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
28067   if (Enabled == TLI.ReciprocalEstimate::Disabled)
28068     return SDValue();
28069 
28070   // Estimates may be explicitly enabled for this type with a custom number of
28071   // refinement steps.
28072   int Iterations = TLI.getDivRefinementSteps(VT, MF);
28073   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
28074     AddToWorklist(Est.getNode());
28075 
28076     SDLoc DL(Op);
28077     if (Iterations) {
28078       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
28079 
28080       // Newton iterations: Est = Est + Est (N - Arg * Est)
28081       // If this is the last iteration, also multiply by the numerator.
28082       for (int i = 0; i < Iterations; ++i) {
28083         SDValue MulEst = Est;
28084 
28085         if (i == Iterations - 1) {
28086           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
28087           AddToWorklist(MulEst.getNode());
28088         }
28089 
28090         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
28091         AddToWorklist(NewEst.getNode());
28092 
28093         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
28094                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
28095         AddToWorklist(NewEst.getNode());
28096 
28097         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
28098         AddToWorklist(NewEst.getNode());
28099 
28100         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
28101         AddToWorklist(Est.getNode());
28102       }
28103     } else {
28104       // If no iterations are available, multiply with N.
28105       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
28106       AddToWorklist(Est.getNode());
28107     }
28108 
28109     return Est;
28110   }
28111 
28112   return SDValue();
28113 }
28114 
28115 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
28116 /// For the reciprocal sqrt, we need to find the zero of the function:
28117 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
28118 ///     =>
28119 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
28120 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)28121 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
28122                                          unsigned Iterations,
28123                                          SDNodeFlags Flags, bool Reciprocal) {
28124   EVT VT = Arg.getValueType();
28125   SDLoc DL(Arg);
28126   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
28127 
28128   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
28129   // this entire sequence requires only one FP constant.
28130   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
28131   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
28132 
28133   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
28134   for (unsigned i = 0; i < Iterations; ++i) {
28135     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
28136     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
28137     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
28138     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
28139   }
28140 
28141   // If non-reciprocal square root is requested, multiply the result by Arg.
28142   if (!Reciprocal)
28143     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
28144 
28145   return Est;
28146 }
28147 
28148 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
28149 /// For the reciprocal sqrt, we need to find the zero of the function:
28150 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
28151 ///     =>
28152 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)28153 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
28154                                          unsigned Iterations,
28155                                          SDNodeFlags Flags, bool Reciprocal) {
28156   EVT VT = Arg.getValueType();
28157   SDLoc DL(Arg);
28158   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
28159   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
28160 
28161   // This routine must enter the loop below to work correctly
28162   // when (Reciprocal == false).
28163   assert(Iterations > 0);
28164 
28165   // Newton iterations for reciprocal square root:
28166   // E = (E * -0.5) * ((A * E) * E + -3.0)
28167   for (unsigned i = 0; i < Iterations; ++i) {
28168     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
28169     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
28170     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
28171 
28172     // When calculating a square root at the last iteration build:
28173     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
28174     // (notice a common subexpression)
28175     SDValue LHS;
28176     if (Reciprocal || (i + 1) < Iterations) {
28177       // RSQRT: LHS = (E * -0.5)
28178       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
28179     } else {
28180       // SQRT: LHS = (A * E) * -0.5
28181       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
28182     }
28183 
28184     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
28185   }
28186 
28187   return Est;
28188 }
28189 
28190 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
28191 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
28192 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)28193 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
28194                                            bool Reciprocal) {
28195   if (LegalDAG)
28196     return SDValue();
28197 
28198   // TODO: Handle extended types?
28199   EVT VT = Op.getValueType();
28200   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
28201       VT.getScalarType() != MVT::f64)
28202     return SDValue();
28203 
28204   // If estimates are explicitly disabled for this function, we're done.
28205   MachineFunction &MF = DAG.getMachineFunction();
28206   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
28207   if (Enabled == TLI.ReciprocalEstimate::Disabled)
28208     return SDValue();
28209 
28210   // Estimates may be explicitly enabled for this type with a custom number of
28211   // refinement steps.
28212   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
28213 
28214   bool UseOneConstNR = false;
28215   if (SDValue Est =
28216       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
28217                           Reciprocal)) {
28218     AddToWorklist(Est.getNode());
28219 
28220     if (Iterations > 0)
28221       Est = UseOneConstNR
28222             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
28223             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
28224     if (!Reciprocal) {
28225       SDLoc DL(Op);
28226       // Try the target specific test first.
28227       SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
28228 
28229       // The estimate is now completely wrong if the input was exactly 0.0 or
28230       // possibly a denormal. Force the answer to 0.0 or value provided by
28231       // target for those cases.
28232       Est = DAG.getNode(
28233           Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
28234           Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
28235     }
28236     return Est;
28237   }
28238 
28239   return SDValue();
28240 }
28241 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)28242 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
28243   return buildSqrtEstimateImpl(Op, Flags, true);
28244 }
28245 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)28246 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
28247   return buildSqrtEstimateImpl(Op, Flags, false);
28248 }
28249 
28250 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const28251 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
28252 
28253   struct MemUseCharacteristics {
28254     bool IsVolatile;
28255     bool IsAtomic;
28256     SDValue BasePtr;
28257     int64_t Offset;
28258     LocationSize NumBytes;
28259     MachineMemOperand *MMO;
28260   };
28261 
28262   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
28263     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
28264       int64_t Offset = 0;
28265       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
28266         Offset = (LSN->getAddressingMode() == ISD::PRE_INC) ? C->getSExtValue()
28267                  : (LSN->getAddressingMode() == ISD::PRE_DEC)
28268                      ? -1 * C->getSExtValue()
28269                      : 0;
28270       TypeSize Size = LSN->getMemoryVT().getStoreSize();
28271       return {LSN->isVolatile(),           LSN->isAtomic(),
28272               LSN->getBasePtr(),           Offset /*base offset*/,
28273               LocationSize::precise(Size), LSN->getMemOperand()};
28274     }
28275     if (const auto *LN = cast<LifetimeSDNode>(N))
28276       return {false /*isVolatile*/,
28277               /*isAtomic*/ false,
28278               LN->getOperand(1),
28279               (LN->hasOffset()) ? LN->getOffset() : 0,
28280               (LN->hasOffset()) ? LocationSize::precise(LN->getSize())
28281                                 : LocationSize::beforeOrAfterPointer(),
28282               (MachineMemOperand *)nullptr};
28283     // Default.
28284     return {false /*isvolatile*/,
28285             /*isAtomic*/ false,
28286             SDValue(),
28287             (int64_t)0 /*offset*/,
28288             LocationSize::beforeOrAfterPointer() /*size*/,
28289             (MachineMemOperand *)nullptr};
28290   };
28291 
28292   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
28293                         MUC1 = getCharacteristics(Op1);
28294 
28295   // If they are to the same address, then they must be aliases.
28296   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
28297       MUC0.Offset == MUC1.Offset)
28298     return true;
28299 
28300   // If they are both volatile then they cannot be reordered.
28301   if (MUC0.IsVolatile && MUC1.IsVolatile)
28302     return true;
28303 
28304   // Be conservative about atomics for the moment
28305   // TODO: This is way overconservative for unordered atomics (see D66309)
28306   if (MUC0.IsAtomic && MUC1.IsAtomic)
28307     return true;
28308 
28309   if (MUC0.MMO && MUC1.MMO) {
28310     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
28311         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
28312       return false;
28313   }
28314 
28315   // If NumBytes is scalable and offset is not 0, conservatively return may
28316   // alias
28317   if ((MUC0.NumBytes.hasValue() && MUC0.NumBytes.isScalable() &&
28318        MUC0.Offset != 0) ||
28319       (MUC1.NumBytes.hasValue() && MUC1.NumBytes.isScalable() &&
28320        MUC1.Offset != 0))
28321     return true;
28322   // Try to prove that there is aliasing, or that there is no aliasing. Either
28323   // way, we can return now. If nothing can be proved, proceed with more tests.
28324   bool IsAlias;
28325   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
28326                                        DAG, IsAlias))
28327     return IsAlias;
28328 
28329   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
28330   // either are not known.
28331   if (!MUC0.MMO || !MUC1.MMO)
28332     return true;
28333 
28334   // If one operation reads from invariant memory, and the other may store, they
28335   // cannot alias. These should really be checking the equivalent of mayWrite,
28336   // but it only matters for memory nodes other than load /store.
28337   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
28338       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
28339     return false;
28340 
28341   // If we know required SrcValue1 and SrcValue2 have relatively large
28342   // alignment compared to the size and offset of the access, we may be able
28343   // to prove they do not alias. This check is conservative for now to catch
28344   // cases created by splitting vector types, it only works when the offsets are
28345   // multiples of the size of the data.
28346   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
28347   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
28348   Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
28349   Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
28350   LocationSize Size0 = MUC0.NumBytes;
28351   LocationSize Size1 = MUC1.NumBytes;
28352 
28353   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
28354       Size0.hasValue() && Size1.hasValue() && !Size0.isScalable() &&
28355       !Size1.isScalable() && Size0 == Size1 &&
28356       OrigAlignment0 > Size0.getValue().getKnownMinValue() &&
28357       SrcValOffset0 % Size0.getValue().getKnownMinValue() == 0 &&
28358       SrcValOffset1 % Size1.getValue().getKnownMinValue() == 0) {
28359     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
28360     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
28361 
28362     // There is no overlap between these relatively aligned accesses of
28363     // similar size. Return no alias.
28364     if ((OffAlign0 + static_cast<int64_t>(
28365                          Size0.getValue().getKnownMinValue())) <= OffAlign1 ||
28366         (OffAlign1 + static_cast<int64_t>(
28367                          Size1.getValue().getKnownMinValue())) <= OffAlign0)
28368       return false;
28369   }
28370 
28371   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
28372                    ? CombinerGlobalAA
28373                    : DAG.getSubtarget().useAA();
28374 #ifndef NDEBUG
28375   if (CombinerAAOnlyFunc.getNumOccurrences() &&
28376       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
28377     UseAA = false;
28378 #endif
28379 
28380   if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
28381       Size0.hasValue() && Size1.hasValue() &&
28382       // Can't represent a scalable size + fixed offset in LocationSize
28383       (!Size0.isScalable() || SrcValOffset0 == 0) &&
28384       (!Size1.isScalable() || SrcValOffset1 == 0)) {
28385     // Use alias analysis information.
28386     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
28387     int64_t Overlap0 =
28388         Size0.getValue().getKnownMinValue() + SrcValOffset0 - MinOffset;
28389     int64_t Overlap1 =
28390         Size1.getValue().getKnownMinValue() + SrcValOffset1 - MinOffset;
28391     LocationSize Loc0 =
28392         Size0.isScalable() ? Size0 : LocationSize::precise(Overlap0);
28393     LocationSize Loc1 =
28394         Size1.isScalable() ? Size1 : LocationSize::precise(Overlap1);
28395     if (AA->isNoAlias(
28396             MemoryLocation(MUC0.MMO->getValue(), Loc0,
28397                            UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
28398             MemoryLocation(MUC1.MMO->getValue(), Loc1,
28399                            UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
28400       return false;
28401   }
28402 
28403   // Otherwise we have to assume they alias.
28404   return true;
28405 }
28406 
28407 /// Walk up chain skipping non-aliasing memory nodes,
28408 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)28409 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
28410                                    SmallVectorImpl<SDValue> &Aliases) {
28411   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
28412   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
28413 
28414   // Get alias information for node.
28415   // TODO: relax aliasing for unordered atomics (see D66309)
28416   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
28417 
28418   // Starting off.
28419   Chains.push_back(OriginalChain);
28420   unsigned Depth = 0;
28421 
28422   // Attempt to improve chain by a single step
28423   auto ImproveChain = [&](SDValue &C) -> bool {
28424     switch (C.getOpcode()) {
28425     case ISD::EntryToken:
28426       // No need to mark EntryToken.
28427       C = SDValue();
28428       return true;
28429     case ISD::LOAD:
28430     case ISD::STORE: {
28431       // Get alias information for C.
28432       // TODO: Relax aliasing for unordered atomics (see D66309)
28433       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
28434                       cast<LSBaseSDNode>(C.getNode())->isSimple();
28435       if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
28436         // Look further up the chain.
28437         C = C.getOperand(0);
28438         return true;
28439       }
28440       // Alias, so stop here.
28441       return false;
28442     }
28443 
28444     case ISD::CopyFromReg:
28445       // Always forward past CopyFromReg.
28446       C = C.getOperand(0);
28447       return true;
28448 
28449     case ISD::LIFETIME_START:
28450     case ISD::LIFETIME_END: {
28451       // We can forward past any lifetime start/end that can be proven not to
28452       // alias the memory access.
28453       if (!mayAlias(N, C.getNode())) {
28454         // Look further up the chain.
28455         C = C.getOperand(0);
28456         return true;
28457       }
28458       return false;
28459     }
28460     default:
28461       return false;
28462     }
28463   };
28464 
28465   // Look at each chain and determine if it is an alias.  If so, add it to the
28466   // aliases list.  If not, then continue up the chain looking for the next
28467   // candidate.
28468   while (!Chains.empty()) {
28469     SDValue Chain = Chains.pop_back_val();
28470 
28471     // Don't bother if we've seen Chain before.
28472     if (!Visited.insert(Chain.getNode()).second)
28473       continue;
28474 
28475     // For TokenFactor nodes, look at each operand and only continue up the
28476     // chain until we reach the depth limit.
28477     //
28478     // FIXME: The depth check could be made to return the last non-aliasing
28479     // chain we found before we hit a tokenfactor rather than the original
28480     // chain.
28481     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
28482       Aliases.clear();
28483       Aliases.push_back(OriginalChain);
28484       return;
28485     }
28486 
28487     if (Chain.getOpcode() == ISD::TokenFactor) {
28488       // We have to check each of the operands of the token factor for "small"
28489       // token factors, so we queue them up.  Adding the operands to the queue
28490       // (stack) in reverse order maintains the original order and increases the
28491       // likelihood that getNode will find a matching token factor (CSE.)
28492       if (Chain.getNumOperands() > 16) {
28493         Aliases.push_back(Chain);
28494         continue;
28495       }
28496       for (unsigned n = Chain.getNumOperands(); n;)
28497         Chains.push_back(Chain.getOperand(--n));
28498       ++Depth;
28499       continue;
28500     }
28501     // Everything else
28502     if (ImproveChain(Chain)) {
28503       // Updated Chain Found, Consider new chain if one exists.
28504       if (Chain.getNode())
28505         Chains.push_back(Chain);
28506       ++Depth;
28507       continue;
28508     }
28509     // No Improved Chain Possible, treat as Alias.
28510     Aliases.push_back(Chain);
28511   }
28512 }
28513 
28514 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
28515 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)28516 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
28517   if (OptLevel == CodeGenOptLevel::None)
28518     return OldChain;
28519 
28520   // Ops for replacing token factor.
28521   SmallVector<SDValue, 8> Aliases;
28522 
28523   // Accumulate all the aliases to this node.
28524   GatherAllAliases(N, OldChain, Aliases);
28525 
28526   // If no operands then chain to entry token.
28527   if (Aliases.empty())
28528     return DAG.getEntryNode();
28529 
28530   // If a single operand then chain to it.  We don't need to revisit it.
28531   if (Aliases.size() == 1)
28532     return Aliases[0];
28533 
28534   // Construct a custom tailored token factor.
28535   return DAG.getTokenFactor(SDLoc(N), Aliases);
28536 }
28537 
28538 // This function tries to collect a bunch of potentially interesting
28539 // nodes to improve the chains of, all at once. This might seem
28540 // redundant, as this function gets called when visiting every store
28541 // node, so why not let the work be done on each store as it's visited?
28542 //
28543 // I believe this is mainly important because mergeConsecutiveStores
28544 // is unable to deal with merging stores of different sizes, so unless
28545 // we improve the chains of all the potential candidates up-front
28546 // before running mergeConsecutiveStores, it might only see some of
28547 // the nodes that will eventually be candidates, and then not be able
28548 // to go from a partially-merged state to the desired final
28549 // fully-merged state.
28550 
parallelizeChainedStores(StoreSDNode * St)28551 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
28552   SmallVector<StoreSDNode *, 8> ChainedStores;
28553   StoreSDNode *STChain = St;
28554   // Intervals records which offsets from BaseIndex have been covered. In
28555   // the common case, every store writes to the immediately previous address
28556   // space and thus merged with the previous interval at insertion time.
28557 
28558   using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
28559                                  IntervalMapHalfOpenInfo<int64_t>>;
28560   IMap::Allocator A;
28561   IMap Intervals(A);
28562 
28563   // This holds the base pointer, index, and the offset in bytes from the base
28564   // pointer.
28565   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
28566 
28567   // We must have a base and an offset.
28568   if (!BasePtr.getBase().getNode())
28569     return false;
28570 
28571   // Do not handle stores to undef base pointers.
28572   if (BasePtr.getBase().isUndef())
28573     return false;
28574 
28575   // Do not handle stores to opaque types
28576   if (St->getMemoryVT().isZeroSized())
28577     return false;
28578 
28579   // BaseIndexOffset assumes that offsets are fixed-size, which
28580   // is not valid for scalable vectors where the offsets are
28581   // scaled by `vscale`, so bail out early.
28582   if (St->getMemoryVT().isScalableVT())
28583     return false;
28584 
28585   // Add ST's interval.
28586   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8,
28587                    std::monostate{});
28588 
28589   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
28590     if (Chain->getMemoryVT().isScalableVector())
28591       return false;
28592 
28593     // If the chain has more than one use, then we can't reorder the mem ops.
28594     if (!SDValue(Chain, 0)->hasOneUse())
28595       break;
28596     // TODO: Relax for unordered atomics (see D66309)
28597     if (!Chain->isSimple() || Chain->isIndexed())
28598       break;
28599 
28600     // Find the base pointer and offset for this memory node.
28601     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
28602     // Check that the base pointer is the same as the original one.
28603     int64_t Offset;
28604     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
28605       break;
28606     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
28607     // Make sure we don't overlap with other intervals by checking the ones to
28608     // the left or right before inserting.
28609     auto I = Intervals.find(Offset);
28610     // If there's a next interval, we should end before it.
28611     if (I != Intervals.end() && I.start() < (Offset + Length))
28612       break;
28613     // If there's a previous interval, we should start after it.
28614     if (I != Intervals.begin() && (--I).stop() <= Offset)
28615       break;
28616     Intervals.insert(Offset, Offset + Length, std::monostate{});
28617 
28618     ChainedStores.push_back(Chain);
28619     STChain = Chain;
28620   }
28621 
28622   // If we didn't find a chained store, exit.
28623   if (ChainedStores.empty())
28624     return false;
28625 
28626   // Improve all chained stores (St and ChainedStores members) starting from
28627   // where the store chain ended and return single TokenFactor.
28628   SDValue NewChain = STChain->getChain();
28629   SmallVector<SDValue, 8> TFOps;
28630   for (unsigned I = ChainedStores.size(); I;) {
28631     StoreSDNode *S = ChainedStores[--I];
28632     SDValue BetterChain = FindBetterChain(S, NewChain);
28633     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
28634         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
28635     TFOps.push_back(SDValue(S, 0));
28636     ChainedStores[I] = S;
28637   }
28638 
28639   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
28640   SDValue BetterChain = FindBetterChain(St, NewChain);
28641   SDValue NewST;
28642   if (St->isTruncatingStore())
28643     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
28644                               St->getBasePtr(), St->getMemoryVT(),
28645                               St->getMemOperand());
28646   else
28647     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
28648                          St->getBasePtr(), St->getMemOperand());
28649 
28650   TFOps.push_back(NewST);
28651 
28652   // If we improved every element of TFOps, then we've lost the dependence on
28653   // NewChain to successors of St and we need to add it back to TFOps. Do so at
28654   // the beginning to keep relative order consistent with FindBetterChains.
28655   auto hasImprovedChain = [&](SDValue ST) -> bool {
28656     return ST->getOperand(0) != NewChain;
28657   };
28658   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
28659   if (AddNewChain)
28660     TFOps.insert(TFOps.begin(), NewChain);
28661 
28662   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
28663   CombineTo(St, TF);
28664 
28665   // Add TF and its operands to the worklist.
28666   AddToWorklist(TF.getNode());
28667   for (const SDValue &Op : TF->ops())
28668     AddToWorklist(Op.getNode());
28669   AddToWorklist(STChain);
28670   return true;
28671 }
28672 
findBetterNeighborChains(StoreSDNode * St)28673 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
28674   if (OptLevel == CodeGenOptLevel::None)
28675     return false;
28676 
28677   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
28678 
28679   // We must have a base and an offset.
28680   if (!BasePtr.getBase().getNode())
28681     return false;
28682 
28683   // Do not handle stores to undef base pointers.
28684   if (BasePtr.getBase().isUndef())
28685     return false;
28686 
28687   // Directly improve a chain of disjoint stores starting at St.
28688   if (parallelizeChainedStores(St))
28689     return true;
28690 
28691   // Improve St's Chain..
28692   SDValue BetterChain = FindBetterChain(St, St->getChain());
28693   if (St->getChain() != BetterChain) {
28694     replaceStoreChain(St, BetterChain);
28695     return true;
28696   }
28697   return false;
28698 }
28699 
28700 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOptLevel OptLevel)28701 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
28702                            CodeGenOptLevel OptLevel) {
28703   /// This is the main entry point to this class.
28704   DAGCombiner(*this, AA, OptLevel).Run(Level);
28705 }
28706