xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (revision e64bea71c21eb42e97aa615188ba91f6cce0d36d)
1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes.  It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/MemoryLocation.h"
32 #include "llvm/Analysis/TargetLibraryInfo.h"
33 #include "llvm/Analysis/ValueTracking.h"
34 #include "llvm/Analysis/VectorUtils.h"
35 #include "llvm/CodeGen/ByteProvider.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFunction.h"
39 #include "llvm/CodeGen/MachineMemOperand.h"
40 #include "llvm/CodeGen/SDPatternMatch.h"
41 #include "llvm/CodeGen/SelectionDAG.h"
42 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
43 #include "llvm/CodeGen/SelectionDAGNodes.h"
44 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
45 #include "llvm/CodeGen/TargetLowering.h"
46 #include "llvm/CodeGen/TargetRegisterInfo.h"
47 #include "llvm/CodeGen/TargetSubtargetInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/Constant.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/Metadata.h"
56 #include "llvm/Support/Casting.h"
57 #include "llvm/Support/CodeGen.h"
58 #include "llvm/Support/CommandLine.h"
59 #include "llvm/Support/Compiler.h"
60 #include "llvm/Support/Debug.h"
61 #include "llvm/Support/DebugCounter.h"
62 #include "llvm/Support/ErrorHandling.h"
63 #include "llvm/Support/KnownBits.h"
64 #include "llvm/Support/MathExtras.h"
65 #include "llvm/Support/raw_ostream.h"
66 #include "llvm/Target/TargetMachine.h"
67 #include "llvm/Target/TargetOptions.h"
68 #include <algorithm>
69 #include <cassert>
70 #include <cstdint>
71 #include <functional>
72 #include <iterator>
73 #include <optional>
74 #include <string>
75 #include <tuple>
76 #include <utility>
77 #include <variant>
78 
79 #include "MatchContext.h"
80 
81 using namespace llvm;
82 using namespace llvm::SDPatternMatch;
83 
84 #define DEBUG_TYPE "dagcombine"
85 
86 STATISTIC(NodesCombined   , "Number of dag nodes combined");
87 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
88 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
89 STATISTIC(OpsNarrowed     , "Number of load/op/store narrowed");
90 STATISTIC(LdStFP2Int      , "Number of fp load/store pairs transformed to int");
91 STATISTIC(SlicedLoads, "Number of load sliced");
92 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
93 
94 DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
95               "Controls whether a DAG combine is performed for a node");
96 
97 static cl::opt<bool>
98 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
99                  cl::desc("Enable DAG combiner's use of IR alias analysis"));
100 
101 static cl::opt<bool>
102 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
103         cl::desc("Enable DAG combiner's use of TBAA"));
104 
105 #ifndef NDEBUG
106 static cl::opt<std::string>
107 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
108                    cl::desc("Only use DAG-combiner alias analysis in this"
109                             " function"));
110 #endif
111 
112 /// Hidden option to stress test load slicing, i.e., when this option
113 /// is enabled, load slicing bypasses most of its profitability guards.
114 static cl::opt<bool>
115 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
116                   cl::desc("Bypass the profitability model of load slicing"),
117                   cl::init(false));
118 
119 static cl::opt<bool>
120   MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
121                     cl::desc("DAG combiner may split indexing from loads"));
122 
123 static cl::opt<bool>
124     EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
125                        cl::desc("DAG combiner enable merging multiple stores "
126                                 "into a wider store"));
127 
128 static cl::opt<unsigned> TokenFactorInlineLimit(
129     "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
130     cl::desc("Limit the number of operands to inline for Token Factors"));
131 
132 static cl::opt<unsigned> StoreMergeDependenceLimit(
133     "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
134     cl::desc("Limit the number of times for the same StoreNode and RootNode "
135              "to bail out in store merging dependence check"));
136 
137 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
138     "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
139     cl::desc("DAG combiner enable reducing the width of load/op/store "
140              "sequence"));
141 static cl::opt<bool> ReduceLoadOpStoreWidthForceNarrowingProfitable(
142     "combiner-reduce-load-op-store-width-force-narrowing-profitable",
143     cl::Hidden, cl::init(false),
144     cl::desc("DAG combiner force override the narrowing profitable check when "
145              "reducing the width of load/op/store sequences"));
146 
147 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
148     "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
149     cl::desc("DAG combiner enable load/<replace bytes>/store with "
150              "a narrower store"));
151 
152 static cl::opt<bool> DisableCombines("combiner-disabled", cl::Hidden,
153                                      cl::init(false),
154                                      cl::desc("Disable the DAG combiner"));
155 
156 namespace {
157 
158   class DAGCombiner {
159     SelectionDAG &DAG;
160     const TargetLowering &TLI;
161     const SelectionDAGTargetInfo *STI;
162     CombineLevel Level = BeforeLegalizeTypes;
163     CodeGenOptLevel OptLevel;
164     bool LegalDAG = false;
165     bool LegalOperations = false;
166     bool LegalTypes = false;
167     bool ForCodeSize;
168     bool DisableGenericCombines;
169 
170     /// Worklist of all of the nodes that need to be simplified.
171     ///
172     /// This must behave as a stack -- new nodes to process are pushed onto the
173     /// back and when processing we pop off of the back.
174     ///
175     /// The worklist will not contain duplicates but may contain null entries
176     /// due to nodes being deleted from the underlying DAG. For fast lookup and
177     /// deduplication, the index of the node in this vector is stored in the
178     /// node in SDNode::CombinerWorklistIndex.
179     SmallVector<SDNode *, 64> Worklist;
180 
181     /// This records all nodes attempted to be added to the worklist since we
182     /// considered a new worklist entry. As we keep do not add duplicate nodes
183     /// in the worklist, this is different from the tail of the worklist.
184     SmallSetVector<SDNode *, 32> PruningList;
185 
186     /// Map from candidate StoreNode to the pair of RootNode and count.
187     /// The count is used to track how many times we have seen the StoreNode
188     /// with the same RootNode bail out in dependence check. If we have seen
189     /// the bail out for the same pair many times over a limit, we won't
190     /// consider the StoreNode with the same RootNode as store merging
191     /// candidate again.
192     DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
193 
194     // BatchAA - Used for DAG load/store alias analysis.
195     BatchAAResults *BatchAA;
196 
197     /// This caches all chains that have already been processed in
198     /// DAGCombiner::getStoreMergeCandidates() and found to have no mergeable
199     /// stores candidates.
200     SmallPtrSet<SDNode *, 4> ChainsWithoutMergeableStores;
201 
202     /// When an instruction is simplified, add all users of the instruction to
203     /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)204     void AddUsersToWorklist(SDNode *N) {
205       for (SDNode *Node : N->users())
206         AddToWorklist(Node);
207     }
208 
209     /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)210     void AddToWorklistWithUsers(SDNode *N) {
211       AddUsersToWorklist(N);
212       AddToWorklist(N);
213     }
214 
215     // Prune potentially dangling nodes. This is called after
216     // any visit to a node, but should also be called during a visit after any
217     // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()218     void clearAddedDanglingWorklistEntries() {
219       // Check any nodes added to the worklist to see if they are prunable.
220       while (!PruningList.empty()) {
221         auto *N = PruningList.pop_back_val();
222         if (N->use_empty())
223           recursivelyDeleteUnusedNodes(N);
224       }
225     }
226 
getNextWorklistEntry()227     SDNode *getNextWorklistEntry() {
228       // Before we do any work, remove nodes that are not in use.
229       clearAddedDanglingWorklistEntries();
230       SDNode *N = nullptr;
231       // The Worklist holds the SDNodes in order, but it may contain null
232       // entries.
233       while (!N && !Worklist.empty()) {
234         N = Worklist.pop_back_val();
235       }
236 
237       if (N) {
238         assert(N->getCombinerWorklistIndex() >= 0 &&
239                "Found a worklist entry without a corresponding map entry!");
240         // Set to -2 to indicate that we combined the node.
241         N->setCombinerWorklistIndex(-2);
242       }
243       return N;
244     }
245 
246     /// Call the node-specific routine that folds each particular type of node.
247     SDValue visit(SDNode *N);
248 
249   public:
DAGCombiner(SelectionDAG & D,BatchAAResults * BatchAA,CodeGenOptLevel OL)250     DAGCombiner(SelectionDAG &D, BatchAAResults *BatchAA, CodeGenOptLevel OL)
251         : DAG(D), TLI(D.getTargetLoweringInfo()),
252           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL),
253           BatchAA(BatchAA) {
254       ForCodeSize = DAG.shouldOptForSize();
255       DisableGenericCombines =
256           DisableCombines || (STI && STI->disableGenericCombines(OptLevel));
257 
258       MaximumLegalStoreInBits = 0;
259       // We use the minimum store size here, since that's all we can guarantee
260       // for the scalable vector types.
261       for (MVT VT : MVT::all_valuetypes())
262         if (EVT(VT).isSimple() && VT != MVT::Other &&
263             TLI.isTypeLegal(EVT(VT)) &&
264             VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
265           MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
266     }
267 
ConsiderForPruning(SDNode * N)268     void ConsiderForPruning(SDNode *N) {
269       // Mark this for potential pruning.
270       PruningList.insert(N);
271     }
272 
273     /// Add to the worklist making sure its instance is at the back (next to be
274     /// processed.)
AddToWorklist(SDNode * N,bool IsCandidateForPruning=true,bool SkipIfCombinedBefore=false)275     void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true,
276                        bool SkipIfCombinedBefore = false) {
277       assert(N->getOpcode() != ISD::DELETED_NODE &&
278              "Deleted Node added to Worklist");
279 
280       // Skip handle nodes as they can't usefully be combined and confuse the
281       // zero-use deletion strategy.
282       if (N->getOpcode() == ISD::HANDLENODE)
283         return;
284 
285       if (SkipIfCombinedBefore && N->getCombinerWorklistIndex() == -2)
286         return;
287 
288       if (IsCandidateForPruning)
289         ConsiderForPruning(N);
290 
291       if (N->getCombinerWorklistIndex() < 0) {
292         N->setCombinerWorklistIndex(Worklist.size());
293         Worklist.push_back(N);
294       }
295     }
296 
297     /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)298     void removeFromWorklist(SDNode *N) {
299       PruningList.remove(N);
300       StoreRootCountMap.erase(N);
301 
302       int WorklistIndex = N->getCombinerWorklistIndex();
303       // If not in the worklist, the index might be -1 or -2 (was combined
304       // before). As the node gets deleted anyway, there's no need to update
305       // the index.
306       if (WorklistIndex < 0)
307         return; // Not in the worklist.
308 
309       // Null out the entry rather than erasing it to avoid a linear operation.
310       Worklist[WorklistIndex] = nullptr;
311       N->setCombinerWorklistIndex(-1);
312     }
313 
314     void deleteAndRecombine(SDNode *N);
315     bool recursivelyDeleteUnusedNodes(SDNode *N);
316 
317     /// Replaces all uses of the results of one DAG node with new values.
318     SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
319                       bool AddTo = true);
320 
321     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)322     SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
323       return CombineTo(N, &Res, 1, AddTo);
324     }
325 
326     /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)327     SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
328                       bool AddTo = true) {
329       SDValue To[] = { Res0, Res1 };
330       return CombineTo(N, To, 2, AddTo);
331     }
332 
333     void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
334 
335   private:
336     unsigned MaximumLegalStoreInBits;
337 
338     /// Check the specified integer node value to see if it can be simplified or
339     /// if things it uses can be simplified by bit propagation.
340     /// If so, return true.
SimplifyDemandedBits(SDValue Op)341     bool SimplifyDemandedBits(SDValue Op) {
342       unsigned BitWidth = Op.getScalarValueSizeInBits();
343       APInt DemandedBits = APInt::getAllOnes(BitWidth);
344       return SimplifyDemandedBits(Op, DemandedBits);
345     }
346 
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)347     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
348       EVT VT = Op.getValueType();
349       APInt DemandedElts = VT.isFixedLengthVector()
350                                ? APInt::getAllOnes(VT.getVectorNumElements())
351                                : APInt(1, 1);
352       return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, false);
353     }
354 
355     /// Check the specified vector node value to see if it can be simplified or
356     /// if things it uses can be simplified as it only uses some of the
357     /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)358     bool SimplifyDemandedVectorElts(SDValue Op) {
359       // TODO: For now just pretend it cannot be simplified.
360       if (Op.getValueType().isScalableVector())
361         return false;
362 
363       unsigned NumElts = Op.getValueType().getVectorNumElements();
364       APInt DemandedElts = APInt::getAllOnes(NumElts);
365       return SimplifyDemandedVectorElts(Op, DemandedElts);
366     }
367 
368     bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
369                               const APInt &DemandedElts,
370                               bool AssumeSingleUse = false);
371     bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
372                                     bool AssumeSingleUse = false);
373 
374     bool CombineToPreIndexedLoadStore(SDNode *N);
375     bool CombineToPostIndexedLoadStore(SDNode *N);
376     SDValue SplitIndexingFromLoad(LoadSDNode *LD);
377     bool SliceUpLoad(SDNode *N);
378 
379     // Looks up the chain to find a unique (unaliased) store feeding the passed
380     // load. If no such store is found, returns a nullptr.
381     // Note: This will look past a CALLSEQ_START if the load is chained to it so
382     //       so that it can find stack stores for byval params.
383     StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
384     // Scalars have size 0 to distinguish from singleton vectors.
385     SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
386     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
387     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
388 
389     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
390     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
391     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
392     SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
393     SDValue PromoteIntBinOp(SDValue Op);
394     SDValue PromoteIntShiftOp(SDValue Op);
395     SDValue PromoteExtend(SDValue Op);
396     bool PromoteLoad(SDValue Op);
397 
398     SDValue foldShiftToAvg(SDNode *N);
399     // Fold `a bitwiseop (~b +/- c)` -> `a bitwiseop ~(b -/+ c)`
400     SDValue foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT);
401 
402     SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
403                                 SDValue RHS, SDValue True, SDValue False,
404                                 ISD::CondCode CC);
405 
406     /// Call the node-specific routine that knows how to fold each
407     /// particular type of node. If that doesn't do anything, try the
408     /// target-specific DAG combines.
409     SDValue combine(SDNode *N);
410 
411     // Visitation implementation - Implement dag node combining for different
412     // node types.  The semantics are as follows:
413     // Return Value:
414     //   SDValue.getNode() == 0 - No change was made
415     //   SDValue.getNode() == N - N was replaced, is dead and has been handled.
416     //   otherwise              - N should be replaced by the returned Operand.
417     //
418     SDValue visitTokenFactor(SDNode *N);
419     SDValue visitMERGE_VALUES(SDNode *N);
420     SDValue visitADD(SDNode *N);
421     SDValue visitADDLike(SDNode *N);
422     SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
423                                     SDNode *LocReference);
424     SDValue visitPTRADD(SDNode *N);
425     SDValue visitSUB(SDNode *N);
426     SDValue visitADDSAT(SDNode *N);
427     SDValue visitSUBSAT(SDNode *N);
428     SDValue visitADDC(SDNode *N);
429     SDValue visitADDO(SDNode *N);
430     SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
431     SDValue visitSUBC(SDNode *N);
432     SDValue visitSUBO(SDNode *N);
433     SDValue visitADDE(SDNode *N);
434     SDValue visitUADDO_CARRY(SDNode *N);
435     SDValue visitSADDO_CARRY(SDNode *N);
436     SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
437                                  SDNode *N);
438     SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
439                                  SDNode *N);
440     SDValue visitSUBE(SDNode *N);
441     SDValue visitUSUBO_CARRY(SDNode *N);
442     SDValue visitSSUBO_CARRY(SDNode *N);
443     template <class MatchContextClass> SDValue visitMUL(SDNode *N);
444     SDValue visitMULFIX(SDNode *N);
445     SDValue useDivRem(SDNode *N);
446     SDValue visitSDIV(SDNode *N);
447     SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
448     SDValue visitUDIV(SDNode *N);
449     SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
450     SDValue visitREM(SDNode *N);
451     SDValue visitMULHU(SDNode *N);
452     SDValue visitMULHS(SDNode *N);
453     SDValue visitAVG(SDNode *N);
454     SDValue visitABD(SDNode *N);
455     SDValue visitSMUL_LOHI(SDNode *N);
456     SDValue visitUMUL_LOHI(SDNode *N);
457     SDValue visitMULO(SDNode *N);
458     SDValue visitIMINMAX(SDNode *N);
459     SDValue visitAND(SDNode *N);
460     SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
461     SDValue visitOR(SDNode *N);
462     SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
463     SDValue visitXOR(SDNode *N);
464     SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
465     SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
466     SDValue visitSHL(SDNode *N);
467     SDValue visitSRA(SDNode *N);
468     SDValue visitSRL(SDNode *N);
469     SDValue visitFunnelShift(SDNode *N);
470     SDValue visitSHLSAT(SDNode *N);
471     SDValue visitRotate(SDNode *N);
472     SDValue visitABS(SDNode *N);
473     SDValue visitBSWAP(SDNode *N);
474     SDValue visitBITREVERSE(SDNode *N);
475     SDValue visitCTLZ(SDNode *N);
476     SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
477     SDValue visitCTTZ(SDNode *N);
478     SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
479     SDValue visitCTPOP(SDNode *N);
480     SDValue visitSELECT(SDNode *N);
481     SDValue visitVSELECT(SDNode *N);
482     SDValue visitVP_SELECT(SDNode *N);
483     SDValue visitSELECT_CC(SDNode *N);
484     SDValue visitSETCC(SDNode *N);
485     SDValue visitSETCCCARRY(SDNode *N);
486     SDValue visitSIGN_EXTEND(SDNode *N);
487     SDValue visitZERO_EXTEND(SDNode *N);
488     SDValue visitANY_EXTEND(SDNode *N);
489     SDValue visitAssertExt(SDNode *N);
490     SDValue visitAssertAlign(SDNode *N);
491     SDValue visitSIGN_EXTEND_INREG(SDNode *N);
492     SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
493     SDValue visitTRUNCATE(SDNode *N);
494     SDValue visitTRUNCATE_USAT_U(SDNode *N);
495     SDValue visitBITCAST(SDNode *N);
496     SDValue visitFREEZE(SDNode *N);
497     SDValue visitBUILD_PAIR(SDNode *N);
498     SDValue visitFADD(SDNode *N);
499     SDValue visitVP_FADD(SDNode *N);
500     SDValue visitVP_FSUB(SDNode *N);
501     SDValue visitSTRICT_FADD(SDNode *N);
502     SDValue visitFSUB(SDNode *N);
503     SDValue visitFMUL(SDNode *N);
504     template <class MatchContextClass> SDValue visitFMA(SDNode *N);
505     SDValue visitFMAD(SDNode *N);
506     SDValue visitFDIV(SDNode *N);
507     SDValue visitFREM(SDNode *N);
508     SDValue visitFSQRT(SDNode *N);
509     SDValue visitFCOPYSIGN(SDNode *N);
510     SDValue visitFPOW(SDNode *N);
511     SDValue visitFCANONICALIZE(SDNode *N);
512     SDValue visitSINT_TO_FP(SDNode *N);
513     SDValue visitUINT_TO_FP(SDNode *N);
514     SDValue visitFP_TO_SINT(SDNode *N);
515     SDValue visitFP_TO_UINT(SDNode *N);
516     SDValue visitXROUND(SDNode *N);
517     SDValue visitFP_ROUND(SDNode *N);
518     SDValue visitFP_EXTEND(SDNode *N);
519     SDValue visitFNEG(SDNode *N);
520     SDValue visitFABS(SDNode *N);
521     SDValue visitFCEIL(SDNode *N);
522     SDValue visitFTRUNC(SDNode *N);
523     SDValue visitFFREXP(SDNode *N);
524     SDValue visitFFLOOR(SDNode *N);
525     SDValue visitFMinMax(SDNode *N);
526     SDValue visitBRCOND(SDNode *N);
527     SDValue visitBR_CC(SDNode *N);
528     SDValue visitLOAD(SDNode *N);
529 
530     SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
531     SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
532     SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
533 
534     bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
535 
536     SDValue visitSTORE(SDNode *N);
537     SDValue visitATOMIC_STORE(SDNode *N);
538     SDValue visitLIFETIME_END(SDNode *N);
539     SDValue visitINSERT_VECTOR_ELT(SDNode *N);
540     SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
541     SDValue visitBUILD_VECTOR(SDNode *N);
542     SDValue visitCONCAT_VECTORS(SDNode *N);
543     SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
544     SDValue visitVECTOR_SHUFFLE(SDNode *N);
545     SDValue visitSCALAR_TO_VECTOR(SDNode *N);
546     SDValue visitINSERT_SUBVECTOR(SDNode *N);
547     SDValue visitVECTOR_COMPRESS(SDNode *N);
548     SDValue visitMLOAD(SDNode *N);
549     SDValue visitMSTORE(SDNode *N);
550     SDValue visitMGATHER(SDNode *N);
551     SDValue visitMSCATTER(SDNode *N);
552     SDValue visitMHISTOGRAM(SDNode *N);
553     SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
554     SDValue visitVPGATHER(SDNode *N);
555     SDValue visitVPSCATTER(SDNode *N);
556     SDValue visitVP_STRIDED_LOAD(SDNode *N);
557     SDValue visitVP_STRIDED_STORE(SDNode *N);
558     SDValue visitFP_TO_FP16(SDNode *N);
559     SDValue visitFP16_TO_FP(SDNode *N);
560     SDValue visitFP_TO_BF16(SDNode *N);
561     SDValue visitBF16_TO_FP(SDNode *N);
562     SDValue visitVECREDUCE(SDNode *N);
563     SDValue visitVPOp(SDNode *N);
564     SDValue visitGET_FPENV_MEM(SDNode *N);
565     SDValue visitSET_FPENV_MEM(SDNode *N);
566 
567     template <class MatchContextClass>
568     SDValue visitFADDForFMACombine(SDNode *N);
569     template <class MatchContextClass>
570     SDValue visitFSUBForFMACombine(SDNode *N);
571     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
572 
573     SDValue XformToShuffleWithZero(SDNode *N);
574     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
575                                                     const SDLoc &DL,
576                                                     SDNode *N,
577                                                     SDValue N0,
578                                                     SDValue N1);
579     SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
580                                       SDValue N1, SDNodeFlags Flags);
581     SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
582                            SDValue N1, SDNodeFlags Flags);
583     SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
584                                  EVT VT, SDValue N0, SDValue N1,
585                                  SDNodeFlags Flags = SDNodeFlags());
586 
587     SDValue visitShiftByConstant(SDNode *N);
588 
589     SDValue foldSelectOfConstants(SDNode *N);
590     SDValue foldVSelectOfConstants(SDNode *N);
591     SDValue foldBinOpIntoSelect(SDNode *BO);
592     bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
593     SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
594     SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
595     SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
596                              SDValue N2, SDValue N3, ISD::CondCode CC,
597                              bool NotExtCompare = false);
598     SDValue convertSelectOfFPConstantsToLoadOffset(
599         const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
600         ISD::CondCode CC);
601     SDValue foldSignChangeInBitcast(SDNode *N);
602     SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
603                                    SDValue N2, SDValue N3, ISD::CondCode CC);
604     SDValue foldSelectOfBinops(SDNode *N);
605     SDValue foldSextSetcc(SDNode *N);
606     SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
607                               const SDLoc &DL);
608     SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
609     SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
610     SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
611                             SDValue False, ISD::CondCode CC, const SDLoc &DL);
612     SDValue unfoldMaskedMerge(SDNode *N);
613     SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
614     SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
615                           const SDLoc &DL, bool foldBooleans);
616     SDValue rebuildSetCC(SDValue N);
617 
618     bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
619                            SDValue &CC, bool MatchStrict = false) const;
620     bool isOneUseSetCC(SDValue N) const;
621 
622     SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
623     SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
624 
625     SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
626                                          unsigned HiOp);
627     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
628     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
629                                  const TargetLowering &TLI);
630     SDValue foldPartialReduceMLAMulOp(SDNode *N);
631     SDValue foldPartialReduceAdd(SDNode *N);
632 
633     SDValue CombineExtLoad(SDNode *N);
634     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
635     SDValue combineRepeatedFPDivisors(SDNode *N);
636     SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
637     SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf);
638     SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
639     SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
640     SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
641     SDValue BuildSDIV(SDNode *N);
642     SDValue BuildSDIVPow2(SDNode *N);
643     SDValue BuildUDIV(SDNode *N);
644     SDValue BuildSREMPow2(SDNode *N);
645     SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
646     SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
647                           bool KnownNeverZero = false,
648                           bool InexpensiveOnly = false,
649                           std::optional<EVT> OutVT = std::nullopt);
650     SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
651     SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
652     SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
653     SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
654     SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
655                                 SDNodeFlags Flags, bool Reciprocal);
656     SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
657                                 SDNodeFlags Flags, bool Reciprocal);
658     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
659                                bool DemandHighBits = true);
660     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
661     SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
662                               SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
663                               bool HasPos, unsigned PosOpcode,
664                               unsigned NegOpcode, const SDLoc &DL);
665     SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
666                               SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
667                               bool HasPos, unsigned PosOpcode,
668                               unsigned NegOpcode, const SDLoc &DL);
669     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
670                         bool FromAdd);
671     SDValue MatchLoadCombine(SDNode *N);
672     SDValue mergeTruncStores(StoreSDNode *N);
673     SDValue reduceLoadWidth(SDNode *N);
674     SDValue ReduceLoadOpStoreWidth(SDNode *N);
675     SDValue splitMergedValStore(StoreSDNode *ST);
676     SDValue TransformFPLoadStorePair(SDNode *N);
677     SDValue convertBuildVecZextToZext(SDNode *N);
678     SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
679     SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
680     SDValue reduceBuildVecTruncToBitCast(SDNode *N);
681     SDValue reduceBuildVecToShuffle(SDNode *N);
682     SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
683                                   ArrayRef<int> VectorMask, SDValue VecIn1,
684                                   SDValue VecIn2, unsigned LeftIdx,
685                                   bool DidSplitVec);
686     SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
687 
688     /// Walk up chain skipping non-aliasing memory nodes,
689     /// looking for aliasing nodes and adding them to the Aliases vector.
690     void GatherAllAliases(SDNode *N, SDValue OriginalChain,
691                           SmallVectorImpl<SDValue> &Aliases);
692 
693     /// Return true if there is any possibility that the two addresses overlap.
694     bool mayAlias(SDNode *Op0, SDNode *Op1) const;
695 
696     /// Walk up chain skipping non-aliasing memory nodes, looking for a better
697     /// chain (aliasing node.)
698     SDValue FindBetterChain(SDNode *N, SDValue Chain);
699 
700     /// Try to replace a store and any possibly adjacent stores on
701     /// consecutive chains with better chains. Return true only if St is
702     /// replaced.
703     ///
704     /// Notice that other chains may still be replaced even if the function
705     /// returns false.
706     bool findBetterNeighborChains(StoreSDNode *St);
707 
708     // Helper for findBetterNeighborChains. Walk up store chain add additional
709     // chained stores that do not overlap and can be parallelized.
710     bool parallelizeChainedStores(StoreSDNode *St);
711 
712     /// Holds a pointer to an LSBaseSDNode as well as information on where it
713     /// is located in a sequence of memory operations connected by a chain.
714     struct MemOpLink {
715       // Ptr to the mem node.
716       LSBaseSDNode *MemNode;
717 
718       // Offset from the base ptr.
719       int64_t OffsetFromBase;
720 
MemOpLink__anon666e37100111::DAGCombiner::MemOpLink721       MemOpLink(LSBaseSDNode *N, int64_t Offset)
722           : MemNode(N), OffsetFromBase(Offset) {}
723     };
724 
725     // Classify the origin of a stored value.
726     enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)727     StoreSource getStoreSource(SDValue StoreVal) {
728       switch (StoreVal.getOpcode()) {
729       case ISD::Constant:
730       case ISD::ConstantFP:
731         return StoreSource::Constant;
732       case ISD::BUILD_VECTOR:
733         if (ISD::isBuildVectorOfConstantSDNodes(StoreVal.getNode()) ||
734             ISD::isBuildVectorOfConstantFPSDNodes(StoreVal.getNode()))
735           return StoreSource::Constant;
736         return StoreSource::Unknown;
737       case ISD::EXTRACT_VECTOR_ELT:
738       case ISD::EXTRACT_SUBVECTOR:
739         return StoreSource::Extract;
740       case ISD::LOAD:
741         return StoreSource::Load;
742       default:
743         return StoreSource::Unknown;
744       }
745     }
746 
747     /// This is a helper function for visitMUL to check the profitability
748     /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
749     /// MulNode is the original multiply, AddNode is (add x, c1),
750     /// and ConstNode is c2.
751     bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
752                                      SDValue ConstNode);
753 
754     /// This is a helper function for visitAND and visitZERO_EXTEND.  Returns
755     /// true if the (and (load x) c) pattern matches an extload.  ExtVT returns
756     /// the type of the loaded value to be extended.
757     bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
758                           EVT LoadResultTy, EVT &ExtVT);
759 
760     /// Helper function to calculate whether the given Load/Store can have its
761     /// width reduced to ExtVT.
762     bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
763                            EVT &MemVT, unsigned ShAmt = 0);
764 
765     /// Used by BackwardsPropagateMask to find suitable loads.
766     bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
767                            SmallPtrSetImpl<SDNode*> &NodesWithConsts,
768                            ConstantSDNode *Mask, SDNode *&NodeToMask);
769     /// Attempt to propagate a given AND node back to load leaves so that they
770     /// can be combined into narrow loads.
771     bool BackwardsPropagateMask(SDNode *N);
772 
773     /// Helper function for mergeConsecutiveStores which merges the component
774     /// store chains.
775     SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
776                                 unsigned NumStores);
777 
778     /// Helper function for mergeConsecutiveStores which checks if all the store
779     /// nodes have the same underlying object. We can still reuse the first
780     /// store's pointer info if all the stores are from the same object.
781     bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
782 
783     /// This is a helper function for mergeConsecutiveStores. When the source
784     /// elements of the consecutive stores are all constants or all extracted
785     /// vector elements, try to merge them into one larger store introducing
786     /// bitcasts if necessary.  \return True if a merged store was created.
787     bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
788                                          EVT MemVT, unsigned NumStores,
789                                          bool IsConstantSrc, bool UseVector,
790                                          bool UseTrunc);
791 
792     /// This is a helper function for mergeConsecutiveStores. Stores that
793     /// potentially may be merged with St are placed in StoreNodes. On success,
794     /// returns a chain predecessor to all store candidates.
795     SDNode *getStoreMergeCandidates(StoreSDNode *St,
796                                     SmallVectorImpl<MemOpLink> &StoreNodes);
797 
798     /// Helper function for mergeConsecutiveStores. Checks if candidate stores
799     /// have indirect dependency through their operands. RootNode is the
800     /// predecessor to all stores calculated by getStoreMergeCandidates and is
801     /// used to prune the dependency check. \return True if safe to merge.
802     bool checkMergeStoreCandidatesForDependencies(
803         SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
804         SDNode *RootNode);
805 
806     /// Helper function for tryStoreMergeOfLoads. Checks if the load/store
807     /// chain has a call in it. \return True if a call is found.
808     bool hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld);
809 
810     /// This is a helper function for mergeConsecutiveStores. Given a list of
811     /// store candidates, find the first N that are consecutive in memory.
812     /// Returns 0 if there are not at least 2 consecutive stores to try merging.
813     unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
814                                   int64_t ElementSizeBytes) const;
815 
816     /// This is a helper function for mergeConsecutiveStores. It is used for
817     /// store chains that are composed entirely of constant values.
818     bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
819                                   unsigned NumConsecutiveStores,
820                                   EVT MemVT, SDNode *Root, bool AllowVectors);
821 
822     /// This is a helper function for mergeConsecutiveStores. It is used for
823     /// store chains that are composed entirely of extracted vector elements.
824     /// When extracting multiple vector elements, try to store them in one
825     /// vector store rather than a sequence of scalar stores.
826     bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
827                                  unsigned NumConsecutiveStores, EVT MemVT,
828                                  SDNode *Root);
829 
830     /// This is a helper function for mergeConsecutiveStores. It is used for
831     /// store chains that are composed entirely of loaded values.
832     bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
833                               unsigned NumConsecutiveStores, EVT MemVT,
834                               SDNode *Root, bool AllowVectors,
835                               bool IsNonTemporalStore, bool IsNonTemporalLoad);
836 
837     /// Merge consecutive store operations into a wide store.
838     /// This optimization uses wide integers or vectors when possible.
839     /// \return true if stores were merged.
840     bool mergeConsecutiveStores(StoreSDNode *St);
841 
842     /// Try to transform a truncation where C is a constant:
843     ///     (trunc (and X, C)) -> (and (trunc X), (trunc C))
844     ///
845     /// \p N needs to be a truncation and its first operand an AND. Other
846     /// requirements are checked by the function (e.g. that trunc is
847     /// single-use) and if missed an empty SDValue is returned.
848     SDValue distributeTruncateThroughAnd(SDNode *N);
849 
850     /// Helper function to determine whether the target supports operation
851     /// given by \p Opcode for type \p VT, that is, whether the operation
852     /// is legal or custom before legalizing operations, and whether is
853     /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)854     bool hasOperation(unsigned Opcode, EVT VT) {
855       return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
856     }
857 
hasUMin(EVT VT) const858     bool hasUMin(EVT VT) const {
859       auto LK = TLI.getTypeConversion(*DAG.getContext(), VT);
860       return (LK.first == TargetLoweringBase::TypeLegal ||
861               LK.first == TargetLoweringBase::TypePromoteInteger) &&
862              TLI.isOperationLegal(ISD::UMIN, LK.second);
863     }
864 
865   public:
866     /// Runs the dag combiner on all nodes in the work list
867     void Run(CombineLevel AtLevel);
868 
getDAG() const869     SelectionDAG &getDAG() const { return DAG; }
870 
871     /// Convenience wrapper around TargetLowering::getShiftAmountTy.
getShiftAmountTy(EVT LHSTy)872     EVT getShiftAmountTy(EVT LHSTy) {
873       return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout());
874     }
875 
876     /// This method returns true if we are running before type legalization or
877     /// if the specified VT is legal.
isTypeLegal(const EVT & VT)878     bool isTypeLegal(const EVT &VT) {
879       if (!LegalTypes) return true;
880       return TLI.isTypeLegal(VT);
881     }
882 
883     /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const884     EVT getSetCCResultType(EVT VT) const {
885       return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
886     }
887 
888     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
889                          SDValue OrigLoad, SDValue ExtLoad,
890                          ISD::NodeType ExtType);
891   };
892 
893 /// This class is a DAGUpdateListener that removes any deleted
894 /// nodes from the worklist.
895 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
896   DAGCombiner &DC;
897 
898 public:
WorklistRemover(DAGCombiner & dc)899   explicit WorklistRemover(DAGCombiner &dc)
900     : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
901 
NodeDeleted(SDNode * N,SDNode * E)902   void NodeDeleted(SDNode *N, SDNode *E) override {
903     DC.removeFromWorklist(N);
904   }
905 };
906 
907 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
908   DAGCombiner &DC;
909 
910 public:
WorklistInserter(DAGCombiner & dc)911   explicit WorklistInserter(DAGCombiner &dc)
912       : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
913 
914   // FIXME: Ideally we could add N to the worklist, but this causes exponential
915   //        compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)916   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
917 };
918 
919 } // end anonymous namespace
920 
921 //===----------------------------------------------------------------------===//
922 //  TargetLowering::DAGCombinerInfo implementation
923 //===----------------------------------------------------------------------===//
924 
AddToWorklist(SDNode * N)925 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
926   ((DAGCombiner*)DC)->AddToWorklist(N);
927 }
928 
929 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)930 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
931   return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
932 }
933 
934 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)935 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
936   return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
937 }
938 
939 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)940 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
941   return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
942 }
943 
944 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)945 recursivelyDeleteUnusedNodes(SDNode *N) {
946   return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
947 }
948 
949 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)950 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
951   return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
952 }
953 
954 //===----------------------------------------------------------------------===//
955 // Helper Functions
956 //===----------------------------------------------------------------------===//
957 
deleteAndRecombine(SDNode * N)958 void DAGCombiner::deleteAndRecombine(SDNode *N) {
959   removeFromWorklist(N);
960 
961   // If the operands of this node are only used by the node, they will now be
962   // dead. Make sure to re-visit them and recursively delete dead nodes.
963   for (const SDValue &Op : N->ops())
964     // For an operand generating multiple values, one of the values may
965     // become dead allowing further simplification (e.g. split index
966     // arithmetic from an indexed load).
967     if (Op->hasOneUse() || Op->getNumValues() > 1)
968       AddToWorklist(Op.getNode());
969 
970   DAG.DeleteNode(N);
971 }
972 
973 // APInts must be the same size for most operations, this helper
974 // function zero extends the shorter of the pair so that they match.
975 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)976 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
977   unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
978   LHS = LHS.zext(Bits);
979   RHS = RHS.zext(Bits);
980 }
981 
982 // Return true if this node is a setcc, or is a select_cc
983 // that selects between the target values used for true and false, making it
984 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
985 // the appropriate nodes based on the type of node we are checking. This
986 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const987 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
988                                     SDValue &CC, bool MatchStrict) const {
989   if (N.getOpcode() == ISD::SETCC) {
990     LHS = N.getOperand(0);
991     RHS = N.getOperand(1);
992     CC  = N.getOperand(2);
993     return true;
994   }
995 
996   if (MatchStrict &&
997       (N.getOpcode() == ISD::STRICT_FSETCC ||
998        N.getOpcode() == ISD::STRICT_FSETCCS)) {
999     LHS = N.getOperand(1);
1000     RHS = N.getOperand(2);
1001     CC  = N.getOperand(3);
1002     return true;
1003   }
1004 
1005   if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
1006       !TLI.isConstFalseVal(N.getOperand(3)))
1007     return false;
1008 
1009   if (TLI.getBooleanContents(N.getValueType()) ==
1010       TargetLowering::UndefinedBooleanContent)
1011     return false;
1012 
1013   LHS = N.getOperand(0);
1014   RHS = N.getOperand(1);
1015   CC  = N.getOperand(4);
1016   return true;
1017 }
1018 
1019 /// Return true if this is a SetCC-equivalent operation with only one use.
1020 /// If this is true, it allows the users to invert the operation for free when
1021 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const1022 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1023   SDValue N0, N1, N2;
1024   if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
1025     return true;
1026   return false;
1027 }
1028 
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)1029 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1030   if (!ScalarTy.isSimple())
1031     return false;
1032 
1033   uint64_t MaskForTy = 0ULL;
1034   switch (ScalarTy.getSimpleVT().SimpleTy) {
1035   case MVT::i8:
1036     MaskForTy = 0xFFULL;
1037     break;
1038   case MVT::i16:
1039     MaskForTy = 0xFFFFULL;
1040     break;
1041   case MVT::i32:
1042     MaskForTy = 0xFFFFFFFFULL;
1043     break;
1044   default:
1045     return false;
1046     break;
1047   }
1048 
1049   APInt Val;
1050   if (ISD::isConstantSplatVector(N, Val))
1051     return Val.getLimitedValue() == MaskForTy;
1052 
1053   return false;
1054 }
1055 
1056 // Determines if it is a constant integer or a splat/build vector of constant
1057 // integers (and undefs).
1058 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)1059 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1060   if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
1061     return !(Const->isOpaque() && NoOpaques);
1062   if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1063     return false;
1064   unsigned BitWidth = N.getScalarValueSizeInBits();
1065   for (const SDValue &Op : N->op_values()) {
1066     if (Op.isUndef())
1067       continue;
1068     ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
1069     if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1070         (Const->isOpaque() && NoOpaques))
1071       return false;
1072   }
1073   return true;
1074 }
1075 
1076 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1077 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)1078 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1079   if (V.getOpcode() != ISD::BUILD_VECTOR)
1080     return false;
1081   return isConstantOrConstantVector(V, NoOpaques) ||
1082          ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
1083 }
1084 
1085 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)1086 static bool canSplitIdx(LoadSDNode *LD) {
1087   return MaySplitLoadIndex &&
1088          (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
1089           !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
1090 }
1091 
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1092 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1093                                                              const SDLoc &DL,
1094                                                              SDNode *N,
1095                                                              SDValue N0,
1096                                                              SDValue N1) {
1097   // Currently this only tries to ensure we don't undo the GEP splits done by
1098   // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1099   // we check if the following transformation would be problematic:
1100   // (load/store (add, (add, x, offset1), offset2)) ->
1101   // (load/store (add, x, offset1+offset2)).
1102 
1103   // (load/store (add, (add, x, y), offset2)) ->
1104   // (load/store (add, (add, x, offset2), y)).
1105 
1106   if (!N0.isAnyAdd())
1107     return false;
1108 
1109   // Check for vscale addressing modes.
1110   // (load/store (add/sub (add x, y), vscale))
1111   // (load/store (add/sub (add x, y), (lsl vscale, C)))
1112   // (load/store (add/sub (add x, y), (mul vscale, C)))
1113   if ((N1.getOpcode() == ISD::VSCALE ||
1114        ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
1115         N1.getOperand(0).getOpcode() == ISD::VSCALE &&
1116         isa<ConstantSDNode>(N1.getOperand(1)))) &&
1117       N1.getValueType().getFixedSizeInBits() <= 64) {
1118     int64_t ScalableOffset = N1.getOpcode() == ISD::VSCALE
1119                                  ? N1.getConstantOperandVal(0)
1120                                  : (N1.getOperand(0).getConstantOperandVal(0) *
1121                                     (N1.getOpcode() == ISD::SHL
1122                                          ? (1LL << N1.getConstantOperandVal(1))
1123                                          : N1.getConstantOperandVal(1)));
1124     if (Opc == ISD::SUB)
1125       ScalableOffset = -ScalableOffset;
1126     if (all_of(N->users(), [&](SDNode *Node) {
1127           if (auto *LoadStore = dyn_cast<MemSDNode>(Node);
1128               LoadStore && LoadStore->getBasePtr().getNode() == N) {
1129             TargetLoweringBase::AddrMode AM;
1130             AM.HasBaseReg = true;
1131             AM.ScalableOffset = ScalableOffset;
1132             EVT VT = LoadStore->getMemoryVT();
1133             unsigned AS = LoadStore->getAddressSpace();
1134             Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1135             return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy,
1136                                              AS);
1137           }
1138           return false;
1139         }))
1140       return true;
1141   }
1142 
1143   if (Opc != ISD::ADD && Opc != ISD::PTRADD)
1144     return false;
1145 
1146   auto *C2 = dyn_cast<ConstantSDNode>(N1);
1147   if (!C2)
1148     return false;
1149 
1150   const APInt &C2APIntVal = C2->getAPIntValue();
1151   if (C2APIntVal.getSignificantBits() > 64)
1152     return false;
1153 
1154   if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1155     if (N0.hasOneUse())
1156       return false;
1157 
1158     const APInt &C1APIntVal = C1->getAPIntValue();
1159     const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1160     if (CombinedValueIntVal.getSignificantBits() > 64)
1161       return false;
1162     const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1163 
1164     for (SDNode *Node : N->users()) {
1165       if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1166         // Is x[offset2] already not a legal addressing mode? If so then
1167         // reassociating the constants breaks nothing (we test offset2 because
1168         // that's the one we hope to fold into the load or store).
1169         TargetLoweringBase::AddrMode AM;
1170         AM.HasBaseReg = true;
1171         AM.BaseOffs = C2APIntVal.getSExtValue();
1172         EVT VT = LoadStore->getMemoryVT();
1173         unsigned AS = LoadStore->getAddressSpace();
1174         Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1175         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1176           continue;
1177 
1178         // Would x[offset1+offset2] still be a legal addressing mode?
1179         AM.BaseOffs = CombinedValue;
1180         if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1181           return true;
1182       }
1183     }
1184   } else {
1185     if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1186       if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1187         return false;
1188 
1189     for (SDNode *Node : N->users()) {
1190       auto *LoadStore = dyn_cast<MemSDNode>(Node);
1191       if (!LoadStore)
1192         return false;
1193 
1194       // Is x[offset2] a legal addressing mode? If so then
1195       // reassociating the constants breaks address pattern
1196       TargetLoweringBase::AddrMode AM;
1197       AM.HasBaseReg = true;
1198       AM.BaseOffs = C2APIntVal.getSExtValue();
1199       EVT VT = LoadStore->getMemoryVT();
1200       unsigned AS = LoadStore->getAddressSpace();
1201       Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1202       if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1203         return false;
1204     }
1205     return true;
1206   }
1207 
1208   return false;
1209 }
1210 
1211 /// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1212 /// \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1213 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1214                                                SDValue N0, SDValue N1,
1215                                                SDNodeFlags Flags) {
1216   EVT VT = N0.getValueType();
1217 
1218   if (N0.getOpcode() != Opc)
1219     return SDValue();
1220 
1221   SDValue N00 = N0.getOperand(0);
1222   SDValue N01 = N0.getOperand(1);
1223 
1224   if (DAG.isConstantIntBuildVectorOrConstantInt(N01)) {
1225     SDNodeFlags NewFlags;
1226     if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1227         Flags.hasNoUnsignedWrap())
1228       NewFlags |= SDNodeFlags::NoUnsignedWrap;
1229 
1230     if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
1231       // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1232       if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1})) {
1233         NewFlags.setDisjoint(Flags.hasDisjoint() &&
1234                              N0->getFlags().hasDisjoint());
1235         return DAG.getNode(Opc, DL, VT, N00, OpNode, NewFlags);
1236       }
1237       return SDValue();
1238     }
1239     if (TLI.isReassocProfitable(DAG, N0, N1)) {
1240       // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1241       //              iff (op x, c1) has one use
1242       SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, NewFlags);
1243       return DAG.getNode(Opc, DL, VT, OpNode, N01, NewFlags);
1244     }
1245   }
1246 
1247   // Check for repeated operand logic simplifications.
1248   if (Opc == ISD::AND || Opc == ISD::OR) {
1249     // (N00 & N01) & N00 --> N00 & N01
1250     // (N00 & N01) & N01 --> N00 & N01
1251     // (N00 | N01) | N00 --> N00 | N01
1252     // (N00 | N01) | N01 --> N00 | N01
1253     if (N1 == N00 || N1 == N01)
1254       return N0;
1255   }
1256   if (Opc == ISD::XOR) {
1257     // (N00 ^ N01) ^ N00 --> N01
1258     if (N1 == N00)
1259       return N01;
1260     // (N00 ^ N01) ^ N01 --> N00
1261     if (N1 == N01)
1262       return N00;
1263   }
1264 
1265   if (TLI.isReassocProfitable(DAG, N0, N1)) {
1266     if (N1 != N01) {
1267       // Reassociate if (op N00, N1) already exist
1268       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1269         // if Op (Op N00, N1), N01 already exist
1270         // we need to stop reassciate to avoid dead loop
1271         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1272           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1273       }
1274     }
1275 
1276     if (N1 != N00) {
1277       // Reassociate if (op N01, N1) already exist
1278       if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1279         // if Op (Op N01, N1), N00 already exist
1280         // we need to stop reassciate to avoid dead loop
1281         if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1282           return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1283       }
1284     }
1285 
1286     // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1287     // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1288     // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1289     // comparisons with the same predicate. This enables optimizations as the
1290     // following one:
1291     // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1292     // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1293     if (Opc == ISD::AND || Opc == ISD::OR) {
1294       if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1295           N01->getOpcode() == ISD::SETCC) {
1296         ISD::CondCode CC1 = cast<CondCodeSDNode>(N1.getOperand(2))->get();
1297         ISD::CondCode CC00 = cast<CondCodeSDNode>(N00.getOperand(2))->get();
1298         ISD::CondCode CC01 = cast<CondCodeSDNode>(N01.getOperand(2))->get();
1299         if (CC1 == CC00 && CC1 != CC01) {
1300           SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, Flags);
1301           return DAG.getNode(Opc, DL, VT, OpNode, N01, Flags);
1302         }
1303         if (CC1 == CC01 && CC1 != CC00) {
1304           SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N01, N1, Flags);
1305           return DAG.getNode(Opc, DL, VT, OpNode, N00, Flags);
1306         }
1307       }
1308     }
1309   }
1310 
1311   return SDValue();
1312 }
1313 
1314 /// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1315 /// same kind of operation as \p Opc.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1316 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1317                                     SDValue N1, SDNodeFlags Flags) {
1318   assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1319 
1320   // Floating-point reassociation is not allowed without loose FP math.
1321   if (N0.getValueType().isFloatingPoint() ||
1322       N1.getValueType().isFloatingPoint())
1323     if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1324       return SDValue();
1325 
1326   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1327     return Combined;
1328   if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0, Flags))
1329     return Combined;
1330   return SDValue();
1331 }
1332 
1333 // Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1334 // Note that we only expect Flags to be passed from FP operations. For integer
1335 // operations they need to be dropped.
reassociateReduction(unsigned RedOpc,unsigned Opc,const SDLoc & DL,EVT VT,SDValue N0,SDValue N1,SDNodeFlags Flags)1336 SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1337                                           const SDLoc &DL, EVT VT, SDValue N0,
1338                                           SDValue N1, SDNodeFlags Flags) {
1339   if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1340       N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
1341       N0->hasOneUse() && N1->hasOneUse() &&
1342       TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) &&
1343       TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
1344     SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1345     return DAG.getNode(RedOpc, DL, VT,
1346                        DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
1347                                    N0.getOperand(0), N1.getOperand(0)));
1348   }
1349 
1350   // Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
1351   // op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
1352   // single node.
1353   SDValue A, B, C, D, RedA, RedB;
1354   if (sd_match(N0, m_OneUse(m_c_BinOp(
1355                        Opc,
1356                        m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(A))),
1357                                m_Value(RedA)),
1358                        m_Value(B)))) &&
1359       sd_match(N1, m_OneUse(m_c_BinOp(
1360                        Opc,
1361                        m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(C))),
1362                                m_Value(RedB)),
1363                        m_Value(D)))) &&
1364       !sd_match(B, m_UnaryOp(RedOpc, m_Value())) &&
1365       !sd_match(D, m_UnaryOp(RedOpc, m_Value())) &&
1366       A.getValueType() == C.getValueType() &&
1367       hasOperation(Opc, A.getValueType()) &&
1368       TLI.shouldReassociateReduction(RedOpc, VT)) {
1369     if ((Opc == ISD::FADD || Opc == ISD::FMUL) &&
1370         (!N0->getFlags().hasAllowReassociation() ||
1371          !N1->getFlags().hasAllowReassociation() ||
1372          !RedA->getFlags().hasAllowReassociation() ||
1373          !RedB->getFlags().hasAllowReassociation()))
1374       return SDValue();
1375     SelectionDAG::FlagInserter FlagsInserter(
1376         DAG, Flags & N0->getFlags() & N1->getFlags() & RedA->getFlags() &
1377                  RedB->getFlags());
1378     SDValue Op = DAG.getNode(Opc, DL, A.getValueType(), A, C);
1379     SDValue Red = DAG.getNode(RedOpc, DL, VT, Op);
1380     SDValue Op2 = DAG.getNode(Opc, DL, VT, B, D);
1381     return DAG.getNode(Opc, DL, VT, Red, Op2);
1382   }
1383   return SDValue();
1384 }
1385 
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1386 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1387                                bool AddTo) {
1388   assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1389   ++NodesCombined;
1390   LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1391              To[0].dump(&DAG);
1392              dbgs() << " and " << NumTo - 1 << " other values\n");
1393   for (unsigned i = 0, e = NumTo; i != e; ++i)
1394     assert((!To[i].getNode() ||
1395             N->getValueType(i) == To[i].getValueType()) &&
1396            "Cannot combine value to value of different type!");
1397 
1398   WorklistRemover DeadNodes(*this);
1399   DAG.ReplaceAllUsesWith(N, To);
1400   if (AddTo) {
1401     // Push the new nodes and any users onto the worklist
1402     for (unsigned i = 0, e = NumTo; i != e; ++i) {
1403       if (To[i].getNode())
1404         AddToWorklistWithUsers(To[i].getNode());
1405     }
1406   }
1407 
1408   // Finally, if the node is now dead, remove it from the graph.  The node
1409   // may not be dead if the replacement process recursively simplified to
1410   // something else needing this node.
1411   if (N->use_empty())
1412     deleteAndRecombine(N);
1413   return SDValue(N, 0);
1414 }
1415 
1416 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1417 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1418   // Replace the old value with the new one.
1419   ++NodesCombined;
1420   LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1421              dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1422 
1423   // Replace all uses.
1424   DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1425 
1426   // Push the new node and any (possibly new) users onto the worklist.
1427   AddToWorklistWithUsers(TLO.New.getNode());
1428 
1429   // Finally, if the node is now dead, remove it from the graph.
1430   recursivelyDeleteUnusedNodes(TLO.Old.getNode());
1431 }
1432 
1433 /// Check the specified integer node value to see if it can be simplified or if
1434 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1435 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1436                                        const APInt &DemandedElts,
1437                                        bool AssumeSingleUse) {
1438   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1439   KnownBits Known;
1440   if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1441                                 AssumeSingleUse))
1442     return false;
1443 
1444   // Revisit the node.
1445   AddToWorklist(Op.getNode());
1446 
1447   CommitTargetLoweringOpt(TLO);
1448   return true;
1449 }
1450 
1451 /// Check the specified vector node value to see if it can be simplified or
1452 /// if things it uses can be simplified as it only uses some of the elements.
1453 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1454 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1455                                              const APInt &DemandedElts,
1456                                              bool AssumeSingleUse) {
1457   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1458   APInt KnownUndef, KnownZero;
1459   if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1460                                       TLO, 0, AssumeSingleUse))
1461     return false;
1462 
1463   // Revisit the node.
1464   AddToWorklist(Op.getNode());
1465 
1466   CommitTargetLoweringOpt(TLO);
1467   return true;
1468 }
1469 
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1470 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1471   SDLoc DL(Load);
1472   EVT VT = Load->getValueType(0);
1473   SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1474 
1475   LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1476              Trunc.dump(&DAG); dbgs() << '\n');
1477 
1478   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1479   DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1480 
1481   AddToWorklist(Trunc.getNode());
1482   recursivelyDeleteUnusedNodes(Load);
1483 }
1484 
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1485 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1486   Replace = false;
1487   SDLoc DL(Op);
1488   if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1489     LoadSDNode *LD = cast<LoadSDNode>(Op);
1490     EVT MemVT = LD->getMemoryVT();
1491     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1492                                                       : LD->getExtensionType();
1493     Replace = true;
1494     return DAG.getExtLoad(ExtType, DL, PVT,
1495                           LD->getChain(), LD->getBasePtr(),
1496                           MemVT, LD->getMemOperand());
1497   }
1498 
1499   unsigned Opc = Op.getOpcode();
1500   switch (Opc) {
1501   default: break;
1502   case ISD::AssertSext:
1503     if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1504       return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1505     break;
1506   case ISD::AssertZext:
1507     if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1508       return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1509     break;
1510   case ISD::Constant: {
1511     unsigned ExtOpc =
1512       Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1513     return DAG.getNode(ExtOpc, DL, PVT, Op);
1514   }
1515   }
1516 
1517   if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1518     return SDValue();
1519   return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1520 }
1521 
SExtPromoteOperand(SDValue Op,EVT PVT)1522 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1523   if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1524     return SDValue();
1525   EVT OldVT = Op.getValueType();
1526   SDLoc DL(Op);
1527   bool Replace = false;
1528   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1529   if (!NewOp.getNode())
1530     return SDValue();
1531   AddToWorklist(NewOp.getNode());
1532 
1533   if (Replace)
1534     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1535   return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1536                      DAG.getValueType(OldVT));
1537 }
1538 
ZExtPromoteOperand(SDValue Op,EVT PVT)1539 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1540   EVT OldVT = Op.getValueType();
1541   SDLoc DL(Op);
1542   bool Replace = false;
1543   SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1544   if (!NewOp.getNode())
1545     return SDValue();
1546   AddToWorklist(NewOp.getNode());
1547 
1548   if (Replace)
1549     ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1550   return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1551 }
1552 
1553 /// Promote the specified integer binary operation if the target indicates it is
1554 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1555 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1556 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1557   if (!LegalOperations)
1558     return SDValue();
1559 
1560   EVT VT = Op.getValueType();
1561   if (VT.isVector() || !VT.isInteger())
1562     return SDValue();
1563 
1564   // If operation type is 'undesirable', e.g. i16 on x86, consider
1565   // promoting it.
1566   unsigned Opc = Op.getOpcode();
1567   if (TLI.isTypeDesirableForOp(Opc, VT))
1568     return SDValue();
1569 
1570   EVT PVT = VT;
1571   // Consult target whether it is a good idea to promote this operation and
1572   // what's the right type to promote it to.
1573   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1574     assert(PVT != VT && "Don't know what type to promote to!");
1575 
1576     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1577 
1578     bool Replace0 = false;
1579     SDValue N0 = Op.getOperand(0);
1580     SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1581 
1582     bool Replace1 = false;
1583     SDValue N1 = Op.getOperand(1);
1584     SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1585     SDLoc DL(Op);
1586 
1587     SDValue RV =
1588         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1589 
1590     // We are always replacing N0/N1's use in N and only need additional
1591     // replacements if there are additional uses.
1592     // Note: We are checking uses of the *nodes* (SDNode) rather than values
1593     //       (SDValue) here because the node may reference multiple values
1594     //       (for example, the chain value of a load node).
1595     Replace0 &= !N0->hasOneUse();
1596     Replace1 &= (N0 != N1) && !N1->hasOneUse();
1597 
1598     // Combine Op here so it is preserved past replacements.
1599     CombineTo(Op.getNode(), RV);
1600 
1601     // If operands have a use ordering, make sure we deal with
1602     // predecessor first.
1603     if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1604       std::swap(N0, N1);
1605       std::swap(NN0, NN1);
1606     }
1607 
1608     if (Replace0) {
1609       AddToWorklist(NN0.getNode());
1610       ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1611     }
1612     if (Replace1) {
1613       AddToWorklist(NN1.getNode());
1614       ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1615     }
1616     return Op;
1617   }
1618   return SDValue();
1619 }
1620 
1621 /// Promote the specified integer shift operation if the target indicates it is
1622 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1623 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1624 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1625   if (!LegalOperations)
1626     return SDValue();
1627 
1628   EVT VT = Op.getValueType();
1629   if (VT.isVector() || !VT.isInteger())
1630     return SDValue();
1631 
1632   // If operation type is 'undesirable', e.g. i16 on x86, consider
1633   // promoting it.
1634   unsigned Opc = Op.getOpcode();
1635   if (TLI.isTypeDesirableForOp(Opc, VT))
1636     return SDValue();
1637 
1638   EVT PVT = VT;
1639   // Consult target whether it is a good idea to promote this operation and
1640   // what's the right type to promote it to.
1641   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1642     assert(PVT != VT && "Don't know what type to promote to!");
1643 
1644     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1645 
1646     bool Replace = false;
1647     SDValue N0 = Op.getOperand(0);
1648     if (Opc == ISD::SRA)
1649       N0 = SExtPromoteOperand(N0, PVT);
1650     else if (Opc == ISD::SRL)
1651       N0 = ZExtPromoteOperand(N0, PVT);
1652     else
1653       N0 = PromoteOperand(N0, PVT, Replace);
1654 
1655     if (!N0.getNode())
1656       return SDValue();
1657 
1658     SDLoc DL(Op);
1659     SDValue N1 = Op.getOperand(1);
1660     SDValue RV =
1661         DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1662 
1663     if (Replace)
1664       ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1665 
1666     // Deal with Op being deleted.
1667     if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1668       return RV;
1669   }
1670   return SDValue();
1671 }
1672 
PromoteExtend(SDValue Op)1673 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1674   if (!LegalOperations)
1675     return SDValue();
1676 
1677   EVT VT = Op.getValueType();
1678   if (VT.isVector() || !VT.isInteger())
1679     return SDValue();
1680 
1681   // If operation type is 'undesirable', e.g. i16 on x86, consider
1682   // promoting it.
1683   unsigned Opc = Op.getOpcode();
1684   if (TLI.isTypeDesirableForOp(Opc, VT))
1685     return SDValue();
1686 
1687   EVT PVT = VT;
1688   // Consult target whether it is a good idea to promote this operation and
1689   // what's the right type to promote it to.
1690   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1691     assert(PVT != VT && "Don't know what type to promote to!");
1692     // fold (aext (aext x)) -> (aext x)
1693     // fold (aext (zext x)) -> (zext x)
1694     // fold (aext (sext x)) -> (sext x)
1695     LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1696     return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1697   }
1698   return SDValue();
1699 }
1700 
PromoteLoad(SDValue Op)1701 bool DAGCombiner::PromoteLoad(SDValue Op) {
1702   if (!LegalOperations)
1703     return false;
1704 
1705   if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1706     return false;
1707 
1708   EVT VT = Op.getValueType();
1709   if (VT.isVector() || !VT.isInteger())
1710     return false;
1711 
1712   // If operation type is 'undesirable', e.g. i16 on x86, consider
1713   // promoting it.
1714   unsigned Opc = Op.getOpcode();
1715   if (TLI.isTypeDesirableForOp(Opc, VT))
1716     return false;
1717 
1718   EVT PVT = VT;
1719   // Consult target whether it is a good idea to promote this operation and
1720   // what's the right type to promote it to.
1721   if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1722     assert(PVT != VT && "Don't know what type to promote to!");
1723 
1724     SDLoc DL(Op);
1725     SDNode *N = Op.getNode();
1726     LoadSDNode *LD = cast<LoadSDNode>(N);
1727     EVT MemVT = LD->getMemoryVT();
1728     ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1729                                                       : LD->getExtensionType();
1730     SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1731                                    LD->getChain(), LD->getBasePtr(),
1732                                    MemVT, LD->getMemOperand());
1733     SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1734 
1735     LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1736                Result.dump(&DAG); dbgs() << '\n');
1737 
1738     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1739     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1740 
1741     AddToWorklist(Result.getNode());
1742     recursivelyDeleteUnusedNodes(N);
1743     return true;
1744   }
1745 
1746   return false;
1747 }
1748 
1749 /// Recursively delete a node which has no uses and any operands for
1750 /// which it is the only use.
1751 ///
1752 /// Note that this both deletes the nodes and removes them from the worklist.
1753 /// It also adds any nodes who have had a user deleted to the worklist as they
1754 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1755 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1756   if (!N->use_empty())
1757     return false;
1758 
1759   SmallSetVector<SDNode *, 16> Nodes;
1760   Nodes.insert(N);
1761   do {
1762     N = Nodes.pop_back_val();
1763     if (!N)
1764       continue;
1765 
1766     if (N->use_empty()) {
1767       for (const SDValue &ChildN : N->op_values())
1768         Nodes.insert(ChildN.getNode());
1769 
1770       removeFromWorklist(N);
1771       DAG.DeleteNode(N);
1772     } else {
1773       AddToWorklist(N);
1774     }
1775   } while (!Nodes.empty());
1776   return true;
1777 }
1778 
1779 //===----------------------------------------------------------------------===//
1780 //  Main DAG Combiner implementation
1781 //===----------------------------------------------------------------------===//
1782 
Run(CombineLevel AtLevel)1783 void DAGCombiner::Run(CombineLevel AtLevel) {
1784   // set the instance variables, so that the various visit routines may use it.
1785   Level = AtLevel;
1786   LegalDAG = Level >= AfterLegalizeDAG;
1787   LegalOperations = Level >= AfterLegalizeVectorOps;
1788   LegalTypes = Level >= AfterLegalizeTypes;
1789 
1790   WorklistInserter AddNodes(*this);
1791 
1792   // Add all the dag nodes to the worklist.
1793   //
1794   // Note: All nodes are not added to PruningList here, this is because the only
1795   // nodes which can be deleted are those which have no uses and all other nodes
1796   // which would otherwise be added to the worklist by the first call to
1797   // getNextWorklistEntry are already present in it.
1798   for (SDNode &Node : DAG.allnodes())
1799     AddToWorklist(&Node, /* IsCandidateForPruning */ Node.use_empty());
1800 
1801   // Create a dummy node (which is not added to allnodes), that adds a reference
1802   // to the root node, preventing it from being deleted, and tracking any
1803   // changes of the root.
1804   HandleSDNode Dummy(DAG.getRoot());
1805 
1806   // While we have a valid worklist entry node, try to combine it.
1807   while (SDNode *N = getNextWorklistEntry()) {
1808     // If N has no uses, it is dead.  Make sure to revisit all N's operands once
1809     // N is deleted from the DAG, since they too may now be dead or may have a
1810     // reduced number of uses, allowing other xforms.
1811     if (recursivelyDeleteUnusedNodes(N))
1812       continue;
1813 
1814     WorklistRemover DeadNodes(*this);
1815 
1816     // If this combine is running after legalizing the DAG, re-legalize any
1817     // nodes pulled off the worklist.
1818     if (LegalDAG) {
1819       SmallSetVector<SDNode *, 16> UpdatedNodes;
1820       bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1821 
1822       for (SDNode *LN : UpdatedNodes)
1823         AddToWorklistWithUsers(LN);
1824 
1825       if (!NIsValid)
1826         continue;
1827     }
1828 
1829     LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1830 
1831     // Add any operands of the new node which have not yet been combined to the
1832     // worklist as well. getNextWorklistEntry flags nodes that have been
1833     // combined before. Because the worklist uniques things already, this won't
1834     // repeatedly process the same operand.
1835     for (const SDValue &ChildN : N->op_values())
1836       AddToWorklist(ChildN.getNode(), /*IsCandidateForPruning=*/true,
1837                     /*SkipIfCombinedBefore=*/true);
1838 
1839     SDValue RV = combine(N);
1840 
1841     if (!RV.getNode())
1842       continue;
1843 
1844     ++NodesCombined;
1845 
1846     // Invalidate cached info.
1847     ChainsWithoutMergeableStores.clear();
1848 
1849     // If we get back the same node we passed in, rather than a new node or
1850     // zero, we know that the node must have defined multiple values and
1851     // CombineTo was used.  Since CombineTo takes care of the worklist
1852     // mechanics for us, we have no work to do in this case.
1853     if (RV.getNode() == N)
1854       continue;
1855 
1856     assert(N->getOpcode() != ISD::DELETED_NODE &&
1857            RV.getOpcode() != ISD::DELETED_NODE &&
1858            "Node was deleted but visit returned new node!");
1859 
1860     LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1861 
1862     if (N->getNumValues() == RV->getNumValues())
1863       DAG.ReplaceAllUsesWith(N, RV.getNode());
1864     else {
1865       assert(N->getValueType(0) == RV.getValueType() &&
1866              N->getNumValues() == 1 && "Type mismatch");
1867       DAG.ReplaceAllUsesWith(N, &RV);
1868     }
1869 
1870     // Push the new node and any users onto the worklist.  Omit this if the
1871     // new node is the EntryToken (e.g. if a store managed to get optimized
1872     // out), because re-visiting the EntryToken and its users will not uncover
1873     // any additional opportunities, but there may be a large number of such
1874     // users, potentially causing compile time explosion.
1875     if (RV.getOpcode() != ISD::EntryToken)
1876       AddToWorklistWithUsers(RV.getNode());
1877 
1878     // Finally, if the node is now dead, remove it from the graph.  The node
1879     // may not be dead if the replacement process recursively simplified to
1880     // something else needing this node. This will also take care of adding any
1881     // operands which have lost a user to the worklist.
1882     recursivelyDeleteUnusedNodes(N);
1883   }
1884 
1885   // If the root changed (e.g. it was a dead load, update the root).
1886   DAG.setRoot(Dummy.getValue());
1887   DAG.RemoveDeadNodes();
1888 }
1889 
visit(SDNode * N)1890 SDValue DAGCombiner::visit(SDNode *N) {
1891   // clang-format off
1892   switch (N->getOpcode()) {
1893   default: break;
1894   case ISD::TokenFactor:        return visitTokenFactor(N);
1895   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
1896   case ISD::ADD:                return visitADD(N);
1897   case ISD::PTRADD:             return visitPTRADD(N);
1898   case ISD::SUB:                return visitSUB(N);
1899   case ISD::SADDSAT:
1900   case ISD::UADDSAT:            return visitADDSAT(N);
1901   case ISD::SSUBSAT:
1902   case ISD::USUBSAT:            return visitSUBSAT(N);
1903   case ISD::ADDC:               return visitADDC(N);
1904   case ISD::SADDO:
1905   case ISD::UADDO:              return visitADDO(N);
1906   case ISD::SUBC:               return visitSUBC(N);
1907   case ISD::SSUBO:
1908   case ISD::USUBO:              return visitSUBO(N);
1909   case ISD::ADDE:               return visitADDE(N);
1910   case ISD::UADDO_CARRY:        return visitUADDO_CARRY(N);
1911   case ISD::SADDO_CARRY:        return visitSADDO_CARRY(N);
1912   case ISD::SUBE:               return visitSUBE(N);
1913   case ISD::USUBO_CARRY:        return visitUSUBO_CARRY(N);
1914   case ISD::SSUBO_CARRY:        return visitSSUBO_CARRY(N);
1915   case ISD::SMULFIX:
1916   case ISD::SMULFIXSAT:
1917   case ISD::UMULFIX:
1918   case ISD::UMULFIXSAT:         return visitMULFIX(N);
1919   case ISD::MUL:                return visitMUL<EmptyMatchContext>(N);
1920   case ISD::SDIV:               return visitSDIV(N);
1921   case ISD::UDIV:               return visitUDIV(N);
1922   case ISD::SREM:
1923   case ISD::UREM:               return visitREM(N);
1924   case ISD::MULHU:              return visitMULHU(N);
1925   case ISD::MULHS:              return visitMULHS(N);
1926   case ISD::AVGFLOORS:
1927   case ISD::AVGFLOORU:
1928   case ISD::AVGCEILS:
1929   case ISD::AVGCEILU:           return visitAVG(N);
1930   case ISD::ABDS:
1931   case ISD::ABDU:               return visitABD(N);
1932   case ISD::SMUL_LOHI:          return visitSMUL_LOHI(N);
1933   case ISD::UMUL_LOHI:          return visitUMUL_LOHI(N);
1934   case ISD::SMULO:
1935   case ISD::UMULO:              return visitMULO(N);
1936   case ISD::SMIN:
1937   case ISD::SMAX:
1938   case ISD::UMIN:
1939   case ISD::UMAX:               return visitIMINMAX(N);
1940   case ISD::AND:                return visitAND(N);
1941   case ISD::OR:                 return visitOR(N);
1942   case ISD::XOR:                return visitXOR(N);
1943   case ISD::SHL:                return visitSHL(N);
1944   case ISD::SRA:                return visitSRA(N);
1945   case ISD::SRL:                return visitSRL(N);
1946   case ISD::ROTR:
1947   case ISD::ROTL:               return visitRotate(N);
1948   case ISD::FSHL:
1949   case ISD::FSHR:               return visitFunnelShift(N);
1950   case ISD::SSHLSAT:
1951   case ISD::USHLSAT:            return visitSHLSAT(N);
1952   case ISD::ABS:                return visitABS(N);
1953   case ISD::BSWAP:              return visitBSWAP(N);
1954   case ISD::BITREVERSE:         return visitBITREVERSE(N);
1955   case ISD::CTLZ:               return visitCTLZ(N);
1956   case ISD::CTLZ_ZERO_UNDEF:    return visitCTLZ_ZERO_UNDEF(N);
1957   case ISD::CTTZ:               return visitCTTZ(N);
1958   case ISD::CTTZ_ZERO_UNDEF:    return visitCTTZ_ZERO_UNDEF(N);
1959   case ISD::CTPOP:              return visitCTPOP(N);
1960   case ISD::SELECT:             return visitSELECT(N);
1961   case ISD::VSELECT:            return visitVSELECT(N);
1962   case ISD::SELECT_CC:          return visitSELECT_CC(N);
1963   case ISD::SETCC:              return visitSETCC(N);
1964   case ISD::SETCCCARRY:         return visitSETCCCARRY(N);
1965   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
1966   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
1967   case ISD::ANY_EXTEND:         return visitANY_EXTEND(N);
1968   case ISD::AssertSext:
1969   case ISD::AssertZext:         return visitAssertExt(N);
1970   case ISD::AssertAlign:        return visitAssertAlign(N);
1971   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
1972   case ISD::SIGN_EXTEND_VECTOR_INREG:
1973   case ISD::ZERO_EXTEND_VECTOR_INREG:
1974   case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1975   case ISD::TRUNCATE:           return visitTRUNCATE(N);
1976   case ISD::TRUNCATE_USAT_U:    return visitTRUNCATE_USAT_U(N);
1977   case ISD::BITCAST:            return visitBITCAST(N);
1978   case ISD::BUILD_PAIR:         return visitBUILD_PAIR(N);
1979   case ISD::FADD:               return visitFADD(N);
1980   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
1981   case ISD::FSUB:               return visitFSUB(N);
1982   case ISD::FMUL:               return visitFMUL(N);
1983   case ISD::FMA:                return visitFMA<EmptyMatchContext>(N);
1984   case ISD::FMAD:               return visitFMAD(N);
1985   case ISD::FDIV:               return visitFDIV(N);
1986   case ISD::FREM:               return visitFREM(N);
1987   case ISD::FSQRT:              return visitFSQRT(N);
1988   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
1989   case ISD::FPOW:               return visitFPOW(N);
1990   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
1991   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
1992   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
1993   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
1994   case ISD::LROUND:
1995   case ISD::LLROUND:
1996   case ISD::LRINT:
1997   case ISD::LLRINT:             return visitXROUND(N);
1998   case ISD::FP_ROUND:           return visitFP_ROUND(N);
1999   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
2000   case ISD::FNEG:               return visitFNEG(N);
2001   case ISD::FABS:               return visitFABS(N);
2002   case ISD::FFLOOR:             return visitFFLOOR(N);
2003   case ISD::FMINNUM:
2004   case ISD::FMAXNUM:
2005   case ISD::FMINIMUM:
2006   case ISD::FMAXIMUM:
2007   case ISD::FMINIMUMNUM:
2008   case ISD::FMAXIMUMNUM:       return visitFMinMax(N);
2009   case ISD::FCEIL:              return visitFCEIL(N);
2010   case ISD::FTRUNC:             return visitFTRUNC(N);
2011   case ISD::FFREXP:             return visitFFREXP(N);
2012   case ISD::BRCOND:             return visitBRCOND(N);
2013   case ISD::BR_CC:              return visitBR_CC(N);
2014   case ISD::LOAD:               return visitLOAD(N);
2015   case ISD::STORE:              return visitSTORE(N);
2016   case ISD::ATOMIC_STORE:       return visitATOMIC_STORE(N);
2017   case ISD::INSERT_VECTOR_ELT:  return visitINSERT_VECTOR_ELT(N);
2018   case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
2019   case ISD::BUILD_VECTOR:       return visitBUILD_VECTOR(N);
2020   case ISD::CONCAT_VECTORS:     return visitCONCAT_VECTORS(N);
2021   case ISD::EXTRACT_SUBVECTOR:  return visitEXTRACT_SUBVECTOR(N);
2022   case ISD::VECTOR_SHUFFLE:     return visitVECTOR_SHUFFLE(N);
2023   case ISD::SCALAR_TO_VECTOR:   return visitSCALAR_TO_VECTOR(N);
2024   case ISD::INSERT_SUBVECTOR:   return visitINSERT_SUBVECTOR(N);
2025   case ISD::MGATHER:            return visitMGATHER(N);
2026   case ISD::MLOAD:              return visitMLOAD(N);
2027   case ISD::MSCATTER:           return visitMSCATTER(N);
2028   case ISD::MSTORE:             return visitMSTORE(N);
2029   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
2030   case ISD::PARTIAL_REDUCE_SMLA:
2031   case ISD::PARTIAL_REDUCE_UMLA:
2032   case ISD::PARTIAL_REDUCE_SUMLA:
2033                                 return visitPARTIAL_REDUCE_MLA(N);
2034   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
2035   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
2036   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
2037   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
2038   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
2039   case ISD::BF16_TO_FP:         return visitBF16_TO_FP(N);
2040   case ISD::FREEZE:             return visitFREEZE(N);
2041   case ISD::GET_FPENV_MEM:      return visitGET_FPENV_MEM(N);
2042   case ISD::SET_FPENV_MEM:      return visitSET_FPENV_MEM(N);
2043   case ISD::FCANONICALIZE:      return visitFCANONICALIZE(N);
2044   case ISD::VECREDUCE_FADD:
2045   case ISD::VECREDUCE_FMUL:
2046   case ISD::VECREDUCE_ADD:
2047   case ISD::VECREDUCE_MUL:
2048   case ISD::VECREDUCE_AND:
2049   case ISD::VECREDUCE_OR:
2050   case ISD::VECREDUCE_XOR:
2051   case ISD::VECREDUCE_SMAX:
2052   case ISD::VECREDUCE_SMIN:
2053   case ISD::VECREDUCE_UMAX:
2054   case ISD::VECREDUCE_UMIN:
2055   case ISD::VECREDUCE_FMAX:
2056   case ISD::VECREDUCE_FMIN:
2057   case ISD::VECREDUCE_FMAXIMUM:
2058   case ISD::VECREDUCE_FMINIMUM:     return visitVECREDUCE(N);
2059 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
2060 #include "llvm/IR/VPIntrinsics.def"
2061     return visitVPOp(N);
2062   }
2063   // clang-format on
2064   return SDValue();
2065 }
2066 
combine(SDNode * N)2067 SDValue DAGCombiner::combine(SDNode *N) {
2068   if (!DebugCounter::shouldExecute(DAGCombineCounter))
2069     return SDValue();
2070 
2071   SDValue RV;
2072   if (!DisableGenericCombines)
2073     RV = visit(N);
2074 
2075   // If nothing happened, try a target-specific DAG combine.
2076   if (!RV.getNode()) {
2077     assert(N->getOpcode() != ISD::DELETED_NODE &&
2078            "Node was deleted but visit returned NULL!");
2079 
2080     if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2081         TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
2082 
2083       // Expose the DAG combiner to the target combiner impls.
2084       TargetLowering::DAGCombinerInfo
2085         DagCombineInfo(DAG, Level, false, this);
2086 
2087       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
2088     }
2089   }
2090 
2091   // If nothing happened still, try promoting the operation.
2092   if (!RV.getNode()) {
2093     switch (N->getOpcode()) {
2094     default: break;
2095     case ISD::ADD:
2096     case ISD::SUB:
2097     case ISD::MUL:
2098     case ISD::AND:
2099     case ISD::OR:
2100     case ISD::XOR:
2101       RV = PromoteIntBinOp(SDValue(N, 0));
2102       break;
2103     case ISD::SHL:
2104     case ISD::SRA:
2105     case ISD::SRL:
2106       RV = PromoteIntShiftOp(SDValue(N, 0));
2107       break;
2108     case ISD::SIGN_EXTEND:
2109     case ISD::ZERO_EXTEND:
2110     case ISD::ANY_EXTEND:
2111       RV = PromoteExtend(SDValue(N, 0));
2112       break;
2113     case ISD::LOAD:
2114       if (PromoteLoad(SDValue(N, 0)))
2115         RV = SDValue(N, 0);
2116       break;
2117     }
2118   }
2119 
2120   // If N is a commutative binary node, try to eliminate it if the commuted
2121   // version is already present in the DAG.
2122   if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
2123     SDValue N0 = N->getOperand(0);
2124     SDValue N1 = N->getOperand(1);
2125 
2126     // Constant operands are canonicalized to RHS.
2127     if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
2128       SDValue Ops[] = {N1, N0};
2129       SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
2130                                             N->getFlags());
2131       if (CSENode)
2132         return SDValue(CSENode, 0);
2133     }
2134   }
2135 
2136   return RV;
2137 }
2138 
2139 /// Given a node, return its input chain if it has one, otherwise return a null
2140 /// sd operand.
getInputChainForNode(SDNode * N)2141 static SDValue getInputChainForNode(SDNode *N) {
2142   if (unsigned NumOps = N->getNumOperands()) {
2143     if (N->getOperand(0).getValueType() == MVT::Other)
2144       return N->getOperand(0);
2145     if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
2146       return N->getOperand(NumOps-1);
2147     for (unsigned i = 1; i < NumOps-1; ++i)
2148       if (N->getOperand(i).getValueType() == MVT::Other)
2149         return N->getOperand(i);
2150   }
2151   return SDValue();
2152 }
2153 
visitFCANONICALIZE(SDNode * N)2154 SDValue DAGCombiner::visitFCANONICALIZE(SDNode *N) {
2155   SDValue Operand = N->getOperand(0);
2156   EVT VT = Operand.getValueType();
2157   SDLoc dl(N);
2158 
2159   // Canonicalize undef to quiet NaN.
2160   if (Operand.isUndef()) {
2161     APFloat CanonicalQNaN = APFloat::getQNaN(VT.getFltSemantics());
2162     return DAG.getConstantFP(CanonicalQNaN, dl, VT);
2163   }
2164   return SDValue();
2165 }
2166 
visitTokenFactor(SDNode * N)2167 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2168   // If N has two operands, where one has an input chain equal to the other,
2169   // the 'other' chain is redundant.
2170   if (N->getNumOperands() == 2) {
2171     if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
2172       return N->getOperand(0);
2173     if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
2174       return N->getOperand(1);
2175   }
2176 
2177   // Don't simplify token factors if optnone.
2178   if (OptLevel == CodeGenOptLevel::None)
2179     return SDValue();
2180 
2181   // Don't simplify the token factor if the node itself has too many operands.
2182   if (N->getNumOperands() > TokenFactorInlineLimit)
2183     return SDValue();
2184 
2185   // If the sole user is a token factor, we should make sure we have a
2186   // chance to merge them together. This prevents TF chains from inhibiting
2187   // optimizations.
2188   if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TokenFactor)
2189     AddToWorklist(*(N->user_begin()));
2190 
2191   SmallVector<SDNode *, 8> TFs;     // List of token factors to visit.
2192   SmallVector<SDValue, 8> Ops;      // Ops for replacing token factor.
2193   SmallPtrSet<SDNode*, 16> SeenOps;
2194   bool Changed = false;             // If we should replace this token factor.
2195 
2196   // Start out with this token factor.
2197   TFs.push_back(N);
2198 
2199   // Iterate through token factors.  The TFs grows when new token factors are
2200   // encountered.
2201   for (unsigned i = 0; i < TFs.size(); ++i) {
2202     // Limit number of nodes to inline, to avoid quadratic compile times.
2203     // We have to add the outstanding Token Factors to Ops, otherwise we might
2204     // drop Ops from the resulting Token Factors.
2205     if (Ops.size() > TokenFactorInlineLimit) {
2206       for (unsigned j = i; j < TFs.size(); j++)
2207         Ops.emplace_back(TFs[j], 0);
2208       // Drop unprocessed Token Factors from TFs, so we do not add them to the
2209       // combiner worklist later.
2210       TFs.resize(i);
2211       break;
2212     }
2213 
2214     SDNode *TF = TFs[i];
2215     // Check each of the operands.
2216     for (const SDValue &Op : TF->op_values()) {
2217       switch (Op.getOpcode()) {
2218       case ISD::EntryToken:
2219         // Entry tokens don't need to be added to the list. They are
2220         // redundant.
2221         Changed = true;
2222         break;
2223 
2224       case ISD::TokenFactor:
2225         if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
2226           // Queue up for processing.
2227           TFs.push_back(Op.getNode());
2228           Changed = true;
2229           break;
2230         }
2231         [[fallthrough]];
2232 
2233       default:
2234         // Only add if it isn't already in the list.
2235         if (SeenOps.insert(Op.getNode()).second)
2236           Ops.push_back(Op);
2237         else
2238           Changed = true;
2239         break;
2240       }
2241     }
2242   }
2243 
2244   // Re-visit inlined Token Factors, to clean them up in case they have been
2245   // removed. Skip the first Token Factor, as this is the current node.
2246   for (unsigned i = 1, e = TFs.size(); i < e; i++)
2247     AddToWorklist(TFs[i]);
2248 
2249   // Remove Nodes that are chained to another node in the list. Do so
2250   // by walking up chains breath-first stopping when we've seen
2251   // another operand. In general we must climb to the EntryNode, but we can exit
2252   // early if we find all remaining work is associated with just one operand as
2253   // no further pruning is possible.
2254 
2255   // List of nodes to search through and original Ops from which they originate.
2256   SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2257   SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2258   SmallPtrSet<SDNode *, 16> SeenChains;
2259   bool DidPruneOps = false;
2260 
2261   unsigned NumLeftToConsider = 0;
2262   for (const SDValue &Op : Ops) {
2263     Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2264     OpWorkCount.push_back(1);
2265   }
2266 
2267   auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2268     // If this is an Op, we can remove the op from the list. Remark any
2269     // search associated with it as from the current OpNumber.
2270     if (SeenOps.contains(Op)) {
2271       Changed = true;
2272       DidPruneOps = true;
2273       unsigned OrigOpNumber = 0;
2274       while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2275         OrigOpNumber++;
2276       assert((OrigOpNumber != Ops.size()) &&
2277              "expected to find TokenFactor Operand");
2278       // Re-mark worklist from OrigOpNumber to OpNumber
2279       for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2280         if (Worklist[i].second == OrigOpNumber) {
2281           Worklist[i].second = OpNumber;
2282         }
2283       }
2284       OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2285       OpWorkCount[OrigOpNumber] = 0;
2286       NumLeftToConsider--;
2287     }
2288     // Add if it's a new chain
2289     if (SeenChains.insert(Op).second) {
2290       OpWorkCount[OpNumber]++;
2291       Worklist.push_back(std::make_pair(Op, OpNumber));
2292     }
2293   };
2294 
2295   for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2296     // We need at least be consider at least 2 Ops to prune.
2297     if (NumLeftToConsider <= 1)
2298       break;
2299     auto CurNode = Worklist[i].first;
2300     auto CurOpNumber = Worklist[i].second;
2301     assert((OpWorkCount[CurOpNumber] > 0) &&
2302            "Node should not appear in worklist");
2303     switch (CurNode->getOpcode()) {
2304     case ISD::EntryToken:
2305       // Hitting EntryToken is the only way for the search to terminate without
2306       // hitting
2307       // another operand's search. Prevent us from marking this operand
2308       // considered.
2309       NumLeftToConsider++;
2310       break;
2311     case ISD::TokenFactor:
2312       for (const SDValue &Op : CurNode->op_values())
2313         AddToWorklist(i, Op.getNode(), CurOpNumber);
2314       break;
2315     case ISD::LIFETIME_START:
2316     case ISD::LIFETIME_END:
2317     case ISD::CopyFromReg:
2318     case ISD::CopyToReg:
2319       AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2320       break;
2321     default:
2322       if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2323         AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2324       break;
2325     }
2326     OpWorkCount[CurOpNumber]--;
2327     if (OpWorkCount[CurOpNumber] == 0)
2328       NumLeftToConsider--;
2329   }
2330 
2331   // If we've changed things around then replace token factor.
2332   if (Changed) {
2333     SDValue Result;
2334     if (Ops.empty()) {
2335       // The entry token is the only possible outcome.
2336       Result = DAG.getEntryNode();
2337     } else {
2338       if (DidPruneOps) {
2339         SmallVector<SDValue, 8> PrunedOps;
2340         //
2341         for (const SDValue &Op : Ops) {
2342           if (SeenChains.count(Op.getNode()) == 0)
2343             PrunedOps.push_back(Op);
2344         }
2345         Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2346       } else {
2347         Result = DAG.getTokenFactor(SDLoc(N), Ops);
2348       }
2349     }
2350     return Result;
2351   }
2352   return SDValue();
2353 }
2354 
2355 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2356 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2357   WorklistRemover DeadNodes(*this);
2358   // Replacing results may cause a different MERGE_VALUES to suddenly
2359   // be CSE'd with N, and carry its uses with it. Iterate until no
2360   // uses remain, to ensure that the node can be safely deleted.
2361   // First add the users of this node to the work list so that they
2362   // can be tried again once they have new operands.
2363   AddUsersToWorklist(N);
2364   do {
2365     // Do as a single replacement to avoid rewalking use lists.
2366     SmallVector<SDValue, 8> Ops(N->ops());
2367     DAG.ReplaceAllUsesWith(N, Ops.data());
2368   } while (!N->use_empty());
2369   deleteAndRecombine(N);
2370   return SDValue(N, 0);   // Return N so it doesn't get rechecked!
2371 }
2372 
2373 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2374 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2375 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2376   ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2377   return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2378 }
2379 
2380 // isTruncateOf - If N is a truncate of some other value, return true, record
2381 // the value being truncated in Op and which of Op's bits are zero/one in Known.
2382 // This function computes KnownBits to avoid a duplicated call to
2383 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)2384 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2385                          KnownBits &Known) {
2386   if (N->getOpcode() == ISD::TRUNCATE) {
2387     Op = N->getOperand(0);
2388     Known = DAG.computeKnownBits(Op);
2389     if (N->getFlags().hasNoUnsignedWrap())
2390       Known.Zero.setBitsFrom(N.getScalarValueSizeInBits());
2391     return true;
2392   }
2393 
2394   if (N.getValueType().getScalarType() != MVT::i1 ||
2395       !sd_match(
2396           N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
2397     return false;
2398 
2399   Known = DAG.computeKnownBits(Op);
2400   return (Known.Zero | 1).isAllOnes();
2401 }
2402 
2403 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2404 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2405 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2406                                     const TargetLowering &TLI) {
2407   EVT VT;
2408   unsigned AS;
2409 
2410   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2411     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2412       return false;
2413     VT = LD->getMemoryVT();
2414     AS = LD->getAddressSpace();
2415   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2416     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2417       return false;
2418     VT = ST->getMemoryVT();
2419     AS = ST->getAddressSpace();
2420   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2421     if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2422       return false;
2423     VT = LD->getMemoryVT();
2424     AS = LD->getAddressSpace();
2425   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2426     if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2427       return false;
2428     VT = ST->getMemoryVT();
2429     AS = ST->getAddressSpace();
2430   } else {
2431     return false;
2432   }
2433 
2434   TargetLowering::AddrMode AM;
2435   if (N->isAnyAdd()) {
2436     AM.HasBaseReg = true;
2437     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2438     if (Offset)
2439       // [reg +/- imm]
2440       AM.BaseOffs = Offset->getSExtValue();
2441     else
2442       // [reg +/- reg]
2443       AM.Scale = 1;
2444   } else if (N->getOpcode() == ISD::SUB) {
2445     AM.HasBaseReg = true;
2446     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2447     if (Offset)
2448       // [reg +/- imm]
2449       AM.BaseOffs = -Offset->getSExtValue();
2450     else
2451       // [reg +/- reg]
2452       AM.Scale = 1;
2453   } else {
2454     return false;
2455   }
2456 
2457   return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2458                                    VT.getTypeForEVT(*DAG.getContext()), AS);
2459 }
2460 
2461 /// This inverts a canonicalization in IR that replaces a variable select arm
2462 /// with an identity constant. Codegen improves if we re-use the variable
2463 /// operand rather than load a constant. This can also be converted into a
2464 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2465 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2466                                               bool ShouldCommuteOperands) {
2467   // Match a select as operand 1. The identity constant that we are looking for
2468   // is only valid as operand 1 of a non-commutative binop.
2469   SDValue N0 = N->getOperand(0);
2470   SDValue N1 = N->getOperand(1);
2471   if (ShouldCommuteOperands)
2472     std::swap(N0, N1);
2473 
2474   unsigned SelOpcode = N1.getOpcode();
2475   if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
2476       !N1.hasOneUse())
2477     return SDValue();
2478 
2479   // We can't hoist all instructions because of immediate UB (not speculatable).
2480   // For example div/rem by zero.
2481   if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2482     return SDValue();
2483 
2484   unsigned Opcode = N->getOpcode();
2485   EVT VT = N->getValueType(0);
2486   SDValue Cond = N1.getOperand(0);
2487   SDValue TVal = N1.getOperand(1);
2488   SDValue FVal = N1.getOperand(2);
2489   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2490 
2491   // This transform increases uses of N0, so freeze it to be safe.
2492   // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2493   unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2494   if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
2495       TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2496                                                FVal)) {
2497     SDValue F0 = DAG.getFreeze(N0);
2498     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2499     return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2500   }
2501   // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2502   if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
2503       TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2504                                                TVal)) {
2505     SDValue F0 = DAG.getFreeze(N0);
2506     SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2507     return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2508   }
2509 
2510   return SDValue();
2511 }
2512 
foldBinOpIntoSelect(SDNode * BO)2513 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2514   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2515   assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2516          "Unexpected binary operator");
2517 
2518   if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2519     return Sel;
2520 
2521   if (TLI.isCommutativeBinOp(BO->getOpcode()))
2522     if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2523       return Sel;
2524 
2525   // Don't do this unless the old select is going away. We want to eliminate the
2526   // binary operator, not replace a binop with a select.
2527   // TODO: Handle ISD::SELECT_CC.
2528   unsigned SelOpNo = 0;
2529   SDValue Sel = BO->getOperand(0);
2530   auto BinOpcode = BO->getOpcode();
2531   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2532     SelOpNo = 1;
2533     Sel = BO->getOperand(1);
2534 
2535     // Peek through trunc to shift amount type.
2536     if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2537          BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2538       // This is valid when the truncated bits of x are already zero.
2539       SDValue Op;
2540       KnownBits Known;
2541       if (isTruncateOf(DAG, Sel, Op, Known) &&
2542           Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2543         Sel = Op;
2544     }
2545   }
2546 
2547   if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2548     return SDValue();
2549 
2550   SDValue CT = Sel.getOperand(1);
2551   if (!isConstantOrConstantVector(CT, true) &&
2552       !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2553     return SDValue();
2554 
2555   SDValue CF = Sel.getOperand(2);
2556   if (!isConstantOrConstantVector(CF, true) &&
2557       !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2558     return SDValue();
2559 
2560   // Bail out if any constants are opaque because we can't constant fold those.
2561   // The exception is "and" and "or" with either 0 or -1 in which case we can
2562   // propagate non constant operands into select. I.e.:
2563   // and (select Cond, 0, -1), X --> select Cond, 0, X
2564   // or X, (select Cond, -1, 0) --> select Cond, -1, X
2565   bool CanFoldNonConst =
2566       (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2567       ((isNullOrNullSplat(CT) && isAllOnesOrAllOnesSplat(CF)) ||
2568        (isNullOrNullSplat(CF) && isAllOnesOrAllOnesSplat(CT)));
2569 
2570   SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2571   if (!CanFoldNonConst &&
2572       !isConstantOrConstantVector(CBO, true) &&
2573       !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2574     return SDValue();
2575 
2576   SDLoc DL(Sel);
2577   SDValue NewCT, NewCF;
2578   EVT VT = BO->getValueType(0);
2579 
2580   if (CanFoldNonConst) {
2581     // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2582     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CT)) ||
2583         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CT)))
2584       NewCT = CT;
2585     else
2586       NewCT = CBO;
2587 
2588     if ((BinOpcode == ISD::AND && isNullOrNullSplat(CF)) ||
2589         (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CF)))
2590       NewCF = CF;
2591     else
2592       NewCF = CBO;
2593   } else {
2594     // We have a select-of-constants followed by a binary operator with a
2595     // constant. Eliminate the binop by pulling the constant math into the
2596     // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2597     // CBO, CF + CBO
2598     NewCT = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CT})
2599                     : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CT, CBO});
2600     if (!NewCT)
2601       return SDValue();
2602 
2603     NewCF = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CF})
2604                     : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CF, CBO});
2605     if (!NewCF)
2606       return SDValue();
2607   }
2608 
2609   return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF, BO->getFlags());
2610 }
2611 
foldAddSubBoolOfMaskedVal(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)2612 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
2613                                          SelectionDAG &DAG) {
2614   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2615          "Expecting add or sub");
2616 
2617   // Match a constant operand and a zext operand for the math instruction:
2618   // add Z, C
2619   // sub C, Z
2620   bool IsAdd = N->getOpcode() == ISD::ADD;
2621   SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2622   SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2623   auto *CN = dyn_cast<ConstantSDNode>(C);
2624   if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2625     return SDValue();
2626 
2627   // Match the zext operand as a setcc of a boolean.
2628   if (Z.getOperand(0).getValueType() != MVT::i1)
2629     return SDValue();
2630 
2631   // Match the compare as: setcc (X & 1), 0, eq.
2632   if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
2633                                          m_SpecificCondCode(ISD::SETEQ))))
2634     return SDValue();
2635 
2636   // We are adding/subtracting a constant and an inverted low bit. Turn that
2637   // into a subtract/add of the low bit with incremented/decremented constant:
2638   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2639   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2640   EVT VT = C.getValueType();
2641   SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
2642   SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
2643                      : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2644   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2645 }
2646 
2647 // Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
foldSubToAvg(SDNode * N,const SDLoc & DL)2648 SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2649   SDValue N0 = N->getOperand(0);
2650   EVT VT = N0.getValueType();
2651   SDValue A, B;
2652 
2653   if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
2654       sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
2655                         m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
2656     return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
2657   }
2658   if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
2659       sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
2660                         m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
2661     return DAG.getNode(ISD::AVGCEILS, DL, VT, A, B);
2662   }
2663   return SDValue();
2664 }
2665 
2666 /// Try to fold a pointer arithmetic node.
2667 /// This needs to be done separately from normal addition, because pointer
2668 /// addition is not commutative.
visitPTRADD(SDNode * N)2669 SDValue DAGCombiner::visitPTRADD(SDNode *N) {
2670   SDValue N0 = N->getOperand(0);
2671   SDValue N1 = N->getOperand(1);
2672   EVT PtrVT = N0.getValueType();
2673   EVT IntVT = N1.getValueType();
2674   SDLoc DL(N);
2675 
2676   // This is already ensured by an assert in SelectionDAG::getNode(). Several
2677   // combines here depend on this assumption.
2678   assert(PtrVT == IntVT &&
2679          "PTRADD with different operand types is not supported");
2680 
2681   // fold (ptradd x, 0) -> x
2682   if (isNullConstant(N1))
2683     return N0;
2684 
2685   // fold (ptradd 0, x) -> x
2686   if (PtrVT == IntVT && isNullConstant(N0))
2687     return N1;
2688 
2689   if (N0.getOpcode() != ISD::PTRADD ||
2690       reassociationCanBreakAddressingModePattern(ISD::PTRADD, DL, N, N0, N1))
2691     return SDValue();
2692 
2693   SDValue X = N0.getOperand(0);
2694   SDValue Y = N0.getOperand(1);
2695   SDValue Z = N1;
2696   bool N0OneUse = N0.hasOneUse();
2697   bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
2698   bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
2699 
2700   // (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
2701   //   * y is a constant and (ptradd x, y) has one use; or
2702   //   * y and z are both constants.
2703   if ((YIsConstant && N0OneUse) || (YIsConstant && ZIsConstant)) {
2704     // If both additions in the original were NUW, the new ones are as well.
2705     SDNodeFlags Flags =
2706         (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
2707     SDValue Add = DAG.getNode(ISD::ADD, DL, IntVT, {Y, Z}, Flags);
2708     AddToWorklist(Add.getNode());
2709     return DAG.getMemBasePlusOffset(X, Add, DL, Flags);
2710   }
2711 
2712   // TODO: There is another possible fold here that was proven useful.
2713   // It would be this:
2714   //
2715   // (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y) if:
2716   //   * (ptradd x, y) has one use; and
2717   //   * y is a constant; and
2718   //   * z is not a constant.
2719   //
2720   // In some cases, specifically in AArch64's FEAT_CPA, it exposes the
2721   // opportunity to select more complex instructions such as SUBPT and
2722   // MSUBPT. However, a hypothetical corner case has been found that we could
2723   // not avoid. Consider this (pseudo-POSIX C):
2724   //
2725   // char *foo(char *x, int z) {return (x + LARGE_CONSTANT) + z;}
2726   // char *p = mmap(LARGE_CONSTANT);
2727   // char *q = foo(p, -LARGE_CONSTANT);
2728   //
2729   // Then x + LARGE_CONSTANT is one-past-the-end, so valid, and a
2730   // further + z takes it back to the start of the mapping, so valid,
2731   // regardless of the address mmap gave back. However, if mmap gives you an
2732   // address < LARGE_CONSTANT (ignoring high bits), x - LARGE_CONSTANT will
2733   // borrow from the high bits (with the subsequent + z carrying back into
2734   // the high bits to give you a well-defined pointer) and thus trip
2735   // FEAT_CPA's pointer corruption checks.
2736   //
2737   // We leave this fold as an opportunity for future work, addressing the
2738   // corner case for FEAT_CPA, as well as reconciling the solution with the
2739   // more general application of pointer arithmetic in other future targets.
2740   // For now each architecture that wants this fold must implement it in the
2741   // target-specific code (see e.g. SITargetLowering::performPtrAddCombine)
2742 
2743   return SDValue();
2744 }
2745 
2746 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2747 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)2748 static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
2749                                    SelectionDAG &DAG) {
2750   assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2751          "Expecting add or sub");
2752 
2753   // We need a constant operand for the add/sub, and the other operand is a
2754   // logical shift right: add (srl), C or sub C, (srl).
2755   bool IsAdd = N->getOpcode() == ISD::ADD;
2756   SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2757   SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2758   if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2759       ShiftOp.getOpcode() != ISD::SRL)
2760     return SDValue();
2761 
2762   // The shift must be of a 'not' value.
2763   SDValue Not = ShiftOp.getOperand(0);
2764   if (!Not.hasOneUse() || !isBitwiseNot(Not))
2765     return SDValue();
2766 
2767   // The shift must be moving the sign bit to the least-significant-bit.
2768   EVT VT = ShiftOp.getValueType();
2769   SDValue ShAmt = ShiftOp.getOperand(1);
2770   ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2771   if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2772     return SDValue();
2773 
2774   // Eliminate the 'not' by adjusting the shift and add/sub constant:
2775   // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2776   // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2777   if (SDValue NewC = DAG.FoldConstantArithmetic(
2778           IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2779           {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2780     SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2781                                    Not.getOperand(0), ShAmt);
2782     return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2783   }
2784 
2785   return SDValue();
2786 }
2787 
2788 static bool
areBitwiseNotOfEachother(SDValue Op0,SDValue Op1)2789 areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2790   return (isBitwiseNot(Op0) && Op0.getOperand(0) == Op1) ||
2791          (isBitwiseNot(Op1) && Op1.getOperand(0) == Op0);
2792 }
2793 
2794 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2795 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2796 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2797 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2798   SDValue N0 = N->getOperand(0);
2799   SDValue N1 = N->getOperand(1);
2800   EVT VT = N0.getValueType();
2801   SDLoc DL(N);
2802 
2803   // fold (add x, undef) -> undef
2804   if (N0.isUndef())
2805     return N0;
2806   if (N1.isUndef())
2807     return N1;
2808 
2809   // fold (add c1, c2) -> c1+c2
2810   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2811     return C;
2812 
2813   // canonicalize constant to RHS
2814   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2815       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2816     return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2817 
2818   if (areBitwiseNotOfEachother(N0, N1))
2819     return DAG.getConstant(APInt::getAllOnes(VT.getScalarSizeInBits()), DL, VT);
2820 
2821   // fold vector ops
2822   if (VT.isVector()) {
2823     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2824       return FoldedVOp;
2825 
2826     // fold (add x, 0) -> x, vector edition
2827     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2828       return N0;
2829   }
2830 
2831   // fold (add x, 0) -> x
2832   if (isNullConstant(N1))
2833     return N0;
2834 
2835   if (N0.getOpcode() == ISD::SUB) {
2836     SDValue N00 = N0.getOperand(0);
2837     SDValue N01 = N0.getOperand(1);
2838 
2839     // fold ((A-c1)+c2) -> (A+(c2-c1))
2840     if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2841       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2842 
2843     // fold ((c1-A)+c2) -> (c1+c2)-A
2844     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2845       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2846   }
2847 
2848   // add (sext i1 X), 1 -> zext (not i1 X)
2849   // We don't transform this pattern:
2850   //   add (zext i1 X), -1 -> sext (not i1 X)
2851   // because most (?) targets generate better code for the zext form.
2852   if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2853       isOneOrOneSplat(N1)) {
2854     SDValue X = N0.getOperand(0);
2855     if ((!LegalOperations ||
2856          (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2857           TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2858         X.getScalarValueSizeInBits() == 1) {
2859       SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2860       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2861     }
2862   }
2863 
2864   // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2865   // iff (or x, c0) is equivalent to (add x, c0).
2866   // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2867   // iff (xor x, c0) is equivalent to (add x, c0).
2868   if (DAG.isADDLike(N0)) {
2869     SDValue N01 = N0.getOperand(1);
2870     if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2871       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2872   }
2873 
2874   if (SDValue NewSel = foldBinOpIntoSelect(N))
2875     return NewSel;
2876 
2877   // reassociate add
2878   if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2879     if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2880       return RADD;
2881 
2882     // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2883     // equivalent to (add x, c).
2884     // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2885     // equivalent to (add x, c).
2886     // Do this optimization only when adding c does not introduce instructions
2887     // for adding carries.
2888     auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2889       if (DAG.isADDLike(N0) && N0.hasOneUse() &&
2890           isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2891         // If N0's type does not split or is a sign mask, it does not introduce
2892         // add carry.
2893         auto TyActn = TLI.getTypeAction(*DAG.getContext(), N0.getValueType());
2894         bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2895                           TyActn == TargetLoweringBase::TypePromoteInteger ||
2896                           isMinSignedConstant(N0.getOperand(1));
2897         if (NoAddCarry)
2898           return DAG.getNode(
2899               ISD::ADD, DL, VT,
2900               DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2901               N0.getOperand(1));
2902       }
2903       return SDValue();
2904     };
2905     if (SDValue Add = ReassociateAddOr(N0, N1))
2906       return Add;
2907     if (SDValue Add = ReassociateAddOr(N1, N0))
2908       return Add;
2909 
2910     // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2911     if (SDValue SD =
2912             reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
2913       return SD;
2914   }
2915 
2916   SDValue A, B, C, D;
2917 
2918   // fold ((0-A) + B) -> B-A
2919   if (sd_match(N0, m_Neg(m_Value(A))))
2920     return DAG.getNode(ISD::SUB, DL, VT, N1, A);
2921 
2922   // fold (A + (0-B)) -> A-B
2923   if (sd_match(N1, m_Neg(m_Value(B))))
2924     return DAG.getNode(ISD::SUB, DL, VT, N0, B);
2925 
2926   // fold (A+(B-A)) -> B
2927   if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0))))
2928     return B;
2929 
2930   // fold ((B-A)+A) -> B
2931   if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1))))
2932     return B;
2933 
2934   // fold ((A-B)+(C-A)) -> (C-B)
2935   if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
2936       sd_match(N1, m_Sub(m_Value(C), m_Specific(A))))
2937     return DAG.getNode(ISD::SUB, DL, VT, C, B);
2938 
2939   // fold ((A-B)+(B-C)) -> (A-C)
2940   if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
2941       sd_match(N1, m_Sub(m_Specific(B), m_Value(C))))
2942     return DAG.getNode(ISD::SUB, DL, VT, A, C);
2943 
2944   // fold (A+(B-(A+C))) to (B-C)
2945   // fold (A+(B-(C+A))) to (B-C)
2946   if (sd_match(N1, m_Sub(m_Value(B), m_Add(m_Specific(N0), m_Value(C)))))
2947     return DAG.getNode(ISD::SUB, DL, VT, B, C);
2948 
2949   // fold (A+((B-A)+or-C)) to (B+or-C)
2950   if (sd_match(N1,
2951                m_AnyOf(m_Add(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)),
2952                        m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)))))
2953     return DAG.getNode(N1.getOpcode(), DL, VT, B, C);
2954 
2955   // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2956   if (sd_match(N0, m_OneUse(m_Sub(m_Value(A), m_Value(B)))) &&
2957       sd_match(N1, m_OneUse(m_Sub(m_Value(C), m_Value(D)))) &&
2958       (isConstantOrConstantVector(A) || isConstantOrConstantVector(C)))
2959     return DAG.getNode(ISD::SUB, DL, VT,
2960                        DAG.getNode(ISD::ADD, SDLoc(N0), VT, A, C),
2961                        DAG.getNode(ISD::ADD, SDLoc(N1), VT, B, D));
2962 
2963   // fold (add (umax X, C), -C) --> (usubsat X, C)
2964   if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2965     auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2966       return (!Max && !Op) ||
2967              (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2968     };
2969     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2970                                   /*AllowUndefs*/ true))
2971       return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2972                          N0.getOperand(1));
2973   }
2974 
2975   if (SimplifyDemandedBits(SDValue(N, 0)))
2976     return SDValue(N, 0);
2977 
2978   if (isOneOrOneSplat(N1)) {
2979     // fold (add (xor a, -1), 1) -> (sub 0, a)
2980     if (isBitwiseNot(N0))
2981       return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2982                          N0.getOperand(0));
2983 
2984     // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2985     if (N0.getOpcode() == ISD::ADD) {
2986       SDValue A, Xor;
2987 
2988       if (isBitwiseNot(N0.getOperand(0))) {
2989         A = N0.getOperand(1);
2990         Xor = N0.getOperand(0);
2991       } else if (isBitwiseNot(N0.getOperand(1))) {
2992         A = N0.getOperand(0);
2993         Xor = N0.getOperand(1);
2994       }
2995 
2996       if (Xor)
2997         return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2998     }
2999 
3000     // Look for:
3001     //   add (add x, y), 1
3002     // And if the target does not like this form then turn into:
3003     //   sub y, (xor x, -1)
3004     if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3005         N0.hasOneUse() &&
3006         // Limit this to after legalization if the add has wrap flags
3007         (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
3008                                        !N->getFlags().hasNoSignedWrap()))) {
3009       SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
3010       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
3011     }
3012   }
3013 
3014   // (x - y) + -1  ->  add (xor y, -1), x
3015   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3016       isAllOnesOrAllOnesSplat(N1, /*AllowUndefs=*/true)) {
3017     SDValue Not = DAG.getNOT(DL, N0.getOperand(1), VT);
3018     return DAG.getNode(ISD::ADD, DL, VT, Not, N0.getOperand(0));
3019   }
3020 
3021   // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
3022   // This can help if the inner add has multiple uses.
3023   APInt CM, CA;
3024   if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(N1)) {
3025     if (VT.getScalarSizeInBits() <= 64) {
3026       if (sd_match(N0, m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
3027                                       m_ConstInt(CM)))) &&
3028           TLI.isLegalAddImmediate(
3029               (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3030         SDNodeFlags Flags;
3031         // If all the inputs are nuw, the outputs can be nuw. If all the input
3032         // are _also_ nsw the outputs can be too.
3033         if (N->getFlags().hasNoUnsignedWrap() &&
3034             N0->getFlags().hasNoUnsignedWrap() &&
3035             N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
3036           Flags |= SDNodeFlags::NoUnsignedWrap;
3037           if (N->getFlags().hasNoSignedWrap() &&
3038               N0->getFlags().hasNoSignedWrap() &&
3039               N0.getOperand(0)->getFlags().hasNoSignedWrap())
3040             Flags |= SDNodeFlags::NoSignedWrap;
3041         }
3042         SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
3043                                   DAG.getConstant(CM, DL, VT), Flags);
3044         return DAG.getNode(
3045             ISD::ADD, DL, VT, Mul,
3046             DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3047       }
3048       // Also look in case there is an intermediate add.
3049       if (sd_match(N0, m_OneUse(m_Add(
3050                            m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
3051                                           m_ConstInt(CM))),
3052                            m_Value(B)))) &&
3053           TLI.isLegalAddImmediate(
3054               (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3055         SDNodeFlags Flags;
3056         // If all the inputs are nuw, the outputs can be nuw. If all the input
3057         // are _also_ nsw the outputs can be too.
3058         SDValue OMul =
3059             N0.getOperand(0) == B ? N0.getOperand(1) : N0.getOperand(0);
3060         if (N->getFlags().hasNoUnsignedWrap() &&
3061             N0->getFlags().hasNoUnsignedWrap() &&
3062             OMul->getFlags().hasNoUnsignedWrap() &&
3063             OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
3064           Flags |= SDNodeFlags::NoUnsignedWrap;
3065           if (N->getFlags().hasNoSignedWrap() &&
3066               N0->getFlags().hasNoSignedWrap() &&
3067               OMul->getFlags().hasNoSignedWrap() &&
3068               OMul.getOperand(0)->getFlags().hasNoSignedWrap())
3069             Flags |= SDNodeFlags::NoSignedWrap;
3070         }
3071         SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
3072                                   DAG.getConstant(CM, DL, VT), Flags);
3073         SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N1), VT, Mul, B, Flags);
3074         return DAG.getNode(
3075             ISD::ADD, DL, VT, Add,
3076             DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3077       }
3078     }
3079   }
3080 
3081   if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
3082     return Combined;
3083 
3084   if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
3085     return Combined;
3086 
3087   return SDValue();
3088 }
3089 
3090 // Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
foldAddToAvg(SDNode * N,const SDLoc & DL)3091 SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
3092   SDValue N0 = N->getOperand(0);
3093   EVT VT = N0.getValueType();
3094   SDValue A, B;
3095 
3096   if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
3097       sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
3098                         m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
3099     return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
3100   }
3101   if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
3102       sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
3103                         m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
3104     return DAG.getNode(ISD::AVGFLOORS, DL, VT, A, B);
3105   }
3106 
3107   return SDValue();
3108 }
3109 
visitADD(SDNode * N)3110 SDValue DAGCombiner::visitADD(SDNode *N) {
3111   SDValue N0 = N->getOperand(0);
3112   SDValue N1 = N->getOperand(1);
3113   EVT VT = N0.getValueType();
3114   SDLoc DL(N);
3115 
3116   if (SDValue Combined = visitADDLike(N))
3117     return Combined;
3118 
3119   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3120     return V;
3121 
3122   if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3123     return V;
3124 
3125   if (SDValue V = MatchRotate(N0, N1, SDLoc(N), /*FromAdd=*/true))
3126     return V;
3127 
3128   // Try to match AVGFLOOR fixedwidth pattern
3129   if (SDValue V = foldAddToAvg(N, DL))
3130     return V;
3131 
3132   // fold (a+b) -> (a|b) iff a and b share no bits.
3133   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
3134       DAG.haveNoCommonBitsSet(N0, N1))
3135     return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
3136 
3137   // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
3138   if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
3139     const APInt &C0 = N0->getConstantOperandAPInt(0);
3140     const APInt &C1 = N1->getConstantOperandAPInt(0);
3141     return DAG.getVScale(DL, VT, C0 + C1);
3142   }
3143 
3144   // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
3145   if (N0.getOpcode() == ISD::ADD &&
3146       N0.getOperand(1).getOpcode() == ISD::VSCALE &&
3147       N1.getOpcode() == ISD::VSCALE) {
3148     const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3149     const APInt &VS1 = N1->getConstantOperandAPInt(0);
3150     SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
3151     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
3152   }
3153 
3154   // Fold (add step_vector(c1), step_vector(c2)  to step_vector(c1+c2))
3155   if (N0.getOpcode() == ISD::STEP_VECTOR &&
3156       N1.getOpcode() == ISD::STEP_VECTOR) {
3157     const APInt &C0 = N0->getConstantOperandAPInt(0);
3158     const APInt &C1 = N1->getConstantOperandAPInt(0);
3159     APInt NewStep = C0 + C1;
3160     return DAG.getStepVector(DL, VT, NewStep);
3161   }
3162 
3163   // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3164   if (N0.getOpcode() == ISD::ADD &&
3165       N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR &&
3166       N1.getOpcode() == ISD::STEP_VECTOR) {
3167     const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3168     const APInt &SV1 = N1->getConstantOperandAPInt(0);
3169     APInt NewStep = SV0 + SV1;
3170     SDValue SV = DAG.getStepVector(DL, VT, NewStep);
3171     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
3172   }
3173 
3174   return SDValue();
3175 }
3176 
visitADDSAT(SDNode * N)3177 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3178   unsigned Opcode = N->getOpcode();
3179   SDValue N0 = N->getOperand(0);
3180   SDValue N1 = N->getOperand(1);
3181   EVT VT = N0.getValueType();
3182   bool IsSigned = Opcode == ISD::SADDSAT;
3183   SDLoc DL(N);
3184 
3185   // fold (add_sat x, undef) -> -1
3186   if (N0.isUndef() || N1.isUndef())
3187     return DAG.getAllOnesConstant(DL, VT);
3188 
3189   // fold (add_sat c1, c2) -> c3
3190   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
3191     return C;
3192 
3193   // canonicalize constant to RHS
3194   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3195       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3196     return DAG.getNode(Opcode, DL, VT, N1, N0);
3197 
3198   // fold vector ops
3199   if (VT.isVector()) {
3200     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3201       return FoldedVOp;
3202 
3203     // fold (add_sat x, 0) -> x, vector edition
3204     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3205       return N0;
3206   }
3207 
3208   // fold (add_sat x, 0) -> x
3209   if (isNullConstant(N1))
3210     return N0;
3211 
3212   // If it cannot overflow, transform into an add.
3213   if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3214     return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
3215 
3216   return SDValue();
3217 }
3218 
getAsCarry(const TargetLowering & TLI,SDValue V,bool ForceCarryReconstruction=false)3219 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3220                           bool ForceCarryReconstruction = false) {
3221   bool Masked = false;
3222 
3223   // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3224   while (true) {
3225     if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3226       V = V.getOperand(0);
3227       continue;
3228     }
3229 
3230     if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
3231       if (ForceCarryReconstruction)
3232         return V;
3233 
3234       Masked = true;
3235       V = V.getOperand(0);
3236       continue;
3237     }
3238 
3239     if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3240       return V;
3241 
3242     break;
3243   }
3244 
3245   // If this is not a carry, return.
3246   if (V.getResNo() != 1)
3247     return SDValue();
3248 
3249   if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3250       V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3251     return SDValue();
3252 
3253   EVT VT = V->getValueType(0);
3254   if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
3255     return SDValue();
3256 
3257   // If the result is masked, then no matter what kind of bool it is we can
3258   // return. If it isn't, then we need to make sure the bool type is either 0 or
3259   // 1 and not other values.
3260   if (Masked ||
3261       TLI.getBooleanContents(V.getValueType()) ==
3262           TargetLoweringBase::ZeroOrOneBooleanContent)
3263     return V;
3264 
3265   return SDValue();
3266 }
3267 
3268 /// Given the operands of an add/sub operation, see if the 2nd operand is a
3269 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3270 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)3271 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3272                                  SelectionDAG &DAG, const SDLoc &DL) {
3273   if (N1.getOpcode() == ISD::ZERO_EXTEND)
3274     N1 = N1.getOperand(0);
3275 
3276   if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
3277     return SDValue();
3278 
3279   EVT VT = N0.getValueType();
3280   SDValue N10 = N1.getOperand(0);
3281   if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3282     N10 = N10.getOperand(0);
3283 
3284   if (N10.getValueType() != VT)
3285     return SDValue();
3286 
3287   if (DAG.ComputeNumSignBits(N10) != VT.getScalarSizeInBits())
3288     return SDValue();
3289 
3290   // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3291   // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3292   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N10);
3293 }
3294 
3295 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)3296 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3297                                              SDNode *LocReference) {
3298   EVT VT = N0.getValueType();
3299   SDLoc DL(LocReference);
3300 
3301   // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3302   SDValue Y, N;
3303   if (sd_match(N1, m_Shl(m_Neg(m_Value(Y)), m_Value(N))))
3304     return DAG.getNode(ISD::SUB, DL, VT, N0,
3305                        DAG.getNode(ISD::SHL, DL, VT, Y, N));
3306 
3307   if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
3308     return V;
3309 
3310   // Look for:
3311   //   add (add x, 1), y
3312   // And if the target does not like this form then turn into:
3313   //   sub y, (xor x, -1)
3314   if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3315       N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1)) &&
3316       // Limit this to after legalization if the add has wrap flags
3317       (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3318                                      !N0->getFlags().hasNoSignedWrap()))) {
3319     SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
3320     return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
3321   }
3322 
3323   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3324     // Hoist one-use subtraction by non-opaque constant:
3325     //   (x - C) + y  ->  (x + y) - C
3326     // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3327     if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3328       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
3329       return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
3330     }
3331     // Hoist one-use subtraction from non-opaque constant:
3332     //   (C - x) + y  ->  (y - x) + C
3333     if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3334       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
3335       return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
3336     }
3337   }
3338 
3339   // add (mul x, C), x -> mul x, C+1
3340   if (N0.getOpcode() == ISD::MUL && N0.getOperand(0) == N1 &&
3341       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true) &&
3342       N0.hasOneUse()) {
3343     SDValue NewC = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
3344                                DAG.getConstant(1, DL, VT));
3345     return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), NewC);
3346   }
3347 
3348   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3349   // rather than 'add 0/-1' (the zext should get folded).
3350   // add (sext i1 Y), X --> sub X, (zext i1 Y)
3351   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3352       N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
3353       TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
3354     SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
3355     return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
3356   }
3357 
3358   // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3359   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3360     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3361     if (TN->getVT() == MVT::i1) {
3362       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3363                                  DAG.getConstant(1, DL, VT));
3364       return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
3365     }
3366   }
3367 
3368   // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3369   if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1)) &&
3370       N1.getResNo() == 0)
3371     return DAG.getNode(ISD::UADDO_CARRY, DL, N1->getVTList(),
3372                        N0, N1.getOperand(0), N1.getOperand(2));
3373 
3374   // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3375   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3376     if (SDValue Carry = getAsCarry(TLI, N1))
3377       return DAG.getNode(ISD::UADDO_CARRY, DL,
3378                          DAG.getVTList(VT, Carry.getValueType()), N0,
3379                          DAG.getConstant(0, DL, VT), Carry);
3380 
3381   return SDValue();
3382 }
3383 
visitADDC(SDNode * N)3384 SDValue DAGCombiner::visitADDC(SDNode *N) {
3385   SDValue N0 = N->getOperand(0);
3386   SDValue N1 = N->getOperand(1);
3387   EVT VT = N0.getValueType();
3388   SDLoc DL(N);
3389 
3390   // If the flag result is dead, turn this into an ADD.
3391   if (!N->hasAnyUseOfValue(1))
3392     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3393                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3394 
3395   // canonicalize constant to RHS.
3396   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3397   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3398   if (N0C && !N1C)
3399     return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
3400 
3401   // fold (addc x, 0) -> x + no carry out
3402   if (isNullConstant(N1))
3403     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
3404                                         DL, MVT::Glue));
3405 
3406   // If it cannot overflow, transform into an add.
3407   if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3408     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3409                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3410 
3411   return SDValue();
3412 }
3413 
3414 /**
3415  * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3416  * then the flip also occurs if computing the inverse is the same cost.
3417  * This function returns an empty SDValue in case it cannot flip the boolean
3418  * without increasing the cost of the computation. If you want to flip a boolean
3419  * no matter what, use DAG.getLogicalNOT.
3420  */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)3421 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3422                                   const TargetLowering &TLI,
3423                                   bool Force) {
3424   if (Force && isa<ConstantSDNode>(V))
3425     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3426 
3427   if (V.getOpcode() != ISD::XOR)
3428     return SDValue();
3429 
3430   if (DAG.isBoolConstant(V.getOperand(1)) == true)
3431     return V.getOperand(0);
3432   if (Force && isConstOrConstSplat(V.getOperand(1), false))
3433     return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3434   return SDValue();
3435 }
3436 
visitADDO(SDNode * N)3437 SDValue DAGCombiner::visitADDO(SDNode *N) {
3438   SDValue N0 = N->getOperand(0);
3439   SDValue N1 = N->getOperand(1);
3440   EVT VT = N0.getValueType();
3441   bool IsSigned = (ISD::SADDO == N->getOpcode());
3442 
3443   EVT CarryVT = N->getValueType(1);
3444   SDLoc DL(N);
3445 
3446   // If the flag result is dead, turn this into an ADD.
3447   if (!N->hasAnyUseOfValue(1))
3448     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3449                      DAG.getUNDEF(CarryVT));
3450 
3451   // canonicalize constant to RHS.
3452   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3453       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3454     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
3455 
3456   // fold (addo x, 0) -> x + no carry out
3457   if (isNullOrNullSplat(N1))
3458     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3459 
3460   // If it cannot overflow, transform into an add.
3461   if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3462     return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3463                      DAG.getConstant(0, DL, CarryVT));
3464 
3465   if (IsSigned) {
3466     // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3467     if (isBitwiseNot(N0) && isOneOrOneSplat(N1))
3468       return DAG.getNode(ISD::SSUBO, DL, N->getVTList(),
3469                          DAG.getConstant(0, DL, VT), N0.getOperand(0));
3470   } else {
3471     // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3472     if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
3473       SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
3474                                 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3475       return CombineTo(
3476           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3477     }
3478 
3479     if (SDValue Combined = visitUADDOLike(N0, N1, N))
3480       return Combined;
3481 
3482     if (SDValue Combined = visitUADDOLike(N1, N0, N))
3483       return Combined;
3484   }
3485 
3486   return SDValue();
3487 }
3488 
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3489 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3490   EVT VT = N0.getValueType();
3491   if (VT.isVector())
3492     return SDValue();
3493 
3494   // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3495   // If Y + 1 cannot overflow.
3496   if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1))) {
3497     SDValue Y = N1.getOperand(0);
3498     SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3499     if (DAG.computeOverflowForUnsignedAdd(Y, One) == SelectionDAG::OFK_Never)
3500       return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0, Y,
3501                          N1.getOperand(2));
3502   }
3503 
3504   // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3505   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3506     if (SDValue Carry = getAsCarry(TLI, N1))
3507       return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0,
3508                          DAG.getConstant(0, SDLoc(N), VT), Carry);
3509 
3510   return SDValue();
3511 }
3512 
visitADDE(SDNode * N)3513 SDValue DAGCombiner::visitADDE(SDNode *N) {
3514   SDValue N0 = N->getOperand(0);
3515   SDValue N1 = N->getOperand(1);
3516   SDValue CarryIn = N->getOperand(2);
3517 
3518   // canonicalize constant to RHS
3519   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3520   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3521   if (N0C && !N1C)
3522     return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3523                        N1, N0, CarryIn);
3524 
3525   // fold (adde x, y, false) -> (addc x, y)
3526   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3527     return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3528 
3529   return SDValue();
3530 }
3531 
visitUADDO_CARRY(SDNode * N)3532 SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3533   SDValue N0 = N->getOperand(0);
3534   SDValue N1 = N->getOperand(1);
3535   SDValue CarryIn = N->getOperand(2);
3536   SDLoc DL(N);
3537 
3538   // canonicalize constant to RHS
3539   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3540   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3541   if (N0C && !N1C)
3542     return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3543 
3544   // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3545   if (isNullConstant(CarryIn)) {
3546     if (!LegalOperations ||
3547         TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3548       return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3549   }
3550 
3551   // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3552   if (isNullConstant(N0) && isNullConstant(N1)) {
3553     EVT VT = N0.getValueType();
3554     EVT CarryVT = CarryIn.getValueType();
3555     SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3556     AddToWorklist(CarryExt.getNode());
3557     return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3558                                     DAG.getConstant(1, DL, VT)),
3559                      DAG.getConstant(0, DL, CarryVT));
3560   }
3561 
3562   if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3563     return Combined;
3564 
3565   if (SDValue Combined = visitUADDO_CARRYLike(N1, N0, CarryIn, N))
3566     return Combined;
3567 
3568   // We want to avoid useless duplication.
3569   // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3570   // not a binary operation, this is not really possible to leverage this
3571   // existing mechanism for it. However, if more operations require the same
3572   // deduplication logic, then it may be worth generalize.
3573   SDValue Ops[] = {N1, N0, CarryIn};
3574   SDNode *CSENode =
3575       DAG.getNodeIfExists(ISD::UADDO_CARRY, N->getVTList(), Ops, N->getFlags());
3576   if (CSENode)
3577     return SDValue(CSENode, 0);
3578 
3579   return SDValue();
3580 }
3581 
3582 /**
3583  * If we are facing some sort of diamond carry propagation pattern try to
3584  * break it up to generate something like:
3585  *   (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3586  *
3587  * The end result is usually an increase in operation required, but because the
3588  * carry is now linearized, other transforms can kick in and optimize the DAG.
3589  *
3590  * Patterns typically look something like
3591  *                (uaddo A, B)
3592  *                /          \
3593  *             Carry         Sum
3594  *               |             \
3595  *               | (uaddo_carry *, 0, Z)
3596  *               |       /
3597  *                \   Carry
3598  *                 |   /
3599  * (uaddo_carry X, *, *)
3600  *
3601  * But numerous variation exist. Our goal is to identify A, B, X and Z and
3602  * produce a combine with a single path for carry propagation.
3603  */
combineUADDO_CARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3604 static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3605                                          SelectionDAG &DAG, SDValue X,
3606                                          SDValue Carry0, SDValue Carry1,
3607                                          SDNode *N) {
3608   if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3609     return SDValue();
3610   if (Carry1.getOpcode() != ISD::UADDO)
3611     return SDValue();
3612 
3613   SDValue Z;
3614 
3615   /**
3616    * First look for a suitable Z. It will present itself in the form of
3617    * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3618    */
3619   if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3620       isNullConstant(Carry0.getOperand(1))) {
3621     Z = Carry0.getOperand(2);
3622   } else if (Carry0.getOpcode() == ISD::UADDO &&
3623              isOneConstant(Carry0.getOperand(1))) {
3624     EVT VT = Carry0->getValueType(1);
3625     Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3626   } else {
3627     // We couldn't find a suitable Z.
3628     return SDValue();
3629   }
3630 
3631 
3632   auto cancelDiamond = [&](SDValue A,SDValue B) {
3633     SDLoc DL(N);
3634     SDValue NewY =
3635         DAG.getNode(ISD::UADDO_CARRY, DL, Carry0->getVTList(), A, B, Z);
3636     Combiner.AddToWorklist(NewY.getNode());
3637     return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), X,
3638                        DAG.getConstant(0, DL, X.getValueType()),
3639                        NewY.getValue(1));
3640   };
3641 
3642   /**
3643    *         (uaddo A, B)
3644    *              |
3645    *             Sum
3646    *              |
3647    * (uaddo_carry *, 0, Z)
3648    */
3649   if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3650     return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3651   }
3652 
3653   /**
3654    * (uaddo_carry A, 0, Z)
3655    *         |
3656    *        Sum
3657    *         |
3658    *  (uaddo *, B)
3659    */
3660   if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3661     return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3662   }
3663 
3664   if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3665     return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3666   }
3667 
3668   return SDValue();
3669 }
3670 
3671 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3672 // match patterns like:
3673 //
3674 //          (uaddo A, B)            CarryIn
3675 //            |  \                     |
3676 //            |   \                    |
3677 //    PartialSum   PartialCarryOutX   /
3678 //            |        |             /
3679 //            |    ____|____________/
3680 //            |   /    |
3681 //     (uaddo *, *)    \________
3682 //       |  \                   \
3683 //       |   \                   |
3684 //       |    PartialCarryOutY   |
3685 //       |        \              |
3686 //       |         \            /
3687 //   AddCarrySum    |    ______/
3688 //                  |   /
3689 //   CarryOut = (or *, *)
3690 //
3691 // And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3692 //
3693 //    {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3694 //
3695 // Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3696 // with a single path for carry/borrow out propagation.
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3697 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3698                                    SDValue N0, SDValue N1, SDNode *N) {
3699   SDValue Carry0 = getAsCarry(TLI, N0);
3700   if (!Carry0)
3701     return SDValue();
3702   SDValue Carry1 = getAsCarry(TLI, N1);
3703   if (!Carry1)
3704     return SDValue();
3705 
3706   unsigned Opcode = Carry0.getOpcode();
3707   if (Opcode != Carry1.getOpcode())
3708     return SDValue();
3709   if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3710     return SDValue();
3711   // Guarantee identical type of CarryOut
3712   EVT CarryOutType = N->getValueType(0);
3713   if (CarryOutType != Carry0.getValue(1).getValueType() ||
3714       CarryOutType != Carry1.getValue(1).getValueType())
3715     return SDValue();
3716 
3717   // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3718   // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3719   if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3720     std::swap(Carry0, Carry1);
3721 
3722   // Check if nodes are connected in expected way.
3723   if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3724       Carry1.getOperand(1) != Carry0.getValue(0))
3725     return SDValue();
3726 
3727   // The carry in value must be on the righthand side for subtraction.
3728   unsigned CarryInOperandNum =
3729       Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3730   if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3731     return SDValue();
3732   SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3733 
3734   unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3735   if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3736     return SDValue();
3737 
3738   // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3739   CarryIn = getAsCarry(TLI, CarryIn, true);
3740   if (!CarryIn)
3741     return SDValue();
3742 
3743   SDLoc DL(N);
3744   CarryIn = DAG.getBoolExtOrTrunc(CarryIn, DL, Carry1->getValueType(1),
3745                                   Carry1->getValueType(0));
3746   SDValue Merged =
3747       DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3748                   Carry0.getOperand(1), CarryIn);
3749 
3750   // Please note that because we have proven that the result of the UADDO/USUBO
3751   // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3752   // therefore prove that if the first UADDO/USUBO overflows, the second
3753   // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3754   // maximum value.
3755   //
3756   //   0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3757   //   0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3758   //
3759   // This is important because it means that OR and XOR can be used to merge
3760   // carry flags; and that AND can return a constant zero.
3761   //
3762   // TODO: match other operations that can merge flags (ADD, etc)
3763   DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3764   if (N->getOpcode() == ISD::AND)
3765     return DAG.getConstant(0, DL, CarryOutType);
3766   return Merged.getValue(1);
3767 }
3768 
visitUADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3769 SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3770                                           SDValue CarryIn, SDNode *N) {
3771   // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3772   // carry.
3773   if (isBitwiseNot(N0))
3774     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3775       SDLoc DL(N);
3776       SDValue Sub = DAG.getNode(ISD::USUBO_CARRY, DL, N->getVTList(), N1,
3777                                 N0.getOperand(0), NotC);
3778       return CombineTo(
3779           N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3780     }
3781 
3782   // Iff the flag result is dead:
3783   // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3784   // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3785   // or the dependency between the instructions.
3786   if ((N0.getOpcode() == ISD::ADD ||
3787        (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3788         N0.getValue(1) != CarryIn)) &&
3789       isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3790     return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(),
3791                        N0.getOperand(0), N0.getOperand(1), CarryIn);
3792 
3793   /**
3794    * When one of the uaddo_carry argument is itself a carry, we may be facing
3795    * a diamond carry propagation. In which case we try to transform the DAG
3796    * to ensure linear carry propagation if that is possible.
3797    */
3798   if (auto Y = getAsCarry(TLI, N1)) {
3799     // Because both are carries, Y and Z can be swapped.
3800     if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3801       return R;
3802     if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3803       return R;
3804   }
3805 
3806   return SDValue();
3807 }
3808 
visitSADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3809 SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3810                                           SDValue CarryIn, SDNode *N) {
3811   // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3812   if (isBitwiseNot(N0)) {
3813     if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true))
3814       return DAG.getNode(ISD::SSUBO_CARRY, SDLoc(N), N->getVTList(), N1,
3815                          N0.getOperand(0), NotC);
3816   }
3817 
3818   return SDValue();
3819 }
3820 
visitSADDO_CARRY(SDNode * N)3821 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3822   SDValue N0 = N->getOperand(0);
3823   SDValue N1 = N->getOperand(1);
3824   SDValue CarryIn = N->getOperand(2);
3825   SDLoc DL(N);
3826 
3827   // canonicalize constant to RHS
3828   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3829   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3830   if (N0C && !N1C)
3831     return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3832 
3833   // fold (saddo_carry x, y, false) -> (saddo x, y)
3834   if (isNullConstant(CarryIn)) {
3835     if (!LegalOperations ||
3836         TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3837       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3838   }
3839 
3840   if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3841     return Combined;
3842 
3843   if (SDValue Combined = visitSADDO_CARRYLike(N1, N0, CarryIn, N))
3844     return Combined;
3845 
3846   return SDValue();
3847 }
3848 
3849 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3850 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3851 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3852                                    SDValue RHS, SelectionDAG &DAG,
3853                                    const SDLoc &DL) {
3854   assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3855          "Illegal truncation");
3856 
3857   if (DstVT == SrcVT)
3858     return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3859 
3860   // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3861   // clamping RHS.
3862   APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3863                                           DstVT.getScalarSizeInBits());
3864   if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3865     return SDValue();
3866 
3867   SDValue SatLimit =
3868       DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3869                                            DstVT.getScalarSizeInBits()),
3870                       DL, SrcVT);
3871   RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3872   RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3873   LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3874   return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3875 }
3876 
3877 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3878 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N,const SDLoc & DL)3879 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3880   if (N->getOpcode() != ISD::SUB ||
3881       !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3882     return SDValue();
3883 
3884   EVT SubVT = N->getValueType(0);
3885   SDValue Op0 = N->getOperand(0);
3886   SDValue Op1 = N->getOperand(1);
3887 
3888   // Try to find umax(a,b) - b or a - umin(a,b) patterns
3889   // they may be converted to usubsat(a,b).
3890   if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3891     SDValue MaxLHS = Op0.getOperand(0);
3892     SDValue MaxRHS = Op0.getOperand(1);
3893     if (MaxLHS == Op1)
3894       return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, DL);
3895     if (MaxRHS == Op1)
3896       return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, DL);
3897   }
3898 
3899   if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3900     SDValue MinLHS = Op1.getOperand(0);
3901     SDValue MinRHS = Op1.getOperand(1);
3902     if (MinLHS == Op0)
3903       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, DL);
3904     if (MinRHS == Op0)
3905       return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, DL);
3906   }
3907 
3908   // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3909   if (Op1.getOpcode() == ISD::TRUNCATE &&
3910       Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3911       Op1.getOperand(0).hasOneUse()) {
3912     SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3913     SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3914     if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3915       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3916                                  DAG, DL);
3917     if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3918       return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3919                                  DAG, DL);
3920   }
3921 
3922   return SDValue();
3923 }
3924 
3925 // Refinement of DAG/Type Legalisation (promotion) when CTLZ is used for
3926 // counting leading ones. Broadly, it replaces the substraction with a left
3927 // shift.
3928 //
3929 // * DAG Legalisation Pattern:
3930 //
3931 //     (sub (ctlz (zeroextend (not Src)))
3932 //          BitWidthDiff)
3933 //
3934 //       if BitWidthDiff == BitWidth(Node) - BitWidth(Src)
3935 //       -->
3936 //
3937 //     (ctlz_zero_undef (not (shl (anyextend Src)
3938 //                                BitWidthDiff)))
3939 //
3940 // * Type Legalisation Pattern:
3941 //
3942 //     (sub (ctlz (and (xor Src XorMask)
3943 //                     AndMask))
3944 //          BitWidthDiff)
3945 //
3946 //       if AndMask has only trailing ones
3947 //       and MaskBitWidth(AndMask) == BitWidth(Node) - BitWidthDiff
3948 //       and XorMask has more trailing ones than AndMask
3949 //       -->
3950 //
3951 //     (ctlz_zero_undef (not (shl Src BitWidthDiff)))
3952 template <class MatchContextClass>
foldSubCtlzNot(SDNode * N,SelectionDAG & DAG)3953 static SDValue foldSubCtlzNot(SDNode *N, SelectionDAG &DAG) {
3954   const SDLoc DL(N);
3955   SDValue N0 = N->getOperand(0);
3956   EVT VT = N0.getValueType();
3957   unsigned BitWidth = VT.getScalarSizeInBits();
3958 
3959   MatchContextClass Matcher(DAG, DAG.getTargetLoweringInfo(), N);
3960 
3961   APInt AndMask;
3962   APInt XorMask;
3963   APInt BitWidthDiff;
3964 
3965   SDValue CtlzOp;
3966   SDValue Src;
3967 
3968   if (!sd_context_match(
3969           N, Matcher, m_Sub(m_Ctlz(m_Value(CtlzOp)), m_ConstInt(BitWidthDiff))))
3970     return SDValue();
3971 
3972   if (sd_context_match(CtlzOp, Matcher, m_ZExt(m_Not(m_Value(Src))))) {
3973     // DAG Legalisation Pattern:
3974     // (sub (ctlz (zero_extend (not Op)) BitWidthDiff))
3975     if ((BitWidth - Src.getValueType().getScalarSizeInBits()) != BitWidthDiff)
3976       return SDValue();
3977 
3978     Src = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Src);
3979   } else if (sd_context_match(CtlzOp, Matcher,
3980                               m_And(m_Xor(m_Value(Src), m_ConstInt(XorMask)),
3981                                     m_ConstInt(AndMask)))) {
3982     // Type Legalisation Pattern:
3983     // (sub (ctlz (and (xor Op XorMask) AndMask)) BitWidthDiff)
3984     unsigned AndMaskWidth = BitWidth - BitWidthDiff.getZExtValue();
3985     if (!(AndMask.isMask(AndMaskWidth) && XorMask.countr_one() >= AndMaskWidth))
3986       return SDValue();
3987   } else
3988     return SDValue();
3989 
3990   SDValue ShiftConst = DAG.getShiftAmountConstant(BitWidthDiff, VT, DL);
3991   SDValue LShift = Matcher.getNode(ISD::SHL, DL, VT, Src, ShiftConst);
3992   SDValue Not =
3993       Matcher.getNode(ISD::XOR, DL, VT, LShift, DAG.getAllOnesConstant(DL, VT));
3994 
3995   return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
3996 }
3997 
3998 // Fold sub(x, mul(divrem(x,y)[0], y)) to divrem(x, y)[1]
foldRemainderIdiom(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)3999 static SDValue foldRemainderIdiom(SDNode *N, SelectionDAG &DAG,
4000                                   const SDLoc &DL) {
4001   assert(N->getOpcode() == ISD::SUB && "Node must be a SUB");
4002   SDValue Sub0 = N->getOperand(0);
4003   SDValue Sub1 = N->getOperand(1);
4004 
4005   auto CheckAndFoldMulCase = [&](SDValue DivRem, SDValue MaybeY) -> SDValue {
4006     if ((DivRem.getOpcode() == ISD::SDIVREM ||
4007          DivRem.getOpcode() == ISD::UDIVREM) &&
4008         DivRem.getResNo() == 0 && DivRem.getOperand(0) == Sub0 &&
4009         DivRem.getOperand(1) == MaybeY) {
4010       return SDValue(DivRem.getNode(), 1);
4011     }
4012     return SDValue();
4013   };
4014 
4015   if (Sub1.getOpcode() == ISD::MUL) {
4016     // (sub x, (mul divrem(x,y)[0], y))
4017     SDValue Mul0 = Sub1.getOperand(0);
4018     SDValue Mul1 = Sub1.getOperand(1);
4019 
4020     if (SDValue Res = CheckAndFoldMulCase(Mul0, Mul1))
4021       return Res;
4022 
4023     if (SDValue Res = CheckAndFoldMulCase(Mul1, Mul0))
4024       return Res;
4025 
4026   } else if (Sub1.getOpcode() == ISD::SHL) {
4027     // Handle (sub x, (shl divrem(x,y)[0], C)) where y = 1 << C
4028     SDValue Shl0 = Sub1.getOperand(0);
4029     SDValue Shl1 = Sub1.getOperand(1);
4030     // Check if Shl0 is divrem(x, Y)[0]
4031     if ((Shl0.getOpcode() == ISD::SDIVREM ||
4032          Shl0.getOpcode() == ISD::UDIVREM) &&
4033         Shl0.getResNo() == 0 && Shl0.getOperand(0) == Sub0) {
4034 
4035       SDValue Divisor = Shl0.getOperand(1);
4036 
4037       ConstantSDNode *DivC = isConstOrConstSplat(Divisor);
4038       ConstantSDNode *ShC = isConstOrConstSplat(Shl1);
4039       if (!DivC || !ShC)
4040         return SDValue();
4041 
4042       if (DivC->getAPIntValue().isPowerOf2() &&
4043           DivC->getAPIntValue().logBase2() == ShC->getAPIntValue())
4044         return SDValue(Shl0.getNode(), 1);
4045     }
4046   }
4047   return SDValue();
4048 }
4049 
4050 // Since it may not be valid to emit a fold to zero for vector initializers
4051 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)4052 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
4053                              SelectionDAG &DAG, bool LegalOperations) {
4054   if (!VT.isVector())
4055     return DAG.getConstant(0, DL, VT);
4056   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
4057     return DAG.getConstant(0, DL, VT);
4058   return SDValue();
4059 }
4060 
visitSUB(SDNode * N)4061 SDValue DAGCombiner::visitSUB(SDNode *N) {
4062   SDValue N0 = N->getOperand(0);
4063   SDValue N1 = N->getOperand(1);
4064   EVT VT = N0.getValueType();
4065   unsigned BitWidth = VT.getScalarSizeInBits();
4066   SDLoc DL(N);
4067 
4068   auto PeekThroughFreeze = [](SDValue N) {
4069     if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
4070       return N->getOperand(0);
4071     return N;
4072   };
4073 
4074   if (SDValue V = foldSubCtlzNot<EmptyMatchContext>(N, DAG))
4075     return V;
4076 
4077   // fold (sub x, x) -> 0
4078   // FIXME: Refactor this and xor and other similar operations together.
4079   if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
4080     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4081 
4082   // fold (sub c1, c2) -> c3
4083   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
4084     return C;
4085 
4086   // fold vector ops
4087   if (VT.isVector()) {
4088     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4089       return FoldedVOp;
4090 
4091     // fold (sub x, 0) -> x, vector edition
4092     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4093       return N0;
4094   }
4095 
4096   if (SDValue NewSel = foldBinOpIntoSelect(N))
4097     return NewSel;
4098 
4099   // fold (sub x, c) -> (add x, -c)
4100   if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
4101     return DAG.getNode(ISD::ADD, DL, VT, N0,
4102                        DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4103 
4104   if (isNullOrNullSplat(N0)) {
4105     // Right-shifting everything out but the sign bit followed by negation is
4106     // the same as flipping arithmetic/logical shift type without the negation:
4107     // -(X >>u 31) -> (X >>s 31)
4108     // -(X >>s 31) -> (X >>u 31)
4109     if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
4110       ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
4111       if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
4112         auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
4113         if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
4114           return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
4115       }
4116     }
4117 
4118     // 0 - X --> 0 if the sub is NUW.
4119     if (N->getFlags().hasNoUnsignedWrap())
4120       return N0;
4121 
4122     if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
4123       // N1 is either 0 or the minimum signed value. If the sub is NSW, then
4124       // N1 must be 0 because negating the minimum signed value is undefined.
4125       if (N->getFlags().hasNoSignedWrap())
4126         return N0;
4127 
4128       // 0 - X --> X if X is 0 or the minimum signed value.
4129       return N1;
4130     }
4131 
4132     // Convert 0 - abs(x).
4133     if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
4134         !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
4135       if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
4136         return Result;
4137 
4138     // Similar to the previous rule, but this time targeting an expanded abs.
4139     // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
4140     // as well as
4141     // (sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X))
4142     // Note that these two are applicable to both signed and unsigned min/max.
4143     SDValue X;
4144     SDValue S0;
4145     auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0));
4146     if (sd_match(N1, m_OneUse(m_AnyOf(m_SMax(m_Value(X), NegPat),
4147                                       m_UMax(m_Value(X), NegPat),
4148                                       m_SMin(m_Value(X), NegPat),
4149                                       m_UMin(m_Value(X), NegPat))))) {
4150       unsigned NewOpc = ISD::getInverseMinMaxOpcode(N1->getOpcode());
4151       if (hasOperation(NewOpc, VT))
4152         return DAG.getNode(NewOpc, DL, VT, X, S0);
4153     }
4154 
4155     // Fold neg(splat(neg(x)) -> splat(x)
4156     if (VT.isVector()) {
4157       SDValue N1S = DAG.getSplatValue(N1, true);
4158       if (N1S && N1S.getOpcode() == ISD::SUB &&
4159           isNullConstant(N1S.getOperand(0)))
4160         return DAG.getSplat(VT, DL, N1S.getOperand(1));
4161     }
4162 
4163     // sub 0, (and x, 1)  -->  SIGN_EXTEND_INREG x, i1
4164     if (N1.getOpcode() == ISD::AND && N1.hasOneUse() &&
4165         isOneOrOneSplat(N1->getOperand(1))) {
4166       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), 1);
4167       if (VT.isVector())
4168         ExtVT = EVT::getVectorVT(*DAG.getContext(), ExtVT,
4169                                  VT.getVectorElementCount());
4170       if (TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
4171           TargetLowering::Legal) {
4172         return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N1->getOperand(0),
4173                            DAG.getValueType(ExtVT));
4174       }
4175     }
4176   }
4177 
4178   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
4179   if (isAllOnesOrAllOnesSplat(N0))
4180     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4181 
4182   // fold (A - (0-B)) -> A+B
4183   if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
4184     return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
4185 
4186   // fold A-(A-B) -> B
4187   if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
4188     return N1.getOperand(1);
4189 
4190   // fold (A+B)-A -> B
4191   if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
4192     return N0.getOperand(1);
4193 
4194   // fold (A+B)-B -> A
4195   if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
4196     return N0.getOperand(0);
4197 
4198   // fold (A+C1)-C2 -> A+(C1-C2)
4199   if (N0.getOpcode() == ISD::ADD) {
4200     SDValue N01 = N0.getOperand(1);
4201     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
4202       return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
4203   }
4204 
4205   // fold C2-(A+C1) -> (C2-C1)-A
4206   if (N1.getOpcode() == ISD::ADD) {
4207     SDValue N11 = N1.getOperand(1);
4208     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
4209       return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
4210   }
4211 
4212   // fold (A-C1)-C2 -> A-(C1+C2)
4213   if (N0.getOpcode() == ISD::SUB) {
4214     SDValue N01 = N0.getOperand(1);
4215     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
4216       return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
4217   }
4218 
4219   // fold (c1-A)-c2 -> (c1-c2)-A
4220   if (N0.getOpcode() == ISD::SUB) {
4221     SDValue N00 = N0.getOperand(0);
4222     if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
4223       return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
4224   }
4225 
4226   SDValue A, B, C;
4227 
4228   // fold ((A+(B+C))-B) -> A+C
4229   if (sd_match(N0, m_Add(m_Value(A), m_Add(m_Specific(N1), m_Value(C)))))
4230     return DAG.getNode(ISD::ADD, DL, VT, A, C);
4231 
4232   // fold ((A+(B-C))-B) -> A-C
4233   if (sd_match(N0, m_Add(m_Value(A), m_Sub(m_Specific(N1), m_Value(C)))))
4234     return DAG.getNode(ISD::SUB, DL, VT, A, C);
4235 
4236   // fold ((A-(B-C))-C) -> A-B
4237   if (sd_match(N0, m_Sub(m_Value(A), m_Sub(m_Value(B), m_Specific(N1)))))
4238     return DAG.getNode(ISD::SUB, DL, VT, A, B);
4239 
4240   // fold (A-(B-C)) -> A+(C-B)
4241   if (sd_match(N1, m_OneUse(m_Sub(m_Value(B), m_Value(C)))))
4242     return DAG.getNode(ISD::ADD, DL, VT, N0,
4243                        DAG.getNode(ISD::SUB, DL, VT, C, B));
4244 
4245   // A - (A & B)  ->  A & (~B)
4246   if (sd_match(N1, m_And(m_Specific(N0), m_Value(B))) &&
4247       (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true)))
4248     return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getNOT(DL, B, VT));
4249 
4250   // fold (A - (-B * C)) -> (A + (B * C))
4251   if (sd_match(N1, m_OneUse(m_Mul(m_Neg(m_Value(B)), m_Value(C)))))
4252     return DAG.getNode(ISD::ADD, DL, VT, N0,
4253                        DAG.getNode(ISD::MUL, DL, VT, B, C));
4254 
4255   // If either operand of a sub is undef, the result is undef
4256   if (N0.isUndef())
4257     return N0;
4258   if (N1.isUndef())
4259     return N1;
4260 
4261   if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
4262     return V;
4263 
4264   if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
4265     return V;
4266 
4267   // Try to match AVGCEIL fixedwidth pattern
4268   if (SDValue V = foldSubToAvg(N, DL))
4269     return V;
4270 
4271   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, DL))
4272     return V;
4273 
4274   if (SDValue V = foldSubToUSubSat(VT, N, DL))
4275     return V;
4276 
4277   if (SDValue V = foldRemainderIdiom(N, DAG, DL))
4278     return V;
4279 
4280   // (A - B) - 1  ->  add (xor B, -1), A
4281   if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))),
4282                         m_One(/*AllowUndefs=*/true))))
4283     return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
4284 
4285   // Look for:
4286   //   sub y, (xor x, -1)
4287   // And if the target does not like this form then turn into:
4288   //   add (add x, y), 1
4289   if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
4290     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
4291     return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
4292   }
4293 
4294   // Hoist one-use addition by non-opaque constant:
4295   //   (x + C) - y  ->  (x - y) + C
4296   if (!reassociationCanBreakAddressingModePattern(ISD::SUB, DL, N, N0, N1) &&
4297       N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
4298       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4299     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4300     return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
4301   }
4302   // y - (x + C)  ->  (y - x) - C
4303   if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4304       isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
4305     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
4306     return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
4307   }
4308   // (x - C) - y  ->  (x - y) - C
4309   // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4310   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4311       isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4312     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4313     return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
4314   }
4315   // (C - x) - y  ->  C - (x + y)
4316   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4317       isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
4318     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
4319     return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
4320   }
4321 
4322   // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4323   // rather than 'sub 0/1' (the sext should get folded).
4324   // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4325   if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4326       N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
4327       TLI.getBooleanContents(VT) ==
4328           TargetLowering::ZeroOrNegativeOneBooleanContent) {
4329     SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
4330     return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
4331   }
4332 
4333   // fold B = sra (A, size(A)-1); sub (xor (A, B), B) -> (abs A)
4334   if ((!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4335       sd_match(N1, m_Sra(m_Value(A), m_SpecificInt(BitWidth - 1))) &&
4336       sd_match(N0, m_Xor(m_Specific(A), m_Specific(N1))))
4337     return DAG.getNode(ISD::ABS, DL, VT, A);
4338 
4339   // If the relocation model supports it, consider symbol offsets.
4340   if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
4341     if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4342       // fold (sub Sym+c1, Sym+c2) -> c1-c2
4343       if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
4344         if (GA->getGlobal() == GB->getGlobal())
4345           return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
4346                                  DL, VT);
4347     }
4348 
4349   // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4350   if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4351     VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
4352     if (TN->getVT() == MVT::i1) {
4353       SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
4354                                  DAG.getConstant(1, DL, VT));
4355       return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
4356     }
4357   }
4358 
4359   // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4360   if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4361     const APInt &IntVal = N1.getConstantOperandAPInt(0);
4362     return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
4363   }
4364 
4365   // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4366   if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4367     APInt NewStep = -N1.getConstantOperandAPInt(0);
4368     return DAG.getNode(ISD::ADD, DL, VT, N0,
4369                        DAG.getStepVector(DL, VT, NewStep));
4370   }
4371 
4372   // Prefer an add for more folding potential and possibly better codegen:
4373   // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4374   if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4375     SDValue ShAmt = N1.getOperand(1);
4376     ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
4377     if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
4378       SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
4379       return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
4380     }
4381   }
4382 
4383   // As with the previous fold, prefer add for more folding potential.
4384   // Subtracting SMIN/0 is the same as adding SMIN/0:
4385   // N0 - (X << BW-1) --> N0 + (X << BW-1)
4386   if (N1.getOpcode() == ISD::SHL) {
4387     ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
4388     if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4389       return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
4390   }
4391 
4392   // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4393   if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(N0.getOperand(1)) &&
4394       N0.getResNo() == 0 && N0.hasOneUse())
4395     return DAG.getNode(ISD::USUBO_CARRY, DL, N0->getVTList(),
4396                        N0.getOperand(0), N1, N0.getOperand(2));
4397 
4398   if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT)) {
4399     // (sub Carry, X)  ->  (uaddo_carry (sub 0, X), 0, Carry)
4400     if (SDValue Carry = getAsCarry(TLI, N0)) {
4401       SDValue X = N1;
4402       SDValue Zero = DAG.getConstant(0, DL, VT);
4403       SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
4404       return DAG.getNode(ISD::UADDO_CARRY, DL,
4405                          DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
4406                          Carry);
4407     }
4408   }
4409 
4410   // If there's no chance of borrowing from adjacent bits, then sub is xor:
4411   // sub C0, X --> xor X, C0
4412   if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
4413     if (!C0->isOpaque()) {
4414       const APInt &C0Val = C0->getAPIntValue();
4415       const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
4416       if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4417         return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4418     }
4419   }
4420 
4421   // smax(a,b) - smin(a,b) --> abds(a,b)
4422   if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
4423       sd_match(N0, m_SMaxLike(m_Value(A), m_Value(B))) &&
4424       sd_match(N1, m_SMinLike(m_Specific(A), m_Specific(B))))
4425     return DAG.getNode(ISD::ABDS, DL, VT, A, B);
4426 
4427   // smin(a,b) - smax(a,b) --> neg(abds(a,b))
4428   if (hasOperation(ISD::ABDS, VT) &&
4429       sd_match(N0, m_SMinLike(m_Value(A), m_Value(B))) &&
4430       sd_match(N1, m_SMaxLike(m_Specific(A), m_Specific(B))))
4431     return DAG.getNegative(DAG.getNode(ISD::ABDS, DL, VT, A, B), DL, VT);
4432 
4433   // umax(a,b) - umin(a,b) --> abdu(a,b)
4434   if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
4435       sd_match(N0, m_UMaxLike(m_Value(A), m_Value(B))) &&
4436       sd_match(N1, m_UMinLike(m_Specific(A), m_Specific(B))))
4437     return DAG.getNode(ISD::ABDU, DL, VT, A, B);
4438 
4439   // umin(a,b) - umax(a,b) --> neg(abdu(a,b))
4440   if (hasOperation(ISD::ABDU, VT) &&
4441       sd_match(N0, m_UMinLike(m_Value(A), m_Value(B))) &&
4442       sd_match(N1, m_UMaxLike(m_Specific(A), m_Specific(B))))
4443     return DAG.getNegative(DAG.getNode(ISD::ABDU, DL, VT, A, B), DL, VT);
4444 
4445   // (sub x, (select (ult x, y), 0, y)) -> (umin x, (sub x, y))
4446   // (sub x, (select (uge x, y), y, 0)) -> (umin x, (sub x, y))
4447   if (hasUMin(VT)) {
4448     SDValue Y;
4449     if (sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y),
4450                                                m_SpecificCondCode(ISD::SETULT)),
4451                                        m_Zero(), m_Deferred(Y)))) ||
4452         sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y),
4453                                                m_SpecificCondCode(ISD::SETUGE)),
4454                                        m_Deferred(Y), m_Zero()))))
4455       return DAG.getNode(ISD::UMIN, DL, VT, N0,
4456                          DAG.getNode(ISD::SUB, DL, VT, N0, Y));
4457   }
4458 
4459   return SDValue();
4460 }
4461 
visitSUBSAT(SDNode * N)4462 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4463   unsigned Opcode = N->getOpcode();
4464   SDValue N0 = N->getOperand(0);
4465   SDValue N1 = N->getOperand(1);
4466   EVT VT = N0.getValueType();
4467   bool IsSigned = Opcode == ISD::SSUBSAT;
4468   SDLoc DL(N);
4469 
4470   // fold (sub_sat x, undef) -> 0
4471   if (N0.isUndef() || N1.isUndef())
4472     return DAG.getConstant(0, DL, VT);
4473 
4474   // fold (sub_sat x, x) -> 0
4475   if (N0 == N1)
4476     return DAG.getConstant(0, DL, VT);
4477 
4478   // fold (sub_sat c1, c2) -> c3
4479   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4480     return C;
4481 
4482   // fold vector ops
4483   if (VT.isVector()) {
4484     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4485       return FoldedVOp;
4486 
4487     // fold (sub_sat x, 0) -> x, vector edition
4488     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4489       return N0;
4490   }
4491 
4492   // fold (sub_sat x, 0) -> x
4493   if (isNullConstant(N1))
4494     return N0;
4495 
4496   // If it cannot overflow, transform into an sub.
4497   if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4498     return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
4499 
4500   return SDValue();
4501 }
4502 
visitSUBC(SDNode * N)4503 SDValue DAGCombiner::visitSUBC(SDNode *N) {
4504   SDValue N0 = N->getOperand(0);
4505   SDValue N1 = N->getOperand(1);
4506   EVT VT = N0.getValueType();
4507   SDLoc DL(N);
4508 
4509   // If the flag result is dead, turn this into an SUB.
4510   if (!N->hasAnyUseOfValue(1))
4511     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4512                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4513 
4514   // fold (subc x, x) -> 0 + no borrow
4515   if (N0 == N1)
4516     return CombineTo(N, DAG.getConstant(0, DL, VT),
4517                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4518 
4519   // fold (subc x, 0) -> x + no borrow
4520   if (isNullConstant(N1))
4521     return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4522 
4523   // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4524   if (isAllOnesConstant(N0))
4525     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4526                      DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4527 
4528   return SDValue();
4529 }
4530 
visitSUBO(SDNode * N)4531 SDValue DAGCombiner::visitSUBO(SDNode *N) {
4532   SDValue N0 = N->getOperand(0);
4533   SDValue N1 = N->getOperand(1);
4534   EVT VT = N0.getValueType();
4535   bool IsSigned = (ISD::SSUBO == N->getOpcode());
4536 
4537   EVT CarryVT = N->getValueType(1);
4538   SDLoc DL(N);
4539 
4540   // If the flag result is dead, turn this into an SUB.
4541   if (!N->hasAnyUseOfValue(1))
4542     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4543                      DAG.getUNDEF(CarryVT));
4544 
4545   // fold (subo x, x) -> 0 + no borrow
4546   if (N0 == N1)
4547     return CombineTo(N, DAG.getConstant(0, DL, VT),
4548                      DAG.getConstant(0, DL, CarryVT));
4549 
4550   // fold (subox, c) -> (addo x, -c)
4551   if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
4552     if (IsSigned && !N1C->isMinSignedValue())
4553       return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
4554                          DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4555 
4556   // fold (subo x, 0) -> x + no borrow
4557   if (isNullOrNullSplat(N1))
4558     return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
4559 
4560   // If it cannot overflow, transform into an sub.
4561   if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4562     return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4563                      DAG.getConstant(0, DL, CarryVT));
4564 
4565   // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4566   if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
4567     return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4568                      DAG.getConstant(0, DL, CarryVT));
4569 
4570   return SDValue();
4571 }
4572 
visitSUBE(SDNode * N)4573 SDValue DAGCombiner::visitSUBE(SDNode *N) {
4574   SDValue N0 = N->getOperand(0);
4575   SDValue N1 = N->getOperand(1);
4576   SDValue CarryIn = N->getOperand(2);
4577 
4578   // fold (sube x, y, false) -> (subc x, y)
4579   if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4580     return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
4581 
4582   return SDValue();
4583 }
4584 
visitUSUBO_CARRY(SDNode * N)4585 SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4586   SDValue N0 = N->getOperand(0);
4587   SDValue N1 = N->getOperand(1);
4588   SDValue CarryIn = N->getOperand(2);
4589 
4590   // fold (usubo_carry x, y, false) -> (usubo x, y)
4591   if (isNullConstant(CarryIn)) {
4592     if (!LegalOperations ||
4593         TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
4594       return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
4595   }
4596 
4597   return SDValue();
4598 }
4599 
visitSSUBO_CARRY(SDNode * N)4600 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4601   SDValue N0 = N->getOperand(0);
4602   SDValue N1 = N->getOperand(1);
4603   SDValue CarryIn = N->getOperand(2);
4604 
4605   // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4606   if (isNullConstant(CarryIn)) {
4607     if (!LegalOperations ||
4608         TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
4609       return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
4610   }
4611 
4612   return SDValue();
4613 }
4614 
4615 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4616 // UMULFIXSAT here.
visitMULFIX(SDNode * N)4617 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4618   SDValue N0 = N->getOperand(0);
4619   SDValue N1 = N->getOperand(1);
4620   SDValue Scale = N->getOperand(2);
4621   EVT VT = N0.getValueType();
4622 
4623   // fold (mulfix x, undef, scale) -> 0
4624   if (N0.isUndef() || N1.isUndef())
4625     return DAG.getConstant(0, SDLoc(N), VT);
4626 
4627   // Canonicalize constant to RHS (vector doesn't have to splat)
4628   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4629      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4630     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
4631 
4632   // fold (mulfix x, 0, scale) -> 0
4633   if (isNullConstant(N1))
4634     return DAG.getConstant(0, SDLoc(N), VT);
4635 
4636   return SDValue();
4637 }
4638 
visitMUL(SDNode * N)4639 template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4640   SDValue N0 = N->getOperand(0);
4641   SDValue N1 = N->getOperand(1);
4642   EVT VT = N0.getValueType();
4643   unsigned BitWidth = VT.getScalarSizeInBits();
4644   SDLoc DL(N);
4645   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4646   MatchContextClass Matcher(DAG, TLI, N);
4647 
4648   // fold (mul x, undef) -> 0
4649   if (N0.isUndef() || N1.isUndef())
4650     return DAG.getConstant(0, DL, VT);
4651 
4652   // fold (mul c1, c2) -> c1*c2
4653   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
4654     return C;
4655 
4656   // canonicalize constant to RHS (vector doesn't have to splat)
4657   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4658       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4659     return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
4660 
4661   bool N1IsConst = false;
4662   bool N1IsOpaqueConst = false;
4663   APInt ConstValue1;
4664 
4665   // fold vector ops
4666   if (VT.isVector()) {
4667     // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4668     if (!UseVP)
4669       if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4670         return FoldedVOp;
4671 
4672     N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4673     assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
4674            "Splat APInt should be element width");
4675   } else {
4676     N1IsConst = isa<ConstantSDNode>(N1);
4677     if (N1IsConst) {
4678       ConstValue1 = N1->getAsAPIntVal();
4679       N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
4680     }
4681   }
4682 
4683   // fold (mul x, 0) -> 0
4684   if (N1IsConst && ConstValue1.isZero())
4685     return N1;
4686 
4687   // fold (mul x, 1) -> x
4688   if (N1IsConst && ConstValue1.isOne())
4689     return N0;
4690 
4691   if (!UseVP)
4692     if (SDValue NewSel = foldBinOpIntoSelect(N))
4693       return NewSel;
4694 
4695   // fold (mul x, -1) -> 0-x
4696   if (N1IsConst && ConstValue1.isAllOnes())
4697     return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4698 
4699   // fold (mul x, (1 << c)) -> x << c
4700   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4701       (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4702     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4703       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4704       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4705       return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
4706     }
4707   }
4708 
4709   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4710   if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4711     unsigned Log2Val = (-ConstValue1).logBase2();
4712 
4713     // FIXME: If the input is something that is easily negated (e.g. a
4714     // single-use add), we should put the negate there.
4715     return Matcher.getNode(
4716         ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4717         Matcher.getNode(ISD::SHL, DL, VT, N0,
4718                         DAG.getShiftAmountConstant(Log2Val, VT, DL)));
4719   }
4720 
4721   // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4722   // hi result is in use in case we hit this mid-legalization.
4723   if (!UseVP) {
4724     for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4725       if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4726         SDVTList LoHiVT = DAG.getVTList(VT, VT);
4727         // TODO: Can we match commutable operands with getNodeIfExists?
4728         if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4729           if (LoHi->hasAnyUseOfValue(1))
4730             return SDValue(LoHi, 0);
4731         if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4732           if (LoHi->hasAnyUseOfValue(1))
4733             return SDValue(LoHi, 0);
4734       }
4735     }
4736   }
4737 
4738   // Try to transform:
4739   // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4740   // mul x, (2^N + 1) --> add (shl x, N), x
4741   // mul x, (2^N - 1) --> sub (shl x, N), x
4742   // Examples: x * 33 --> (x << 5) + x
4743   //           x * 15 --> (x << 4) - x
4744   //           x * -33 --> -((x << 5) + x)
4745   //           x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4746   // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4747   // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4748   // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4749   // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4750   //           x * 0xf800 --> (x << 16) - (x << 11)
4751   //           x * -0x8800 --> -((x << 15) + (x << 11))
4752   //           x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4753   if (!UseVP && N1IsConst &&
4754       TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4755     // TODO: We could handle more general decomposition of any constant by
4756     //       having the target set a limit on number of ops and making a
4757     //       callback to determine that sequence (similar to sqrt expansion).
4758     unsigned MathOp = ISD::DELETED_NODE;
4759     APInt MulC = ConstValue1.abs();
4760     // The constant `2` should be treated as (2^0 + 1).
4761     unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4762     MulC.lshrInPlace(TZeros);
4763     if ((MulC - 1).isPowerOf2())
4764       MathOp = ISD::ADD;
4765     else if ((MulC + 1).isPowerOf2())
4766       MathOp = ISD::SUB;
4767 
4768     if (MathOp != ISD::DELETED_NODE) {
4769       unsigned ShAmt =
4770           MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4771       ShAmt += TZeros;
4772       assert(ShAmt < BitWidth &&
4773              "multiply-by-constant generated out of bounds shift");
4774       SDValue Shl =
4775           DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4776       SDValue R =
4777           TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4778                                DAG.getNode(ISD::SHL, DL, VT, N0,
4779                                            DAG.getConstant(TZeros, DL, VT)))
4780                  : DAG.getNode(MathOp, DL, VT, Shl, N0);
4781       if (ConstValue1.isNegative())
4782         R = DAG.getNegative(R, DL, VT);
4783       return R;
4784     }
4785   }
4786 
4787   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4788   if (sd_context_match(N0, Matcher, m_Opc(ISD::SHL))) {
4789     SDValue N01 = N0.getOperand(1);
4790     if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4791       return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4792   }
4793 
4794   // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4795   // use.
4796   {
4797     SDValue Sh, Y;
4798 
4799     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
4800     if (sd_context_match(N0, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4801         isConstantOrConstantVector(N0.getOperand(1))) {
4802       Sh = N0; Y = N1;
4803     } else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4804                isConstantOrConstantVector(N1.getOperand(1))) {
4805       Sh = N1; Y = N0;
4806     }
4807 
4808     if (Sh.getNode()) {
4809       SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4810       return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4811     }
4812   }
4813 
4814   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4815   if (sd_context_match(N0, Matcher, m_Opc(ISD::ADD)) &&
4816       DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4817       DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4818       isMulAddWithConstProfitable(N, N0, N1))
4819     return Matcher.getNode(
4820         ISD::ADD, DL, VT,
4821         Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4822         Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4823 
4824   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4825   ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4826   if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4827     const APInt &C0 = N0.getConstantOperandAPInt(0);
4828     const APInt &C1 = NC1->getAPIntValue();
4829     return DAG.getVScale(DL, VT, C0 * C1);
4830   }
4831 
4832   // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4833   APInt MulVal;
4834   if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4835       ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4836     const APInt &C0 = N0.getConstantOperandAPInt(0);
4837     APInt NewStep = C0 * MulVal;
4838     return DAG.getStepVector(DL, VT, NewStep);
4839   }
4840 
4841   // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
4842   SDValue X;
4843   if (!UseVP && (!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4844       sd_context_match(
4845           N, Matcher,
4846           m_Mul(m_Or(m_Sra(m_Value(X), m_SpecificInt(BitWidth - 1)), m_One()),
4847                 m_Deferred(X)))) {
4848     return Matcher.getNode(ISD::ABS, DL, VT, X);
4849   }
4850 
4851   // Fold ((mul x, 0/undef) -> 0,
4852   //       (mul x, 1) -> x) -> x)
4853   // -> and(x, mask)
4854   // We can replace vectors with '0' and '1' factors with a clearing mask.
4855   if (VT.isFixedLengthVector()) {
4856     unsigned NumElts = VT.getVectorNumElements();
4857     SmallBitVector ClearMask;
4858     ClearMask.reserve(NumElts);
4859     auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4860       if (!V || V->isZero()) {
4861         ClearMask.push_back(true);
4862         return true;
4863       }
4864       ClearMask.push_back(false);
4865       return V->isOne();
4866     };
4867     if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4868         ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4869       assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4870       EVT LegalSVT = N1.getOperand(0).getValueType();
4871       SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4872       SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4873       SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4874       for (unsigned I = 0; I != NumElts; ++I)
4875         if (ClearMask[I])
4876           Mask[I] = Zero;
4877       return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4878     }
4879   }
4880 
4881   // reassociate mul
4882   // TODO: Change reassociateOps to support vp ops.
4883   if (!UseVP)
4884     if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4885       return RMUL;
4886 
4887   // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4888   // TODO: Change reassociateReduction to support vp ops.
4889   if (!UseVP)
4890     if (SDValue SD =
4891             reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4892       return SD;
4893 
4894   // Simplify the operands using demanded-bits information.
4895   if (SimplifyDemandedBits(SDValue(N, 0)))
4896     return SDValue(N, 0);
4897 
4898   return SDValue();
4899 }
4900 
4901 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4902 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4903                                      const TargetLowering &TLI) {
4904   RTLIB::Libcall LC;
4905   EVT NodeType = Node->getValueType(0);
4906   if (!NodeType.isSimple())
4907     return false;
4908   switch (NodeType.getSimpleVT().SimpleTy) {
4909   default: return false; // No libcall for vector types.
4910   case MVT::i8:   LC= isSigned ? RTLIB::SDIVREM_I8  : RTLIB::UDIVREM_I8;  break;
4911   case MVT::i16:  LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4912   case MVT::i32:  LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4913   case MVT::i64:  LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4914   case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4915   }
4916 
4917   return TLI.getLibcallName(LC) != nullptr;
4918 }
4919 
4920 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4921 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4922   if (Node->use_empty())
4923     return SDValue(); // This is a dead node, leave it alone.
4924 
4925   unsigned Opcode = Node->getOpcode();
4926   bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4927   unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4928 
4929   // DivMod lib calls can still work on non-legal types if using lib-calls.
4930   EVT VT = Node->getValueType(0);
4931   if (VT.isVector() || !VT.isInteger())
4932     return SDValue();
4933 
4934   if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4935     return SDValue();
4936 
4937   // If DIVREM is going to get expanded into a libcall,
4938   // but there is no libcall available, then don't combine.
4939   if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4940       !isDivRemLibcallAvailable(Node, isSigned, TLI))
4941     return SDValue();
4942 
4943   // If div is legal, it's better to do the normal expansion
4944   unsigned OtherOpcode = 0;
4945   if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4946     OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4947     if (TLI.isOperationLegalOrCustom(Opcode, VT))
4948       return SDValue();
4949   } else {
4950     OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4951     if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4952       return SDValue();
4953   }
4954 
4955   SDValue Op0 = Node->getOperand(0);
4956   SDValue Op1 = Node->getOperand(1);
4957   SDValue combined;
4958   for (SDNode *User : Op0->users()) {
4959     if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4960         User->use_empty())
4961       continue;
4962     // Convert the other matching node(s), too;
4963     // otherwise, the DIVREM may get target-legalized into something
4964     // target-specific that we won't be able to recognize.
4965     unsigned UserOpc = User->getOpcode();
4966     if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4967         User->getOperand(0) == Op0 &&
4968         User->getOperand(1) == Op1) {
4969       if (!combined) {
4970         if (UserOpc == OtherOpcode) {
4971           SDVTList VTs = DAG.getVTList(VT, VT);
4972           combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4973         } else if (UserOpc == DivRemOpc) {
4974           combined = SDValue(User, 0);
4975         } else {
4976           assert(UserOpc == Opcode);
4977           continue;
4978         }
4979       }
4980       if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4981         CombineTo(User, combined);
4982       else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4983         CombineTo(User, combined.getValue(1));
4984     }
4985   }
4986   return combined;
4987 }
4988 
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4989 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4990   SDValue N0 = N->getOperand(0);
4991   SDValue N1 = N->getOperand(1);
4992   EVT VT = N->getValueType(0);
4993   SDLoc DL(N);
4994 
4995   unsigned Opc = N->getOpcode();
4996   bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4997   ConstantSDNode *N1C = isConstOrConstSplat(N1);
4998 
4999   // X / undef -> undef
5000   // X % undef -> undef
5001   // X / 0 -> undef
5002   // X % 0 -> undef
5003   // NOTE: This includes vectors where any divisor element is zero/undef.
5004   if (DAG.isUndef(Opc, {N0, N1}))
5005     return DAG.getUNDEF(VT);
5006 
5007   // undef / X -> 0
5008   // undef % X -> 0
5009   if (N0.isUndef())
5010     return DAG.getConstant(0, DL, VT);
5011 
5012   // 0 / X -> 0
5013   // 0 % X -> 0
5014   ConstantSDNode *N0C = isConstOrConstSplat(N0);
5015   if (N0C && N0C->isZero())
5016     return N0;
5017 
5018   // X / X -> 1
5019   // X % X -> 0
5020   if (N0 == N1)
5021     return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
5022 
5023   // X / 1 -> X
5024   // X % 1 -> 0
5025   // If this is a boolean op (single-bit element type), we can't have
5026   // division-by-zero or remainder-by-zero, so assume the divisor is 1.
5027   // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
5028   // it's a 1.
5029   if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
5030     return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
5031 
5032   return SDValue();
5033 }
5034 
visitSDIV(SDNode * N)5035 SDValue DAGCombiner::visitSDIV(SDNode *N) {
5036   SDValue N0 = N->getOperand(0);
5037   SDValue N1 = N->getOperand(1);
5038   EVT VT = N->getValueType(0);
5039   EVT CCVT = getSetCCResultType(VT);
5040   SDLoc DL(N);
5041 
5042   // fold (sdiv c1, c2) -> c1/c2
5043   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
5044     return C;
5045 
5046   // fold vector ops
5047   if (VT.isVector())
5048     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5049       return FoldedVOp;
5050 
5051   // fold (sdiv X, -1) -> 0-X
5052   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5053   if (N1C && N1C->isAllOnes())
5054     return DAG.getNegative(N0, DL, VT);
5055 
5056   // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
5057   if (N1C && N1C->isMinSignedValue())
5058     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
5059                          DAG.getConstant(1, DL, VT),
5060                          DAG.getConstant(0, DL, VT));
5061 
5062   if (SDValue V = simplifyDivRem(N, DAG))
5063     return V;
5064 
5065   if (SDValue NewSel = foldBinOpIntoSelect(N))
5066     return NewSel;
5067 
5068   // If we know the sign bits of both operands are zero, strength reduce to a
5069   // udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
5070   if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
5071     return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
5072 
5073   if (SDValue V = visitSDIVLike(N0, N1, N)) {
5074     // If the corresponding remainder node exists, update its users with
5075     // (Dividend - (Quotient * Divisor).
5076     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
5077                                               { N0, N1 })) {
5078       // If the sdiv has the exact flag we shouldn't propagate it to the
5079       // remainder node.
5080       if (!N->getFlags().hasExact()) {
5081         SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
5082         SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5083         AddToWorklist(Mul.getNode());
5084         AddToWorklist(Sub.getNode());
5085         CombineTo(RemNode, Sub);
5086       }
5087     }
5088     return V;
5089   }
5090 
5091   // sdiv, srem -> sdivrem
5092   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5093   // true.  Otherwise, we break the simplification logic in visitREM().
5094   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5095   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
5096     if (SDValue DivRem = useDivRem(N))
5097         return DivRem;
5098 
5099   return SDValue();
5100 }
5101 
isDivisorPowerOfTwo(SDValue Divisor)5102 static bool isDivisorPowerOfTwo(SDValue Divisor) {
5103   // Helper for determining whether a value is a power-2 constant scalar or a
5104   // vector of such elements.
5105   auto IsPowerOfTwo = [](ConstantSDNode *C) {
5106     if (C->isZero() || C->isOpaque())
5107       return false;
5108     if (C->getAPIntValue().isPowerOf2())
5109       return true;
5110     if (C->getAPIntValue().isNegatedPowerOf2())
5111       return true;
5112     return false;
5113   };
5114 
5115   return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
5116 }
5117 
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)5118 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5119   SDLoc DL(N);
5120   EVT VT = N->getValueType(0);
5121   EVT CCVT = getSetCCResultType(VT);
5122   unsigned BitWidth = VT.getScalarSizeInBits();
5123 
5124   // fold (sdiv X, pow2) -> simple ops after legalize
5125   // FIXME: We check for the exact bit here because the generic lowering gives
5126   // better results in that case. The target-specific lowering should learn how
5127   // to handle exact sdivs efficiently.
5128   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
5129     // Target-specific implementation of sdiv x, pow2.
5130     if (SDValue Res = BuildSDIVPow2(N))
5131       return Res;
5132 
5133     // Create constants that are functions of the shift amount value.
5134     EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
5135     SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
5136     SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
5137     C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
5138     SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
5139     if (!isConstantOrConstantVector(Inexact))
5140       return SDValue();
5141 
5142     // Splat the sign bit into the register
5143     SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
5144                                DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
5145     AddToWorklist(Sign.getNode());
5146 
5147     // Add (N0 < 0) ? abs2 - 1 : 0;
5148     SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
5149     AddToWorklist(Srl.getNode());
5150     SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
5151     AddToWorklist(Add.getNode());
5152     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
5153     AddToWorklist(Sra.getNode());
5154 
5155     // Special case: (sdiv X, 1) -> X
5156     // Special Case: (sdiv X, -1) -> 0-X
5157     SDValue One = DAG.getConstant(1, DL, VT);
5158     SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
5159     SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
5160     SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
5161     SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
5162     Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
5163 
5164     // If dividing by a positive value, we're done. Otherwise, the result must
5165     // be negated.
5166     SDValue Zero = DAG.getConstant(0, DL, VT);
5167     SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
5168 
5169     // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
5170     SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
5171     SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
5172     return Res;
5173   }
5174 
5175   // If integer divide is expensive and we satisfy the requirements, emit an
5176   // alternate sequence.  Targets may check function attributes for size/speed
5177   // trade-offs.
5178   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5179   if (isConstantOrConstantVector(N1) &&
5180       !TLI.isIntDivCheap(N->getValueType(0), Attr))
5181     if (SDValue Op = BuildSDIV(N))
5182       return Op;
5183 
5184   return SDValue();
5185 }
5186 
visitUDIV(SDNode * N)5187 SDValue DAGCombiner::visitUDIV(SDNode *N) {
5188   SDValue N0 = N->getOperand(0);
5189   SDValue N1 = N->getOperand(1);
5190   EVT VT = N->getValueType(0);
5191   EVT CCVT = getSetCCResultType(VT);
5192   SDLoc DL(N);
5193 
5194   // fold (udiv c1, c2) -> c1/c2
5195   if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
5196     return C;
5197 
5198   // fold vector ops
5199   if (VT.isVector())
5200     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5201       return FoldedVOp;
5202 
5203   // fold (udiv X, -1) -> select(X == -1, 1, 0)
5204   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5205   if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
5206     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
5207                          DAG.getConstant(1, DL, VT),
5208                          DAG.getConstant(0, DL, VT));
5209   }
5210 
5211   if (SDValue V = simplifyDivRem(N, DAG))
5212     return V;
5213 
5214   if (SDValue NewSel = foldBinOpIntoSelect(N))
5215     return NewSel;
5216 
5217   if (SDValue V = visitUDIVLike(N0, N1, N)) {
5218     // If the corresponding remainder node exists, update its users with
5219     // (Dividend - (Quotient * Divisor).
5220     if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
5221                                               { N0, N1 })) {
5222       // If the udiv has the exact flag we shouldn't propagate it to the
5223       // remainder node.
5224       if (!N->getFlags().hasExact()) {
5225         SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
5226         SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5227         AddToWorklist(Mul.getNode());
5228         AddToWorklist(Sub.getNode());
5229         CombineTo(RemNode, Sub);
5230       }
5231     }
5232     return V;
5233   }
5234 
5235   // sdiv, srem -> sdivrem
5236   // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5237   // true.  Otherwise, we break the simplification logic in visitREM().
5238   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5239   if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
5240     if (SDValue DivRem = useDivRem(N))
5241         return DivRem;
5242 
5243   // Simplify the operands using demanded-bits information.
5244   // We don't have demanded bits support for UDIV so this just enables constant
5245   // folding based on known bits.
5246   if (SimplifyDemandedBits(SDValue(N, 0)))
5247     return SDValue(N, 0);
5248 
5249   return SDValue();
5250 }
5251 
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)5252 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5253   SDLoc DL(N);
5254   EVT VT = N->getValueType(0);
5255 
5256   // fold (udiv x, (1 << c)) -> x >>u c
5257   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true)) {
5258     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5259       AddToWorklist(LogBase2.getNode());
5260 
5261       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5262       SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
5263       AddToWorklist(Trunc.getNode());
5264       return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5265     }
5266   }
5267 
5268   // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
5269   if (N1.getOpcode() == ISD::SHL) {
5270     SDValue N10 = N1.getOperand(0);
5271     if (isConstantOrConstantVector(N10, /*NoOpaques*/ true)) {
5272       if (SDValue LogBase2 = BuildLogBase2(N10, DL)) {
5273         AddToWorklist(LogBase2.getNode());
5274 
5275         EVT ADDVT = N1.getOperand(1).getValueType();
5276         SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
5277         AddToWorklist(Trunc.getNode());
5278         SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
5279         AddToWorklist(Add.getNode());
5280         return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
5281       }
5282     }
5283   }
5284 
5285   // fold (udiv x, c) -> alternate
5286   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5287   if (isConstantOrConstantVector(N1) &&
5288       !TLI.isIntDivCheap(N->getValueType(0), Attr))
5289     if (SDValue Op = BuildUDIV(N))
5290       return Op;
5291 
5292   return SDValue();
5293 }
5294 
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)5295 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
5296   if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
5297       !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
5298     // Target-specific implementation of srem x, pow2.
5299     if (SDValue Res = BuildSREMPow2(N))
5300       return Res;
5301   }
5302   return SDValue();
5303 }
5304 
5305 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)5306 SDValue DAGCombiner::visitREM(SDNode *N) {
5307   unsigned Opcode = N->getOpcode();
5308   SDValue N0 = N->getOperand(0);
5309   SDValue N1 = N->getOperand(1);
5310   EVT VT = N->getValueType(0);
5311   EVT CCVT = getSetCCResultType(VT);
5312 
5313   bool isSigned = (Opcode == ISD::SREM);
5314   SDLoc DL(N);
5315 
5316   // fold (rem c1, c2) -> c1%c2
5317   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5318     return C;
5319 
5320   // fold (urem X, -1) -> select(FX == -1, 0, FX)
5321   // Freeze the numerator to avoid a miscompile with an undefined value.
5322   if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false) &&
5323       CCVT.isVector() == VT.isVector()) {
5324     SDValue F0 = DAG.getFreeze(N0);
5325     SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
5326     return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
5327   }
5328 
5329   if (SDValue V = simplifyDivRem(N, DAG))
5330     return V;
5331 
5332   if (SDValue NewSel = foldBinOpIntoSelect(N))
5333     return NewSel;
5334 
5335   if (isSigned) {
5336     // If we know the sign bits of both operands are zero, strength reduce to a
5337     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
5338     if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
5339       return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
5340   } else {
5341     if (DAG.isKnownToBeAPowerOfTwo(N1)) {
5342       // fold (urem x, pow2) -> (and x, pow2-1)
5343       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5344       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5345       AddToWorklist(Add.getNode());
5346       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5347     }
5348     // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5349     // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5350     // TODO: We should sink the following into isKnownToBePowerOfTwo
5351     // using a OrZero parameter analogous to our handling in ValueTracking.
5352     if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5353         DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
5354       SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5355       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5356       AddToWorklist(Add.getNode());
5357       return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5358     }
5359   }
5360 
5361   AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5362 
5363   // If X/C can be simplified by the division-by-constant logic, lower
5364   // X%C to the equivalent of X-X/C*C.
5365   // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5366   // speculative DIV must not cause a DIVREM conversion.  We guard against this
5367   // by skipping the simplification if isIntDivCheap().  When div is not cheap,
5368   // combine will not return a DIVREM.  Regardless, checking cheapness here
5369   // makes sense since the simplification results in fatter code.
5370   if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
5371     if (isSigned) {
5372       // check if we can build faster implementation for srem
5373       if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5374         return OptimizedRem;
5375     }
5376 
5377     SDValue OptimizedDiv =
5378         isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5379     if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5380       // If the equivalent Div node also exists, update its users.
5381       unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5382       if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
5383                                                 { N0, N1 }))
5384         CombineTo(DivNode, OptimizedDiv);
5385       SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
5386       SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5387       AddToWorklist(OptimizedDiv.getNode());
5388       AddToWorklist(Mul.getNode());
5389       return Sub;
5390     }
5391   }
5392 
5393   // sdiv, srem -> sdivrem
5394   if (SDValue DivRem = useDivRem(N))
5395     return DivRem.getValue(1);
5396 
5397   return SDValue();
5398 }
5399 
visitMULHS(SDNode * N)5400 SDValue DAGCombiner::visitMULHS(SDNode *N) {
5401   SDValue N0 = N->getOperand(0);
5402   SDValue N1 = N->getOperand(1);
5403   EVT VT = N->getValueType(0);
5404   SDLoc DL(N);
5405 
5406   // fold (mulhs c1, c2)
5407   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
5408     return C;
5409 
5410   // canonicalize constant to RHS.
5411   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5412       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5413     return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
5414 
5415   if (VT.isVector()) {
5416     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5417       return FoldedVOp;
5418 
5419     // fold (mulhs x, 0) -> 0
5420     // do not return N1, because undef node may exist.
5421     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5422       return DAG.getConstant(0, DL, VT);
5423   }
5424 
5425   // fold (mulhs x, 0) -> 0
5426   if (isNullConstant(N1))
5427     return N1;
5428 
5429   // fold (mulhs x, 1) -> (sra x, size(x)-1)
5430   if (isOneConstant(N1))
5431     return DAG.getNode(
5432         ISD::SRA, DL, VT, N0,
5433         DAG.getShiftAmountConstant(N0.getScalarValueSizeInBits() - 1, VT, DL));
5434 
5435   // fold (mulhs x, undef) -> 0
5436   if (N0.isUndef() || N1.isUndef())
5437     return DAG.getConstant(0, DL, VT);
5438 
5439   // If the type twice as wide is legal, transform the mulhs to a wider multiply
5440   // plus a shift.
5441   if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
5442       !VT.isVector()) {
5443     MVT Simple = VT.getSimpleVT();
5444     unsigned SimpleSize = Simple.getSizeInBits();
5445     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5446     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5447       N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5448       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5449       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5450       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5451                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5452       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5453     }
5454   }
5455 
5456   return SDValue();
5457 }
5458 
visitMULHU(SDNode * N)5459 SDValue DAGCombiner::visitMULHU(SDNode *N) {
5460   SDValue N0 = N->getOperand(0);
5461   SDValue N1 = N->getOperand(1);
5462   EVT VT = N->getValueType(0);
5463   SDLoc DL(N);
5464 
5465   // fold (mulhu c1, c2)
5466   if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
5467     return C;
5468 
5469   // canonicalize constant to RHS.
5470   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5471       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5472     return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
5473 
5474   if (VT.isVector()) {
5475     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5476       return FoldedVOp;
5477 
5478     // fold (mulhu x, 0) -> 0
5479     // do not return N1, because undef node may exist.
5480     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5481       return DAG.getConstant(0, DL, VT);
5482   }
5483 
5484   // fold (mulhu x, 0) -> 0
5485   if (isNullConstant(N1))
5486     return N1;
5487 
5488   // fold (mulhu x, 1) -> 0
5489   if (isOneConstant(N1))
5490     return DAG.getConstant(0, DL, VT);
5491 
5492   // fold (mulhu x, undef) -> 0
5493   if (N0.isUndef() || N1.isUndef())
5494     return DAG.getConstant(0, DL, VT);
5495 
5496   // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5497   if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
5498       hasOperation(ISD::SRL, VT)) {
5499     if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5500       unsigned NumEltBits = VT.getScalarSizeInBits();
5501       SDValue SRLAmt = DAG.getNode(
5502           ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
5503       EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5504       SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
5505       return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5506     }
5507   }
5508 
5509   // If the type twice as wide is legal, transform the mulhu to a wider multiply
5510   // plus a shift.
5511   if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
5512       !VT.isVector()) {
5513     MVT Simple = VT.getSimpleVT();
5514     unsigned SimpleSize = Simple.getSizeInBits();
5515     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5516     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5517       N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5518       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5519       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5520       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5521                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5522       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5523     }
5524   }
5525 
5526   // Simplify the operands using demanded-bits information.
5527   // We don't have demanded bits support for MULHU so this just enables constant
5528   // folding based on known bits.
5529   if (SimplifyDemandedBits(SDValue(N, 0)))
5530     return SDValue(N, 0);
5531 
5532   return SDValue();
5533 }
5534 
visitAVG(SDNode * N)5535 SDValue DAGCombiner::visitAVG(SDNode *N) {
5536   unsigned Opcode = N->getOpcode();
5537   SDValue N0 = N->getOperand(0);
5538   SDValue N1 = N->getOperand(1);
5539   EVT VT = N->getValueType(0);
5540   SDLoc DL(N);
5541   bool IsSigned = Opcode == ISD::AVGCEILS || Opcode == ISD::AVGFLOORS;
5542 
5543   // fold (avg c1, c2)
5544   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5545     return C;
5546 
5547   // canonicalize constant to RHS.
5548   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5549       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5550     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5551 
5552   if (VT.isVector())
5553     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5554       return FoldedVOp;
5555 
5556   // fold (avg x, undef) -> x
5557   if (N0.isUndef())
5558     return N1;
5559   if (N1.isUndef())
5560     return N0;
5561 
5562   // fold (avg x, x) --> x
5563   if (N0 == N1 && Level >= AfterLegalizeTypes)
5564     return N0;
5565 
5566   // fold (avgfloor x, 0) -> x >> 1
5567   SDValue X, Y;
5568   if (sd_match(N, m_c_BinOp(ISD::AVGFLOORS, m_Value(X), m_Zero())))
5569     return DAG.getNode(ISD::SRA, DL, VT, X,
5570                        DAG.getShiftAmountConstant(1, VT, DL));
5571   if (sd_match(N, m_c_BinOp(ISD::AVGFLOORU, m_Value(X), m_Zero())))
5572     return DAG.getNode(ISD::SRL, DL, VT, X,
5573                        DAG.getShiftAmountConstant(1, VT, DL));
5574 
5575   // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5576   // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
5577   if (!IsSigned &&
5578       sd_match(N, m_BinOp(Opcode, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
5579       X.getValueType() == Y.getValueType() &&
5580       hasOperation(Opcode, X.getValueType())) {
5581     SDValue AvgU = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5582     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgU);
5583   }
5584   if (IsSigned &&
5585       sd_match(N, m_BinOp(Opcode, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
5586       X.getValueType() == Y.getValueType() &&
5587       hasOperation(Opcode, X.getValueType())) {
5588     SDValue AvgS = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5589     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgS);
5590   }
5591 
5592   // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5593   // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5594   // Check if avgflooru isn't legal/custom but avgceilu is.
5595   if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
5596       (!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
5597     if (DAG.isKnownNeverZero(N1))
5598       return DAG.getNode(
5599           ISD::AVGCEILU, DL, VT, N0,
5600           DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
5601     if (DAG.isKnownNeverZero(N0))
5602       return DAG.getNode(
5603           ISD::AVGCEILU, DL, VT, N1,
5604           DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
5605   }
5606 
5607   // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
5608   // Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
5609   if ((Opcode == ISD::AVGFLOORU && hasOperation(ISD::AVGCEILU, VT)) ||
5610       (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGCEILS, VT))) {
5611     SDValue Add;
5612     if (sd_match(N,
5613                  m_c_BinOp(Opcode,
5614                            m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))),
5615                            m_One())) ||
5616         sd_match(N, m_c_BinOp(Opcode,
5617                               m_AllOf(m_Value(Add), m_Add(m_Value(X), m_One())),
5618                               m_Value(Y)))) {
5619 
5620       if (IsSigned && Add->getFlags().hasNoSignedWrap())
5621         return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
5622 
5623       if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
5624         return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
5625     }
5626   }
5627 
5628   // Fold avgfloors(x,y) -> avgflooru(x,y) if both x and y are non-negative
5629   if (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGFLOORU, VT)) {
5630     if (DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5631       return DAG.getNode(ISD::AVGFLOORU, DL, VT, N0, N1);
5632   }
5633 
5634   return SDValue();
5635 }
5636 
visitABD(SDNode * N)5637 SDValue DAGCombiner::visitABD(SDNode *N) {
5638   unsigned Opcode = N->getOpcode();
5639   SDValue N0 = N->getOperand(0);
5640   SDValue N1 = N->getOperand(1);
5641   EVT VT = N->getValueType(0);
5642   SDLoc DL(N);
5643 
5644   // fold (abd c1, c2)
5645   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5646     return C;
5647 
5648   // canonicalize constant to RHS.
5649   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5650       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5651     return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5652 
5653   if (VT.isVector())
5654     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5655       return FoldedVOp;
5656 
5657   // fold (abd x, undef) -> 0
5658   if (N0.isUndef() || N1.isUndef())
5659     return DAG.getConstant(0, DL, VT);
5660 
5661   // fold (abd x, x) -> 0
5662   if (N0 == N1)
5663     return DAG.getConstant(0, DL, VT);
5664 
5665   SDValue X;
5666 
5667   // fold (abds x, 0) -> abs x
5668   if (sd_match(N, m_c_BinOp(ISD::ABDS, m_Value(X), m_Zero())) &&
5669       (!LegalOperations || hasOperation(ISD::ABS, VT)))
5670     return DAG.getNode(ISD::ABS, DL, VT, X);
5671 
5672   // fold (abdu x, 0) -> x
5673   if (sd_match(N, m_c_BinOp(ISD::ABDU, m_Value(X), m_Zero())))
5674     return X;
5675 
5676   // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5677   if (Opcode == ISD::ABDS && hasOperation(ISD::ABDU, VT) &&
5678       DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5679     return DAG.getNode(ISD::ABDU, DL, VT, N1, N0);
5680 
5681   return SDValue();
5682 }
5683 
5684 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5685 /// give the opcodes for the two computations that are being performed. Return
5686 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)5687 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5688                                                 unsigned HiOp) {
5689   // If the high half is not needed, just compute the low half.
5690   bool HiExists = N->hasAnyUseOfValue(1);
5691   if (!HiExists && (!LegalOperations ||
5692                     TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
5693     SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5694     return CombineTo(N, Res, Res);
5695   }
5696 
5697   // If the low half is not needed, just compute the high half.
5698   bool LoExists = N->hasAnyUseOfValue(0);
5699   if (!LoExists && (!LegalOperations ||
5700                     TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
5701     SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5702     return CombineTo(N, Res, Res);
5703   }
5704 
5705   // If both halves are used, return as it is.
5706   if (LoExists && HiExists)
5707     return SDValue();
5708 
5709   // If the two computed results can be simplified separately, separate them.
5710   if (LoExists) {
5711     SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5712     AddToWorklist(Lo.getNode());
5713     SDValue LoOpt = combine(Lo.getNode());
5714     if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5715         (!LegalOperations ||
5716          TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
5717       return CombineTo(N, LoOpt, LoOpt);
5718   }
5719 
5720   if (HiExists) {
5721     SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5722     AddToWorklist(Hi.getNode());
5723     SDValue HiOpt = combine(Hi.getNode());
5724     if (HiOpt.getNode() && HiOpt != Hi &&
5725         (!LegalOperations ||
5726          TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
5727       return CombineTo(N, HiOpt, HiOpt);
5728   }
5729 
5730   return SDValue();
5731 }
5732 
visitSMUL_LOHI(SDNode * N)5733 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5734   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
5735     return Res;
5736 
5737   SDValue N0 = N->getOperand(0);
5738   SDValue N1 = N->getOperand(1);
5739   EVT VT = N->getValueType(0);
5740   SDLoc DL(N);
5741 
5742   // Constant fold.
5743   if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5744     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N0, N1);
5745 
5746   // canonicalize constant to RHS (vector doesn't have to splat)
5747   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5748       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5749     return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
5750 
5751   // If the type is twice as wide is legal, transform the mulhu to a wider
5752   // multiply plus a shift.
5753   if (VT.isSimple() && !VT.isVector()) {
5754     MVT Simple = VT.getSimpleVT();
5755     unsigned SimpleSize = Simple.getSizeInBits();
5756     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5757     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5758       SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5759       SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5760       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5761       // Compute the high part as N1.
5762       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5763                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5764       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5765       // Compute the low part as N0.
5766       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5767       return CombineTo(N, Lo, Hi);
5768     }
5769   }
5770 
5771   return SDValue();
5772 }
5773 
visitUMUL_LOHI(SDNode * N)5774 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5775   if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
5776     return Res;
5777 
5778   SDValue N0 = N->getOperand(0);
5779   SDValue N1 = N->getOperand(1);
5780   EVT VT = N->getValueType(0);
5781   SDLoc DL(N);
5782 
5783   // Constant fold.
5784   if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5785     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N0, N1);
5786 
5787   // canonicalize constant to RHS (vector doesn't have to splat)
5788   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5789       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5790     return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
5791 
5792   // (umul_lohi N0, 0) -> (0, 0)
5793   if (isNullConstant(N1)) {
5794     SDValue Zero = DAG.getConstant(0, DL, VT);
5795     return CombineTo(N, Zero, Zero);
5796   }
5797 
5798   // (umul_lohi N0, 1) -> (N0, 0)
5799   if (isOneConstant(N1)) {
5800     SDValue Zero = DAG.getConstant(0, DL, VT);
5801     return CombineTo(N, N0, Zero);
5802   }
5803 
5804   // If the type is twice as wide is legal, transform the mulhu to a wider
5805   // multiply plus a shift.
5806   if (VT.isSimple() && !VT.isVector()) {
5807     MVT Simple = VT.getSimpleVT();
5808     unsigned SimpleSize = Simple.getSizeInBits();
5809     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5810     if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5811       SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5812       SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5813       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5814       // Compute the high part as N1.
5815       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5816                        DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5817       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5818       // Compute the low part as N0.
5819       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5820       return CombineTo(N, Lo, Hi);
5821     }
5822   }
5823 
5824   return SDValue();
5825 }
5826 
visitMULO(SDNode * N)5827 SDValue DAGCombiner::visitMULO(SDNode *N) {
5828   SDValue N0 = N->getOperand(0);
5829   SDValue N1 = N->getOperand(1);
5830   EVT VT = N0.getValueType();
5831   bool IsSigned = (ISD::SMULO == N->getOpcode());
5832 
5833   EVT CarryVT = N->getValueType(1);
5834   SDLoc DL(N);
5835 
5836   ConstantSDNode *N0C = isConstOrConstSplat(N0);
5837   ConstantSDNode *N1C = isConstOrConstSplat(N1);
5838 
5839   // fold operation with constant operands.
5840   // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5841   // multiple results.
5842   if (N0C && N1C) {
5843     bool Overflow;
5844     APInt Result =
5845         IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
5846                  : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
5847     return CombineTo(N, DAG.getConstant(Result, DL, VT),
5848                      DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
5849   }
5850 
5851   // canonicalize constant to RHS.
5852   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5853       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5854     return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
5855 
5856   // fold (mulo x, 0) -> 0 + no carry out
5857   if (isNullOrNullSplat(N1))
5858     return CombineTo(N, DAG.getConstant(0, DL, VT),
5859                      DAG.getConstant(0, DL, CarryVT));
5860 
5861   // (mulo x, 2) -> (addo x, x)
5862   // FIXME: This needs a freeze.
5863   if (N1C && N1C->getAPIntValue() == 2 &&
5864       (!IsSigned || VT.getScalarSizeInBits() > 2))
5865     return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5866                        N->getVTList(), N0, N0);
5867 
5868   // A 1 bit SMULO overflows if both inputs are 1.
5869   if (IsSigned && VT.getScalarSizeInBits() == 1) {
5870     SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5871     SDValue Cmp = DAG.getSetCC(DL, CarryVT, And,
5872                                DAG.getConstant(0, DL, VT), ISD::SETNE);
5873     return CombineTo(N, And, Cmp);
5874   }
5875 
5876   // If it cannot overflow, transform into a mul.
5877   if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5878     return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5879                      DAG.getConstant(0, DL, CarryVT));
5880   return SDValue();
5881 }
5882 
5883 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5884 // swapped around) make a signed saturate pattern, clamping to between a signed
5885 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5886 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5887 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5888 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned,SelectionDAG & DAG)5889 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5890                                   SDValue N3, ISD::CondCode CC, unsigned &BW,
5891                                   bool &Unsigned, SelectionDAG &DAG) {
5892   auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5893                             ISD::CondCode CC) {
5894     // The compare and select operand should be the same or the select operands
5895     // should be truncated versions of the comparison.
5896     if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5897       return 0;
5898     // The constants need to be the same or a truncated version of each other.
5899     ConstantSDNode *N1C = isConstOrConstSplat(peekThroughTruncates(N1));
5900     ConstantSDNode *N3C = isConstOrConstSplat(peekThroughTruncates(N3));
5901     if (!N1C || !N3C)
5902       return 0;
5903     const APInt &C1 = N1C->getAPIntValue().trunc(N1.getScalarValueSizeInBits());
5904     const APInt &C2 = N3C->getAPIntValue().trunc(N3.getScalarValueSizeInBits());
5905     if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5906       return 0;
5907     return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5908   };
5909 
5910   // Check the initial value is a SMIN/SMAX equivalent.
5911   unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5912   if (!Opcode0)
5913     return SDValue();
5914 
5915   // We could only need one range check, if the fptosi could never produce
5916   // the upper value.
5917   if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5918     if (isNullOrNullSplat(N3)) {
5919       EVT IntVT = N0.getValueType().getScalarType();
5920       EVT FPVT = N0.getOperand(0).getValueType().getScalarType();
5921       if (FPVT.isSimple()) {
5922         Type *InputTy = FPVT.getTypeForEVT(*DAG.getContext());
5923         const fltSemantics &Semantics = InputTy->getFltSemantics();
5924         uint32_t MinBitWidth =
5925           APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5926         if (IntVT.getSizeInBits() >= MinBitWidth) {
5927           Unsigned = true;
5928           BW = PowerOf2Ceil(MinBitWidth);
5929           return N0;
5930         }
5931       }
5932     }
5933   }
5934 
5935   SDValue N00, N01, N02, N03;
5936   ISD::CondCode N0CC;
5937   switch (N0.getOpcode()) {
5938   case ISD::SMIN:
5939   case ISD::SMAX:
5940     N00 = N02 = N0.getOperand(0);
5941     N01 = N03 = N0.getOperand(1);
5942     N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5943     break;
5944   case ISD::SELECT_CC:
5945     N00 = N0.getOperand(0);
5946     N01 = N0.getOperand(1);
5947     N02 = N0.getOperand(2);
5948     N03 = N0.getOperand(3);
5949     N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5950     break;
5951   case ISD::SELECT:
5952   case ISD::VSELECT:
5953     if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5954       return SDValue();
5955     N00 = N0.getOperand(0).getOperand(0);
5956     N01 = N0.getOperand(0).getOperand(1);
5957     N02 = N0.getOperand(1);
5958     N03 = N0.getOperand(2);
5959     N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5960     break;
5961   default:
5962     return SDValue();
5963   }
5964 
5965   unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5966   if (!Opcode1 || Opcode0 == Opcode1)
5967     return SDValue();
5968 
5969   ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5970   ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5971   if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5972     return SDValue();
5973 
5974   const APInt &MinC = MinCOp->getAPIntValue();
5975   const APInt &MaxC = MaxCOp->getAPIntValue();
5976   APInt MinCPlus1 = MinC + 1;
5977   if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5978     BW = MinCPlus1.exactLogBase2() + 1;
5979     Unsigned = false;
5980     return N02;
5981   }
5982 
5983   if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5984     BW = MinCPlus1.exactLogBase2();
5985     Unsigned = true;
5986     return N02;
5987   }
5988 
5989   return SDValue();
5990 }
5991 
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5992 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5993                                            SDValue N3, ISD::CondCode CC,
5994                                            SelectionDAG &DAG) {
5995   unsigned BW;
5996   bool Unsigned;
5997   SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
5998   if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5999     return SDValue();
6000   EVT FPVT = Fp.getOperand(0).getValueType();
6001   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
6002   if (FPVT.isVector())
6003     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
6004                              FPVT.getVectorElementCount());
6005   unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
6006   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
6007     return SDValue();
6008   SDLoc DL(Fp);
6009   SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
6010                             DAG.getValueType(NewVT.getScalarType()));
6011   return DAG.getExtOrTrunc(!Unsigned, Sat, DL, N2->getValueType(0));
6012 }
6013 
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)6014 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
6015                                          SDValue N3, ISD::CondCode CC,
6016                                          SelectionDAG &DAG) {
6017   // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
6018   // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
6019   // be truncated versions of the setcc (N0/N1).
6020   if ((N0 != N2 &&
6021        (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
6022       N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
6023     return SDValue();
6024   ConstantSDNode *N1C = isConstOrConstSplat(N1);
6025   ConstantSDNode *N3C = isConstOrConstSplat(N3);
6026   if (!N1C || !N3C)
6027     return SDValue();
6028   const APInt &C1 = N1C->getAPIntValue();
6029   const APInt &C3 = N3C->getAPIntValue();
6030   if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
6031       C1 != C3.zext(C1.getBitWidth()))
6032     return SDValue();
6033 
6034   unsigned BW = (C1 + 1).exactLogBase2();
6035   EVT FPVT = N0.getOperand(0).getValueType();
6036   EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
6037   if (FPVT.isVector())
6038     NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
6039                              FPVT.getVectorElementCount());
6040   if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
6041                                                         FPVT, NewVT))
6042     return SDValue();
6043 
6044   SDValue Sat =
6045       DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
6046                   DAG.getValueType(NewVT.getScalarType()));
6047   return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
6048 }
6049 
visitIMINMAX(SDNode * N)6050 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
6051   SDValue N0 = N->getOperand(0);
6052   SDValue N1 = N->getOperand(1);
6053   EVT VT = N0.getValueType();
6054   unsigned Opcode = N->getOpcode();
6055   SDLoc DL(N);
6056 
6057   // fold operation with constant operands.
6058   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
6059     return C;
6060 
6061   // If the operands are the same, this is a no-op.
6062   if (N0 == N1)
6063     return N0;
6064 
6065   // canonicalize constant to RHS
6066   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6067       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6068     return DAG.getNode(Opcode, DL, VT, N1, N0);
6069 
6070   // fold vector ops
6071   if (VT.isVector())
6072     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6073       return FoldedVOp;
6074 
6075   // reassociate minmax
6076   if (SDValue RMINMAX = reassociateOps(Opcode, DL, N0, N1, N->getFlags()))
6077     return RMINMAX;
6078 
6079   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
6080   // Only do this if:
6081   // 1. The current op isn't legal and the flipped is.
6082   // 2. The saturation pattern is broken by canonicalization in InstCombine.
6083   bool IsOpIllegal = !TLI.isOperationLegal(Opcode, VT);
6084   bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
6085   if ((IsSatBroken || IsOpIllegal) && (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
6086       (N1.isUndef() || DAG.SignBitIsZero(N1))) {
6087     unsigned AltOpcode;
6088     switch (Opcode) {
6089     case ISD::SMIN: AltOpcode = ISD::UMIN; break;
6090     case ISD::SMAX: AltOpcode = ISD::UMAX; break;
6091     case ISD::UMIN: AltOpcode = ISD::SMIN; break;
6092     case ISD::UMAX: AltOpcode = ISD::SMAX; break;
6093     default: llvm_unreachable("Unknown MINMAX opcode");
6094     }
6095     if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(AltOpcode, VT))
6096       return DAG.getNode(AltOpcode, DL, VT, N0, N1);
6097   }
6098 
6099   if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
6100     if (SDValue S = PerformMinMaxFpToSatCombine(
6101             N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
6102       return S;
6103   if (Opcode == ISD::UMIN)
6104     if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
6105       return S;
6106 
6107   // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
6108   auto ReductionOpcode = [](unsigned Opcode) {
6109     switch (Opcode) {
6110     case ISD::SMIN:
6111       return ISD::VECREDUCE_SMIN;
6112     case ISD::SMAX:
6113       return ISD::VECREDUCE_SMAX;
6114     case ISD::UMIN:
6115       return ISD::VECREDUCE_UMIN;
6116     case ISD::UMAX:
6117       return ISD::VECREDUCE_UMAX;
6118     default:
6119       llvm_unreachable("Unexpected opcode");
6120     }
6121   };
6122   if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
6123                                         SDLoc(N), VT, N0, N1))
6124     return SD;
6125 
6126   // Simplify the operands using demanded-bits information.
6127   if (SimplifyDemandedBits(SDValue(N, 0)))
6128     return SDValue(N, 0);
6129 
6130   return SDValue();
6131 }
6132 
6133 /// If this is a bitwise logic instruction and both operands have the same
6134 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)6135 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
6136   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
6137   EVT VT = N0.getValueType();
6138   unsigned LogicOpcode = N->getOpcode();
6139   unsigned HandOpcode = N0.getOpcode();
6140   assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
6141   assert(HandOpcode == N1.getOpcode() && "Bad input!");
6142 
6143   // Bail early if none of these transforms apply.
6144   if (N0.getNumOperands() == 0)
6145     return SDValue();
6146 
6147   // FIXME: We should check number of uses of the operands to not increase
6148   //        the instruction count for all transforms.
6149 
6150   // Handle size-changing casts (or sign_extend_inreg).
6151   SDValue X = N0.getOperand(0);
6152   SDValue Y = N1.getOperand(0);
6153   EVT XVT = X.getValueType();
6154   SDLoc DL(N);
6155   if (ISD::isExtOpcode(HandOpcode) || ISD::isExtVecInRegOpcode(HandOpcode) ||
6156       (HandOpcode == ISD::SIGN_EXTEND_INREG &&
6157        N0.getOperand(1) == N1.getOperand(1))) {
6158     // If both operands have other uses, this transform would create extra
6159     // instructions without eliminating anything.
6160     if (!N0.hasOneUse() && !N1.hasOneUse())
6161       return SDValue();
6162     // We need matching integer source types.
6163     if (XVT != Y.getValueType())
6164       return SDValue();
6165     // Don't create an illegal op during or after legalization. Don't ever
6166     // create an unsupported vector op.
6167     if ((VT.isVector() || LegalOperations) &&
6168         !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
6169       return SDValue();
6170     // Avoid infinite looping with PromoteIntBinOp.
6171     // TODO: Should we apply desirable/legal constraints to all opcodes?
6172     if ((HandOpcode == ISD::ANY_EXTEND ||
6173          HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
6174         LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
6175       return SDValue();
6176     // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
6177     SDNodeFlags LogicFlags;
6178     LogicFlags.setDisjoint(N->getFlags().hasDisjoint() &&
6179                            ISD::isExtOpcode(HandOpcode));
6180     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y, LogicFlags);
6181     if (HandOpcode == ISD::SIGN_EXTEND_INREG)
6182       return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
6183     return DAG.getNode(HandOpcode, DL, VT, Logic);
6184   }
6185 
6186   // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
6187   if (HandOpcode == ISD::TRUNCATE) {
6188     // If both operands have other uses, this transform would create extra
6189     // instructions without eliminating anything.
6190     if (!N0.hasOneUse() && !N1.hasOneUse())
6191       return SDValue();
6192     // We need matching source types.
6193     if (XVT != Y.getValueType())
6194       return SDValue();
6195     // Don't create an illegal op during or after legalization.
6196     if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
6197       return SDValue();
6198     // Be extra careful sinking truncate. If it's free, there's no benefit in
6199     // widening a binop. Also, don't create a logic op on an illegal type.
6200     if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
6201       return SDValue();
6202     if (!TLI.isTypeLegal(XVT))
6203       return SDValue();
6204     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6205     return DAG.getNode(HandOpcode, DL, VT, Logic);
6206   }
6207 
6208   // For binops SHL/SRL/SRA/AND:
6209   //   logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
6210   if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
6211        HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
6212       N0.getOperand(1) == N1.getOperand(1)) {
6213     // If either operand has other uses, this transform is not an improvement.
6214     if (!N0.hasOneUse() || !N1.hasOneUse())
6215       return SDValue();
6216     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6217     return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
6218   }
6219 
6220   // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
6221   if (HandOpcode == ISD::BSWAP) {
6222     // If either operand has other uses, this transform is not an improvement.
6223     if (!N0.hasOneUse() || !N1.hasOneUse())
6224       return SDValue();
6225     SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6226     return DAG.getNode(HandOpcode, DL, VT, Logic);
6227   }
6228 
6229   // For funnel shifts FSHL/FSHR:
6230   // logic_op (OP x, x1, s), (OP y, y1, s) -->
6231   // --> OP (logic_op x, y), (logic_op, x1, y1), s
6232   if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
6233       N0.getOperand(2) == N1.getOperand(2)) {
6234     if (!N0.hasOneUse() || !N1.hasOneUse())
6235       return SDValue();
6236     SDValue X1 = N0.getOperand(1);
6237     SDValue Y1 = N1.getOperand(1);
6238     SDValue S = N0.getOperand(2);
6239     SDValue Logic0 = DAG.getNode(LogicOpcode, DL, VT, X, Y);
6240     SDValue Logic1 = DAG.getNode(LogicOpcode, DL, VT, X1, Y1);
6241     return DAG.getNode(HandOpcode, DL, VT, Logic0, Logic1, S);
6242   }
6243 
6244   // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
6245   // Only perform this optimization up until type legalization, before
6246   // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
6247   // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
6248   // we don't want to undo this promotion.
6249   // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
6250   // on scalars.
6251   if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
6252        Level <= AfterLegalizeTypes) {
6253     // Input types must be integer and the same.
6254     if (XVT.isInteger() && XVT == Y.getValueType() &&
6255         !(VT.isVector() && TLI.isTypeLegal(VT) &&
6256           !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
6257       SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6258       return DAG.getNode(HandOpcode, DL, VT, Logic);
6259     }
6260   }
6261 
6262   // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
6263   // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
6264   // If both shuffles use the same mask, and both shuffle within a single
6265   // vector, then it is worthwhile to move the swizzle after the operation.
6266   // The type-legalizer generates this pattern when loading illegal
6267   // vector types from memory. In many cases this allows additional shuffle
6268   // optimizations.
6269   // There are other cases where moving the shuffle after the xor/and/or
6270   // is profitable even if shuffles don't perform a swizzle.
6271   // If both shuffles use the same mask, and both shuffles have the same first
6272   // or second operand, then it might still be profitable to move the shuffle
6273   // after the xor/and/or operation.
6274   if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
6275     auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
6276     auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
6277     assert(X.getValueType() == Y.getValueType() &&
6278            "Inputs to shuffles are not the same type");
6279 
6280     // Check that both shuffles use the same mask. The masks are known to be of
6281     // the same length because the result vector type is the same.
6282     // Check also that shuffles have only one use to avoid introducing extra
6283     // instructions.
6284     if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
6285         !SVN0->getMask().equals(SVN1->getMask()))
6286       return SDValue();
6287 
6288     // Don't try to fold this node if it requires introducing a
6289     // build vector of all zeros that might be illegal at this stage.
6290     SDValue ShOp = N0.getOperand(1);
6291     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6292       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6293 
6294     // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
6295     if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
6296       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
6297                                   N0.getOperand(0), N1.getOperand(0));
6298       return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
6299     }
6300 
6301     // Don't try to fold this node if it requires introducing a
6302     // build vector of all zeros that might be illegal at this stage.
6303     ShOp = N0.getOperand(0);
6304     if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6305       ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6306 
6307     // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
6308     if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
6309       SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
6310                                   N1.getOperand(1));
6311       return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
6312     }
6313   }
6314 
6315   return SDValue();
6316 }
6317 
6318 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)6319 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
6320                                        const SDLoc &DL) {
6321   SDValue LL, LR, RL, RR, N0CC, N1CC;
6322   if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
6323       !isSetCCEquivalent(N1, RL, RR, N1CC))
6324     return SDValue();
6325 
6326   assert(N0.getValueType() == N1.getValueType() &&
6327          "Unexpected operand types for bitwise logic op");
6328   assert(LL.getValueType() == LR.getValueType() &&
6329          RL.getValueType() == RR.getValueType() &&
6330          "Unexpected operand types for setcc");
6331 
6332   // If we're here post-legalization or the logic op type is not i1, the logic
6333   // op type must match a setcc result type. Also, all folds require new
6334   // operations on the left and right operands, so those types must match.
6335   EVT VT = N0.getValueType();
6336   EVT OpVT = LL.getValueType();
6337   if (LegalOperations || VT.getScalarType() != MVT::i1)
6338     if (VT != getSetCCResultType(OpVT))
6339       return SDValue();
6340   if (OpVT != RL.getValueType())
6341     return SDValue();
6342 
6343   ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
6344   ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
6345   bool IsInteger = OpVT.isInteger();
6346   if (LR == RR && CC0 == CC1 && IsInteger) {
6347     bool IsZero = isNullOrNullSplat(LR);
6348     bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
6349 
6350     // All bits clear?
6351     bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
6352     // All sign bits clear?
6353     bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
6354     // Any bits set?
6355     bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
6356     // Any sign bits set?
6357     bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
6358 
6359     // (and (seteq X,  0), (seteq Y,  0)) --> (seteq (or X, Y),  0)
6360     // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
6361     // (or  (setne X,  0), (setne Y,  0)) --> (setne (or X, Y),  0)
6362     // (or  (setlt X,  0), (setlt Y,  0)) --> (setlt (or X, Y),  0)
6363     if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
6364       SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
6365       AddToWorklist(Or.getNode());
6366       return DAG.getSetCC(DL, VT, Or, LR, CC1);
6367     }
6368 
6369     // All bits set?
6370     bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
6371     // All sign bits set?
6372     bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
6373     // Any bits clear?
6374     bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
6375     // Any sign bits clear?
6376     bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
6377 
6378     // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
6379     // (and (setlt X,  0), (setlt Y,  0)) --> (setlt (and X, Y),  0)
6380     // (or  (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
6381     // (or  (setgt X, -1), (setgt Y  -1)) --> (setgt (and X, Y), -1)
6382     if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6383       SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
6384       AddToWorklist(And.getNode());
6385       return DAG.getSetCC(DL, VT, And, LR, CC1);
6386     }
6387   }
6388 
6389   // TODO: What is the 'or' equivalent of this fold?
6390   // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6391   if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6392       IsInteger && CC0 == ISD::SETNE &&
6393       ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
6394        (isAllOnesConstant(LR) && isNullConstant(RR)))) {
6395     SDValue One = DAG.getConstant(1, DL, OpVT);
6396     SDValue Two = DAG.getConstant(2, DL, OpVT);
6397     SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
6398     AddToWorklist(Add.getNode());
6399     return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
6400   }
6401 
6402   // Try more general transforms if the predicates match and the only user of
6403   // the compares is the 'and' or 'or'.
6404   if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
6405       N0.hasOneUse() && N1.hasOneUse()) {
6406     // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6407     // or  (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6408     if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6409       SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
6410       SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
6411       SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
6412       SDValue Zero = DAG.getConstant(0, DL, OpVT);
6413       return DAG.getSetCC(DL, VT, Or, Zero, CC1);
6414     }
6415 
6416     // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6417     if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6418       // Match a shared variable operand and 2 non-opaque constant operands.
6419       auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6420         // The difference of the constants must be a single bit.
6421         const APInt &CMax =
6422             APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
6423         const APInt &CMin =
6424             APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
6425         return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6426       };
6427       if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
6428         // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6429         // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6430         SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
6431         SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
6432         SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
6433         SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
6434         SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
6435         SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
6436         SDValue Zero = DAG.getConstant(0, DL, OpVT);
6437         return DAG.getSetCC(DL, VT, And, Zero, CC0);
6438       }
6439     }
6440   }
6441 
6442   // Canonicalize equivalent operands to LL == RL.
6443   if (LL == RR && LR == RL) {
6444     CC1 = ISD::getSetCCSwappedOperands(CC1);
6445     std::swap(RL, RR);
6446   }
6447 
6448   // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6449   // (or  (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6450   if (LL == RL && LR == RR) {
6451     ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
6452                                 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
6453     if (NewCC != ISD::SETCC_INVALID &&
6454         (!LegalOperations ||
6455          (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
6456           TLI.isOperationLegal(ISD::SETCC, OpVT))))
6457       return DAG.getSetCC(DL, VT, LL, LR, NewCC);
6458   }
6459 
6460   return SDValue();
6461 }
6462 
arebothOperandsNotSNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6463 static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6464                                    SelectionDAG &DAG) {
6465   return DAG.isKnownNeverSNaN(Operand2) && DAG.isKnownNeverSNaN(Operand1);
6466 }
6467 
arebothOperandsNotNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6468 static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6469                                   SelectionDAG &DAG) {
6470   return DAG.isKnownNeverNaN(Operand2) && DAG.isKnownNeverNaN(Operand1);
6471 }
6472 
6473 // FIXME: use FMINIMUMNUM if possible, such as for RISC-V.
getMinMaxOpcodeForFP(SDValue Operand1,SDValue Operand2,ISD::CondCode CC,unsigned OrAndOpcode,SelectionDAG & DAG,bool isFMAXNUMFMINNUM_IEEE,bool isFMAXNUMFMINNUM)6474 static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
6475                                      ISD::CondCode CC, unsigned OrAndOpcode,
6476                                      SelectionDAG &DAG,
6477                                      bool isFMAXNUMFMINNUM_IEEE,
6478                                      bool isFMAXNUMFMINNUM) {
6479   // The optimization cannot be applied for all the predicates because
6480   // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6481   // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6482   // applied at all if one of the operands is a signaling NaN.
6483 
6484   // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6485   // are non NaN values.
6486   if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6487       ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
6488     return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6489                    isFMAXNUMFMINNUM_IEEE
6490                ? ISD::FMINNUM_IEEE
6491                : ISD::DELETED_NODE;
6492   else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
6493             (OrAndOpcode == ISD::OR)) ||
6494            ((CC == ISD::SETLT || CC == ISD::SETLE) &&
6495             (OrAndOpcode == ISD::AND)))
6496     return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6497                    isFMAXNUMFMINNUM_IEEE
6498                ? ISD::FMAXNUM_IEEE
6499                : ISD::DELETED_NODE;
6500   // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6501   // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6502   // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6503   // that there are not any sNaNs, then the optimization is not valid
6504   // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6505   // the optimization using FMINNUM/FMAXNUM for the following cases. If
6506   // we can prove that we do not have any sNaNs, then we can do the
6507   // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6508   // cases.
6509   else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6510             (OrAndOpcode == ISD::OR)) ||
6511            ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6512             (OrAndOpcode == ISD::AND)))
6513     return isFMAXNUMFMINNUM ? ISD::FMINNUM
6514                             : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6515                                       isFMAXNUMFMINNUM_IEEE
6516                                   ? ISD::FMINNUM_IEEE
6517                                   : ISD::DELETED_NODE;
6518   else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6519             (OrAndOpcode == ISD::OR)) ||
6520            ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6521             (OrAndOpcode == ISD::AND)))
6522     return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6523                             : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6524                                       isFMAXNUMFMINNUM_IEEE
6525                                   ? ISD::FMAXNUM_IEEE
6526                                   : ISD::DELETED_NODE;
6527   return ISD::DELETED_NODE;
6528 }
6529 
foldAndOrOfSETCC(SDNode * LogicOp,SelectionDAG & DAG)6530 static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6531   using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6532   assert(
6533       (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6534       "Invalid Op to combine SETCC with");
6535 
6536   // TODO: Search past casts/truncates.
6537   SDValue LHS = LogicOp->getOperand(0);
6538   SDValue RHS = LogicOp->getOperand(1);
6539   if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6540       !LHS->hasOneUse() || !RHS->hasOneUse())
6541     return SDValue();
6542 
6543   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6544   AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6545       LogicOp, LHS.getNode(), RHS.getNode());
6546 
6547   SDValue LHS0 = LHS->getOperand(0);
6548   SDValue RHS0 = RHS->getOperand(0);
6549   SDValue LHS1 = LHS->getOperand(1);
6550   SDValue RHS1 = RHS->getOperand(1);
6551   // TODO: We don't actually need a splat here, for vectors we just need the
6552   // invariants to hold for each element.
6553   auto *LHS1C = isConstOrConstSplat(LHS1);
6554   auto *RHS1C = isConstOrConstSplat(RHS1);
6555   ISD::CondCode CCL = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
6556   ISD::CondCode CCR = cast<CondCodeSDNode>(RHS.getOperand(2))->get();
6557   EVT VT = LogicOp->getValueType(0);
6558   EVT OpVT = LHS0.getValueType();
6559   SDLoc DL(LogicOp);
6560 
6561   // Check if the operands of an and/or operation are comparisons and if they
6562   // compare against the same value. Replace the and/or-cmp-cmp sequence with
6563   // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6564   // sequence will be replaced with min-cmp sequence:
6565   // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6566   // and and-cmp-cmp will be replaced with max-cmp sequence:
6567   // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6568   // The optimization does not work for `==` or `!=` .
6569   // The two comparisons should have either the same predicate or the
6570   // predicate of one of the comparisons is the opposite of the other one.
6571   bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(ISD::FMAXNUM_IEEE, OpVT) &&
6572                                TLI.isOperationLegal(ISD::FMINNUM_IEEE, OpVT);
6573   bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(ISD::FMAXNUM, OpVT) &&
6574                           TLI.isOperationLegalOrCustom(ISD::FMINNUM, OpVT);
6575   if (((OpVT.isInteger() && TLI.isOperationLegal(ISD::UMAX, OpVT) &&
6576         TLI.isOperationLegal(ISD::SMAX, OpVT) &&
6577         TLI.isOperationLegal(ISD::UMIN, OpVT) &&
6578         TLI.isOperationLegal(ISD::SMIN, OpVT)) ||
6579        (OpVT.isFloatingPoint() &&
6580         (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6581       !ISD::isIntEqualitySetCC(CCL) && !ISD::isFPEqualitySetCC(CCL) &&
6582       CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6583       CCL != ISD::SETTRUE &&
6584       (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(CCR))) {
6585 
6586     SDValue CommonValue, Operand1, Operand2;
6587     ISD::CondCode CC = ISD::SETCC_INVALID;
6588     if (CCL == CCR) {
6589       if (LHS0 == RHS0) {
6590         CommonValue = LHS0;
6591         Operand1 = LHS1;
6592         Operand2 = RHS1;
6593         CC = ISD::getSetCCSwappedOperands(CCL);
6594       } else if (LHS1 == RHS1) {
6595         CommonValue = LHS1;
6596         Operand1 = LHS0;
6597         Operand2 = RHS0;
6598         CC = CCL;
6599       }
6600     } else {
6601       assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6602       if (LHS0 == RHS1) {
6603         CommonValue = LHS0;
6604         Operand1 = LHS1;
6605         Operand2 = RHS0;
6606         CC = CCR;
6607       } else if (RHS0 == LHS1) {
6608         CommonValue = LHS1;
6609         Operand1 = LHS0;
6610         Operand2 = RHS1;
6611         CC = CCL;
6612       }
6613     }
6614 
6615     // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6616     // handle it using OR/AND.
6617     if (CC == ISD::SETLT && isNullOrNullSplat(CommonValue))
6618       CC = ISD::SETCC_INVALID;
6619     else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CommonValue))
6620       CC = ISD::SETCC_INVALID;
6621 
6622     if (CC != ISD::SETCC_INVALID) {
6623       unsigned NewOpcode = ISD::DELETED_NODE;
6624       bool IsSigned = isSignedIntSetCC(CC);
6625       if (OpVT.isInteger()) {
6626         bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6627                        CC == ISD::SETLT || CC == ISD::SETULT);
6628         bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6629         if (IsLess == IsOr)
6630           NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6631         else
6632           NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6633       } else if (OpVT.isFloatingPoint())
6634         NewOpcode =
6635             getMinMaxOpcodeForFP(Operand1, Operand2, CC, LogicOp->getOpcode(),
6636                                  DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6637 
6638       if (NewOpcode != ISD::DELETED_NODE) {
6639         SDValue MinMaxValue =
6640             DAG.getNode(NewOpcode, DL, OpVT, Operand1, Operand2);
6641         return DAG.getSetCC(DL, VT, MinMaxValue, CommonValue, CC);
6642       }
6643     }
6644   }
6645 
6646   if (LHS0 == LHS1 && RHS0 == RHS1 && CCL == CCR &&
6647       LHS0.getValueType() == RHS0.getValueType() &&
6648       ((LogicOp->getOpcode() == ISD::AND && CCL == ISD::SETO) ||
6649        (LogicOp->getOpcode() == ISD::OR && CCL == ISD::SETUO)))
6650     return DAG.getSetCC(DL, VT, LHS0, RHS0, CCL);
6651 
6652   if (TargetPreference == AndOrSETCCFoldKind::None)
6653     return SDValue();
6654 
6655   if (CCL == CCR &&
6656       CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6657       LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6658     const APInt &APLhs = LHS1C->getAPIntValue();
6659     const APInt &APRhs = RHS1C->getAPIntValue();
6660 
6661     // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6662     // case this is just a compare).
6663     if (APLhs == (-APRhs) &&
6664         ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6665          DAG.doesNodeExist(ISD::ABS, DAG.getVTList(OpVT), {LHS0}))) {
6666       const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6667       // (icmp eq A, C) | (icmp eq A, -C)
6668       //    -> (icmp eq Abs(A), C)
6669       // (icmp ne A, C) & (icmp ne A, -C)
6670       //    -> (icmp ne Abs(A), C)
6671       SDValue AbsOp = DAG.getNode(ISD::ABS, DL, OpVT, LHS0);
6672       return DAG.getNode(ISD::SETCC, DL, VT, AbsOp,
6673                          DAG.getConstant(C, DL, OpVT), LHS.getOperand(2));
6674     } else if (TargetPreference &
6675                (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6676 
6677       // AndOrSETCCFoldKind::AddAnd:
6678       // A == C0 | A == C1
6679       //  IF IsPow2(smax(C0, C1)-smin(C0, C1))
6680       //    -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6681       // A != C0 & A != C1
6682       //  IF IsPow2(smax(C0, C1)-smin(C0, C1))
6683       //    -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6684 
6685       // AndOrSETCCFoldKind::NotAnd:
6686       // A == C0 | A == C1
6687       //  IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6688       //    -> ~A & smin(C0, C1) == 0
6689       // A != C0 & A != C1
6690       //  IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6691       //    -> ~A & smin(C0, C1) != 0
6692 
6693       const APInt &MaxC = APIntOps::smax(APRhs, APLhs);
6694       const APInt &MinC = APIntOps::smin(APRhs, APLhs);
6695       APInt Dif = MaxC - MinC;
6696       if (!Dif.isZero() && Dif.isPowerOf2()) {
6697         if (MaxC.isAllOnes() &&
6698             (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6699           SDValue NotOp = DAG.getNOT(DL, LHS0, OpVT);
6700           SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, NotOp,
6701                                       DAG.getConstant(MinC, DL, OpVT));
6702           return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6703                              DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6704         } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6705 
6706           SDValue AddOp = DAG.getNode(ISD::ADD, DL, OpVT, LHS0,
6707                                       DAG.getConstant(-MinC, DL, OpVT));
6708           SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, AddOp,
6709                                       DAG.getConstant(~Dif, DL, OpVT));
6710           return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6711                              DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6712         }
6713       }
6714     }
6715   }
6716 
6717   return SDValue();
6718 }
6719 
6720 // Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6721 // We canonicalize to the `select` form in the middle end, but the `and` form
6722 // gets better codegen and all tested targets (arm, x86, riscv)
combineSelectAsExtAnd(SDValue Cond,SDValue T,SDValue F,const SDLoc & DL,SelectionDAG & DAG)6723 static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6724                                      const SDLoc &DL, SelectionDAG &DAG) {
6725   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6726   if (!isNullConstant(F))
6727     return SDValue();
6728 
6729   EVT CondVT = Cond.getValueType();
6730   if (TLI.getBooleanContents(CondVT) !=
6731       TargetLoweringBase::ZeroOrOneBooleanContent)
6732     return SDValue();
6733 
6734   if (T.getOpcode() != ISD::AND)
6735     return SDValue();
6736 
6737   if (!isOneConstant(T.getOperand(1)))
6738     return SDValue();
6739 
6740   EVT OpVT = T.getValueType();
6741 
6742   SDValue CondMask =
6743       OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Cond, DL, OpVT, CondVT);
6744   return DAG.getNode(ISD::AND, DL, OpVT, CondMask, T.getOperand(0));
6745 }
6746 
6747 /// This contains all DAGCombine rules which reduce two values combined by
6748 /// an And operation to a single value. This makes them reusable in the context
6749 /// of visitSELECT(). Rules involving constants are not included as
6750 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)6751 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6752   EVT VT = N1.getValueType();
6753   SDLoc DL(N);
6754 
6755   // fold (and x, undef) -> 0
6756   if (N0.isUndef() || N1.isUndef())
6757     return DAG.getConstant(0, DL, VT);
6758 
6759   if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
6760     return V;
6761 
6762   // Canonicalize:
6763   //   and(x, add) -> and(add, x)
6764   if (N1.getOpcode() == ISD::ADD)
6765     std::swap(N0, N1);
6766 
6767   // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6768   if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6769       VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6770     if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
6771       if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
6772         // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6773         // immediate for an add, but it is legal if its top c2 bits are set,
6774         // transform the ADD so the immediate doesn't need to be materialized
6775         // in a register.
6776         APInt ADDC = ADDI->getAPIntValue();
6777         APInt SRLC = SRLI->getAPIntValue();
6778         if (ADDC.getSignificantBits() <= 64 && SRLC.ult(VT.getSizeInBits()) &&
6779             !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6780           APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
6781                                              SRLC.getZExtValue());
6782           if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
6783             ADDC |= Mask;
6784             if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6785               SDLoc DL0(N0);
6786               SDValue NewAdd =
6787                 DAG.getNode(ISD::ADD, DL0, VT,
6788                             N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
6789               CombineTo(N0.getNode(), NewAdd);
6790               // Return N so it doesn't get rechecked!
6791               return SDValue(N, 0);
6792             }
6793           }
6794         }
6795       }
6796     }
6797   }
6798 
6799   return SDValue();
6800 }
6801 
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)6802 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6803                                    EVT LoadResultTy, EVT &ExtVT) {
6804   if (!AndC->getAPIntValue().isMask())
6805     return false;
6806 
6807   unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6808 
6809   ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6810   EVT LoadedVT = LoadN->getMemoryVT();
6811 
6812   if (ExtVT == LoadedVT &&
6813       (!LegalOperations ||
6814        TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
6815     // ZEXTLOAD will match without needing to change the size of the value being
6816     // loaded.
6817     return true;
6818   }
6819 
6820   // Do not change the width of a volatile or atomic loads.
6821   if (!LoadN->isSimple())
6822     return false;
6823 
6824   // Do not generate loads of non-round integer types since these can
6825   // be expensive (and would be wrong if the type is not byte sized).
6826   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
6827     return false;
6828 
6829   if (LegalOperations &&
6830       !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
6831     return false;
6832 
6833   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0))
6834     return false;
6835 
6836   return true;
6837 }
6838 
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)6839 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6840                                     ISD::LoadExtType ExtType, EVT &MemVT,
6841                                     unsigned ShAmt) {
6842   if (!LDST)
6843     return false;
6844 
6845   // Only allow byte offsets.
6846   if (ShAmt % 8)
6847     return false;
6848   const unsigned ByteShAmt = ShAmt / 8;
6849 
6850   // Do not generate loads of non-round integer types since these can
6851   // be expensive (and would be wrong if the type is not byte sized).
6852   if (!MemVT.isRound())
6853     return false;
6854 
6855   // Don't change the width of a volatile or atomic loads.
6856   if (!LDST->isSimple())
6857     return false;
6858 
6859   EVT LdStMemVT = LDST->getMemoryVT();
6860 
6861   // Bail out when changing the scalable property, since we can't be sure that
6862   // we're actually narrowing here.
6863   if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6864     return false;
6865 
6866   // Verify that we are actually reducing a load width here.
6867   if (LdStMemVT.bitsLT(MemVT))
6868     return false;
6869 
6870   // Ensure that this isn't going to produce an unsupported memory access.
6871   if (ShAmt) {
6872     const Align LDSTAlign = LDST->getAlign();
6873     const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
6874     if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
6875                                 LDST->getAddressSpace(), NarrowAlign,
6876                                 LDST->getMemOperand()->getFlags()))
6877       return false;
6878   }
6879 
6880   // It's not possible to generate a constant of extended or untyped type.
6881   EVT PtrType = LDST->getBasePtr().getValueType();
6882   if (PtrType == MVT::Untyped || PtrType.isExtended())
6883     return false;
6884 
6885   if (isa<LoadSDNode>(LDST)) {
6886     LoadSDNode *Load = cast<LoadSDNode>(LDST);
6887     // Don't transform one with multiple uses, this would require adding a new
6888     // load.
6889     if (!SDValue(Load, 0).hasOneUse())
6890       return false;
6891 
6892     if (LegalOperations &&
6893         !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
6894       return false;
6895 
6896     // For the transform to be legal, the load must produce only two values
6897     // (the value loaded and the chain).  Don't transform a pre-increment
6898     // load, for example, which produces an extra value.  Otherwise the
6899     // transformation is not equivalent, and the downstream logic to replace
6900     // uses gets things wrong.
6901     if (Load->getNumValues() > 2)
6902       return false;
6903 
6904     // If the load that we're shrinking is an extload and we're not just
6905     // discarding the extension we can't simply shrink the load. Bail.
6906     // TODO: It would be possible to merge the extensions in some cases.
6907     if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6908         Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6909       return false;
6910 
6911     if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT, ByteShAmt))
6912       return false;
6913   } else {
6914     assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6915     StoreSDNode *Store = cast<StoreSDNode>(LDST);
6916     // Can't write outside the original store
6917     if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6918       return false;
6919 
6920     if (LegalOperations &&
6921         !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
6922       return false;
6923   }
6924   return true;
6925 }
6926 
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)6927 bool DAGCombiner::SearchForAndLoads(SDNode *N,
6928                                     SmallVectorImpl<LoadSDNode*> &Loads,
6929                                     SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6930                                     ConstantSDNode *Mask,
6931                                     SDNode *&NodeToMask) {
6932   // Recursively search for the operands, looking for loads which can be
6933   // narrowed.
6934   for (SDValue Op : N->op_values()) {
6935     if (Op.getValueType().isVector())
6936       return false;
6937 
6938     // Some constants may need fixing up later if they are too large.
6939     if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
6940       assert(ISD::isBitwiseLogicOp(N->getOpcode()) &&
6941              "Expected bitwise logic operation");
6942       if (!C->getAPIntValue().isSubsetOf(Mask->getAPIntValue()))
6943         NodesWithConsts.insert(N);
6944       continue;
6945     }
6946 
6947     if (!Op.hasOneUse())
6948       return false;
6949 
6950     switch(Op.getOpcode()) {
6951     case ISD::LOAD: {
6952       auto *Load = cast<LoadSDNode>(Op);
6953       EVT ExtVT;
6954       if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
6955           isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
6956 
6957         // ZEXTLOAD is already small enough.
6958         if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6959             ExtVT.bitsGE(Load->getMemoryVT()))
6960           continue;
6961 
6962         // Use LE to convert equal sized loads to zext.
6963         if (ExtVT.bitsLE(Load->getMemoryVT()))
6964           Loads.push_back(Load);
6965 
6966         continue;
6967       }
6968       return false;
6969     }
6970     case ISD::ZERO_EXTEND:
6971     case ISD::AssertZext: {
6972       unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6973       EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6974       EVT VT = Op.getOpcode() == ISD::AssertZext ?
6975         cast<VTSDNode>(Op.getOperand(1))->getVT() :
6976         Op.getOperand(0).getValueType();
6977 
6978       // We can accept extending nodes if the mask is wider or an equal
6979       // width to the original type.
6980       if (ExtVT.bitsGE(VT))
6981         continue;
6982       break;
6983     }
6984     case ISD::OR:
6985     case ISD::XOR:
6986     case ISD::AND:
6987       if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
6988                              NodeToMask))
6989         return false;
6990       continue;
6991     }
6992 
6993     // Allow one node which will masked along with any loads found.
6994     if (NodeToMask)
6995       return false;
6996 
6997     // Also ensure that the node to be masked only produces one data result.
6998     NodeToMask = Op.getNode();
6999     if (NodeToMask->getNumValues() > 1) {
7000       bool HasValue = false;
7001       for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
7002         MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
7003         if (VT != MVT::Glue && VT != MVT::Other) {
7004           if (HasValue) {
7005             NodeToMask = nullptr;
7006             return false;
7007           }
7008           HasValue = true;
7009         }
7010       }
7011       assert(HasValue && "Node to be masked has no data result?");
7012     }
7013   }
7014   return true;
7015 }
7016 
BackwardsPropagateMask(SDNode * N)7017 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
7018   auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
7019   if (!Mask)
7020     return false;
7021 
7022   if (!Mask->getAPIntValue().isMask())
7023     return false;
7024 
7025   // No need to do anything if the and directly uses a load.
7026   if (isa<LoadSDNode>(N->getOperand(0)))
7027     return false;
7028 
7029   SmallVector<LoadSDNode*, 8> Loads;
7030   SmallPtrSet<SDNode*, 2> NodesWithConsts;
7031   SDNode *FixupNode = nullptr;
7032   if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
7033     if (Loads.empty())
7034       return false;
7035 
7036     LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
7037     SDValue MaskOp = N->getOperand(1);
7038 
7039     // If it exists, fixup the single node we allow in the tree that needs
7040     // masking.
7041     if (FixupNode) {
7042       LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
7043       SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
7044                                 FixupNode->getValueType(0),
7045                                 SDValue(FixupNode, 0), MaskOp);
7046       DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
7047       if (And.getOpcode() == ISD ::AND)
7048         DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
7049     }
7050 
7051     // Narrow any constants that need it.
7052     for (auto *LogicN : NodesWithConsts) {
7053       SDValue Op0 = LogicN->getOperand(0);
7054       SDValue Op1 = LogicN->getOperand(1);
7055 
7056       // We only need to fix AND if both inputs are constants. And we only need
7057       // to fix one of the constants.
7058       if (LogicN->getOpcode() == ISD::AND &&
7059           (!isa<ConstantSDNode>(Op0) || !isa<ConstantSDNode>(Op1)))
7060         continue;
7061 
7062       if (isa<ConstantSDNode>(Op0) && LogicN->getOpcode() != ISD::AND)
7063         Op0 =
7064             DAG.getNode(ISD::AND, SDLoc(Op0), Op0.getValueType(), Op0, MaskOp);
7065 
7066       if (isa<ConstantSDNode>(Op1))
7067         Op1 =
7068             DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(), Op1, MaskOp);
7069 
7070       if (isa<ConstantSDNode>(Op0) && !isa<ConstantSDNode>(Op1))
7071         std::swap(Op0, Op1);
7072 
7073       DAG.UpdateNodeOperands(LogicN, Op0, Op1);
7074     }
7075 
7076     // Create narrow loads.
7077     for (auto *Load : Loads) {
7078       LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
7079       SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
7080                                 SDValue(Load, 0), MaskOp);
7081       DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
7082       if (And.getOpcode() == ISD ::AND)
7083         And = SDValue(
7084             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
7085       SDValue NewLoad = reduceLoadWidth(And.getNode());
7086       assert(NewLoad &&
7087              "Shouldn't be masking the load if it can't be narrowed");
7088       CombineTo(Load, NewLoad, NewLoad.getValue(1));
7089     }
7090     DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
7091     return true;
7092   }
7093   return false;
7094 }
7095 
7096 // Unfold
7097 //    x &  (-1 'logical shift' y)
7098 // To
7099 //    (x 'opposite logical shift' y) 'logical shift' y
7100 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)7101 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
7102   assert(N->getOpcode() == ISD::AND);
7103 
7104   SDValue N0 = N->getOperand(0);
7105   SDValue N1 = N->getOperand(1);
7106 
7107   // Do we actually prefer shifts over mask?
7108   if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
7109     return SDValue();
7110 
7111   // Try to match  (-1 '[outer] logical shift' y)
7112   unsigned OuterShift;
7113   unsigned InnerShift; // The opposite direction to the OuterShift.
7114   SDValue Y;           // Shift amount.
7115   auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
7116     if (!M.hasOneUse())
7117       return false;
7118     OuterShift = M->getOpcode();
7119     if (OuterShift == ISD::SHL)
7120       InnerShift = ISD::SRL;
7121     else if (OuterShift == ISD::SRL)
7122       InnerShift = ISD::SHL;
7123     else
7124       return false;
7125     if (!isAllOnesConstant(M->getOperand(0)))
7126       return false;
7127     Y = M->getOperand(1);
7128     return true;
7129   };
7130 
7131   SDValue X;
7132   if (matchMask(N1))
7133     X = N0;
7134   else if (matchMask(N0))
7135     X = N1;
7136   else
7137     return SDValue();
7138 
7139   SDLoc DL(N);
7140   EVT VT = N->getValueType(0);
7141 
7142   //     tmp = x   'opposite logical shift' y
7143   SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
7144   //     ret = tmp 'logical shift' y
7145   SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
7146 
7147   return T1;
7148 }
7149 
7150 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
7151 /// For a target with a bit test, this is expected to become test + set and save
7152 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)7153 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
7154   assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
7155 
7156   // Look through an optional extension.
7157   SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
7158   if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
7159     And0 = And0.getOperand(0);
7160   if (!isOneConstant(And1) || !And0.hasOneUse())
7161     return SDValue();
7162 
7163   SDValue Src = And0;
7164 
7165   // Attempt to find a 'not' op.
7166   // TODO: Should we favor test+set even without the 'not' op?
7167   bool FoundNot = false;
7168   if (isBitwiseNot(Src)) {
7169     FoundNot = true;
7170     Src = Src.getOperand(0);
7171 
7172     // Look though an optional truncation. The source operand may not be the
7173     // same type as the original 'and', but that is ok because we are masking
7174     // off everything but the low bit.
7175     if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
7176       Src = Src.getOperand(0);
7177   }
7178 
7179   // Match a shift-right by constant.
7180   if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
7181     return SDValue();
7182 
7183   // This is probably not worthwhile without a supported type.
7184   EVT SrcVT = Src.getValueType();
7185   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7186   if (!TLI.isTypeLegal(SrcVT))
7187     return SDValue();
7188 
7189   // We might have looked through casts that make this transform invalid.
7190   unsigned BitWidth = SrcVT.getScalarSizeInBits();
7191   SDValue ShiftAmt = Src.getOperand(1);
7192   auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
7193   if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(BitWidth))
7194     return SDValue();
7195 
7196   // Set source to shift source.
7197   Src = Src.getOperand(0);
7198 
7199   // Try again to find a 'not' op.
7200   // TODO: Should we favor test+set even with two 'not' ops?
7201   if (!FoundNot) {
7202     if (!isBitwiseNot(Src))
7203       return SDValue();
7204     Src = Src.getOperand(0);
7205   }
7206 
7207   if (!TLI.hasBitTest(Src, ShiftAmt))
7208     return SDValue();
7209 
7210   // Turn this into a bit-test pattern using mask op + setcc:
7211   // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
7212   // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
7213   SDLoc DL(And);
7214   SDValue X = DAG.getZExtOrTrunc(Src, DL, SrcVT);
7215   EVT CCVT =
7216       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
7217   SDValue Mask = DAG.getConstant(
7218       APInt::getOneBitSet(BitWidth, ShiftAmtC->getZExtValue()), DL, SrcVT);
7219   SDValue NewAnd = DAG.getNode(ISD::AND, DL, SrcVT, X, Mask);
7220   SDValue Zero = DAG.getConstant(0, DL, SrcVT);
7221   SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
7222   return DAG.getZExtOrTrunc(Setcc, DL, And->getValueType(0));
7223 }
7224 
7225 /// For targets that support usubsat, match a bit-hack form of that operation
7226 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)7227 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG, const SDLoc &DL) {
7228   EVT VT = N->getValueType(0);
7229   unsigned BitWidth = VT.getScalarSizeInBits();
7230   APInt SignMask = APInt::getSignMask(BitWidth);
7231 
7232   // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
7233   // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
7234   // xor/add with SMIN (signmask) are logically equivalent.
7235   SDValue X;
7236   if (!sd_match(N, m_And(m_OneUse(m_Xor(m_Value(X), m_SpecificInt(SignMask))),
7237                          m_OneUse(m_Sra(m_Deferred(X),
7238                                         m_SpecificInt(BitWidth - 1))))) &&
7239       !sd_match(N, m_And(m_OneUse(m_Add(m_Value(X), m_SpecificInt(SignMask))),
7240                          m_OneUse(m_Sra(m_Deferred(X),
7241                                         m_SpecificInt(BitWidth - 1))))))
7242     return SDValue();
7243 
7244   return DAG.getNode(ISD::USUBSAT, DL, VT, X,
7245                      DAG.getConstant(SignMask, DL, VT));
7246 }
7247 
7248 /// Given a bitwise logic operation N with a matching bitwise logic operand,
7249 /// fold a pattern where 2 of the source operands are identically shifted
7250 /// values. For example:
7251 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)7252 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
7253                                  SelectionDAG &DAG) {
7254   unsigned LogicOpcode = N->getOpcode();
7255   assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7256          "Expected bitwise logic operation");
7257 
7258   if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
7259     return SDValue();
7260 
7261   // Match another bitwise logic op and a shift.
7262   unsigned ShiftOpcode = ShiftOp.getOpcode();
7263   if (LogicOp.getOpcode() != LogicOpcode ||
7264       !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
7265         ShiftOpcode == ISD::SRA))
7266     return SDValue();
7267 
7268   // Match another shift op inside the first logic operand. Handle both commuted
7269   // possibilities.
7270   // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7271   // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7272   SDValue X1 = ShiftOp.getOperand(0);
7273   SDValue Y = ShiftOp.getOperand(1);
7274   SDValue X0, Z;
7275   if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
7276       LogicOp.getOperand(0).getOperand(1) == Y) {
7277     X0 = LogicOp.getOperand(0).getOperand(0);
7278     Z = LogicOp.getOperand(1);
7279   } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
7280              LogicOp.getOperand(1).getOperand(1) == Y) {
7281     X0 = LogicOp.getOperand(1).getOperand(0);
7282     Z = LogicOp.getOperand(0);
7283   } else {
7284     return SDValue();
7285   }
7286 
7287   EVT VT = N->getValueType(0);
7288   SDLoc DL(N);
7289   SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
7290   SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
7291   return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
7292 }
7293 
7294 /// Given a tree of logic operations with shape like
7295 /// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
7296 /// try to match and fold shift operations with the same shift amount.
7297 /// For example:
7298 /// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
7299 /// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
foldLogicTreeOfShifts(SDNode * N,SDValue LeftHand,SDValue RightHand,SelectionDAG & DAG)7300 static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
7301                                      SDValue RightHand, SelectionDAG &DAG) {
7302   unsigned LogicOpcode = N->getOpcode();
7303   assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7304          "Expected bitwise logic operation");
7305   if (LeftHand.getOpcode() != LogicOpcode ||
7306       RightHand.getOpcode() != LogicOpcode)
7307     return SDValue();
7308   if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
7309     return SDValue();
7310 
7311   // Try to match one of following patterns:
7312   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
7313   // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
7314   // Note that foldLogicOfShifts will handle commuted versions of the left hand
7315   // itself.
7316   SDValue CombinedShifts, W;
7317   SDValue R0 = RightHand.getOperand(0);
7318   SDValue R1 = RightHand.getOperand(1);
7319   if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R0, DAG)))
7320     W = R1;
7321   else if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R1, DAG)))
7322     W = R0;
7323   else
7324     return SDValue();
7325 
7326   EVT VT = N->getValueType(0);
7327   SDLoc DL(N);
7328   return DAG.getNode(LogicOpcode, DL, VT, CombinedShifts, W);
7329 }
7330 
7331 /// Fold "masked merge" expressions like `(m & x) | (~m & y)` and its DeMorgan
7332 /// variant `(~m | x) & (m | y)` into the equivalent `((x ^ y) & m) ^ y)`
7333 /// pattern. This is typically a better representation for targets without a
7334 /// fused "and-not" operation.
foldMaskedMerge(SDNode * Node,SelectionDAG & DAG,const TargetLowering & TLI,const SDLoc & DL)7335 static SDValue foldMaskedMerge(SDNode *Node, SelectionDAG &DAG,
7336                                const TargetLowering &TLI, const SDLoc &DL) {
7337   // Note that masked-merge variants using XOR or ADD expressions are
7338   // normalized to OR by InstCombine so we only check for OR or AND.
7339   assert((Node->getOpcode() == ISD::OR || Node->getOpcode() == ISD::AND) &&
7340          "Must be called with ISD::OR or ISD::AND node");
7341 
7342   // If the target supports and-not, don't fold this.
7343   if (TLI.hasAndNot(SDValue(Node, 0)))
7344     return SDValue();
7345 
7346   SDValue M, X, Y;
7347 
7348   if (sd_match(Node,
7349                m_Or(m_OneUse(m_And(m_OneUse(m_Not(m_Value(M))), m_Value(Y))),
7350                     m_OneUse(m_And(m_Deferred(M), m_Value(X))))) ||
7351       sd_match(Node,
7352                m_And(m_OneUse(m_Or(m_OneUse(m_Not(m_Value(M))), m_Value(X))),
7353                      m_OneUse(m_Or(m_Deferred(M), m_Value(Y)))))) {
7354     EVT VT = M.getValueType();
7355     SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, X, Y);
7356     SDValue And = DAG.getNode(ISD::AND, DL, VT, Xor, M);
7357     return DAG.getNode(ISD::XOR, DL, VT, And, Y);
7358   }
7359   return SDValue();
7360 }
7361 
visitAND(SDNode * N)7362 SDValue DAGCombiner::visitAND(SDNode *N) {
7363   SDValue N0 = N->getOperand(0);
7364   SDValue N1 = N->getOperand(1);
7365   EVT VT = N1.getValueType();
7366   SDLoc DL(N);
7367 
7368   // x & x --> x
7369   if (N0 == N1)
7370     return N0;
7371 
7372   // fold (and c1, c2) -> c1&c2
7373   if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, DL, VT, {N0, N1}))
7374     return C;
7375 
7376   // canonicalize constant to RHS
7377   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7378       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7379     return DAG.getNode(ISD::AND, DL, VT, N1, N0);
7380 
7381   if (areBitwiseNotOfEachother(N0, N1))
7382     return DAG.getConstant(APInt::getZero(VT.getScalarSizeInBits()), DL, VT);
7383 
7384   // fold vector ops
7385   if (VT.isVector()) {
7386     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7387       return FoldedVOp;
7388 
7389     // fold (and x, 0) -> 0, vector edition
7390     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
7391       // do not return N1, because undef node may exist in N1
7392       return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()), DL,
7393                              N1.getValueType());
7394 
7395     // fold (and x, -1) -> x, vector edition
7396     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
7397       return N0;
7398 
7399     // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
7400     auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
7401     ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
7402     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
7403       EVT LoadVT = MLoad->getMemoryVT();
7404       EVT ExtVT = VT;
7405       if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
7406         // For this AND to be a zero extension of the masked load the elements
7407         // of the BuildVec must mask the bottom bits of the extended element
7408         // type
7409         uint64_t ElementSize =
7410             LoadVT.getVectorElementType().getScalarSizeInBits();
7411         if (Splat->getAPIntValue().isMask(ElementSize)) {
7412           SDValue NewLoad = DAG.getMaskedLoad(
7413               ExtVT, DL, MLoad->getChain(), MLoad->getBasePtr(),
7414               MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
7415               LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
7416               ISD::ZEXTLOAD, MLoad->isExpandingLoad());
7417           bool LoadHasOtherUsers = !N0.hasOneUse();
7418           CombineTo(N, NewLoad);
7419           if (LoadHasOtherUsers)
7420             CombineTo(MLoad, NewLoad.getValue(0), NewLoad.getValue(1));
7421           return SDValue(N, 0);
7422         }
7423       }
7424     }
7425   }
7426 
7427   // fold (and x, -1) -> x
7428   if (isAllOnesConstant(N1))
7429     return N0;
7430 
7431   // if (and x, c) is known to be zero, return 0
7432   unsigned BitWidth = VT.getScalarSizeInBits();
7433   ConstantSDNode *N1C = isConstOrConstSplat(N1);
7434   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
7435     return DAG.getConstant(0, DL, VT);
7436 
7437   if (SDValue R = foldAndOrOfSETCC(N, DAG))
7438     return R;
7439 
7440   if (SDValue NewSel = foldBinOpIntoSelect(N))
7441     return NewSel;
7442 
7443   // reassociate and
7444   if (SDValue RAND = reassociateOps(ISD::AND, DL, N0, N1, N->getFlags()))
7445     return RAND;
7446 
7447   // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7448   if (SDValue SD =
7449           reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, DL, VT, N0, N1))
7450     return SD;
7451 
7452   // fold (and (or x, C), D) -> D if (C & D) == D
7453   auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7454     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
7455   };
7456   if (N0.getOpcode() == ISD::OR &&
7457       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
7458     return N1;
7459 
7460   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7461     SDValue N0Op0 = N0.getOperand(0);
7462     EVT SrcVT = N0Op0.getValueType();
7463     unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7464     APInt Mask = ~N1C->getAPIntValue();
7465     Mask = Mask.trunc(SrcBitWidth);
7466 
7467     // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7468     if (DAG.MaskedValueIsZero(N0Op0, Mask))
7469       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0Op0);
7470 
7471     // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7472     if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7473         TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) &&
7474         TLI.isTypeDesirableForOp(ISD::AND, SrcVT) &&
7475         TLI.isNarrowingProfitable(N, VT, SrcVT))
7476       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
7477                          DAG.getNode(ISD::AND, DL, SrcVT, N0Op0,
7478                                      DAG.getZExtOrTrunc(N1, DL, SrcVT)));
7479   }
7480 
7481   // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7482   if (ISD::isExtOpcode(N0.getOpcode())) {
7483     unsigned ExtOpc = N0.getOpcode();
7484     SDValue N0Op0 = N0.getOperand(0);
7485     if (N0Op0.getOpcode() == ISD::AND &&
7486         (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0Op0, VT)) &&
7487         N0->hasOneUse() && N0Op0->hasOneUse()) {
7488       if (SDValue NewExt = DAG.FoldConstantArithmetic(ExtOpc, DL, VT,
7489                                                       {N0Op0.getOperand(1)})) {
7490         if (SDValue NewMask =
7491                 DAG.FoldConstantArithmetic(ISD::AND, DL, VT, {N1, NewExt})) {
7492           return DAG.getNode(ISD::AND, DL, VT,
7493                              DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(0)),
7494                              NewMask);
7495         }
7496       }
7497     }
7498   }
7499 
7500   // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7501   // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7502   // already be zero by virtue of the width of the base type of the load.
7503   //
7504   // the 'X' node here can either be nothing or an extract_vector_elt to catch
7505   // more cases.
7506   if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7507        N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
7508        N0.getOperand(0).getOpcode() == ISD::LOAD &&
7509        N0.getOperand(0).getResNo() == 0) ||
7510       (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7511     auto *Load =
7512         cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0));
7513 
7514     // Get the constant (if applicable) the zero'th operand is being ANDed with.
7515     // This can be a pure constant or a vector splat, in which case we treat the
7516     // vector as a scalar and use the splat value.
7517     APInt Constant = APInt::getZero(1);
7518     if (const ConstantSDNode *C = isConstOrConstSplat(
7519             N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) {
7520       Constant = C->getAPIntValue();
7521     } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
7522       unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
7523       APInt SplatValue, SplatUndef;
7524       unsigned SplatBitSize;
7525       bool HasAnyUndefs;
7526       // Endianness should not matter here. Code below makes sure that we only
7527       // use the result if the SplatBitSize is a multiple of the vector element
7528       // size. And after that we AND all element sized parts of the splat
7529       // together. So the end result should be the same regardless of in which
7530       // order we do those operations.
7531       const bool IsBigEndian = false;
7532       bool IsSplat =
7533           Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7534                                   HasAnyUndefs, EltBitWidth, IsBigEndian);
7535 
7536       // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7537       // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7538       if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7539         // Undef bits can contribute to a possible optimisation if set, so
7540         // set them.
7541         SplatValue |= SplatUndef;
7542 
7543         // The splat value may be something like "0x00FFFFFF", which means 0 for
7544         // the first vector value and FF for the rest, repeating. We need a mask
7545         // that will apply equally to all members of the vector, so AND all the
7546         // lanes of the constant together.
7547         Constant = APInt::getAllOnes(EltBitWidth);
7548         for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7549           Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
7550       }
7551     }
7552 
7553     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7554     // actually legal and isn't going to get expanded, else this is a false
7555     // optimisation.
7556     bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
7557                                                     Load->getValueType(0),
7558                                                     Load->getMemoryVT());
7559 
7560     // Resize the constant to the same size as the original memory access before
7561     // extension. If it is still the AllOnesValue then this AND is completely
7562     // unneeded.
7563     Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
7564 
7565     bool B;
7566     switch (Load->getExtensionType()) {
7567     default: B = false; break;
7568     case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7569     case ISD::ZEXTLOAD:
7570     case ISD::NON_EXTLOAD: B = true; break;
7571     }
7572 
7573     if (B && Constant.isAllOnes()) {
7574       // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7575       // preserve semantics once we get rid of the AND.
7576       SDValue NewLoad(Load, 0);
7577 
7578       // Fold the AND away. NewLoad may get replaced immediately.
7579       CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
7580 
7581       if (Load->getExtensionType() == ISD::EXTLOAD) {
7582         NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
7583                               Load->getValueType(0), SDLoc(Load),
7584                               Load->getChain(), Load->getBasePtr(),
7585                               Load->getOffset(), Load->getMemoryVT(),
7586                               Load->getMemOperand());
7587         // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7588         if (Load->getNumValues() == 3) {
7589           // PRE/POST_INC loads have 3 values.
7590           SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
7591                            NewLoad.getValue(2) };
7592           CombineTo(Load, To, 3, true);
7593         } else {
7594           CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
7595         }
7596       }
7597 
7598       return SDValue(N, 0); // Return N so it doesn't get rechecked!
7599     }
7600   }
7601 
7602   // Try to convert a constant mask AND into a shuffle clear mask.
7603   if (VT.isVector())
7604     if (SDValue Shuffle = XformToShuffleWithZero(N))
7605       return Shuffle;
7606 
7607   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7608     return Combined;
7609 
7610   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7611       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
7612     SDValue Ext = N0.getOperand(0);
7613     EVT ExtVT = Ext->getValueType(0);
7614     SDValue Extendee = Ext->getOperand(0);
7615 
7616     unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7617     if (N1C->getAPIntValue().isMask(ScalarWidth) &&
7618         (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
7619       //    (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7620       // => (extract_subvector (iN_zeroext v))
7621       SDValue ZeroExtExtendee =
7622           DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVT, Extendee);
7623 
7624       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ZeroExtExtendee,
7625                          N0.getOperand(1));
7626     }
7627   }
7628 
7629   // fold (and (masked_gather x)) -> (zext_masked_gather x)
7630   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
7631     EVT MemVT = GN0->getMemoryVT();
7632     EVT ScalarVT = MemVT.getScalarType();
7633 
7634     if (SDValue(GN0, 0).hasOneUse() &&
7635         isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
7636         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
7637       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
7638                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
7639 
7640       SDValue ZExtLoad = DAG.getMaskedGather(
7641           DAG.getVTList(VT, MVT::Other), MemVT, DL, Ops, GN0->getMemOperand(),
7642           GN0->getIndexType(), ISD::ZEXTLOAD);
7643 
7644       CombineTo(N, ZExtLoad);
7645       AddToWorklist(ZExtLoad.getNode());
7646       // Avoid recheck of N.
7647       return SDValue(N, 0);
7648     }
7649   }
7650 
7651   // fold (and (load x), 255) -> (zextload x, i8)
7652   // fold (and (extload x, i16), 255) -> (zextload x, i8)
7653   if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7654     if (SDValue Res = reduceLoadWidth(N))
7655       return Res;
7656 
7657   if (LegalTypes) {
7658     // Attempt to propagate the AND back up to the leaves which, if they're
7659     // loads, can be combined to narrow loads and the AND node can be removed.
7660     // Perform after legalization so that extend nodes will already be
7661     // combined into the loads.
7662     if (BackwardsPropagateMask(N))
7663       return SDValue(N, 0);
7664   }
7665 
7666   if (SDValue Combined = visitANDLike(N0, N1, N))
7667     return Combined;
7668 
7669   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
7670   if (N0.getOpcode() == N1.getOpcode())
7671     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7672       return V;
7673 
7674   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7675     return R;
7676   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
7677     return R;
7678 
7679   // Fold (and X, (bswap (not Y))) -> (and X, (not (bswap Y)))
7680   // Fold (and X, (bitreverse (not Y))) -> (and X, (not (bitreverse Y)))
7681   SDValue X, Y, Z, NotY;
7682   for (unsigned Opc : {ISD::BSWAP, ISD::BITREVERSE})
7683     if (sd_match(N,
7684                  m_And(m_Value(X), m_OneUse(m_UnaryOp(Opc, m_Value(NotY))))) &&
7685         sd_match(NotY, m_Not(m_Value(Y))) &&
7686         (TLI.hasAndNot(SDValue(N, 0)) || NotY->hasOneUse()))
7687       return DAG.getNode(ISD::AND, DL, VT, X,
7688                          DAG.getNOT(DL, DAG.getNode(Opc, DL, VT, Y), VT));
7689 
7690   // Fold (and X, (rot (not Y), Z)) -> (and X, (not (rot Y, Z)))
7691   for (unsigned Opc : {ISD::ROTL, ISD::ROTR})
7692     if (sd_match(N, m_And(m_Value(X),
7693                           m_OneUse(m_BinOp(Opc, m_Value(NotY), m_Value(Z))))) &&
7694         sd_match(NotY, m_Not(m_Value(Y))) &&
7695         (TLI.hasAndNot(SDValue(N, 0)) || NotY->hasOneUse()))
7696       return DAG.getNode(ISD::AND, DL, VT, X,
7697                          DAG.getNOT(DL, DAG.getNode(Opc, DL, VT, Y, Z), VT));
7698 
7699   // Fold (and X, (add (not Y), Z)) -> (and X, (not (sub Y, Z)))
7700   // Fold (and X, (sub (not Y), Z)) -> (and X, (not (add Y, Z)))
7701   if (TLI.hasAndNot(SDValue(N, 0)))
7702     if (SDValue Folded = foldBitwiseOpWithNeg(N, DL, VT))
7703       return Folded;
7704 
7705   // Fold (and (srl X, C), 1) -> (srl X, BW-1) for signbit extraction
7706   // If we are shifting down an extended sign bit, see if we can simplify
7707   // this to shifting the MSB directly to expose further simplifications.
7708   // This pattern often appears after sext_inreg legalization.
7709   APInt Amt;
7710   if (sd_match(N, m_And(m_Srl(m_Value(X), m_ConstInt(Amt)), m_One())) &&
7711       Amt.ult(BitWidth - 1) && Amt.uge(BitWidth - DAG.ComputeNumSignBits(X)))
7712     return DAG.getNode(ISD::SRL, DL, VT, X,
7713                        DAG.getShiftAmountConstant(BitWidth - 1, VT, DL));
7714 
7715   // Masking the negated extension of a boolean is just the zero-extended
7716   // boolean:
7717   // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7718   // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7719   //
7720   // Note: the SimplifyDemandedBits fold below can make an information-losing
7721   // transform, and then we have no way to find this better fold.
7722   if (sd_match(N, m_And(m_Sub(m_Zero(), m_Value(X)), m_One()))) {
7723     if (X.getOpcode() == ISD::ZERO_EXTEND &&
7724         X.getOperand(0).getScalarValueSizeInBits() == 1)
7725       return X;
7726     if (X.getOpcode() == ISD::SIGN_EXTEND &&
7727         X.getOperand(0).getScalarValueSizeInBits() == 1)
7728       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, X.getOperand(0));
7729   }
7730 
7731   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7732   // fold (and (sra)) -> (and (srl)) when possible.
7733   if (SimplifyDemandedBits(SDValue(N, 0)))
7734     return SDValue(N, 0);
7735 
7736   // fold (zext_inreg (extload x)) -> (zextload x)
7737   // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7738   if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
7739       (ISD::isEXTLoad(N0.getNode()) ||
7740        (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
7741     auto *LN0 = cast<LoadSDNode>(N0);
7742     EVT MemVT = LN0->getMemoryVT();
7743     // If we zero all the possible extended bits, then we can turn this into
7744     // a zextload if we are running before legalize or the operation is legal.
7745     unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7746     unsigned MemBitSize = MemVT.getScalarSizeInBits();
7747     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
7748     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
7749         ((!LegalOperations && LN0->isSimple()) ||
7750          TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
7751       SDValue ExtLoad =
7752           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
7753                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
7754       AddToWorklist(N);
7755       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
7756       return SDValue(N, 0); // Return N so it doesn't get rechecked!
7757     }
7758   }
7759 
7760   // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7761   if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7762     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
7763                                            N0.getOperand(1), false))
7764       return BSwap;
7765   }
7766 
7767   if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7768     return Shifts;
7769 
7770   if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
7771     return V;
7772 
7773   // Recognize the following pattern:
7774   //
7775   // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7776   //
7777   // where bitmask is a mask that clears the upper bits of AndVT. The
7778   // number of bits in bitmask must be a power of two.
7779   auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7780     if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7781       return false;
7782 
7783     auto *C = dyn_cast<ConstantSDNode>(RHS);
7784     if (!C)
7785       return false;
7786 
7787     if (!C->getAPIntValue().isMask(
7788             LHS.getOperand(0).getValueType().getFixedSizeInBits()))
7789       return false;
7790 
7791     return true;
7792   };
7793 
7794   // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7795   if (IsAndZeroExtMask(N0, N1))
7796     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
7797 
7798   if (hasOperation(ISD::USUBSAT, VT))
7799     if (SDValue V = foldAndToUsubsat(N, DAG, DL))
7800       return V;
7801 
7802   // Postpone until legalization completed to avoid interference with bswap
7803   // folding
7804   if (LegalOperations || VT.isVector())
7805     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
7806       return R;
7807 
7808   if (VT.isScalarInteger() && VT != MVT::i1)
7809     if (SDValue R = foldMaskedMerge(N, DAG, TLI, DL))
7810       return R;
7811 
7812   return SDValue();
7813 }
7814 
7815 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)7816 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7817                                         bool DemandHighBits) {
7818   if (!LegalOperations)
7819     return SDValue();
7820 
7821   EVT VT = N->getValueType(0);
7822   if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7823     return SDValue();
7824   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
7825     return SDValue();
7826 
7827   // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7828   bool LookPassAnd0 = false;
7829   bool LookPassAnd1 = false;
7830   if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
7831     std::swap(N0, N1);
7832   if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
7833     std::swap(N0, N1);
7834   if (N0.getOpcode() == ISD::AND) {
7835     if (!N0->hasOneUse())
7836       return SDValue();
7837     ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7838     // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7839     // This is needed for X86.
7840     if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7841                   N01C->getZExtValue() != 0xFFFF))
7842       return SDValue();
7843     N0 = N0.getOperand(0);
7844     LookPassAnd0 = true;
7845   }
7846 
7847   if (N1.getOpcode() == ISD::AND) {
7848     if (!N1->hasOneUse())
7849       return SDValue();
7850     ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7851     if (!N11C || N11C->getZExtValue() != 0xFF)
7852       return SDValue();
7853     N1 = N1.getOperand(0);
7854     LookPassAnd1 = true;
7855   }
7856 
7857   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7858     std::swap(N0, N1);
7859   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7860     return SDValue();
7861   if (!N0->hasOneUse() || !N1->hasOneUse())
7862     return SDValue();
7863 
7864   ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7865   ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7866   if (!N01C || !N11C)
7867     return SDValue();
7868   if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7869     return SDValue();
7870 
7871   // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7872   SDValue N00 = N0->getOperand(0);
7873   if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7874     if (!N00->hasOneUse())
7875       return SDValue();
7876     ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
7877     if (!N001C || N001C->getZExtValue() != 0xFF)
7878       return SDValue();
7879     N00 = N00.getOperand(0);
7880     LookPassAnd0 = true;
7881   }
7882 
7883   SDValue N10 = N1->getOperand(0);
7884   if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7885     if (!N10->hasOneUse())
7886       return SDValue();
7887     ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
7888     // Also allow 0xFFFF since the bits will be shifted out. This is needed
7889     // for X86.
7890     if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7891                    N101C->getZExtValue() != 0xFFFF))
7892       return SDValue();
7893     N10 = N10.getOperand(0);
7894     LookPassAnd1 = true;
7895   }
7896 
7897   if (N00 != N10)
7898     return SDValue();
7899 
7900   // Make sure everything beyond the low halfword gets set to zero since the SRL
7901   // 16 will clear the top bits.
7902   unsigned OpSizeInBits = VT.getSizeInBits();
7903   if (OpSizeInBits > 16) {
7904     // If the left-shift isn't masked out then the only way this is a bswap is
7905     // if all bits beyond the low 8 are 0. In that case the entire pattern
7906     // reduces to a left shift anyway: leave it for other parts of the combiner.
7907     if (DemandHighBits && !LookPassAnd0)
7908       return SDValue();
7909 
7910     // However, if the right shift isn't masked out then it might be because
7911     // it's not needed. See if we can spot that too. If the high bits aren't
7912     // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7913     // upper bits to be zero.
7914     if (!LookPassAnd1) {
7915       unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7916       if (!DAG.MaskedValueIsZero(N10,
7917                                  APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
7918         return SDValue();
7919     }
7920   }
7921 
7922   SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
7923   if (OpSizeInBits > 16) {
7924     SDLoc DL(N);
7925     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
7926                       DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
7927   }
7928   return Res;
7929 }
7930 
7931 /// Return true if the specified node is an element that makes up a 32-bit
7932 /// packed halfword byteswap.
7933 /// ((x & 0x000000ff) << 8) |
7934 /// ((x & 0x0000ff00) >> 8) |
7935 /// ((x & 0x00ff0000) << 8) |
7936 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)7937 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7938   if (!N->hasOneUse())
7939     return false;
7940 
7941   unsigned Opc = N.getOpcode();
7942   if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7943     return false;
7944 
7945   SDValue N0 = N.getOperand(0);
7946   unsigned Opc0 = N0.getOpcode();
7947   if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7948     return false;
7949 
7950   ConstantSDNode *N1C = nullptr;
7951   // SHL or SRL: look upstream for AND mask operand
7952   if (Opc == ISD::AND)
7953     N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7954   else if (Opc0 == ISD::AND)
7955     N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7956   if (!N1C)
7957     return false;
7958 
7959   unsigned MaskByteOffset;
7960   switch (N1C->getZExtValue()) {
7961   default:
7962     return false;
7963   case 0xFF:       MaskByteOffset = 0; break;
7964   case 0xFF00:     MaskByteOffset = 1; break;
7965   case 0xFFFF:
7966     // In case demanded bits didn't clear the bits that will be shifted out.
7967     // This is needed for X86.
7968     if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7969       MaskByteOffset = 1;
7970       break;
7971     }
7972     return false;
7973   case 0xFF0000:   MaskByteOffset = 2; break;
7974   case 0xFF000000: MaskByteOffset = 3; break;
7975   }
7976 
7977   // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7978   if (Opc == ISD::AND) {
7979     if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7980       // (x >> 8) & 0xff
7981       // (x >> 8) & 0xff0000
7982       if (Opc0 != ISD::SRL)
7983         return false;
7984       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7985       if (!C || C->getZExtValue() != 8)
7986         return false;
7987     } else {
7988       // (x << 8) & 0xff00
7989       // (x << 8) & 0xff000000
7990       if (Opc0 != ISD::SHL)
7991         return false;
7992       ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7993       if (!C || C->getZExtValue() != 8)
7994         return false;
7995     }
7996   } else if (Opc == ISD::SHL) {
7997     // (x & 0xff) << 8
7998     // (x & 0xff0000) << 8
7999     if (MaskByteOffset != 0 && MaskByteOffset != 2)
8000       return false;
8001     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
8002     if (!C || C->getZExtValue() != 8)
8003       return false;
8004   } else { // Opc == ISD::SRL
8005     // (x & 0xff00) >> 8
8006     // (x & 0xff000000) >> 8
8007     if (MaskByteOffset != 1 && MaskByteOffset != 3)
8008       return false;
8009     ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
8010     if (!C || C->getZExtValue() != 8)
8011       return false;
8012   }
8013 
8014   if (Parts[MaskByteOffset])
8015     return false;
8016 
8017   Parts[MaskByteOffset] = N0.getOperand(0).getNode();
8018   return true;
8019 }
8020 
8021 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)8022 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
8023   if (N.getOpcode() == ISD::OR)
8024     return isBSwapHWordElement(N.getOperand(0), Parts) &&
8025            isBSwapHWordElement(N.getOperand(1), Parts);
8026 
8027   if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
8028     ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
8029     if (!C || C->getAPIntValue() != 16)
8030       return false;
8031     Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
8032     return true;
8033   }
8034 
8035   return false;
8036 }
8037 
8038 // Match this pattern:
8039 //   (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
8040 // And rewrite this to:
8041 //   (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT)8042 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
8043                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
8044                                        SDValue N1, EVT VT) {
8045   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
8046          "MatchBSwapHWordOrAndAnd: expecting i32");
8047   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
8048     return SDValue();
8049   if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
8050     return SDValue();
8051   // TODO: this is too restrictive; lifting this restriction requires more tests
8052   if (!N0->hasOneUse() || !N1->hasOneUse())
8053     return SDValue();
8054   ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
8055   ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
8056   if (!Mask0 || !Mask1)
8057     return SDValue();
8058   if (Mask0->getAPIntValue() != 0xff00ff00 ||
8059       Mask1->getAPIntValue() != 0x00ff00ff)
8060     return SDValue();
8061   SDValue Shift0 = N0.getOperand(0);
8062   SDValue Shift1 = N1.getOperand(0);
8063   if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
8064     return SDValue();
8065   ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
8066   ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
8067   if (!ShiftAmt0 || !ShiftAmt1)
8068     return SDValue();
8069   if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
8070     return SDValue();
8071   if (Shift0.getOperand(0) != Shift1.getOperand(0))
8072     return SDValue();
8073 
8074   SDLoc DL(N);
8075   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
8076   SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
8077   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
8078 }
8079 
8080 /// Match a 32-bit packed halfword bswap. That is
8081 /// ((x & 0x000000ff) << 8) |
8082 /// ((x & 0x0000ff00) >> 8) |
8083 /// ((x & 0x00ff0000) << 8) |
8084 /// ((x & 0xff000000) >> 8)
8085 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)8086 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
8087   if (!LegalOperations)
8088     return SDValue();
8089 
8090   EVT VT = N->getValueType(0);
8091   if (VT != MVT::i32)
8092     return SDValue();
8093   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
8094     return SDValue();
8095 
8096   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
8097     return BSwap;
8098 
8099   // Try again with commuted operands.
8100   if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
8101     return BSwap;
8102 
8103 
8104   // Look for either
8105   // (or (bswaphpair), (bswaphpair))
8106   // (or (or (bswaphpair), (and)), (and))
8107   // (or (or (and), (bswaphpair)), (and))
8108   SDNode *Parts[4] = {};
8109 
8110   if (isBSwapHWordPair(N0, Parts)) {
8111     // (or (or (and), (and)), (or (and), (and)))
8112     if (!isBSwapHWordPair(N1, Parts))
8113       return SDValue();
8114   } else if (N0.getOpcode() == ISD::OR) {
8115     // (or (or (or (and), (and)), (and)), (and))
8116     if (!isBSwapHWordElement(N1, Parts))
8117       return SDValue();
8118     SDValue N00 = N0.getOperand(0);
8119     SDValue N01 = N0.getOperand(1);
8120     if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
8121         !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
8122       return SDValue();
8123   } else {
8124     return SDValue();
8125   }
8126 
8127   // Make sure the parts are all coming from the same node.
8128   if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
8129     return SDValue();
8130 
8131   SDLoc DL(N);
8132   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
8133                               SDValue(Parts[0], 0));
8134 
8135   // Result of the bswap should be rotated by 16. If it's not legal, then
8136   // do  (x << 16) | (x >> 16).
8137   SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
8138   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
8139     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
8140   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
8141     return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
8142   return DAG.getNode(ISD::OR, DL, VT,
8143                      DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
8144                      DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
8145 }
8146 
8147 /// This contains all DAGCombine rules which reduce two values combined by
8148 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,const SDLoc & DL)8149 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
8150   EVT VT = N1.getValueType();
8151 
8152   // fold (or x, undef) -> -1
8153   if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
8154     return DAG.getAllOnesConstant(DL, VT);
8155 
8156   if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
8157     return V;
8158 
8159   // (or (and X, C1), (and Y, C2))  -> (and (or X, Y), C3) if possible.
8160   if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
8161       // Don't increase # computations.
8162       (N0->hasOneUse() || N1->hasOneUse())) {
8163     // We can only do this xform if we know that bits from X that are set in C2
8164     // but not in C1 are already zero.  Likewise for Y.
8165     if (const ConstantSDNode *N0O1C =
8166         getAsNonOpaqueConstant(N0.getOperand(1))) {
8167       if (const ConstantSDNode *N1O1C =
8168           getAsNonOpaqueConstant(N1.getOperand(1))) {
8169         // We can only do this xform if we know that bits from X that are set in
8170         // C2 but not in C1 are already zero.  Likewise for Y.
8171         const APInt &LHSMask = N0O1C->getAPIntValue();
8172         const APInt &RHSMask = N1O1C->getAPIntValue();
8173 
8174         if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
8175             DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
8176           SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
8177                                   N0.getOperand(0), N1.getOperand(0));
8178           return DAG.getNode(ISD::AND, DL, VT, X,
8179                              DAG.getConstant(LHSMask | RHSMask, DL, VT));
8180         }
8181       }
8182     }
8183   }
8184 
8185   // (or (and X, M), (and X, N)) -> (and X, (or M, N))
8186   if (N0.getOpcode() == ISD::AND &&
8187       N1.getOpcode() == ISD::AND &&
8188       N0.getOperand(0) == N1.getOperand(0) &&
8189       // Don't increase # computations.
8190       (N0->hasOneUse() || N1->hasOneUse())) {
8191     SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
8192                             N0.getOperand(1), N1.getOperand(1));
8193     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
8194   }
8195 
8196   return SDValue();
8197 }
8198 
8199 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)8200 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
8201                                   SDNode *N) {
8202   EVT VT = N0.getValueType();
8203   unsigned BW = VT.getScalarSizeInBits();
8204   SDLoc DL(N);
8205 
8206   auto peekThroughResize = [](SDValue V) {
8207     if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
8208       return V->getOperand(0);
8209     return V;
8210   };
8211 
8212   SDValue N0Resized = peekThroughResize(N0);
8213   if (N0Resized.getOpcode() == ISD::AND) {
8214     SDValue N1Resized = peekThroughResize(N1);
8215     SDValue N00 = N0Resized.getOperand(0);
8216     SDValue N01 = N0Resized.getOperand(1);
8217 
8218     // fold or (and x, y), x --> x
8219     if (N00 == N1Resized || N01 == N1Resized)
8220       return N1;
8221 
8222     // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
8223     // TODO: Set AllowUndefs = true.
8224     if (SDValue NotOperand = getBitwiseNotOperand(N01, N00,
8225                                                   /* AllowUndefs */ false)) {
8226       if (peekThroughResize(NotOperand) == N1Resized)
8227         return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N00, DL, VT),
8228                            N1);
8229     }
8230 
8231     // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
8232     if (SDValue NotOperand = getBitwiseNotOperand(N00, N01,
8233                                                   /* AllowUndefs */ false)) {
8234       if (peekThroughResize(NotOperand) == N1Resized)
8235         return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N01, DL, VT),
8236                            N1);
8237     }
8238   }
8239 
8240   SDValue X, Y;
8241 
8242   // fold or (xor X, N1), N1 --> or X, N1
8243   if (sd_match(N0, m_Xor(m_Value(X), m_Specific(N1))))
8244     return DAG.getNode(ISD::OR, DL, VT, X, N1);
8245 
8246   // fold or (xor x, y), (x and/or y) --> or x, y
8247   if (sd_match(N0, m_Xor(m_Value(X), m_Value(Y))) &&
8248       (sd_match(N1, m_And(m_Specific(X), m_Specific(Y))) ||
8249        sd_match(N1, m_Or(m_Specific(X), m_Specific(Y)))))
8250     return DAG.getNode(ISD::OR, DL, VT, X, Y);
8251 
8252   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
8253     return R;
8254 
8255   auto peekThroughZext = [](SDValue V) {
8256     if (V->getOpcode() == ISD::ZERO_EXTEND)
8257       return V->getOperand(0);
8258     return V;
8259   };
8260 
8261   // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
8262   if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
8263       N0.getOperand(0) == N1.getOperand(0) &&
8264       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
8265     return N0;
8266 
8267   // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
8268   if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
8269       N0.getOperand(1) == N1.getOperand(0) &&
8270       peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
8271     return N0;
8272 
8273   // Attempt to match a legalized build_pair-esque pattern:
8274   // or(shl(aext(Hi),BW/2),zext(Lo))
8275   SDValue Lo, Hi;
8276   if (sd_match(N0,
8277                m_OneUse(m_Shl(m_AnyExt(m_Value(Hi)), m_SpecificInt(BW / 2)))) &&
8278       sd_match(N1, m_ZExt(m_Value(Lo))) &&
8279       Lo.getScalarValueSizeInBits() == (BW / 2) &&
8280       Lo.getValueType() == Hi.getValueType()) {
8281     // Fold build_pair(not(Lo),not(Hi)) -> not(build_pair(Lo,Hi)).
8282     SDValue NotLo, NotHi;
8283     if (sd_match(Lo, m_OneUse(m_Not(m_Value(NotLo)))) &&
8284         sd_match(Hi, m_OneUse(m_Not(m_Value(NotHi))))) {
8285       Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotLo);
8286       Hi = DAG.getNode(ISD::ANY_EXTEND, DL, VT, NotHi);
8287       Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
8288                        DAG.getShiftAmountConstant(BW / 2, VT, DL));
8289       return DAG.getNOT(DL, DAG.getNode(ISD::OR, DL, VT, Lo, Hi), VT);
8290     }
8291   }
8292 
8293   return SDValue();
8294 }
8295 
visitOR(SDNode * N)8296 SDValue DAGCombiner::visitOR(SDNode *N) {
8297   SDValue N0 = N->getOperand(0);
8298   SDValue N1 = N->getOperand(1);
8299   EVT VT = N1.getValueType();
8300   SDLoc DL(N);
8301 
8302   // x | x --> x
8303   if (N0 == N1)
8304     return N0;
8305 
8306   // fold (or c1, c2) -> c1|c2
8307   if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, DL, VT, {N0, N1}))
8308     return C;
8309 
8310   // canonicalize constant to RHS
8311   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
8312       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
8313     return DAG.getNode(ISD::OR, DL, VT, N1, N0);
8314 
8315   // fold vector ops
8316   if (VT.isVector()) {
8317     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8318       return FoldedVOp;
8319 
8320     // fold (or x, 0) -> x, vector edition
8321     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
8322       return N0;
8323 
8324     // fold (or x, -1) -> -1, vector edition
8325     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
8326       // do not return N1, because undef node may exist in N1
8327       return DAG.getAllOnesConstant(DL, N1.getValueType());
8328 
8329     // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
8330     // Do this only if the resulting type / shuffle is legal.
8331     auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
8332     auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
8333     if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
8334       bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
8335       bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
8336       bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
8337       bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
8338       // Ensure both shuffles have a zero input.
8339       if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
8340         assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
8341         assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
8342         bool CanFold = true;
8343         int NumElts = VT.getVectorNumElements();
8344         SmallVector<int, 4> Mask(NumElts, -1);
8345 
8346         for (int i = 0; i != NumElts; ++i) {
8347           int M0 = SV0->getMaskElt(i);
8348           int M1 = SV1->getMaskElt(i);
8349 
8350           // Determine if either index is pointing to a zero vector.
8351           bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
8352           bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
8353 
8354           // If one element is zero and the otherside is undef, keep undef.
8355           // This also handles the case that both are undef.
8356           if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
8357             continue;
8358 
8359           // Make sure only one of the elements is zero.
8360           if (M0Zero == M1Zero) {
8361             CanFold = false;
8362             break;
8363           }
8364 
8365           assert((M0 >= 0 || M1 >= 0) && "Undef index!");
8366 
8367           // We have a zero and non-zero element. If the non-zero came from
8368           // SV0 make the index a LHS index. If it came from SV1, make it
8369           // a RHS index. We need to mod by NumElts because we don't care
8370           // which operand it came from in the original shuffles.
8371           Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
8372         }
8373 
8374         if (CanFold) {
8375           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
8376           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
8377           SDValue LegalShuffle =
8378               TLI.buildLegalVectorShuffle(VT, DL, NewLHS, NewRHS, Mask, DAG);
8379           if (LegalShuffle)
8380             return LegalShuffle;
8381         }
8382       }
8383     }
8384   }
8385 
8386   // fold (or x, 0) -> x
8387   if (isNullConstant(N1))
8388     return N0;
8389 
8390   // fold (or x, -1) -> -1
8391   if (isAllOnesConstant(N1))
8392     return N1;
8393 
8394   if (SDValue NewSel = foldBinOpIntoSelect(N))
8395     return NewSel;
8396 
8397   // fold (or x, c) -> c iff (x & ~c) == 0
8398   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
8399   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
8400     return N1;
8401 
8402   if (SDValue R = foldAndOrOfSETCC(N, DAG))
8403     return R;
8404 
8405   if (SDValue Combined = visitORLike(N0, N1, DL))
8406     return Combined;
8407 
8408   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8409     return Combined;
8410 
8411   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
8412   if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
8413     return BSwap;
8414   if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
8415     return BSwap;
8416 
8417   // reassociate or
8418   if (SDValue ROR = reassociateOps(ISD::OR, DL, N0, N1, N->getFlags()))
8419     return ROR;
8420 
8421   // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
8422   if (SDValue SD =
8423           reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, DL, VT, N0, N1))
8424     return SD;
8425 
8426   // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
8427   // iff (c1 & c2) != 0 or c1/c2 are undef.
8428   auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
8429     return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
8430   };
8431   if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
8432       ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
8433     if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
8434                                                  {N1, N0.getOperand(1)})) {
8435       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
8436       AddToWorklist(IOR.getNode());
8437       return DAG.getNode(ISD::AND, DL, VT, COR, IOR);
8438     }
8439   }
8440 
8441   if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
8442     return Combined;
8443   if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
8444     return Combined;
8445 
8446   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
8447   if (N0.getOpcode() == N1.getOpcode())
8448     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8449       return V;
8450 
8451   // See if this is some rotate idiom.
8452   if (SDValue Rot = MatchRotate(N0, N1, DL, /*FromAdd=*/false))
8453     return Rot;
8454 
8455   if (SDValue Load = MatchLoadCombine(N))
8456     return Load;
8457 
8458   // Simplify the operands using demanded-bits information.
8459   if (SimplifyDemandedBits(SDValue(N, 0)))
8460     return SDValue(N, 0);
8461 
8462   // If OR can be rewritten into ADD, try combines based on ADD.
8463   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8464       DAG.isADDLike(SDValue(N, 0)))
8465     if (SDValue Combined = visitADDLike(N))
8466       return Combined;
8467 
8468   // Postpone until legalization completed to avoid interference with bswap
8469   // folding
8470   if (LegalOperations || VT.isVector())
8471     if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8472       return R;
8473 
8474   if (VT.isScalarInteger() && VT != MVT::i1)
8475     if (SDValue R = foldMaskedMerge(N, DAG, TLI, DL))
8476       return R;
8477 
8478   return SDValue();
8479 }
8480 
stripConstantMask(const SelectionDAG & DAG,SDValue Op,SDValue & Mask)8481 static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8482                                  SDValue &Mask) {
8483   if (Op.getOpcode() == ISD::AND &&
8484       DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
8485     Mask = Op.getOperand(1);
8486     return Op.getOperand(0);
8487   }
8488   return Op;
8489 }
8490 
8491 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(const SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)8492 static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8493                             SDValue &Mask) {
8494   Op = stripConstantMask(DAG, Op, Mask);
8495   if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8496     Shift = Op;
8497     return true;
8498   }
8499   return false;
8500 }
8501 
8502 /// Helper function for visitOR to extract the needed side of a rotate idiom
8503 /// from a shl/srl/mul/udiv.  This is meant to handle cases where
8504 /// InstCombine merged some outside op with one of the shifts from
8505 /// the rotate pattern.
8506 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8507 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
8508 /// patterns:
8509 ///
8510 ///   (or (add v v) (shrl v bitwidth-1)):
8511 ///     expands (add v v) -> (shl v 1)
8512 ///
8513 ///   (or (mul v c0) (shrl (mul v c1) c2)):
8514 ///     expands (mul v c0) -> (shl (mul v c1) c3)
8515 ///
8516 ///   (or (udiv v c0) (shl (udiv v c1) c2)):
8517 ///     expands (udiv v c0) -> (shrl (udiv v c1) c3)
8518 ///
8519 ///   (or (shl v c0) (shrl (shl v c1) c2)):
8520 ///     expands (shl v c0) -> (shl (shl v c1) c3)
8521 ///
8522 ///   (or (shrl v c0) (shl (shrl v c1) c2)):
8523 ///     expands (shrl v c0) -> (shrl (shrl v c1) c3)
8524 ///
8525 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)8526 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8527                                      SDValue ExtractFrom, SDValue &Mask,
8528                                      const SDLoc &DL) {
8529   assert(OppShift && ExtractFrom && "Empty SDValue");
8530   if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8531     return SDValue();
8532 
8533   ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
8534 
8535   // Value and Type of the shift.
8536   SDValue OppShiftLHS = OppShift.getOperand(0);
8537   EVT ShiftedVT = OppShiftLHS.getValueType();
8538 
8539   // Amount of the existing shift.
8540   ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
8541 
8542   // (add v v) -> (shl v 1)
8543   // TODO: Should this be a general DAG canonicalization?
8544   if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8545       ExtractFrom.getOpcode() == ISD::ADD &&
8546       ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
8547       ExtractFrom.getOperand(0) == OppShiftLHS &&
8548       OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8549     return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
8550                        DAG.getShiftAmountConstant(1, ShiftedVT, DL));
8551 
8552   // Preconditions:
8553   //    (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8554   //
8555   // Find opcode of the needed shift to be extracted from (op0 v c0).
8556   unsigned Opcode = ISD::DELETED_NODE;
8557   bool IsMulOrDiv = false;
8558   // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8559   // opcode or its arithmetic (mul or udiv) variant.
8560   auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8561     IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8562     if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8563       return false;
8564     Opcode = NeededShift;
8565     return true;
8566   };
8567   // op0 must be either the needed shift opcode or the mul/udiv equivalent
8568   // that the needed shift can be extracted from.
8569   if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8570       (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8571     return SDValue();
8572 
8573   // op0 must be the same opcode on both sides, have the same LHS argument,
8574   // and produce the same value type.
8575   if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8576       OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
8577       ShiftedVT != ExtractFrom.getValueType())
8578     return SDValue();
8579 
8580   // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8581   ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
8582   // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8583   ConstantSDNode *ExtractFromCst =
8584       isConstOrConstSplat(ExtractFrom.getOperand(1));
8585   // TODO: We should be able to handle non-uniform constant vectors for these values
8586   // Check that we have constant values.
8587   if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8588       !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8589       !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8590     return SDValue();
8591 
8592   // Compute the shift amount we need to extract to complete the rotate.
8593   const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8594   if (OppShiftCst->getAPIntValue().ugt(VTWidth))
8595     return SDValue();
8596   APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8597   // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8598   APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8599   APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8600   zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
8601 
8602   // Now try extract the needed shift from the ExtractFrom op and see if the
8603   // result matches up with the existing shift's LHS op.
8604   if (IsMulOrDiv) {
8605     // Op to extract from is a mul or udiv by a constant.
8606     // Check:
8607     //     c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8608     //     c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8609     const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
8610                                                  NeededShiftAmt.getZExtValue());
8611     APInt ResultAmt;
8612     APInt Rem;
8613     APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
8614     if (Rem != 0 || ResultAmt != OppLHSAmt)
8615       return SDValue();
8616   } else {
8617     // Op to extract from is a shift by a constant.
8618     // Check:
8619     //      c2 - (bitwidth(op0 v c0) - c1) == c0
8620     if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8621                                           ExtractFromAmt.getBitWidth()))
8622       return SDValue();
8623   }
8624 
8625   // Return the expanded shift op that should allow a rotate to be formed.
8626   EVT ShiftVT = OppShift.getOperand(1).getValueType();
8627   EVT ResVT = ExtractFrom.getValueType();
8628   SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
8629   return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
8630 }
8631 
8632 // Return true if we can prove that, whenever Neg and Pos are both in the
8633 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos).  This means that
8634 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8635 //
8636 //     (or (shift1 X, Neg), (shift2 X, Pos))
8637 //
8638 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8639 // in direction shift1 by Neg.  The range [0, EltSize) means that we only need
8640 // to consider shift amounts with defined behavior.
8641 //
8642 // The IsRotate flag should be set when the LHS of both shifts is the same.
8643 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate,bool FromAdd)8644 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8645                            SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
8646   const auto &TLI = DAG.getTargetLoweringInfo();
8647   // If EltSize is a power of 2 then:
8648   //
8649   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8650   //  (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8651   //
8652   // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8653   // for the stronger condition:
8654   //
8655   //     Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1)    [A]
8656   //
8657   // for all Neg and Pos.  Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8658   // we can just replace Neg with Neg' for the rest of the function.
8659   //
8660   // In other cases we check for the even stronger condition:
8661   //
8662   //     Neg == EltSize - Pos                                    [B]
8663   //
8664   // for all Neg and Pos.  Note that the (or ...) then invokes undefined
8665   // behavior if Pos == 0 (and consequently Neg == EltSize).
8666   //
8667   // We could actually use [A] whenever EltSize is a power of 2, but the
8668   // only extra cases that it would match are those uninteresting ones
8669   // where Neg and Pos are never in range at the same time.  E.g. for
8670   // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8671   // as well as (sub 32, Pos), but:
8672   //
8673   //     (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8674   //
8675   // always invokes undefined behavior for 32-bit X.
8676   //
8677   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8678   // This allows us to peek through any operations that only affect Mask's
8679   // un-demanded bits.
8680   //
8681   // NOTE: We can only do this when matching operations which won't modify the
8682   // least Log2(EltSize) significant bits and not a general funnel shift.
8683   unsigned MaskLoBits = 0;
8684   if (IsRotate && !FromAdd && isPowerOf2_64(EltSize)) {
8685     unsigned Bits = Log2_64(EltSize);
8686     unsigned NegBits = Neg.getScalarValueSizeInBits();
8687     if (NegBits >= Bits) {
8688       APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
8689       if (SDValue Inner =
8690               TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
8691         Neg = Inner;
8692         MaskLoBits = Bits;
8693       }
8694     }
8695   }
8696 
8697   // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8698   if (Neg.getOpcode() != ISD::SUB)
8699     return false;
8700   ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
8701   if (!NegC)
8702     return false;
8703   SDValue NegOp1 = Neg.getOperand(1);
8704 
8705   // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8706   // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8707   // are redundant for the purpose of the equality.
8708   if (MaskLoBits) {
8709     unsigned PosBits = Pos.getScalarValueSizeInBits();
8710     if (PosBits >= MaskLoBits) {
8711       APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
8712       if (SDValue Inner =
8713               TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
8714         Pos = Inner;
8715       }
8716     }
8717   }
8718 
8719   // The condition we need is now:
8720   //
8721   //     (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8722   //
8723   // If NegOp1 == Pos then we need:
8724   //
8725   //              EltSize & Mask == NegC & Mask
8726   //
8727   // (because "x & Mask" is a truncation and distributes through subtraction).
8728   //
8729   // We also need to account for a potential truncation of NegOp1 if the amount
8730   // has already been legalized to a shift amount type.
8731   APInt Width;
8732   if ((Pos == NegOp1) ||
8733       (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
8734     Width = NegC->getAPIntValue();
8735 
8736   // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8737   // Then the condition we want to prove becomes:
8738   //
8739   //     (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8740   //
8741   // which, again because "x & Mask" is a truncation, becomes:
8742   //
8743   //                NegC & Mask == (EltSize - PosC) & Mask
8744   //             EltSize & Mask == (NegC + PosC) & Mask
8745   else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
8746     if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
8747       Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8748     else
8749       return false;
8750   } else
8751     return false;
8752 
8753   // Now we just need to check that EltSize & Mask == Width & Mask.
8754   if (MaskLoBits)
8755     // EltSize & Mask is 0 since Mask is EltSize - 1.
8756     return Width.getLoBits(MaskLoBits) == 0;
8757   return Width == EltSize;
8758 }
8759 
8760 // A subroutine of MatchRotate used once we have found an OR of two opposite
8761 // shifts of Shifted.  If Neg == <operand size> - Pos then the OR reduces
8762 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8763 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
8764 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool FromAdd,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8765 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8766                                        SDValue Neg, SDValue InnerPos,
8767                                        SDValue InnerNeg, bool FromAdd,
8768                                        bool HasPos, unsigned PosOpcode,
8769                                        unsigned NegOpcode, const SDLoc &DL) {
8770   // fold (or/add (shl x, (*ext y)),
8771   //              (srl x, (*ext (sub 32, y)))) ->
8772   //   (rotl x, y) or (rotr x, (sub 32, y))
8773   //
8774   // fold (or/add (shl x, (*ext (sub 32, y))),
8775   //              (srl x, (*ext y))) ->
8776   //   (rotr x, y) or (rotl x, (sub 32, y))
8777   EVT VT = Shifted.getValueType();
8778   if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
8779                      /*IsRotate*/ true, FromAdd))
8780     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
8781                        HasPos ? Pos : Neg);
8782 
8783   return SDValue();
8784 }
8785 
8786 // A subroutine of MatchRotate used once we have found an OR of two opposite
8787 // shifts of N0 + N1.  If Neg == <operand size> - Pos then the OR reduces
8788 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8789 // former being preferred if supported.  InnerPos and InnerNeg are Pos and
8790 // Neg with outer conversions stripped away.
8791 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool FromAdd,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8792 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8793                                        SDValue Neg, SDValue InnerPos,
8794                                        SDValue InnerNeg, bool FromAdd,
8795                                        bool HasPos, unsigned PosOpcode,
8796                                        unsigned NegOpcode, const SDLoc &DL) {
8797   EVT VT = N0.getValueType();
8798   unsigned EltBits = VT.getScalarSizeInBits();
8799 
8800   // fold (or/add (shl x0, (*ext y)),
8801   //              (srl x1, (*ext (sub 32, y)))) ->
8802   //   (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8803   //
8804   // fold (or/add (shl x0, (*ext (sub 32, y))),
8805   //              (srl x1, (*ext y))) ->
8806   //   (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8807   if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
8808                      FromAdd))
8809     return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
8810                        HasPos ? Pos : Neg);
8811 
8812   // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8813   // so for now just use the PosOpcode case if its legal.
8814   // TODO: When can we use the NegOpcode case?
8815   if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
8816     SDValue X;
8817     // fold (or/add (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8818     //   -> (fshl x0, x1, y)
8819     if (sd_match(N1, m_Srl(m_Value(X), m_One())) &&
8820         sd_match(InnerNeg,
8821                  m_Xor(m_Specific(InnerPos), m_SpecificInt(EltBits - 1))) &&
8822         TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
8823       return DAG.getNode(ISD::FSHL, DL, VT, N0, X, Pos);
8824     }
8825 
8826     // fold (or/add (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8827     //   -> (fshr x0, x1, y)
8828     if (sd_match(N0, m_Shl(m_Value(X), m_One())) &&
8829         sd_match(InnerPos,
8830                  m_Xor(m_Specific(InnerNeg), m_SpecificInt(EltBits - 1))) &&
8831         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8832       return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
8833     }
8834 
8835     // fold (or/add (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8836     //   -> (fshr x0, x1, y)
8837     // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8838     if (sd_match(N0, m_Add(m_Value(X), m_Deferred(X))) &&
8839         sd_match(InnerPos,
8840                  m_Xor(m_Specific(InnerNeg), m_SpecificInt(EltBits - 1))) &&
8841         TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8842       return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
8843     }
8844   }
8845 
8846   return SDValue();
8847 }
8848 
8849 // MatchRotate - Handle an 'or' or 'add' of two operands.  If this is one of the
8850 // many idioms for rotate, and if the target supports rotation instructions,
8851 // generate a rot[lr]. This also matches funnel shift patterns, similar to
8852 // rotation but with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL,bool FromAdd)8853 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
8854                                  bool FromAdd) {
8855   EVT VT = LHS.getValueType();
8856 
8857   // The target must have at least one rotate/funnel flavor.
8858   // We still try to match rotate by constant pre-legalization.
8859   // TODO: Support pre-legalization funnel-shift by constant.
8860   bool HasROTL = hasOperation(ISD::ROTL, VT);
8861   bool HasROTR = hasOperation(ISD::ROTR, VT);
8862   bool HasFSHL = hasOperation(ISD::FSHL, VT);
8863   bool HasFSHR = hasOperation(ISD::FSHR, VT);
8864 
8865   // If the type is going to be promoted and the target has enabled custom
8866   // lowering for rotate, allow matching rotate by non-constants. Only allow
8867   // this for scalar types.
8868   if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
8869                                   TargetLowering::TypePromoteInteger) {
8870     HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
8871     HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
8872   }
8873 
8874   if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8875     return SDValue();
8876 
8877   // Check for truncated rotate.
8878   if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8879       LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
8880     assert(LHS.getValueType() == RHS.getValueType());
8881     if (SDValue Rot =
8882             MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL, FromAdd))
8883       return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
8884   }
8885 
8886   // Match "(X shl/srl V1) & V2" where V2 may not be present.
8887   SDValue LHSShift;   // The shift.
8888   SDValue LHSMask;    // AND value if any.
8889   matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
8890 
8891   SDValue RHSShift;   // The shift.
8892   SDValue RHSMask;    // AND value if any.
8893   matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
8894 
8895   // If neither side matched a rotate half, bail
8896   if (!LHSShift && !RHSShift)
8897     return SDValue();
8898 
8899   // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8900   // side of the rotate, so try to handle that here. In all cases we need to
8901   // pass the matched shift from the opposite side to compute the opcode and
8902   // needed shift amount to extract.  We still want to do this if both sides
8903   // matched a rotate half because one half may be a potential overshift that
8904   // can be broken down (ie if InstCombine merged two shl or srl ops into a
8905   // single one).
8906 
8907   // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8908   if (LHSShift)
8909     if (SDValue NewRHSShift =
8910             extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
8911       RHSShift = NewRHSShift;
8912   // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8913   if (RHSShift)
8914     if (SDValue NewLHSShift =
8915             extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
8916       LHSShift = NewLHSShift;
8917 
8918   // If a side is still missing, nothing else we can do.
8919   if (!RHSShift || !LHSShift)
8920     return SDValue();
8921 
8922   // At this point we've matched or extracted a shift op on each side.
8923 
8924   if (LHSShift.getOpcode() == RHSShift.getOpcode())
8925     return SDValue(); // Shifts must disagree.
8926 
8927   // Canonicalize shl to left side in a shl/srl pair.
8928   if (RHSShift.getOpcode() == ISD::SHL) {
8929     std::swap(LHS, RHS);
8930     std::swap(LHSShift, RHSShift);
8931     std::swap(LHSMask, RHSMask);
8932   }
8933 
8934   // Something has gone wrong - we've lost the shl/srl pair - bail.
8935   if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8936     return SDValue();
8937 
8938   unsigned EltSizeInBits = VT.getScalarSizeInBits();
8939   SDValue LHSShiftArg = LHSShift.getOperand(0);
8940   SDValue LHSShiftAmt = LHSShift.getOperand(1);
8941   SDValue RHSShiftArg = RHSShift.getOperand(0);
8942   SDValue RHSShiftAmt = RHSShift.getOperand(1);
8943 
8944   auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8945                                         ConstantSDNode *RHS) {
8946     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8947   };
8948 
8949   auto ApplyMasks = [&](SDValue Res) {
8950     // If there is an AND of either shifted operand, apply it to the result.
8951     if (LHSMask.getNode() || RHSMask.getNode()) {
8952       SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8953       SDValue Mask = AllOnes;
8954 
8955       if (LHSMask.getNode()) {
8956         SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
8957         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8958                            DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
8959       }
8960       if (RHSMask.getNode()) {
8961         SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
8962         Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8963                            DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
8964       }
8965 
8966       Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
8967     }
8968 
8969     return Res;
8970   };
8971 
8972   // TODO: Support pre-legalization funnel-shift by constant.
8973   bool IsRotate = LHSShiftArg == RHSShiftArg;
8974   if (!IsRotate && !(HasFSHL || HasFSHR)) {
8975     if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8976         ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
8977       // Look for a disguised rotate by constant.
8978       // The common shifted operand X may be hidden inside another 'or'.
8979       SDValue X, Y;
8980       auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8981         if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8982           return false;
8983         if (CommonOp == Or.getOperand(0)) {
8984           X = CommonOp;
8985           Y = Or.getOperand(1);
8986           return true;
8987         }
8988         if (CommonOp == Or.getOperand(1)) {
8989           X = CommonOp;
8990           Y = Or.getOperand(0);
8991           return true;
8992         }
8993         return false;
8994       };
8995 
8996       SDValue Res;
8997       if (matchOr(LHSShiftArg, RHSShiftArg)) {
8998         // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
8999         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
9000         SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
9001         Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
9002       } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
9003         // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
9004         SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
9005         SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
9006         Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
9007       } else {
9008         return SDValue();
9009       }
9010 
9011       return ApplyMasks(Res);
9012     }
9013 
9014     return SDValue(); // Requires funnel shift support.
9015   }
9016 
9017   // fold (or/add (shl x, C1), (srl x, C2)) -> (rotl x, C1)
9018   // fold (or/add (shl x, C1), (srl x, C2)) -> (rotr x, C2)
9019   // fold (or/add (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
9020   // fold (or/add (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
9021   // iff C1+C2 == EltSizeInBits
9022   if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
9023     SDValue Res;
9024     if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
9025       bool UseROTL = !LegalOperations || HasROTL;
9026       Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
9027                         UseROTL ? LHSShiftAmt : RHSShiftAmt);
9028     } else {
9029       bool UseFSHL = !LegalOperations || HasFSHL;
9030       Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
9031                         RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
9032     }
9033 
9034     return ApplyMasks(Res);
9035   }
9036 
9037   // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
9038   // shift.
9039   if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
9040     return SDValue();
9041 
9042   // If there is a mask here, and we have a variable shift, we can't be sure
9043   // that we're masking out the right stuff.
9044   if (LHSMask.getNode() || RHSMask.getNode())
9045     return SDValue();
9046 
9047   // If the shift amount is sign/zext/any-extended just peel it off.
9048   SDValue LExtOp0 = LHSShiftAmt;
9049   SDValue RExtOp0 = RHSShiftAmt;
9050   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9051        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9052        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9053        LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
9054       (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9055        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9056        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9057        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
9058     LExtOp0 = LHSShiftAmt.getOperand(0);
9059     RExtOp0 = RHSShiftAmt.getOperand(0);
9060   }
9061 
9062   if (IsRotate && (HasROTL || HasROTR)) {
9063     if (SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
9064                                          LExtOp0, RExtOp0, FromAdd, HasROTL,
9065                                          ISD::ROTL, ISD::ROTR, DL))
9066       return TryL;
9067 
9068     if (SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
9069                                          RExtOp0, LExtOp0, FromAdd, HasROTR,
9070                                          ISD::ROTR, ISD::ROTL, DL))
9071       return TryR;
9072   }
9073 
9074   if (SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt,
9075                                        RHSShiftAmt, LExtOp0, RExtOp0, FromAdd,
9076                                        HasFSHL, ISD::FSHL, ISD::FSHR, DL))
9077     return TryL;
9078 
9079   if (SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt,
9080                                        LHSShiftAmt, RExtOp0, LExtOp0, FromAdd,
9081                                        HasFSHR, ISD::FSHR, ISD::FSHL, DL))
9082     return TryR;
9083 
9084   return SDValue();
9085 }
9086 
9087 /// Recursively traverses the expression calculating the origin of the requested
9088 /// byte of the given value. Returns std::nullopt if the provider can't be
9089 /// calculated.
9090 ///
9091 /// For all the values except the root of the expression, we verify that the
9092 /// value has exactly one use and if not then return std::nullopt. This way if
9093 /// the origin of the byte is returned it's guaranteed that the values which
9094 /// contribute to the byte are not used outside of this expression.
9095 
9096 /// However, there is a special case when dealing with vector loads -- we allow
9097 /// more than one use if the load is a vector type.  Since the values that
9098 /// contribute to the byte ultimately come from the ExtractVectorElements of the
9099 /// Load, we don't care if the Load has uses other than ExtractVectorElements,
9100 /// because those operations are independent from the pattern to be combined.
9101 /// For vector loads, we simply care that the ByteProviders are adjacent
9102 /// positions of the same vector, and their index matches the byte that is being
9103 /// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
9104 /// is the index used in an ExtractVectorElement, and \p StartingIndex is the
9105 /// byte position we are trying to provide for the LoadCombine. If these do
9106 /// not match, then we can not combine the vector loads. \p Index uses the
9107 /// byte position we are trying to provide for and is matched against the
9108 /// shl and load size. The \p Index algorithm ensures the requested byte is
9109 /// provided for by the pattern, and the pattern does not over provide bytes.
9110 ///
9111 ///
9112 /// The supported LoadCombine pattern for vector loads is as follows
9113 ///                              or
9114 ///                          /        \
9115 ///                         or        shl
9116 ///                       /     \      |
9117 ///                     or      shl   zext
9118 ///                   /    \     |     |
9119 ///                 shl   zext  zext  EVE*
9120 ///                  |     |     |     |
9121 ///                 zext  EVE*  EVE*  LOAD
9122 ///                  |     |     |
9123 ///                 EVE*  LOAD  LOAD
9124 ///                  |
9125 ///                 LOAD
9126 ///
9127 /// *ExtractVectorElement
9128 using SDByteProvider = ByteProvider<SDNode *>;
9129 
9130 static std::optional<SDByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,std::optional<uint64_t> VectorIndex,unsigned StartingIndex=0)9131 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
9132                       std::optional<uint64_t> VectorIndex,
9133                       unsigned StartingIndex = 0) {
9134 
9135   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
9136   if (Depth == 10)
9137     return std::nullopt;
9138 
9139   // Only allow multiple uses if the instruction is a vector load (in which
9140   // case we will use the load for every ExtractVectorElement)
9141   if (Depth && !Op.hasOneUse() &&
9142       (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
9143     return std::nullopt;
9144 
9145   // Fail to combine if we have encountered anything but a LOAD after handling
9146   // an ExtractVectorElement.
9147   if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
9148     return std::nullopt;
9149 
9150   unsigned BitWidth = Op.getScalarValueSizeInBits();
9151   if (BitWidth % 8 != 0)
9152     return std::nullopt;
9153   unsigned ByteWidth = BitWidth / 8;
9154   assert(Index < ByteWidth && "invalid index requested");
9155   (void) ByteWidth;
9156 
9157   switch (Op.getOpcode()) {
9158   case ISD::OR: {
9159     auto LHS =
9160         calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
9161     if (!LHS)
9162       return std::nullopt;
9163     auto RHS =
9164         calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
9165     if (!RHS)
9166       return std::nullopt;
9167 
9168     if (LHS->isConstantZero())
9169       return RHS;
9170     if (RHS->isConstantZero())
9171       return LHS;
9172     return std::nullopt;
9173   }
9174   case ISD::SHL: {
9175     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
9176     if (!ShiftOp)
9177       return std::nullopt;
9178 
9179     uint64_t BitShift = ShiftOp->getZExtValue();
9180 
9181     if (BitShift % 8 != 0)
9182       return std::nullopt;
9183     uint64_t ByteShift = BitShift / 8;
9184 
9185     // If we are shifting by an amount greater than the index we are trying to
9186     // provide, then do not provide anything. Otherwise, subtract the index by
9187     // the amount we shifted by.
9188     return Index < ByteShift
9189                ? SDByteProvider::getConstantZero()
9190                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
9191                                        Depth + 1, VectorIndex, Index);
9192   }
9193   case ISD::ANY_EXTEND:
9194   case ISD::SIGN_EXTEND:
9195   case ISD::ZERO_EXTEND: {
9196     SDValue NarrowOp = Op->getOperand(0);
9197     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9198     if (NarrowBitWidth % 8 != 0)
9199       return std::nullopt;
9200     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9201 
9202     if (Index >= NarrowByteWidth)
9203       return Op.getOpcode() == ISD::ZERO_EXTEND
9204                  ? std::optional<SDByteProvider>(
9205                        SDByteProvider::getConstantZero())
9206                  : std::nullopt;
9207     return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
9208                                  StartingIndex);
9209   }
9210   case ISD::BSWAP:
9211     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
9212                                  Depth + 1, VectorIndex, StartingIndex);
9213   case ISD::EXTRACT_VECTOR_ELT: {
9214     auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
9215     if (!OffsetOp)
9216       return std::nullopt;
9217 
9218     VectorIndex = OffsetOp->getZExtValue();
9219 
9220     SDValue NarrowOp = Op->getOperand(0);
9221     unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9222     if (NarrowBitWidth % 8 != 0)
9223       return std::nullopt;
9224     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9225     // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
9226     // type, leaving the high bits undefined.
9227     if (Index >= NarrowByteWidth)
9228       return std::nullopt;
9229 
9230     // Check to see if the position of the element in the vector corresponds
9231     // with the byte we are trying to provide for. In the case of a vector of
9232     // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
9233     // the element will provide a range of bytes. For example, if we have a
9234     // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
9235     // 3).
9236     if (*VectorIndex * NarrowByteWidth > StartingIndex)
9237       return std::nullopt;
9238     if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
9239       return std::nullopt;
9240 
9241     return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
9242                                  VectorIndex, StartingIndex);
9243   }
9244   case ISD::LOAD: {
9245     auto L = cast<LoadSDNode>(Op.getNode());
9246     if (!L->isSimple() || L->isIndexed())
9247       return std::nullopt;
9248 
9249     unsigned NarrowBitWidth = L->getMemoryVT().getScalarSizeInBits();
9250     if (NarrowBitWidth % 8 != 0)
9251       return std::nullopt;
9252     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9253 
9254     // If the width of the load does not reach byte we are trying to provide for
9255     // and it is not a ZEXTLOAD, then the load does not provide for the byte in
9256     // question
9257     if (Index >= NarrowByteWidth)
9258       return L->getExtensionType() == ISD::ZEXTLOAD
9259                  ? std::optional<SDByteProvider>(
9260                        SDByteProvider::getConstantZero())
9261                  : std::nullopt;
9262 
9263     unsigned BPVectorIndex = VectorIndex.value_or(0U);
9264     return SDByteProvider::getSrc(L, Index, BPVectorIndex);
9265   }
9266   }
9267 
9268   return std::nullopt;
9269 }
9270 
littleEndianByteAt(unsigned BW,unsigned i)9271 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
9272   return i;
9273 }
9274 
bigEndianByteAt(unsigned BW,unsigned i)9275 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
9276   return BW - i - 1;
9277 }
9278 
9279 // Check if the bytes offsets we are looking at match with either big or
9280 // little endian value loaded. Return true for big endian, false for little
9281 // endian, and std::nullopt if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)9282 static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
9283                                        int64_t FirstOffset) {
9284   // The endian can be decided only when it is 2 bytes at least.
9285   unsigned Width = ByteOffsets.size();
9286   if (Width < 2)
9287     return std::nullopt;
9288 
9289   bool BigEndian = true, LittleEndian = true;
9290   for (unsigned i = 0; i < Width; i++) {
9291     int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
9292     LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
9293     BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
9294     if (!BigEndian && !LittleEndian)
9295       return std::nullopt;
9296   }
9297 
9298   assert((BigEndian != LittleEndian) && "It should be either big endian or"
9299                                         "little endian");
9300   return BigEndian;
9301 }
9302 
9303 // Look through one layer of truncate or extend.
stripTruncAndExt(SDValue Value)9304 static SDValue stripTruncAndExt(SDValue Value) {
9305   switch (Value.getOpcode()) {
9306   case ISD::TRUNCATE:
9307   case ISD::ZERO_EXTEND:
9308   case ISD::SIGN_EXTEND:
9309   case ISD::ANY_EXTEND:
9310     return Value.getOperand(0);
9311   }
9312   return SDValue();
9313 }
9314 
9315 /// Match a pattern where a wide type scalar value is stored by several narrow
9316 /// stores. Fold it into a single store or a BSWAP and a store if the targets
9317 /// supports it.
9318 ///
9319 /// Assuming little endian target:
9320 ///  i8 *p = ...
9321 ///  i32 val = ...
9322 ///  p[0] = (val >> 0) & 0xFF;
9323 ///  p[1] = (val >> 8) & 0xFF;
9324 ///  p[2] = (val >> 16) & 0xFF;
9325 ///  p[3] = (val >> 24) & 0xFF;
9326 /// =>
9327 ///  *((i32)p) = val;
9328 ///
9329 ///  i8 *p = ...
9330 ///  i32 val = ...
9331 ///  p[0] = (val >> 24) & 0xFF;
9332 ///  p[1] = (val >> 16) & 0xFF;
9333 ///  p[2] = (val >> 8) & 0xFF;
9334 ///  p[3] = (val >> 0) & 0xFF;
9335 /// =>
9336 ///  *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)9337 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
9338   // The matching looks for "store (trunc x)" patterns that appear early but are
9339   // likely to be replaced by truncating store nodes during combining.
9340   // TODO: If there is evidence that running this later would help, this
9341   //       limitation could be removed. Legality checks may need to be added
9342   //       for the created store and optional bswap/rotate.
9343   if (LegalOperations || OptLevel == CodeGenOptLevel::None)
9344     return SDValue();
9345 
9346   // We only handle merging simple stores of 1-4 bytes.
9347   // TODO: Allow unordered atomics when wider type is legal (see D66309)
9348   EVT MemVT = N->getMemoryVT();
9349   if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
9350       !N->isSimple() || N->isIndexed())
9351     return SDValue();
9352 
9353   // Collect all of the stores in the chain, upto the maximum store width (i64).
9354   SDValue Chain = N->getChain();
9355   SmallVector<StoreSDNode *, 8> Stores = {N};
9356   unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
9357   unsigned MaxWideNumBits = 64;
9358   unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
9359   while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
9360     // All stores must be the same size to ensure that we are writing all of the
9361     // bytes in the wide value.
9362     // This store should have exactly one use as a chain operand for another
9363     // store in the merging set. If there are other chain uses, then the
9364     // transform may not be safe because order of loads/stores outside of this
9365     // set may not be preserved.
9366     // TODO: We could allow multiple sizes by tracking each stored byte.
9367     if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
9368         Store->isIndexed() || !Store->hasOneUse())
9369       return SDValue();
9370     Stores.push_back(Store);
9371     Chain = Store->getChain();
9372     if (MaxStores < Stores.size())
9373       return SDValue();
9374   }
9375   // There is no reason to continue if we do not have at least a pair of stores.
9376   if (Stores.size() < 2)
9377     return SDValue();
9378 
9379   // Handle simple types only.
9380   LLVMContext &Context = *DAG.getContext();
9381   unsigned NumStores = Stores.size();
9382   unsigned WideNumBits = NumStores * NarrowNumBits;
9383   EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
9384   if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
9385     return SDValue();
9386 
9387   // Check if all bytes of the source value that we are looking at are stored
9388   // to the same base address. Collect offsets from Base address into OffsetMap.
9389   SDValue SourceValue;
9390   SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
9391   int64_t FirstOffset = INT64_MAX;
9392   StoreSDNode *FirstStore = nullptr;
9393   std::optional<BaseIndexOffset> Base;
9394   for (auto *Store : Stores) {
9395     // All the stores store different parts of the CombinedValue. A truncate is
9396     // required to get the partial value.
9397     SDValue Trunc = Store->getValue();
9398     if (Trunc.getOpcode() != ISD::TRUNCATE)
9399       return SDValue();
9400     // Other than the first/last part, a shift operation is required to get the
9401     // offset.
9402     int64_t Offset = 0;
9403     SDValue WideVal = Trunc.getOperand(0);
9404     if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
9405         isa<ConstantSDNode>(WideVal.getOperand(1))) {
9406       // The shift amount must be a constant multiple of the narrow type.
9407       // It is translated to the offset address in the wide source value "y".
9408       //
9409       // x = srl y, ShiftAmtC
9410       // i8 z = trunc x
9411       // store z, ...
9412       uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
9413       if (ShiftAmtC % NarrowNumBits != 0)
9414         return SDValue();
9415 
9416       // Make sure we aren't reading bits that are shifted in.
9417       if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
9418         return SDValue();
9419 
9420       Offset = ShiftAmtC / NarrowNumBits;
9421       WideVal = WideVal.getOperand(0);
9422     }
9423 
9424     // Stores must share the same source value with different offsets.
9425     if (!SourceValue)
9426       SourceValue = WideVal;
9427     else if (SourceValue != WideVal) {
9428       // Truncate and extends can be stripped to see if the values are related.
9429       if (stripTruncAndExt(SourceValue) != WideVal &&
9430           stripTruncAndExt(WideVal) != SourceValue)
9431         return SDValue();
9432 
9433       if (WideVal.getScalarValueSizeInBits() >
9434           SourceValue.getScalarValueSizeInBits())
9435         SourceValue = WideVal;
9436 
9437       // Give up if the source value type is smaller than the store size.
9438       if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
9439         return SDValue();
9440     }
9441 
9442     // Stores must share the same base address.
9443     BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
9444     int64_t ByteOffsetFromBase = 0;
9445     if (!Base)
9446       Base = Ptr;
9447     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9448       return SDValue();
9449 
9450     // Remember the first store.
9451     if (ByteOffsetFromBase < FirstOffset) {
9452       FirstStore = Store;
9453       FirstOffset = ByteOffsetFromBase;
9454     }
9455     // Map the offset in the store and the offset in the combined value, and
9456     // early return if it has been set before.
9457     if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
9458       return SDValue();
9459     OffsetMap[Offset] = ByteOffsetFromBase;
9460   }
9461 
9462   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9463   assert(FirstStore && "First store must be set");
9464 
9465   // Check that a store of the wide type is both allowed and fast on the target
9466   const DataLayout &Layout = DAG.getDataLayout();
9467   unsigned Fast = 0;
9468   bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
9469                                         *FirstStore->getMemOperand(), &Fast);
9470   if (!Allowed || !Fast)
9471     return SDValue();
9472 
9473   // Check if the pieces of the value are going to the expected places in memory
9474   // to merge the stores.
9475   auto checkOffsets = [&](bool MatchLittleEndian) {
9476     if (MatchLittleEndian) {
9477       for (unsigned i = 0; i != NumStores; ++i)
9478         if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9479           return false;
9480     } else { // MatchBigEndian by reversing loop counter.
9481       for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9482         if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9483           return false;
9484     }
9485     return true;
9486   };
9487 
9488   // Check if the offsets line up for the native data layout of this target.
9489   bool NeedBswap = false;
9490   bool NeedRotate = false;
9491   if (!checkOffsets(Layout.isLittleEndian())) {
9492     // Special-case: check if byte offsets line up for the opposite endian.
9493     if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9494       NeedBswap = true;
9495     else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9496       NeedRotate = true;
9497     else
9498       return SDValue();
9499   }
9500 
9501   SDLoc DL(N);
9502   if (WideVT != SourceValue.getValueType()) {
9503     assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9504            "Unexpected store value to merge");
9505     SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
9506   }
9507 
9508   // Before legalize we can introduce illegal bswaps/rotates which will be later
9509   // converted to an explicit bswap sequence. This way we end up with a single
9510   // store and byte shuffling instead of several stores and byte shuffling.
9511   if (NeedBswap) {
9512     SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
9513   } else if (NeedRotate) {
9514     assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9515     SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
9516     SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
9517   }
9518 
9519   SDValue NewStore =
9520       DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
9521                    FirstStore->getPointerInfo(), FirstStore->getAlign());
9522 
9523   // Rely on other DAG combine rules to remove the other individual stores.
9524   DAG.ReplaceAllUsesWith(N, NewStore.getNode());
9525   return NewStore;
9526 }
9527 
9528 /// Match a pattern where a wide type scalar value is loaded by several narrow
9529 /// loads and combined by shifts and ors. Fold it into a single load or a load
9530 /// and a BSWAP if the targets supports it.
9531 ///
9532 /// Assuming little endian target:
9533 ///  i8 *a = ...
9534 ///  i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9535 /// =>
9536 ///  i32 val = *((i32)a)
9537 ///
9538 ///  i8 *a = ...
9539 ///  i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9540 /// =>
9541 ///  i32 val = BSWAP(*((i32)a))
9542 ///
9543 /// TODO: This rule matches complex patterns with OR node roots and doesn't
9544 /// interact well with the worklist mechanism. When a part of the pattern is
9545 /// updated (e.g. one of the loads) its direct users are put into the worklist,
9546 /// but the root node of the pattern which triggers the load combine is not
9547 /// necessarily a direct user of the changed node. For example, once the address
9548 /// of t28 load is reassociated load combine won't be triggered:
9549 ///             t25: i32 = add t4, Constant:i32<2>
9550 ///           t26: i64 = sign_extend t25
9551 ///        t27: i64 = add t2, t26
9552 ///       t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9553 ///     t29: i32 = zero_extend t28
9554 ///   t32: i32 = shl t29, Constant:i8<8>
9555 /// t33: i32 = or t23, t32
9556 /// As a possible fix visitLoad can check if the load can be a part of a load
9557 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)9558 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9559   assert(N->getOpcode() == ISD::OR &&
9560          "Can only match load combining against OR nodes");
9561 
9562   // Handles simple types only
9563   EVT VT = N->getValueType(0);
9564   if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9565     return SDValue();
9566   unsigned ByteWidth = VT.getSizeInBits() / 8;
9567 
9568   bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9569   auto MemoryByteOffset = [&](SDByteProvider P) {
9570     assert(P.hasSrc() && "Must be a memory byte provider");
9571     auto *Load = cast<LoadSDNode>(P.Src.value());
9572 
9573     unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9574 
9575     assert(LoadBitWidth % 8 == 0 &&
9576            "can only analyze providers for individual bytes not bit");
9577     unsigned LoadByteWidth = LoadBitWidth / 8;
9578     return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset)
9579                              : littleEndianByteAt(LoadByteWidth, P.DestOffset);
9580   };
9581 
9582   std::optional<BaseIndexOffset> Base;
9583   SDValue Chain;
9584 
9585   SmallPtrSet<LoadSDNode *, 8> Loads;
9586   std::optional<SDByteProvider> FirstByteProvider;
9587   int64_t FirstOffset = INT64_MAX;
9588 
9589   // Check if all the bytes of the OR we are looking at are loaded from the same
9590   // base address. Collect bytes offsets from Base address in ByteOffsets.
9591   SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9592   unsigned ZeroExtendedBytes = 0;
9593   for (int i = ByteWidth - 1; i >= 0; --i) {
9594     auto P =
9595         calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
9596                               /*StartingIndex*/ i);
9597     if (!P)
9598       return SDValue();
9599 
9600     if (P->isConstantZero()) {
9601       // It's OK for the N most significant bytes to be 0, we can just
9602       // zero-extend the load.
9603       if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9604         return SDValue();
9605       continue;
9606     }
9607     assert(P->hasSrc() && "provenance should either be memory or zero");
9608     auto *L = cast<LoadSDNode>(P->Src.value());
9609 
9610     // All loads must share the same chain
9611     SDValue LChain = L->getChain();
9612     if (!Chain)
9613       Chain = LChain;
9614     else if (Chain != LChain)
9615       return SDValue();
9616 
9617     // Loads must share the same base address
9618     BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
9619     int64_t ByteOffsetFromBase = 0;
9620 
9621     // For vector loads, the expected load combine pattern will have an
9622     // ExtractElement for each index in the vector. While each of these
9623     // ExtractElements will be accessing the same base address as determined
9624     // by the load instruction, the actual bytes they interact with will differ
9625     // due to different ExtractElement indices. To accurately determine the
9626     // byte position of an ExtractElement, we offset the base load ptr with
9627     // the index multiplied by the byte size of each element in the vector.
9628     if (L->getMemoryVT().isVector()) {
9629       unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9630       if (LoadWidthInBit % 8 != 0)
9631         return SDValue();
9632       unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9633       Ptr.addToOffset(ByteOffsetFromVector);
9634     }
9635 
9636     if (!Base)
9637       Base = Ptr;
9638 
9639     else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9640       return SDValue();
9641 
9642     // Calculate the offset of the current byte from the base address
9643     ByteOffsetFromBase += MemoryByteOffset(*P);
9644     ByteOffsets[i] = ByteOffsetFromBase;
9645 
9646     // Remember the first byte load
9647     if (ByteOffsetFromBase < FirstOffset) {
9648       FirstByteProvider = P;
9649       FirstOffset = ByteOffsetFromBase;
9650     }
9651 
9652     Loads.insert(L);
9653   }
9654 
9655   assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9656          "memory, so there must be at least one load which produces the value");
9657   assert(Base && "Base address of the accessed memory location must be set");
9658   assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9659 
9660   bool NeedsZext = ZeroExtendedBytes > 0;
9661 
9662   EVT MemVT =
9663       EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
9664 
9665   if (!MemVT.isSimple())
9666     return SDValue();
9667 
9668   // Before legalize we can introduce too wide illegal loads which will be later
9669   // split into legal sized loads. This enables us to combine i64 load by i8
9670   // patterns to a couple of i32 loads on 32 bit targets.
9671   if (LegalOperations &&
9672       !TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
9673                           MemVT))
9674     return SDValue();
9675 
9676   // Check if the bytes of the OR we are looking at match with either big or
9677   // little endian value load
9678   std::optional<bool> IsBigEndian = isBigEndian(
9679       ArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
9680   if (!IsBigEndian)
9681     return SDValue();
9682 
9683   assert(FirstByteProvider && "must be set");
9684 
9685   // Ensure that the first byte is loaded from zero offset of the first load.
9686   // So the combined value can be loaded from the first load address.
9687   if (MemoryByteOffset(*FirstByteProvider) != 0)
9688     return SDValue();
9689   auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value());
9690 
9691   // The node we are looking at matches with the pattern, check if we can
9692   // replace it with a single (possibly zero-extended) load and bswap + shift if
9693   // needed.
9694 
9695   // If the load needs byte swap check if the target supports it
9696   bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9697 
9698   // Before legalize we can introduce illegal bswaps which will be later
9699   // converted to an explicit bswap sequence. This way we end up with a single
9700   // load and byte shuffling instead of several loads and byte shuffling.
9701   // We do not introduce illegal bswaps when zero-extending as this tends to
9702   // introduce too many arithmetic instructions.
9703   if (NeedsBswap && (LegalOperations || NeedsZext) &&
9704       !TLI.isOperationLegal(ISD::BSWAP, VT))
9705     return SDValue();
9706 
9707   // If we need to bswap and zero extend, we have to insert a shift. Check that
9708   // it is legal.
9709   if (NeedsBswap && NeedsZext && LegalOperations &&
9710       !TLI.isOperationLegal(ISD::SHL, VT))
9711     return SDValue();
9712 
9713   // Check that a load of the wide type is both allowed and fast on the target
9714   unsigned Fast = 0;
9715   bool Allowed =
9716       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
9717                              *FirstLoad->getMemOperand(), &Fast);
9718   if (!Allowed || !Fast)
9719     return SDValue();
9720 
9721   SDValue NewLoad =
9722       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
9723                      Chain, FirstLoad->getBasePtr(),
9724                      FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
9725 
9726   // Transfer chain users from old loads to the new load.
9727   for (LoadSDNode *L : Loads)
9728     DAG.makeEquivalentMemoryOrdering(L, NewLoad);
9729 
9730   if (!NeedsBswap)
9731     return NewLoad;
9732 
9733   SDValue ShiftedLoad =
9734       NeedsZext ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
9735                               DAG.getShiftAmountConstant(ZeroExtendedBytes * 8,
9736                                                          VT, SDLoc(N)))
9737                 : NewLoad;
9738   return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
9739 }
9740 
9741 // If the target has andn, bsl, or a similar bit-select instruction,
9742 // we want to unfold masked merge, with canonical pattern of:
9743 //   |        A  |  |B|
9744 //   ((x ^ y) & m) ^ y
9745 //    |  D  |
9746 // Into:
9747 //   (x & m) | (y & ~m)
9748 // If y is a constant, m is not a 'not', and the 'andn' does not work with
9749 // immediates, we unfold into a different pattern:
9750 //   ~(~x & m) & (m | y)
9751 // If x is a constant, m is a 'not', and the 'andn' does not work with
9752 // immediates, we unfold into a different pattern:
9753 //   (x | ~m) & ~(~m & ~y)
9754 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9755 //       the very least that breaks andnpd / andnps patterns, and because those
9756 //       patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)9757 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9758   assert(N->getOpcode() == ISD::XOR);
9759 
9760   // Don't touch 'not' (i.e. where y = -1).
9761   if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
9762     return SDValue();
9763 
9764   EVT VT = N->getValueType(0);
9765 
9766   // There are 3 commutable operators in the pattern,
9767   // so we have to deal with 8 possible variants of the basic pattern.
9768   SDValue X, Y, M;
9769   auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9770     if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9771       return false;
9772     SDValue Xor = And.getOperand(XorIdx);
9773     if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9774       return false;
9775     SDValue Xor0 = Xor.getOperand(0);
9776     SDValue Xor1 = Xor.getOperand(1);
9777     // Don't touch 'not' (i.e. where y = -1).
9778     if (isAllOnesOrAllOnesSplat(Xor1))
9779       return false;
9780     if (Other == Xor0)
9781       std::swap(Xor0, Xor1);
9782     if (Other != Xor1)
9783       return false;
9784     X = Xor0;
9785     Y = Xor1;
9786     M = And.getOperand(XorIdx ? 0 : 1);
9787     return true;
9788   };
9789 
9790   SDValue N0 = N->getOperand(0);
9791   SDValue N1 = N->getOperand(1);
9792   if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9793       !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9794     return SDValue();
9795 
9796   // Don't do anything if the mask is constant. This should not be reachable.
9797   // InstCombine should have already unfolded this pattern, and DAGCombiner
9798   // probably shouldn't produce it, too.
9799   if (isa<ConstantSDNode>(M.getNode()))
9800     return SDValue();
9801 
9802   // We can transform if the target has AndNot
9803   if (!TLI.hasAndNot(M))
9804     return SDValue();
9805 
9806   SDLoc DL(N);
9807 
9808   // If Y is a constant, check that 'andn' works with immediates. Unless M is
9809   // a bitwise not that would already allow ANDN to be used.
9810   if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
9811     assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9812     // If not, we need to do a bit more work to make sure andn is still used.
9813     SDValue NotX = DAG.getNOT(DL, X, VT);
9814     SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
9815     SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
9816     SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
9817     return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
9818   }
9819 
9820   // If X is a constant and M is a bitwise not, check that 'andn' works with
9821   // immediates.
9822   if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
9823     assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9824     // If not, we need to do a bit more work to make sure andn is still used.
9825     SDValue NotM = M.getOperand(0);
9826     SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
9827     SDValue NotY = DAG.getNOT(DL, Y, VT);
9828     SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
9829     SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
9830     return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
9831   }
9832 
9833   SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
9834   SDValue NotM = DAG.getNOT(DL, M, VT);
9835   SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
9836 
9837   return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
9838 }
9839 
visitXOR(SDNode * N)9840 SDValue DAGCombiner::visitXOR(SDNode *N) {
9841   SDValue N0 = N->getOperand(0);
9842   SDValue N1 = N->getOperand(1);
9843   EVT VT = N0.getValueType();
9844   SDLoc DL(N);
9845 
9846   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9847   if (N0.isUndef() && N1.isUndef())
9848     return DAG.getConstant(0, DL, VT);
9849 
9850   // fold (xor x, undef) -> undef
9851   if (N0.isUndef())
9852     return N0;
9853   if (N1.isUndef())
9854     return N1;
9855 
9856   // fold (xor c1, c2) -> c1^c2
9857   if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
9858     return C;
9859 
9860   // canonicalize constant to RHS
9861   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
9862       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
9863     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
9864 
9865   // fold vector ops
9866   if (VT.isVector()) {
9867     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9868       return FoldedVOp;
9869 
9870     // fold (xor x, 0) -> x, vector edition
9871     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
9872       return N0;
9873   }
9874 
9875   // fold (xor x, 0) -> x
9876   if (isNullConstant(N1))
9877     return N0;
9878 
9879   if (SDValue NewSel = foldBinOpIntoSelect(N))
9880     return NewSel;
9881 
9882   // reassociate xor
9883   if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
9884     return RXOR;
9885 
9886   // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9887   if (SDValue SD =
9888           reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
9889     return SD;
9890 
9891   // fold (a^b) -> (a|b) iff a and b share no bits.
9892   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
9893       DAG.haveNoCommonBitsSet(N0, N1))
9894     return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
9895 
9896   // look for 'add-like' folds:
9897   // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9898   if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
9899       isMinSignedConstant(N1))
9900     if (SDValue Combined = visitADDLike(N))
9901       return Combined;
9902 
9903   // fold not (setcc x, y, cc) -> setcc x y !cc
9904   // Avoid breaking: and (not(setcc x, y, cc), z) -> andn for vec
9905   unsigned N0Opcode = N0.getOpcode();
9906   SDValue LHS, RHS, CC;
9907   if (TLI.isConstTrueVal(N1) &&
9908       isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true) &&
9909       !(VT.isVector() && TLI.hasAndNot(SDValue(N, 0)) && N->hasOneUse() &&
9910         N->use_begin()->getUser()->getOpcode() == ISD::AND)) {
9911     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
9912                                                LHS.getValueType());
9913     if (!LegalOperations ||
9914         TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
9915       switch (N0Opcode) {
9916       default:
9917         llvm_unreachable("Unhandled SetCC Equivalent!");
9918       case ISD::SETCC:
9919         return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
9920       case ISD::SELECT_CC:
9921         return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
9922                                N0.getOperand(3), NotCC);
9923       case ISD::STRICT_FSETCC:
9924       case ISD::STRICT_FSETCCS: {
9925         if (N0.hasOneUse()) {
9926           // FIXME Can we handle multiple uses? Could we token factor the chain
9927           // results from the new/old setcc?
9928           SDValue SetCC =
9929               DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
9930                            N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
9931           CombineTo(N, SetCC);
9932           DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
9933           recursivelyDeleteUnusedNodes(N0.getNode());
9934           return SDValue(N, 0); // Return N so it doesn't get rechecked!
9935         }
9936         break;
9937       }
9938       }
9939     }
9940   }
9941 
9942   // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9943   if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9944       isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
9945     SDValue V = N0.getOperand(0);
9946     SDLoc DL0(N0);
9947     V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
9948                     DAG.getConstant(1, DL0, V.getValueType()));
9949     AddToWorklist(V.getNode());
9950     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
9951   }
9952 
9953   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9954   // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are setcc
9955   if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
9956       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9957     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9958     if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
9959       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9960       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9961       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9962       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9963       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9964     }
9965   }
9966   // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9967   // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are constants
9968   if (isAllOnesConstant(N1) && N0.hasOneUse() &&
9969       (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9970     SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9971     if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
9972       unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9973       N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9974       N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9975       AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9976       return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9977     }
9978   }
9979 
9980   // fold (not (neg x)) -> (add X, -1)
9981   // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9982   // Y is a constant or the subtract has a single use.
9983   if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
9984       isNullConstant(N0.getOperand(0))) {
9985     return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
9986                        DAG.getAllOnesConstant(DL, VT));
9987   }
9988 
9989   // fold (not (add X, -1)) -> (neg X)
9990   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && isAllOnesConstant(N1) &&
9991       isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
9992     return DAG.getNegative(N0.getOperand(0), DL, VT);
9993   }
9994 
9995   // fold (xor (and x, y), y) -> (and (not x), y)
9996   if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
9997     SDValue X = N0.getOperand(0);
9998     SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
9999     AddToWorklist(NotX.getNode());
10000     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
10001   }
10002 
10003   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
10004   if (!LegalOperations || hasOperation(ISD::ABS, VT)) {
10005     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
10006     SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
10007     if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
10008       SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
10009       SDValue S0 = S.getOperand(0);
10010       if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
10011         if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
10012           if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
10013             return DAG.getNode(ISD::ABS, DL, VT, S0);
10014     }
10015   }
10016 
10017   // fold (xor x, x) -> 0
10018   if (N0 == N1)
10019     return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
10020 
10021   // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
10022   // Here is a concrete example of this equivalence:
10023   // i16   x ==  14
10024   // i16 shl ==   1 << 14  == 16384 == 0b0100000000000000
10025   // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
10026   //
10027   // =>
10028   //
10029   // i16     ~1      == 0b1111111111111110
10030   // i16 rol(~1, 14) == 0b1011111111111111
10031   //
10032   // Some additional tips to help conceptualize this transform:
10033   // - Try to see the operation as placing a single zero in a value of all ones.
10034   // - There exists no value for x which would allow the result to contain zero.
10035   // - Values of x larger than the bitwidth are undefined and do not require a
10036   //   consistent result.
10037   // - Pushing the zero left requires shifting one bits in from the right.
10038   // A rotate left of ~1 is a nice way of achieving the desired result.
10039   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
10040       isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
10041     return DAG.getNode(ISD::ROTL, DL, VT, DAG.getSignedConstant(~1, DL, VT),
10042                        N0.getOperand(1));
10043   }
10044 
10045   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
10046   if (N0Opcode == N1.getOpcode())
10047     if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
10048       return V;
10049 
10050   if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
10051     return R;
10052   if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
10053     return R;
10054   if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
10055     return R;
10056 
10057   // Unfold  ((x ^ y) & m) ^ y  into  (x & m) | (y & ~m)  if profitable
10058   if (SDValue MM = unfoldMaskedMerge(N))
10059     return MM;
10060 
10061   // Simplify the expression using non-local knowledge.
10062   if (SimplifyDemandedBits(SDValue(N, 0)))
10063     return SDValue(N, 0);
10064 
10065   if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
10066     return Combined;
10067 
10068   return SDValue();
10069 }
10070 
10071 /// If we have a shift-by-constant of a bitwise logic op that itself has a
10072 /// shift-by-constant operand with identical opcode, we may be able to convert
10073 /// that into 2 independent shifts followed by the logic op. This is a
10074 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)10075 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
10076   // Match a one-use bitwise logic op.
10077   SDValue LogicOp = Shift->getOperand(0);
10078   if (!LogicOp.hasOneUse())
10079     return SDValue();
10080 
10081   unsigned LogicOpcode = LogicOp.getOpcode();
10082   if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
10083       LogicOpcode != ISD::XOR)
10084     return SDValue();
10085 
10086   // Find a matching one-use shift by constant.
10087   unsigned ShiftOpcode = Shift->getOpcode();
10088   SDValue C1 = Shift->getOperand(1);
10089   ConstantSDNode *C1Node = isConstOrConstSplat(C1);
10090   assert(C1Node && "Expected a shift with constant operand");
10091   const APInt &C1Val = C1Node->getAPIntValue();
10092   auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
10093                              const APInt *&ShiftAmtVal) {
10094     if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
10095       return false;
10096 
10097     ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
10098     if (!ShiftCNode)
10099       return false;
10100 
10101     // Capture the shifted operand and shift amount value.
10102     ShiftOp = V.getOperand(0);
10103     ShiftAmtVal = &ShiftCNode->getAPIntValue();
10104 
10105     // Shift amount types do not have to match their operand type, so check that
10106     // the constants are the same width.
10107     if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
10108       return false;
10109 
10110     // The fold is not valid if the sum of the shift values doesn't fit in the
10111     // given shift amount type.
10112     bool Overflow = false;
10113     APInt NewShiftAmt = C1Val.uadd_ov(*ShiftAmtVal, Overflow);
10114     if (Overflow)
10115       return false;
10116 
10117     // The fold is not valid if the sum of the shift values exceeds bitwidth.
10118     if (NewShiftAmt.uge(V.getScalarValueSizeInBits()))
10119       return false;
10120 
10121     return true;
10122   };
10123 
10124   // Logic ops are commutative, so check each operand for a match.
10125   SDValue X, Y;
10126   const APInt *C0Val;
10127   if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
10128     Y = LogicOp.getOperand(1);
10129   else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
10130     Y = LogicOp.getOperand(0);
10131   else
10132     return SDValue();
10133 
10134   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
10135   SDLoc DL(Shift);
10136   EVT VT = Shift->getValueType(0);
10137   EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
10138   SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
10139   SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
10140   SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
10141   return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2,
10142                      LogicOp->getFlags());
10143 }
10144 
10145 /// Handle transforms common to the three shifts, when the shift amount is a
10146 /// constant.
10147 /// We are looking for: (shift being one of shl/sra/srl)
10148 ///   shift (binop X, C0), C1
10149 /// And want to transform into:
10150 ///   binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)10151 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
10152   assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
10153 
10154   // Do not turn a 'not' into a regular xor.
10155   if (isBitwiseNot(N->getOperand(0)))
10156     return SDValue();
10157 
10158   // The inner binop must be one-use, since we want to replace it.
10159   SDValue LHS = N->getOperand(0);
10160   if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
10161     return SDValue();
10162 
10163   // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
10164   if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
10165     return R;
10166 
10167   // We want to pull some binops through shifts, so that we have (and (shift))
10168   // instead of (shift (and)), likewise for add, or, xor, etc.  This sort of
10169   // thing happens with address calculations, so it's important to canonicalize
10170   // it.
10171   switch (LHS.getOpcode()) {
10172   default:
10173     return SDValue();
10174   case ISD::OR:
10175   case ISD::XOR:
10176   case ISD::AND:
10177     break;
10178   case ISD::ADD:
10179     if (N->getOpcode() != ISD::SHL)
10180       return SDValue(); // only shl(add) not sr[al](add).
10181     break;
10182   }
10183 
10184   // FIXME: disable this unless the input to the binop is a shift by a constant
10185   // or is copy/select. Enable this in other cases when figure out it's exactly
10186   // profitable.
10187   SDValue BinOpLHSVal = LHS.getOperand(0);
10188   bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
10189                             BinOpLHSVal.getOpcode() == ISD::SRA ||
10190                             BinOpLHSVal.getOpcode() == ISD::SRL) &&
10191                            isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
10192   bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
10193                         BinOpLHSVal.getOpcode() == ISD::SELECT;
10194 
10195   if (!IsShiftByConstant && !IsCopyOrSelect)
10196     return SDValue();
10197 
10198   if (IsCopyOrSelect && N->hasOneUse())
10199     return SDValue();
10200 
10201   // Attempt to fold the constants, shifting the binop RHS by the shift amount.
10202   SDLoc DL(N);
10203   EVT VT = N->getValueType(0);
10204   if (SDValue NewRHS = DAG.FoldConstantArithmetic(
10205           N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
10206     SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
10207                                    N->getOperand(1));
10208     return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
10209   }
10210 
10211   return SDValue();
10212 }
10213 
distributeTruncateThroughAnd(SDNode * N)10214 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
10215   assert(N->getOpcode() == ISD::TRUNCATE);
10216   assert(N->getOperand(0).getOpcode() == ISD::AND);
10217 
10218   // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
10219   EVT TruncVT = N->getValueType(0);
10220   if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
10221       TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
10222     SDValue N01 = N->getOperand(0).getOperand(1);
10223     if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
10224       SDLoc DL(N);
10225       SDValue N00 = N->getOperand(0).getOperand(0);
10226       SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
10227       SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
10228       AddToWorklist(Trunc00.getNode());
10229       AddToWorklist(Trunc01.getNode());
10230       return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
10231     }
10232   }
10233 
10234   return SDValue();
10235 }
10236 
visitRotate(SDNode * N)10237 SDValue DAGCombiner::visitRotate(SDNode *N) {
10238   SDLoc dl(N);
10239   SDValue N0 = N->getOperand(0);
10240   SDValue N1 = N->getOperand(1);
10241   EVT VT = N->getValueType(0);
10242   unsigned Bitsize = VT.getScalarSizeInBits();
10243 
10244   // fold (rot x, 0) -> x
10245   if (isNullOrNullSplat(N1))
10246     return N0;
10247 
10248   // fold (rot x, c) -> x iff (c % BitSize) == 0
10249   if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
10250     APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
10251     if (DAG.MaskedValueIsZero(N1, ModuloMask))
10252       return N0;
10253   }
10254 
10255   // fold (rot x, c) -> (rot x, c % BitSize)
10256   bool OutOfRange = false;
10257   auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
10258     OutOfRange |= C->getAPIntValue().uge(Bitsize);
10259     return true;
10260   };
10261   if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
10262     EVT AmtVT = N1.getValueType();
10263     SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
10264     if (SDValue Amt =
10265             DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
10266       return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
10267   }
10268 
10269   // rot i16 X, 8 --> bswap X
10270   auto *RotAmtC = isConstOrConstSplat(N1);
10271   if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
10272       VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
10273     return DAG.getNode(ISD::BSWAP, dl, VT, N0);
10274 
10275   // Simplify the operands using demanded-bits information.
10276   if (SimplifyDemandedBits(SDValue(N, 0)))
10277     return SDValue(N, 0);
10278 
10279   // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
10280   if (N1.getOpcode() == ISD::TRUNCATE &&
10281       N1.getOperand(0).getOpcode() == ISD::AND) {
10282     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10283       return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
10284   }
10285 
10286   unsigned NextOp = N0.getOpcode();
10287 
10288   // fold (rot* (rot* x, c2), c1)
10289   //   -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
10290   if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
10291     bool C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
10292     bool C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
10293     if (C1 && C2 && N1.getValueType() == N0.getOperand(1).getValueType()) {
10294       EVT ShiftVT = N1.getValueType();
10295       bool SameSide = (N->getOpcode() == NextOp);
10296       unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
10297       SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
10298       SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
10299                                                  {N1, BitsizeC});
10300       SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
10301                                                  {N0.getOperand(1), BitsizeC});
10302       if (Norm1 && Norm2)
10303         if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
10304                 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
10305           CombinedShift = DAG.FoldConstantArithmetic(ISD::ADD, dl, ShiftVT,
10306                                                      {CombinedShift, BitsizeC});
10307           SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
10308               ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
10309           return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
10310                              CombinedShiftNorm);
10311         }
10312     }
10313   }
10314   return SDValue();
10315 }
10316 
visitSHL(SDNode * N)10317 SDValue DAGCombiner::visitSHL(SDNode *N) {
10318   SDValue N0 = N->getOperand(0);
10319   SDValue N1 = N->getOperand(1);
10320   if (SDValue V = DAG.simplifyShift(N0, N1))
10321     return V;
10322 
10323   SDLoc DL(N);
10324   EVT VT = N0.getValueType();
10325   EVT ShiftVT = N1.getValueType();
10326   unsigned OpSizeInBits = VT.getScalarSizeInBits();
10327 
10328   // fold (shl c1, c2) -> c1<<c2
10329   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N0, N1}))
10330     return C;
10331 
10332   // fold vector ops
10333   if (VT.isVector()) {
10334     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10335       return FoldedVOp;
10336 
10337     BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
10338     // If setcc produces all-one true value then:
10339     // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
10340     if (N1CV && N1CV->isConstant()) {
10341       if (N0.getOpcode() == ISD::AND) {
10342         SDValue N00 = N0->getOperand(0);
10343         SDValue N01 = N0->getOperand(1);
10344         BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
10345 
10346         if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
10347             TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
10348                 TargetLowering::ZeroOrNegativeOneBooleanContent) {
10349           if (SDValue C =
10350                   DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N01, N1}))
10351             return DAG.getNode(ISD::AND, DL, VT, N00, C);
10352         }
10353       }
10354     }
10355   }
10356 
10357   if (SDValue NewSel = foldBinOpIntoSelect(N))
10358     return NewSel;
10359 
10360   // if (shl x, c) is known to be zero, return 0
10361   if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
10362     return DAG.getConstant(0, DL, VT);
10363 
10364   // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
10365   if (N1.getOpcode() == ISD::TRUNCATE &&
10366       N1.getOperand(0).getOpcode() == ISD::AND) {
10367     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10368       return DAG.getNode(ISD::SHL, DL, VT, N0, NewOp1);
10369   }
10370 
10371   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
10372   if (N0.getOpcode() == ISD::SHL) {
10373     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10374                                           ConstantSDNode *RHS) {
10375       APInt c1 = LHS->getAPIntValue();
10376       APInt c2 = RHS->getAPIntValue();
10377       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10378       return (c1 + c2).uge(OpSizeInBits);
10379     };
10380     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
10381       return DAG.getConstant(0, DL, VT);
10382 
10383     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10384                                        ConstantSDNode *RHS) {
10385       APInt c1 = LHS->getAPIntValue();
10386       APInt c2 = RHS->getAPIntValue();
10387       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10388       return (c1 + c2).ult(OpSizeInBits);
10389     };
10390     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
10391       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
10392       return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
10393     }
10394   }
10395 
10396   // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
10397   // For this to be valid, the second form must not preserve any of the bits
10398   // that are shifted out by the inner shift in the first form.  This means
10399   // the outer shift size must be >= the number of bits added by the ext.
10400   // As a corollary, we don't care what kind of ext it is.
10401   if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
10402        N0.getOpcode() == ISD::ANY_EXTEND ||
10403        N0.getOpcode() == ISD::SIGN_EXTEND) &&
10404       N0.getOperand(0).getOpcode() == ISD::SHL) {
10405     SDValue N0Op0 = N0.getOperand(0);
10406     SDValue InnerShiftAmt = N0Op0.getOperand(1);
10407     EVT InnerVT = N0Op0.getValueType();
10408     uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
10409 
10410     auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10411                                                          ConstantSDNode *RHS) {
10412       APInt c1 = LHS->getAPIntValue();
10413       APInt c2 = RHS->getAPIntValue();
10414       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10415       return c2.uge(OpSizeInBits - InnerBitwidth) &&
10416              (c1 + c2).uge(OpSizeInBits);
10417     };
10418     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
10419                                   /*AllowUndefs*/ false,
10420                                   /*AllowTypeMismatch*/ true))
10421       return DAG.getConstant(0, DL, VT);
10422 
10423     auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10424                                                       ConstantSDNode *RHS) {
10425       APInt c1 = LHS->getAPIntValue();
10426       APInt c2 = RHS->getAPIntValue();
10427       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10428       return c2.uge(OpSizeInBits - InnerBitwidth) &&
10429              (c1 + c2).ult(OpSizeInBits);
10430     };
10431     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
10432                                   /*AllowUndefs*/ false,
10433                                   /*AllowTypeMismatch*/ true)) {
10434       SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
10435       SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
10436       Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
10437       return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
10438     }
10439   }
10440 
10441   // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
10442   // Only fold this if the inner zext has no other uses to avoid increasing
10443   // the total number of instructions.
10444   if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
10445       N0.getOperand(0).getOpcode() == ISD::SRL) {
10446     SDValue N0Op0 = N0.getOperand(0);
10447     SDValue InnerShiftAmt = N0Op0.getOperand(1);
10448 
10449     auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10450       APInt c1 = LHS->getAPIntValue();
10451       APInt c2 = RHS->getAPIntValue();
10452       zeroExtendToMatch(c1, c2);
10453       return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
10454     };
10455     if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
10456                                   /*AllowUndefs*/ false,
10457                                   /*AllowTypeMismatch*/ true)) {
10458       EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
10459       SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
10460       NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
10461       AddToWorklist(NewSHL.getNode());
10462       return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
10463     }
10464   }
10465 
10466   if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
10467     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10468                                            ConstantSDNode *RHS) {
10469       const APInt &LHSC = LHS->getAPIntValue();
10470       const APInt &RHSC = RHS->getAPIntValue();
10471       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
10472              LHSC.getZExtValue() <= RHSC.getZExtValue();
10473     };
10474 
10475     // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
10476     // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10477     if (N0->getFlags().hasExact()) {
10478       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10479                                     /*AllowUndefs*/ false,
10480                                     /*AllowTypeMismatch*/ true)) {
10481         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10482         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10483         return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10484       }
10485       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10486                                     /*AllowUndefs*/ false,
10487                                     /*AllowTypeMismatch*/ true)) {
10488         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10489         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10490         return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
10491       }
10492     }
10493 
10494     // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10495     //                               (and (srl x, (sub c1, c2), MASK)
10496     // Only fold this if the inner shift has no other uses -- if it does,
10497     // folding this will increase the total number of instructions.
10498     if (N0.getOpcode() == ISD::SRL &&
10499         (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
10500         TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10501       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10502                                     /*AllowUndefs*/ false,
10503                                     /*AllowTypeMismatch*/ true)) {
10504         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10505         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10506         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10507         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
10508         Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
10509         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
10510         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10511       }
10512       if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10513                                     /*AllowUndefs*/ false,
10514                                     /*AllowTypeMismatch*/ true)) {
10515         SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10516         SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10517         SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10518         Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
10519         SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10520         return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10521       }
10522     }
10523   }
10524 
10525   // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10526   if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
10527       isConstantOrConstantVector(N1, /* No Opaques */ true)) {
10528     SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10529     SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
10530     return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
10531   }
10532 
10533   // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10534   // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10535   // Variant of version done on multiply, except mul by a power of 2 is turned
10536   // into a shift.
10537   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10538       TLI.isDesirableToCommuteWithShift(N, Level)) {
10539     SDValue N01 = N0.getOperand(1);
10540     if (SDValue Shl1 =
10541             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
10542       SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
10543       AddToWorklist(Shl0.getNode());
10544       SDNodeFlags Flags;
10545       // Preserve the disjoint flag for Or.
10546       if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10547         Flags |= SDNodeFlags::Disjoint;
10548       return DAG.getNode(N0.getOpcode(), DL, VT, Shl0, Shl1, Flags);
10549     }
10550   }
10551 
10552   // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10553   // TODO: Add zext/add_nuw variant with suitable test coverage
10554   // TODO: Should we limit this with isLegalAddImmediate?
10555   if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10556       N0.getOperand(0).getOpcode() == ISD::ADD &&
10557       N0.getOperand(0)->getFlags().hasNoSignedWrap() &&
10558       TLI.isDesirableToCommuteWithShift(N, Level)) {
10559     SDValue Add = N0.getOperand(0);
10560     SDLoc DL(N0);
10561     if (SDValue ExtC = DAG.FoldConstantArithmetic(N0.getOpcode(), DL, VT,
10562                                                   {Add.getOperand(1)})) {
10563       if (SDValue ShlC =
10564               DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {ExtC, N1})) {
10565         SDValue ExtX = DAG.getNode(N0.getOpcode(), DL, VT, Add.getOperand(0));
10566         SDValue ShlX = DAG.getNode(ISD::SHL, DL, VT, ExtX, N1);
10567         return DAG.getNode(ISD::ADD, DL, VT, ShlX, ShlC);
10568       }
10569     }
10570   }
10571 
10572   // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10573   if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10574     SDValue N01 = N0.getOperand(1);
10575     if (SDValue Shl =
10576             DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
10577       return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), Shl);
10578   }
10579 
10580   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10581   if (N1C && !N1C->isOpaque())
10582     if (SDValue NewSHL = visitShiftByConstant(N))
10583       return NewSHL;
10584 
10585   // fold (shl X, cttz(Y)) -> (mul (Y & -Y), X) if cttz is unsupported on the
10586   // target.
10587   if (((N1.getOpcode() == ISD::CTTZ &&
10588         VT.getScalarSizeInBits() <= ShiftVT.getScalarSizeInBits()) ||
10589        N1.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
10590       N1.hasOneUse() && !TLI.isOperationLegalOrCustom(ISD::CTTZ, ShiftVT) &&
10591       TLI.isOperationLegalOrCustom(ISD::MUL, VT)) {
10592     SDValue Y = N1.getOperand(0);
10593     SDLoc DL(N);
10594     SDValue NegY = DAG.getNegative(Y, DL, ShiftVT);
10595     SDValue And =
10596         DAG.getZExtOrTrunc(DAG.getNode(ISD::AND, DL, ShiftVT, Y, NegY), DL, VT);
10597     return DAG.getNode(ISD::MUL, DL, VT, And, N0);
10598   }
10599 
10600   if (SimplifyDemandedBits(SDValue(N, 0)))
10601     return SDValue(N, 0);
10602 
10603   // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10604   if (N0.getOpcode() == ISD::VSCALE && N1C) {
10605     const APInt &C0 = N0.getConstantOperandAPInt(0);
10606     const APInt &C1 = N1C->getAPIntValue();
10607     return DAG.getVScale(DL, VT, C0 << C1);
10608   }
10609 
10610   // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10611   APInt ShlVal;
10612   if (N0.getOpcode() == ISD::STEP_VECTOR &&
10613       ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
10614     const APInt &C0 = N0.getConstantOperandAPInt(0);
10615     if (ShlVal.ult(C0.getBitWidth())) {
10616       APInt NewStep = C0 << ShlVal;
10617       return DAG.getStepVector(DL, VT, NewStep);
10618     }
10619   }
10620 
10621   return SDValue();
10622 }
10623 
10624 // Transform a right shift of a multiply into a multiply-high.
10625 // Examples:
10626 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10627 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const TargetLowering & TLI)10628 static SDValue combineShiftToMULH(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
10629                                   const TargetLowering &TLI) {
10630   assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10631          "SRL or SRA node is required here!");
10632 
10633   // Check the shift amount. Proceed with the transformation if the shift
10634   // amount is constant.
10635   ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
10636   if (!ShiftAmtSrc)
10637     return SDValue();
10638 
10639   // The operation feeding into the shift must be a multiply.
10640   SDValue ShiftOperand = N->getOperand(0);
10641   if (ShiftOperand.getOpcode() != ISD::MUL)
10642     return SDValue();
10643 
10644   // Both operands must be equivalent extend nodes.
10645   SDValue LeftOp = ShiftOperand.getOperand(0);
10646   SDValue RightOp = ShiftOperand.getOperand(1);
10647 
10648   bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10649   bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10650 
10651   if (!IsSignExt && !IsZeroExt)
10652     return SDValue();
10653 
10654   EVT NarrowVT = LeftOp.getOperand(0).getValueType();
10655   unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10656 
10657   // return true if U may use the lower bits of its operands
10658   auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10659     if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10660       return true;
10661     }
10662     ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(U->getOperand(1));
10663     if (!UShiftAmtSrc) {
10664       return true;
10665     }
10666     unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10667     return UShiftAmt < NarrowVTSize;
10668   };
10669 
10670   // If the lower part of the MUL is also used and MUL_LOHI is supported
10671   // do not introduce the MULH in favor of MUL_LOHI
10672   unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10673   if (!ShiftOperand.hasOneUse() &&
10674       TLI.isOperationLegalOrCustom(MulLoHiOp, NarrowVT) &&
10675       llvm::any_of(ShiftOperand->users(), UserOfLowerBits)) {
10676     return SDValue();
10677   }
10678 
10679   SDValue MulhRightOp;
10680   if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
10681     unsigned ActiveBits = IsSignExt
10682                               ? Constant->getAPIntValue().getSignificantBits()
10683                               : Constant->getAPIntValue().getActiveBits();
10684     if (ActiveBits > NarrowVTSize)
10685       return SDValue();
10686     MulhRightOp = DAG.getConstant(
10687         Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
10688         NarrowVT);
10689   } else {
10690     if (LeftOp.getOpcode() != RightOp.getOpcode())
10691       return SDValue();
10692     // Check that the two extend nodes are the same type.
10693     if (NarrowVT != RightOp.getOperand(0).getValueType())
10694       return SDValue();
10695     MulhRightOp = RightOp.getOperand(0);
10696   }
10697 
10698   EVT WideVT = LeftOp.getValueType();
10699   // Proceed with the transformation if the wide types match.
10700   assert((WideVT == RightOp.getValueType()) &&
10701          "Cannot have a multiply node with two different operand types.");
10702 
10703   // Proceed with the transformation if the wide type is twice as large
10704   // as the narrow type.
10705   if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10706     return SDValue();
10707 
10708   // Check the shift amount with the narrow type size.
10709   // Proceed with the transformation if the shift amount is the width
10710   // of the narrow type.
10711   unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10712   if (ShiftAmt != NarrowVTSize)
10713     return SDValue();
10714 
10715   // If the operation feeding into the MUL is a sign extend (sext),
10716   // we use mulhs. Othewise, zero extends (zext) use mulhu.
10717   unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10718 
10719   // Combine to mulh if mulh is legal/custom for the narrow type on the target
10720   // or if it is a vector type then we could transform to an acceptable type and
10721   // rely on legalization to split/combine the result.
10722   if (NarrowVT.isVector()) {
10723     EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), NarrowVT);
10724     if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10725         !TLI.isOperationLegalOrCustom(MulhOpcode, TransformVT))
10726       return SDValue();
10727   } else {
10728     if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
10729       return SDValue();
10730   }
10731 
10732   SDValue Result =
10733       DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
10734   bool IsSigned = N->getOpcode() == ISD::SRA;
10735   return DAG.getExtOrTrunc(IsSigned, Result, DL, WideVT);
10736 }
10737 
10738 // fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10739 // This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
foldBitOrderCrossLogicOp(SDNode * N,SelectionDAG & DAG)10740 static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10741   unsigned Opcode = N->getOpcode();
10742   if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10743     return SDValue();
10744 
10745   SDValue N0 = N->getOperand(0);
10746   EVT VT = N->getValueType(0);
10747   SDLoc DL(N);
10748   SDValue X, Y;
10749 
10750   // If both operands are bswap/bitreverse, ignore the multiuse
10751   if (sd_match(N0, m_OneUse(m_BitwiseLogic(m_UnaryOp(Opcode, m_Value(X)),
10752                                            m_UnaryOp(Opcode, m_Value(Y))))))
10753     return DAG.getNode(N0.getOpcode(), DL, VT, X, Y);
10754 
10755   // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10756   if (sd_match(N0, m_OneUse(m_BitwiseLogic(
10757                        m_OneUse(m_UnaryOp(Opcode, m_Value(X))), m_Value(Y))))) {
10758     SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Y);
10759     return DAG.getNode(N0.getOpcode(), DL, VT, X, NewBitReorder);
10760   }
10761 
10762   return SDValue();
10763 }
10764 
visitSRA(SDNode * N)10765 SDValue DAGCombiner::visitSRA(SDNode *N) {
10766   SDValue N0 = N->getOperand(0);
10767   SDValue N1 = N->getOperand(1);
10768   if (SDValue V = DAG.simplifyShift(N0, N1))
10769     return V;
10770 
10771   SDLoc DL(N);
10772   EVT VT = N0.getValueType();
10773   unsigned OpSizeInBits = VT.getScalarSizeInBits();
10774 
10775   // fold (sra c1, c2) -> (sra c1, c2)
10776   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, DL, VT, {N0, N1}))
10777     return C;
10778 
10779   // Arithmetic shifting an all-sign-bit value is a no-op.
10780   // fold (sra 0, x) -> 0
10781   // fold (sra -1, x) -> -1
10782   if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
10783     return N0;
10784 
10785   // fold vector ops
10786   if (VT.isVector())
10787     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10788       return FoldedVOp;
10789 
10790   if (SDValue NewSel = foldBinOpIntoSelect(N))
10791     return NewSel;
10792 
10793   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10794 
10795   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10796   // clamp (add c1, c2) to max shift.
10797   if (N0.getOpcode() == ISD::SRA) {
10798     EVT ShiftVT = N1.getValueType();
10799     EVT ShiftSVT = ShiftVT.getScalarType();
10800     SmallVector<SDValue, 16> ShiftValues;
10801 
10802     auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10803       APInt c1 = LHS->getAPIntValue();
10804       APInt c2 = RHS->getAPIntValue();
10805       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10806       APInt Sum = c1 + c2;
10807       unsigned ShiftSum =
10808           Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10809       ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
10810       return true;
10811     };
10812     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
10813       SDValue ShiftValue;
10814       if (N1.getOpcode() == ISD::BUILD_VECTOR)
10815         ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
10816       else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10817         assert(ShiftValues.size() == 1 &&
10818                "Expected matchBinaryPredicate to return one element for "
10819                "SPLAT_VECTORs");
10820         ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
10821       } else
10822         ShiftValue = ShiftValues[0];
10823       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
10824     }
10825   }
10826 
10827   // fold (sra (shl X, m), (sub result_size, n))
10828   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10829   // result_size - n != m.
10830   // If truncate is free for the target sext(shl) is likely to result in better
10831   // code.
10832   if (N0.getOpcode() == ISD::SHL && N1C) {
10833     // Get the two constants of the shifts, CN0 = m, CN = n.
10834     const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
10835     if (N01C) {
10836       LLVMContext &Ctx = *DAG.getContext();
10837       // Determine what the truncate's result bitsize and type would be.
10838       EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
10839 
10840       if (VT.isVector())
10841         TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10842 
10843       // Determine the residual right-shift amount.
10844       int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10845 
10846       // If the shift is not a no-op (in which case this should be just a sign
10847       // extend already), the truncated to type is legal, sign_extend is legal
10848       // on that type, and the truncate to that type is both legal and free,
10849       // perform the transform.
10850       if ((ShiftAmt > 0) &&
10851           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
10852           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
10853           TLI.isTruncateFree(VT, TruncVT)) {
10854         SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
10855         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
10856                                     N0.getOperand(0), Amt);
10857         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
10858                                     Shift);
10859         return DAG.getNode(ISD::SIGN_EXTEND, DL,
10860                            N->getValueType(0), Trunc);
10861       }
10862     }
10863   }
10864 
10865   // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10866   //   sra (add (shl X, N1C), AddC), N1C -->
10867   //   sext (add (trunc X to (width - N1C)), AddC')
10868   //   sra (sub AddC, (shl X, N1C)), N1C -->
10869   //   sext (sub AddC1',(trunc X to (width - N1C)))
10870   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10871       N0.hasOneUse()) {
10872     bool IsAdd = N0.getOpcode() == ISD::ADD;
10873     SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
10874     if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
10875         Shl.hasOneUse()) {
10876       // TODO: AddC does not need to be a splat.
10877       if (ConstantSDNode *AddC =
10878               isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
10879         // Determine what the truncate's type would be and ask the target if
10880         // that is a free operation.
10881         LLVMContext &Ctx = *DAG.getContext();
10882         unsigned ShiftAmt = N1C->getZExtValue();
10883         EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
10884         if (VT.isVector())
10885           TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10886 
10887         // TODO: The simple type check probably belongs in the default hook
10888         //       implementation and/or target-specific overrides (because
10889         //       non-simple types likely require masking when legalized), but
10890         //       that restriction may conflict with other transforms.
10891         if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
10892             TLI.isTruncateFree(VT, TruncVT)) {
10893           SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
10894           SDValue ShiftC =
10895               DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
10896                                   TruncVT.getScalarSizeInBits()),
10897                               DL, TruncVT);
10898           SDValue Add;
10899           if (IsAdd)
10900             Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
10901           else
10902             Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
10903           return DAG.getSExtOrTrunc(Add, DL, VT);
10904         }
10905       }
10906     }
10907   }
10908 
10909   // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10910   if (N1.getOpcode() == ISD::TRUNCATE &&
10911       N1.getOperand(0).getOpcode() == ISD::AND) {
10912     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10913       return DAG.getNode(ISD::SRA, DL, VT, N0, NewOp1);
10914   }
10915 
10916   // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10917   // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10918   //      if c1 is equal to the number of bits the trunc removes
10919   // TODO - support non-uniform vector shift amounts.
10920   if (N0.getOpcode() == ISD::TRUNCATE &&
10921       (N0.getOperand(0).getOpcode() == ISD::SRL ||
10922        N0.getOperand(0).getOpcode() == ISD::SRA) &&
10923       N0.getOperand(0).hasOneUse() &&
10924       N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
10925     SDValue N0Op0 = N0.getOperand(0);
10926     if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
10927       EVT LargeVT = N0Op0.getValueType();
10928       unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10929       if (LargeShift->getAPIntValue() == TruncBits) {
10930         EVT LargeShiftVT = getShiftAmountTy(LargeVT);
10931         SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
10932         Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
10933                           DAG.getConstant(TruncBits, DL, LargeShiftVT));
10934         SDValue SRA =
10935             DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
10936         return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
10937       }
10938     }
10939   }
10940 
10941   // Simplify, based on bits shifted out of the LHS.
10942   if (SimplifyDemandedBits(SDValue(N, 0)))
10943     return SDValue(N, 0);
10944 
10945   // If the sign bit is known to be zero, switch this to a SRL.
10946   if (DAG.SignBitIsZero(N0))
10947     return DAG.getNode(ISD::SRL, DL, VT, N0, N1);
10948 
10949   if (N1C && !N1C->isOpaque())
10950     if (SDValue NewSRA = visitShiftByConstant(N))
10951       return NewSRA;
10952 
10953   // Try to transform this shift into a multiply-high if
10954   // it matches the appropriate pattern detected in combineShiftToMULH.
10955   if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
10956     return MULH;
10957 
10958   // Attempt to convert a sra of a load into a narrower sign-extending load.
10959   if (SDValue NarrowLoad = reduceLoadWidth(N))
10960     return NarrowLoad;
10961 
10962   if (SDValue AVG = foldShiftToAvg(N))
10963     return AVG;
10964 
10965   return SDValue();
10966 }
10967 
visitSRL(SDNode * N)10968 SDValue DAGCombiner::visitSRL(SDNode *N) {
10969   SDValue N0 = N->getOperand(0);
10970   SDValue N1 = N->getOperand(1);
10971   if (SDValue V = DAG.simplifyShift(N0, N1))
10972     return V;
10973 
10974   SDLoc DL(N);
10975   EVT VT = N0.getValueType();
10976   EVT ShiftVT = N1.getValueType();
10977   unsigned OpSizeInBits = VT.getScalarSizeInBits();
10978 
10979   // fold (srl c1, c2) -> c1 >>u c2
10980   if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, DL, VT, {N0, N1}))
10981     return C;
10982 
10983   // fold vector ops
10984   if (VT.isVector())
10985     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10986       return FoldedVOp;
10987 
10988   if (SDValue NewSel = foldBinOpIntoSelect(N))
10989     return NewSel;
10990 
10991   // if (srl x, c) is known to be zero, return 0
10992   ConstantSDNode *N1C = isConstOrConstSplat(N1);
10993   if (N1C &&
10994       DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
10995     return DAG.getConstant(0, DL, VT);
10996 
10997   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10998   if (N0.getOpcode() == ISD::SRL) {
10999     auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
11000                                           ConstantSDNode *RHS) {
11001       APInt c1 = LHS->getAPIntValue();
11002       APInt c2 = RHS->getAPIntValue();
11003       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
11004       return (c1 + c2).uge(OpSizeInBits);
11005     };
11006     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
11007       return DAG.getConstant(0, DL, VT);
11008 
11009     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
11010                                        ConstantSDNode *RHS) {
11011       APInt c1 = LHS->getAPIntValue();
11012       APInt c2 = RHS->getAPIntValue();
11013       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
11014       return (c1 + c2).ult(OpSizeInBits);
11015     };
11016     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
11017       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
11018       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
11019     }
11020   }
11021 
11022   if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
11023       N0.getOperand(0).getOpcode() == ISD::SRL) {
11024     SDValue InnerShift = N0.getOperand(0);
11025     // TODO - support non-uniform vector shift amounts.
11026     if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
11027       uint64_t c1 = N001C->getZExtValue();
11028       uint64_t c2 = N1C->getZExtValue();
11029       EVT InnerShiftVT = InnerShift.getValueType();
11030       EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
11031       uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
11032       // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
11033       // This is only valid if the OpSizeInBits + c1 = size of inner shift.
11034       if (c1 + OpSizeInBits == InnerShiftSize) {
11035         if (c1 + c2 >= InnerShiftSize)
11036           return DAG.getConstant(0, DL, VT);
11037         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
11038         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
11039                                        InnerShift.getOperand(0), NewShiftAmt);
11040         return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
11041       }
11042       // In the more general case, we can clear the high bits after the shift:
11043       // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
11044       if (N0.hasOneUse() && InnerShift.hasOneUse() &&
11045           c1 + c2 < InnerShiftSize) {
11046         SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
11047         SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
11048                                        InnerShift.getOperand(0), NewShiftAmt);
11049         SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
11050                                                             OpSizeInBits - c2),
11051                                        DL, InnerShiftVT);
11052         SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
11053         return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
11054       }
11055     }
11056   }
11057 
11058   // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
11059   //                               (and (srl x, (sub c2, c1), MASK)
11060   if (N0.getOpcode() == ISD::SHL &&
11061       (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
11062       TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
11063     auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
11064                                            ConstantSDNode *RHS) {
11065       const APInt &LHSC = LHS->getAPIntValue();
11066       const APInt &RHSC = RHS->getAPIntValue();
11067       return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
11068              LHSC.getZExtValue() <= RHSC.getZExtValue();
11069     };
11070     if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
11071                                   /*AllowUndefs*/ false,
11072                                   /*AllowTypeMismatch*/ true)) {
11073       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
11074       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
11075       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11076       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
11077       Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
11078       SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
11079       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
11080     }
11081     if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
11082                                   /*AllowUndefs*/ false,
11083                                   /*AllowTypeMismatch*/ true)) {
11084       SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
11085       SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
11086       SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11087       Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
11088       SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
11089       return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
11090     }
11091   }
11092 
11093   // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
11094   // TODO - support non-uniform vector shift amounts.
11095   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
11096     // Shifting in all undef bits?
11097     EVT SmallVT = N0.getOperand(0).getValueType();
11098     unsigned BitSize = SmallVT.getScalarSizeInBits();
11099     if (N1C->getAPIntValue().uge(BitSize))
11100       return DAG.getUNDEF(VT);
11101 
11102     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
11103       uint64_t ShiftAmt = N1C->getZExtValue();
11104       SDLoc DL0(N0);
11105       SDValue SmallShift =
11106           DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
11107                       DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
11108       AddToWorklist(SmallShift.getNode());
11109       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
11110       return DAG.getNode(ISD::AND, DL, VT,
11111                          DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
11112                          DAG.getConstant(Mask, DL, VT));
11113     }
11114   }
11115 
11116   // fold (srl (sra X, Y), 31) -> (srl X, 31).  This srl only looks at the sign
11117   // bit, which is unmodified by sra.
11118   if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
11119     if (N0.getOpcode() == ISD::SRA)
11120       return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), N1);
11121   }
11122 
11123   // fold (srl (ctlz x), "5") -> x  iff x has one bit set (the low bit), and x has a power
11124   // of two bitwidth. The "5" represents (log2 (bitwidth x)).
11125   if (N1C && N0.getOpcode() == ISD::CTLZ &&
11126       isPowerOf2_32(OpSizeInBits) &&
11127       N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
11128     KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
11129 
11130     // If any of the input bits are KnownOne, then the input couldn't be all
11131     // zeros, thus the result of the srl will always be zero.
11132     if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
11133 
11134     // If all of the bits input the to ctlz node are known to be zero, then
11135     // the result of the ctlz is "32" and the result of the shift is one.
11136     APInt UnknownBits = ~Known.Zero;
11137     if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
11138 
11139     // Otherwise, check to see if there is exactly one bit input to the ctlz.
11140     if (UnknownBits.isPowerOf2()) {
11141       // Okay, we know that only that the single bit specified by UnknownBits
11142       // could be set on input to the CTLZ node. If this bit is set, the SRL
11143       // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
11144       // to an SRL/XOR pair, which is likely to simplify more.
11145       unsigned ShAmt = UnknownBits.countr_zero();
11146       SDValue Op = N0.getOperand(0);
11147 
11148       if (ShAmt) {
11149         SDLoc DL(N0);
11150         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
11151                          DAG.getShiftAmountConstant(ShAmt, VT, DL));
11152         AddToWorklist(Op.getNode());
11153       }
11154       return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
11155     }
11156   }
11157 
11158   // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
11159   if (N1.getOpcode() == ISD::TRUNCATE &&
11160       N1.getOperand(0).getOpcode() == ISD::AND) {
11161     if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
11162       return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
11163   }
11164 
11165   // fold (srl (logic_op x, (shl (zext y), c1)), c1)
11166   //   -> (logic_op (srl x, c1), (zext y))
11167   // c1 <= leadingzeros(zext(y))
11168   SDValue X, ZExtY;
11169   if (N1C && sd_match(N0, m_OneUse(m_BitwiseLogic(
11170                               m_Value(X),
11171                               m_OneUse(m_Shl(m_AllOf(m_Value(ZExtY),
11172                                                      m_Opc(ISD::ZERO_EXTEND)),
11173                                              m_Specific(N1))))))) {
11174     unsigned NumLeadingZeros = ZExtY.getScalarValueSizeInBits() -
11175                                ZExtY.getOperand(0).getScalarValueSizeInBits();
11176     if (N1C->getZExtValue() <= NumLeadingZeros)
11177       return DAG.getNode(N0.getOpcode(), SDLoc(N0), VT,
11178                          DAG.getNode(ISD::SRL, SDLoc(N0), VT, X, N1), ZExtY);
11179   }
11180 
11181   // fold operands of srl based on knowledge that the low bits are not
11182   // demanded.
11183   if (SimplifyDemandedBits(SDValue(N, 0)))
11184     return SDValue(N, 0);
11185 
11186   if (N1C && !N1C->isOpaque())
11187     if (SDValue NewSRL = visitShiftByConstant(N))
11188       return NewSRL;
11189 
11190   // Attempt to convert a srl of a load into a narrower zero-extending load.
11191   if (SDValue NarrowLoad = reduceLoadWidth(N))
11192     return NarrowLoad;
11193 
11194   // Here is a common situation. We want to optimize:
11195   //
11196   //   %a = ...
11197   //   %b = and i32 %a, 2
11198   //   %c = srl i32 %b, 1
11199   //   brcond i32 %c ...
11200   //
11201   // into
11202   //
11203   //   %a = ...
11204   //   %b = and %a, 2
11205   //   %c = setcc eq %b, 0
11206   //   brcond %c ...
11207   //
11208   // However when after the source operand of SRL is optimized into AND, the SRL
11209   // itself may not be optimized further. Look for it and add the BRCOND into
11210   // the worklist.
11211   //
11212   // The also tends to happen for binary operations when SimplifyDemandedBits
11213   // is involved.
11214   //
11215   // FIXME: This is unecessary if we process the DAG in topological order,
11216   // which we plan to do. This workaround can be removed once the DAG is
11217   // processed in topological order.
11218   if (N->hasOneUse()) {
11219     SDNode *User = *N->user_begin();
11220 
11221     // Look pass the truncate.
11222     if (User->getOpcode() == ISD::TRUNCATE && User->hasOneUse())
11223       User = *User->user_begin();
11224 
11225     if (User->getOpcode() == ISD::BRCOND || User->getOpcode() == ISD::AND ||
11226         User->getOpcode() == ISD::OR || User->getOpcode() == ISD::XOR)
11227       AddToWorklist(User);
11228   }
11229 
11230   // Try to transform this shift into a multiply-high if
11231   // it matches the appropriate pattern detected in combineShiftToMULH.
11232   if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
11233     return MULH;
11234 
11235   if (SDValue AVG = foldShiftToAvg(N))
11236     return AVG;
11237 
11238   return SDValue();
11239 }
11240 
visitFunnelShift(SDNode * N)11241 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
11242   EVT VT = N->getValueType(0);
11243   SDValue N0 = N->getOperand(0);
11244   SDValue N1 = N->getOperand(1);
11245   SDValue N2 = N->getOperand(2);
11246   bool IsFSHL = N->getOpcode() == ISD::FSHL;
11247   unsigned BitWidth = VT.getScalarSizeInBits();
11248   SDLoc DL(N);
11249 
11250   // fold (fshl N0, N1, 0) -> N0
11251   // fold (fshr N0, N1, 0) -> N1
11252   if (isPowerOf2_32(BitWidth))
11253     if (DAG.MaskedValueIsZero(
11254             N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
11255       return IsFSHL ? N0 : N1;
11256 
11257   auto IsUndefOrZero = [](SDValue V) {
11258     return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
11259   };
11260 
11261   // TODO - support non-uniform vector shift amounts.
11262   if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
11263     EVT ShAmtTy = N2.getValueType();
11264 
11265     // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
11266     if (Cst->getAPIntValue().uge(BitWidth)) {
11267       uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
11268       return DAG.getNode(N->getOpcode(), DL, VT, N0, N1,
11269                          DAG.getConstant(RotAmt, DL, ShAmtTy));
11270     }
11271 
11272     unsigned ShAmt = Cst->getZExtValue();
11273     if (ShAmt == 0)
11274       return IsFSHL ? N0 : N1;
11275 
11276     // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
11277     // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
11278     // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
11279     // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
11280     if (IsUndefOrZero(N0))
11281       return DAG.getNode(
11282           ISD::SRL, DL, VT, N1,
11283           DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt, DL, ShAmtTy));
11284     if (IsUndefOrZero(N1))
11285       return DAG.getNode(
11286           ISD::SHL, DL, VT, N0,
11287           DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt, DL, ShAmtTy));
11288 
11289     // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11290     // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11291     // TODO - bigendian support once we have test coverage.
11292     // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
11293     // TODO - permit LHS EXTLOAD if extensions are shifted out.
11294     if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
11295         !DAG.getDataLayout().isBigEndian()) {
11296       auto *LHS = dyn_cast<LoadSDNode>(N0);
11297       auto *RHS = dyn_cast<LoadSDNode>(N1);
11298       if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
11299           LHS->getAddressSpace() == RHS->getAddressSpace() &&
11300           (LHS->hasNUsesOfValue(1, 0) || RHS->hasNUsesOfValue(1, 0)) &&
11301           ISD::isNON_EXTLoad(RHS) && ISD::isNON_EXTLoad(LHS)) {
11302         if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
11303           SDLoc DL(RHS);
11304           uint64_t PtrOff =
11305               IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
11306           Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
11307           unsigned Fast = 0;
11308           if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
11309                                      RHS->getAddressSpace(), NewAlign,
11310                                      RHS->getMemOperand()->getFlags(), &Fast) &&
11311               Fast) {
11312             SDValue NewPtr = DAG.getMemBasePlusOffset(
11313                 RHS->getBasePtr(), TypeSize::getFixed(PtrOff), DL);
11314             AddToWorklist(NewPtr.getNode());
11315             SDValue Load = DAG.getLoad(
11316                 VT, DL, RHS->getChain(), NewPtr,
11317                 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
11318                 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
11319             DAG.makeEquivalentMemoryOrdering(LHS, Load.getValue(1));
11320             DAG.makeEquivalentMemoryOrdering(RHS, Load.getValue(1));
11321             return Load;
11322           }
11323         }
11324       }
11325     }
11326   }
11327 
11328   // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
11329   // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
11330   // iff We know the shift amount is in range.
11331   // TODO: when is it worth doing SUB(BW, N2) as well?
11332   if (isPowerOf2_32(BitWidth)) {
11333     APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
11334     if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
11335       return DAG.getNode(ISD::SRL, DL, VT, N1, N2);
11336     if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
11337       return DAG.getNode(ISD::SHL, DL, VT, N0, N2);
11338   }
11339 
11340   // fold (fshl N0, N0, N2) -> (rotl N0, N2)
11341   // fold (fshr N0, N0, N2) -> (rotr N0, N2)
11342   // TODO: Investigate flipping this rotate if only one is legal.
11343   // If funnel shift is legal as well we might be better off avoiding
11344   // non-constant (BW - N2).
11345   unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
11346   if (N0 == N1 && hasOperation(RotOpc, VT))
11347     return DAG.getNode(RotOpc, DL, VT, N0, N2);
11348 
11349   // Simplify, based on bits shifted out of N0/N1.
11350   if (SimplifyDemandedBits(SDValue(N, 0)))
11351     return SDValue(N, 0);
11352 
11353   return SDValue();
11354 }
11355 
visitSHLSAT(SDNode * N)11356 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
11357   SDValue N0 = N->getOperand(0);
11358   SDValue N1 = N->getOperand(1);
11359   if (SDValue V = DAG.simplifyShift(N0, N1))
11360     return V;
11361 
11362   SDLoc DL(N);
11363   EVT VT = N0.getValueType();
11364 
11365   // fold (*shlsat c1, c2) -> c1<<c2
11366   if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
11367     return C;
11368 
11369   ConstantSDNode *N1C = isConstOrConstSplat(N1);
11370 
11371   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
11372     // fold (sshlsat x, c) -> (shl x, c)
11373     if (N->getOpcode() == ISD::SSHLSAT && N1C &&
11374         N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
11375       return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
11376 
11377     // fold (ushlsat x, c) -> (shl x, c)
11378     if (N->getOpcode() == ISD::USHLSAT && N1C &&
11379         N1C->getAPIntValue().ule(
11380             DAG.computeKnownBits(N0).countMinLeadingZeros()))
11381       return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
11382   }
11383 
11384   return SDValue();
11385 }
11386 
11387 // Given a ABS node, detect the following patterns:
11388 // (ABS (SUB (EXTEND a), (EXTEND b))).
11389 // (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
11390 // Generates UABD/SABD instruction.
foldABSToABD(SDNode * N,const SDLoc & DL)11391 SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
11392   EVT SrcVT = N->getValueType(0);
11393 
11394   if (N->getOpcode() == ISD::TRUNCATE)
11395     N = N->getOperand(0).getNode();
11396 
11397   EVT VT = N->getValueType(0);
11398   SDValue Op0, Op1;
11399 
11400   if (!sd_match(N, m_Abs(m_Sub(m_Value(Op0), m_Value(Op1)))))
11401     return SDValue();
11402 
11403   SDValue AbsOp0 = N->getOperand(0);
11404   unsigned Opc0 = Op0.getOpcode();
11405 
11406   // Check if the operands of the sub are (zero|sign)-extended, otherwise
11407   // fallback to ValueTracking.
11408   if (Opc0 != Op1.getOpcode() ||
11409       (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
11410        Opc0 != ISD::SIGN_EXTEND_INREG)) {
11411     // fold (abs (sub nsw x, y)) -> abds(x, y)
11412     // Don't fold this for unsupported types as we lose the NSW handling.
11413     if (hasOperation(ISD::ABDS, VT) && TLI.preferABDSToABSWithNSW(VT) &&
11414         (AbsOp0->getFlags().hasNoSignedWrap() ||
11415          DAG.willNotOverflowSub(/*IsSigned=*/true, Op0, Op1))) {
11416       SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
11417       return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11418     }
11419     // fold (abs (sub x, y)) -> abdu(x, y)
11420     if (hasOperation(ISD::ABDU, VT) && DAG.SignBitIsZero(Op0) &&
11421         DAG.SignBitIsZero(Op1)) {
11422       SDValue ABD = DAG.getNode(ISD::ABDU, DL, VT, Op0, Op1);
11423       return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11424     }
11425     return SDValue();
11426   }
11427 
11428   EVT VT0, VT1;
11429   if (Opc0 == ISD::SIGN_EXTEND_INREG) {
11430     VT0 = cast<VTSDNode>(Op0.getOperand(1))->getVT();
11431     VT1 = cast<VTSDNode>(Op1.getOperand(1))->getVT();
11432   } else {
11433     VT0 = Op0.getOperand(0).getValueType();
11434     VT1 = Op1.getOperand(0).getValueType();
11435   }
11436   unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
11437 
11438   // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
11439   // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
11440   EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
11441   if ((VT0 == MaxVT || Op0->hasOneUse()) &&
11442       (VT1 == MaxVT || Op1->hasOneUse()) &&
11443       (!LegalTypes || hasOperation(ABDOpcode, MaxVT))) {
11444     SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
11445                               DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
11446                               DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
11447     ABD = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ABD);
11448     return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11449   }
11450 
11451   // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
11452   // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
11453   if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
11454     SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
11455     return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11456   }
11457 
11458   return SDValue();
11459 }
11460 
visitABS(SDNode * N)11461 SDValue DAGCombiner::visitABS(SDNode *N) {
11462   SDValue N0 = N->getOperand(0);
11463   EVT VT = N->getValueType(0);
11464   SDLoc DL(N);
11465 
11466   // fold (abs c1) -> c2
11467   if (SDValue C = DAG.FoldConstantArithmetic(ISD::ABS, DL, VT, {N0}))
11468     return C;
11469   // fold (abs (abs x)) -> (abs x)
11470   if (N0.getOpcode() == ISD::ABS)
11471     return N0;
11472   // fold (abs x) -> x iff not-negative
11473   if (DAG.SignBitIsZero(N0))
11474     return N0;
11475 
11476   if (SDValue ABD = foldABSToABD(N, DL))
11477     return ABD;
11478 
11479   // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
11480   // iff zero_extend/truncate are free.
11481   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11482     EVT ExtVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
11483     if (TLI.isTruncateFree(VT, ExtVT) && TLI.isZExtFree(ExtVT, VT) &&
11484         TLI.isTypeDesirableForOp(ISD::ABS, ExtVT) &&
11485         hasOperation(ISD::ABS, ExtVT)) {
11486       return DAG.getNode(
11487           ISD::ZERO_EXTEND, DL, VT,
11488           DAG.getNode(ISD::ABS, DL, ExtVT,
11489                       DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N0.getOperand(0))));
11490     }
11491   }
11492 
11493   return SDValue();
11494 }
11495 
visitBSWAP(SDNode * N)11496 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11497   SDValue N0 = N->getOperand(0);
11498   EVT VT = N->getValueType(0);
11499   SDLoc DL(N);
11500 
11501   // fold (bswap c1) -> c2
11502   if (SDValue C = DAG.FoldConstantArithmetic(ISD::BSWAP, DL, VT, {N0}))
11503     return C;
11504   // fold (bswap (bswap x)) -> x
11505   if (N0.getOpcode() == ISD::BSWAP)
11506     return N0.getOperand(0);
11507 
11508   // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11509   // isn't supported, it will be expanded to bswap followed by a manual reversal
11510   // of bits in each byte. By placing bswaps before bitreverse, we can remove
11511   // the two bswaps if the bitreverse gets expanded.
11512   if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11513     SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11514     return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
11515   }
11516 
11517   // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11518   // iff x >= bw/2 (i.e. lower half is known zero)
11519   unsigned BW = VT.getScalarSizeInBits();
11520   if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11521     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11522     EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
11523     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11524         ShAmt->getZExtValue() >= (BW / 2) &&
11525         (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
11526         TLI.isTruncateFree(VT, HalfVT) &&
11527         (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
11528       SDValue Res = N0.getOperand(0);
11529       if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11530         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
11531                           DAG.getShiftAmountConstant(NewShAmt, VT, DL));
11532       Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
11533       Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
11534       return DAG.getZExtOrTrunc(Res, DL, VT);
11535     }
11536   }
11537 
11538   // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11539   // inverse-shift-of-bswap:
11540   // bswap (X u<< C) --> (bswap X) u>> C
11541   // bswap (X u>> C) --> (bswap X) u<< C
11542   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11543       N0.hasOneUse()) {
11544     auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11545     if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11546         ShAmt->getZExtValue() % 8 == 0) {
11547       SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11548       unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11549       return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
11550     }
11551   }
11552 
11553   if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11554     return V;
11555 
11556   return SDValue();
11557 }
11558 
visitBITREVERSE(SDNode * N)11559 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11560   SDValue N0 = N->getOperand(0);
11561   EVT VT = N->getValueType(0);
11562   SDLoc DL(N);
11563 
11564   // fold (bitreverse c1) -> c2
11565   if (SDValue C = DAG.FoldConstantArithmetic(ISD::BITREVERSE, DL, VT, {N0}))
11566     return C;
11567 
11568   // fold (bitreverse (bitreverse x)) -> x
11569   if (N0.getOpcode() == ISD::BITREVERSE)
11570     return N0.getOperand(0);
11571 
11572   SDValue X, Y;
11573 
11574   // fold (bitreverse (lshr (bitreverse x), y)) -> (shl x, y)
11575   if ((!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
11576       sd_match(N, m_BitReverse(m_Srl(m_BitReverse(m_Value(X)), m_Value(Y)))))
11577     return DAG.getNode(ISD::SHL, DL, VT, X, Y);
11578 
11579   // fold (bitreverse (shl (bitreverse x), y)) -> (lshr x, y)
11580   if ((!LegalOperations || TLI.isOperationLegal(ISD::SRL, VT)) &&
11581       sd_match(N, m_BitReverse(m_Shl(m_BitReverse(m_Value(X)), m_Value(Y)))))
11582     return DAG.getNode(ISD::SRL, DL, VT, X, Y);
11583 
11584   return SDValue();
11585 }
11586 
visitCTLZ(SDNode * N)11587 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
11588   SDValue N0 = N->getOperand(0);
11589   EVT VT = N->getValueType(0);
11590   SDLoc DL(N);
11591 
11592   // fold (ctlz c1) -> c2
11593   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTLZ, DL, VT, {N0}))
11594     return C;
11595 
11596   // If the value is known never to be zero, switch to the undef version.
11597   if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT))
11598     if (DAG.isKnownNeverZero(N0))
11599       return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, N0);
11600 
11601   return SDValue();
11602 }
11603 
visitCTLZ_ZERO_UNDEF(SDNode * N)11604 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
11605   SDValue N0 = N->getOperand(0);
11606   EVT VT = N->getValueType(0);
11607   SDLoc DL(N);
11608 
11609   // fold (ctlz_zero_undef c1) -> c2
11610   if (SDValue C =
11611           DAG.FoldConstantArithmetic(ISD::CTLZ_ZERO_UNDEF, DL, VT, {N0}))
11612     return C;
11613   return SDValue();
11614 }
11615 
visitCTTZ(SDNode * N)11616 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
11617   SDValue N0 = N->getOperand(0);
11618   EVT VT = N->getValueType(0);
11619   SDLoc DL(N);
11620 
11621   // fold (cttz c1) -> c2
11622   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTTZ, DL, VT, {N0}))
11623     return C;
11624 
11625   // If the value is known never to be zero, switch to the undef version.
11626   if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT))
11627     if (DAG.isKnownNeverZero(N0))
11628       return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, DL, VT, N0);
11629 
11630   return SDValue();
11631 }
11632 
visitCTTZ_ZERO_UNDEF(SDNode * N)11633 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11634   SDValue N0 = N->getOperand(0);
11635   EVT VT = N->getValueType(0);
11636   SDLoc DL(N);
11637 
11638   // fold (cttz_zero_undef c1) -> c2
11639   if (SDValue C =
11640           DAG.FoldConstantArithmetic(ISD::CTTZ_ZERO_UNDEF, DL, VT, {N0}))
11641     return C;
11642   return SDValue();
11643 }
11644 
visitCTPOP(SDNode * N)11645 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11646   SDValue N0 = N->getOperand(0);
11647   EVT VT = N->getValueType(0);
11648   unsigned NumBits = VT.getScalarSizeInBits();
11649   SDLoc DL(N);
11650 
11651   // fold (ctpop c1) -> c2
11652   if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTPOP, DL, VT, {N0}))
11653     return C;
11654 
11655   // If the source is being shifted, but doesn't affect any active bits,
11656   // then we can call CTPOP on the shift source directly.
11657   if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
11658     if (ConstantSDNode *AmtC = isConstOrConstSplat(N0.getOperand(1))) {
11659       const APInt &Amt = AmtC->getAPIntValue();
11660       if (Amt.ult(NumBits)) {
11661         KnownBits KnownSrc = DAG.computeKnownBits(N0.getOperand(0));
11662         if ((N0.getOpcode() == ISD::SRL &&
11663              Amt.ule(KnownSrc.countMinTrailingZeros())) ||
11664             (N0.getOpcode() == ISD::SHL &&
11665              Amt.ule(KnownSrc.countMinLeadingZeros()))) {
11666           return DAG.getNode(ISD::CTPOP, DL, VT, N0.getOperand(0));
11667         }
11668       }
11669     }
11670   }
11671 
11672   // If the upper bits are known to be zero, then see if its profitable to
11673   // only count the lower bits.
11674   if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
11675     EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), NumBits / 2);
11676     if (hasOperation(ISD::CTPOP, HalfVT) &&
11677         TLI.isTypeDesirableForOp(ISD::CTPOP, HalfVT) &&
11678         TLI.isTruncateFree(N0, HalfVT) && TLI.isZExtFree(HalfVT, VT)) {
11679       APInt UpperBits = APInt::getHighBitsSet(NumBits, NumBits / 2);
11680       if (DAG.MaskedValueIsZero(N0, UpperBits)) {
11681         SDValue PopCnt = DAG.getNode(ISD::CTPOP, DL, HalfVT,
11682                                      DAG.getZExtOrTrunc(N0, DL, HalfVT));
11683         return DAG.getZExtOrTrunc(PopCnt, DL, VT);
11684       }
11685     }
11686   }
11687 
11688   return SDValue();
11689 }
11690 
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const SDNodeFlags Flags,const TargetLowering & TLI)11691 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11692                                          SDValue RHS, const SDNodeFlags Flags,
11693                                          const TargetLowering &TLI) {
11694   EVT VT = LHS.getValueType();
11695   if (!VT.isFloatingPoint())
11696     return false;
11697 
11698   const TargetOptions &Options = DAG.getTarget().Options;
11699 
11700   return (Flags.hasNoSignedZeros() || Options.NoSignedZerosFPMath) &&
11701          TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11702          (Flags.hasNoNaNs() ||
11703           (DAG.isKnownNeverNaN(RHS) && DAG.isKnownNeverNaN(LHS)));
11704 }
11705 
combineMinNumMaxNumImpl(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)11706 static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11707                                        SDValue RHS, SDValue True, SDValue False,
11708                                        ISD::CondCode CC,
11709                                        const TargetLowering &TLI,
11710                                        SelectionDAG &DAG) {
11711   EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
11712   switch (CC) {
11713   case ISD::SETOLT:
11714   case ISD::SETOLE:
11715   case ISD::SETLT:
11716   case ISD::SETLE:
11717   case ISD::SETULT:
11718   case ISD::SETULE: {
11719     // Since it's known never nan to get here already, either fminnum or
11720     // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11721     // expanded in terms of it.
11722     unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11723     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11724       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11725 
11726     unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11727     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11728       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11729     return SDValue();
11730   }
11731   case ISD::SETOGT:
11732   case ISD::SETOGE:
11733   case ISD::SETGT:
11734   case ISD::SETGE:
11735   case ISD::SETUGT:
11736   case ISD::SETUGE: {
11737     unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11738     if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11739       return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11740 
11741     unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11742     if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11743       return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11744     return SDValue();
11745   }
11746   default:
11747     return SDValue();
11748   }
11749 }
11750 
foldShiftToAvg(SDNode * N)11751 SDValue DAGCombiner::foldShiftToAvg(SDNode *N) {
11752   const unsigned Opcode = N->getOpcode();
11753 
11754   // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
11755   if (Opcode != ISD::SRA && Opcode != ISD::SRL)
11756     return SDValue();
11757 
11758   unsigned FloorISD = 0;
11759   auto VT = N->getValueType(0);
11760   bool IsUnsigned = false;
11761 
11762   // Decide wether signed or unsigned.
11763   switch (Opcode) {
11764   case ISD::SRA:
11765     if (!hasOperation(ISD::AVGFLOORS, VT))
11766       return SDValue();
11767     FloorISD = ISD::AVGFLOORS;
11768     break;
11769   case ISD::SRL:
11770     IsUnsigned = true;
11771     if (!hasOperation(ISD::AVGFLOORU, VT))
11772       return SDValue();
11773     FloorISD = ISD::AVGFLOORU;
11774     break;
11775   default:
11776     return SDValue();
11777   }
11778 
11779   // Captured values.
11780   SDValue A, B, Add;
11781 
11782   // Match floor average as it is common to both floor/ceil avgs.
11783   if (!sd_match(N, m_BinOp(Opcode,
11784                            m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
11785                            m_One())))
11786     return SDValue();
11787 
11788   // Can't optimize adds that may wrap.
11789   if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
11790     return SDValue();
11791 
11792   if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
11793     return SDValue();
11794 
11795   return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0), {A, B});
11796 }
11797 
foldBitwiseOpWithNeg(SDNode * N,const SDLoc & DL,EVT VT)11798 SDValue DAGCombiner::foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT) {
11799   unsigned Opc = N->getOpcode();
11800   SDValue X, Y, Z;
11801   if (sd_match(
11802           N, m_BitwiseLogic(m_Value(X), m_Add(m_Not(m_Value(Y)), m_Value(Z)))))
11803     return DAG.getNode(Opc, DL, VT, X,
11804                        DAG.getNOT(DL, DAG.getNode(ISD::SUB, DL, VT, Y, Z), VT));
11805 
11806   if (sd_match(N, m_BitwiseLogic(m_Value(X), m_Sub(m_OneUse(m_Not(m_Value(Y))),
11807                                                    m_Value(Z)))))
11808     return DAG.getNode(Opc, DL, VT, X,
11809                        DAG.getNOT(DL, DAG.getNode(ISD::ADD, DL, VT, Y, Z), VT));
11810 
11811   return SDValue();
11812 }
11813 
11814 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC)11815 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11816                                          SDValue RHS, SDValue True,
11817                                          SDValue False, ISD::CondCode CC) {
11818   if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11819     return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11820 
11821   // If we can't directly match this, try to see if we can pull an fneg out of
11822   // the select.
11823   SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11824       True, DAG, LegalOperations, ForCodeSize);
11825   if (!NegTrue)
11826     return SDValue();
11827 
11828   HandleSDNode NegTrueHandle(NegTrue);
11829 
11830   // Try to unfold an fneg from the select if we are comparing the negated
11831   // constant.
11832   //
11833   // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11834   //
11835   // TODO: Handle fabs
11836   if (LHS == NegTrue) {
11837     // If we can't directly match this, try to see if we can pull an fneg out of
11838     // the select.
11839     SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11840         RHS, DAG, LegalOperations, ForCodeSize);
11841     if (NegRHS) {
11842       HandleSDNode NegRHSHandle(NegRHS);
11843       if (NegRHS == False) {
11844         SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
11845                                                    False, CC, TLI, DAG);
11846         if (Combined)
11847           return DAG.getNode(ISD::FNEG, DL, VT, Combined);
11848       }
11849     }
11850   }
11851 
11852   return SDValue();
11853 }
11854 
11855 /// If a (v)select has a condition value that is a sign-bit test, try to smear
11856 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)11857 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, const SDLoc &DL,
11858                                              SelectionDAG &DAG) {
11859   SDValue Cond = N->getOperand(0);
11860   SDValue C1 = N->getOperand(1);
11861   SDValue C2 = N->getOperand(2);
11862   if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
11863     return SDValue();
11864 
11865   EVT VT = N->getValueType(0);
11866   if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11867       VT != Cond.getOperand(0).getValueType())
11868     return SDValue();
11869 
11870   // The inverted-condition + commuted-select variants of these patterns are
11871   // canonicalized to these forms in IR.
11872   SDValue X = Cond.getOperand(0);
11873   SDValue CondC = Cond.getOperand(1);
11874   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11875   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
11876       isAllOnesOrAllOnesSplat(C2)) {
11877     // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11878     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11879     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11880     return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
11881   }
11882   if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
11883     // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11884     SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11885     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11886     return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
11887   }
11888   return SDValue();
11889 }
11890 
shouldConvertSelectOfConstantsToMath(const SDValue & Cond,EVT VT,const TargetLowering & TLI)11891 static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11892                                                  const TargetLowering &TLI) {
11893   if (!TLI.convertSelectOfConstantsToMath(VT))
11894     return false;
11895 
11896   if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11897     return true;
11898   if (!TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))
11899     return true;
11900 
11901   ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11902   if (CC == ISD::SETLT && isNullOrNullSplat(Cond.getOperand(1)))
11903     return true;
11904   if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond.getOperand(1)))
11905     return true;
11906 
11907   return false;
11908 }
11909 
foldSelectOfConstants(SDNode * N)11910 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11911   SDValue Cond = N->getOperand(0);
11912   SDValue N1 = N->getOperand(1);
11913   SDValue N2 = N->getOperand(2);
11914   EVT VT = N->getValueType(0);
11915   EVT CondVT = Cond.getValueType();
11916   SDLoc DL(N);
11917 
11918   if (!VT.isInteger())
11919     return SDValue();
11920 
11921   auto *C1 = dyn_cast<ConstantSDNode>(N1);
11922   auto *C2 = dyn_cast<ConstantSDNode>(N2);
11923   if (!C1 || !C2)
11924     return SDValue();
11925 
11926   if (CondVT != MVT::i1 || LegalOperations) {
11927     // fold (select Cond, 0, 1) -> (xor Cond, 1)
11928     // We can't do this reliably if integer based booleans have different contents
11929     // to floating point based booleans. This is because we can't tell whether we
11930     // have an integer-based boolean or a floating-point-based boolean unless we
11931     // can find the SETCC that produced it and inspect its operands. This is
11932     // fairly easy if C is the SETCC node, but it can potentially be
11933     // undiscoverable (or not reasonably discoverable). For example, it could be
11934     // in another basic block or it could require searching a complicated
11935     // expression.
11936     if (CondVT.isInteger() &&
11937         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11938             TargetLowering::ZeroOrOneBooleanContent &&
11939         TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11940             TargetLowering::ZeroOrOneBooleanContent &&
11941         C1->isZero() && C2->isOne()) {
11942       SDValue NotCond =
11943           DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
11944       if (VT.bitsEq(CondVT))
11945         return NotCond;
11946       return DAG.getZExtOrTrunc(NotCond, DL, VT);
11947     }
11948 
11949     return SDValue();
11950   }
11951 
11952   // Only do this before legalization to avoid conflicting with target-specific
11953   // transforms in the other direction (create a select from a zext/sext). There
11954   // is also a target-independent combine here in DAGCombiner in the other
11955   // direction for (select Cond, -1, 0) when the condition is not i1.
11956   assert(CondVT == MVT::i1 && !LegalOperations);
11957 
11958   // select Cond, 1, 0 --> zext (Cond)
11959   if (C1->isOne() && C2->isZero())
11960     return DAG.getZExtOrTrunc(Cond, DL, VT);
11961 
11962   // select Cond, -1, 0 --> sext (Cond)
11963   if (C1->isAllOnes() && C2->isZero())
11964     return DAG.getSExtOrTrunc(Cond, DL, VT);
11965 
11966   // select Cond, 0, 1 --> zext (!Cond)
11967   if (C1->isZero() && C2->isOne()) {
11968     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11969     NotCond = DAG.getZExtOrTrunc(NotCond, DL, VT);
11970     return NotCond;
11971   }
11972 
11973   // select Cond, 0, -1 --> sext (!Cond)
11974   if (C1->isZero() && C2->isAllOnes()) {
11975     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11976     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
11977     return NotCond;
11978   }
11979 
11980   // Use a target hook because some targets may prefer to transform in the
11981   // other direction.
11982   if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11983     return SDValue();
11984 
11985   // For any constants that differ by 1, we can transform the select into
11986   // an extend and add.
11987   const APInt &C1Val = C1->getAPIntValue();
11988   const APInt &C2Val = C2->getAPIntValue();
11989 
11990   // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11991   if (C1Val - 1 == C2Val) {
11992     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
11993     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
11994   }
11995 
11996   // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11997   if (C1Val + 1 == C2Val) {
11998     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
11999     return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
12000   }
12001 
12002   // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
12003   if (C1Val.isPowerOf2() && C2Val.isZero()) {
12004     Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
12005     SDValue ShAmtC =
12006         DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
12007     return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
12008   }
12009 
12010   // select Cond, -1, C --> or (sext Cond), C
12011   if (C1->isAllOnes()) {
12012     Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
12013     return DAG.getNode(ISD::OR, DL, VT, Cond, N2);
12014   }
12015 
12016   // select Cond, C, -1 --> or (sext (not Cond)), C
12017   if (C2->isAllOnes()) {
12018     SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
12019     NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
12020     return DAG.getNode(ISD::OR, DL, VT, NotCond, N1);
12021   }
12022 
12023   if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
12024     return V;
12025 
12026   return SDValue();
12027 }
12028 
12029 template <class MatchContextClass>
foldBoolSelectToLogic(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)12030 static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
12031                                      SelectionDAG &DAG) {
12032   assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
12033           N->getOpcode() == ISD::VP_SELECT) &&
12034          "Expected a (v)(vp.)select");
12035   SDValue Cond = N->getOperand(0);
12036   SDValue T = N->getOperand(1), F = N->getOperand(2);
12037   EVT VT = N->getValueType(0);
12038   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12039   MatchContextClass matcher(DAG, TLI, N);
12040 
12041   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
12042     return SDValue();
12043 
12044   // select Cond, Cond, F --> or Cond, freeze(F)
12045   // select Cond, 1, F    --> or Cond, freeze(F)
12046   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
12047     return matcher.getNode(ISD::OR, DL, VT, Cond, DAG.getFreeze(F));
12048 
12049   // select Cond, T, Cond --> and Cond, freeze(T)
12050   // select Cond, T, 0    --> and Cond, freeze(T)
12051   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
12052     return matcher.getNode(ISD::AND, DL, VT, Cond, DAG.getFreeze(T));
12053 
12054   // select Cond, T, 1 --> or (not Cond), freeze(T)
12055   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
12056     SDValue NotCond =
12057         matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12058     return matcher.getNode(ISD::OR, DL, VT, NotCond, DAG.getFreeze(T));
12059   }
12060 
12061   // select Cond, 0, F --> and (not Cond), freeze(F)
12062   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
12063     SDValue NotCond =
12064         matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12065     return matcher.getNode(ISD::AND, DL, VT, NotCond, DAG.getFreeze(F));
12066   }
12067 
12068   return SDValue();
12069 }
12070 
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)12071 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
12072   SDValue N0 = N->getOperand(0);
12073   SDValue N1 = N->getOperand(1);
12074   SDValue N2 = N->getOperand(2);
12075   EVT VT = N->getValueType(0);
12076   unsigned EltSizeInBits = VT.getScalarSizeInBits();
12077 
12078   SDValue Cond0, Cond1;
12079   ISD::CondCode CC;
12080   if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
12081                                      m_CondCode(CC)))) ||
12082       VT != Cond0.getValueType())
12083     return SDValue();
12084 
12085   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
12086   // compare is inverted from that pattern ("Cond0 s> -1").
12087   if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
12088     ; // This is the pattern we are looking for.
12089   else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
12090     std::swap(N1, N2);
12091   else
12092     return SDValue();
12093 
12094   // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & freeze(N1)
12095   if (isNullOrNullSplat(N2)) {
12096     SDLoc DL(N);
12097     SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12098     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12099     return DAG.getNode(ISD::AND, DL, VT, Sra, DAG.getFreeze(N1));
12100   }
12101 
12102   // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | freeze(N2)
12103   if (isAllOnesOrAllOnesSplat(N1)) {
12104     SDLoc DL(N);
12105     SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12106     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12107     return DAG.getNode(ISD::OR, DL, VT, Sra, DAG.getFreeze(N2));
12108   }
12109 
12110   // If we have to invert the sign bit mask, only do that transform if the
12111   // target has a bitwise 'and not' instruction (the invert is free).
12112   // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & freeze(N2)
12113   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12114   if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
12115     SDLoc DL(N);
12116     SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12117     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12118     SDValue Not = DAG.getNOT(DL, Sra, VT);
12119     return DAG.getNode(ISD::AND, DL, VT, Not, DAG.getFreeze(N2));
12120   }
12121 
12122   // TODO: There's another pattern in this family, but it may require
12123   //       implementing hasOrNot() to check for profitability:
12124   //       (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | freeze(N2)
12125 
12126   return SDValue();
12127 }
12128 
12129 // Match SELECTs with absolute difference patterns.
12130 // (select (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12131 // (select (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12132 // (select (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12133 // (select (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
foldSelectToABD(SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const SDLoc & DL)12134 SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
12135                                      SDValue False, ISD::CondCode CC,
12136                                      const SDLoc &DL) {
12137   bool IsSigned = isSignedIntSetCC(CC);
12138   unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12139   EVT VT = LHS.getValueType();
12140 
12141   if (LegalOperations && !hasOperation(ABDOpc, VT))
12142     return SDValue();
12143 
12144   switch (CC) {
12145   case ISD::SETGT:
12146   case ISD::SETGE:
12147   case ISD::SETUGT:
12148   case ISD::SETUGE:
12149     if (sd_match(True, m_Sub(m_Specific(LHS), m_Specific(RHS))) &&
12150         sd_match(False, m_Sub(m_Specific(RHS), m_Specific(LHS))))
12151       return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12152     if (sd_match(True, m_Sub(m_Specific(RHS), m_Specific(LHS))) &&
12153         sd_match(False, m_Sub(m_Specific(LHS), m_Specific(RHS))) &&
12154         hasOperation(ABDOpc, VT))
12155       return DAG.getNegative(DAG.getNode(ABDOpc, DL, VT, LHS, RHS), DL, VT);
12156     break;
12157   case ISD::SETLT:
12158   case ISD::SETLE:
12159   case ISD::SETULT:
12160   case ISD::SETULE:
12161     if (sd_match(True, m_Sub(m_Specific(RHS), m_Specific(LHS))) &&
12162         sd_match(False, m_Sub(m_Specific(LHS), m_Specific(RHS))))
12163       return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12164     if (sd_match(True, m_Sub(m_Specific(LHS), m_Specific(RHS))) &&
12165         sd_match(False, m_Sub(m_Specific(RHS), m_Specific(LHS))) &&
12166         hasOperation(ABDOpc, VT))
12167       return DAG.getNegative(DAG.getNode(ABDOpc, DL, VT, LHS, RHS), DL, VT);
12168     break;
12169   default:
12170     break;
12171   }
12172 
12173   return SDValue();
12174 }
12175 
visitSELECT(SDNode * N)12176 SDValue DAGCombiner::visitSELECT(SDNode *N) {
12177   SDValue N0 = N->getOperand(0);
12178   SDValue N1 = N->getOperand(1);
12179   SDValue N2 = N->getOperand(2);
12180   EVT VT = N->getValueType(0);
12181   EVT VT0 = N0.getValueType();
12182   SDLoc DL(N);
12183   SDNodeFlags Flags = N->getFlags();
12184 
12185   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
12186     return V;
12187 
12188   if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
12189     return V;
12190 
12191   // select (not Cond), N1, N2 -> select Cond, N2, N1
12192   if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
12193     return DAG.getSelect(DL, VT, F, N2, N1, Flags);
12194 
12195   if (SDValue V = foldSelectOfConstants(N))
12196     return V;
12197 
12198   // If we can fold this based on the true/false value, do so.
12199   if (SimplifySelectOps(N, N1, N2))
12200     return SDValue(N, 0); // Don't revisit N.
12201 
12202   if (VT0 == MVT::i1) {
12203     // The code in this block deals with the following 2 equivalences:
12204     //    select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
12205     //    select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
12206     // The target can specify its preferred form with the
12207     // shouldNormalizeToSelectSequence() callback. However we always transform
12208     // to the right anyway if we find the inner select exists in the DAG anyway
12209     // and we always transform to the left side if we know that we can further
12210     // optimize the combination of the conditions.
12211     bool normalizeToSequence =
12212         TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
12213     // select (and Cond0, Cond1), X, Y
12214     //   -> select Cond0, (select Cond1, X, Y), Y
12215     if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
12216       SDValue Cond0 = N0->getOperand(0);
12217       SDValue Cond1 = N0->getOperand(1);
12218       SDValue InnerSelect =
12219           DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
12220       if (normalizeToSequence || !InnerSelect.use_empty())
12221         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
12222                            InnerSelect, N2, Flags);
12223       // Cleanup on failure.
12224       if (InnerSelect.use_empty())
12225         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
12226     }
12227     // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
12228     if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
12229       SDValue Cond0 = N0->getOperand(0);
12230       SDValue Cond1 = N0->getOperand(1);
12231       SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
12232                                         Cond1, N1, N2, Flags);
12233       if (normalizeToSequence || !InnerSelect.use_empty())
12234         return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
12235                            InnerSelect, Flags);
12236       // Cleanup on failure.
12237       if (InnerSelect.use_empty())
12238         recursivelyDeleteUnusedNodes(InnerSelect.getNode());
12239     }
12240 
12241     // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
12242     if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
12243       SDValue N1_0 = N1->getOperand(0);
12244       SDValue N1_1 = N1->getOperand(1);
12245       SDValue N1_2 = N1->getOperand(2);
12246       if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
12247         // Create the actual and node if we can generate good code for it.
12248         if (!normalizeToSequence) {
12249           SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
12250           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
12251                              N2, Flags);
12252         }
12253         // Otherwise see if we can optimize the "and" to a better pattern.
12254         if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
12255           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
12256                              N2, Flags);
12257         }
12258       }
12259     }
12260     // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
12261     if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
12262       SDValue N2_0 = N2->getOperand(0);
12263       SDValue N2_1 = N2->getOperand(1);
12264       SDValue N2_2 = N2->getOperand(2);
12265       if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
12266         // Create the actual or node if we can generate good code for it.
12267         if (!normalizeToSequence) {
12268           SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
12269           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
12270                              N2_2, Flags);
12271         }
12272         // Otherwise see if we can optimize to a better pattern.
12273         if (SDValue Combined = visitORLike(N0, N2_0, DL))
12274           return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
12275                              N2_2, Flags);
12276       }
12277     }
12278 
12279     // select usubo(x, y).overflow, (sub y, x), (usubo x, y) -> abdu(x, y)
12280     if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12281         N2.getNode() == N0.getNode() && N2.getResNo() == 0 &&
12282         N1.getOpcode() == ISD::SUB && N2.getOperand(0) == N1.getOperand(1) &&
12283         N2.getOperand(1) == N1.getOperand(0) &&
12284         (!LegalOperations || TLI.isOperationLegal(ISD::ABDU, VT)))
12285       return DAG.getNode(ISD::ABDU, DL, VT, N0.getOperand(0), N0.getOperand(1));
12286 
12287     // select usubo(x, y).overflow, (usubo x, y), (sub y, x) -> neg (abdu x, y)
12288     if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12289         N1.getNode() == N0.getNode() && N1.getResNo() == 0 &&
12290         N2.getOpcode() == ISD::SUB && N2.getOperand(0) == N1.getOperand(1) &&
12291         N2.getOperand(1) == N1.getOperand(0) &&
12292         (!LegalOperations || TLI.isOperationLegal(ISD::ABDU, VT)))
12293       return DAG.getNegative(
12294           DAG.getNode(ISD::ABDU, DL, VT, N0.getOperand(0), N0.getOperand(1)),
12295           DL, VT);
12296   }
12297 
12298   // Fold selects based on a setcc into other things, such as min/max/abs.
12299   if (N0.getOpcode() == ISD::SETCC) {
12300     SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
12301     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
12302 
12303     // select (fcmp lt x, y), x, y -> fminnum x, y
12304     // select (fcmp gt x, y), x, y -> fmaxnum x, y
12305     //
12306     // This is OK if we don't care what happens if either operand is a NaN.
12307     if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, Flags, TLI))
12308       if (SDValue FMinMax =
12309               combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
12310         return FMinMax;
12311 
12312     // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
12313     // This is conservatively limited to pre-legal-operations to give targets
12314     // a chance to reverse the transform if they want to do that. Also, it is
12315     // unlikely that the pattern would be formed late, so it's probably not
12316     // worth going through the other checks.
12317     if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
12318         CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
12319         N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
12320       auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
12321       auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
12322       if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
12323         // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
12324         // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
12325         //
12326         // The IR equivalent of this transform would have this form:
12327         //   %a = add %x, C
12328         //   %c = icmp ugt %x, ~C
12329         //   %r = select %c, -1, %a
12330         //   =>
12331         //   %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
12332         //   %u0 = extractvalue %u, 0
12333         //   %u1 = extractvalue %u, 1
12334         //   %r = select %u1, -1, %u0
12335         SDVTList VTs = DAG.getVTList(VT, VT0);
12336         SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
12337         return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
12338       }
12339     }
12340 
12341     if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
12342         (!LegalOperations &&
12343          TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
12344       // Any flags available in a select/setcc fold will be on the setcc as they
12345       // migrated from fcmp
12346       Flags = N0->getFlags();
12347       SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
12348                                        N2, N0.getOperand(2));
12349       SelectNode->setFlags(Flags);
12350       return SelectNode;
12351     }
12352 
12353     if (SDValue ABD = foldSelectToABD(Cond0, Cond1, N1, N2, CC, DL))
12354       return ABD;
12355 
12356     if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
12357       return NewSel;
12358 
12359     // (select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
12360     // (select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
12361     APInt C;
12362     if (sd_match(Cond1, m_ConstInt(C)) && hasUMin(VT)) {
12363       if (CC == ISD::SETUGT && Cond0 == N2 &&
12364           sd_match(N1, m_Add(m_Specific(N2), m_SpecificInt(~C)))) {
12365         // The resulting code relies on an unsigned wrap in ADD.
12366         // Recreating ADD to drop possible nuw/nsw flags.
12367         SDValue AddC = DAG.getConstant(~C, DL, VT);
12368         SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N2, AddC);
12369         return DAG.getNode(ISD::UMIN, DL, VT, Add, N2);
12370       }
12371       if (CC == ISD::SETULT && Cond0 == N1 &&
12372           sd_match(N2, m_Add(m_Specific(N1), m_SpecificInt(-C)))) {
12373         // Ditto.
12374         SDValue AddC = DAG.getConstant(-C, DL, VT);
12375         SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, AddC);
12376         return DAG.getNode(ISD::UMIN, DL, VT, N1, Add);
12377       }
12378     }
12379   }
12380 
12381   if (!VT.isVector())
12382     if (SDValue BinOp = foldSelectOfBinops(N))
12383       return BinOp;
12384 
12385   if (SDValue R = combineSelectAsExtAnd(N0, N1, N2, DL, DAG))
12386     return R;
12387 
12388   return SDValue();
12389 }
12390 
12391 // This function assumes all the vselect's arguments are CONCAT_VECTOR
12392 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)12393 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
12394   SDLoc DL(N);
12395   SDValue Cond = N->getOperand(0);
12396   SDValue LHS = N->getOperand(1);
12397   SDValue RHS = N->getOperand(2);
12398   EVT VT = N->getValueType(0);
12399   int NumElems = VT.getVectorNumElements();
12400   assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
12401          RHS.getOpcode() == ISD::CONCAT_VECTORS &&
12402          Cond.getOpcode() == ISD::BUILD_VECTOR);
12403 
12404   // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
12405   // binary ones here.
12406   if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
12407     return SDValue();
12408 
12409   // We're sure we have an even number of elements due to the
12410   // concat_vectors we have as arguments to vselect.
12411   // Skip BV elements until we find one that's not an UNDEF
12412   // After we find an UNDEF element, keep looping until we get to half the
12413   // length of the BV and see if all the non-undef nodes are the same.
12414   ConstantSDNode *BottomHalf = nullptr;
12415   for (int i = 0; i < NumElems / 2; ++i) {
12416     if (Cond->getOperand(i)->isUndef())
12417       continue;
12418 
12419     if (BottomHalf == nullptr)
12420       BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
12421     else if (Cond->getOperand(i).getNode() != BottomHalf)
12422       return SDValue();
12423   }
12424 
12425   // Do the same for the second half of the BuildVector
12426   ConstantSDNode *TopHalf = nullptr;
12427   for (int i = NumElems / 2; i < NumElems; ++i) {
12428     if (Cond->getOperand(i)->isUndef())
12429       continue;
12430 
12431     if (TopHalf == nullptr)
12432       TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
12433     else if (Cond->getOperand(i).getNode() != TopHalf)
12434       return SDValue();
12435   }
12436 
12437   assert(TopHalf && BottomHalf &&
12438          "One half of the selector was all UNDEFs and the other was all the "
12439          "same value. This should have been addressed before this function.");
12440   return DAG.getNode(
12441       ISD::CONCAT_VECTORS, DL, VT,
12442       BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
12443       TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
12444 }
12445 
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG,const SDLoc & DL)12446 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
12447                        SelectionDAG &DAG, const SDLoc &DL) {
12448 
12449   // Only perform the transformation when existing operands can be reused.
12450   if (IndexIsScaled)
12451     return false;
12452 
12453   if (!isNullConstant(BasePtr) && !Index.hasOneUse())
12454     return false;
12455 
12456   EVT VT = BasePtr.getValueType();
12457 
12458   if (SDValue SplatVal = DAG.getSplatValue(Index);
12459       SplatVal && !isNullConstant(SplatVal) &&
12460       SplatVal.getValueType() == VT) {
12461     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
12462     Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
12463     return true;
12464   }
12465 
12466   if (Index.getOpcode() != ISD::ADD)
12467     return false;
12468 
12469   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
12470       SplatVal && SplatVal.getValueType() == VT) {
12471     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
12472     Index = Index.getOperand(1);
12473     return true;
12474   }
12475   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
12476       SplatVal && SplatVal.getValueType() == VT) {
12477     BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
12478     Index = Index.getOperand(0);
12479     return true;
12480   }
12481   return false;
12482 }
12483 
12484 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)12485 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
12486                      SelectionDAG &DAG) {
12487   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12488 
12489   // It's always safe to look through zero extends.
12490   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
12491     if (TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
12492       IndexType = ISD::UNSIGNED_SCALED;
12493       Index = Index.getOperand(0);
12494       return true;
12495     }
12496     if (ISD::isIndexTypeSigned(IndexType)) {
12497       IndexType = ISD::UNSIGNED_SCALED;
12498       return true;
12499     }
12500   }
12501 
12502   // It's only safe to look through sign extends when Index is signed.
12503   if (Index.getOpcode() == ISD::SIGN_EXTEND &&
12504       ISD::isIndexTypeSigned(IndexType) &&
12505       TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
12506     Index = Index.getOperand(0);
12507     return true;
12508   }
12509 
12510   return false;
12511 }
12512 
visitVPSCATTER(SDNode * N)12513 SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
12514   VPScatterSDNode *MSC = cast<VPScatterSDNode>(N);
12515   SDValue Mask = MSC->getMask();
12516   SDValue Chain = MSC->getChain();
12517   SDValue Index = MSC->getIndex();
12518   SDValue Scale = MSC->getScale();
12519   SDValue StoreVal = MSC->getValue();
12520   SDValue BasePtr = MSC->getBasePtr();
12521   SDValue VL = MSC->getVectorLength();
12522   ISD::MemIndexType IndexType = MSC->getIndexType();
12523   SDLoc DL(N);
12524 
12525   // Zap scatters with a zero mask.
12526   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12527     return Chain;
12528 
12529   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
12530     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12531     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12532                             DL, Ops, MSC->getMemOperand(), IndexType);
12533   }
12534 
12535   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
12536     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12537     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12538                             DL, Ops, MSC->getMemOperand(), IndexType);
12539   }
12540 
12541   return SDValue();
12542 }
12543 
visitMSCATTER(SDNode * N)12544 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
12545   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
12546   SDValue Mask = MSC->getMask();
12547   SDValue Chain = MSC->getChain();
12548   SDValue Index = MSC->getIndex();
12549   SDValue Scale = MSC->getScale();
12550   SDValue StoreVal = MSC->getValue();
12551   SDValue BasePtr = MSC->getBasePtr();
12552   ISD::MemIndexType IndexType = MSC->getIndexType();
12553   SDLoc DL(N);
12554 
12555   // Zap scatters with a zero mask.
12556   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12557     return Chain;
12558 
12559   if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
12560     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
12561     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12562                                 DL, Ops, MSC->getMemOperand(), IndexType,
12563                                 MSC->isTruncatingStore());
12564   }
12565 
12566   if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
12567     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
12568     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12569                                 DL, Ops, MSC->getMemOperand(), IndexType,
12570                                 MSC->isTruncatingStore());
12571   }
12572 
12573   return SDValue();
12574 }
12575 
visitMSTORE(SDNode * N)12576 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
12577   MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
12578   SDValue Mask = MST->getMask();
12579   SDValue Chain = MST->getChain();
12580   SDValue Value = MST->getValue();
12581   SDValue Ptr = MST->getBasePtr();
12582 
12583   // Zap masked stores with a zero mask.
12584   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12585     return Chain;
12586 
12587   // Remove a masked store if base pointers and masks are equal.
12588   if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Chain)) {
12589     if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
12590         MST1->isSimple() && MST1->getBasePtr() == Ptr &&
12591         !MST->getBasePtr().isUndef() &&
12592         ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
12593                                          MST1->getMemoryVT().getStoreSize()) ||
12594          ISD::isConstantSplatVectorAllOnes(Mask.getNode())) &&
12595         TypeSize::isKnownLE(MST1->getMemoryVT().getStoreSize(),
12596                             MST->getMemoryVT().getStoreSize())) {
12597       CombineTo(MST1, MST1->getChain());
12598       if (N->getOpcode() != ISD::DELETED_NODE)
12599         AddToWorklist(N);
12600       return SDValue(N, 0);
12601     }
12602   }
12603 
12604   // If this is a masked load with an all ones mask, we can use a unmasked load.
12605   // FIXME: Can we do this for indexed, compressing, or truncating stores?
12606   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
12607       !MST->isCompressingStore() && !MST->isTruncatingStore())
12608     return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
12609                         MST->getBasePtr(), MST->getPointerInfo(),
12610                         MST->getBaseAlign(), MST->getMemOperand()->getFlags(),
12611                         MST->getAAInfo());
12612 
12613   // Try transforming N to an indexed store.
12614   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12615     return SDValue(N, 0);
12616 
12617   if (MST->isTruncatingStore() && MST->isUnindexed() &&
12618       Value.getValueType().isInteger() &&
12619       (!isa<ConstantSDNode>(Value) ||
12620        !cast<ConstantSDNode>(Value)->isOpaque())) {
12621     APInt TruncDemandedBits =
12622         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
12623                              MST->getMemoryVT().getScalarSizeInBits());
12624 
12625     // See if we can simplify the operation with
12626     // SimplifyDemandedBits, which only works if the value has a single use.
12627     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
12628       // Re-visit the store if anything changed and the store hasn't been merged
12629       // with another node (N is deleted) SimplifyDemandedBits will add Value's
12630       // node back to the worklist if necessary, but we also need to re-visit
12631       // the Store node itself.
12632       if (N->getOpcode() != ISD::DELETED_NODE)
12633         AddToWorklist(N);
12634       return SDValue(N, 0);
12635     }
12636   }
12637 
12638   // If this is a TRUNC followed by a masked store, fold this into a masked
12639   // truncating store.  We can do this even if this is already a masked
12640   // truncstore.
12641   // TODO: Try combine to masked compress store if possiable.
12642   if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
12643       MST->isUnindexed() && !MST->isCompressingStore() &&
12644       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
12645                                MST->getMemoryVT(), LegalOperations)) {
12646     auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
12647                                          Value.getOperand(0).getValueType());
12648     return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
12649                               MST->getOffset(), Mask, MST->getMemoryVT(),
12650                               MST->getMemOperand(), MST->getAddressingMode(),
12651                               /*IsTruncating=*/true);
12652   }
12653 
12654   return SDValue();
12655 }
12656 
visitVP_STRIDED_STORE(SDNode * N)12657 SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
12658   auto *SST = cast<VPStridedStoreSDNode>(N);
12659   EVT EltVT = SST->getValue().getValueType().getVectorElementType();
12660   // Combine strided stores with unit-stride to a regular VP store.
12661   if (auto *CStride = dyn_cast<ConstantSDNode>(SST->getStride());
12662       CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12663     return DAG.getStoreVP(SST->getChain(), SDLoc(N), SST->getValue(),
12664                           SST->getBasePtr(), SST->getOffset(), SST->getMask(),
12665                           SST->getVectorLength(), SST->getMemoryVT(),
12666                           SST->getMemOperand(), SST->getAddressingMode(),
12667                           SST->isTruncatingStore(), SST->isCompressingStore());
12668   }
12669   return SDValue();
12670 }
12671 
visitVECTOR_COMPRESS(SDNode * N)12672 SDValue DAGCombiner::visitVECTOR_COMPRESS(SDNode *N) {
12673   SDLoc DL(N);
12674   SDValue Vec = N->getOperand(0);
12675   SDValue Mask = N->getOperand(1);
12676   SDValue Passthru = N->getOperand(2);
12677   EVT VecVT = Vec.getValueType();
12678 
12679   bool HasPassthru = !Passthru.isUndef();
12680 
12681   APInt SplatVal;
12682   if (ISD::isConstantSplatVector(Mask.getNode(), SplatVal))
12683     return TLI.isConstTrueVal(Mask) ? Vec : Passthru;
12684 
12685   if (Vec.isUndef() || Mask.isUndef())
12686     return Passthru;
12687 
12688   // No need for potentially expensive compress if the mask is constant.
12689   if (ISD::isBuildVectorOfConstantSDNodes(Mask.getNode())) {
12690     SmallVector<SDValue, 16> Ops;
12691     EVT ScalarVT = VecVT.getVectorElementType();
12692     unsigned NumSelected = 0;
12693     unsigned NumElmts = VecVT.getVectorNumElements();
12694     for (unsigned I = 0; I < NumElmts; ++I) {
12695       SDValue MaskI = Mask.getOperand(I);
12696       // We treat undef mask entries as "false".
12697       if (MaskI.isUndef())
12698         continue;
12699 
12700       if (TLI.isConstTrueVal(MaskI)) {
12701         SDValue VecI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec,
12702                                    DAG.getVectorIdxConstant(I, DL));
12703         Ops.push_back(VecI);
12704         NumSelected++;
12705       }
12706     }
12707     for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
12708       SDValue Val =
12709           HasPassthru
12710               ? DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Passthru,
12711                             DAG.getVectorIdxConstant(Rest, DL))
12712               : DAG.getUNDEF(ScalarVT);
12713       Ops.push_back(Val);
12714     }
12715     return DAG.getBuildVector(VecVT, DL, Ops);
12716   }
12717 
12718   return SDValue();
12719 }
12720 
visitVPGATHER(SDNode * N)12721 SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
12722   VPGatherSDNode *MGT = cast<VPGatherSDNode>(N);
12723   SDValue Mask = MGT->getMask();
12724   SDValue Chain = MGT->getChain();
12725   SDValue Index = MGT->getIndex();
12726   SDValue Scale = MGT->getScale();
12727   SDValue BasePtr = MGT->getBasePtr();
12728   SDValue VL = MGT->getVectorLength();
12729   ISD::MemIndexType IndexType = MGT->getIndexType();
12730   SDLoc DL(N);
12731 
12732   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
12733     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12734     return DAG.getGatherVP(
12735         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12736         Ops, MGT->getMemOperand(), IndexType);
12737   }
12738 
12739   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12740     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12741     return DAG.getGatherVP(
12742         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12743         Ops, MGT->getMemOperand(), IndexType);
12744   }
12745 
12746   return SDValue();
12747 }
12748 
visitMGATHER(SDNode * N)12749 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
12750   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
12751   SDValue Mask = MGT->getMask();
12752   SDValue Chain = MGT->getChain();
12753   SDValue Index = MGT->getIndex();
12754   SDValue Scale = MGT->getScale();
12755   SDValue PassThru = MGT->getPassThru();
12756   SDValue BasePtr = MGT->getBasePtr();
12757   ISD::MemIndexType IndexType = MGT->getIndexType();
12758   SDLoc DL(N);
12759 
12760   // Zap gathers with a zero mask.
12761   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12762     return CombineTo(N, PassThru, MGT->getChain());
12763 
12764   if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
12765     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12766     return DAG.getMaskedGather(
12767         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12768         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12769   }
12770 
12771   if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12772     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12773     return DAG.getMaskedGather(
12774         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12775         Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12776   }
12777 
12778   return SDValue();
12779 }
12780 
visitMLOAD(SDNode * N)12781 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
12782   MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
12783   SDValue Mask = MLD->getMask();
12784 
12785   // Zap masked loads with a zero mask.
12786   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12787     return CombineTo(N, MLD->getPassThru(), MLD->getChain());
12788 
12789   // If this is a masked load with an all ones mask, we can use a unmasked load.
12790   // FIXME: Can we do this for indexed, expanding, or extending loads?
12791   if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
12792       !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
12793     SDValue NewLd = DAG.getLoad(
12794         N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
12795         MLD->getPointerInfo(), MLD->getBaseAlign(),
12796         MLD->getMemOperand()->getFlags(), MLD->getAAInfo(), MLD->getRanges());
12797     return CombineTo(N, NewLd, NewLd.getValue(1));
12798   }
12799 
12800   // Try transforming N to an indexed load.
12801   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12802     return SDValue(N, 0);
12803 
12804   return SDValue();
12805 }
12806 
visitMHISTOGRAM(SDNode * N)12807 SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
12808   MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
12809   SDValue Chain = HG->getChain();
12810   SDValue Inc = HG->getInc();
12811   SDValue Mask = HG->getMask();
12812   SDValue BasePtr = HG->getBasePtr();
12813   SDValue Index = HG->getIndex();
12814   SDLoc DL(HG);
12815 
12816   EVT MemVT = HG->getMemoryVT();
12817   MachineMemOperand *MMO = HG->getMemOperand();
12818   ISD::MemIndexType IndexType = HG->getIndexType();
12819 
12820   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12821     return Chain;
12822 
12823   SDValue Ops[] = {Chain,          Inc,           Mask, BasePtr, Index,
12824                    HG->getScale(), HG->getIntID()};
12825   if (refineUniformBase(BasePtr, Index, HG->isIndexScaled(), DAG, DL))
12826     return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
12827                                   MMO, IndexType);
12828 
12829   EVT DataVT = Index.getValueType();
12830   if (refineIndexType(Index, IndexType, DataVT, DAG))
12831     return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
12832                                   MMO, IndexType);
12833   return SDValue();
12834 }
12835 
visitPARTIAL_REDUCE_MLA(SDNode * N)12836 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12837   if (SDValue Res = foldPartialReduceMLAMulOp(N))
12838     return Res;
12839   if (SDValue Res = foldPartialReduceAdd(N))
12840     return Res;
12841   return SDValue();
12842 }
12843 
12844 // partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
12845 // -> partial_reduce_*mla(acc, a, b)
12846 //
12847 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12848 // -> partial_reduce_*mla(acc, x, C)
foldPartialReduceMLAMulOp(SDNode * N)12849 SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12850   SDLoc DL(N);
12851   auto *Context = DAG.getContext();
12852   SDValue Acc = N->getOperand(0);
12853   SDValue Op1 = N->getOperand(1);
12854   SDValue Op2 = N->getOperand(2);
12855 
12856   APInt C;
12857   if (Op1->getOpcode() != ISD::MUL ||
12858       !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
12859     return SDValue();
12860 
12861   SDValue LHS = Op1->getOperand(0);
12862   SDValue RHS = Op1->getOperand(1);
12863   unsigned LHSOpcode = LHS->getOpcode();
12864   if (!ISD::isExtOpcode(LHSOpcode))
12865     return SDValue();
12866 
12867   SDValue LHSExtOp = LHS->getOperand(0);
12868   EVT LHSExtOpVT = LHSExtOp.getValueType();
12869 
12870   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12871   // -> partial_reduce_*mla(acc, x, C)
12872   if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
12873     // TODO: Make use of partial_reduce_sumla here
12874     APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
12875     unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
12876     if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
12877         (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
12878       return SDValue();
12879 
12880     unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
12881                              ? ISD::PARTIAL_REDUCE_SMLA
12882                              : ISD::PARTIAL_REDUCE_UMLA;
12883 
12884     // Only perform these combines if the target supports folding
12885     // the extends into the operation.
12886     if (!TLI.isPartialReduceMLALegalOrCustom(
12887             NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12888             TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12889       return SDValue();
12890 
12891     return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12892                        DAG.getConstant(CTrunc, DL, LHSExtOpVT));
12893   }
12894 
12895   unsigned RHSOpcode = RHS->getOpcode();
12896   if (!ISD::isExtOpcode(RHSOpcode))
12897     return SDValue();
12898 
12899   SDValue RHSExtOp = RHS->getOperand(0);
12900   if (LHSExtOpVT != RHSExtOp.getValueType())
12901     return SDValue();
12902 
12903   unsigned NewOpc;
12904   if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
12905     NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12906   else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12907     NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12908   else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12909     NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12910   else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
12911     NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12912     std::swap(LHSExtOp, RHSExtOp);
12913   } else
12914     return SDValue();
12915   // For a 2-stage extend the signedness of both of the extends must match
12916   // If the mul has the same type, there is no outer extend, and thus we
12917   // can simply use the inner extends to pick the result node.
12918   // TODO: extend to handle nonneg zext as sext
12919   EVT AccElemVT = Acc.getValueType().getVectorElementType();
12920   if (Op1.getValueType().getVectorElementType() != AccElemVT &&
12921       NewOpc != N->getOpcode())
12922     return SDValue();
12923 
12924   // Only perform these combines if the target supports folding
12925   // the extends into the operation.
12926   if (!TLI.isPartialReduceMLALegalOrCustom(
12927           NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12928           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12929     return SDValue();
12930 
12931   return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
12932 }
12933 
12934 // partial.reduce.umla(acc, zext(op), splat(1))
12935 // -> partial.reduce.umla(acc, op, splat(trunc(1)))
12936 // partial.reduce.smla(acc, sext(op), splat(1))
12937 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
12938 // partial.reduce.sumla(acc, sext(op), splat(1))
12939 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
foldPartialReduceAdd(SDNode * N)12940 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
12941   SDLoc DL(N);
12942   SDValue Acc = N->getOperand(0);
12943   SDValue Op1 = N->getOperand(1);
12944   SDValue Op2 = N->getOperand(2);
12945 
12946   APInt ConstantOne;
12947   if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12948       !ConstantOne.isOne())
12949     return SDValue();
12950 
12951   unsigned Op1Opcode = Op1.getOpcode();
12952   if (!ISD::isExtOpcode(Op1Opcode))
12953     return SDValue();
12954 
12955   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12956   bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
12957   EVT AccElemVT = Acc.getValueType().getVectorElementType();
12958   if (Op1IsSigned != NodeIsSigned &&
12959       Op1.getValueType().getVectorElementType() != AccElemVT)
12960     return SDValue();
12961 
12962   unsigned NewOpcode =
12963       Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12964 
12965   SDValue UnextOp1 = Op1.getOperand(0);
12966   EVT UnextOp1VT = UnextOp1.getValueType();
12967   auto *Context = DAG.getContext();
12968   if (!TLI.isPartialReduceMLALegalOrCustom(
12969           NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12970           TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
12971     return SDValue();
12972 
12973   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12974                      DAG.getConstant(1, DL, UnextOp1VT));
12975 }
12976 
visitVP_STRIDED_LOAD(SDNode * N)12977 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12978   auto *SLD = cast<VPStridedLoadSDNode>(N);
12979   EVT EltVT = SLD->getValueType(0).getVectorElementType();
12980   // Combine strided loads with unit-stride to a regular VP load.
12981   if (auto *CStride = dyn_cast<ConstantSDNode>(SLD->getStride());
12982       CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12983     SDValue NewLd = DAG.getLoadVP(
12984         SLD->getAddressingMode(), SLD->getExtensionType(), SLD->getValueType(0),
12985         SDLoc(N), SLD->getChain(), SLD->getBasePtr(), SLD->getOffset(),
12986         SLD->getMask(), SLD->getVectorLength(), SLD->getMemoryVT(),
12987         SLD->getMemOperand(), SLD->isExpandingLoad());
12988     return CombineTo(N, NewLd, NewLd.getValue(1));
12989   }
12990   return SDValue();
12991 }
12992 
12993 /// A vector select of 2 constant vectors can be simplified to math/logic to
12994 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)12995 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12996   SDValue Cond = N->getOperand(0);
12997   SDValue N1 = N->getOperand(1);
12998   SDValue N2 = N->getOperand(2);
12999   EVT VT = N->getValueType(0);
13000   if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
13001       !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
13002       !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
13003       !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
13004     return SDValue();
13005 
13006   // Check if we can use the condition value to increment/decrement a single
13007   // constant value. This simplifies a select to an add and removes a constant
13008   // load/materialization from the general case.
13009   bool AllAddOne = true;
13010   bool AllSubOne = true;
13011   unsigned Elts = VT.getVectorNumElements();
13012   for (unsigned i = 0; i != Elts; ++i) {
13013     SDValue N1Elt = N1.getOperand(i);
13014     SDValue N2Elt = N2.getOperand(i);
13015     if (N1Elt.isUndef())
13016       continue;
13017     // N2 should not contain undef values since it will be reused in the fold.
13018     if (N2Elt.isUndef() || N1Elt.getValueType() != N2Elt.getValueType()) {
13019       AllAddOne = false;
13020       AllSubOne = false;
13021       break;
13022     }
13023 
13024     const APInt &C1 = N1Elt->getAsAPIntVal();
13025     const APInt &C2 = N2Elt->getAsAPIntVal();
13026     if (C1 != C2 + 1)
13027       AllAddOne = false;
13028     if (C1 != C2 - 1)
13029       AllSubOne = false;
13030   }
13031 
13032   // Further simplifications for the extra-special cases where the constants are
13033   // all 0 or all -1 should be implemented as folds of these patterns.
13034   SDLoc DL(N);
13035   if (AllAddOne || AllSubOne) {
13036     // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
13037     // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
13038     auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
13039     SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
13040     return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
13041   }
13042 
13043   // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
13044   APInt Pow2C;
13045   if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
13046       isNullOrNullSplat(N2)) {
13047     SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
13048     SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
13049     return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
13050   }
13051 
13052   if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
13053     return V;
13054 
13055   // The general case for select-of-constants:
13056   // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
13057   // ...but that only makes sense if a vselect is slower than 2 logic ops, so
13058   // leave that to a machine-specific pass.
13059   return SDValue();
13060 }
13061 
visitVP_SELECT(SDNode * N)13062 SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
13063   SDValue N0 = N->getOperand(0);
13064   SDValue N1 = N->getOperand(1);
13065   SDValue N2 = N->getOperand(2);
13066   SDLoc DL(N);
13067 
13068   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
13069     return V;
13070 
13071   if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DL, DAG))
13072     return V;
13073 
13074   return SDValue();
13075 }
13076 
combineVSelectWithAllOnesOrZeros(SDValue Cond,SDValue TVal,SDValue FVal,const TargetLowering & TLI,SelectionDAG & DAG,const SDLoc & DL)13077 static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
13078                                                 SDValue FVal,
13079                                                 const TargetLowering &TLI,
13080                                                 SelectionDAG &DAG,
13081                                                 const SDLoc &DL) {
13082   EVT VT = TVal.getValueType();
13083   if (!TLI.isTypeLegal(VT))
13084     return SDValue();
13085 
13086   EVT CondVT = Cond.getValueType();
13087   assert(CondVT.isVector() && "Vector select expects a vector selector!");
13088 
13089   bool IsTAllZero = ISD::isConstantSplatVectorAllZeros(TVal.getNode());
13090   bool IsTAllOne = ISD::isConstantSplatVectorAllOnes(TVal.getNode());
13091   bool IsFAllZero = ISD::isConstantSplatVectorAllZeros(FVal.getNode());
13092   bool IsFAllOne = ISD::isConstantSplatVectorAllOnes(FVal.getNode());
13093 
13094   // no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
13095   if (!IsTAllZero && !IsTAllOne && !IsFAllZero && !IsFAllOne)
13096     return SDValue();
13097 
13098   // select Cond, 0, 0 → 0
13099   if (IsTAllZero && IsFAllZero) {
13100     return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, VT)
13101                                 : DAG.getConstant(0, DL, VT);
13102   }
13103 
13104   // check select(setgt lhs, -1), 1, -1 --> or (sra lhs, bitwidth - 1), 1
13105   APInt TValAPInt;
13106   if (Cond.getOpcode() == ISD::SETCC &&
13107       Cond.getOperand(2) == DAG.getCondCode(ISD::SETGT) &&
13108       Cond.getOperand(0).getValueType() == VT && VT.isSimple() &&
13109       ISD::isConstantSplatVector(TVal.getNode(), TValAPInt) &&
13110       TValAPInt.isOne() &&
13111       ISD::isConstantSplatVectorAllOnes(Cond.getOperand(1).getNode()) &&
13112       ISD::isConstantSplatVectorAllOnes(FVal.getNode())) {
13113     return SDValue();
13114   }
13115 
13116   // To use the condition operand as a bitwise mask, it must have elements that
13117   // are the same size as the select elements. i.e, the condition operand must
13118   // have already been promoted from the IR select condition type <N x i1>.
13119   // Don't check if the types themselves are equal because that excludes
13120   // vector floating-point selects.
13121   if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
13122     return SDValue();
13123 
13124   // Cond value must be 'sign splat' to be converted to a logical op.
13125   if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
13126     return SDValue();
13127 
13128   // Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
13129   if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
13130       Cond.getOpcode() == ISD::SETCC &&
13131       TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
13132           CondVT) {
13133     if (IsTAllZero || IsFAllOne) {
13134       SDValue CC = Cond.getOperand(2);
13135       ISD::CondCode InverseCC = ISD::getSetCCInverse(
13136           cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
13137       Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
13138                           InverseCC);
13139       std::swap(TVal, FVal);
13140       std::swap(IsTAllOne, IsFAllOne);
13141       std::swap(IsTAllZero, IsFAllZero);
13142     }
13143   }
13144 
13145   assert(DAG.ComputeNumSignBits(Cond) == CondVT.getScalarSizeInBits() &&
13146          "Select condition no longer all-sign bits");
13147 
13148   // select Cond, -1, 0 → bitcast Cond
13149   if (IsTAllOne && IsFAllZero)
13150     return DAG.getBitcast(VT, Cond);
13151 
13152   // select Cond, -1, x → or Cond, x
13153   if (IsTAllOne) {
13154     SDValue X = DAG.getBitcast(CondVT, FVal);
13155     SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, X);
13156     return DAG.getBitcast(VT, Or);
13157   }
13158 
13159   // select Cond, x, 0 → and Cond, x
13160   if (IsFAllZero) {
13161     SDValue X = DAG.getBitcast(CondVT, TVal);
13162     SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, X);
13163     return DAG.getBitcast(VT, And);
13164   }
13165 
13166   // select Cond, 0, x -> and not(Cond), x
13167   if (IsTAllZero &&
13168       (isBitwiseNot(peekThroughBitcasts(Cond)) || TLI.hasAndNot(Cond))) {
13169     SDValue X = DAG.getBitcast(CondVT, FVal);
13170     SDValue And =
13171         DAG.getNode(ISD::AND, DL, CondVT, DAG.getNOT(DL, Cond, CondVT), X);
13172     return DAG.getBitcast(VT, And);
13173   }
13174 
13175   return SDValue();
13176 }
13177 
visitVSELECT(SDNode * N)13178 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
13179   SDValue N0 = N->getOperand(0);
13180   SDValue N1 = N->getOperand(1);
13181   SDValue N2 = N->getOperand(2);
13182   EVT VT = N->getValueType(0);
13183   SDLoc DL(N);
13184 
13185   if (SDValue V = DAG.simplifySelect(N0, N1, N2))
13186     return V;
13187 
13188   if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
13189     return V;
13190 
13191   // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
13192   if (!TLI.isTargetCanonicalSelect(N))
13193     if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
13194       return DAG.getSelect(DL, VT, F, N2, N1);
13195 
13196   // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
13197   if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&
13198       DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1)) &&
13199       N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits() &&
13200       TLI.getBooleanContents(N0.getValueType()) ==
13201           TargetLowering::ZeroOrNegativeOneBooleanContent) {
13202     return DAG.getNode(
13203         ISD::ADD, DL, N1.getValueType(), N2,
13204         DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1), N0));
13205   }
13206 
13207   // Canonicalize integer abs.
13208   // vselect (setg[te] X,  0),  X, -X ->
13209   // vselect (setgt    X, -1),  X, -X ->
13210   // vselect (setl[te] X,  0), -X,  X ->
13211   // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
13212   if (N0.getOpcode() == ISD::SETCC) {
13213     SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
13214     ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
13215     bool isAbs = false;
13216     bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
13217 
13218     if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
13219          (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
13220         N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
13221       isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
13222     else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
13223              N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
13224       isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
13225 
13226     if (isAbs) {
13227       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
13228         return DAG.getNode(ISD::ABS, DL, VT, LHS);
13229 
13230       SDValue Shift = DAG.getNode(
13231           ISD::SRA, DL, VT, LHS,
13232           DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
13233       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
13234       AddToWorklist(Shift.getNode());
13235       AddToWorklist(Add.getNode());
13236       return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
13237     }
13238 
13239     // vselect x, y (fcmp lt x, y) -> fminnum x, y
13240     // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
13241     //
13242     // This is OK if we don't care about what happens if either operand is a
13243     // NaN.
13244     //
13245     if (N0.hasOneUse() &&
13246         isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, N->getFlags(), TLI)) {
13247       if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
13248         return FMinMax;
13249     }
13250 
13251     if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
13252       return S;
13253     if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
13254       return S;
13255 
13256     // If this select has a condition (setcc) with narrower operands than the
13257     // select, try to widen the compare to match the select width.
13258     // TODO: This should be extended to handle any constant.
13259     // TODO: This could be extended to handle non-loading patterns, but that
13260     //       requires thorough testing to avoid regressions.
13261     if (isNullOrNullSplat(RHS)) {
13262       EVT NarrowVT = LHS.getValueType();
13263       EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
13264       EVT SetCCVT = getSetCCResultType(LHS.getValueType());
13265       unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
13266       unsigned WideWidth = WideVT.getScalarSizeInBits();
13267       bool IsSigned = isSignedIntSetCC(CC);
13268       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13269       if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
13270           SetCCWidth != 1 && SetCCWidth < WideWidth &&
13271           TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
13272           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
13273         // Both compare operands can be widened for free. The LHS can use an
13274         // extended load, and the RHS is a constant:
13275         //   vselect (ext (setcc load(X), C)), N1, N2 -->
13276         //   vselect (setcc extload(X), C'), N1, N2
13277         auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13278         SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
13279         SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
13280         EVT WideSetCCVT = getSetCCResultType(WideVT);
13281         SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
13282         return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
13283       }
13284     }
13285 
13286     if (SDValue ABD = foldSelectToABD(LHS, RHS, N1, N2, CC, DL))
13287       return ABD;
13288 
13289     // Match VSELECTs into add with unsigned saturation.
13290     if (hasOperation(ISD::UADDSAT, VT)) {
13291       // Check if one of the arms of the VSELECT is vector with all bits set.
13292       // If it's on the left side invert the predicate to simplify logic below.
13293       SDValue Other;
13294       ISD::CondCode SatCC = CC;
13295       if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
13296         Other = N2;
13297         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
13298       } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
13299         Other = N1;
13300       }
13301 
13302       if (Other && Other.getOpcode() == ISD::ADD) {
13303         SDValue CondLHS = LHS, CondRHS = RHS;
13304         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
13305 
13306         // Canonicalize condition operands.
13307         if (SatCC == ISD::SETUGE) {
13308           std::swap(CondLHS, CondRHS);
13309           SatCC = ISD::SETULE;
13310         }
13311 
13312         // We can test against either of the addition operands.
13313         // x <= x+y ? x+y : ~0 --> uaddsat x, y
13314         // x+y >= x ? x+y : ~0 --> uaddsat x, y
13315         if (SatCC == ISD::SETULE && Other == CondRHS &&
13316             (OpLHS == CondLHS || OpRHS == CondLHS))
13317           return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
13318 
13319         if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
13320             (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13321              OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
13322             CondLHS == OpLHS) {
13323           // If the RHS is a constant we have to reverse the const
13324           // canonicalization.
13325           // x >= ~C ? x+C : ~0 --> uaddsat x, C
13326           auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13327             return Cond->getAPIntValue() == ~Op->getAPIntValue();
13328           };
13329           if (SatCC == ISD::SETULE &&
13330               ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
13331             return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
13332         }
13333       }
13334     }
13335 
13336     // Match VSELECTs into sub with unsigned saturation.
13337     if (hasOperation(ISD::USUBSAT, VT)) {
13338       // Check if one of the arms of the VSELECT is a zero vector. If it's on
13339       // the left side invert the predicate to simplify logic below.
13340       SDValue Other;
13341       ISD::CondCode SatCC = CC;
13342       if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
13343         Other = N2;
13344         SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
13345       } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
13346         Other = N1;
13347       }
13348 
13349       // zext(x) >= y ? trunc(zext(x) - y) : 0
13350       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13351       // zext(x) >  y ? trunc(zext(x) - y) : 0
13352       // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13353       if (Other && Other.getOpcode() == ISD::TRUNCATE &&
13354           Other.getOperand(0).getOpcode() == ISD::SUB &&
13355           (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
13356         SDValue OpLHS = Other.getOperand(0).getOperand(0);
13357         SDValue OpRHS = Other.getOperand(0).getOperand(1);
13358         if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
13359           if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
13360                                               DAG, DL))
13361             return R;
13362       }
13363 
13364       if (Other && Other.getNumOperands() == 2) {
13365         SDValue CondRHS = RHS;
13366         SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
13367 
13368         if (OpLHS == LHS) {
13369           // Look for a general sub with unsigned saturation first.
13370           // x >= y ? x-y : 0 --> usubsat x, y
13371           // x >  y ? x-y : 0 --> usubsat x, y
13372           if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
13373               Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
13374             return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
13375 
13376           if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13377               OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13378             if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
13379                 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13380               // If the RHS is a constant we have to reverse the const
13381               // canonicalization.
13382               // x > C-1 ? x+-C : 0 --> usubsat x, C
13383               auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13384                 return (!Op && !Cond) ||
13385                        (Op && Cond &&
13386                         Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
13387               };
13388               if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
13389                   ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
13390                                             /*AllowUndefs*/ true)) {
13391                 OpRHS = DAG.getNegative(OpRHS, DL, VT);
13392                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
13393               }
13394 
13395               // Another special case: If C was a sign bit, the sub has been
13396               // canonicalized into a xor.
13397               // FIXME: Would it be better to use computeKnownBits to
13398               // determine whether it's safe to decanonicalize the xor?
13399               // x s< 0 ? x^C : 0 --> usubsat x, C
13400               APInt SplatValue;
13401               if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
13402                   ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
13403                   ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
13404                   SplatValue.isSignMask()) {
13405                 // Note that we have to rebuild the RHS constant here to
13406                 // ensure we don't rely on particular values of undef lanes.
13407                 OpRHS = DAG.getConstant(SplatValue, DL, VT);
13408                 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
13409               }
13410             }
13411           }
13412         }
13413       }
13414     }
13415   }
13416 
13417   if (SimplifySelectOps(N, N1, N2))
13418     return SDValue(N, 0);  // Don't revisit N.
13419 
13420   // Fold (vselect all_ones, N1, N2) -> N1
13421   if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
13422     return N1;
13423   // Fold (vselect all_zeros, N1, N2) -> N2
13424   if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
13425     return N2;
13426 
13427   // The ConvertSelectToConcatVector function is assuming both the above
13428   // checks for (vselect (build_vector all{ones,zeros) ...) have been made
13429   // and addressed.
13430   if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
13431       N2.getOpcode() == ISD::CONCAT_VECTORS &&
13432       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
13433     if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
13434       return CV;
13435   }
13436 
13437   if (SDValue V = foldVSelectOfConstants(N))
13438     return V;
13439 
13440   if (hasOperation(ISD::SRA, VT))
13441     if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
13442       return V;
13443 
13444   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
13445     return SDValue(N, 0);
13446 
13447   if (SDValue V = combineVSelectWithAllOnesOrZeros(N0, N1, N2, TLI, DAG, DL))
13448     return V;
13449 
13450   return SDValue();
13451 }
13452 
visitSELECT_CC(SDNode * N)13453 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
13454   SDValue N0 = N->getOperand(0);
13455   SDValue N1 = N->getOperand(1);
13456   SDValue N2 = N->getOperand(2);
13457   SDValue N3 = N->getOperand(3);
13458   SDValue N4 = N->getOperand(4);
13459   ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
13460   SDLoc DL(N);
13461 
13462   // fold select_cc lhs, rhs, x, x, cc -> x
13463   if (N2 == N3)
13464     return N2;
13465 
13466   // select_cc bool, 0, x, y, seteq -> select bool, y, x
13467   if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
13468       isNullConstant(N1))
13469     return DAG.getSelect(DL, N2.getValueType(), N0, N3, N2);
13470 
13471   // Determine if the condition we're dealing with is constant
13472   if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
13473                                   CC, DL, false)) {
13474     AddToWorklist(SCC.getNode());
13475 
13476     // cond always true -> true val
13477     // cond always false -> false val
13478     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
13479       return SCCC->isZero() ? N3 : N2;
13480 
13481     // When the condition is UNDEF, just return the first operand. This is
13482     // coherent the DAG creation, no setcc node is created in this case
13483     if (SCC->isUndef())
13484       return N2;
13485 
13486     // Fold to a simpler select_cc
13487     if (SCC.getOpcode() == ISD::SETCC) {
13488       return DAG.getNode(ISD::SELECT_CC, DL, N2.getValueType(),
13489                          SCC.getOperand(0), SCC.getOperand(1), N2, N3,
13490                          SCC.getOperand(2), SCC->getFlags());
13491     }
13492   }
13493 
13494   // If we can fold this based on the true/false value, do so.
13495   if (SimplifySelectOps(N, N2, N3))
13496     return SDValue(N, 0); // Don't revisit N.
13497 
13498   // fold select_cc into other things, such as min/max/abs
13499   return SimplifySelectCC(DL, N0, N1, N2, N3, CC);
13500 }
13501 
visitSETCC(SDNode * N)13502 SDValue DAGCombiner::visitSETCC(SDNode *N) {
13503   // setcc is very commonly used as an argument to brcond. This pattern
13504   // also lend itself to numerous combines and, as a result, it is desired
13505   // we keep the argument to a brcond as a setcc as much as possible.
13506   bool PreferSetCC =
13507       N->hasOneUse() && N->user_begin()->getOpcode() == ISD::BRCOND;
13508 
13509   ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
13510   EVT VT = N->getValueType(0);
13511   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
13512   SDLoc DL(N);
13513 
13514   if (SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL, !PreferSetCC)) {
13515     // If we prefer to have a setcc, and we don't, we'll try our best to
13516     // recreate one using rebuildSetCC.
13517     if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
13518       SDValue NewSetCC = rebuildSetCC(Combined);
13519 
13520       // We don't have anything interesting to combine to.
13521       if (NewSetCC.getNode() == N)
13522         return SDValue();
13523 
13524       if (NewSetCC)
13525         return NewSetCC;
13526     }
13527     return Combined;
13528   }
13529 
13530   // Optimize
13531   //    1) (icmp eq/ne (and X, C0), (shift X, C1))
13532   // or
13533   //    2) (icmp eq/ne X, (rotate X, C1))
13534   // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
13535   // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
13536   // Then:
13537   // If C1 is a power of 2, then the rotate and shift+and versions are
13538   // equivilent, so we can interchange them depending on target preference.
13539   // Otherwise, if we have the shift+and version we can interchange srl/shl
13540   // which inturn affects the constant C0. We can use this to get better
13541   // constants again determined by target preference.
13542   if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
13543     auto IsAndWithShift = [](SDValue A, SDValue B) {
13544       return A.getOpcode() == ISD::AND &&
13545              (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
13546              A.getOperand(0) == B.getOperand(0);
13547     };
13548     auto IsRotateWithOp = [](SDValue A, SDValue B) {
13549       return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
13550              B.getOperand(0) == A;
13551     };
13552     SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
13553     bool IsRotate = false;
13554 
13555     // Find either shift+and or rotate pattern.
13556     if (IsAndWithShift(N0, N1)) {
13557       AndOrOp = N0;
13558       ShiftOrRotate = N1;
13559     } else if (IsAndWithShift(N1, N0)) {
13560       AndOrOp = N1;
13561       ShiftOrRotate = N0;
13562     } else if (IsRotateWithOp(N0, N1)) {
13563       IsRotate = true;
13564       AndOrOp = N0;
13565       ShiftOrRotate = N1;
13566     } else if (IsRotateWithOp(N1, N0)) {
13567       IsRotate = true;
13568       AndOrOp = N1;
13569       ShiftOrRotate = N0;
13570     }
13571 
13572     if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
13573         (IsRotate || AndOrOp.hasOneUse())) {
13574       EVT OpVT = N0.getValueType();
13575       // Get constant shift/rotate amount and possibly mask (if its shift+and
13576       // variant).
13577       auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
13578         ConstantSDNode *CNode = isConstOrConstSplat(Op, /*AllowUndefs*/ false,
13579                                                     /*AllowTrunc*/ false);
13580         if (CNode == nullptr)
13581           return std::nullopt;
13582         return CNode->getAPIntValue();
13583       };
13584       std::optional<APInt> AndCMask =
13585           IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(1));
13586       std::optional<APInt> ShiftCAmt =
13587           GetAPIntValue(ShiftOrRotate.getOperand(1));
13588       unsigned NumBits = OpVT.getScalarSizeInBits();
13589 
13590       // We found constants.
13591       if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(NumBits)) {
13592         unsigned ShiftOpc = ShiftOrRotate.getOpcode();
13593         // Check that the constants meet the constraints.
13594         bool CanTransform = IsRotate;
13595         if (!CanTransform) {
13596           // Check that mask and shift compliment eachother
13597           CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
13598           // Check that we are comparing all bits
13599           CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
13600           // Check that the and mask is correct for the shift
13601           CanTransform &=
13602               ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
13603         }
13604 
13605         // See if target prefers another shift/rotate opcode.
13606         unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
13607             OpVT, ShiftOpc, ShiftCAmt->isPowerOf2(), *ShiftCAmt, AndCMask);
13608         // Transform is valid and we have a new preference.
13609         if (CanTransform && NewShiftOpc != ShiftOpc) {
13610           SDValue NewShiftOrRotate =
13611               DAG.getNode(NewShiftOpc, DL, OpVT, ShiftOrRotate.getOperand(0),
13612                           ShiftOrRotate.getOperand(1));
13613           SDValue NewAndOrOp = SDValue();
13614 
13615           if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
13616             APInt NewMask =
13617                 NewShiftOpc == ISD::SHL
13618                     ? APInt::getHighBitsSet(NumBits,
13619                                             NumBits - ShiftCAmt->getZExtValue())
13620                     : APInt::getLowBitsSet(NumBits,
13621                                            NumBits - ShiftCAmt->getZExtValue());
13622             NewAndOrOp =
13623                 DAG.getNode(ISD::AND, DL, OpVT, ShiftOrRotate.getOperand(0),
13624                             DAG.getConstant(NewMask, DL, OpVT));
13625           } else {
13626             NewAndOrOp = ShiftOrRotate.getOperand(0);
13627           }
13628 
13629           return DAG.getSetCC(DL, VT, NewAndOrOp, NewShiftOrRotate, Cond);
13630         }
13631       }
13632     }
13633   }
13634   return SDValue();
13635 }
13636 
visitSETCCCARRY(SDNode * N)13637 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
13638   SDValue LHS = N->getOperand(0);
13639   SDValue RHS = N->getOperand(1);
13640   SDValue Carry = N->getOperand(2);
13641   SDValue Cond = N->getOperand(3);
13642 
13643   // If Carry is false, fold to a regular SETCC.
13644   if (isNullConstant(Carry))
13645     return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
13646 
13647   return SDValue();
13648 }
13649 
13650 /// Check if N satisfies:
13651 ///   N is used once.
13652 ///   N is a Load.
13653 ///   The load is compatible with ExtOpcode. It means
13654 ///     If load has explicit zero/sign extension, ExpOpcode must have the same
13655 ///     extension.
13656 ///     Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)13657 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
13658   if (!N.hasOneUse())
13659     return false;
13660 
13661   if (!isa<LoadSDNode>(N))
13662     return false;
13663 
13664   LoadSDNode *Load = cast<LoadSDNode>(N);
13665   ISD::LoadExtType LoadExt = Load->getExtensionType();
13666   if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
13667     return true;
13668 
13669   // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
13670   // extension.
13671   if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
13672       (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
13673     return false;
13674 
13675   return true;
13676 }
13677 
13678 /// Fold
13679 ///   (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
13680 ///   (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
13681 ///   (aext (select c, load x, load y)) -> (select c, extload x, extload y)
13682 /// This function is called by the DAGCombiner when visiting sext/zext/aext
13683 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,const SDLoc & DL,CombineLevel Level)13684 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
13685                                          SelectionDAG &DAG, const SDLoc &DL,
13686                                          CombineLevel Level) {
13687   unsigned Opcode = N->getOpcode();
13688   SDValue N0 = N->getOperand(0);
13689   EVT VT = N->getValueType(0);
13690   assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
13691           Opcode == ISD::ANY_EXTEND) &&
13692          "Expected EXTEND dag node in input!");
13693 
13694   if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
13695       !N0.hasOneUse())
13696     return SDValue();
13697 
13698   SDValue Op1 = N0->getOperand(1);
13699   SDValue Op2 = N0->getOperand(2);
13700   if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
13701     return SDValue();
13702 
13703   auto ExtLoadOpcode = ISD::EXTLOAD;
13704   if (Opcode == ISD::SIGN_EXTEND)
13705     ExtLoadOpcode = ISD::SEXTLOAD;
13706   else if (Opcode == ISD::ZERO_EXTEND)
13707     ExtLoadOpcode = ISD::ZEXTLOAD;
13708 
13709   // Illegal VSELECT may ISel fail if happen after legalization (DAG
13710   // Combine2), so we should conservatively check the OperationAction.
13711   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
13712   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
13713   if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
13714       !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) ||
13715       (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
13716        TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal))
13717     return SDValue();
13718 
13719   SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
13720   SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
13721   return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
13722 }
13723 
13724 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
13725 /// a build_vector of constants.
13726 /// This function is called by the DAGCombiner when visiting sext/zext/aext
13727 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
13728 /// Vector extends are not folded if operations are legal; this is to
13729 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)13730 static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
13731                                          const TargetLowering &TLI,
13732                                          SelectionDAG &DAG, bool LegalTypes) {
13733   unsigned Opcode = N->getOpcode();
13734   SDValue N0 = N->getOperand(0);
13735   EVT VT = N->getValueType(0);
13736 
13737   assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
13738          "Expected EXTEND dag node in input!");
13739 
13740   // fold (sext c1) -> c1
13741   // fold (zext c1) -> c1
13742   // fold (aext c1) -> c1
13743   if (isa<ConstantSDNode>(N0))
13744     return DAG.getNode(Opcode, DL, VT, N0);
13745 
13746   // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
13747   // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
13748   // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
13749   if (N0->getOpcode() == ISD::SELECT) {
13750     SDValue Op1 = N0->getOperand(1);
13751     SDValue Op2 = N0->getOperand(2);
13752     if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
13753         (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
13754       // For any_extend, choose sign extension of the constants to allow a
13755       // possible further transform to sign_extend_inreg.i.e.
13756       //
13757       // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
13758       // t2: i64 = any_extend t1
13759       // -->
13760       // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
13761       // -->
13762       // t4: i64 = sign_extend_inreg t3
13763       unsigned FoldOpc = Opcode;
13764       if (FoldOpc == ISD::ANY_EXTEND)
13765         FoldOpc = ISD::SIGN_EXTEND;
13766       return DAG.getSelect(DL, VT, N0->getOperand(0),
13767                            DAG.getNode(FoldOpc, DL, VT, Op1),
13768                            DAG.getNode(FoldOpc, DL, VT, Op2));
13769     }
13770   }
13771 
13772   // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
13773   // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
13774   // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
13775   EVT SVT = VT.getScalarType();
13776   if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
13777       ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
13778     return SDValue();
13779 
13780   // We can fold this node into a build_vector.
13781   unsigned VTBits = SVT.getSizeInBits();
13782   unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
13783   SmallVector<SDValue, 8> Elts;
13784   unsigned NumElts = VT.getVectorNumElements();
13785 
13786   for (unsigned i = 0; i != NumElts; ++i) {
13787     SDValue Op = N0.getOperand(i);
13788     if (Op.isUndef()) {
13789       if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
13790         Elts.push_back(DAG.getUNDEF(SVT));
13791       else
13792         Elts.push_back(DAG.getConstant(0, DL, SVT));
13793       continue;
13794     }
13795 
13796     SDLoc DL(Op);
13797     // Get the constant value and if needed trunc it to the size of the type.
13798     // Nodes like build_vector might have constants wider than the scalar type.
13799     APInt C = Op->getAsAPIntVal().zextOrTrunc(EVTBits);
13800     if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
13801       Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
13802     else
13803       Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
13804   }
13805 
13806   return DAG.getBuildVector(VT, DL, Elts);
13807 }
13808 
13809 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
13810 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
13811 // transformation. Returns true if extension are possible and the above
13812 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)13813 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
13814                                     unsigned ExtOpc,
13815                                     SmallVectorImpl<SDNode *> &ExtendNodes,
13816                                     const TargetLowering &TLI) {
13817   bool HasCopyToRegUses = false;
13818   bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
13819   for (SDUse &Use : N0->uses()) {
13820     SDNode *User = Use.getUser();
13821     if (User == N)
13822       continue;
13823     if (Use.getResNo() != N0.getResNo())
13824       continue;
13825     // FIXME: Only extend SETCC N, N and SETCC N, c for now.
13826     if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
13827       ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
13828       if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
13829         // Sign bits will be lost after a zext.
13830         return false;
13831       bool Add = false;
13832       for (unsigned i = 0; i != 2; ++i) {
13833         SDValue UseOp = User->getOperand(i);
13834         if (UseOp == N0)
13835           continue;
13836         if (!isa<ConstantSDNode>(UseOp))
13837           return false;
13838         Add = true;
13839       }
13840       if (Add)
13841         ExtendNodes.push_back(User);
13842       continue;
13843     }
13844     // If truncates aren't free and there are users we can't
13845     // extend, it isn't worthwhile.
13846     if (!isTruncFree)
13847       return false;
13848     // Remember if this value is live-out.
13849     if (User->getOpcode() == ISD::CopyToReg)
13850       HasCopyToRegUses = true;
13851   }
13852 
13853   if (HasCopyToRegUses) {
13854     bool BothLiveOut = false;
13855     for (SDUse &Use : N->uses()) {
13856       if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
13857         BothLiveOut = true;
13858         break;
13859       }
13860     }
13861     if (BothLiveOut)
13862       // Both unextended and extended values are live out. There had better be
13863       // a good reason for the transformation.
13864       return !ExtendNodes.empty();
13865   }
13866   return true;
13867 }
13868 
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)13869 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
13870                                   SDValue OrigLoad, SDValue ExtLoad,
13871                                   ISD::NodeType ExtType) {
13872   // Extend SetCC uses if necessary.
13873   SDLoc DL(ExtLoad);
13874   for (SDNode *SetCC : SetCCs) {
13875     SmallVector<SDValue, 4> Ops;
13876 
13877     for (unsigned j = 0; j != 2; ++j) {
13878       SDValue SOp = SetCC->getOperand(j);
13879       if (SOp == OrigLoad)
13880         Ops.push_back(ExtLoad);
13881       else
13882         Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
13883     }
13884 
13885     Ops.push_back(SetCC->getOperand(2));
13886     CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
13887   }
13888 }
13889 
13890 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)13891 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
13892   SDValue N0 = N->getOperand(0);
13893   EVT DstVT = N->getValueType(0);
13894   EVT SrcVT = N0.getValueType();
13895 
13896   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13897           N->getOpcode() == ISD::ZERO_EXTEND) &&
13898          "Unexpected node type (not an extend)!");
13899 
13900   // fold (sext (load x)) to multiple smaller sextloads; same for zext.
13901   // For example, on a target with legal v4i32, but illegal v8i32, turn:
13902   //   (v8i32 (sext (v8i16 (load x))))
13903   // into:
13904   //   (v8i32 (concat_vectors (v4i32 (sextload x)),
13905   //                          (v4i32 (sextload (x + 16)))))
13906   // Where uses of the original load, i.e.:
13907   //   (v8i16 (load x))
13908   // are replaced with:
13909   //   (v8i16 (truncate
13910   //     (v8i32 (concat_vectors (v4i32 (sextload x)),
13911   //                            (v4i32 (sextload (x + 16)))))))
13912   //
13913   // This combine is only applicable to illegal, but splittable, vectors.
13914   // All legal types, and illegal non-vector types, are handled elsewhere.
13915   // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
13916   //
13917   if (N0->getOpcode() != ISD::LOAD)
13918     return SDValue();
13919 
13920   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13921 
13922   if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
13923       !N0.hasOneUse() || !LN0->isSimple() ||
13924       !DstVT.isVector() || !DstVT.isPow2VectorType() ||
13925       !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
13926     return SDValue();
13927 
13928   SmallVector<SDNode *, 4> SetCCs;
13929   if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
13930     return SDValue();
13931 
13932   ISD::LoadExtType ExtType =
13933       N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13934 
13935   // Try to split the vector types to get down to legal types.
13936   EVT SplitSrcVT = SrcVT;
13937   EVT SplitDstVT = DstVT;
13938   while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
13939          SplitSrcVT.getVectorNumElements() > 1) {
13940     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
13941     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
13942   }
13943 
13944   if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
13945     return SDValue();
13946 
13947   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
13948 
13949   SDLoc DL(N);
13950   const unsigned NumSplits =
13951       DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
13952   const unsigned Stride = SplitSrcVT.getStoreSize();
13953   SmallVector<SDValue, 4> Loads;
13954   SmallVector<SDValue, 4> Chains;
13955 
13956   SDValue BasePtr = LN0->getBasePtr();
13957   for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
13958     const unsigned Offset = Idx * Stride;
13959 
13960     SDValue SplitLoad =
13961         DAG.getExtLoad(ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(),
13962                        BasePtr, LN0->getPointerInfo().getWithOffset(Offset),
13963                        SplitSrcVT, LN0->getBaseAlign(),
13964                        LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
13965 
13966     BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::getFixed(Stride), DL);
13967 
13968     Loads.push_back(SplitLoad.getValue(0));
13969     Chains.push_back(SplitLoad.getValue(1));
13970   }
13971 
13972   SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
13973   SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
13974 
13975   // Simplify TF.
13976   AddToWorklist(NewChain.getNode());
13977 
13978   CombineTo(N, NewValue);
13979 
13980   // Replace uses of the original load (before extension)
13981   // with a truncate of the concatenated sextloaded vectors.
13982   SDValue Trunc =
13983       DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
13984   ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
13985   CombineTo(N0.getNode(), Trunc, NewChain);
13986   return SDValue(N, 0); // Return N so it doesn't get rechecked!
13987 }
13988 
13989 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13990 //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)13991 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
13992   assert(N->getOpcode() == ISD::ZERO_EXTEND);
13993   EVT VT = N->getValueType(0);
13994   EVT OrigVT = N->getOperand(0).getValueType();
13995   if (TLI.isZExtFree(OrigVT, VT))
13996     return SDValue();
13997 
13998   // and/or/xor
13999   SDValue N0 = N->getOperand(0);
14000   if (!ISD::isBitwiseLogicOp(N0.getOpcode()) ||
14001       N0.getOperand(1).getOpcode() != ISD::Constant ||
14002       (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
14003     return SDValue();
14004 
14005   // shl/shr
14006   SDValue N1 = N0->getOperand(0);
14007   if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
14008       N1.getOperand(1).getOpcode() != ISD::Constant ||
14009       (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
14010     return SDValue();
14011 
14012   // load
14013   if (!isa<LoadSDNode>(N1.getOperand(0)))
14014     return SDValue();
14015   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
14016   EVT MemVT = Load->getMemoryVT();
14017   if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
14018       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
14019     return SDValue();
14020 
14021 
14022   // If the shift op is SHL, the logic op must be AND, otherwise the result
14023   // will be wrong.
14024   if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
14025     return SDValue();
14026 
14027   if (!N0.hasOneUse() || !N1.hasOneUse())
14028     return SDValue();
14029 
14030   SmallVector<SDNode*, 4> SetCCs;
14031   if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
14032                                ISD::ZERO_EXTEND, SetCCs, TLI))
14033     return SDValue();
14034 
14035   // Actually do the transformation.
14036   SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
14037                                    Load->getChain(), Load->getBasePtr(),
14038                                    Load->getMemoryVT(), Load->getMemOperand());
14039 
14040   SDLoc DL1(N1);
14041   SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
14042                               N1.getOperand(1));
14043 
14044   APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14045   SDLoc DL0(N0);
14046   SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
14047                             DAG.getConstant(Mask, DL0, VT));
14048 
14049   ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
14050   CombineTo(N, And);
14051   if (SDValue(Load, 0).hasOneUse()) {
14052     DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
14053   } else {
14054     SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
14055                                 Load->getValueType(0), ExtLoad);
14056     CombineTo(Load, Trunc, ExtLoad.getValue(1));
14057   }
14058 
14059   // N0 is dead at this point.
14060   recursivelyDeleteUnusedNodes(N0.getNode());
14061 
14062   return SDValue(N,0); // Return N so it doesn't get rechecked!
14063 }
14064 
14065 /// If we're narrowing or widening the result of a vector select and the final
14066 /// size is the same size as a setcc (compare) feeding the select, then try to
14067 /// apply the cast operation to the select's operands because matching vector
14068 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)14069 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
14070   unsigned CastOpcode = Cast->getOpcode();
14071   assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
14072           CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
14073           CastOpcode == ISD::FP_ROUND) &&
14074          "Unexpected opcode for vector select narrowing/widening");
14075 
14076   // We only do this transform before legal ops because the pattern may be
14077   // obfuscated by target-specific operations after legalization. Do not create
14078   // an illegal select op, however, because that may be difficult to lower.
14079   EVT VT = Cast->getValueType(0);
14080   if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
14081     return SDValue();
14082 
14083   SDValue VSel = Cast->getOperand(0);
14084   if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
14085       VSel.getOperand(0).getOpcode() != ISD::SETCC)
14086     return SDValue();
14087 
14088   // Does the setcc have the same vector size as the casted select?
14089   SDValue SetCC = VSel.getOperand(0);
14090   EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
14091   if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
14092     return SDValue();
14093 
14094   // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
14095   SDValue A = VSel.getOperand(1);
14096   SDValue B = VSel.getOperand(2);
14097   SDValue CastA, CastB;
14098   SDLoc DL(Cast);
14099   if (CastOpcode == ISD::FP_ROUND) {
14100     // FP_ROUND (fptrunc) has an extra flag operand to pass along.
14101     CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
14102     CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
14103   } else {
14104     CastA = DAG.getNode(CastOpcode, DL, VT, A);
14105     CastB = DAG.getNode(CastOpcode, DL, VT, B);
14106   }
14107   return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
14108 }
14109 
14110 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14111 // fold ([s|z]ext (     extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)14112 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
14113                                      const TargetLowering &TLI, EVT VT,
14114                                      bool LegalOperations, SDNode *N,
14115                                      SDValue N0, ISD::LoadExtType ExtLoadType) {
14116   SDNode *N0Node = N0.getNode();
14117   bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
14118                                                    : ISD::isZEXTLoad(N0Node);
14119   if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
14120       !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
14121     return SDValue();
14122 
14123   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14124   EVT MemVT = LN0->getMemoryVT();
14125   if ((LegalOperations || !LN0->isSimple() ||
14126        VT.isVector()) &&
14127       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
14128     return SDValue();
14129 
14130   SDValue ExtLoad =
14131       DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
14132                      LN0->getBasePtr(), MemVT, LN0->getMemOperand());
14133   Combiner.CombineTo(N, ExtLoad);
14134   DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14135   if (LN0->use_empty())
14136     Combiner.recursivelyDeleteUnusedNodes(LN0);
14137   return SDValue(N, 0); // Return N so it doesn't get rechecked!
14138 }
14139 
14140 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14141 // Only generate vector extloads when 1) they're legal, and 2) they are
14142 // deemed desirable by the target. NonNegZExt can be set to true if a zero
14143 // extend has the nonneg flag to allow use of sextload if profitable.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc,bool NonNegZExt=false)14144 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
14145                                   const TargetLowering &TLI, EVT VT,
14146                                   bool LegalOperations, SDNode *N, SDValue N0,
14147                                   ISD::LoadExtType ExtLoadType,
14148                                   ISD::NodeType ExtOpc,
14149                                   bool NonNegZExt = false) {
14150   if (!ISD::isNON_EXTLoad(N0.getNode()) || !ISD::isUNINDEXEDLoad(N0.getNode()))
14151     return {};
14152 
14153   // If this is zext nneg, see if it would make sense to treat it as a sext.
14154   if (NonNegZExt) {
14155     assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
14156            "Unexpected load type or opcode");
14157     for (SDNode *User : N0->users()) {
14158       if (User->getOpcode() == ISD::SETCC) {
14159         ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
14160         if (ISD::isSignedIntSetCC(CC)) {
14161           ExtLoadType = ISD::SEXTLOAD;
14162           ExtOpc = ISD::SIGN_EXTEND;
14163           break;
14164         }
14165       }
14166     }
14167   }
14168 
14169   // TODO: isFixedLengthVector() should be removed and any negative effects on
14170   // code generation being the result of that target's implementation of
14171   // isVectorLoadExtDesirable().
14172   if ((LegalOperations || VT.isFixedLengthVector() ||
14173        !cast<LoadSDNode>(N0)->isSimple()) &&
14174       !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))
14175     return {};
14176 
14177   bool DoXform = true;
14178   SmallVector<SDNode *, 4> SetCCs;
14179   if (!N0.hasOneUse())
14180     DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
14181   if (VT.isVector())
14182     DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
14183   if (!DoXform)
14184     return {};
14185 
14186   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14187   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
14188                                    LN0->getBasePtr(), N0.getValueType(),
14189                                    LN0->getMemOperand());
14190   Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
14191   // If the load value is used only by N, replace it via CombineTo N.
14192   bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
14193   Combiner.CombineTo(N, ExtLoad);
14194   if (NoReplaceTrunc) {
14195     DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14196     Combiner.recursivelyDeleteUnusedNodes(LN0);
14197   } else {
14198     SDValue Trunc =
14199         DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
14200     Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
14201   }
14202   return SDValue(N, 0); // Return N so it doesn't get rechecked!
14203 }
14204 
14205 static SDValue
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)14206 tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
14207                          bool LegalOperations, SDNode *N, SDValue N0,
14208                          ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
14209   if (!N0.hasOneUse())
14210     return SDValue();
14211 
14212   MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
14213   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
14214     return SDValue();
14215 
14216   if ((LegalOperations || !cast<MaskedLoadSDNode>(N0)->isSimple()) &&
14217       !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
14218     return SDValue();
14219 
14220   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
14221     return SDValue();
14222 
14223   SDLoc dl(Ld);
14224   SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
14225   SDValue NewLoad = DAG.getMaskedLoad(
14226       VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
14227       PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
14228       ExtLoadType, Ld->isExpandingLoad());
14229   DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
14230   return NewLoad;
14231 }
14232 
14233 // fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
tryToFoldExtOfAtomicLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDValue N0,ISD::LoadExtType ExtLoadType)14234 static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
14235                                         const TargetLowering &TLI, EVT VT,
14236                                         SDValue N0,
14237                                         ISD::LoadExtType ExtLoadType) {
14238   auto *ALoad = dyn_cast<AtomicSDNode>(N0);
14239   if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
14240     return {};
14241   EVT MemoryVT = ALoad->getMemoryVT();
14242   if (!TLI.isAtomicLoadExtLegal(ExtLoadType, VT, MemoryVT))
14243     return {};
14244   // Can't fold into ALoad if it is already extending differently.
14245   ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
14246   if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
14247       (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
14248     return {};
14249 
14250   EVT OrigVT = ALoad->getValueType(0);
14251   assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
14252   auto *NewALoad = cast<AtomicSDNode>(DAG.getAtomicLoad(
14253       ExtLoadType, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
14254       ALoad->getBasePtr(), ALoad->getMemOperand()));
14255   DAG.ReplaceAllUsesOfValueWith(
14256       SDValue(ALoad, 0),
14257       DAG.getNode(ISD::TRUNCATE, SDLoc(ALoad), OrigVT, SDValue(NewALoad, 0)));
14258   // Update the chain uses.
14259   DAG.ReplaceAllUsesOfValueWith(SDValue(ALoad, 1), SDValue(NewALoad, 1));
14260   return SDValue(NewALoad, 0);
14261 }
14262 
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)14263 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
14264                                        bool LegalOperations) {
14265   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
14266           N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
14267 
14268   SDValue SetCC = N->getOperand(0);
14269   if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
14270       !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
14271     return SDValue();
14272 
14273   SDValue X = SetCC.getOperand(0);
14274   SDValue Ones = SetCC.getOperand(1);
14275   ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
14276   EVT VT = N->getValueType(0);
14277   EVT XVT = X.getValueType();
14278   // setge X, C is canonicalized to setgt, so we do not need to match that
14279   // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
14280   // not require the 'not' op.
14281   if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
14282     // Invert and smear/shift the sign bit:
14283     // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
14284     // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
14285     SDLoc DL(N);
14286     unsigned ShCt = VT.getSizeInBits() - 1;
14287     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14288     if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
14289       SDValue NotX = DAG.getNOT(DL, X, VT);
14290       SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
14291       auto ShiftOpcode =
14292         N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
14293       return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
14294     }
14295   }
14296   return SDValue();
14297 }
14298 
foldSextSetcc(SDNode * N)14299 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
14300   SDValue N0 = N->getOperand(0);
14301   if (N0.getOpcode() != ISD::SETCC)
14302     return SDValue();
14303 
14304   SDValue N00 = N0.getOperand(0);
14305   SDValue N01 = N0.getOperand(1);
14306   ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
14307   EVT VT = N->getValueType(0);
14308   EVT N00VT = N00.getValueType();
14309   SDLoc DL(N);
14310 
14311   // Propagate fast-math-flags.
14312   SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14313 
14314   // On some architectures (such as SSE/NEON/etc) the SETCC result type is
14315   // the same size as the compared operands. Try to optimize sext(setcc())
14316   // if this is the case.
14317   if (VT.isVector() && !LegalOperations &&
14318       TLI.getBooleanContents(N00VT) ==
14319           TargetLowering::ZeroOrNegativeOneBooleanContent) {
14320     EVT SVT = getSetCCResultType(N00VT);
14321 
14322     // If we already have the desired type, don't change it.
14323     if (SVT != N0.getValueType()) {
14324       // We know that the # elements of the results is the same as the
14325       // # elements of the compare (and the # elements of the compare result
14326       // for that matter).  Check to see that they are the same size.  If so,
14327       // we know that the element size of the sext'd result matches the
14328       // element size of the compare operands.
14329       if (VT.getSizeInBits() == SVT.getSizeInBits())
14330         return DAG.getSetCC(DL, VT, N00, N01, CC);
14331 
14332       // If the desired elements are smaller or larger than the source
14333       // elements, we can use a matching integer vector type and then
14334       // truncate/sign extend.
14335       EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
14336       if (SVT == MatchingVecType) {
14337         SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
14338         return DAG.getSExtOrTrunc(VsetCC, DL, VT);
14339       }
14340     }
14341 
14342     // Try to eliminate the sext of a setcc by zexting the compare operands.
14343     if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
14344         !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
14345       bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
14346       unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
14347       unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
14348 
14349       // We have an unsupported narrow vector compare op that would be legal
14350       // if extended to the destination type. See if the compare operands
14351       // can be freely extended to the destination type.
14352       auto IsFreeToExtend = [&](SDValue V) {
14353         if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
14354           return true;
14355         // Match a simple, non-extended load that can be converted to a
14356         // legal {z/s}ext-load.
14357         // TODO: Allow widening of an existing {z/s}ext-load?
14358         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
14359               ISD::isUNINDEXEDLoad(V.getNode()) &&
14360               cast<LoadSDNode>(V)->isSimple() &&
14361               TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
14362           return false;
14363 
14364         // Non-chain users of this value must either be the setcc in this
14365         // sequence or extends that can be folded into the new {z/s}ext-load.
14366         for (SDUse &Use : V->uses()) {
14367           // Skip uses of the chain and the setcc.
14368           SDNode *User = Use.getUser();
14369           if (Use.getResNo() != 0 || User == N0.getNode())
14370             continue;
14371           // Extra users must have exactly the same cast we are about to create.
14372           // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
14373           //       is enhanced similarly.
14374           if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
14375             return false;
14376         }
14377         return true;
14378       };
14379 
14380       if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
14381         SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
14382         SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
14383         return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
14384       }
14385     }
14386   }
14387 
14388   // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
14389   // Here, T can be 1 or -1, depending on the type of the setcc and
14390   // getBooleanContents().
14391   unsigned SetCCWidth = N0.getScalarValueSizeInBits();
14392 
14393   // To determine the "true" side of the select, we need to know the high bit
14394   // of the value returned by the setcc if it evaluates to true.
14395   // If the type of the setcc is i1, then the true case of the select is just
14396   // sext(i1 1), that is, -1.
14397   // If the type of the setcc is larger (say, i8) then the value of the high
14398   // bit depends on getBooleanContents(), so ask TLI for a real "true" value
14399   // of the appropriate width.
14400   SDValue ExtTrueVal = (SetCCWidth == 1)
14401                            ? DAG.getAllOnesConstant(DL, VT)
14402                            : DAG.getBoolConstant(true, DL, VT, N00VT);
14403   SDValue Zero = DAG.getConstant(0, DL, VT);
14404   if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
14405     return SCC;
14406 
14407   if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(N0, VT, TLI)) {
14408     EVT SetCCVT = getSetCCResultType(N00VT);
14409     // Don't do this transform for i1 because there's a select transform
14410     // that would reverse it.
14411     // TODO: We should not do this transform at all without a target hook
14412     // because a sext is likely cheaper than a select?
14413     if (SetCCVT.getScalarSizeInBits() != 1 &&
14414         (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
14415       SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
14416       return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
14417     }
14418   }
14419 
14420   return SDValue();
14421 }
14422 
visitSIGN_EXTEND(SDNode * N)14423 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
14424   SDValue N0 = N->getOperand(0);
14425   EVT VT = N->getValueType(0);
14426   SDLoc DL(N);
14427 
14428   if (VT.isVector())
14429     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
14430       return FoldedVOp;
14431 
14432   // sext(undef) = 0 because the top bit will all be the same.
14433   if (N0.isUndef())
14434     return DAG.getConstant(0, DL, VT);
14435 
14436   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14437     return Res;
14438 
14439   // fold (sext (sext x)) -> (sext x)
14440   // fold (sext (aext x)) -> (sext x)
14441   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
14442     return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
14443 
14444   // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14445   // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14446   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14447       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
14448     return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
14449                        N0.getOperand(0));
14450 
14451   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
14452     SDValue N00 = N0.getOperand(0);
14453     EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT();
14454     if (N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(N00, ExtVT)) {
14455       // fold (sext (sext_inreg x)) -> (sext (trunc x))
14456       if ((!LegalTypes || TLI.isTypeLegal(ExtVT))) {
14457         SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00);
14458         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T);
14459       }
14460 
14461       // If the trunc wasn't legal, try to fold to (sext_inreg (anyext x))
14462       if (!LegalTypes || TLI.isTypeLegal(VT)) {
14463         SDValue ExtSrc = DAG.getAnyExtOrTrunc(N00, DL, VT);
14464         return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, ExtSrc,
14465                            N0->getOperand(1));
14466       }
14467     }
14468   }
14469 
14470   if (N0.getOpcode() == ISD::TRUNCATE) {
14471     // fold (sext (truncate (load x))) -> (sext (smaller load x))
14472     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
14473     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
14474       SDNode *oye = N0.getOperand(0).getNode();
14475       if (NarrowLoad.getNode() != N0.getNode()) {
14476         CombineTo(N0.getNode(), NarrowLoad);
14477         // CombineTo deleted the truncate, if needed, but not what's under it.
14478         AddToWorklist(oye);
14479       }
14480       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
14481     }
14482 
14483     // See if the value being truncated is already sign extended.  If so, just
14484     // eliminate the trunc/sext pair.
14485     SDValue Op = N0.getOperand(0);
14486     unsigned OpBits   = Op.getScalarValueSizeInBits();
14487     unsigned MidBits  = N0.getScalarValueSizeInBits();
14488     unsigned DestBits = VT.getScalarSizeInBits();
14489 
14490     if (N0->getFlags().hasNoSignedWrap() ||
14491         DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
14492       if (OpBits == DestBits) {
14493         // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
14494         // bits, it is already ready.
14495         return Op;
14496       }
14497 
14498       if (OpBits < DestBits) {
14499         // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
14500         // bits, just sext from i32.
14501         return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
14502       }
14503 
14504       // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
14505       // bits, just truncate to i32.
14506       SDNodeFlags Flags;
14507       Flags.setNoSignedWrap(true);
14508       Flags.setNoUnsignedWrap(N0->getFlags().hasNoUnsignedWrap());
14509       return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
14510     }
14511 
14512     // fold (sext (truncate x)) -> (sextinreg x).
14513     if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
14514                                                  N0.getValueType())) {
14515       if (OpBits < DestBits)
14516         Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
14517       else if (OpBits > DestBits)
14518         Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
14519       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
14520                          DAG.getValueType(N0.getValueType()));
14521     }
14522   }
14523 
14524   // Try to simplify (sext (load x)).
14525   if (SDValue foldedExt =
14526           tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
14527                              ISD::SEXTLOAD, ISD::SIGN_EXTEND))
14528     return foldedExt;
14529 
14530   if (SDValue foldedExt =
14531           tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
14532                                    ISD::SEXTLOAD, ISD::SIGN_EXTEND))
14533     return foldedExt;
14534 
14535   // fold (sext (load x)) to multiple smaller sextloads.
14536   // Only on illegal but splittable vectors.
14537   if (SDValue ExtLoad = CombineExtLoad(N))
14538     return ExtLoad;
14539 
14540   // Try to simplify (sext (sextload x)).
14541   if (SDValue foldedExt = tryToFoldExtOfExtload(
14542           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
14543     return foldedExt;
14544 
14545   // Try to simplify (sext (atomic_load x)).
14546   if (SDValue foldedExt =
14547           tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::SEXTLOAD))
14548     return foldedExt;
14549 
14550   // fold (sext (and/or/xor (load x), cst)) ->
14551   //      (and/or/xor (sextload x), (sext cst))
14552   if (ISD::isBitwiseLogicOp(N0.getOpcode()) &&
14553       isa<LoadSDNode>(N0.getOperand(0)) &&
14554       N0.getOperand(1).getOpcode() == ISD::Constant &&
14555       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
14556     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
14557     EVT MemVT = LN00->getMemoryVT();
14558     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
14559       LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
14560       SmallVector<SDNode*, 4> SetCCs;
14561       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
14562                                              ISD::SIGN_EXTEND, SetCCs, TLI);
14563       if (DoXform) {
14564         SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
14565                                          LN00->getChain(), LN00->getBasePtr(),
14566                                          LN00->getMemoryVT(),
14567                                          LN00->getMemOperand());
14568         APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
14569         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
14570                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
14571         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
14572         bool NoReplaceTruncAnd = !N0.hasOneUse();
14573         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14574         CombineTo(N, And);
14575         // If N0 has multiple uses, change other uses as well.
14576         if (NoReplaceTruncAnd) {
14577           SDValue TruncAnd =
14578               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
14579           CombineTo(N0.getNode(), TruncAnd);
14580         }
14581         if (NoReplaceTrunc) {
14582           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
14583         } else {
14584           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
14585                                       LN00->getValueType(0), ExtLoad);
14586           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
14587         }
14588         return SDValue(N,0); // Return N so it doesn't get rechecked!
14589       }
14590     }
14591   }
14592 
14593   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14594     return V;
14595 
14596   if (SDValue V = foldSextSetcc(N))
14597     return V;
14598 
14599   // fold (sext x) -> (zext x) if the sign bit is known zero.
14600   if (!TLI.isSExtCheaperThanZExt(N0.getValueType(), VT) &&
14601       (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
14602       DAG.SignBitIsZero(N0))
14603     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, SDNodeFlags::NonNeg);
14604 
14605   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14606     return NewVSel;
14607 
14608   // Eliminate this sign extend by doing a negation in the destination type:
14609   // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
14610   if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
14611       isNullOrNullSplat(N0.getOperand(0)) &&
14612       N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
14613       TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
14614     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
14615     return DAG.getNegative(Zext, DL, VT);
14616   }
14617   // Eliminate this sign extend by doing a decrement in the destination type:
14618   // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
14619   if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
14620       isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
14621       N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
14622       TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
14623     SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
14624     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
14625   }
14626 
14627   // fold sext (not i1 X) -> add (zext i1 X), -1
14628   // TODO: This could be extended to handle bool vectors.
14629   if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
14630       (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
14631                             TLI.isOperationLegal(ISD::ADD, VT)))) {
14632     // If we can eliminate the 'not', the sext form should be better
14633     if (SDValue NewXor = visitXOR(N0.getNode())) {
14634       // Returning N0 is a form of in-visit replacement that may have
14635       // invalidated N0.
14636       if (NewXor.getNode() == N0.getNode()) {
14637         // Return SDValue here as the xor should have already been replaced in
14638         // this sext.
14639         return SDValue();
14640       }
14641 
14642       // Return a new sext with the new xor.
14643       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
14644     }
14645 
14646     SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
14647     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
14648   }
14649 
14650   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
14651     return Res;
14652 
14653   return SDValue();
14654 }
14655 
14656 /// Given an extending node with a pop-count operand, if the target does not
14657 /// support a pop-count in the narrow source type but does support it in the
14658 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG,const SDLoc & DL)14659 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG, const SDLoc &DL) {
14660   assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
14661           Extend->getOpcode() == ISD::ANY_EXTEND) &&
14662          "Expected extend op");
14663 
14664   SDValue CtPop = Extend->getOperand(0);
14665   if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
14666     return SDValue();
14667 
14668   EVT VT = Extend->getValueType(0);
14669   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14670   if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
14671       !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
14672     return SDValue();
14673 
14674   // zext (ctpop X) --> ctpop (zext X)
14675   SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
14676   return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
14677 }
14678 
14679 // If we have (zext (abs X)) where X is a type that will be promoted by type
14680 // legalization, convert to (abs (sext X)). But don't extend past a legal type.
widenAbs(SDNode * Extend,SelectionDAG & DAG)14681 static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
14682   assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
14683 
14684   EVT VT = Extend->getValueType(0);
14685   if (VT.isVector())
14686     return SDValue();
14687 
14688   SDValue Abs = Extend->getOperand(0);
14689   if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
14690     return SDValue();
14691 
14692   EVT AbsVT = Abs.getValueType();
14693   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14694   if (TLI.getTypeAction(*DAG.getContext(), AbsVT) !=
14695       TargetLowering::TypePromoteInteger)
14696     return SDValue();
14697 
14698   EVT LegalVT = TLI.getTypeToTransformTo(*DAG.getContext(), AbsVT);
14699 
14700   SDValue SExt =
14701       DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Abs), LegalVT, Abs.getOperand(0));
14702   SDValue NewAbs = DAG.getNode(ISD::ABS, SDLoc(Abs), LegalVT, SExt);
14703   return DAG.getZExtOrTrunc(NewAbs, SDLoc(Extend), VT);
14704 }
14705 
visitZERO_EXTEND(SDNode * N)14706 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
14707   SDValue N0 = N->getOperand(0);
14708   EVT VT = N->getValueType(0);
14709   SDLoc DL(N);
14710 
14711   if (VT.isVector())
14712     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
14713       return FoldedVOp;
14714 
14715   // zext(undef) = 0
14716   if (N0.isUndef())
14717     return DAG.getConstant(0, DL, VT);
14718 
14719   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14720     return Res;
14721 
14722   // fold (zext (zext x)) -> (zext x)
14723   // fold (zext (aext x)) -> (zext x)
14724   if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14725     SDNodeFlags Flags;
14726     if (N0.getOpcode() == ISD::ZERO_EXTEND)
14727       Flags.setNonNeg(N0->getFlags().hasNonNeg());
14728     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0), Flags);
14729   }
14730 
14731   // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14732   // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14733   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14734       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
14735     return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, N0.getOperand(0));
14736 
14737   // fold (zext (truncate x)) -> (zext x) or
14738   //      (zext (truncate x)) -> (truncate x)
14739   // This is valid when the truncated bits of x are already zero.
14740   SDValue Op;
14741   KnownBits Known;
14742   if (isTruncateOf(DAG, N0, Op, Known)) {
14743     APInt TruncatedBits =
14744       (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
14745       APInt(Op.getScalarValueSizeInBits(), 0) :
14746       APInt::getBitsSet(Op.getScalarValueSizeInBits(),
14747                         N0.getScalarValueSizeInBits(),
14748                         std::min(Op.getScalarValueSizeInBits(),
14749                                  VT.getScalarSizeInBits()));
14750     if (TruncatedBits.isSubsetOf(Known.Zero)) {
14751       SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
14752       DAG.salvageDebugInfo(*N0.getNode());
14753 
14754       return ZExtOrTrunc;
14755     }
14756   }
14757 
14758   // fold (zext (truncate x)) -> (and x, mask)
14759   if (N0.getOpcode() == ISD::TRUNCATE) {
14760     // fold (zext (truncate (load x))) -> (zext (smaller load x))
14761     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
14762     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
14763       SDNode *oye = N0.getOperand(0).getNode();
14764       if (NarrowLoad.getNode() != N0.getNode()) {
14765         CombineTo(N0.getNode(), NarrowLoad);
14766         // CombineTo deleted the truncate, if needed, but not what's under it.
14767         AddToWorklist(oye);
14768       }
14769       return SDValue(N, 0); // Return N so it doesn't get rechecked!
14770     }
14771 
14772     EVT SrcVT = N0.getOperand(0).getValueType();
14773     EVT MinVT = N0.getValueType();
14774 
14775     if (N->getFlags().hasNonNeg()) {
14776       SDValue Op = N0.getOperand(0);
14777       unsigned OpBits = SrcVT.getScalarSizeInBits();
14778       unsigned MidBits = MinVT.getScalarSizeInBits();
14779       unsigned DestBits = VT.getScalarSizeInBits();
14780 
14781       if (N0->getFlags().hasNoSignedWrap() ||
14782           DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
14783         if (OpBits == DestBits) {
14784           // Op is i32, Mid is i8, and Dest is i32.  If Op has more than 24 sign
14785           // bits, it is already ready.
14786           return Op;
14787         }
14788 
14789         if (OpBits < DestBits) {
14790           // Op is i32, Mid is i8, and Dest is i64.  If Op has more than 24 sign
14791           // bits, just sext from i32.
14792           // FIXME: This can probably be ZERO_EXTEND nneg?
14793           return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
14794         }
14795 
14796         // Op is i64, Mid is i8, and Dest is i32.  If Op has more than 56 sign
14797         // bits, just truncate to i32.
14798         SDNodeFlags Flags;
14799         Flags.setNoSignedWrap(true);
14800         Flags.setNoUnsignedWrap(true);
14801         return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
14802       }
14803     }
14804 
14805     // Try to mask before the extension to avoid having to generate a larger mask,
14806     // possibly over several sub-vectors.
14807     if (SrcVT.bitsLT(VT) && VT.isVector()) {
14808       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
14809                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
14810         SDValue Op = N0.getOperand(0);
14811         Op = DAG.getZeroExtendInReg(Op, DL, MinVT);
14812         AddToWorklist(Op.getNode());
14813         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
14814         // Transfer the debug info; the new node is equivalent to N0.
14815         DAG.transferDbgValues(N0, ZExtOrTrunc);
14816         return ZExtOrTrunc;
14817       }
14818     }
14819 
14820     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
14821       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
14822       AddToWorklist(Op.getNode());
14823       SDValue And = DAG.getZeroExtendInReg(Op, DL, MinVT);
14824       // We may safely transfer the debug info describing the truncate node over
14825       // to the equivalent and operation.
14826       DAG.transferDbgValues(N0, And);
14827       return And;
14828     }
14829   }
14830 
14831   // Fold (zext (and (trunc x), cst)) -> (and x, cst),
14832   // if either of the casts is not free.
14833   if (N0.getOpcode() == ISD::AND &&
14834       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
14835       N0.getOperand(1).getOpcode() == ISD::Constant &&
14836       (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType()) ||
14837        !TLI.isZExtFree(N0.getValueType(), VT))) {
14838     SDValue X = N0.getOperand(0).getOperand(0);
14839     X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
14840     APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14841     return DAG.getNode(ISD::AND, DL, VT,
14842                        X, DAG.getConstant(Mask, DL, VT));
14843   }
14844 
14845   // Try to simplify (zext (load x)).
14846   if (SDValue foldedExt = tryToFoldExtOfLoad(
14847           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD,
14848           ISD::ZERO_EXTEND, N->getFlags().hasNonNeg()))
14849     return foldedExt;
14850 
14851   if (SDValue foldedExt =
14852           tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
14853                                    ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
14854     return foldedExt;
14855 
14856   // fold (zext (load x)) to multiple smaller zextloads.
14857   // Only on illegal but splittable vectors.
14858   if (SDValue ExtLoad = CombineExtLoad(N))
14859     return ExtLoad;
14860 
14861   // Try to simplify (zext (atomic_load x)).
14862   if (SDValue foldedExt =
14863           tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::ZEXTLOAD))
14864     return foldedExt;
14865 
14866   // fold (zext (and/or/xor (load x), cst)) ->
14867   //      (and/or/xor (zextload x), (zext cst))
14868   // Unless (and (load x) cst) will match as a zextload already and has
14869   // additional users, or the zext is already free.
14870   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && !TLI.isZExtFree(N0, VT) &&
14871       isa<LoadSDNode>(N0.getOperand(0)) &&
14872       N0.getOperand(1).getOpcode() == ISD::Constant &&
14873       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
14874     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
14875     EVT MemVT = LN00->getMemoryVT();
14876     if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
14877         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
14878       bool DoXform = true;
14879       SmallVector<SDNode*, 4> SetCCs;
14880       if (!N0.hasOneUse()) {
14881         if (N0.getOpcode() == ISD::AND) {
14882           auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
14883           EVT LoadResultTy = AndC->getValueType(0);
14884           EVT ExtVT;
14885           if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
14886             DoXform = false;
14887         }
14888       }
14889       if (DoXform)
14890         DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
14891                                           ISD::ZERO_EXTEND, SetCCs, TLI);
14892       if (DoXform) {
14893         SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
14894                                          LN00->getChain(), LN00->getBasePtr(),
14895                                          LN00->getMemoryVT(),
14896                                          LN00->getMemOperand());
14897         APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14898         SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
14899                                   ExtLoad, DAG.getConstant(Mask, DL, VT));
14900         ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
14901         bool NoReplaceTruncAnd = !N0.hasOneUse();
14902         bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14903         CombineTo(N, And);
14904         // If N0 has multiple uses, change other uses as well.
14905         if (NoReplaceTruncAnd) {
14906           SDValue TruncAnd =
14907               DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
14908           CombineTo(N0.getNode(), TruncAnd);
14909         }
14910         if (NoReplaceTrunc) {
14911           DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
14912         } else {
14913           SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
14914                                       LN00->getValueType(0), ExtLoad);
14915           CombineTo(LN00, Trunc, ExtLoad.getValue(1));
14916         }
14917         return SDValue(N,0); // Return N so it doesn't get rechecked!
14918       }
14919     }
14920   }
14921 
14922   // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
14923   //      (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
14924   if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
14925     return ZExtLoad;
14926 
14927   // Try to simplify (zext (zextload x)).
14928   if (SDValue foldedExt = tryToFoldExtOfExtload(
14929           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
14930     return foldedExt;
14931 
14932   if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14933     return V;
14934 
14935   if (N0.getOpcode() == ISD::SETCC) {
14936     // Propagate fast-math-flags.
14937     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14938 
14939     // Only do this before legalize for now.
14940     if (!LegalOperations && VT.isVector() &&
14941         N0.getValueType().getVectorElementType() == MVT::i1) {
14942       EVT N00VT = N0.getOperand(0).getValueType();
14943       if (getSetCCResultType(N00VT) == N0.getValueType())
14944         return SDValue();
14945 
14946       // We know that the # elements of the results is the same as the #
14947       // elements of the compare (and the # elements of the compare result for
14948       // that matter). Check to see that they are the same size. If so, we know
14949       // that the element size of the sext'd result matches the element size of
14950       // the compare operands.
14951       if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
14952         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
14953         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
14954                                      N0.getOperand(1), N0.getOperand(2));
14955         return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
14956       }
14957 
14958       // If the desired elements are smaller or larger than the source
14959       // elements we can use a matching integer vector type and then
14960       // truncate/any extend followed by zext_in_reg.
14961       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14962       SDValue VsetCC =
14963           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
14964                       N0.getOperand(1), N0.getOperand(2));
14965       return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
14966                                     N0.getValueType());
14967     }
14968 
14969     // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
14970     EVT N0VT = N0.getValueType();
14971     EVT N00VT = N0.getOperand(0).getValueType();
14972     if (SDValue SCC = SimplifySelectCC(
14973             DL, N0.getOperand(0), N0.getOperand(1),
14974             DAG.getBoolConstant(true, DL, N0VT, N00VT),
14975             DAG.getBoolConstant(false, DL, N0VT, N00VT),
14976             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
14977       return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
14978   }
14979 
14980   // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
14981   if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
14982       !TLI.isZExtFree(N0, VT)) {
14983     SDValue ShVal = N0.getOperand(0);
14984     SDValue ShAmt = N0.getOperand(1);
14985     if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) {
14986       if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
14987         if (N0.getOpcode() == ISD::SHL) {
14988           // If the original shl may be shifting out bits, do not perform this
14989           // transformation.
14990           unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
14991                                    ShVal.getOperand(0).getValueSizeInBits();
14992           if (ShAmtC->getAPIntValue().ugt(KnownZeroBits)) {
14993             // If the shift is too large, then see if we can deduce that the
14994             // shift is safe anyway.
14995 
14996             // Check if the bits being shifted out are known to be zero.
14997             KnownBits KnownShVal = DAG.computeKnownBits(ShVal);
14998             if (ShAmtC->getAPIntValue().ugt(KnownShVal.countMinLeadingZeros()))
14999               return SDValue();
15000           }
15001         }
15002 
15003         // Ensure that the shift amount is wide enough for the shifted value.
15004         if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
15005           ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
15006 
15007         return DAG.getNode(N0.getOpcode(), DL, VT,
15008                            DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ShVal), ShAmt);
15009       }
15010     }
15011   }
15012 
15013   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15014     return NewVSel;
15015 
15016   if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
15017     return NewCtPop;
15018 
15019   if (SDValue V = widenAbs(N, DAG))
15020     return V;
15021 
15022   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15023     return Res;
15024 
15025   // CSE zext nneg with sext if the zext is not free.
15026   if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(N0.getValueType(), VT)) {
15027     SDNode *CSENode = DAG.getNodeIfExists(ISD::SIGN_EXTEND, N->getVTList(), N0);
15028     if (CSENode)
15029       return SDValue(CSENode, 0);
15030   }
15031 
15032   return SDValue();
15033 }
15034 
visitANY_EXTEND(SDNode * N)15035 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
15036   SDValue N0 = N->getOperand(0);
15037   EVT VT = N->getValueType(0);
15038   SDLoc DL(N);
15039 
15040   // aext(undef) = undef
15041   if (N0.isUndef())
15042     return DAG.getUNDEF(VT);
15043 
15044   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15045     return Res;
15046 
15047   // fold (aext (aext x)) -> (aext x)
15048   // fold (aext (zext x)) -> (zext x)
15049   // fold (aext (sext x)) -> (sext x)
15050   if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
15051       N0.getOpcode() == ISD::SIGN_EXTEND) {
15052     SDNodeFlags Flags;
15053     if (N0.getOpcode() == ISD::ZERO_EXTEND)
15054       Flags.setNonNeg(N0->getFlags().hasNonNeg());
15055     return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Flags);
15056   }
15057 
15058   // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
15059   // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15060   // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
15061   if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
15062       N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
15063       N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
15064     return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0));
15065 
15066   // fold (aext (truncate (load x))) -> (aext (smaller load x))
15067   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
15068   if (N0.getOpcode() == ISD::TRUNCATE) {
15069     if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
15070       SDNode *oye = N0.getOperand(0).getNode();
15071       if (NarrowLoad.getNode() != N0.getNode()) {
15072         CombineTo(N0.getNode(), NarrowLoad);
15073         // CombineTo deleted the truncate, if needed, but not what's under it.
15074         AddToWorklist(oye);
15075       }
15076       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
15077     }
15078   }
15079 
15080   // fold (aext (truncate x))
15081   if (N0.getOpcode() == ISD::TRUNCATE)
15082     return DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
15083 
15084   // Fold (aext (and (trunc x), cst)) -> (and x, cst)
15085   // if the trunc is not free.
15086   if (N0.getOpcode() == ISD::AND &&
15087       N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
15088       N0.getOperand(1).getOpcode() == ISD::Constant &&
15089       !TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType())) {
15090     SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
15091     SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
15092     assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
15093     return DAG.getNode(ISD::AND, DL, VT, X, Y);
15094   }
15095 
15096   // fold (aext (load x)) -> (aext (truncate (extload x)))
15097   // None of the supported targets knows how to perform load and any_ext
15098   // on vectors in one instruction, so attempt to fold to zext instead.
15099   if (VT.isVector()) {
15100     // Try to simplify (zext (load x)).
15101     if (SDValue foldedExt =
15102             tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
15103                                ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
15104       return foldedExt;
15105   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
15106              ISD::isUNINDEXEDLoad(N0.getNode()) &&
15107              TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
15108     bool DoXform = true;
15109     SmallVector<SDNode *, 4> SetCCs;
15110     if (!N0.hasOneUse())
15111       DoXform =
15112           ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
15113     if (DoXform) {
15114       LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15115       SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT, LN0->getChain(),
15116                                        LN0->getBasePtr(), N0.getValueType(),
15117                                        LN0->getMemOperand());
15118       ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
15119       // If the load value is used only by N, replace it via CombineTo N.
15120       bool NoReplaceTrunc = N0.hasOneUse();
15121       CombineTo(N, ExtLoad);
15122       if (NoReplaceTrunc) {
15123         DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
15124         recursivelyDeleteUnusedNodes(LN0);
15125       } else {
15126         SDValue Trunc =
15127             DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
15128         CombineTo(LN0, Trunc, ExtLoad.getValue(1));
15129       }
15130       return SDValue(N, 0); // Return N so it doesn't get rechecked!
15131     }
15132   }
15133 
15134   // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
15135   // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
15136   // fold (aext ( extload x)) -> (aext (truncate (extload  x)))
15137   if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
15138       ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
15139     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15140     ISD::LoadExtType ExtType = LN0->getExtensionType();
15141     EVT MemVT = LN0->getMemoryVT();
15142     if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
15143       SDValue ExtLoad =
15144           DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(),
15145                          MemVT, LN0->getMemOperand());
15146       CombineTo(N, ExtLoad);
15147       DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
15148       recursivelyDeleteUnusedNodes(LN0);
15149       return SDValue(N, 0);   // Return N so it doesn't get rechecked!
15150     }
15151   }
15152 
15153   if (N0.getOpcode() == ISD::SETCC) {
15154     // Propagate fast-math-flags.
15155     SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
15156 
15157     // For vectors:
15158     // aext(setcc) -> vsetcc
15159     // aext(setcc) -> truncate(vsetcc)
15160     // aext(setcc) -> aext(vsetcc)
15161     // Only do this before legalize for now.
15162     if (VT.isVector() && !LegalOperations) {
15163       EVT N00VT = N0.getOperand(0).getValueType();
15164       if (getSetCCResultType(N00VT) == N0.getValueType())
15165         return SDValue();
15166 
15167       // We know that the # elements of the results is the same as the
15168       // # elements of the compare (and the # elements of the compare result
15169       // for that matter).  Check to see that they are the same size.  If so,
15170       // we know that the element size of the sext'd result matches the
15171       // element size of the compare operands.
15172       if (VT.getSizeInBits() == N00VT.getSizeInBits())
15173         return DAG.getSetCC(DL, VT, N0.getOperand(0), N0.getOperand(1),
15174                             cast<CondCodeSDNode>(N0.getOperand(2))->get());
15175 
15176       // If the desired elements are smaller or larger than the source
15177       // elements we can use a matching integer vector type and then
15178       // truncate/any extend
15179       EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
15180       SDValue VsetCC = DAG.getSetCC(
15181           DL, MatchingVectorType, N0.getOperand(0), N0.getOperand(1),
15182           cast<CondCodeSDNode>(N0.getOperand(2))->get());
15183       return DAG.getAnyExtOrTrunc(VsetCC, DL, VT);
15184     }
15185 
15186     // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
15187     if (SDValue SCC = SimplifySelectCC(
15188             DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
15189             DAG.getConstant(0, DL, VT),
15190             cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
15191       return SCC;
15192   }
15193 
15194   if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
15195     return NewCtPop;
15196 
15197   if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15198     return Res;
15199 
15200   return SDValue();
15201 }
15202 
visitAssertExt(SDNode * N)15203 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
15204   unsigned Opcode = N->getOpcode();
15205   SDValue N0 = N->getOperand(0);
15206   SDValue N1 = N->getOperand(1);
15207   EVT AssertVT = cast<VTSDNode>(N1)->getVT();
15208 
15209   // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
15210   if (N0.getOpcode() == Opcode &&
15211       AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
15212     return N0;
15213 
15214   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15215       N0.getOperand(0).getOpcode() == Opcode) {
15216     // We have an assert, truncate, assert sandwich. Make one stronger assert
15217     // by asserting on the smallest asserted type to the larger source type.
15218     // This eliminates the later assert:
15219     // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
15220     // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
15221     SDLoc DL(N);
15222     SDValue BigA = N0.getOperand(0);
15223     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15224     EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
15225     SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
15226     SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
15227                                     BigA.getOperand(0), MinAssertVTVal);
15228     return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
15229   }
15230 
15231   // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
15232   // than X. Just move the AssertZext in front of the truncate and drop the
15233   // AssertSExt.
15234   if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15235       N0.getOperand(0).getOpcode() == ISD::AssertSext &&
15236       Opcode == ISD::AssertZext) {
15237     SDValue BigA = N0.getOperand(0);
15238     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15239     if (AssertVT.bitsLT(BigA_AssertVT)) {
15240       SDLoc DL(N);
15241       SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
15242                                       BigA.getOperand(0), N1);
15243       return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
15244     }
15245   }
15246 
15247   // If we have (AssertZext (and (AssertSext X, iX), M), iY) and Y is smaller
15248   // than X, and the And doesn't change the lower iX bits, we can move the
15249   // AssertZext in front of the And and drop the AssertSext.
15250   if (Opcode == ISD::AssertZext && N0.getOpcode() == ISD::AND &&
15251       N0.hasOneUse() && N0.getOperand(0).getOpcode() == ISD::AssertSext &&
15252       isa<ConstantSDNode>(N0.getOperand(1))) {
15253     SDValue BigA = N0.getOperand(0);
15254     EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15255     const APInt &Mask = N0.getConstantOperandAPInt(1);
15256     if (AssertVT.bitsLT(BigA_AssertVT) &&
15257         Mask.countr_one() >= BigA_AssertVT.getScalarSizeInBits()) {
15258       SDLoc DL(N);
15259       SDValue NewAssert =
15260           DAG.getNode(Opcode, DL, N->getValueType(0), BigA.getOperand(0), N1);
15261       return DAG.getNode(ISD::AND, DL, N->getValueType(0), NewAssert,
15262                          N0.getOperand(1));
15263     }
15264   }
15265 
15266   return SDValue();
15267 }
15268 
visitAssertAlign(SDNode * N)15269 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
15270   SDLoc DL(N);
15271 
15272   Align AL = cast<AssertAlignSDNode>(N)->getAlign();
15273   SDValue N0 = N->getOperand(0);
15274 
15275   // Fold (assertalign (assertalign x, AL0), AL1) ->
15276   // (assertalign x, max(AL0, AL1))
15277   if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
15278     return DAG.getAssertAlign(DL, N0.getOperand(0),
15279                               std::max(AL, AAN->getAlign()));
15280 
15281   // In rare cases, there are trivial arithmetic ops in source operands. Sink
15282   // this assert down to source operands so that those arithmetic ops could be
15283   // exposed to the DAG combining.
15284   switch (N0.getOpcode()) {
15285   default:
15286     break;
15287   case ISD::ADD:
15288   case ISD::PTRADD:
15289   case ISD::SUB: {
15290     unsigned AlignShift = Log2(AL);
15291     SDValue LHS = N0.getOperand(0);
15292     SDValue RHS = N0.getOperand(1);
15293     unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
15294     unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
15295     if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
15296       if (LHSAlignShift < AlignShift)
15297         LHS = DAG.getAssertAlign(DL, LHS, AL);
15298       if (RHSAlignShift < AlignShift)
15299         RHS = DAG.getAssertAlign(DL, RHS, AL);
15300       return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
15301     }
15302     break;
15303   }
15304   }
15305 
15306   return SDValue();
15307 }
15308 
15309 /// If the result of a load is shifted/masked/truncated to an effectively
15310 /// narrower type, try to transform the load to a narrower type and/or
15311 /// use an extending load.
reduceLoadWidth(SDNode * N)15312 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
15313   unsigned Opc = N->getOpcode();
15314 
15315   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
15316   SDValue N0 = N->getOperand(0);
15317   EVT VT = N->getValueType(0);
15318   EVT ExtVT = VT;
15319 
15320   // This transformation isn't valid for vector loads.
15321   if (VT.isVector())
15322     return SDValue();
15323 
15324   // The ShAmt variable is used to indicate that we've consumed a right
15325   // shift. I.e. we want to narrow the width of the load by skipping to load the
15326   // ShAmt least significant bits.
15327   unsigned ShAmt = 0;
15328   // A special case is when the least significant bits from the load are masked
15329   // away, but using an AND rather than a right shift. HasShiftedOffset is used
15330   // to indicate that the narrowed load should be left-shifted ShAmt bits to get
15331   // the result.
15332   unsigned ShiftedOffset = 0;
15333   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
15334   // extended to VT.
15335   if (Opc == ISD::SIGN_EXTEND_INREG) {
15336     ExtType = ISD::SEXTLOAD;
15337     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
15338   } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
15339     // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
15340     // value, or it may be shifting a higher subword, half or byte into the
15341     // lowest bits.
15342 
15343     // Only handle shift with constant shift amount, and the shiftee must be a
15344     // load.
15345     auto *LN = dyn_cast<LoadSDNode>(N0);
15346     auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
15347     if (!N1C || !LN)
15348       return SDValue();
15349     // If the shift amount is larger than the memory type then we're not
15350     // accessing any of the loaded bytes.
15351     ShAmt = N1C->getZExtValue();
15352     uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
15353     if (MemoryWidth <= ShAmt)
15354       return SDValue();
15355     // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
15356     ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
15357     ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
15358     // If original load is a SEXTLOAD then we can't simply replace it by a
15359     // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
15360     // followed by a ZEXT, but that is not handled at the moment). Similarly if
15361     // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
15362     if ((LN->getExtensionType() == ISD::SEXTLOAD ||
15363          LN->getExtensionType() == ISD::ZEXTLOAD) &&
15364         LN->getExtensionType() != ExtType)
15365       return SDValue();
15366   } else if (Opc == ISD::AND) {
15367     // An AND with a constant mask is the same as a truncate + zero-extend.
15368     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
15369     if (!AndC)
15370       return SDValue();
15371 
15372     const APInt &Mask = AndC->getAPIntValue();
15373     unsigned ActiveBits = 0;
15374     if (Mask.isMask()) {
15375       ActiveBits = Mask.countr_one();
15376     } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
15377       ShiftedOffset = ShAmt;
15378     } else {
15379       return SDValue();
15380     }
15381 
15382     ExtType = ISD::ZEXTLOAD;
15383     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
15384   }
15385 
15386   // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
15387   // a right shift. Here we redo some of those checks, to possibly adjust the
15388   // ExtVT even further based on "a masking AND". We could also end up here for
15389   // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
15390   // need to be done here as well.
15391   if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
15392     SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
15393     // Bail out when the SRL has more than one use. This is done for historical
15394     // (undocumented) reasons. Maybe intent was to guard the AND-masking below
15395     // check below? And maybe it could be non-profitable to do the transform in
15396     // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
15397     // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
15398     if (!SRL.hasOneUse())
15399       return SDValue();
15400 
15401     // Only handle shift with constant shift amount, and the shiftee must be a
15402     // load.
15403     auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
15404     auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
15405     if (!SRL1C || !LN)
15406       return SDValue();
15407 
15408     // If the shift amount is larger than the input type then we're not
15409     // accessing any of the loaded bytes.  If the load was a zextload/extload
15410     // then the result of the shift+trunc is zero/undef (handled elsewhere).
15411     ShAmt = SRL1C->getZExtValue();
15412     uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
15413     if (ShAmt >= MemoryWidth)
15414       return SDValue();
15415 
15416     // Because a SRL must be assumed to *need* to zero-extend the high bits
15417     // (as opposed to anyext the high bits), we can't combine the zextload
15418     // lowering of SRL and an sextload.
15419     if (LN->getExtensionType() == ISD::SEXTLOAD)
15420       return SDValue();
15421 
15422     // Avoid reading outside the memory accessed by the original load (could
15423     // happened if we only adjust the load base pointer by ShAmt). Instead we
15424     // try to narrow the load even further. The typical scenario here is:
15425     //   (i64 (truncate (i96 (srl (load x), 64)))) ->
15426     //     (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
15427     if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
15428       // Don't replace sextload by zextload.
15429       if (ExtType == ISD::SEXTLOAD)
15430         return SDValue();
15431       // Narrow the load.
15432       ExtType = ISD::ZEXTLOAD;
15433       ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
15434     }
15435 
15436     // If the SRL is only used by a masking AND, we may be able to adjust
15437     // the ExtVT to make the AND redundant.
15438     SDNode *Mask = *(SRL->user_begin());
15439     if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
15440         isa<ConstantSDNode>(Mask->getOperand(1))) {
15441       unsigned Offset, ActiveBits;
15442       const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
15443       if (ShiftMask.isMask()) {
15444         EVT MaskedVT =
15445             EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one());
15446         // If the mask is smaller, recompute the type.
15447         if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
15448             TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
15449           ExtVT = MaskedVT;
15450       } else if (ExtType == ISD::ZEXTLOAD &&
15451                  ShiftMask.isShiftedMask(Offset, ActiveBits) &&
15452                  (Offset + ShAmt) < VT.getScalarSizeInBits()) {
15453         EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
15454         // If the mask is shifted we can use a narrower load and a shl to insert
15455         // the trailing zeros.
15456         if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
15457             TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) {
15458           ExtVT = MaskedVT;
15459           ShAmt = Offset + ShAmt;
15460           ShiftedOffset = Offset;
15461         }
15462       }
15463     }
15464 
15465     N0 = SRL.getOperand(0);
15466   }
15467 
15468   // If the load is shifted left (and the result isn't shifted back right), we
15469   // can fold a truncate through the shift. The typical scenario is that N
15470   // points at a TRUNCATE here so the attempted fold is:
15471   //   (truncate (shl (load x), c))) -> (shl (narrow load x), c)
15472   // ShLeftAmt will indicate how much a narrowed load should be shifted left.
15473   unsigned ShLeftAmt = 0;
15474   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
15475       ExtVT == VT && TLI.isNarrowingProfitable(N, N0.getValueType(), VT)) {
15476     if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
15477       ShLeftAmt = N01->getZExtValue();
15478       N0 = N0.getOperand(0);
15479     }
15480   }
15481 
15482   // If we haven't found a load, we can't narrow it.
15483   if (!isa<LoadSDNode>(N0))
15484     return SDValue();
15485 
15486   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15487   // Reducing the width of a volatile load is illegal.  For atomics, we may be
15488   // able to reduce the width provided we never widen again. (see D66309)
15489   if (!LN0->isSimple() ||
15490       !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
15491     return SDValue();
15492 
15493   auto AdjustBigEndianShift = [&](unsigned ShAmt) {
15494     unsigned LVTStoreBits =
15495         LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
15496     unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
15497     return LVTStoreBits - EVTStoreBits - ShAmt;
15498   };
15499 
15500   // We need to adjust the pointer to the load by ShAmt bits in order to load
15501   // the correct bytes.
15502   unsigned PtrAdjustmentInBits =
15503       DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
15504 
15505   uint64_t PtrOff = PtrAdjustmentInBits / 8;
15506   SDLoc DL(LN0);
15507   // The original load itself didn't wrap, so an offset within it doesn't.
15508   SDValue NewPtr =
15509       DAG.getMemBasePlusOffset(LN0->getBasePtr(), TypeSize::getFixed(PtrOff),
15510                                DL, SDNodeFlags::NoUnsignedWrap);
15511   AddToWorklist(NewPtr.getNode());
15512 
15513   SDValue Load;
15514   if (ExtType == ISD::NON_EXTLOAD) {
15515     const MDNode *OldRanges = LN0->getRanges();
15516     const MDNode *NewRanges = nullptr;
15517     // If LSBs are loaded and the truncated ConstantRange for the OldRanges
15518     // metadata is not the full-set for the new width then create a NewRanges
15519     // metadata for the truncated load
15520     if (ShAmt == 0 && OldRanges) {
15521       ConstantRange CR = getConstantRangeFromMetadata(*OldRanges);
15522       unsigned BitSize = VT.getScalarSizeInBits();
15523 
15524       // It is possible for an 8-bit extending load with 8-bit range
15525       // metadata to be narrowed to an 8-bit load.  This guard is necessary to
15526       // ensure that truncation is strictly smaller.
15527       if (CR.getBitWidth() > BitSize) {
15528         ConstantRange TruncatedCR = CR.truncate(BitSize);
15529         if (!TruncatedCR.isFullSet()) {
15530           Metadata *Bounds[2] = {
15531               ConstantAsMetadata::get(
15532                   ConstantInt::get(*DAG.getContext(), TruncatedCR.getLower())),
15533               ConstantAsMetadata::get(
15534                   ConstantInt::get(*DAG.getContext(), TruncatedCR.getUpper()))};
15535           NewRanges = MDNode::get(*DAG.getContext(), Bounds);
15536         }
15537       } else if (CR.getBitWidth() == BitSize)
15538         NewRanges = OldRanges;
15539     }
15540     Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
15541                        LN0->getPointerInfo().getWithOffset(PtrOff),
15542                        LN0->getBaseAlign(), LN0->getMemOperand()->getFlags(),
15543                        LN0->getAAInfo(), NewRanges);
15544   } else
15545     Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
15546                           LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
15547                           LN0->getBaseAlign(), LN0->getMemOperand()->getFlags(),
15548                           LN0->getAAInfo());
15549 
15550   // Replace the old load's chain with the new load's chain.
15551   WorklistRemover DeadNodes(*this);
15552   DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
15553 
15554   // Shift the result left, if we've swallowed a left shift.
15555   SDValue Result = Load;
15556   if (ShLeftAmt != 0) {
15557     // If the shift amount is as large as the result size (but, presumably,
15558     // no larger than the source) then the useful bits of the result are
15559     // zero; we can't simply return the shortened shift, because the result
15560     // of that operation is undefined.
15561     if (ShLeftAmt >= VT.getScalarSizeInBits())
15562       Result = DAG.getConstant(0, DL, VT);
15563     else
15564       Result = DAG.getNode(ISD::SHL, DL, VT, Result,
15565                            DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
15566   }
15567 
15568   if (ShiftedOffset != 0) {
15569     // We're using a shifted mask, so the load now has an offset. This means
15570     // that data has been loaded into the lower bytes than it would have been
15571     // before, so we need to shl the loaded data into the correct position in the
15572     // register.
15573     SDValue ShiftC = DAG.getConstant(ShiftedOffset, DL, VT);
15574     Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
15575     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
15576   }
15577 
15578   // Return the new loaded value.
15579   return Result;
15580 }
15581 
visitSIGN_EXTEND_INREG(SDNode * N)15582 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
15583   SDValue N0 = N->getOperand(0);
15584   SDValue N1 = N->getOperand(1);
15585   EVT VT = N->getValueType(0);
15586   EVT ExtVT = cast<VTSDNode>(N1)->getVT();
15587   unsigned VTBits = VT.getScalarSizeInBits();
15588   unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
15589   SDLoc DL(N);
15590 
15591   // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
15592   if (N0.isUndef())
15593     return DAG.getConstant(0, DL, VT);
15594 
15595   // fold (sext_in_reg c1) -> c1
15596   if (SDValue C =
15597           DAG.FoldConstantArithmetic(ISD::SIGN_EXTEND_INREG, DL, VT, {N0, N1}))
15598     return C;
15599 
15600   // If the input is already sign extended, just drop the extension.
15601   if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
15602     return N0;
15603 
15604   // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
15605   if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
15606       ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
15607     return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N0.getOperand(0), N1);
15608 
15609   // fold (sext_in_reg (sext x)) -> (sext x)
15610   // fold (sext_in_reg (aext x)) -> (sext x)
15611   // if x is small enough or if we know that x has more than 1 sign bit and the
15612   // sign_extend_inreg is extending from one of them.
15613   if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
15614     SDValue N00 = N0.getOperand(0);
15615     unsigned N00Bits = N00.getScalarValueSizeInBits();
15616     if ((N00Bits <= ExtVTBits ||
15617          DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
15618         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
15619       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N00);
15620   }
15621 
15622   // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
15623   // if x is small enough or if we know that x has more than 1 sign bit and the
15624   // sign_extend_inreg is extending from one of them.
15625   if (ISD::isExtVecInRegOpcode(N0.getOpcode())) {
15626     SDValue N00 = N0.getOperand(0);
15627     unsigned N00Bits = N00.getScalarValueSizeInBits();
15628     bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
15629     if ((N00Bits == ExtVTBits ||
15630          (!IsZext && (N00Bits < ExtVTBits ||
15631                       DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
15632         (!LegalOperations ||
15633          TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
15634       return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, DL, VT, N00);
15635   }
15636 
15637   // fold (sext_in_reg (zext x)) -> (sext x)
15638   // iff we are extending the source sign bit.
15639   if (N0.getOpcode() == ISD::ZERO_EXTEND) {
15640     SDValue N00 = N0.getOperand(0);
15641     if (N00.getScalarValueSizeInBits() == ExtVTBits &&
15642         (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
15643       return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N00);
15644   }
15645 
15646   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
15647   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
15648     return DAG.getZeroExtendInReg(N0, DL, ExtVT);
15649 
15650   // fold operands of sext_in_reg based on knowledge that the top bits are not
15651   // demanded.
15652   if (SimplifyDemandedBits(SDValue(N, 0)))
15653     return SDValue(N, 0);
15654 
15655   // fold (sext_in_reg (load x)) -> (smaller sextload x)
15656   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
15657   if (SDValue NarrowLoad = reduceLoadWidth(N))
15658     return NarrowLoad;
15659 
15660   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
15661   // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
15662   // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
15663   if (N0.getOpcode() == ISD::SRL) {
15664     if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
15665       if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
15666         // We can turn this into an SRA iff the input to the SRL is already sign
15667         // extended enough.
15668         unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
15669         if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
15670           return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0),
15671                              N0.getOperand(1));
15672       }
15673   }
15674 
15675   // fold (sext_inreg (extload x)) -> (sextload x)
15676   // If sextload is not supported by target, we can only do the combine when
15677   // load has one use. Doing otherwise can block folding the extload with other
15678   // extends that the target does support.
15679   if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
15680       ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
15681       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
15682         N0.hasOneUse()) ||
15683        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
15684     auto *LN0 = cast<LoadSDNode>(N0);
15685     SDValue ExtLoad =
15686         DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
15687                        LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
15688     CombineTo(N, ExtLoad);
15689     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
15690     AddToWorklist(ExtLoad.getNode());
15691     return SDValue(N, 0); // Return N so it doesn't get rechecked!
15692   }
15693 
15694   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
15695   if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
15696       N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
15697       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
15698        TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
15699     auto *LN0 = cast<LoadSDNode>(N0);
15700     SDValue ExtLoad =
15701         DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
15702                        LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
15703     CombineTo(N, ExtLoad);
15704     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
15705     return SDValue(N, 0); // Return N so it doesn't get rechecked!
15706   }
15707 
15708   // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
15709   // ignore it if the masked load is already sign extended
15710   if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
15711     if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
15712         Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
15713         TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
15714       SDValue ExtMaskedLoad = DAG.getMaskedLoad(
15715           VT, DL, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
15716           Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
15717           Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
15718       CombineTo(N, ExtMaskedLoad);
15719       CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
15720       return SDValue(N, 0); // Return N so it doesn't get rechecked!
15721     }
15722   }
15723 
15724   // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
15725   if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
15726     if (SDValue(GN0, 0).hasOneUse() && ExtVT == GN0->getMemoryVT() &&
15727         TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
15728       SDValue Ops[] = {GN0->getChain(),   GN0->getPassThru(), GN0->getMask(),
15729                        GN0->getBasePtr(), GN0->getIndex(),    GN0->getScale()};
15730 
15731       SDValue ExtLoad = DAG.getMaskedGather(
15732           DAG.getVTList(VT, MVT::Other), ExtVT, DL, Ops, GN0->getMemOperand(),
15733           GN0->getIndexType(), ISD::SEXTLOAD);
15734 
15735       CombineTo(N, ExtLoad);
15736       CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
15737       AddToWorklist(ExtLoad.getNode());
15738       return SDValue(N, 0); // Return N so it doesn't get rechecked!
15739     }
15740   }
15741 
15742   // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
15743   if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
15744     if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
15745                                            N0.getOperand(1), false))
15746       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, BSwap, N1);
15747   }
15748 
15749   // Fold (iM_signext_inreg
15750   //        (extract_subvector (zext|anyext|sext iN_v to _) _)
15751   //        from iN)
15752   //      -> (extract_subvector (signext iN_v to iM))
15753   if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
15754       ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
15755     SDValue InnerExt = N0.getOperand(0);
15756     EVT InnerExtVT = InnerExt->getValueType(0);
15757     SDValue Extendee = InnerExt->getOperand(0);
15758 
15759     if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
15760         (!LegalOperations ||
15761          TLI.isOperationLegal(ISD::SIGN_EXTEND, InnerExtVT))) {
15762       SDValue SignExtExtendee =
15763           DAG.getNode(ISD::SIGN_EXTEND, DL, InnerExtVT, Extendee);
15764       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SignExtExtendee,
15765                          N0.getOperand(1));
15766     }
15767   }
15768 
15769   return SDValue();
15770 }
15771 
foldExtendVectorInregToExtendOfSubvector(SDNode * N,const SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalOperations)15772 static SDValue foldExtendVectorInregToExtendOfSubvector(
15773     SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
15774     bool LegalOperations) {
15775   unsigned InregOpcode = N->getOpcode();
15776   unsigned Opcode = DAG.getOpcode_EXTEND(InregOpcode);
15777 
15778   SDValue Src = N->getOperand(0);
15779   EVT VT = N->getValueType(0);
15780   EVT SrcVT = EVT::getVectorVT(*DAG.getContext(),
15781                                Src.getValueType().getVectorElementType(),
15782                                VT.getVectorElementCount());
15783 
15784   assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
15785          "Expected EXTEND_VECTOR_INREG dag node in input!");
15786 
15787   // Profitability check: our operand must be an one-use CONCAT_VECTORS.
15788   // FIXME: one-use check may be overly restrictive
15789   if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
15790     return SDValue();
15791 
15792   // Profitability check: we must be extending exactly one of it's operands.
15793   // FIXME: this is probably overly restrictive.
15794   Src = Src.getOperand(0);
15795   if (Src.getValueType() != SrcVT)
15796     return SDValue();
15797 
15798   if (LegalOperations && !TLI.isOperationLegal(Opcode, VT))
15799     return SDValue();
15800 
15801   return DAG.getNode(Opcode, DL, VT, Src);
15802 }
15803 
visitEXTEND_VECTOR_INREG(SDNode * N)15804 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
15805   SDValue N0 = N->getOperand(0);
15806   EVT VT = N->getValueType(0);
15807   SDLoc DL(N);
15808 
15809   if (N0.isUndef()) {
15810     // aext_vector_inreg(undef) = undef because the top bits are undefined.
15811     // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
15812     return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
15813                ? DAG.getUNDEF(VT)
15814                : DAG.getConstant(0, DL, VT);
15815   }
15816 
15817   if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15818     return Res;
15819 
15820   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
15821     return SDValue(N, 0);
15822 
15823   if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
15824                                                            LegalOperations))
15825     return R;
15826 
15827   return SDValue();
15828 }
15829 
visitTRUNCATE_USAT_U(SDNode * N)15830 SDValue DAGCombiner::visitTRUNCATE_USAT_U(SDNode *N) {
15831   EVT VT = N->getValueType(0);
15832   SDValue N0 = N->getOperand(0);
15833 
15834   SDValue FPVal;
15835   if (sd_match(N0, m_FPToUI(m_Value(FPVal))) &&
15836       DAG.getTargetLoweringInfo().shouldConvertFpToSat(
15837           ISD::FP_TO_UINT_SAT, FPVal.getValueType(), VT))
15838     return DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), VT, FPVal,
15839                        DAG.getValueType(VT.getScalarType()));
15840 
15841   return SDValue();
15842 }
15843 
15844 /// Detect patterns of truncation with unsigned saturation:
15845 ///
15846 /// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
15847 /// Return the source value x to be truncated or SDValue() if the pattern was
15848 /// not matched.
15849 ///
detectUSatUPattern(SDValue In,EVT VT)15850 static SDValue detectUSatUPattern(SDValue In, EVT VT) {
15851   unsigned NumDstBits = VT.getScalarSizeInBits();
15852   unsigned NumSrcBits = In.getScalarValueSizeInBits();
15853   // Saturation with truncation. We truncate from InVT to VT.
15854   assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15855 
15856   SDValue Min;
15857   APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
15858   if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax))))
15859     return Min;
15860 
15861   return SDValue();
15862 }
15863 
15864 /// Detect patterns of truncation with signed saturation:
15865 /// (truncate (smin (smax (x, signed_min_of_dest_type),
15866 ///                  signed_max_of_dest_type)) to dest_type)
15867 /// or:
15868 /// (truncate (smax (smin (x, signed_max_of_dest_type),
15869 ///                  signed_min_of_dest_type)) to dest_type).
15870 ///
15871 /// Return the source value to be truncated or SDValue() if the pattern was not
15872 /// matched.
detectSSatSPattern(SDValue In,EVT VT)15873 static SDValue detectSSatSPattern(SDValue In, EVT VT) {
15874   unsigned NumDstBits = VT.getScalarSizeInBits();
15875   unsigned NumSrcBits = In.getScalarValueSizeInBits();
15876   // Saturation with truncation. We truncate from InVT to VT.
15877   assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15878 
15879   SDValue Val;
15880   APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
15881   APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
15882 
15883   if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)),
15884                           m_SpecificInt(SignedMax))))
15885     return Val;
15886 
15887   if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)),
15888                           m_SpecificInt(SignedMin))))
15889     return Val;
15890 
15891   return SDValue();
15892 }
15893 
15894 /// Detect patterns of truncation with unsigned saturation:
detectSSatUPattern(SDValue In,EVT VT,SelectionDAG & DAG,const SDLoc & DL)15895 static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
15896                                   const SDLoc &DL) {
15897   unsigned NumDstBits = VT.getScalarSizeInBits();
15898   unsigned NumSrcBits = In.getScalarValueSizeInBits();
15899   // Saturation with truncation. We truncate from InVT to VT.
15900   assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15901 
15902   SDValue Val;
15903   APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
15904   // Min == 0, Max is unsigned max of destination type.
15905   if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)),
15906                           m_Zero())))
15907     return Val;
15908 
15909   if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()),
15910                           m_SpecificInt(UnsignedMax))))
15911     return Val;
15912 
15913   if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()),
15914                           m_SpecificInt(UnsignedMax))))
15915     return Val;
15916 
15917   return SDValue();
15918 }
15919 
foldToSaturated(SDNode * N,EVT & VT,SDValue & Src,EVT & SrcVT,SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG)15920 static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
15921                                SDLoc &DL, const TargetLowering &TLI,
15922                                SelectionDAG &DAG) {
15923   auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
15924     return (TLI.isOperationLegalOrCustom(Opc, SrcVT) &&
15925             TLI.isTypeDesirableForOp(Opc, VT));
15926   };
15927 
15928   if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15929     if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
15930       if (SDValue SSatVal = detectSSatSPattern(Src, VT))
15931         return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15932     if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15933       if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15934         return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15935   } else if (Src.getOpcode() == ISD::UMIN) {
15936     if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15937       if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15938         return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15939     if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
15940       if (SDValue USatVal = detectUSatUPattern(Src, VT))
15941         return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
15942   }
15943 
15944   return SDValue();
15945 }
15946 
visitTRUNCATE(SDNode * N)15947 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
15948   SDValue N0 = N->getOperand(0);
15949   EVT VT = N->getValueType(0);
15950   EVT SrcVT = N0.getValueType();
15951   bool isLE = DAG.getDataLayout().isLittleEndian();
15952   SDLoc DL(N);
15953 
15954   // trunc(undef) = undef
15955   if (N0.isUndef())
15956     return DAG.getUNDEF(VT);
15957 
15958   // fold (truncate (truncate x)) -> (truncate x)
15959   if (N0.getOpcode() == ISD::TRUNCATE)
15960     return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
15961 
15962   // fold saturated truncate
15963   if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG))
15964     return SaturatedTR;
15965 
15966   // fold (truncate c1) -> c1
15967   if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
15968     return C;
15969 
15970   // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
15971   if (N0.getOpcode() == ISD::ZERO_EXTEND ||
15972       N0.getOpcode() == ISD::SIGN_EXTEND ||
15973       N0.getOpcode() == ISD::ANY_EXTEND) {
15974     // if the source is smaller than the dest, we still need an extend.
15975     if (N0.getOperand(0).getValueType().bitsLT(VT)) {
15976       SDNodeFlags Flags;
15977       if (N0.getOpcode() == ISD::ZERO_EXTEND)
15978         Flags.setNonNeg(N0->getFlags().hasNonNeg());
15979       return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Flags);
15980     }
15981     // if the source is larger than the dest, than we just need the truncate.
15982     if (N0.getOperand(0).getValueType().bitsGT(VT))
15983       return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
15984     // if the source and dest are the same type, we can drop both the extend
15985     // and the truncate.
15986     return N0.getOperand(0);
15987   }
15988 
15989   // Try to narrow a truncate-of-sext_in_reg to the destination type:
15990   // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
15991   if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
15992       N0.hasOneUse()) {
15993     SDValue X = N0.getOperand(0);
15994     SDValue ExtVal = N0.getOperand(1);
15995     EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
15996     if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(VT, SrcVT, ExtVT)) {
15997       SDValue TrX = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
15998       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, TrX, ExtVal);
15999     }
16000   }
16001 
16002   // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
16003   if (N->hasOneUse() && (N->user_begin()->getOpcode() == ISD::ANY_EXTEND))
16004     return SDValue();
16005 
16006   // Fold extract-and-trunc into a narrow extract. For example:
16007   //   i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
16008   //   i32 y = TRUNCATE(i64 x)
16009   //        -- becomes --
16010   //   v16i8 b = BITCAST (v2i64 val)
16011   //   i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
16012   //
16013   // Note: We only run this optimization after type legalization (which often
16014   // creates this pattern) and before operation legalization after which
16015   // we need to be more careful about the vector instructions that we generate.
16016   if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
16017       N0->hasOneUse()) {
16018     EVT TrTy = N->getValueType(0);
16019     SDValue Src = N0;
16020 
16021     // Check for cases where we shift down an upper element before truncation.
16022     int EltOffset = 0;
16023     if (Src.getOpcode() == ISD::SRL && Src.getOperand(0)->hasOneUse()) {
16024       if (auto ShAmt = DAG.getValidShiftAmount(Src)) {
16025         if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
16026           Src = Src.getOperand(0);
16027           EltOffset = *ShAmt / TrTy.getSizeInBits();
16028         }
16029       }
16030     }
16031 
16032     if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
16033       EVT VecTy = Src.getOperand(0).getValueType();
16034       EVT ExTy = Src.getValueType();
16035 
16036       auto EltCnt = VecTy.getVectorElementCount();
16037       unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
16038       auto NewEltCnt = EltCnt * SizeRatio;
16039 
16040       EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
16041       assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
16042 
16043       SDValue EltNo = Src->getOperand(1);
16044       if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
16045         int Elt = EltNo->getAsZExtVal();
16046         int Index = isLE ? (Elt * SizeRatio + EltOffset)
16047                          : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
16048         return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
16049                            DAG.getBitcast(NVT, Src.getOperand(0)),
16050                            DAG.getVectorIdxConstant(Index, DL));
16051       }
16052     }
16053   }
16054 
16055   // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
16056   if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse() &&
16057       TLI.isTruncateFree(SrcVT, VT)) {
16058     if (!LegalOperations ||
16059         (TLI.isOperationLegal(ISD::SELECT, SrcVT) &&
16060          TLI.isNarrowingProfitable(N0.getNode(), SrcVT, VT))) {
16061       SDLoc SL(N0);
16062       SDValue Cond = N0.getOperand(0);
16063       SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
16064       SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
16065       return DAG.getNode(ISD::SELECT, DL, VT, Cond, TruncOp0, TruncOp1);
16066     }
16067   }
16068 
16069   // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
16070   if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
16071       (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
16072       TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
16073     SDValue Amt = N0.getOperand(1);
16074     KnownBits Known = DAG.computeKnownBits(Amt);
16075     unsigned Size = VT.getScalarSizeInBits();
16076     if (Known.countMaxActiveBits() <= Log2_32(Size)) {
16077       EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
16078       SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16079       if (AmtVT != Amt.getValueType()) {
16080         Amt = DAG.getZExtOrTrunc(Amt, DL, AmtVT);
16081         AddToWorklist(Amt.getNode());
16082       }
16083       return DAG.getNode(ISD::SHL, DL, VT, Trunc, Amt);
16084     }
16085   }
16086 
16087   if (SDValue V = foldSubToUSubSat(VT, N0.getNode(), DL))
16088     return V;
16089 
16090   if (SDValue ABD = foldABSToABD(N, DL))
16091     return ABD;
16092 
16093   // Attempt to pre-truncate BUILD_VECTOR sources.
16094   if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
16095       N0.hasOneUse() &&
16096       TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
16097       // Avoid creating illegal types if running after type legalizer.
16098       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
16099     EVT SVT = VT.getScalarType();
16100     SmallVector<SDValue, 8> TruncOps;
16101     for (const SDValue &Op : N0->op_values()) {
16102       SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
16103       TruncOps.push_back(TruncOp);
16104     }
16105     return DAG.getBuildVector(VT, DL, TruncOps);
16106   }
16107 
16108   // trunc (splat_vector x) -> splat_vector (trunc x)
16109   if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
16110       (!LegalTypes || TLI.isTypeLegal(VT.getScalarType())) &&
16111       (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT))) {
16112     EVT SVT = VT.getScalarType();
16113     return DAG.getSplatVector(
16114         VT, DL, DAG.getNode(ISD::TRUNCATE, DL, SVT, N0->getOperand(0)));
16115   }
16116 
16117   // Fold a series of buildvector, bitcast, and truncate if possible.
16118   // For example fold
16119   //   (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
16120   //   (2xi32 (buildvector x, y)).
16121   if (Level == AfterLegalizeVectorOps && VT.isVector() &&
16122       N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
16123       N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
16124       N0.getOperand(0).hasOneUse()) {
16125     SDValue BuildVect = N0.getOperand(0);
16126     EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
16127     EVT TruncVecEltTy = VT.getVectorElementType();
16128 
16129     // Check that the element types match.
16130     if (BuildVectEltTy == TruncVecEltTy) {
16131       // Now we only need to compute the offset of the truncated elements.
16132       unsigned BuildVecNumElts =  BuildVect.getNumOperands();
16133       unsigned TruncVecNumElts = VT.getVectorNumElements();
16134       unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
16135       unsigned FirstElt = isLE ? 0 : (TruncEltOffset - 1);
16136 
16137       assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
16138              "Invalid number of elements");
16139 
16140       SmallVector<SDValue, 8> Opnds;
16141       for (unsigned i = FirstElt, e = BuildVecNumElts; i < e;
16142            i += TruncEltOffset)
16143         Opnds.push_back(BuildVect.getOperand(i));
16144 
16145       return DAG.getBuildVector(VT, DL, Opnds);
16146     }
16147   }
16148 
16149   // fold (truncate (load x)) -> (smaller load x)
16150   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
16151   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
16152     if (SDValue Reduced = reduceLoadWidth(N))
16153       return Reduced;
16154 
16155     // Handle the case where the truncated result is at least as wide as the
16156     // loaded type.
16157     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
16158       auto *LN0 = cast<LoadSDNode>(N0);
16159       if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
16160         SDValue NewLoad = DAG.getExtLoad(
16161             LN0->getExtensionType(), SDLoc(LN0), VT, LN0->getChain(),
16162             LN0->getBasePtr(), LN0->getMemoryVT(), LN0->getMemOperand());
16163         DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
16164         return NewLoad;
16165       }
16166     }
16167   }
16168 
16169   // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
16170   // where ... are all 'undef'.
16171   if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
16172     SmallVector<EVT, 8> VTs;
16173     SDValue V;
16174     unsigned Idx = 0;
16175     unsigned NumDefs = 0;
16176 
16177     for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
16178       SDValue X = N0.getOperand(i);
16179       if (!X.isUndef()) {
16180         V = X;
16181         Idx = i;
16182         NumDefs++;
16183       }
16184       // Stop if more than one members are non-undef.
16185       if (NumDefs > 1)
16186         break;
16187 
16188       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
16189                                      VT.getVectorElementType(),
16190                                      X.getValueType().getVectorElementCount()));
16191     }
16192 
16193     if (NumDefs == 0)
16194       return DAG.getUNDEF(VT);
16195 
16196     if (NumDefs == 1) {
16197       assert(V.getNode() && "The single defined operand is empty!");
16198       SmallVector<SDValue, 8> Opnds;
16199       for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
16200         if (i != Idx) {
16201           Opnds.push_back(DAG.getUNDEF(VTs[i]));
16202           continue;
16203         }
16204         SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
16205         AddToWorklist(NV.getNode());
16206         Opnds.push_back(NV);
16207       }
16208       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Opnds);
16209     }
16210   }
16211 
16212   // Fold truncate of a bitcast of a vector to an extract of the low vector
16213   // element.
16214   //
16215   // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
16216   if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
16217     SDValue VecSrc = N0.getOperand(0);
16218     EVT VecSrcVT = VecSrc.getValueType();
16219     if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
16220         (!LegalOperations ||
16221          TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
16222       unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
16223       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VecSrc,
16224                          DAG.getVectorIdxConstant(Idx, DL));
16225     }
16226   }
16227 
16228   // Simplify the operands using demanded-bits information.
16229   if (SimplifyDemandedBits(SDValue(N, 0)))
16230     return SDValue(N, 0);
16231 
16232   // fold (truncate (extract_subvector(ext x))) ->
16233   //      (extract_subvector x)
16234   // TODO: This can be generalized to cover cases where the truncate and extract
16235   // do not fully cancel each other out.
16236   if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
16237     SDValue N00 = N0.getOperand(0);
16238     if (N00.getOpcode() == ISD::SIGN_EXTEND ||
16239         N00.getOpcode() == ISD::ZERO_EXTEND ||
16240         N00.getOpcode() == ISD::ANY_EXTEND) {
16241       if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
16242           VT.getVectorElementType())
16243         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
16244                            N00.getOperand(0), N0.getOperand(1));
16245     }
16246   }
16247 
16248   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
16249     return NewVSel;
16250 
16251   // Narrow a suitable binary operation with a non-opaque constant operand by
16252   // moving it ahead of the truncate. This is limited to pre-legalization
16253   // because targets may prefer a wider type during later combines and invert
16254   // this transform.
16255   switch (N0.getOpcode()) {
16256   case ISD::ADD:
16257   case ISD::SUB:
16258   case ISD::MUL:
16259   case ISD::AND:
16260   case ISD::OR:
16261   case ISD::XOR:
16262     if (!LegalOperations && N0.hasOneUse() &&
16263         (isConstantOrConstantVector(N0.getOperand(0), true) ||
16264          isConstantOrConstantVector(N0.getOperand(1), true))) {
16265       // TODO: We already restricted this to pre-legalization, but for vectors
16266       // we are extra cautious to not create an unsupported operation.
16267       // Target-specific changes are likely needed to avoid regressions here.
16268       if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
16269         SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16270         SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
16271         return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
16272       }
16273     }
16274     break;
16275   case ISD::ADDE:
16276   case ISD::UADDO_CARRY:
16277     // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
16278     // (trunc uaddo_carry(X, Y, Carry)) ->
16279     //     (uaddo_carry trunc(X), trunc(Y), Carry)
16280     // When the adde's carry is not used.
16281     // We only do for uaddo_carry before legalize operation
16282     if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
16283          TLI.isOperationLegal(N0.getOpcode(), VT)) &&
16284         N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
16285       SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16286       SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
16287       SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
16288       return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
16289     }
16290     break;
16291   case ISD::USUBSAT:
16292     // Truncate the USUBSAT only if LHS is a known zero-extension, its not
16293     // enough to know that the upper bits are zero we must ensure that we don't
16294     // introduce an extra truncate.
16295     if (!LegalOperations && N0.hasOneUse() &&
16296         N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
16297         N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
16298             VT.getScalarSizeInBits() &&
16299         hasOperation(N0.getOpcode(), VT)) {
16300       return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
16301                                  DAG, DL);
16302     }
16303     break;
16304   }
16305 
16306   return SDValue();
16307 }
16308 
getBuildPairElt(SDNode * N,unsigned i)16309 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
16310   SDValue Elt = N->getOperand(i);
16311   if (Elt.getOpcode() != ISD::MERGE_VALUES)
16312     return Elt.getNode();
16313   return Elt.getOperand(Elt.getResNo()).getNode();
16314 }
16315 
16316 /// build_pair (load, load) -> load
16317 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)16318 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
16319   assert(N->getOpcode() == ISD::BUILD_PAIR);
16320 
16321   auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
16322   auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
16323 
16324   // A BUILD_PAIR is always having the least significant part in elt 0 and the
16325   // most significant part in elt 1. So when combining into one large load, we
16326   // need to consider the endianness.
16327   if (DAG.getDataLayout().isBigEndian())
16328     std::swap(LD1, LD2);
16329 
16330   if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
16331       !LD1->hasOneUse() || !LD2->hasOneUse() ||
16332       LD1->getAddressSpace() != LD2->getAddressSpace())
16333     return SDValue();
16334 
16335   unsigned LD1Fast = 0;
16336   EVT LD1VT = LD1->getValueType(0);
16337   unsigned LD1Bytes = LD1VT.getStoreSize();
16338   if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
16339       DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
16340       TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
16341                              *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
16342     return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
16343                        LD1->getPointerInfo(), LD1->getAlign());
16344 
16345   return SDValue();
16346 }
16347 
getPPCf128HiElementSelector(const SelectionDAG & DAG)16348 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
16349   // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
16350   // and Lo parts; on big-endian machines it doesn't.
16351   return DAG.getDataLayout().isBigEndian() ? 1 : 0;
16352 }
16353 
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)16354 SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
16355                                           const TargetLowering &TLI) {
16356   // If this is not a bitcast to an FP type or if the target doesn't have
16357   // IEEE754-compliant FP logic, we're done.
16358   EVT VT = N->getValueType(0);
16359   SDValue N0 = N->getOperand(0);
16360   EVT SourceVT = N0.getValueType();
16361 
16362   if (!VT.isFloatingPoint())
16363     return SDValue();
16364 
16365   // TODO: Handle cases where the integer constant is a different scalar
16366   // bitwidth to the FP.
16367   if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
16368     return SDValue();
16369 
16370   unsigned FPOpcode;
16371   APInt SignMask;
16372   switch (N0.getOpcode()) {
16373   case ISD::AND:
16374     FPOpcode = ISD::FABS;
16375     SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
16376     break;
16377   case ISD::XOR:
16378     FPOpcode = ISD::FNEG;
16379     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
16380     break;
16381   case ISD::OR:
16382     FPOpcode = ISD::FABS;
16383     SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
16384     break;
16385   default:
16386     return SDValue();
16387   }
16388 
16389   if (LegalOperations && !TLI.isOperationLegal(FPOpcode, VT))
16390     return SDValue();
16391 
16392   // This needs to be the inverse of logic in foldSignChangeInBitcast.
16393   // FIXME: I don't think looking for bitcast intrinsically makes sense, but
16394   // removing this would require more changes.
16395   auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
16396     if (sd_match(Op, m_BitCast(m_SpecificVT(VT))))
16397       return true;
16398 
16399     return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
16400   };
16401 
16402   // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
16403   // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
16404   // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
16405   //   fneg (fabs X)
16406   SDValue LogicOp0 = N0.getOperand(0);
16407   ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
16408   if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
16409       IsBitCastOrFree(LogicOp0, VT)) {
16410     SDValue CastOp0 = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, LogicOp0);
16411     SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, CastOp0);
16412     NumFPLogicOpsConv++;
16413     if (N0.getOpcode() == ISD::OR)
16414       return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
16415     return FPOp;
16416   }
16417 
16418   return SDValue();
16419 }
16420 
visitBITCAST(SDNode * N)16421 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
16422   SDValue N0 = N->getOperand(0);
16423   EVT VT = N->getValueType(0);
16424 
16425   if (N0.isUndef())
16426     return DAG.getUNDEF(VT);
16427 
16428   // If the input is a BUILD_VECTOR with all constant elements, fold this now.
16429   // Only do this before legalize types, unless both types are integer and the
16430   // scalar type is legal. Only do this before legalize ops, since the target
16431   // maybe depending on the bitcast.
16432   // First check to see if this is all constant.
16433   // TODO: Support FP bitcasts after legalize types.
16434   if (VT.isVector() &&
16435       (!LegalTypes ||
16436        (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
16437         TLI.isTypeLegal(VT.getVectorElementType()))) &&
16438       N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
16439       cast<BuildVectorSDNode>(N0)->isConstant())
16440     return DAG.FoldConstantBuildVector(cast<BuildVectorSDNode>(N0), SDLoc(N),
16441                                        VT.getVectorElementType());
16442 
16443   // If the input is a constant, let getNode fold it.
16444   if (isIntOrFPConstant(N0)) {
16445     // If we can't allow illegal operations, we need to check that this is just
16446     // a fp -> int or int -> conversion and that the resulting operation will
16447     // be legal.
16448     if (!LegalOperations ||
16449         (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
16450          TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
16451         (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
16452          TLI.isOperationLegal(ISD::Constant, VT))) {
16453       SDValue C = DAG.getBitcast(VT, N0);
16454       if (C.getNode() != N)
16455         return C;
16456     }
16457   }
16458 
16459   // (conv (conv x, t1), t2) -> (conv x, t2)
16460   if (N0.getOpcode() == ISD::BITCAST)
16461     return DAG.getBitcast(VT, N0.getOperand(0));
16462 
16463   // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
16464   // iff the current bitwise logicop type isn't legal
16465   if (ISD::isBitwiseLogicOp(N0.getOpcode()) && VT.isInteger() &&
16466       !TLI.isTypeLegal(N0.getOperand(0).getValueType())) {
16467     auto IsFreeBitcast = [VT](SDValue V) {
16468       return (V.getOpcode() == ISD::BITCAST &&
16469               V.getOperand(0).getValueType() == VT) ||
16470              (ISD::isBuildVectorOfConstantSDNodes(V.getNode()) &&
16471               V->hasOneUse());
16472     };
16473     if (IsFreeBitcast(N0.getOperand(0)) && IsFreeBitcast(N0.getOperand(1)))
16474       return DAG.getNode(N0.getOpcode(), SDLoc(N), VT,
16475                          DAG.getBitcast(VT, N0.getOperand(0)),
16476                          DAG.getBitcast(VT, N0.getOperand(1)));
16477   }
16478 
16479   // fold (conv (load x)) -> (load (conv*)x)
16480   // If the resultant load doesn't need a higher alignment than the original!
16481   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16482       // Do not remove the cast if the types differ in endian layout.
16483       TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
16484           TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
16485       // If the load is volatile, we only want to change the load type if the
16486       // resulting load is legal. Otherwise we might increase the number of
16487       // memory accesses. We don't care if the original type was legal or not
16488       // as we assume software couldn't rely on the number of accesses of an
16489       // illegal type.
16490       ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
16491        TLI.isOperationLegal(ISD::LOAD, VT))) {
16492     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
16493 
16494     if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
16495                                     *LN0->getMemOperand())) {
16496       // If the range metadata type does not match the new memory
16497       // operation type, remove the range metadata.
16498       if (const MDNode *MD = LN0->getRanges()) {
16499         ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
16500         if (Lower->getBitWidth() != VT.getScalarSizeInBits() ||
16501             !VT.isInteger()) {
16502           LN0->getMemOperand()->clearRanges();
16503         }
16504       }
16505       SDValue Load =
16506           DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
16507                       LN0->getMemOperand());
16508       DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
16509       return Load;
16510     }
16511   }
16512 
16513   if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
16514     return V;
16515 
16516   // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
16517   // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
16518   //
16519   // For ppc_fp128:
16520   // fold (bitcast (fneg x)) ->
16521   //     flipbit = signbit
16522   //     (xor (bitcast x) (build_pair flipbit, flipbit))
16523   //
16524   // fold (bitcast (fabs x)) ->
16525   //     flipbit = (and (extract_element (bitcast x), 0), signbit)
16526   //     (xor (bitcast x) (build_pair flipbit, flipbit))
16527   // This often reduces constant pool loads.
16528   if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
16529        (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
16530       N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
16531       !N0.getValueType().isVector()) {
16532     SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
16533     AddToWorklist(NewConv.getNode());
16534 
16535     SDLoc DL(N);
16536     if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
16537       assert(VT.getSizeInBits() == 128);
16538       SDValue SignBit = DAG.getConstant(
16539           APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
16540       SDValue FlipBit;
16541       if (N0.getOpcode() == ISD::FNEG) {
16542         FlipBit = SignBit;
16543         AddToWorklist(FlipBit.getNode());
16544       } else {
16545         assert(N0.getOpcode() == ISD::FABS);
16546         SDValue Hi =
16547             DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
16548                         DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
16549                                               SDLoc(NewConv)));
16550         AddToWorklist(Hi.getNode());
16551         FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
16552         AddToWorklist(FlipBit.getNode());
16553       }
16554       SDValue FlipBits =
16555           DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
16556       AddToWorklist(FlipBits.getNode());
16557       return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
16558     }
16559     APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
16560     if (N0.getOpcode() == ISD::FNEG)
16561       return DAG.getNode(ISD::XOR, DL, VT,
16562                          NewConv, DAG.getConstant(SignBit, DL, VT));
16563     assert(N0.getOpcode() == ISD::FABS);
16564     return DAG.getNode(ISD::AND, DL, VT,
16565                        NewConv, DAG.getConstant(~SignBit, DL, VT));
16566   }
16567 
16568   // fold (bitconvert (fcopysign cst, x)) ->
16569   //         (or (and (bitconvert x), sign), (and cst, (not sign)))
16570   // Note that we don't handle (copysign x, cst) because this can always be
16571   // folded to an fneg or fabs.
16572   //
16573   // For ppc_fp128:
16574   // fold (bitcast (fcopysign cst, x)) ->
16575   //     flipbit = (and (extract_element
16576   //                     (xor (bitcast cst), (bitcast x)), 0),
16577   //                    signbit)
16578   //     (xor (bitcast cst) (build_pair flipbit, flipbit))
16579   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
16580       isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
16581       !VT.isVector()) {
16582     unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
16583     EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
16584     if (isTypeLegal(IntXVT)) {
16585       SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
16586       AddToWorklist(X.getNode());
16587 
16588       // If X has a different width than the result/lhs, sext it or truncate it.
16589       unsigned VTWidth = VT.getSizeInBits();
16590       if (OrigXWidth < VTWidth) {
16591         X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
16592         AddToWorklist(X.getNode());
16593       } else if (OrigXWidth > VTWidth) {
16594         // To get the sign bit in the right place, we have to shift it right
16595         // before truncating.
16596         SDLoc DL(X);
16597         X = DAG.getNode(ISD::SRL, DL,
16598                         X.getValueType(), X,
16599                         DAG.getConstant(OrigXWidth-VTWidth, DL,
16600                                         X.getValueType()));
16601         AddToWorklist(X.getNode());
16602         X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
16603         AddToWorklist(X.getNode());
16604       }
16605 
16606       if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
16607         APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
16608         SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
16609         AddToWorklist(Cst.getNode());
16610         SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
16611         AddToWorklist(X.getNode());
16612         SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
16613         AddToWorklist(XorResult.getNode());
16614         SDValue XorResult64 = DAG.getNode(
16615             ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
16616             DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
16617                                   SDLoc(XorResult)));
16618         AddToWorklist(XorResult64.getNode());
16619         SDValue FlipBit =
16620             DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
16621                         DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
16622         AddToWorklist(FlipBit.getNode());
16623         SDValue FlipBits =
16624             DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
16625         AddToWorklist(FlipBits.getNode());
16626         return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
16627       }
16628       APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
16629       X = DAG.getNode(ISD::AND, SDLoc(X), VT,
16630                       X, DAG.getConstant(SignBit, SDLoc(X), VT));
16631       AddToWorklist(X.getNode());
16632 
16633       SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
16634       Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
16635                         Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
16636       AddToWorklist(Cst.getNode());
16637 
16638       return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
16639     }
16640   }
16641 
16642   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
16643   if (N0.getOpcode() == ISD::BUILD_PAIR)
16644     if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
16645       return CombineLD;
16646 
16647   // int_vt (bitcast (vec_vt (scalar_to_vector elt_vt:x)))
16648   //   => int_vt (any_extend elt_vt:x)
16649   if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && VT.isScalarInteger()) {
16650     SDValue SrcScalar = N0.getOperand(0);
16651     if (SrcScalar.getValueType().isScalarInteger())
16652       return DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, SrcScalar);
16653   }
16654 
16655   // Remove double bitcasts from shuffles - this is often a legacy of
16656   // XformToShuffleWithZero being used to combine bitmaskings (of
16657   // float vectors bitcast to integer vectors) into shuffles.
16658   // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
16659   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
16660       N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
16661       VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
16662       !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
16663     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
16664 
16665     // If operands are a bitcast, peek through if it casts the original VT.
16666     // If operands are a constant, just bitcast back to original VT.
16667     auto PeekThroughBitcast = [&](SDValue Op) {
16668       if (Op.getOpcode() == ISD::BITCAST &&
16669           Op.getOperand(0).getValueType() == VT)
16670         return SDValue(Op.getOperand(0));
16671       if (Op.isUndef() || isAnyConstantBuildVector(Op))
16672         return DAG.getBitcast(VT, Op);
16673       return SDValue();
16674     };
16675 
16676     // FIXME: If either input vector is bitcast, try to convert the shuffle to
16677     // the result type of this bitcast. This would eliminate at least one
16678     // bitcast. See the transform in InstCombine.
16679     SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
16680     SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
16681     if (!(SV0 && SV1))
16682       return SDValue();
16683 
16684     int MaskScale =
16685         VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
16686     SmallVector<int, 8> NewMask;
16687     for (int M : SVN->getMask())
16688       for (int i = 0; i != MaskScale; ++i)
16689         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
16690 
16691     SDValue LegalShuffle =
16692         TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
16693     if (LegalShuffle)
16694       return LegalShuffle;
16695   }
16696 
16697   return SDValue();
16698 }
16699 
visitBUILD_PAIR(SDNode * N)16700 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
16701   EVT VT = N->getValueType(0);
16702   return CombineConsecutiveLoads(N, VT);
16703 }
16704 
visitFREEZE(SDNode * N)16705 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
16706   SDValue N0 = N->getOperand(0);
16707 
16708   if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
16709     return N0;
16710 
16711   // We currently avoid folding freeze over SRA/SRL, due to the problems seen
16712   // with (freeze (assert ext)) blocking simplifications of SRA/SRL. See for
16713   // example https://reviews.llvm.org/D136529#4120959.
16714   if (N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::SRL)
16715     return SDValue();
16716 
16717   // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
16718   // Try to push freeze through instructions that propagate but don't produce
16719   // poison as far as possible. If an operand of freeze follows three
16720   // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
16721   // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
16722   // the freeze through to the operands that are not guaranteed non-poison.
16723   // NOTE: we will strip poison-generating flags, so ignore them here.
16724   if (DAG.canCreateUndefOrPoison(N0, /*PoisonOnly*/ false,
16725                                  /*ConsiderFlags*/ false) ||
16726       N0->getNumValues() != 1 || !N0->hasOneUse())
16727     return SDValue();
16728 
16729   // TOOD: we should always allow multiple operands, however this increases the
16730   // likelihood of infinite loops due to the ReplaceAllUsesOfValueWith call
16731   // below causing later nodes that share frozen operands to fold again and no
16732   // longer being able to confirm other operands are not poison due to recursion
16733   // depth limits on isGuaranteedNotToBeUndefOrPoison.
16734   bool AllowMultipleMaybePoisonOperands =
16735       N0.getOpcode() == ISD::SELECT_CC || N0.getOpcode() == ISD::SETCC ||
16736       N0.getOpcode() == ISD::BUILD_VECTOR ||
16737       N0.getOpcode() == ISD::BUILD_PAIR ||
16738       N0.getOpcode() == ISD::VECTOR_SHUFFLE ||
16739       N0.getOpcode() == ISD::CONCAT_VECTORS || N0.getOpcode() == ISD::FMUL;
16740 
16741   // Avoid turning a BUILD_VECTOR that can be recognized as "all zeros", "all
16742   // ones" or "constant" into something that depends on FrozenUndef. We can
16743   // instead pick undef values to keep those properties, while at the same time
16744   // folding away the freeze.
16745   // If we implement a more general solution for folding away freeze(undef) in
16746   // the future, then this special handling can be removed.
16747   if (N0.getOpcode() == ISD::BUILD_VECTOR) {
16748     SDLoc DL(N0);
16749     EVT VT = N0.getValueType();
16750     if (llvm::ISD::isBuildVectorAllOnes(N0.getNode()) && VT.isInteger())
16751       return DAG.getAllOnesConstant(DL, VT);
16752     if (llvm::ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
16753       SmallVector<SDValue, 8> NewVecC;
16754       for (const SDValue &Op : N0->op_values())
16755         NewVecC.push_back(
16756             Op.isUndef() ? DAG.getConstant(0, DL, Op.getValueType()) : Op);
16757       return DAG.getBuildVector(VT, DL, NewVecC);
16758     }
16759   }
16760 
16761   SmallSet<SDValue, 8> MaybePoisonOperands;
16762   SmallVector<unsigned, 8> MaybePoisonOperandNumbers;
16763   for (auto [OpNo, Op] : enumerate(N0->ops())) {
16764     if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
16765                                              /*Depth*/ 1))
16766       continue;
16767     bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
16768     bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op).second;
16769     if (IsNewMaybePoisonOperand)
16770       MaybePoisonOperandNumbers.push_back(OpNo);
16771     if (!HadMaybePoisonOperands)
16772       continue;
16773     if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
16774       // Multiple maybe-poison ops when not allowed - bail out.
16775       return SDValue();
16776     }
16777   }
16778   // NOTE: the whole op may be not guaranteed to not be undef or poison because
16779   // it could create undef or poison due to it's poison-generating flags.
16780   // So not finding any maybe-poison operands is fine.
16781 
16782   for (unsigned OpNo : MaybePoisonOperandNumbers) {
16783     // N0 can mutate during iteration, so make sure to refetch the maybe poison
16784     // operands via the operand numbers. The typical scenario is that we have
16785     // something like this
16786     //   t262: i32 = freeze t181
16787     //   t150: i32 = ctlz_zero_undef t262
16788     //   t184: i32 = ctlz_zero_undef t181
16789     //   t268: i32 = select_cc t181, Constant:i32<0>, t184, t186, setne:ch
16790     // When freezing the t181 operand we get t262 back, and then the
16791     // ReplaceAllUsesOfValueWith call will not only replace t181 by t262, but
16792     // also recursively replace t184 by t150.
16793     SDValue MaybePoisonOperand = N->getOperand(0).getOperand(OpNo);
16794     // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
16795     if (MaybePoisonOperand.isUndef())
16796       continue;
16797     // First, freeze each offending operand.
16798     SDValue FrozenMaybePoisonOperand = DAG.getFreeze(MaybePoisonOperand);
16799     // Then, change all other uses of unfrozen operand to use frozen operand.
16800     DAG.ReplaceAllUsesOfValueWith(MaybePoisonOperand, FrozenMaybePoisonOperand);
16801     if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
16802         FrozenMaybePoisonOperand.getOperand(0) == FrozenMaybePoisonOperand) {
16803       // But, that also updated the use in the freeze we just created, thus
16804       // creating a cycle in a DAG. Let's undo that by mutating the freeze.
16805       DAG.UpdateNodeOperands(FrozenMaybePoisonOperand.getNode(),
16806                              MaybePoisonOperand);
16807     }
16808 
16809     // This node has been merged with another.
16810     if (N->getOpcode() == ISD::DELETED_NODE)
16811       return SDValue(N, 0);
16812   }
16813 
16814   assert(N->getOpcode() != ISD::DELETED_NODE && "Node was deleted!");
16815 
16816   // The whole node may have been updated, so the value we were holding
16817   // may no longer be valid. Re-fetch the operand we're `freeze`ing.
16818   N0 = N->getOperand(0);
16819 
16820   // Finally, recreate the node, it's operands were updated to use
16821   // frozen operands, so we just need to use it's "original" operands.
16822   SmallVector<SDValue> Ops(N0->ops());
16823   // TODO: ISD::UNDEF and ISD::POISON should get separate handling, but best
16824   // leave for a future patch.
16825   for (SDValue &Op : Ops) {
16826     if (Op.isUndef())
16827       Op = DAG.getFreeze(Op);
16828   }
16829 
16830   SDLoc DL(N0);
16831 
16832   // Special case handling for ShuffleVectorSDNode nodes.
16833   if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N0))
16834     return DAG.getVectorShuffle(N0.getValueType(), DL, Ops[0], Ops[1],
16835                                 SVN->getMask());
16836 
16837   // NOTE: this strips poison generating flags.
16838   // Folding freeze(op(x, ...)) -> op(freeze(x), ...) does not require nnan,
16839   // ninf, nsz, or fast.
16840   // However, contract, reassoc, afn, and arcp should be preserved,
16841   // as these fast-math flags do not introduce poison values.
16842   SDNodeFlags SrcFlags = N0->getFlags();
16843   SDNodeFlags SafeFlags;
16844   SafeFlags.setAllowContract(SrcFlags.hasAllowContract());
16845   SafeFlags.setAllowReassociation(SrcFlags.hasAllowReassociation());
16846   SafeFlags.setApproximateFuncs(SrcFlags.hasApproximateFuncs());
16847   SafeFlags.setAllowReciprocal(SrcFlags.hasAllowReciprocal());
16848   return DAG.getNode(N0.getOpcode(), DL, N0->getVTList(), Ops, SafeFlags);
16849 }
16850 
16851 // Returns true if floating point contraction is allowed on the FMUL-SDValue
16852 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)16853 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
16854   assert(N.getOpcode() == ISD::FMUL);
16855 
16856   return Options.AllowFPOpFusion == FPOpFusion::Fast ||
16857          N->getFlags().hasAllowContract();
16858 }
16859 
16860 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)16861 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
16862   return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
16863 }
16864 
16865 /// Try to perform FMA combining on a given FADD node.
16866 template <class MatchContextClass>
visitFADDForFMACombine(SDNode * N)16867 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
16868   SDValue N0 = N->getOperand(0);
16869   SDValue N1 = N->getOperand(1);
16870   EVT VT = N->getValueType(0);
16871   SDLoc SL(N);
16872   MatchContextClass matcher(DAG, TLI, N);
16873   const TargetOptions &Options = DAG.getTarget().Options;
16874 
16875   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
16876 
16877   // Floating-point multiply-add with intermediate rounding.
16878   // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
16879   // FIXME: Add VP_FMAD opcode.
16880   bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
16881 
16882   // Floating-point multiply-add without intermediate rounding.
16883   bool HasFMA =
16884       (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
16885       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT);
16886 
16887   // No valid opcode, do not combine.
16888   if (!HasFMAD && !HasFMA)
16889     return SDValue();
16890 
16891   bool AllowFusionGlobally =
16892       Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
16893   // If the addition is not contractable, do not combine.
16894   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
16895     return SDValue();
16896 
16897   // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
16898   // beneficial. It does not reduce latency. It increases register pressure. It
16899   // replaces an fadd with an fma which is a more complex instruction, so is
16900   // likely to have a larger encoding, use more functional units, etc.
16901   if (N0 == N1)
16902     return SDValue();
16903 
16904   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
16905     return SDValue();
16906 
16907   // Always prefer FMAD to FMA for precision.
16908   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16909   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16910 
16911   auto isFusedOp = [&](SDValue N) {
16912     return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
16913   };
16914 
16915   // Is the node an FMUL and contractable either due to global flags or
16916   // SDNodeFlags.
16917   auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
16918     if (!matcher.match(N, ISD::FMUL))
16919       return false;
16920     return AllowFusionGlobally || N->getFlags().hasAllowContract();
16921   };
16922   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
16923   // prefer to fold the multiply with fewer uses.
16924   if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
16925     if (N0->use_size() > N1->use_size())
16926       std::swap(N0, N1);
16927   }
16928 
16929   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
16930   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
16931     return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
16932                            N0.getOperand(1), N1);
16933   }
16934 
16935   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
16936   // Note: Commutes FADD operands.
16937   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
16938     return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
16939                            N1.getOperand(1), N0);
16940   }
16941 
16942   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
16943   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
16944   // This also works with nested fma instructions:
16945   // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
16946   // fma A, B, (fma C, D, fma (E, F, G))
16947   // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
16948   // fma A, B, (fma C, D, fma (E, F, G)).
16949   // This requires reassociation because it changes the order of operations.
16950   bool CanReassociate =
16951       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16952   if (CanReassociate) {
16953     SDValue FMA, E;
16954     if (isFusedOp(N0) && N0.hasOneUse()) {
16955       FMA = N0;
16956       E = N1;
16957     } else if (isFusedOp(N1) && N1.hasOneUse()) {
16958       FMA = N1;
16959       E = N0;
16960     }
16961 
16962     SDValue TmpFMA = FMA;
16963     while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
16964       SDValue FMul = TmpFMA->getOperand(2);
16965       if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
16966         SDValue C = FMul.getOperand(0);
16967         SDValue D = FMul.getOperand(1);
16968         SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
16969         DAG.ReplaceAllUsesOfValueWith(FMul, CDE);
16970         // Replacing the inner FMul could cause the outer FMA to be simplified
16971         // away.
16972         return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
16973       }
16974 
16975       TmpFMA = TmpFMA->getOperand(2);
16976     }
16977   }
16978 
16979   // Look through FP_EXTEND nodes to do more combining.
16980 
16981   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
16982   if (matcher.match(N0, ISD::FP_EXTEND)) {
16983     SDValue N00 = N0.getOperand(0);
16984     if (isContractableFMUL(N00) &&
16985         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16986                             N00.getValueType())) {
16987       return matcher.getNode(
16988           PreferredFusedOpcode, SL, VT,
16989           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
16990           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1);
16991     }
16992   }
16993 
16994   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
16995   // Note: Commutes FADD operands.
16996   if (matcher.match(N1, ISD::FP_EXTEND)) {
16997     SDValue N10 = N1.getOperand(0);
16998     if (isContractableFMUL(N10) &&
16999         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17000                             N10.getValueType())) {
17001       return matcher.getNode(
17002           PreferredFusedOpcode, SL, VT,
17003           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
17004           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
17005     }
17006   }
17007 
17008   // More folding opportunities when target permits.
17009   if (Aggressive) {
17010     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
17011     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
17012     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17013                                     SDValue Z) {
17014       return matcher.getNode(
17015           PreferredFusedOpcode, SL, VT, X, Y,
17016           matcher.getNode(PreferredFusedOpcode, SL, VT,
17017                           matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17018                           matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17019     };
17020     if (isFusedOp(N0)) {
17021       SDValue N02 = N0.getOperand(2);
17022       if (matcher.match(N02, ISD::FP_EXTEND)) {
17023         SDValue N020 = N02.getOperand(0);
17024         if (isContractableFMUL(N020) &&
17025             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17026                                 N020.getValueType())) {
17027           return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
17028                                       N020.getOperand(0), N020.getOperand(1),
17029                                       N1);
17030         }
17031       }
17032     }
17033 
17034     // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
17035     //   -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
17036     // FIXME: This turns two single-precision and one double-precision
17037     // operation into two double-precision operations, which might not be
17038     // interesting for all targets, especially GPUs.
17039     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17040                                     SDValue Z) {
17041       return matcher.getNode(
17042           PreferredFusedOpcode, SL, VT,
17043           matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
17044           matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
17045           matcher.getNode(PreferredFusedOpcode, SL, VT,
17046                           matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17047                           matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17048     };
17049     if (N0.getOpcode() == ISD::FP_EXTEND) {
17050       SDValue N00 = N0.getOperand(0);
17051       if (isFusedOp(N00)) {
17052         SDValue N002 = N00.getOperand(2);
17053         if (isContractableFMUL(N002) &&
17054             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17055                                 N00.getValueType())) {
17056           return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
17057                                       N002.getOperand(0), N002.getOperand(1),
17058                                       N1);
17059         }
17060       }
17061     }
17062 
17063     // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
17064     //   -> (fma y, z, (fma (fpext u), (fpext v), x))
17065     if (isFusedOp(N1)) {
17066       SDValue N12 = N1.getOperand(2);
17067       if (N12.getOpcode() == ISD::FP_EXTEND) {
17068         SDValue N120 = N12.getOperand(0);
17069         if (isContractableFMUL(N120) &&
17070             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17071                                 N120.getValueType())) {
17072           return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
17073                                       N120.getOperand(0), N120.getOperand(1),
17074                                       N0);
17075         }
17076       }
17077     }
17078 
17079     // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
17080     //   -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
17081     // FIXME: This turns two single-precision and one double-precision
17082     // operation into two double-precision operations, which might not be
17083     // interesting for all targets, especially GPUs.
17084     if (N1.getOpcode() == ISD::FP_EXTEND) {
17085       SDValue N10 = N1.getOperand(0);
17086       if (isFusedOp(N10)) {
17087         SDValue N102 = N10.getOperand(2);
17088         if (isContractableFMUL(N102) &&
17089             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17090                                 N10.getValueType())) {
17091           return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
17092                                       N102.getOperand(0), N102.getOperand(1),
17093                                       N0);
17094         }
17095       }
17096     }
17097   }
17098 
17099   return SDValue();
17100 }
17101 
17102 /// Try to perform FMA combining on a given FSUB node.
17103 template <class MatchContextClass>
visitFSUBForFMACombine(SDNode * N)17104 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
17105   SDValue N0 = N->getOperand(0);
17106   SDValue N1 = N->getOperand(1);
17107   EVT VT = N->getValueType(0);
17108   SDLoc SL(N);
17109   MatchContextClass matcher(DAG, TLI, N);
17110   const TargetOptions &Options = DAG.getTarget().Options;
17111 
17112   bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
17113 
17114   // Floating-point multiply-add with intermediate rounding.
17115   // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
17116   // FIXME: Add VP_FMAD opcode.
17117   bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
17118 
17119   // Floating-point multiply-add without intermediate rounding.
17120   bool HasFMA =
17121       (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17122       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT);
17123 
17124   // No valid opcode, do not combine.
17125   if (!HasFMAD && !HasFMA)
17126     return SDValue();
17127 
17128   const SDNodeFlags Flags = N->getFlags();
17129   bool AllowFusionGlobally =
17130       (Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD);
17131 
17132   // If the subtraction is not contractable, do not combine.
17133   if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
17134     return SDValue();
17135 
17136   if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
17137     return SDValue();
17138 
17139   // Always prefer FMAD to FMA for precision.
17140   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17141   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17142   bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
17143 
17144   // Is the node an FMUL and contractable either due to global flags or
17145   // SDNodeFlags.
17146   auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
17147     if (!matcher.match(N, ISD::FMUL))
17148       return false;
17149     return AllowFusionGlobally || N->getFlags().hasAllowContract();
17150   };
17151 
17152   // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17153   auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
17154     if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
17155       return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
17156                              XY.getOperand(1),
17157                              matcher.getNode(ISD::FNEG, SL, VT, Z));
17158     }
17159     return SDValue();
17160   };
17161 
17162   // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17163   // Note: Commutes FSUB operands.
17164   auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
17165     if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
17166       return matcher.getNode(
17167           PreferredFusedOpcode, SL, VT,
17168           matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
17169           YZ.getOperand(1), X);
17170     }
17171     return SDValue();
17172   };
17173 
17174   // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
17175   // prefer to fold the multiply with fewer uses.
17176   if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
17177       (N0->use_size() > N1->use_size())) {
17178     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
17179     if (SDValue V = tryToFoldXSubYZ(N0, N1))
17180       return V;
17181     // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
17182     if (SDValue V = tryToFoldXYSubZ(N0, N1))
17183       return V;
17184   } else {
17185     // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17186     if (SDValue V = tryToFoldXYSubZ(N0, N1))
17187       return V;
17188     // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17189     if (SDValue V = tryToFoldXSubYZ(N0, N1))
17190       return V;
17191   }
17192 
17193   // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
17194   if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) &&
17195       (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
17196     SDValue N00 = N0.getOperand(0).getOperand(0);
17197     SDValue N01 = N0.getOperand(0).getOperand(1);
17198     return matcher.getNode(PreferredFusedOpcode, SL, VT,
17199                            matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
17200                            matcher.getNode(ISD::FNEG, SL, VT, N1));
17201   }
17202 
17203   // Look through FP_EXTEND nodes to do more combining.
17204 
17205   // fold (fsub (fpext (fmul x, y)), z)
17206   //   -> (fma (fpext x), (fpext y), (fneg z))
17207   if (matcher.match(N0, ISD::FP_EXTEND)) {
17208     SDValue N00 = N0.getOperand(0);
17209     if (isContractableFMUL(N00) &&
17210         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17211                             N00.getValueType())) {
17212       return matcher.getNode(
17213           PreferredFusedOpcode, SL, VT,
17214           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
17215           matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
17216           matcher.getNode(ISD::FNEG, SL, VT, N1));
17217     }
17218   }
17219 
17220   // fold (fsub x, (fpext (fmul y, z)))
17221   //   -> (fma (fneg (fpext y)), (fpext z), x)
17222   // Note: Commutes FSUB operands.
17223   if (matcher.match(N1, ISD::FP_EXTEND)) {
17224     SDValue N10 = N1.getOperand(0);
17225     if (isContractableFMUL(N10) &&
17226         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17227                             N10.getValueType())) {
17228       return matcher.getNode(
17229           PreferredFusedOpcode, SL, VT,
17230           matcher.getNode(
17231               ISD::FNEG, SL, VT,
17232               matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
17233           matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
17234     }
17235   }
17236 
17237   // fold (fsub (fpext (fneg (fmul, x, y))), z)
17238   //   -> (fneg (fma (fpext x), (fpext y), z))
17239   // Note: This could be removed with appropriate canonicalization of the
17240   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
17241   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
17242   // from implementing the canonicalization in visitFSUB.
17243   if (matcher.match(N0, ISD::FP_EXTEND)) {
17244     SDValue N00 = N0.getOperand(0);
17245     if (matcher.match(N00, ISD::FNEG)) {
17246       SDValue N000 = N00.getOperand(0);
17247       if (isContractableFMUL(N000) &&
17248           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17249                               N00.getValueType())) {
17250         return matcher.getNode(
17251             ISD::FNEG, SL, VT,
17252             matcher.getNode(
17253                 PreferredFusedOpcode, SL, VT,
17254                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
17255                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
17256                 N1));
17257       }
17258     }
17259   }
17260 
17261   // fold (fsub (fneg (fpext (fmul, x, y))), z)
17262   //   -> (fneg (fma (fpext x)), (fpext y), z)
17263   // Note: This could be removed with appropriate canonicalization of the
17264   // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
17265   // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
17266   // from implementing the canonicalization in visitFSUB.
17267   if (matcher.match(N0, ISD::FNEG)) {
17268     SDValue N00 = N0.getOperand(0);
17269     if (matcher.match(N00, ISD::FP_EXTEND)) {
17270       SDValue N000 = N00.getOperand(0);
17271       if (isContractableFMUL(N000) &&
17272           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17273                               N000.getValueType())) {
17274         return matcher.getNode(
17275             ISD::FNEG, SL, VT,
17276             matcher.getNode(
17277                 PreferredFusedOpcode, SL, VT,
17278                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
17279                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
17280                 N1));
17281       }
17282     }
17283   }
17284 
17285   auto isContractableAndReassociableFMUL = [&isContractableFMUL](SDValue N) {
17286     return isContractableFMUL(N) && N->getFlags().hasAllowReassociation();
17287   };
17288 
17289   auto isFusedOp = [&](SDValue N) {
17290     return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
17291   };
17292 
17293   // More folding opportunities when target permits.
17294   if (Aggressive && N->getFlags().hasAllowReassociation()) {
17295     bool CanFuse = N->getFlags().hasAllowContract();
17296     // fold (fsub (fma x, y, (fmul u, v)), z)
17297     //   -> (fma x, y (fma u, v, (fneg z)))
17298     if (CanFuse && isFusedOp(N0) &&
17299         isContractableAndReassociableFMUL(N0.getOperand(2)) &&
17300         N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
17301       return matcher.getNode(
17302           PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
17303           matcher.getNode(PreferredFusedOpcode, SL, VT,
17304                           N0.getOperand(2).getOperand(0),
17305                           N0.getOperand(2).getOperand(1),
17306                           matcher.getNode(ISD::FNEG, SL, VT, N1)));
17307     }
17308 
17309     // fold (fsub x, (fma y, z, (fmul u, v)))
17310     //   -> (fma (fneg y), z, (fma (fneg u), v, x))
17311     if (CanFuse && isFusedOp(N1) &&
17312         isContractableAndReassociableFMUL(N1.getOperand(2)) &&
17313         N1->hasOneUse() && NoSignedZero) {
17314       SDValue N20 = N1.getOperand(2).getOperand(0);
17315       SDValue N21 = N1.getOperand(2).getOperand(1);
17316       return matcher.getNode(
17317           PreferredFusedOpcode, SL, VT,
17318           matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
17319           N1.getOperand(1),
17320           matcher.getNode(PreferredFusedOpcode, SL, VT,
17321                           matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
17322     }
17323 
17324     // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
17325     //   -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
17326     if (isFusedOp(N0) && N0->hasOneUse()) {
17327       SDValue N02 = N0.getOperand(2);
17328       if (matcher.match(N02, ISD::FP_EXTEND)) {
17329         SDValue N020 = N02.getOperand(0);
17330         if (isContractableAndReassociableFMUL(N020) &&
17331             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17332                                 N020.getValueType())) {
17333           return matcher.getNode(
17334               PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
17335               matcher.getNode(
17336                   PreferredFusedOpcode, SL, VT,
17337                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
17338                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
17339                   matcher.getNode(ISD::FNEG, SL, VT, N1)));
17340         }
17341       }
17342     }
17343 
17344     // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
17345     //   -> (fma (fpext x), (fpext y),
17346     //           (fma (fpext u), (fpext v), (fneg z)))
17347     // FIXME: This turns two single-precision and one double-precision
17348     // operation into two double-precision operations, which might not be
17349     // interesting for all targets, especially GPUs.
17350     if (matcher.match(N0, ISD::FP_EXTEND)) {
17351       SDValue N00 = N0.getOperand(0);
17352       if (isFusedOp(N00)) {
17353         SDValue N002 = N00.getOperand(2);
17354         if (isContractableAndReassociableFMUL(N002) &&
17355             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17356                                 N00.getValueType())) {
17357           return matcher.getNode(
17358               PreferredFusedOpcode, SL, VT,
17359               matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
17360               matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
17361               matcher.getNode(
17362                   PreferredFusedOpcode, SL, VT,
17363                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
17364                   matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
17365                   matcher.getNode(ISD::FNEG, SL, VT, N1)));
17366         }
17367       }
17368     }
17369 
17370     // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
17371     //   -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
17372     if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) &&
17373         N1->hasOneUse()) {
17374       SDValue N120 = N1.getOperand(2).getOperand(0);
17375       if (isContractableAndReassociableFMUL(N120) &&
17376           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17377                               N120.getValueType())) {
17378         SDValue N1200 = N120.getOperand(0);
17379         SDValue N1201 = N120.getOperand(1);
17380         return matcher.getNode(
17381             PreferredFusedOpcode, SL, VT,
17382             matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
17383             N1.getOperand(1),
17384             matcher.getNode(
17385                 PreferredFusedOpcode, SL, VT,
17386                 matcher.getNode(ISD::FNEG, SL, VT,
17387                                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
17388                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
17389       }
17390     }
17391 
17392     // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
17393     //   -> (fma (fneg (fpext y)), (fpext z),
17394     //           (fma (fneg (fpext u)), (fpext v), x))
17395     // FIXME: This turns two single-precision and one double-precision
17396     // operation into two double-precision operations, which might not be
17397     // interesting for all targets, especially GPUs.
17398     if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) {
17399       SDValue CvtSrc = N1.getOperand(0);
17400       SDValue N100 = CvtSrc.getOperand(0);
17401       SDValue N101 = CvtSrc.getOperand(1);
17402       SDValue N102 = CvtSrc.getOperand(2);
17403       if (isContractableAndReassociableFMUL(N102) &&
17404           TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17405                               CvtSrc.getValueType())) {
17406         SDValue N1020 = N102.getOperand(0);
17407         SDValue N1021 = N102.getOperand(1);
17408         return matcher.getNode(
17409             PreferredFusedOpcode, SL, VT,
17410             matcher.getNode(ISD::FNEG, SL, VT,
17411                             matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
17412             matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
17413             matcher.getNode(
17414                 PreferredFusedOpcode, SL, VT,
17415                 matcher.getNode(ISD::FNEG, SL, VT,
17416                                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
17417                 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
17418       }
17419     }
17420   }
17421 
17422   return SDValue();
17423 }
17424 
17425 /// Try to perform FMA combining on a given FMUL node based on the distributive
17426 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
17427 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)17428 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
17429   SDValue N0 = N->getOperand(0);
17430   SDValue N1 = N->getOperand(1);
17431   EVT VT = N->getValueType(0);
17432   SDLoc SL(N);
17433 
17434   assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
17435 
17436   const TargetOptions &Options = DAG.getTarget().Options;
17437 
17438   // The transforms below are incorrect when x == 0 and y == inf, because the
17439   // intermediate multiplication produces a nan.
17440   SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
17441   if (!hasNoInfs(Options, FAdd))
17442     return SDValue();
17443 
17444   // Floating-point multiply-add without intermediate rounding.
17445   bool HasFMA =
17446       isContractableFMUL(Options, SDValue(N, 0)) &&
17447       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17448       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT);
17449 
17450   // Floating-point multiply-add with intermediate rounding. This can result
17451   // in a less precise result due to the changed rounding order.
17452   bool HasFMAD = LegalOperations && TLI.isFMADLegal(DAG, N);
17453 
17454   // No valid opcode, do not combine.
17455   if (!HasFMAD && !HasFMA)
17456     return SDValue();
17457 
17458   // Always prefer FMAD to FMA for precision.
17459   unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17460   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17461 
17462   // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
17463   // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
17464   auto FuseFADD = [&](SDValue X, SDValue Y) {
17465     if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
17466       if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
17467         if (C->isExactlyValue(+1.0))
17468           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17469                              Y);
17470         if (C->isExactlyValue(-1.0))
17471           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17472                              DAG.getNode(ISD::FNEG, SL, VT, Y));
17473       }
17474     }
17475     return SDValue();
17476   };
17477 
17478   if (SDValue FMA = FuseFADD(N0, N1))
17479     return FMA;
17480   if (SDValue FMA = FuseFADD(N1, N0))
17481     return FMA;
17482 
17483   // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
17484   // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
17485   // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
17486   // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
17487   auto FuseFSUB = [&](SDValue X, SDValue Y) {
17488     if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
17489       if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
17490         if (C0->isExactlyValue(+1.0))
17491           return DAG.getNode(PreferredFusedOpcode, SL, VT,
17492                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
17493                              Y);
17494         if (C0->isExactlyValue(-1.0))
17495           return DAG.getNode(PreferredFusedOpcode, SL, VT,
17496                              DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
17497                              DAG.getNode(ISD::FNEG, SL, VT, Y));
17498       }
17499       if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
17500         if (C1->isExactlyValue(+1.0))
17501           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17502                              DAG.getNode(ISD::FNEG, SL, VT, Y));
17503         if (C1->isExactlyValue(-1.0))
17504           return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17505                              Y);
17506       }
17507     }
17508     return SDValue();
17509   };
17510 
17511   if (SDValue FMA = FuseFSUB(N0, N1))
17512     return FMA;
17513   if (SDValue FMA = FuseFSUB(N1, N0))
17514     return FMA;
17515 
17516   return SDValue();
17517 }
17518 
visitVP_FADD(SDNode * N)17519 SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
17520   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17521 
17522   // FADD -> FMA combines:
17523   if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
17524     if (Fused.getOpcode() != ISD::DELETED_NODE)
17525       AddToWorklist(Fused.getNode());
17526     return Fused;
17527   }
17528   return SDValue();
17529 }
17530 
visitFADD(SDNode * N)17531 SDValue DAGCombiner::visitFADD(SDNode *N) {
17532   SDValue N0 = N->getOperand(0);
17533   SDValue N1 = N->getOperand(1);
17534   bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
17535   bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
17536   EVT VT = N->getValueType(0);
17537   SDLoc DL(N);
17538   const TargetOptions &Options = DAG.getTarget().Options;
17539   SDNodeFlags Flags = N->getFlags();
17540   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17541 
17542   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17543     return R;
17544 
17545   // fold (fadd c1, c2) -> c1 + c2
17546   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
17547     return C;
17548 
17549   // canonicalize constant to RHS
17550   if (N0CFP && !N1CFP)
17551     return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
17552 
17553   // fold vector ops
17554   if (VT.isVector())
17555     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17556       return FoldedVOp;
17557 
17558   // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
17559   ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
17560   if (N1C && N1C->isZero())
17561     if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
17562       return N0;
17563 
17564   if (SDValue NewSel = foldBinOpIntoSelect(N))
17565     return NewSel;
17566 
17567   // fold (fadd A, (fneg B)) -> (fsub A, B)
17568   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
17569     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
17570             N1, DAG, LegalOperations, ForCodeSize))
17571       return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
17572 
17573   // fold (fadd (fneg A), B) -> (fsub B, A)
17574   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
17575     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
17576             N0, DAG, LegalOperations, ForCodeSize))
17577       return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
17578 
17579   auto isFMulNegTwo = [](SDValue FMul) {
17580     if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
17581       return false;
17582     auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
17583     return C && C->isExactlyValue(-2.0);
17584   };
17585 
17586   // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
17587   if (isFMulNegTwo(N0)) {
17588     SDValue B = N0.getOperand(0);
17589     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
17590     return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
17591   }
17592   // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
17593   if (isFMulNegTwo(N1)) {
17594     SDValue B = N1.getOperand(0);
17595     SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
17596     return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
17597   }
17598 
17599   // No FP constant should be created after legalization as Instruction
17600   // Selection pass has a hard time dealing with FP constants.
17601   bool AllowNewConst = (Level < AfterLegalizeDAG);
17602 
17603   // If nnan is enabled, fold lots of things.
17604   if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
17605     // If allowed, fold (fadd (fneg x), x) -> 0.0
17606     if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
17607       return DAG.getConstantFP(0.0, DL, VT);
17608 
17609     // If allowed, fold (fadd x, (fneg x)) -> 0.0
17610     if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
17611       return DAG.getConstantFP(0.0, DL, VT);
17612   }
17613 
17614   // If 'unsafe math' or reassoc and nsz, fold lots of things.
17615   // TODO: break out portions of the transformations below for which Unsafe is
17616   //       considered and which do not require both nsz and reassoc
17617   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17618        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
17619       AllowNewConst) {
17620     // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
17621     if (N1CFP && N0.getOpcode() == ISD::FADD &&
17622         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
17623       SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
17624       return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
17625     }
17626 
17627     // We can fold chains of FADD's of the same value into multiplications.
17628     // This transform is not safe in general because we are reducing the number
17629     // of rounding steps.
17630     if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
17631       if (N0.getOpcode() == ISD::FMUL) {
17632         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
17633         bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
17634 
17635         // (fadd (fmul x, c), x) -> (fmul x, c+1)
17636         if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
17637           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
17638                                        DAG.getConstantFP(1.0, DL, VT));
17639           return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
17640         }
17641 
17642         // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
17643         if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
17644             N1.getOperand(0) == N1.getOperand(1) &&
17645             N0.getOperand(0) == N1.getOperand(0)) {
17646           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
17647                                        DAG.getConstantFP(2.0, DL, VT));
17648           return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
17649         }
17650       }
17651 
17652       if (N1.getOpcode() == ISD::FMUL) {
17653         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
17654         bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
17655 
17656         // (fadd x, (fmul x, c)) -> (fmul x, c+1)
17657         if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
17658           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
17659                                        DAG.getConstantFP(1.0, DL, VT));
17660           return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
17661         }
17662 
17663         // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
17664         if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
17665             N0.getOperand(0) == N0.getOperand(1) &&
17666             N1.getOperand(0) == N0.getOperand(0)) {
17667           SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
17668                                        DAG.getConstantFP(2.0, DL, VT));
17669           return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
17670         }
17671       }
17672 
17673       if (N0.getOpcode() == ISD::FADD) {
17674         bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
17675         // (fadd (fadd x, x), x) -> (fmul x, 3.0)
17676         if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
17677             (N0.getOperand(0) == N1)) {
17678           return DAG.getNode(ISD::FMUL, DL, VT, N1,
17679                              DAG.getConstantFP(3.0, DL, VT));
17680         }
17681       }
17682 
17683       if (N1.getOpcode() == ISD::FADD) {
17684         bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
17685         // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
17686         if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
17687             N1.getOperand(0) == N0) {
17688           return DAG.getNode(ISD::FMUL, DL, VT, N0,
17689                              DAG.getConstantFP(3.0, DL, VT));
17690         }
17691       }
17692 
17693       // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
17694       if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
17695           N0.getOperand(0) == N0.getOperand(1) &&
17696           N1.getOperand(0) == N1.getOperand(1) &&
17697           N0.getOperand(0) == N1.getOperand(0)) {
17698         return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
17699                            DAG.getConstantFP(4.0, DL, VT));
17700       }
17701     }
17702   } // enable-unsafe-fp-math && AllowNewConst
17703 
17704   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17705        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))) {
17706     // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
17707     if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
17708                                           VT, N0, N1, Flags))
17709       return SD;
17710   }
17711 
17712   // FADD -> FMA combines:
17713   if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
17714     if (Fused.getOpcode() != ISD::DELETED_NODE)
17715       AddToWorklist(Fused.getNode());
17716     return Fused;
17717   }
17718   return SDValue();
17719 }
17720 
visitSTRICT_FADD(SDNode * N)17721 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
17722   SDValue Chain = N->getOperand(0);
17723   SDValue N0 = N->getOperand(1);
17724   SDValue N1 = N->getOperand(2);
17725   EVT VT = N->getValueType(0);
17726   EVT ChainVT = N->getValueType(1);
17727   SDLoc DL(N);
17728   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17729 
17730   // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
17731   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
17732     if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
17733             N1, DAG, LegalOperations, ForCodeSize)) {
17734       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
17735                          {Chain, N0, NegN1});
17736     }
17737 
17738   // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
17739   if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
17740     if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
17741             N0, DAG, LegalOperations, ForCodeSize)) {
17742       return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
17743                          {Chain, N1, NegN0});
17744     }
17745   return SDValue();
17746 }
17747 
visitFSUB(SDNode * N)17748 SDValue DAGCombiner::visitFSUB(SDNode *N) {
17749   SDValue N0 = N->getOperand(0);
17750   SDValue N1 = N->getOperand(1);
17751   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
17752   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
17753   EVT VT = N->getValueType(0);
17754   SDLoc DL(N);
17755   const TargetOptions &Options = DAG.getTarget().Options;
17756   const SDNodeFlags Flags = N->getFlags();
17757   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17758 
17759   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17760     return R;
17761 
17762   // fold (fsub c1, c2) -> c1-c2
17763   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
17764     return C;
17765 
17766   // fold vector ops
17767   if (VT.isVector())
17768     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17769       return FoldedVOp;
17770 
17771   if (SDValue NewSel = foldBinOpIntoSelect(N))
17772     return NewSel;
17773 
17774   // (fsub A, 0) -> A
17775   if (N1CFP && N1CFP->isZero()) {
17776     if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
17777         Flags.hasNoSignedZeros()) {
17778       return N0;
17779     }
17780   }
17781 
17782   if (N0 == N1) {
17783     // (fsub x, x) -> 0.0
17784     if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
17785       return DAG.getConstantFP(0.0f, DL, VT);
17786   }
17787 
17788   // (fsub -0.0, N1) -> -N1
17789   if (N0CFP && N0CFP->isZero()) {
17790     if (N0CFP->isNegative() ||
17791         (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
17792       // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
17793       // flushed to zero, unless all users treat denorms as zero (DAZ).
17794       // FIXME: This transform will change the sign of a NaN and the behavior
17795       // of a signaling NaN. It is only valid when a NoNaN flag is present.
17796       DenormalMode DenormMode = DAG.getDenormalMode(VT);
17797       if (DenormMode == DenormalMode::getIEEE()) {
17798         if (SDValue NegN1 =
17799                 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
17800           return NegN1;
17801         if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
17802           return DAG.getNode(ISD::FNEG, DL, VT, N1);
17803       }
17804     }
17805   }
17806 
17807   if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17808        (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
17809       N1.getOpcode() == ISD::FADD) {
17810     // X - (X + Y) -> -Y
17811     if (N0 == N1->getOperand(0))
17812       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
17813     // X - (Y + X) -> -Y
17814     if (N0 == N1->getOperand(1))
17815       return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
17816   }
17817 
17818   // fold (fsub A, (fneg B)) -> (fadd A, B)
17819   if (SDValue NegN1 =
17820           TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
17821     return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
17822 
17823   // FSUB -> FMA combines:
17824   if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
17825     AddToWorklist(Fused.getNode());
17826     return Fused;
17827   }
17828 
17829   return SDValue();
17830 }
17831 
17832 // Transform IEEE Floats:
17833 //      (fmul C, (uitofp Pow2))
17834 //          -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
17835 //      (fdiv C, (uitofp Pow2))
17836 //          -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
17837 //
17838 // The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
17839 // there is no need for more than an add/sub.
17840 //
17841 // This is valid under the following circumstances:
17842 // 1) We are dealing with IEEE floats
17843 // 2) C is normal
17844 // 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
17845 // TODO: Much of this could also be used for generating `ldexp` on targets the
17846 // prefer it.
combineFMulOrFDivWithIntPow2(SDNode * N)17847 SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
17848   EVT VT = N->getValueType(0);
17849   if (!APFloat::isIEEELikeFP(VT.getFltSemantics()))
17850     return SDValue();
17851 
17852   SDValue ConstOp, Pow2Op;
17853 
17854   std::optional<int> Mantissa;
17855   auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
17856     if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
17857       return false;
17858 
17859     ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx));
17860     Pow2Op = N->getOperand(1 - ConstOpIdx);
17861     if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
17862         (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
17863          !DAG.computeKnownBits(Pow2Op).isNonNegative()))
17864       return false;
17865 
17866     Pow2Op = Pow2Op.getOperand(0);
17867 
17868     // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
17869     // TODO: We could use knownbits to make this bound more precise.
17870     int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
17871 
17872     auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
17873       if (CFP == nullptr)
17874         return false;
17875 
17876       const APFloat &APF = CFP->getValueAPF();
17877 
17878       // Make sure we have normal constant.
17879       if (!APF.isNormal())
17880         return false;
17881 
17882       // Make sure the floats exponent is within the bounds that this transform
17883       // produces bitwise equals value.
17884       int CurExp = ilogb(APF);
17885       // FMul by pow2 will only increase exponent.
17886       int MinExp =
17887           N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
17888       // FDiv by pow2 will only decrease exponent.
17889       int MaxExp =
17890           N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
17891       if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
17892           MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
17893         return false;
17894 
17895       // Finally make sure we actually know the mantissa for the float type.
17896       int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
17897       if (!Mantissa)
17898         Mantissa = ThisMantissa;
17899 
17900       return *Mantissa == ThisMantissa && ThisMantissa > 0;
17901     };
17902 
17903     // TODO: We may be able to include undefs.
17904     return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
17905   };
17906 
17907   if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
17908     return SDValue();
17909 
17910   if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op))
17911     return SDValue();
17912 
17913   // Get log2 after all other checks have taken place. This is because
17914   // BuildLogBase2 may create a new node.
17915   SDLoc DL(N);
17916   // Get Log2 type with same bitwidth as the float type (VT).
17917   EVT NewIntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits());
17918   if (VT.isVector())
17919     NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewIntVT,
17920                                 VT.getVectorElementCount());
17921 
17922   SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op),
17923                                /*InexpensiveOnly*/ true, NewIntVT);
17924   if (!Log2)
17925     return SDValue();
17926 
17927   // Perform actual transform.
17928   SDValue MantissaShiftCnt =
17929       DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
17930   // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
17931   // `(X << C1) + (C << C1)`, but that isn't always the case because of the
17932   // cast. We could implement that by handle here to handle the casts.
17933   SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt);
17934   SDValue ResAsInt =
17935       DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
17936                   NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift);
17937   SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt);
17938   return ResAsFP;
17939 }
17940 
visitFMUL(SDNode * N)17941 SDValue DAGCombiner::visitFMUL(SDNode *N) {
17942   SDValue N0 = N->getOperand(0);
17943   SDValue N1 = N->getOperand(1);
17944   ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
17945   EVT VT = N->getValueType(0);
17946   SDLoc DL(N);
17947   const TargetOptions &Options = DAG.getTarget().Options;
17948   const SDNodeFlags Flags = N->getFlags();
17949   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17950 
17951   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17952     return R;
17953 
17954   // fold (fmul c1, c2) -> c1*c2
17955   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
17956     return C;
17957 
17958   // canonicalize constant to RHS
17959   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
17960      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
17961     return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
17962 
17963   // fold vector ops
17964   if (VT.isVector())
17965     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17966       return FoldedVOp;
17967 
17968   if (SDValue NewSel = foldBinOpIntoSelect(N))
17969     return NewSel;
17970 
17971   if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
17972     // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
17973     if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
17974         N0.getOpcode() == ISD::FMUL) {
17975       SDValue N00 = N0.getOperand(0);
17976       SDValue N01 = N0.getOperand(1);
17977       // Avoid an infinite loop by making sure that N00 is not a constant
17978       // (the inner multiply has not been constant folded yet).
17979       if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
17980           !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
17981         SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
17982         return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
17983       }
17984     }
17985 
17986     // Match a special-case: we convert X * 2.0 into fadd.
17987     // fmul (fadd X, X), C -> fmul X, 2.0 * C
17988     if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
17989         N0.getOperand(0) == N0.getOperand(1)) {
17990       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
17991       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
17992       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
17993     }
17994 
17995     // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
17996     if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
17997                                           VT, N0, N1, Flags))
17998       return SD;
17999   }
18000 
18001   // fold (fmul X, 2.0) -> (fadd X, X)
18002   if (N1CFP && N1CFP->isExactlyValue(+2.0))
18003     return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
18004 
18005   // fold (fmul X, -1.0) -> (fsub -0.0, X)
18006   if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
18007     if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
18008       return DAG.getNode(ISD::FSUB, DL, VT,
18009                          DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
18010     }
18011   }
18012 
18013   // -N0 * -N1 --> N0 * N1
18014   TargetLowering::NegatibleCost CostN0 =
18015       TargetLowering::NegatibleCost::Expensive;
18016   TargetLowering::NegatibleCost CostN1 =
18017       TargetLowering::NegatibleCost::Expensive;
18018   SDValue NegN0 =
18019       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18020   if (NegN0) {
18021     HandleSDNode NegN0Handle(NegN0);
18022     SDValue NegN1 =
18023         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18024     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18025                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
18026       return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
18027   }
18028 
18029   // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
18030   // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
18031   if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
18032       (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
18033       TLI.isOperationLegal(ISD::FABS, VT)) {
18034     SDValue Select = N0, X = N1;
18035     if (Select.getOpcode() != ISD::SELECT)
18036       std::swap(Select, X);
18037 
18038     SDValue Cond = Select.getOperand(0);
18039     auto TrueOpnd  = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
18040     auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
18041 
18042     if (TrueOpnd && FalseOpnd &&
18043         Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
18044         isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
18045         cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
18046       ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
18047       switch (CC) {
18048       default: break;
18049       case ISD::SETOLT:
18050       case ISD::SETULT:
18051       case ISD::SETOLE:
18052       case ISD::SETULE:
18053       case ISD::SETLT:
18054       case ISD::SETLE:
18055         std::swap(TrueOpnd, FalseOpnd);
18056         [[fallthrough]];
18057       case ISD::SETOGT:
18058       case ISD::SETUGT:
18059       case ISD::SETOGE:
18060       case ISD::SETUGE:
18061       case ISD::SETGT:
18062       case ISD::SETGE:
18063         if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
18064             TLI.isOperationLegal(ISD::FNEG, VT))
18065           return DAG.getNode(ISD::FNEG, DL, VT,
18066                    DAG.getNode(ISD::FABS, DL, VT, X));
18067         if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
18068           return DAG.getNode(ISD::FABS, DL, VT, X);
18069 
18070         break;
18071       }
18072     }
18073   }
18074 
18075   // FMUL -> FMA combines:
18076   if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
18077     AddToWorklist(Fused.getNode());
18078     return Fused;
18079   }
18080 
18081   // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
18082   // able to run.
18083   if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18084     return R;
18085 
18086   return SDValue();
18087 }
18088 
visitFMA(SDNode * N)18089 template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
18090   SDValue N0 = N->getOperand(0);
18091   SDValue N1 = N->getOperand(1);
18092   SDValue N2 = N->getOperand(2);
18093   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
18094   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
18095   ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
18096   EVT VT = N->getValueType(0);
18097   SDLoc DL(N);
18098   const TargetOptions &Options = DAG.getTarget().Options;
18099   // FMA nodes have flags that propagate to the created nodes.
18100   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18101   MatchContextClass matcher(DAG, TLI, N);
18102 
18103   // Constant fold FMA.
18104   if (SDValue C =
18105           DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1, N2}))
18106     return C;
18107 
18108   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
18109   TargetLowering::NegatibleCost CostN0 =
18110       TargetLowering::NegatibleCost::Expensive;
18111   TargetLowering::NegatibleCost CostN1 =
18112       TargetLowering::NegatibleCost::Expensive;
18113   SDValue NegN0 =
18114       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18115   if (NegN0) {
18116     HandleSDNode NegN0Handle(NegN0);
18117     SDValue NegN1 =
18118         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18119     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18120                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
18121       return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
18122   }
18123 
18124   // FIXME: use fast math flags instead of Options.UnsafeFPMath
18125   // TODO: Finally migrate away from global TargetOptions.
18126   if ((Options.NoNaNsFPMath && Options.NoInfsFPMath) ||
18127       (N->getFlags().hasNoNaNs() && N->getFlags().hasNoInfs())) {
18128     if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros() ||
18129         (N2CFP && !N2CFP->isExactlyValue(-0.0))) {
18130       if (N0CFP && N0CFP->isZero())
18131         return N2;
18132       if (N1CFP && N1CFP->isZero())
18133         return N2;
18134     }
18135   }
18136 
18137   // FIXME: Support splat of constant.
18138   if (N0CFP && N0CFP->isExactlyValue(1.0))
18139     return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
18140   if (N1CFP && N1CFP->isExactlyValue(1.0))
18141     return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18142 
18143   // Canonicalize (fma c, x, y) -> (fma x, c, y)
18144   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
18145      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
18146     return matcher.getNode(ISD::FMA, DL, VT, N1, N0, N2);
18147 
18148   bool CanReassociate =
18149       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
18150   if (CanReassociate) {
18151     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
18152     if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) &&
18153         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
18154         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
18155       return matcher.getNode(
18156           ISD::FMUL, DL, VT, N0,
18157           matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
18158     }
18159 
18160     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
18161     if (matcher.match(N0, ISD::FMUL) &&
18162         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
18163         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
18164       return matcher.getNode(
18165           ISD::FMA, DL, VT, N0.getOperand(0),
18166           matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2);
18167     }
18168   }
18169 
18170   // (fma x, -1, y) -> (fadd (fneg x), y)
18171   // FIXME: Support splat of constant.
18172   if (N1CFP) {
18173     if (N1CFP->isExactlyValue(1.0))
18174       return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18175 
18176     if (N1CFP->isExactlyValue(-1.0) &&
18177         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
18178       SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
18179       AddToWorklist(RHSNeg.getNode());
18180       return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
18181     }
18182 
18183     // fma (fneg x), K, y -> fma x -K, y
18184     if (matcher.match(N0, ISD::FNEG) &&
18185         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
18186          (N1.hasOneUse() &&
18187           !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
18188       return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
18189                              matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
18190     }
18191   }
18192 
18193   // FIXME: Support splat of constant.
18194   if (CanReassociate) {
18195     // (fma x, c, x) -> (fmul x, (c+1))
18196     if (N1CFP && N0 == N2) {
18197       return matcher.getNode(ISD::FMUL, DL, VT, N0,
18198                              matcher.getNode(ISD::FADD, DL, VT, N1,
18199                                              DAG.getConstantFP(1.0, DL, VT)));
18200     }
18201 
18202     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
18203     if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) {
18204       return matcher.getNode(ISD::FMUL, DL, VT, N0,
18205                              matcher.getNode(ISD::FADD, DL, VT, N1,
18206                                              DAG.getConstantFP(-1.0, DL, VT)));
18207     }
18208   }
18209 
18210   // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
18211   // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
18212   if (!TLI.isFNegFree(VT))
18213     if (SDValue Neg = TLI.getCheaperNegatedExpression(
18214             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
18215       return matcher.getNode(ISD::FNEG, DL, VT, Neg);
18216   return SDValue();
18217 }
18218 
visitFMAD(SDNode * N)18219 SDValue DAGCombiner::visitFMAD(SDNode *N) {
18220   SDValue N0 = N->getOperand(0);
18221   SDValue N1 = N->getOperand(1);
18222   SDValue N2 = N->getOperand(2);
18223   EVT VT = N->getValueType(0);
18224   SDLoc DL(N);
18225 
18226   // Constant fold FMAD.
18227   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMAD, DL, VT, {N0, N1, N2}))
18228     return C;
18229 
18230   return SDValue();
18231 }
18232 
18233 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
18234 // reciprocal.
18235 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
18236 // Notice that this is not always beneficial. One reason is different targets
18237 // may have different costs for FDIV and FMUL, so sometimes the cost of two
18238 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
18239 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)18240 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
18241   // TODO: Limit this transform based on optsize/minsize - it always creates at
18242   //       least 1 extra instruction. But the perf win may be substantial enough
18243   //       that only minsize should restrict this.
18244   bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
18245   const SDNodeFlags Flags = N->getFlags();
18246   if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
18247     return SDValue();
18248 
18249   // Skip if current node is a reciprocal/fneg-reciprocal.
18250   SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
18251   ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
18252   if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
18253     return SDValue();
18254 
18255   // Exit early if the target does not want this transform or if there can't
18256   // possibly be enough uses of the divisor to make the transform worthwhile.
18257   unsigned MinUses = TLI.combineRepeatedFPDivisors();
18258 
18259   // For splat vectors, scale the number of uses by the splat factor. If we can
18260   // convert the division into a scalar op, that will likely be much faster.
18261   unsigned NumElts = 1;
18262   EVT VT = N->getValueType(0);
18263   if (VT.isVector() && DAG.isSplatValue(N1))
18264     NumElts = VT.getVectorMinNumElements();
18265 
18266   if (!MinUses || (N1->use_size() * NumElts) < MinUses)
18267     return SDValue();
18268 
18269   // Find all FDIV users of the same divisor.
18270   // Use a set because duplicates may be present in the user list.
18271   SetVector<SDNode *> Users;
18272   for (auto *U : N1->users()) {
18273     if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
18274       // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
18275       if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
18276           U->getOperand(0) == U->getOperand(1).getOperand(0) &&
18277           U->getFlags().hasAllowReassociation() &&
18278           U->getFlags().hasNoSignedZeros())
18279         continue;
18280 
18281       // This division is eligible for optimization only if global unsafe math
18282       // is enabled or if this division allows reciprocal formation.
18283       if (UnsafeMath || U->getFlags().hasAllowReciprocal())
18284         Users.insert(U);
18285     }
18286   }
18287 
18288   // Now that we have the actual number of divisor uses, make sure it meets
18289   // the minimum threshold specified by the target.
18290   if ((Users.size() * NumElts) < MinUses)
18291     return SDValue();
18292 
18293   SDLoc DL(N);
18294   SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
18295   SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
18296 
18297   // Dividend / Divisor -> Dividend * Reciprocal
18298   for (auto *U : Users) {
18299     SDValue Dividend = U->getOperand(0);
18300     if (Dividend != FPOne) {
18301       SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
18302                                     Reciprocal, Flags);
18303       CombineTo(U, NewNode);
18304     } else if (U != Reciprocal.getNode()) {
18305       // In the absence of fast-math-flags, this user node is always the
18306       // same node as Reciprocal, but with FMF they may be different nodes.
18307       CombineTo(U, Reciprocal);
18308     }
18309   }
18310   return SDValue(N, 0);  // N was replaced.
18311 }
18312 
visitFDIV(SDNode * N)18313 SDValue DAGCombiner::visitFDIV(SDNode *N) {
18314   SDValue N0 = N->getOperand(0);
18315   SDValue N1 = N->getOperand(1);
18316   EVT VT = N->getValueType(0);
18317   SDLoc DL(N);
18318   const TargetOptions &Options = DAG.getTarget().Options;
18319   SDNodeFlags Flags = N->getFlags();
18320   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18321 
18322   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
18323     return R;
18324 
18325   // fold (fdiv c1, c2) -> c1/c2
18326   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
18327     return C;
18328 
18329   // fold vector ops
18330   if (VT.isVector())
18331     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18332       return FoldedVOp;
18333 
18334   if (SDValue NewSel = foldBinOpIntoSelect(N))
18335     return NewSel;
18336 
18337   if (SDValue V = combineRepeatedFPDivisors(N))
18338     return V;
18339 
18340   // fold (fdiv X, c2) -> (fmul X, 1/c2) if there is no loss in precision, or
18341   // the loss is acceptable with AllowReciprocal.
18342   if (auto *N1CFP = isConstOrConstSplatFP(N1, true)) {
18343     // Compute the reciprocal 1.0 / c2.
18344     const APFloat &N1APF = N1CFP->getValueAPF();
18345     APFloat Recip = APFloat::getOne(N1APF.getSemantics());
18346     APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
18347     // Only do the transform if the reciprocal is a legal fp immediate that
18348     // isn't too nasty (eg NaN, denormal, ...).
18349     if (((st == APFloat::opOK && !Recip.isDenormal()) ||
18350          (st == APFloat::opInexact && Flags.hasAllowReciprocal())) &&
18351         (!LegalOperations ||
18352          // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
18353          // backend)... we should handle this gracefully after Legalize.
18354          // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
18355          TLI.isOperationLegal(ISD::ConstantFP, VT) ||
18356          TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
18357       return DAG.getNode(ISD::FMUL, DL, VT, N0,
18358                          DAG.getConstantFP(Recip, DL, VT));
18359   }
18360 
18361   if (Flags.hasAllowReciprocal()) {
18362     // If this FDIV is part of a reciprocal square root, it may be folded
18363     // into a target-specific square root estimate instruction.
18364     if (N1.getOpcode() == ISD::FSQRT) {
18365       if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
18366         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
18367     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
18368                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
18369       if (SDValue RV =
18370               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
18371         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
18372         AddToWorklist(RV.getNode());
18373         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
18374       }
18375     } else if (N1.getOpcode() == ISD::FP_ROUND &&
18376                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
18377       if (SDValue RV =
18378               buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
18379         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
18380         AddToWorklist(RV.getNode());
18381         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
18382       }
18383     } else if (N1.getOpcode() == ISD::FMUL) {
18384       // Look through an FMUL. Even though this won't remove the FDIV directly,
18385       // it's still worthwhile to get rid of the FSQRT if possible.
18386       SDValue Sqrt, Y;
18387       if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
18388         Sqrt = N1.getOperand(0);
18389         Y = N1.getOperand(1);
18390       } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
18391         Sqrt = N1.getOperand(1);
18392         Y = N1.getOperand(0);
18393       }
18394       if (Sqrt.getNode()) {
18395         // If the other multiply operand is known positive, pull it into the
18396         // sqrt. That will eliminate the division if we convert to an estimate.
18397         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
18398             N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
18399           SDValue A;
18400           if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
18401             A = Y.getOperand(0);
18402           else if (Y == Sqrt.getOperand(0))
18403             A = Y;
18404           if (A) {
18405             // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
18406             // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
18407             SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
18408             SDValue AAZ =
18409                 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
18410             if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
18411               return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
18412 
18413             // Estimate creation failed. Clean up speculatively created nodes.
18414             recursivelyDeleteUnusedNodes(AAZ.getNode());
18415           }
18416         }
18417 
18418         // We found a FSQRT, so try to make this fold:
18419         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
18420         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
18421           SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
18422           AddToWorklist(Div.getNode());
18423           return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
18424         }
18425       }
18426     }
18427 
18428     // Fold into a reciprocal estimate and multiply instead of a real divide.
18429     if (Options.NoInfsFPMath || Flags.hasNoInfs())
18430       if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
18431         return RV;
18432   }
18433 
18434   // Fold X/Sqrt(X) -> Sqrt(X)
18435   if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
18436       Flags.hasAllowReassociation())
18437     if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
18438       return N1;
18439 
18440   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
18441   TargetLowering::NegatibleCost CostN0 =
18442       TargetLowering::NegatibleCost::Expensive;
18443   TargetLowering::NegatibleCost CostN1 =
18444       TargetLowering::NegatibleCost::Expensive;
18445   SDValue NegN0 =
18446       TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18447   if (NegN0) {
18448     HandleSDNode NegN0Handle(NegN0);
18449     SDValue NegN1 =
18450         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18451     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18452                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
18453       return DAG.getNode(ISD::FDIV, DL, VT, NegN0, NegN1);
18454   }
18455 
18456   if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18457     return R;
18458 
18459   return SDValue();
18460 }
18461 
visitFREM(SDNode * N)18462 SDValue DAGCombiner::visitFREM(SDNode *N) {
18463   SDValue N0 = N->getOperand(0);
18464   SDValue N1 = N->getOperand(1);
18465   EVT VT = N->getValueType(0);
18466   SDNodeFlags Flags = N->getFlags();
18467   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18468   SDLoc DL(N);
18469 
18470   if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
18471     return R;
18472 
18473   // fold (frem c1, c2) -> fmod(c1,c2)
18474   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, DL, VT, {N0, N1}))
18475     return C;
18476 
18477   if (SDValue NewSel = foldBinOpIntoSelect(N))
18478     return NewSel;
18479 
18480   // Lower frem N0, N1 => x - trunc(N0 / N1) * N1, providing N1 is an integer
18481   // power of 2.
18482   if (!TLI.isOperationLegal(ISD::FREM, VT) &&
18483       TLI.isOperationLegalOrCustom(ISD::FMUL, VT) &&
18484       TLI.isOperationLegalOrCustom(ISD::FDIV, VT) &&
18485       TLI.isOperationLegalOrCustom(ISD::FTRUNC, VT) &&
18486       DAG.isKnownToBeAPowerOfTwoFP(N1)) {
18487     bool NeedsCopySign =
18488         !Flags.hasNoSignedZeros() && !DAG.cannotBeOrderedNegativeFP(N0);
18489     SDValue Div = DAG.getNode(ISD::FDIV, DL, VT, N0, N1);
18490     SDValue Rnd = DAG.getNode(ISD::FTRUNC, DL, VT, Div);
18491     SDValue MLA;
18492     if (TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT)) {
18493       MLA = DAG.getNode(ISD::FMA, DL, VT, DAG.getNode(ISD::FNEG, DL, VT, Rnd),
18494                         N1, N0);
18495     } else {
18496       SDValue Mul = DAG.getNode(ISD::FMUL, DL, VT, Rnd, N1);
18497       MLA = DAG.getNode(ISD::FSUB, DL, VT, N0, Mul);
18498     }
18499     return NeedsCopySign ? DAG.getNode(ISD::FCOPYSIGN, DL, VT, MLA, N0) : MLA;
18500   }
18501 
18502   return SDValue();
18503 }
18504 
visitFSQRT(SDNode * N)18505 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
18506   SDNodeFlags Flags = N->getFlags();
18507   const TargetOptions &Options = DAG.getTarget().Options;
18508 
18509   // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
18510   // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
18511   if (!Flags.hasApproximateFuncs() ||
18512       (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
18513     return SDValue();
18514 
18515   SDValue N0 = N->getOperand(0);
18516   if (TLI.isFsqrtCheap(N0, DAG))
18517     return SDValue();
18518 
18519   // FSQRT nodes have flags that propagate to the created nodes.
18520   // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
18521   //       transform the fdiv, we may produce a sub-optimal estimate sequence
18522   //       because the reciprocal calculation may not have to filter out a
18523   //       0.0 input.
18524   return buildSqrtEstimate(N0, Flags);
18525 }
18526 
18527 /// copysign(x, fp_extend(y)) -> copysign(x, y)
18528 /// copysign(x, fp_round(y)) -> copysign(x, y)
18529 /// Operands to the functions are the type of X and Y respectively.
CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy,EVT YTy)18530 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
18531   // Always fold no-op FP casts.
18532   if (XTy == YTy)
18533     return true;
18534 
18535   // Do not optimize out type conversion of f128 type yet.
18536   // For some targets like x86_64, configuration is changed to keep one f128
18537   // value in one SSE register, but instruction selection cannot handle
18538   // FCOPYSIGN on SSE registers yet.
18539   if (YTy == MVT::f128)
18540     return false;
18541 
18542   // Avoid mismatched vector operand types, for better instruction selection.
18543   return !YTy.isVector();
18544 }
18545 
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)18546 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
18547   SDValue N1 = N->getOperand(1);
18548   if (N1.getOpcode() != ISD::FP_EXTEND &&
18549       N1.getOpcode() != ISD::FP_ROUND)
18550     return false;
18551   EVT N1VT = N1->getValueType(0);
18552   EVT N1Op0VT = N1->getOperand(0).getValueType();
18553   return CanCombineFCOPYSIGN_EXTEND_ROUND(N1VT, N1Op0VT);
18554 }
18555 
visitFCOPYSIGN(SDNode * N)18556 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
18557   SDValue N0 = N->getOperand(0);
18558   SDValue N1 = N->getOperand(1);
18559   EVT VT = N->getValueType(0);
18560   SDLoc DL(N);
18561 
18562   // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
18563   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, DL, VT, {N0, N1}))
18564     return C;
18565 
18566   // copysign(x, fp_extend(y)) -> copysign(x, y)
18567   // copysign(x, fp_round(y)) -> copysign(x, y)
18568   if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
18569     return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N0, N1.getOperand(0));
18570 
18571   if (SimplifyDemandedBits(SDValue(N, 0)))
18572     return SDValue(N, 0);
18573 
18574   return SDValue();
18575 }
18576 
visitFPOW(SDNode * N)18577 SDValue DAGCombiner::visitFPOW(SDNode *N) {
18578   ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
18579   if (!ExponentC)
18580     return SDValue();
18581   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18582 
18583   // Try to convert x ** (1/3) into cube root.
18584   // TODO: Handle the various flavors of long double.
18585   // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
18586   //       Some range near 1/3 should be fine.
18587   EVT VT = N->getValueType(0);
18588   if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
18589       (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
18590     // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
18591     // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
18592     // pow(-val, 1/3) =  nan; cbrt(-val) = -num.
18593     // For regular numbers, rounding may cause the results to differ.
18594     // Therefore, we require { nsz ninf nnan afn } for this transform.
18595     // TODO: We could select out the special cases if we don't have nsz/ninf.
18596     SDNodeFlags Flags = N->getFlags();
18597     if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
18598         !Flags.hasApproximateFuncs())
18599       return SDValue();
18600 
18601     // Do not create a cbrt() libcall if the target does not have it, and do not
18602     // turn a pow that has lowering support into a cbrt() libcall.
18603     if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
18604         (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
18605          DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
18606       return SDValue();
18607 
18608     return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
18609   }
18610 
18611   // Try to convert x ** (1/4) and x ** (3/4) into square roots.
18612   // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
18613   // TODO: This could be extended (using a target hook) to handle smaller
18614   // power-of-2 fractional exponents.
18615   bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
18616   bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
18617   if (ExponentIs025 || ExponentIs075) {
18618     // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
18619     // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) =  NaN.
18620     // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
18621     // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) =  NaN.
18622     // For regular numbers, rounding may cause the results to differ.
18623     // Therefore, we require { nsz ninf afn } for this transform.
18624     // TODO: We could select out the special cases if we don't have nsz/ninf.
18625     SDNodeFlags Flags = N->getFlags();
18626 
18627     // We only need no signed zeros for the 0.25 case.
18628     if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
18629         !Flags.hasApproximateFuncs())
18630       return SDValue();
18631 
18632     // Don't double the number of libcalls. We are trying to inline fast code.
18633     if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
18634       return SDValue();
18635 
18636     // Assume that libcalls are the smallest code.
18637     // TODO: This restriction should probably be lifted for vectors.
18638     if (ForCodeSize)
18639       return SDValue();
18640 
18641     // pow(X, 0.25) --> sqrt(sqrt(X))
18642     SDLoc DL(N);
18643     SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
18644     SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
18645     if (ExponentIs025)
18646       return SqrtSqrt;
18647     // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
18648     return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
18649   }
18650 
18651   return SDValue();
18652 }
18653 
foldFPToIntToFP(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const TargetLowering & TLI)18654 static SDValue foldFPToIntToFP(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
18655                                const TargetLowering &TLI) {
18656   // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
18657   // replacing casts with a libcall. We also must be allowed to ignore -0.0
18658   // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
18659   // conversions would return +0.0.
18660   // FIXME: We should be able to use node-level FMF here.
18661   // TODO: If strict math, should we use FABS (+ range check for signed cast)?
18662   EVT VT = N->getValueType(0);
18663   if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
18664       !DAG.getTarget().Options.NoSignedZerosFPMath)
18665     return SDValue();
18666 
18667   // fptosi/fptoui round towards zero, so converting from FP to integer and
18668   // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
18669   SDValue N0 = N->getOperand(0);
18670   if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
18671       N0.getOperand(0).getValueType() == VT)
18672     return DAG.getNode(ISD::FTRUNC, DL, VT, N0.getOperand(0));
18673 
18674   if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
18675       N0.getOperand(0).getValueType() == VT)
18676     return DAG.getNode(ISD::FTRUNC, DL, VT, N0.getOperand(0));
18677 
18678   return SDValue();
18679 }
18680 
visitSINT_TO_FP(SDNode * N)18681 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
18682   SDValue N0 = N->getOperand(0);
18683   EVT VT = N->getValueType(0);
18684   EVT OpVT = N0.getValueType();
18685   SDLoc DL(N);
18686 
18687   // [us]itofp(undef) = 0, because the result value is bounded.
18688   if (N0.isUndef())
18689     return DAG.getConstantFP(0.0, DL, VT);
18690 
18691   // fold (sint_to_fp c1) -> c1fp
18692   // ...but only if the target supports immediate floating-point values
18693   if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18694     if (SDValue C = DAG.FoldConstantArithmetic(ISD::SINT_TO_FP, DL, VT, {N0}))
18695       return C;
18696 
18697   // If the input is a legal type, and SINT_TO_FP is not legal on this target,
18698   // but UINT_TO_FP is legal on this target, try to convert.
18699   if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
18700       hasOperation(ISD::UINT_TO_FP, OpVT)) {
18701     // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
18702     if (DAG.SignBitIsZero(N0))
18703       return DAG.getNode(ISD::UINT_TO_FP, DL, VT, N0);
18704   }
18705 
18706   // The next optimizations are desirable only if SELECT_CC can be lowered.
18707   // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
18708   if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
18709       !VT.isVector() &&
18710       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18711     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
18712                          DAG.getConstantFP(0.0, DL, VT));
18713 
18714   // fold (sint_to_fp (zext (setcc x, y, cc))) ->
18715   //      (select (setcc x, y, cc), 1.0, 0.0)
18716   if (N0.getOpcode() == ISD::ZERO_EXTEND &&
18717       N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
18718       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18719     return DAG.getSelect(DL, VT, N0.getOperand(0),
18720                          DAG.getConstantFP(1.0, DL, VT),
18721                          DAG.getConstantFP(0.0, DL, VT));
18722 
18723   if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
18724     return FTrunc;
18725 
18726   return SDValue();
18727 }
18728 
visitUINT_TO_FP(SDNode * N)18729 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
18730   SDValue N0 = N->getOperand(0);
18731   EVT VT = N->getValueType(0);
18732   EVT OpVT = N0.getValueType();
18733   SDLoc DL(N);
18734 
18735   // [us]itofp(undef) = 0, because the result value is bounded.
18736   if (N0.isUndef())
18737     return DAG.getConstantFP(0.0, DL, VT);
18738 
18739   // fold (uint_to_fp c1) -> c1fp
18740   // ...but only if the target supports immediate floating-point values
18741   if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18742     if (SDValue C = DAG.FoldConstantArithmetic(ISD::UINT_TO_FP, DL, VT, {N0}))
18743       return C;
18744 
18745   // If the input is a legal type, and UINT_TO_FP is not legal on this target,
18746   // but SINT_TO_FP is legal on this target, try to convert.
18747   if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
18748       hasOperation(ISD::SINT_TO_FP, OpVT)) {
18749     // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
18750     if (DAG.SignBitIsZero(N0))
18751       return DAG.getNode(ISD::SINT_TO_FP, DL, VT, N0);
18752   }
18753 
18754   // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
18755   if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
18756       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18757     return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
18758                          DAG.getConstantFP(0.0, DL, VT));
18759 
18760   if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
18761     return FTrunc;
18762 
18763   return SDValue();
18764 }
18765 
18766 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)18767 static SDValue FoldIntToFPToInt(SDNode *N, const SDLoc &DL, SelectionDAG &DAG) {
18768   SDValue N0 = N->getOperand(0);
18769   EVT VT = N->getValueType(0);
18770 
18771   if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
18772     return SDValue();
18773 
18774   SDValue Src = N0.getOperand(0);
18775   EVT SrcVT = Src.getValueType();
18776   bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
18777   bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
18778 
18779   // We can safely assume the conversion won't overflow the output range,
18780   // because (for example) (uint8_t)18293.f is undefined behavior.
18781 
18782   // Since we can assume the conversion won't overflow, our decision as to
18783   // whether the input will fit in the float should depend on the minimum
18784   // of the input range and output range.
18785 
18786   // This means this is also safe for a signed input and unsigned output, since
18787   // a negative input would lead to undefined behavior.
18788   unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
18789   unsigned OutputSize = (int)VT.getScalarSizeInBits();
18790   unsigned ActualSize = std::min(InputSize, OutputSize);
18791   const fltSemantics &Sem = N0.getValueType().getFltSemantics();
18792 
18793   // We can only fold away the float conversion if the input range can be
18794   // represented exactly in the float range.
18795   if (APFloat::semanticsPrecision(Sem) >= ActualSize) {
18796     if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
18797       unsigned ExtOp =
18798           IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
18799       return DAG.getNode(ExtOp, DL, VT, Src);
18800     }
18801     if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
18802       return DAG.getNode(ISD::TRUNCATE, DL, VT, Src);
18803     return DAG.getBitcast(VT, Src);
18804   }
18805   return SDValue();
18806 }
18807 
visitFP_TO_SINT(SDNode * N)18808 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
18809   SDValue N0 = N->getOperand(0);
18810   EVT VT = N->getValueType(0);
18811   SDLoc DL(N);
18812 
18813   // fold (fp_to_sint undef) -> undef
18814   if (N0.isUndef())
18815     return DAG.getUNDEF(VT);
18816 
18817   // fold (fp_to_sint c1fp) -> c1
18818   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_TO_SINT, DL, VT, {N0}))
18819     return C;
18820 
18821   return FoldIntToFPToInt(N, DL, DAG);
18822 }
18823 
visitFP_TO_UINT(SDNode * N)18824 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
18825   SDValue N0 = N->getOperand(0);
18826   EVT VT = N->getValueType(0);
18827   SDLoc DL(N);
18828 
18829   // fold (fp_to_uint undef) -> undef
18830   if (N0.isUndef())
18831     return DAG.getUNDEF(VT);
18832 
18833   // fold (fp_to_uint c1fp) -> c1
18834   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_TO_UINT, DL, VT, {N0}))
18835     return C;
18836 
18837   return FoldIntToFPToInt(N, DL, DAG);
18838 }
18839 
visitXROUND(SDNode * N)18840 SDValue DAGCombiner::visitXROUND(SDNode *N) {
18841   SDValue N0 = N->getOperand(0);
18842   EVT VT = N->getValueType(0);
18843 
18844   // fold (lrint|llrint undef) -> undef
18845   // fold (lround|llround undef) -> undef
18846   if (N0.isUndef())
18847     return DAG.getUNDEF(VT);
18848 
18849   // fold (lrint|llrint c1fp) -> c1
18850   // fold (lround|llround c1fp) -> c1
18851   if (SDValue C =
18852           DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0}))
18853     return C;
18854 
18855   return SDValue();
18856 }
18857 
visitFP_ROUND(SDNode * N)18858 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
18859   SDValue N0 = N->getOperand(0);
18860   SDValue N1 = N->getOperand(1);
18861   EVT VT = N->getValueType(0);
18862   SDLoc DL(N);
18863 
18864   // fold (fp_round c1fp) -> c1fp
18865   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_ROUND, DL, VT, {N0, N1}))
18866     return C;
18867 
18868   // fold (fp_round (fp_extend x)) -> x
18869   if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
18870     return N0.getOperand(0);
18871 
18872   // fold (fp_round (fp_round x)) -> (fp_round x)
18873   if (N0.getOpcode() == ISD::FP_ROUND) {
18874     const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
18875     const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
18876 
18877     // Avoid folding legal fp_rounds into non-legal ones.
18878     if (!hasOperation(ISD::FP_ROUND, VT))
18879       return SDValue();
18880 
18881     // Skip this folding if it results in an fp_round from f80 to f16.
18882     //
18883     // f80 to f16 always generates an expensive (and as yet, unimplemented)
18884     // libcall to __truncxfhf2 instead of selecting native f16 conversion
18885     // instructions from f32 or f64.  Moreover, the first (value-preserving)
18886     // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
18887     // x86.
18888     if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
18889       return SDValue();
18890 
18891     // If the first fp_round isn't a value preserving truncation, it might
18892     // introduce a tie in the second fp_round, that wouldn't occur in the
18893     // single-step fp_round we want to fold to.
18894     // In other words, double rounding isn't the same as rounding.
18895     // Also, this is a value preserving truncation iff both fp_round's are.
18896     if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc)
18897       return DAG.getNode(
18898           ISD::FP_ROUND, DL, VT, N0.getOperand(0),
18899           DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
18900   }
18901 
18902   // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
18903   // Note: From a legality perspective, this is a two step transform.  First,
18904   // we duplicate the fp_round to the arguments of the copysign, then we
18905   // eliminate the fp_round on Y.  The second step requires an additional
18906   // predicate to match the implementation above.
18907   if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
18908       CanCombineFCOPYSIGN_EXTEND_ROUND(VT,
18909                                        N0.getValueType())) {
18910     SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
18911                               N0.getOperand(0), N1);
18912     AddToWorklist(Tmp.getNode());
18913     return DAG.getNode(ISD::FCOPYSIGN, DL, VT, Tmp, N0.getOperand(1));
18914   }
18915 
18916   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
18917     return NewVSel;
18918 
18919   return SDValue();
18920 }
18921 
18922 // Eliminate a floating-point widening of a narrowed value if the fast math
18923 // flags allow it.
eliminateFPCastPair(SDNode * N)18924 static SDValue eliminateFPCastPair(SDNode *N) {
18925   SDValue N0 = N->getOperand(0);
18926   EVT VT = N->getValueType(0);
18927 
18928   unsigned NarrowingOp;
18929   switch (N->getOpcode()) {
18930   case ISD::FP16_TO_FP:
18931     NarrowingOp = ISD::FP_TO_FP16;
18932     break;
18933   case ISD::BF16_TO_FP:
18934     NarrowingOp = ISD::FP_TO_BF16;
18935     break;
18936   case ISD::FP_EXTEND:
18937     NarrowingOp = ISD::FP_ROUND;
18938     break;
18939   default:
18940     llvm_unreachable("Expected widening FP cast");
18941   }
18942 
18943   if (N0.getOpcode() == NarrowingOp && N0.getOperand(0).getValueType() == VT) {
18944     const SDNodeFlags NarrowFlags = N0->getFlags();
18945     const SDNodeFlags WidenFlags = N->getFlags();
18946     // Narrowing can introduce inf and change the encoding of a nan, so the
18947     // widen must have the nnan and ninf flags to indicate that we don't need to
18948     // care about that. We are also removing a rounding step, and that requires
18949     // both the narrow and widen to allow contraction.
18950     if (WidenFlags.hasNoNaNs() && WidenFlags.hasNoInfs() &&
18951         NarrowFlags.hasAllowContract() && WidenFlags.hasAllowContract()) {
18952       return N0.getOperand(0);
18953     }
18954   }
18955 
18956   return SDValue();
18957 }
18958 
visitFP_EXTEND(SDNode * N)18959 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
18960   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18961   SDValue N0 = N->getOperand(0);
18962   EVT VT = N->getValueType(0);
18963   SDLoc DL(N);
18964 
18965   if (VT.isVector())
18966     if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
18967       return FoldedVOp;
18968 
18969   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
18970   if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::FP_ROUND)
18971     return SDValue();
18972 
18973   // fold (fp_extend c1fp) -> c1fp
18974   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_EXTEND, DL, VT, {N0}))
18975     return C;
18976 
18977   // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
18978   if (N0.getOpcode() == ISD::FP16_TO_FP &&
18979       TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
18980     return DAG.getNode(ISD::FP16_TO_FP, DL, VT, N0.getOperand(0));
18981 
18982   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
18983   // value of X.
18984   if (N0.getOpcode() == ISD::FP_ROUND && N0.getConstantOperandVal(1) == 1) {
18985     SDValue In = N0.getOperand(0);
18986     if (In.getValueType() == VT) return In;
18987     if (VT.bitsLT(In.getValueType()))
18988       return DAG.getNode(ISD::FP_ROUND, DL, VT, In, N0.getOperand(1));
18989     return DAG.getNode(ISD::FP_EXTEND, DL, VT, In);
18990   }
18991 
18992   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
18993   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
18994       TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
18995     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
18996     SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT,
18997                                      LN0->getChain(),
18998                                      LN0->getBasePtr(), N0.getValueType(),
18999                                      LN0->getMemOperand());
19000     CombineTo(N, ExtLoad);
19001     CombineTo(
19002         N0.getNode(),
19003         DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
19004                     DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
19005         ExtLoad.getValue(1));
19006     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
19007   }
19008 
19009   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
19010     return NewVSel;
19011 
19012   if (SDValue CastEliminated = eliminateFPCastPair(N))
19013     return CastEliminated;
19014 
19015   return SDValue();
19016 }
19017 
visitFCEIL(SDNode * N)19018 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
19019   SDValue N0 = N->getOperand(0);
19020   EVT VT = N->getValueType(0);
19021 
19022   // fold (fceil c1) -> fceil(c1)
19023   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FCEIL, SDLoc(N), VT, {N0}))
19024     return C;
19025 
19026   return SDValue();
19027 }
19028 
visitFTRUNC(SDNode * N)19029 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
19030   SDValue N0 = N->getOperand(0);
19031   EVT VT = N->getValueType(0);
19032 
19033   // fold (ftrunc c1) -> ftrunc(c1)
19034   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FTRUNC, SDLoc(N), VT, {N0}))
19035     return C;
19036 
19037   // fold ftrunc (known rounded int x) -> x
19038   // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
19039   // likely to be generated to extract integer from a rounded floating value.
19040   switch (N0.getOpcode()) {
19041   default: break;
19042   case ISD::FRINT:
19043   case ISD::FTRUNC:
19044   case ISD::FNEARBYINT:
19045   case ISD::FROUNDEVEN:
19046   case ISD::FFLOOR:
19047   case ISD::FCEIL:
19048     return N0;
19049   }
19050 
19051   return SDValue();
19052 }
19053 
visitFFREXP(SDNode * N)19054 SDValue DAGCombiner::visitFFREXP(SDNode *N) {
19055   SDValue N0 = N->getOperand(0);
19056 
19057   // fold (ffrexp c1) -> ffrexp(c1)
19058   if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
19059     return DAG.getNode(ISD::FFREXP, SDLoc(N), N->getVTList(), N0);
19060   return SDValue();
19061 }
19062 
visitFFLOOR(SDNode * N)19063 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
19064   SDValue N0 = N->getOperand(0);
19065   EVT VT = N->getValueType(0);
19066 
19067   // fold (ffloor c1) -> ffloor(c1)
19068   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FFLOOR, SDLoc(N), VT, {N0}))
19069     return C;
19070 
19071   return SDValue();
19072 }
19073 
visitFNEG(SDNode * N)19074 SDValue DAGCombiner::visitFNEG(SDNode *N) {
19075   SDValue N0 = N->getOperand(0);
19076   EVT VT = N->getValueType(0);
19077   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19078 
19079   // Constant fold FNEG.
19080   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FNEG, SDLoc(N), VT, {N0}))
19081     return C;
19082 
19083   if (SDValue NegN0 =
19084           TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
19085     return NegN0;
19086 
19087   // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
19088   // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
19089   // know it was called from a context with a nsz flag if the input fsub does
19090   // not.
19091   if (N0.getOpcode() == ISD::FSUB &&
19092       (DAG.getTarget().Options.NoSignedZerosFPMath ||
19093        N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
19094     return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
19095                        N0.getOperand(0));
19096   }
19097 
19098   if (SimplifyDemandedBits(SDValue(N, 0)))
19099     return SDValue(N, 0);
19100 
19101   if (SDValue Cast = foldSignChangeInBitcast(N))
19102     return Cast;
19103 
19104   return SDValue();
19105 }
19106 
visitFMinMax(SDNode * N)19107 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
19108   SDValue N0 = N->getOperand(0);
19109   SDValue N1 = N->getOperand(1);
19110   EVT VT = N->getValueType(0);
19111   const SDNodeFlags Flags = N->getFlags();
19112   unsigned Opc = N->getOpcode();
19113   bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
19114   bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
19115   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19116 
19117   // Constant fold.
19118   if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
19119     return C;
19120 
19121   // Canonicalize to constant on RHS.
19122   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
19123       !DAG.isConstantFPBuildVectorOrConstantFP(N1))
19124     return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
19125 
19126   if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
19127     const APFloat &AF = N1CFP->getValueAPF();
19128 
19129     // minnum(X, nan) -> X
19130     // maxnum(X, nan) -> X
19131     // minimum(X, nan) -> nan
19132     // maximum(X, nan) -> nan
19133     if (AF.isNaN())
19134       return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
19135 
19136     // In the following folds, inf can be replaced with the largest finite
19137     // float, if the ninf flag is set.
19138     if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
19139       // minnum(X, -inf) -> -inf
19140       // maxnum(X, +inf) -> +inf
19141       // minimum(X, -inf) -> -inf if nnan
19142       // maximum(X, +inf) -> +inf if nnan
19143       if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
19144         return N->getOperand(1);
19145 
19146       // minnum(X, +inf) -> X if nnan
19147       // maxnum(X, -inf) -> X if nnan
19148       // minimum(X, +inf) -> X
19149       // maximum(X, -inf) -> X
19150       if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
19151         return N->getOperand(0);
19152     }
19153   }
19154 
19155   if (SDValue SD = reassociateReduction(
19156           PropagatesNaN
19157               ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
19158               : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
19159           Opc, SDLoc(N), VT, N0, N1, Flags))
19160     return SD;
19161 
19162   return SDValue();
19163 }
19164 
visitFABS(SDNode * N)19165 SDValue DAGCombiner::visitFABS(SDNode *N) {
19166   SDValue N0 = N->getOperand(0);
19167   EVT VT = N->getValueType(0);
19168   SDLoc DL(N);
19169 
19170   // fold (fabs c1) -> fabs(c1)
19171   if (SDValue C = DAG.FoldConstantArithmetic(ISD::FABS, DL, VT, {N0}))
19172     return C;
19173 
19174   if (SimplifyDemandedBits(SDValue(N, 0)))
19175     return SDValue(N, 0);
19176 
19177   if (SDValue Cast = foldSignChangeInBitcast(N))
19178     return Cast;
19179 
19180   return SDValue();
19181 }
19182 
visitBRCOND(SDNode * N)19183 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
19184   SDValue Chain = N->getOperand(0);
19185   SDValue N1 = N->getOperand(1);
19186   SDValue N2 = N->getOperand(2);
19187 
19188   // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
19189   // nondeterministic jumps).
19190   if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
19191     return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
19192                        N1->getOperand(0), N2, N->getFlags());
19193   }
19194 
19195   // Variant of the previous fold where there is a SETCC in between:
19196   //   BRCOND(SETCC(FREEZE(X), CONST, Cond))
19197   // =>
19198   //   BRCOND(FREEZE(SETCC(X, CONST, Cond)))
19199   // =>
19200   //   BRCOND(SETCC(X, CONST, Cond))
19201   // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
19202   // isn't equivalent to true or false.
19203   // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
19204   // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
19205   if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
19206     SDValue S0 = N1->getOperand(0), S1 = N1->getOperand(1);
19207     ISD::CondCode Cond = cast<CondCodeSDNode>(N1->getOperand(2))->get();
19208     ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(S0);
19209     ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(S1);
19210     bool Updated = false;
19211 
19212     // Is 'X Cond C' always true or false?
19213     auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
19214       bool False = (Cond == ISD::SETULT && C->isZero()) ||
19215                    (Cond == ISD::SETLT && C->isMinSignedValue()) ||
19216                    (Cond == ISD::SETUGT && C->isAllOnes()) ||
19217                    (Cond == ISD::SETGT && C->isMaxSignedValue());
19218       bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
19219                   (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
19220                   (Cond == ISD::SETUGE && C->isZero()) ||
19221                   (Cond == ISD::SETGE && C->isMinSignedValue());
19222       return True || False;
19223     };
19224 
19225     if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
19226       if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
19227         S0 = S0->getOperand(0);
19228         Updated = true;
19229       }
19230     }
19231     if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
19232       if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond), S0C)) {
19233         S1 = S1->getOperand(0);
19234         Updated = true;
19235       }
19236     }
19237 
19238     if (Updated)
19239       return DAG.getNode(
19240           ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
19241           DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2,
19242           N->getFlags());
19243   }
19244 
19245   // If N is a constant we could fold this into a fallthrough or unconditional
19246   // branch. However that doesn't happen very often in normal code, because
19247   // Instcombine/SimplifyCFG should have handled the available opportunities.
19248   // If we did this folding here, it would be necessary to update the
19249   // MachineBasicBlock CFG, which is awkward.
19250 
19251   // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
19252   // on the target.
19253   if (N1.getOpcode() == ISD::SETCC &&
19254       TLI.isOperationLegalOrCustom(ISD::BR_CC,
19255                                    N1.getOperand(0).getValueType())) {
19256     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
19257                        Chain, N1.getOperand(2),
19258                        N1.getOperand(0), N1.getOperand(1), N2);
19259   }
19260 
19261   if (N1.hasOneUse()) {
19262     // rebuildSetCC calls visitXor which may change the Chain when there is a
19263     // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
19264     HandleSDNode ChainHandle(Chain);
19265     if (SDValue NewN1 = rebuildSetCC(N1))
19266       return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
19267                          ChainHandle.getValue(), NewN1, N2, N->getFlags());
19268   }
19269 
19270   return SDValue();
19271 }
19272 
rebuildSetCC(SDValue N)19273 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
19274   if (N.getOpcode() == ISD::SRL ||
19275       (N.getOpcode() == ISD::TRUNCATE &&
19276        (N.getOperand(0).hasOneUse() &&
19277         N.getOperand(0).getOpcode() == ISD::SRL))) {
19278     // Look pass the truncate.
19279     if (N.getOpcode() == ISD::TRUNCATE)
19280       N = N.getOperand(0);
19281 
19282     // Match this pattern so that we can generate simpler code:
19283     //
19284     //   %a = ...
19285     //   %b = and i32 %a, 2
19286     //   %c = srl i32 %b, 1
19287     //   brcond i32 %c ...
19288     //
19289     // into
19290     //
19291     //   %a = ...
19292     //   %b = and i32 %a, 2
19293     //   %c = setcc eq %b, 0
19294     //   brcond %c ...
19295     //
19296     // This applies only when the AND constant value has one bit set and the
19297     // SRL constant is equal to the log2 of the AND constant. The back-end is
19298     // smart enough to convert the result into a TEST/JMP sequence.
19299     SDValue Op0 = N.getOperand(0);
19300     SDValue Op1 = N.getOperand(1);
19301 
19302     if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
19303       SDValue AndOp1 = Op0.getOperand(1);
19304 
19305       if (AndOp1.getOpcode() == ISD::Constant) {
19306         const APInt &AndConst = AndOp1->getAsAPIntVal();
19307 
19308         if (AndConst.isPowerOf2() &&
19309             Op1->getAsAPIntVal() == AndConst.logBase2()) {
19310           SDLoc DL(N);
19311           return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
19312                               Op0, DAG.getConstant(0, DL, Op0.getValueType()),
19313                               ISD::SETNE);
19314         }
19315       }
19316     }
19317   }
19318 
19319   // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
19320   // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
19321   if (N.getOpcode() == ISD::XOR) {
19322     // Because we may call this on a speculatively constructed
19323     // SimplifiedSetCC Node, we need to simplify this node first.
19324     // Ideally this should be folded into SimplifySetCC and not
19325     // here. For now, grab a handle to N so we don't lose it from
19326     // replacements interal to the visit.
19327     HandleSDNode XORHandle(N);
19328     while (N.getOpcode() == ISD::XOR) {
19329       SDValue Tmp = visitXOR(N.getNode());
19330       // No simplification done.
19331       if (!Tmp.getNode())
19332         break;
19333       // Returning N is form in-visit replacement that may invalidated
19334       // N. Grab value from Handle.
19335       if (Tmp.getNode() == N.getNode())
19336         N = XORHandle.getValue();
19337       else // Node simplified. Try simplifying again.
19338         N = Tmp;
19339     }
19340 
19341     if (N.getOpcode() != ISD::XOR)
19342       return N;
19343 
19344     SDValue Op0 = N->getOperand(0);
19345     SDValue Op1 = N->getOperand(1);
19346 
19347     if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
19348       bool Equal = false;
19349       // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
19350       if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
19351           Op0.getValueType() == MVT::i1) {
19352         N = Op0;
19353         Op0 = N->getOperand(0);
19354         Op1 = N->getOperand(1);
19355         Equal = true;
19356       }
19357 
19358       EVT SetCCVT = N.getValueType();
19359       if (LegalTypes)
19360         SetCCVT = getSetCCResultType(SetCCVT);
19361       // Replace the uses of XOR with SETCC. Note, avoid this transformation if
19362       // it would introduce illegal operations post-legalization as this can
19363       // result in infinite looping between converting xor->setcc here, and
19364       // expanding setcc->xor in LegalizeSetCCCondCode if requested.
19365       const ISD::CondCode CC = Equal ? ISD::SETEQ : ISD::SETNE;
19366       if (!LegalOperations || TLI.isCondCodeLegal(CC, Op0.getSimpleValueType()))
19367         return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1, CC);
19368     }
19369   }
19370 
19371   return SDValue();
19372 }
19373 
19374 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
19375 //
visitBR_CC(SDNode * N)19376 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
19377   CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
19378   SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
19379 
19380   // If N is a constant we could fold this into a fallthrough or unconditional
19381   // branch. However that doesn't happen very often in normal code, because
19382   // Instcombine/SimplifyCFG should have handled the available opportunities.
19383   // If we did this folding here, it would be necessary to update the
19384   // MachineBasicBlock CFG, which is awkward.
19385 
19386   // Use SimplifySetCC to simplify SETCC's.
19387   SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
19388                                CondLHS, CondRHS, CC->get(), SDLoc(N),
19389                                false);
19390   if (Simp.getNode()) AddToWorklist(Simp.getNode());
19391 
19392   // fold to a simpler setcc
19393   if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
19394     return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
19395                        N->getOperand(0), Simp.getOperand(2),
19396                        Simp.getOperand(0), Simp.getOperand(1),
19397                        N->getOperand(4));
19398 
19399   return SDValue();
19400 }
19401 
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)19402 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
19403                                      bool &IsLoad, bool &IsMasked, SDValue &Ptr,
19404                                      const TargetLowering &TLI) {
19405   if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
19406     if (LD->isIndexed())
19407       return false;
19408     EVT VT = LD->getMemoryVT();
19409     if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
19410       return false;
19411     Ptr = LD->getBasePtr();
19412   } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
19413     if (ST->isIndexed())
19414       return false;
19415     EVT VT = ST->getMemoryVT();
19416     if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
19417       return false;
19418     Ptr = ST->getBasePtr();
19419     IsLoad = false;
19420   } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
19421     if (LD->isIndexed())
19422       return false;
19423     EVT VT = LD->getMemoryVT();
19424     if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
19425         !TLI.isIndexedMaskedLoadLegal(Dec, VT))
19426       return false;
19427     Ptr = LD->getBasePtr();
19428     IsMasked = true;
19429   } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
19430     if (ST->isIndexed())
19431       return false;
19432     EVT VT = ST->getMemoryVT();
19433     if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
19434         !TLI.isIndexedMaskedStoreLegal(Dec, VT))
19435       return false;
19436     Ptr = ST->getBasePtr();
19437     IsLoad = false;
19438     IsMasked = true;
19439   } else {
19440     return false;
19441   }
19442   return true;
19443 }
19444 
19445 /// Try turning a load/store into a pre-indexed load/store when the base
19446 /// pointer is an add or subtract and it has other uses besides the load/store.
19447 /// After the transformation, the new indexed load/store has effectively folded
19448 /// the add/subtract in and all of its other uses are redirected to the
19449 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)19450 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
19451   if (Level < AfterLegalizeDAG)
19452     return false;
19453 
19454   bool IsLoad = true;
19455   bool IsMasked = false;
19456   SDValue Ptr;
19457   if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
19458                                 Ptr, TLI))
19459     return false;
19460 
19461   // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
19462   // out.  There is no reason to make this a preinc/predec.
19463   if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
19464       Ptr->hasOneUse())
19465     return false;
19466 
19467   // Ask the target to do addressing mode selection.
19468   SDValue BasePtr;
19469   SDValue Offset;
19470   ISD::MemIndexedMode AM = ISD::UNINDEXED;
19471   if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
19472     return false;
19473 
19474   // Backends without true r+i pre-indexed forms may need to pass a
19475   // constant base with a variable offset so that constant coercion
19476   // will work with the patterns in canonical form.
19477   bool Swapped = false;
19478   if (isa<ConstantSDNode>(BasePtr)) {
19479     std::swap(BasePtr, Offset);
19480     Swapped = true;
19481   }
19482 
19483   // Don't create a indexed load / store with zero offset.
19484   if (isNullConstant(Offset))
19485     return false;
19486 
19487   // Try turning it into a pre-indexed load / store except when:
19488   // 1) The new base ptr is a frame index.
19489   // 2) If N is a store and the new base ptr is either the same as or is a
19490   //    predecessor of the value being stored.
19491   // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
19492   //    that would create a cycle.
19493   // 4) All uses are load / store ops that use it as old base ptr.
19494 
19495   // Check #1.  Preinc'ing a frame index would require copying the stack pointer
19496   // (plus the implicit offset) to a register to preinc anyway.
19497   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
19498     return false;
19499 
19500   // Check #2.
19501   if (!IsLoad) {
19502     SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
19503                            : cast<StoreSDNode>(N)->getValue();
19504 
19505     // Would require a copy.
19506     if (Val == BasePtr)
19507       return false;
19508 
19509     // Would create a cycle.
19510     if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
19511       return false;
19512   }
19513 
19514   // Caches for hasPredecessorHelper.
19515   SmallPtrSet<const SDNode *, 32> Visited;
19516   SmallVector<const SDNode *, 16> Worklist;
19517   Worklist.push_back(N);
19518 
19519   // If the offset is a constant, there may be other adds of constants that
19520   // can be folded with this one. We should do this to avoid having to keep
19521   // a copy of the original base pointer.
19522   SmallVector<SDNode *, 16> OtherUses;
19523   unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19524   if (isa<ConstantSDNode>(Offset))
19525     for (SDUse &Use : BasePtr->uses()) {
19526       // Skip the use that is Ptr and uses of other results from BasePtr's
19527       // node (important for nodes that return multiple results).
19528       if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
19529         continue;
19530 
19531       if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist,
19532                                        MaxSteps))
19533         continue;
19534 
19535       if (Use.getUser()->getOpcode() != ISD::ADD &&
19536           Use.getUser()->getOpcode() != ISD::SUB) {
19537         OtherUses.clear();
19538         break;
19539       }
19540 
19541       SDValue Op1 = Use.getUser()->getOperand((Use.getOperandNo() + 1) & 1);
19542       if (!isa<ConstantSDNode>(Op1)) {
19543         OtherUses.clear();
19544         break;
19545       }
19546 
19547       // FIXME: In some cases, we can be smarter about this.
19548       if (Op1.getValueType() != Offset.getValueType()) {
19549         OtherUses.clear();
19550         break;
19551       }
19552 
19553       OtherUses.push_back(Use.getUser());
19554     }
19555 
19556   if (Swapped)
19557     std::swap(BasePtr, Offset);
19558 
19559   // Now check for #3 and #4.
19560   bool RealUse = false;
19561 
19562   for (SDNode *User : Ptr->users()) {
19563     if (User == N)
19564       continue;
19565     if (SDNode::hasPredecessorHelper(User, Visited, Worklist, MaxSteps))
19566       return false;
19567 
19568     // If Ptr may be folded in addressing mode of other use, then it's
19569     // not profitable to do this transformation.
19570     if (!canFoldInAddressingMode(Ptr.getNode(), User, DAG, TLI))
19571       RealUse = true;
19572   }
19573 
19574   if (!RealUse)
19575     return false;
19576 
19577   SDValue Result;
19578   if (!IsMasked) {
19579     if (IsLoad)
19580       Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
19581     else
19582       Result =
19583           DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
19584   } else {
19585     if (IsLoad)
19586       Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
19587                                         Offset, AM);
19588     else
19589       Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
19590                                          Offset, AM);
19591   }
19592   ++PreIndexedNodes;
19593   ++NodesCombined;
19594   LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
19595              Result.dump(&DAG); dbgs() << '\n');
19596   WorklistRemover DeadNodes(*this);
19597   if (IsLoad) {
19598     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
19599     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
19600   } else {
19601     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
19602   }
19603 
19604   // Finally, since the node is now dead, remove it from the graph.
19605   deleteAndRecombine(N);
19606 
19607   if (Swapped)
19608     std::swap(BasePtr, Offset);
19609 
19610   // Replace other uses of BasePtr that can be updated to use Ptr
19611   for (SDNode *OtherUse : OtherUses) {
19612     unsigned OffsetIdx = 1;
19613     if (OtherUse->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
19614       OffsetIdx = 0;
19615     assert(OtherUse->getOperand(!OffsetIdx).getNode() == BasePtr.getNode() &&
19616            "Expected BasePtr operand");
19617 
19618     // We need to replace ptr0 in the following expression:
19619     //   x0 * offset0 + y0 * ptr0 = t0
19620     // knowing that
19621     //   x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
19622     //
19623     // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
19624     // indexed load/store and the expression that needs to be re-written.
19625     //
19626     // Therefore, we have:
19627     //   t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
19628 
19629     auto *CN = cast<ConstantSDNode>(OtherUse->getOperand(OffsetIdx));
19630     const APInt &Offset0 = CN->getAPIntValue();
19631     const APInt &Offset1 = Offset->getAsAPIntVal();
19632     int X0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
19633     int Y0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
19634     int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
19635     int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
19636 
19637     unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
19638 
19639     APInt CNV = Offset0;
19640     if (X0 < 0) CNV = -CNV;
19641     if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
19642     else CNV = CNV - Offset1;
19643 
19644     SDLoc DL(OtherUse);
19645 
19646     // We can now generate the new expression.
19647     SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
19648     SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
19649 
19650     SDValue NewUse =
19651         DAG.getNode(Opcode, DL, OtherUse->getValueType(0), NewOp1, NewOp2);
19652     DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUse, 0), NewUse);
19653     deleteAndRecombine(OtherUse);
19654   }
19655 
19656   // Replace the uses of Ptr with uses of the updated base value.
19657   DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
19658   deleteAndRecombine(Ptr.getNode());
19659   AddToWorklist(Result.getNode());
19660 
19661   return true;
19662 }
19663 
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)19664 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
19665                                    SDValue &BasePtr, SDValue &Offset,
19666                                    ISD::MemIndexedMode &AM,
19667                                    SelectionDAG &DAG,
19668                                    const TargetLowering &TLI) {
19669   if (PtrUse == N ||
19670       (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
19671     return false;
19672 
19673   if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
19674     return false;
19675 
19676   // Don't create a indexed load / store with zero offset.
19677   if (isNullConstant(Offset))
19678     return false;
19679 
19680   if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
19681     return false;
19682 
19683   SmallPtrSet<const SDNode *, 32> Visited;
19684   unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19685   for (SDNode *User : BasePtr->users()) {
19686     if (User == Ptr.getNode())
19687       continue;
19688 
19689     // No if there's a later user which could perform the index instead.
19690     if (isa<MemSDNode>(User)) {
19691       bool IsLoad = true;
19692       bool IsMasked = false;
19693       SDValue OtherPtr;
19694       if (getCombineLoadStoreParts(User, ISD::POST_INC, ISD::POST_DEC, IsLoad,
19695                                    IsMasked, OtherPtr, TLI)) {
19696         SmallVector<const SDNode *, 2> Worklist;
19697         Worklist.push_back(User);
19698         if (SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps))
19699           return false;
19700       }
19701     }
19702 
19703     // If all the uses are load / store addresses, then don't do the
19704     // transformation.
19705     if (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SUB) {
19706       for (SDNode *UserUser : User->users())
19707         if (canFoldInAddressingMode(User, UserUser, DAG, TLI))
19708           return false;
19709     }
19710   }
19711   return true;
19712 }
19713 
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)19714 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
19715                                          bool &IsMasked, SDValue &Ptr,
19716                                          SDValue &BasePtr, SDValue &Offset,
19717                                          ISD::MemIndexedMode &AM,
19718                                          SelectionDAG &DAG,
19719                                          const TargetLowering &TLI) {
19720   if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
19721                                 IsMasked, Ptr, TLI) ||
19722       Ptr->hasOneUse())
19723     return nullptr;
19724 
19725   // Try turning it into a post-indexed load / store except when
19726   // 1) All uses are load / store ops that use it as base ptr (and
19727   //    it may be folded as addressing mmode).
19728   // 2) Op must be independent of N, i.e. Op is neither a predecessor
19729   //    nor a successor of N. Otherwise, if Op is folded that would
19730   //    create a cycle.
19731   unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19732   for (SDNode *Op : Ptr->users()) {
19733     // Check for #1.
19734     if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
19735       continue;
19736 
19737     // Check for #2.
19738     SmallPtrSet<const SDNode *, 32> Visited;
19739     SmallVector<const SDNode *, 8> Worklist;
19740     // Ptr is predecessor to both N and Op.
19741     Visited.insert(Ptr.getNode());
19742     Worklist.push_back(N);
19743     Worklist.push_back(Op);
19744     if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
19745         !SDNode::hasPredecessorHelper(Op, Visited, Worklist, MaxSteps))
19746       return Op;
19747   }
19748   return nullptr;
19749 }
19750 
19751 /// Try to combine a load/store with a add/sub of the base pointer node into a
19752 /// post-indexed load/store. The transformation folded the add/subtract into the
19753 /// new indexed load/store effectively and all of its uses are redirected to the
19754 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)19755 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
19756   if (Level < AfterLegalizeDAG)
19757     return false;
19758 
19759   bool IsLoad = true;
19760   bool IsMasked = false;
19761   SDValue Ptr;
19762   SDValue BasePtr;
19763   SDValue Offset;
19764   ISD::MemIndexedMode AM = ISD::UNINDEXED;
19765   SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
19766                                          Offset, AM, DAG, TLI);
19767   if (!Op)
19768     return false;
19769 
19770   SDValue Result;
19771   if (!IsMasked)
19772     Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
19773                                          Offset, AM)
19774                     : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
19775                                           BasePtr, Offset, AM);
19776   else
19777     Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
19778                                                BasePtr, Offset, AM)
19779                     : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
19780                                                 BasePtr, Offset, AM);
19781   ++PostIndexedNodes;
19782   ++NodesCombined;
19783   LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
19784              Result.dump(&DAG); dbgs() << '\n');
19785   WorklistRemover DeadNodes(*this);
19786   if (IsLoad) {
19787     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
19788     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
19789   } else {
19790     DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
19791   }
19792 
19793   // Finally, since the node is now dead, remove it from the graph.
19794   deleteAndRecombine(N);
19795 
19796   // Replace the uses of Use with uses of the updated base value.
19797   DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
19798                                 Result.getValue(IsLoad ? 1 : 0));
19799   deleteAndRecombine(Op);
19800   return true;
19801 }
19802 
19803 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)19804 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
19805   ISD::MemIndexedMode AM = LD->getAddressingMode();
19806   assert(AM != ISD::UNINDEXED);
19807   SDValue BP = LD->getOperand(1);
19808   SDValue Inc = LD->getOperand(2);
19809 
19810   // Some backends use TargetConstants for load offsets, but don't expect
19811   // TargetConstants in general ADD nodes. We can convert these constants into
19812   // regular Constants (if the constant is not opaque).
19813   assert((Inc.getOpcode() != ISD::TargetConstant ||
19814           !cast<ConstantSDNode>(Inc)->isOpaque()) &&
19815          "Cannot split out indexing using opaque target constants");
19816   if (Inc.getOpcode() == ISD::TargetConstant) {
19817     ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
19818     Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
19819                           ConstInc->getValueType(0));
19820   }
19821 
19822   unsigned Opc =
19823       (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
19824   return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
19825 }
19826 
numVectorEltsOrZero(EVT T)19827 static inline ElementCount numVectorEltsOrZero(EVT T) {
19828   return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
19829 }
19830 
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)19831 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
19832   EVT STType = Val.getValueType();
19833   EVT STMemType = ST->getMemoryVT();
19834   if (STType == STMemType)
19835     return true;
19836   if (isTypeLegal(STMemType))
19837     return false; // fail.
19838   if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
19839       TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
19840     Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
19841     return true;
19842   }
19843   if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
19844       STType.isInteger() && STMemType.isInteger()) {
19845     Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
19846     return true;
19847   }
19848   if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
19849     Val = DAG.getBitcast(STMemType, Val);
19850     return true;
19851   }
19852   return false; // fail.
19853 }
19854 
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)19855 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
19856   EVT LDMemType = LD->getMemoryVT();
19857   EVT LDType = LD->getValueType(0);
19858   assert(Val.getValueType() == LDMemType &&
19859          "Attempting to extend value of non-matching type");
19860   if (LDType == LDMemType)
19861     return true;
19862   if (LDMemType.isInteger() && LDType.isInteger()) {
19863     switch (LD->getExtensionType()) {
19864     case ISD::NON_EXTLOAD:
19865       Val = DAG.getBitcast(LDType, Val);
19866       return true;
19867     case ISD::EXTLOAD:
19868       Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
19869       return true;
19870     case ISD::SEXTLOAD:
19871       Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
19872       return true;
19873     case ISD::ZEXTLOAD:
19874       Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
19875       return true;
19876     }
19877   }
19878   return false;
19879 }
19880 
getUniqueStoreFeeding(LoadSDNode * LD,int64_t & Offset)19881 StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
19882                                                 int64_t &Offset) {
19883   SDValue Chain = LD->getOperand(0);
19884 
19885   // Look through CALLSEQ_START.
19886   if (Chain.getOpcode() == ISD::CALLSEQ_START)
19887     Chain = Chain->getOperand(0);
19888 
19889   StoreSDNode *ST = nullptr;
19890   SmallVector<SDValue, 8> Aliases;
19891   if (Chain.getOpcode() == ISD::TokenFactor) {
19892     // Look for unique store within the TokenFactor.
19893     for (SDValue Op : Chain->ops()) {
19894       StoreSDNode *Store = dyn_cast<StoreSDNode>(Op.getNode());
19895       if (!Store)
19896         continue;
19897       BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
19898       BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
19899       if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
19900         continue;
19901       // Make sure the store is not aliased with any nodes in TokenFactor.
19902       GatherAllAliases(Store, Chain, Aliases);
19903       if (Aliases.empty() ||
19904           (Aliases.size() == 1 && Aliases.front().getNode() == Store))
19905         ST = Store;
19906       break;
19907     }
19908   } else {
19909     StoreSDNode *Store = dyn_cast<StoreSDNode>(Chain.getNode());
19910     if (Store) {
19911       BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
19912       BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
19913       if (BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
19914         ST = Store;
19915     }
19916   }
19917 
19918   return ST;
19919 }
19920 
ForwardStoreValueToDirectLoad(LoadSDNode * LD)19921 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
19922   if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
19923     return SDValue();
19924   SDValue Chain = LD->getOperand(0);
19925   int64_t Offset;
19926 
19927   StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
19928   // TODO: Relax this restriction for unordered atomics (see D66309)
19929   if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
19930     return SDValue();
19931 
19932   EVT LDType = LD->getValueType(0);
19933   EVT LDMemType = LD->getMemoryVT();
19934   EVT STMemType = ST->getMemoryVT();
19935   EVT STType = ST->getValue().getValueType();
19936 
19937   // There are two cases to consider here:
19938   //  1. The store is fixed width and the load is scalable. In this case we
19939   //     don't know at compile time if the store completely envelops the load
19940   //     so we abandon the optimisation.
19941   //  2. The store is scalable and the load is fixed width. We could
19942   //     potentially support a limited number of cases here, but there has been
19943   //     no cost-benefit analysis to prove it's worth it.
19944   bool LdStScalable = LDMemType.isScalableVT();
19945   if (LdStScalable != STMemType.isScalableVT())
19946     return SDValue();
19947 
19948   // If we are dealing with scalable vectors on a big endian platform the
19949   // calculation of offsets below becomes trickier, since we do not know at
19950   // compile time the absolute size of the vector. Until we've done more
19951   // analysis on big-endian platforms it seems better to bail out for now.
19952   if (LdStScalable && DAG.getDataLayout().isBigEndian())
19953     return SDValue();
19954 
19955   // Normalize for Endianness. After this Offset=0 will denote that the least
19956   // significant bit in the loaded value maps to the least significant bit in
19957   // the stored value). With Offset=n (for n > 0) the loaded value starts at the
19958   // n:th least significant byte of the stored value.
19959   int64_t OrigOffset = Offset;
19960   if (DAG.getDataLayout().isBigEndian())
19961     Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
19962               (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
19963                  8 -
19964              Offset;
19965 
19966   // Check that the stored value cover all bits that are loaded.
19967   bool STCoversLD;
19968 
19969   TypeSize LdMemSize = LDMemType.getSizeInBits();
19970   TypeSize StMemSize = STMemType.getSizeInBits();
19971   if (LdStScalable)
19972     STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
19973   else
19974     STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
19975                                    StMemSize.getFixedValue());
19976 
19977   auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
19978     if (LD->isIndexed()) {
19979       // Cannot handle opaque target constants and we must respect the user's
19980       // request not to split indexes from loads.
19981       if (!canSplitIdx(LD))
19982         return SDValue();
19983       SDValue Idx = SplitIndexingFromLoad(LD);
19984       SDValue Ops[] = {Val, Idx, Chain};
19985       return CombineTo(LD, Ops, 3);
19986     }
19987     return CombineTo(LD, Val, Chain);
19988   };
19989 
19990   if (!STCoversLD)
19991     return SDValue();
19992 
19993   // Memory as copy space (potentially masked).
19994   if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
19995     // Simple case: Direct non-truncating forwarding
19996     if (LDType.getSizeInBits() == LdMemSize)
19997       return ReplaceLd(LD, ST->getValue(), Chain);
19998     // Can we model the truncate and extension with an and mask?
19999     if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
20000         !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
20001       // Mask to size of LDMemType
20002       auto Mask =
20003           DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
20004                                                StMemSize.getFixedValue()),
20005                           SDLoc(ST), STType);
20006       auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
20007       return ReplaceLd(LD, Val, Chain);
20008     }
20009   }
20010 
20011   // Handle some cases for big-endian that would be Offset 0 and handled for
20012   // little-endian.
20013   SDValue Val = ST->getValue();
20014   if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
20015     if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
20016         !LDType.isVector() && isTypeLegal(STType) &&
20017         TLI.isOperationLegal(ISD::SRL, STType)) {
20018       Val = DAG.getNode(ISD::SRL, SDLoc(LD), STType, Val,
20019                         DAG.getConstant(Offset * 8, SDLoc(LD), STType));
20020       Offset = 0;
20021     }
20022   }
20023 
20024   // TODO: Deal with nonzero offset.
20025   if (LD->getBasePtr().isUndef() || Offset != 0)
20026     return SDValue();
20027   // Model necessary truncations / extenstions.
20028   // Truncate Value To Stored Memory Size.
20029   do {
20030     if (!getTruncatedStoreValue(ST, Val))
20031       break;
20032     if (!isTypeLegal(LDMemType))
20033       break;
20034     if (STMemType != LDMemType) {
20035       // TODO: Support vectors? This requires extract_subvector/bitcast.
20036       if (!STMemType.isVector() && !LDMemType.isVector() &&
20037           STMemType.isInteger() && LDMemType.isInteger())
20038         Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
20039       else
20040         break;
20041     }
20042     if (!extendLoadedValueToExtension(LD, Val))
20043       break;
20044     return ReplaceLd(LD, Val, Chain);
20045   } while (false);
20046 
20047   // On failure, cleanup dead nodes we may have created.
20048   if (Val->use_empty())
20049     deleteAndRecombine(Val.getNode());
20050   return SDValue();
20051 }
20052 
visitLOAD(SDNode * N)20053 SDValue DAGCombiner::visitLOAD(SDNode *N) {
20054   LoadSDNode *LD  = cast<LoadSDNode>(N);
20055   SDValue Chain = LD->getChain();
20056   SDValue Ptr   = LD->getBasePtr();
20057 
20058   // If load is not volatile and there are no uses of the loaded value (and
20059   // the updated indexed value in case of indexed loads), change uses of the
20060   // chain value into uses of the chain input (i.e. delete the dead load).
20061   // TODO: Allow this for unordered atomics (see D66309)
20062   if (LD->isSimple()) {
20063     if (N->getValueType(1) == MVT::Other) {
20064       // Unindexed loads.
20065       if (!N->hasAnyUseOfValue(0)) {
20066         // It's not safe to use the two value CombineTo variant here. e.g.
20067         // v1, chain2 = load chain1, loc
20068         // v2, chain3 = load chain2, loc
20069         // v3         = add v2, c
20070         // Now we replace use of chain2 with chain1.  This makes the second load
20071         // isomorphic to the one we are deleting, and thus makes this load live.
20072         LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
20073                    dbgs() << "\nWith chain: "; Chain.dump(&DAG);
20074                    dbgs() << "\n");
20075         WorklistRemover DeadNodes(*this);
20076         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
20077         AddUsersToWorklist(Chain.getNode());
20078         if (N->use_empty())
20079           deleteAndRecombine(N);
20080 
20081         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
20082       }
20083     } else {
20084       // Indexed loads.
20085       assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
20086 
20087       // If this load has an opaque TargetConstant offset, then we cannot split
20088       // the indexing into an add/sub directly (that TargetConstant may not be
20089       // valid for a different type of node, and we cannot convert an opaque
20090       // target constant into a regular constant).
20091       bool CanSplitIdx = canSplitIdx(LD);
20092 
20093       if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
20094         SDValue Undef = DAG.getUNDEF(N->getValueType(0));
20095         SDValue Index;
20096         if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
20097           Index = SplitIndexingFromLoad(LD);
20098           // Try to fold the base pointer arithmetic into subsequent loads and
20099           // stores.
20100           AddUsersToWorklist(N);
20101         } else
20102           Index = DAG.getUNDEF(N->getValueType(1));
20103         LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
20104                    dbgs() << "\nWith: "; Undef.dump(&DAG);
20105                    dbgs() << " and 2 other values\n");
20106         WorklistRemover DeadNodes(*this);
20107         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
20108         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
20109         DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
20110         deleteAndRecombine(N);
20111         return SDValue(N, 0);   // Return N so it doesn't get rechecked!
20112       }
20113     }
20114   }
20115 
20116   // If this load is directly stored, replace the load value with the stored
20117   // value.
20118   if (auto V = ForwardStoreValueToDirectLoad(LD))
20119     return V;
20120 
20121   // Try to infer better alignment information than the load already has.
20122   if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
20123       !LD->isAtomic()) {
20124     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
20125       if (*Alignment > LD->getAlign() &&
20126           isAligned(*Alignment, LD->getSrcValueOffset())) {
20127         SDValue NewLoad = DAG.getExtLoad(
20128             LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
20129             LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
20130             LD->getMemOperand()->getFlags(), LD->getAAInfo());
20131         // NewLoad will always be N as we are only refining the alignment
20132         assert(NewLoad.getNode() == N);
20133         (void)NewLoad;
20134       }
20135     }
20136   }
20137 
20138   if (LD->isUnindexed()) {
20139     // Walk up chain skipping non-aliasing memory nodes.
20140     SDValue BetterChain = FindBetterChain(LD, Chain);
20141 
20142     // If there is a better chain.
20143     if (Chain != BetterChain) {
20144       SDValue ReplLoad;
20145 
20146       // Replace the chain to void dependency.
20147       if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
20148         ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
20149                                BetterChain, Ptr, LD->getMemOperand());
20150       } else {
20151         ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
20152                                   LD->getValueType(0),
20153                                   BetterChain, Ptr, LD->getMemoryVT(),
20154                                   LD->getMemOperand());
20155       }
20156 
20157       // Create token factor to keep old chain connected.
20158       SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
20159                                   MVT::Other, Chain, ReplLoad.getValue(1));
20160 
20161       // Replace uses with load result and token factor
20162       return CombineTo(N, ReplLoad.getValue(0), Token);
20163     }
20164   }
20165 
20166   // Try transforming N to an indexed load.
20167   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
20168     return SDValue(N, 0);
20169 
20170   // Try to slice up N to more direct loads if the slices are mapped to
20171   // different register banks or pairing can take place.
20172   if (SliceUpLoad(N))
20173     return SDValue(N, 0);
20174 
20175   return SDValue();
20176 }
20177 
20178 namespace {
20179 
20180 /// Helper structure used to slice a load in smaller loads.
20181 /// Basically a slice is obtained from the following sequence:
20182 /// Origin = load Ty1, Base
20183 /// Shift = srl Ty1 Origin, CstTy Amount
20184 /// Inst = trunc Shift to Ty2
20185 ///
20186 /// Then, it will be rewritten into:
20187 /// Slice = load SliceTy, Base + SliceOffset
20188 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
20189 ///
20190 /// SliceTy is deduced from the number of bits that are actually used to
20191 /// build Inst.
20192 struct LoadedSlice {
20193   /// Helper structure used to compute the cost of a slice.
20194   struct Cost {
20195     /// Are we optimizing for code size.
20196     bool ForCodeSize = false;
20197 
20198     /// Various cost.
20199     unsigned Loads = 0;
20200     unsigned Truncates = 0;
20201     unsigned CrossRegisterBanksCopies = 0;
20202     unsigned ZExts = 0;
20203     unsigned Shift = 0;
20204 
Cost__anon666e37104411::LoadedSlice::Cost20205     explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
20206 
20207     /// Get the cost of one isolated slice.
Cost__anon666e37104411::LoadedSlice::Cost20208     Cost(const LoadedSlice &LS, bool ForCodeSize)
20209         : ForCodeSize(ForCodeSize), Loads(1) {
20210       EVT TruncType = LS.Inst->getValueType(0);
20211       EVT LoadedType = LS.getLoadedType();
20212       if (TruncType != LoadedType &&
20213           !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
20214         ZExts = 1;
20215     }
20216 
20217     /// Account for slicing gain in the current cost.
20218     /// Slicing provide a few gains like removing a shift or a
20219     /// truncate. This method allows to grow the cost of the original
20220     /// load with the gain from this slice.
addSliceGain__anon666e37104411::LoadedSlice::Cost20221     void addSliceGain(const LoadedSlice &LS) {
20222       // Each slice saves a truncate.
20223       const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
20224       if (!TLI.isTruncateFree(LS.Inst->getOperand(0), LS.Inst->getValueType(0)))
20225         ++Truncates;
20226       // If there is a shift amount, this slice gets rid of it.
20227       if (LS.Shift)
20228         ++Shift;
20229       // If this slice can merge a cross register bank copy, account for it.
20230       if (LS.canMergeExpensiveCrossRegisterBankCopy())
20231         ++CrossRegisterBanksCopies;
20232     }
20233 
operator +=__anon666e37104411::LoadedSlice::Cost20234     Cost &operator+=(const Cost &RHS) {
20235       Loads += RHS.Loads;
20236       Truncates += RHS.Truncates;
20237       CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
20238       ZExts += RHS.ZExts;
20239       Shift += RHS.Shift;
20240       return *this;
20241     }
20242 
operator ==__anon666e37104411::LoadedSlice::Cost20243     bool operator==(const Cost &RHS) const {
20244       return Loads == RHS.Loads && Truncates == RHS.Truncates &&
20245              CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
20246              ZExts == RHS.ZExts && Shift == RHS.Shift;
20247     }
20248 
operator !=__anon666e37104411::LoadedSlice::Cost20249     bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
20250 
operator <__anon666e37104411::LoadedSlice::Cost20251     bool operator<(const Cost &RHS) const {
20252       // Assume cross register banks copies are as expensive as loads.
20253       // FIXME: Do we want some more target hooks?
20254       unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
20255       unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
20256       // Unless we are optimizing for code size, consider the
20257       // expensive operation first.
20258       if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
20259         return ExpensiveOpsLHS < ExpensiveOpsRHS;
20260       return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
20261              (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
20262     }
20263 
operator >__anon666e37104411::LoadedSlice::Cost20264     bool operator>(const Cost &RHS) const { return RHS < *this; }
20265 
operator <=__anon666e37104411::LoadedSlice::Cost20266     bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
20267 
operator >=__anon666e37104411::LoadedSlice::Cost20268     bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
20269   };
20270 
20271   // The last instruction that represent the slice. This should be a
20272   // truncate instruction.
20273   SDNode *Inst;
20274 
20275   // The original load instruction.
20276   LoadSDNode *Origin;
20277 
20278   // The right shift amount in bits from the original load.
20279   unsigned Shift;
20280 
20281   // The DAG from which Origin came from.
20282   // This is used to get some contextual information about legal types, etc.
20283   SelectionDAG *DAG;
20284 
LoadedSlice__anon666e37104411::LoadedSlice20285   LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
20286               unsigned Shift = 0, SelectionDAG *DAG = nullptr)
20287       : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
20288 
20289   /// Get the bits used in a chunk of bits \p BitWidth large.
20290   /// \return Result is \p BitWidth and has used bits set to 1 and
20291   ///         not used bits set to 0.
getUsedBits__anon666e37104411::LoadedSlice20292   APInt getUsedBits() const {
20293     // Reproduce the trunc(lshr) sequence:
20294     // - Start from the truncated value.
20295     // - Zero extend to the desired bit width.
20296     // - Shift left.
20297     assert(Origin && "No original load to compare against.");
20298     unsigned BitWidth = Origin->getValueSizeInBits(0);
20299     assert(Inst && "This slice is not bound to an instruction");
20300     assert(Inst->getValueSizeInBits(0) <= BitWidth &&
20301            "Extracted slice is bigger than the whole type!");
20302     APInt UsedBits(Inst->getValueSizeInBits(0), 0);
20303     UsedBits.setAllBits();
20304     UsedBits = UsedBits.zext(BitWidth);
20305     UsedBits <<= Shift;
20306     return UsedBits;
20307   }
20308 
20309   /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon666e37104411::LoadedSlice20310   unsigned getLoadedSize() const {
20311     unsigned SliceSize = getUsedBits().popcount();
20312     assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
20313     return SliceSize / 8;
20314   }
20315 
20316   /// Get the type that will be loaded for this slice.
20317   /// Note: This may not be the final type for the slice.
getLoadedType__anon666e37104411::LoadedSlice20318   EVT getLoadedType() const {
20319     assert(DAG && "Missing context");
20320     LLVMContext &Ctxt = *DAG->getContext();
20321     return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
20322   }
20323 
20324   /// Get the alignment of the load used for this slice.
getAlign__anon666e37104411::LoadedSlice20325   Align getAlign() const {
20326     Align Alignment = Origin->getAlign();
20327     uint64_t Offset = getOffsetFromBase();
20328     if (Offset != 0)
20329       Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
20330     return Alignment;
20331   }
20332 
20333   /// Check if this slice can be rewritten with legal operations.
isLegal__anon666e37104411::LoadedSlice20334   bool isLegal() const {
20335     // An invalid slice is not legal.
20336     if (!Origin || !Inst || !DAG)
20337       return false;
20338 
20339     // Offsets are for indexed load only, we do not handle that.
20340     if (!Origin->getOffset().isUndef())
20341       return false;
20342 
20343     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
20344 
20345     // Check that the type is legal.
20346     EVT SliceType = getLoadedType();
20347     if (!TLI.isTypeLegal(SliceType))
20348       return false;
20349 
20350     // Check that the load is legal for this type.
20351     if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
20352       return false;
20353 
20354     // Check that the offset can be computed.
20355     // 1. Check its type.
20356     EVT PtrType = Origin->getBasePtr().getValueType();
20357     if (PtrType == MVT::Untyped || PtrType.isExtended())
20358       return false;
20359 
20360     // 2. Check that it fits in the immediate.
20361     if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
20362       return false;
20363 
20364     // 3. Check that the computation is legal.
20365     if (!TLI.isOperationLegal(ISD::ADD, PtrType))
20366       return false;
20367 
20368     // Check that the zext is legal if it needs one.
20369     EVT TruncateType = Inst->getValueType(0);
20370     if (TruncateType != SliceType &&
20371         !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
20372       return false;
20373 
20374     return true;
20375   }
20376 
20377   /// Get the offset in bytes of this slice in the original chunk of
20378   /// bits.
20379   /// \pre DAG != nullptr.
getOffsetFromBase__anon666e37104411::LoadedSlice20380   uint64_t getOffsetFromBase() const {
20381     assert(DAG && "Missing context.");
20382     bool IsBigEndian = DAG->getDataLayout().isBigEndian();
20383     assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
20384     uint64_t Offset = Shift / 8;
20385     unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
20386     assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
20387            "The size of the original loaded type is not a multiple of a"
20388            " byte.");
20389     // If Offset is bigger than TySizeInBytes, it means we are loading all
20390     // zeros. This should have been optimized before in the process.
20391     assert(TySizeInBytes > Offset &&
20392            "Invalid shift amount for given loaded size");
20393     if (IsBigEndian)
20394       Offset = TySizeInBytes - Offset - getLoadedSize();
20395     return Offset;
20396   }
20397 
20398   /// Generate the sequence of instructions to load the slice
20399   /// represented by this object and redirect the uses of this slice to
20400   /// this new sequence of instructions.
20401   /// \pre this->Inst && this->Origin are valid Instructions and this
20402   /// object passed the legal check: LoadedSlice::isLegal returned true.
20403   /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon666e37104411::LoadedSlice20404   SDValue loadSlice() const {
20405     assert(Inst && Origin && "Unable to replace a non-existing slice.");
20406     const SDValue &OldBaseAddr = Origin->getBasePtr();
20407     SDValue BaseAddr = OldBaseAddr;
20408     // Get the offset in that chunk of bytes w.r.t. the endianness.
20409     int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
20410     assert(Offset >= 0 && "Offset too big to fit in int64_t!");
20411     if (Offset) {
20412       // BaseAddr = BaseAddr + Offset.
20413       EVT ArithType = BaseAddr.getValueType();
20414       SDLoc DL(Origin);
20415       BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
20416                               DAG->getConstant(Offset, DL, ArithType));
20417     }
20418 
20419     // Create the type of the loaded slice according to its size.
20420     EVT SliceType = getLoadedType();
20421 
20422     // Create the load for the slice.
20423     SDValue LastInst =
20424         DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
20425                      Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
20426                      Origin->getMemOperand()->getFlags());
20427     // If the final type is not the same as the loaded type, this means that
20428     // we have to pad with zero. Create a zero extend for that.
20429     EVT FinalType = Inst->getValueType(0);
20430     if (SliceType != FinalType)
20431       LastInst =
20432           DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
20433     return LastInst;
20434   }
20435 
20436   /// Check if this slice can be merged with an expensive cross register
20437   /// bank copy. E.g.,
20438   /// i = load i32
20439   /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon666e37104411::LoadedSlice20440   bool canMergeExpensiveCrossRegisterBankCopy() const {
20441     if (!Inst || !Inst->hasOneUse())
20442       return false;
20443     SDNode *User = *Inst->user_begin();
20444     if (User->getOpcode() != ISD::BITCAST)
20445       return false;
20446     assert(DAG && "Missing context");
20447     const TargetLowering &TLI = DAG->getTargetLoweringInfo();
20448     EVT ResVT = User->getValueType(0);
20449     const TargetRegisterClass *ResRC =
20450         TLI.getRegClassFor(ResVT.getSimpleVT(), User->isDivergent());
20451     const TargetRegisterClass *ArgRC =
20452         TLI.getRegClassFor(User->getOperand(0).getValueType().getSimpleVT(),
20453                            User->getOperand(0)->isDivergent());
20454     if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
20455       return false;
20456 
20457     // At this point, we know that we perform a cross-register-bank copy.
20458     // Check if it is expensive.
20459     const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
20460     // Assume bitcasts are cheap, unless both register classes do not
20461     // explicitly share a common sub class.
20462     if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
20463       return false;
20464 
20465     // Check if it will be merged with the load.
20466     // 1. Check the alignment / fast memory access constraint.
20467     unsigned IsFast = 0;
20468     if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
20469                                 Origin->getAddressSpace(), getAlign(),
20470                                 Origin->getMemOperand()->getFlags(), &IsFast) ||
20471         !IsFast)
20472       return false;
20473 
20474     // 2. Check that the load is a legal operation for that type.
20475     if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
20476       return false;
20477 
20478     // 3. Check that we do not have a zext in the way.
20479     if (Inst->getValueType(0) != getLoadedType())
20480       return false;
20481 
20482     return true;
20483   }
20484 };
20485 
20486 } // end anonymous namespace
20487 
20488 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
20489 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)20490 static bool areUsedBitsDense(const APInt &UsedBits) {
20491   // If all the bits are one, this is dense!
20492   if (UsedBits.isAllOnes())
20493     return true;
20494 
20495   // Get rid of the unused bits on the right.
20496   APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countr_zero());
20497   // Get rid of the unused bits on the left.
20498   if (NarrowedUsedBits.countl_zero())
20499     NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
20500   // Check that the chunk of bits is completely used.
20501   return NarrowedUsedBits.isAllOnes();
20502 }
20503 
20504 /// Check whether or not \p First and \p Second are next to each other
20505 /// in memory. This means that there is no hole between the bits loaded
20506 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)20507 static bool areSlicesNextToEachOther(const LoadedSlice &First,
20508                                      const LoadedSlice &Second) {
20509   assert(First.Origin == Second.Origin && First.Origin &&
20510          "Unable to match different memory origins.");
20511   APInt UsedBits = First.getUsedBits();
20512   assert((UsedBits & Second.getUsedBits()) == 0 &&
20513          "Slices are not supposed to overlap.");
20514   UsedBits |= Second.getUsedBits();
20515   return areUsedBitsDense(UsedBits);
20516 }
20517 
20518 /// Adjust the \p GlobalLSCost according to the target
20519 /// paring capabilities and the layout of the slices.
20520 /// \pre \p GlobalLSCost should account for at least as many loads as
20521 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)20522 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
20523                                  LoadedSlice::Cost &GlobalLSCost) {
20524   unsigned NumberOfSlices = LoadedSlices.size();
20525   // If there is less than 2 elements, no pairing is possible.
20526   if (NumberOfSlices < 2)
20527     return;
20528 
20529   // Sort the slices so that elements that are likely to be next to each
20530   // other in memory are next to each other in the list.
20531   llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
20532     assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
20533     return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
20534   });
20535   const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
20536   // First (resp. Second) is the first (resp. Second) potentially candidate
20537   // to be placed in a paired load.
20538   const LoadedSlice *First = nullptr;
20539   const LoadedSlice *Second = nullptr;
20540   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
20541                 // Set the beginning of the pair.
20542                                                            First = Second) {
20543     Second = &LoadedSlices[CurrSlice];
20544 
20545     // If First is NULL, it means we start a new pair.
20546     // Get to the next slice.
20547     if (!First)
20548       continue;
20549 
20550     EVT LoadedType = First->getLoadedType();
20551 
20552     // If the types of the slices are different, we cannot pair them.
20553     if (LoadedType != Second->getLoadedType())
20554       continue;
20555 
20556     // Check if the target supplies paired loads for this type.
20557     Align RequiredAlignment;
20558     if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
20559       // move to the next pair, this type is hopeless.
20560       Second = nullptr;
20561       continue;
20562     }
20563     // Check if we meet the alignment requirement.
20564     if (First->getAlign() < RequiredAlignment)
20565       continue;
20566 
20567     // Check that both loads are next to each other in memory.
20568     if (!areSlicesNextToEachOther(*First, *Second))
20569       continue;
20570 
20571     assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
20572     --GlobalLSCost.Loads;
20573     // Move to the next pair.
20574     Second = nullptr;
20575   }
20576 }
20577 
20578 /// Check the profitability of all involved LoadedSlice.
20579 /// Currently, it is considered profitable if there is exactly two
20580 /// involved slices (1) which are (2) next to each other in memory, and
20581 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
20582 ///
20583 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
20584 /// the elements themselves.
20585 ///
20586 /// FIXME: When the cost model will be mature enough, we can relax
20587 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)20588 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
20589                                 const APInt &UsedBits, bool ForCodeSize) {
20590   unsigned NumberOfSlices = LoadedSlices.size();
20591   if (StressLoadSlicing)
20592     return NumberOfSlices > 1;
20593 
20594   // Check (1).
20595   if (NumberOfSlices != 2)
20596     return false;
20597 
20598   // Check (2).
20599   if (!areUsedBitsDense(UsedBits))
20600     return false;
20601 
20602   // Check (3).
20603   LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
20604   // The original code has one big load.
20605   OrigCost.Loads = 1;
20606   for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
20607     const LoadedSlice &LS = LoadedSlices[CurrSlice];
20608     // Accumulate the cost of all the slices.
20609     LoadedSlice::Cost SliceCost(LS, ForCodeSize);
20610     GlobalSlicingCost += SliceCost;
20611 
20612     // Account as cost in the original configuration the gain obtained
20613     // with the current slices.
20614     OrigCost.addSliceGain(LS);
20615   }
20616 
20617   // If the target supports paired load, adjust the cost accordingly.
20618   adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
20619   return OrigCost > GlobalSlicingCost;
20620 }
20621 
20622 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
20623 /// operations, split it in the various pieces being extracted.
20624 ///
20625 /// This sort of thing is introduced by SROA.
20626 /// This slicing takes care not to insert overlapping loads.
20627 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)20628 bool DAGCombiner::SliceUpLoad(SDNode *N) {
20629   if (Level < AfterLegalizeDAG)
20630     return false;
20631 
20632   LoadSDNode *LD = cast<LoadSDNode>(N);
20633   if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
20634       !LD->getValueType(0).isInteger())
20635     return false;
20636 
20637   // The algorithm to split up a load of a scalable vector into individual
20638   // elements currently requires knowing the length of the loaded type,
20639   // so will need adjusting to work on scalable vectors.
20640   if (LD->getValueType(0).isScalableVector())
20641     return false;
20642 
20643   // Keep track of already used bits to detect overlapping values.
20644   // In that case, we will just abort the transformation.
20645   APInt UsedBits(LD->getValueSizeInBits(0), 0);
20646 
20647   SmallVector<LoadedSlice, 4> LoadedSlices;
20648 
20649   // Check if this load is used as several smaller chunks of bits.
20650   // Basically, look for uses in trunc or trunc(lshr) and record a new chain
20651   // of computation for each trunc.
20652   for (SDUse &U : LD->uses()) {
20653     // Skip the uses of the chain.
20654     if (U.getResNo() != 0)
20655       continue;
20656 
20657     SDNode *User = U.getUser();
20658     unsigned Shift = 0;
20659 
20660     // Check if this is a trunc(lshr).
20661     if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
20662         isa<ConstantSDNode>(User->getOperand(1))) {
20663       Shift = User->getConstantOperandVal(1);
20664       User = *User->user_begin();
20665     }
20666 
20667     // At this point, User is a Truncate, iff we encountered, trunc or
20668     // trunc(lshr).
20669     if (User->getOpcode() != ISD::TRUNCATE)
20670       return false;
20671 
20672     // The width of the type must be a power of 2 and greater than 8-bits.
20673     // Otherwise the load cannot be represented in LLVM IR.
20674     // Moreover, if we shifted with a non-8-bits multiple, the slice
20675     // will be across several bytes. We do not support that.
20676     unsigned Width = User->getValueSizeInBits(0);
20677     if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
20678       return false;
20679 
20680     // Build the slice for this chain of computations.
20681     LoadedSlice LS(User, LD, Shift, &DAG);
20682     APInt CurrentUsedBits = LS.getUsedBits();
20683 
20684     // Check if this slice overlaps with another.
20685     if ((CurrentUsedBits & UsedBits) != 0)
20686       return false;
20687     // Update the bits used globally.
20688     UsedBits |= CurrentUsedBits;
20689 
20690     // Check if the new slice would be legal.
20691     if (!LS.isLegal())
20692       return false;
20693 
20694     // Record the slice.
20695     LoadedSlices.push_back(LS);
20696   }
20697 
20698   // Abort slicing if it does not seem to be profitable.
20699   if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
20700     return false;
20701 
20702   ++SlicedLoads;
20703 
20704   // Rewrite each chain to use an independent load.
20705   // By construction, each chain can be represented by a unique load.
20706 
20707   // Prepare the argument for the new token factor for all the slices.
20708   SmallVector<SDValue, 8> ArgChains;
20709   for (const LoadedSlice &LS : LoadedSlices) {
20710     SDValue SliceInst = LS.loadSlice();
20711     CombineTo(LS.Inst, SliceInst, true);
20712     if (SliceInst.getOpcode() != ISD::LOAD)
20713       SliceInst = SliceInst.getOperand(0);
20714     assert(SliceInst->getOpcode() == ISD::LOAD &&
20715            "It takes more than a zext to get to the loaded slice!!");
20716     ArgChains.push_back(SliceInst.getValue(1));
20717   }
20718 
20719   SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
20720                               ArgChains);
20721   DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
20722   AddToWorklist(Chain.getNode());
20723   return true;
20724 }
20725 
20726 /// Check to see if V is (and load (ptr), imm), where the load is having
20727 /// specific bytes cleared out.  If so, return the byte size being masked out
20728 /// and the shift amount.
20729 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)20730 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
20731   std::pair<unsigned, unsigned> Result(0, 0);
20732 
20733   // Check for the structure we're looking for.
20734   if (V->getOpcode() != ISD::AND ||
20735       !isa<ConstantSDNode>(V->getOperand(1)) ||
20736       !ISD::isNormalLoad(V->getOperand(0).getNode()))
20737     return Result;
20738 
20739   // Check the chain and pointer.
20740   LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
20741   if (LD->getBasePtr() != Ptr) return Result;  // Not from same pointer.
20742 
20743   // This only handles simple types.
20744   if (V.getValueType() != MVT::i16 &&
20745       V.getValueType() != MVT::i32 &&
20746       V.getValueType() != MVT::i64)
20747     return Result;
20748 
20749   // Check the constant mask.  Invert it so that the bits being masked out are
20750   // 0 and the bits being kept are 1.  Use getSExtValue so that leading bits
20751   // follow the sign bit for uniformity.
20752   uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
20753   unsigned NotMaskLZ = llvm::countl_zero(NotMask);
20754   if (NotMaskLZ & 7) return Result;  // Must be multiple of a byte.
20755   unsigned NotMaskTZ = llvm::countr_zero(NotMask);
20756   if (NotMaskTZ & 7) return Result;  // Must be multiple of a byte.
20757   if (NotMaskLZ == 64) return Result;  // All zero mask.
20758 
20759   // See if we have a continuous run of bits.  If so, we have 0*1+0*
20760   if (llvm::countr_one(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
20761     return Result;
20762 
20763   // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
20764   if (V.getValueType() != MVT::i64 && NotMaskLZ)
20765     NotMaskLZ -= 64-V.getValueSizeInBits();
20766 
20767   unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
20768   switch (MaskedBytes) {
20769   case 1:
20770   case 2:
20771   case 4: break;
20772   default: return Result; // All one mask, or 5-byte mask.
20773   }
20774 
20775   // Verify that the first bit starts at a multiple of mask so that the access
20776   // is aligned the same as the access width.
20777   if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
20778 
20779   // For narrowing to be valid, it must be the case that the load the
20780   // immediately preceding memory operation before the store.
20781   if (LD == Chain.getNode())
20782     ; // ok.
20783   else if (Chain->getOpcode() == ISD::TokenFactor &&
20784            SDValue(LD, 1).hasOneUse()) {
20785     // LD has only 1 chain use so they are no indirect dependencies.
20786     if (!LD->isOperandOf(Chain.getNode()))
20787       return Result;
20788   } else
20789     return Result; // Fail.
20790 
20791   Result.first = MaskedBytes;
20792   Result.second = NotMaskTZ/8;
20793   return Result;
20794 }
20795 
20796 /// Check to see if IVal is something that provides a value as specified by
20797 /// MaskInfo. If so, replace the specified store with a narrower store of
20798 /// truncated IVal.
20799 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)20800 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
20801                                 SDValue IVal, StoreSDNode *St,
20802                                 DAGCombiner *DC) {
20803   unsigned NumBytes = MaskInfo.first;
20804   unsigned ByteShift = MaskInfo.second;
20805   SelectionDAG &DAG = DC->getDAG();
20806 
20807   // Check to see if IVal is all zeros in the part being masked in by the 'or'
20808   // that uses this.  If not, this is not a replacement.
20809   APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
20810                                   ByteShift*8, (ByteShift+NumBytes)*8);
20811   if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
20812 
20813   // Check that it is legal on the target to do this.  It is legal if the new
20814   // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
20815   // legalization. If the source type is legal, but the store type isn't, see
20816   // if we can use a truncating store.
20817   MVT VT = MVT::getIntegerVT(NumBytes * 8);
20818   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20819   bool UseTruncStore;
20820   if (DC->isTypeLegal(VT))
20821     UseTruncStore = false;
20822   else if (TLI.isTypeLegal(IVal.getValueType()) &&
20823            TLI.isTruncStoreLegal(IVal.getValueType(), VT))
20824     UseTruncStore = true;
20825   else
20826     return SDValue();
20827 
20828   // Can't do this for indexed stores.
20829   if (St->isIndexed())
20830     return SDValue();
20831 
20832   // Check that the target doesn't think this is a bad idea.
20833   if (St->getMemOperand() &&
20834       !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
20835                               *St->getMemOperand()))
20836     return SDValue();
20837 
20838   // Okay, we can do this!  Replace the 'St' store with a store of IVal that is
20839   // shifted by ByteShift and truncated down to NumBytes.
20840   if (ByteShift) {
20841     SDLoc DL(IVal);
20842     IVal = DAG.getNode(
20843         ISD::SRL, DL, IVal.getValueType(), IVal,
20844         DAG.getShiftAmountConstant(ByteShift * 8, IVal.getValueType(), DL));
20845   }
20846 
20847   // Figure out the offset for the store and the alignment of the access.
20848   unsigned StOffset;
20849   if (DAG.getDataLayout().isLittleEndian())
20850     StOffset = ByteShift;
20851   else
20852     StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
20853 
20854   SDValue Ptr = St->getBasePtr();
20855   if (StOffset) {
20856     SDLoc DL(IVal);
20857     Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(StOffset), DL);
20858   }
20859 
20860   ++OpsNarrowed;
20861   if (UseTruncStore)
20862     return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
20863                              St->getPointerInfo().getWithOffset(StOffset), VT,
20864                              St->getBaseAlign());
20865 
20866   // Truncate down to the new size.
20867   IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
20868 
20869   return DAG.getStore(St->getChain(), SDLoc(St), IVal, Ptr,
20870                       St->getPointerInfo().getWithOffset(StOffset),
20871                       St->getBaseAlign());
20872 }
20873 
20874 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
20875 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
20876 /// narrowing the load and store if it would end up being a win for performance
20877 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)20878 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
20879   StoreSDNode *ST  = cast<StoreSDNode>(N);
20880   if (!ST->isSimple())
20881     return SDValue();
20882 
20883   SDValue Chain = ST->getChain();
20884   SDValue Value = ST->getValue();
20885   SDValue Ptr   = ST->getBasePtr();
20886   EVT VT = Value.getValueType();
20887 
20888   if (ST->isTruncatingStore() || VT.isVector())
20889     return SDValue();
20890 
20891   unsigned Opc = Value.getOpcode();
20892 
20893   if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
20894       !Value.hasOneUse())
20895     return SDValue();
20896 
20897   // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
20898   // is a byte mask indicating a consecutive number of bytes, check to see if
20899   // Y is known to provide just those bytes.  If so, we try to replace the
20900   // load + replace + store sequence with a single (narrower) store, which makes
20901   // the load dead.
20902   if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
20903     std::pair<unsigned, unsigned> MaskedLoad;
20904     MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
20905     if (MaskedLoad.first)
20906       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
20907                                                   Value.getOperand(1), ST,this))
20908         return NewST;
20909 
20910     // Or is commutative, so try swapping X and Y.
20911     MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
20912     if (MaskedLoad.first)
20913       if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
20914                                                   Value.getOperand(0), ST,this))
20915         return NewST;
20916   }
20917 
20918   if (!EnableReduceLoadOpStoreWidth)
20919     return SDValue();
20920 
20921   if (Value.getOperand(1).getOpcode() != ISD::Constant)
20922     return SDValue();
20923 
20924   SDValue N0 = Value.getOperand(0);
20925   if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
20926       Chain == SDValue(N0.getNode(), 1)) {
20927     LoadSDNode *LD = cast<LoadSDNode>(N0);
20928     if (LD->getBasePtr() != Ptr ||
20929         LD->getPointerInfo().getAddrSpace() !=
20930         ST->getPointerInfo().getAddrSpace())
20931       return SDValue();
20932 
20933     // Find the type NewVT to narrow the load / op / store to.
20934     SDValue N1 = Value.getOperand(1);
20935     unsigned BitWidth = N1.getValueSizeInBits();
20936     APInt Imm = N1->getAsAPIntVal();
20937     if (Opc == ISD::AND)
20938       Imm.flipAllBits();
20939     if (Imm == 0 || Imm.isAllOnes())
20940       return SDValue();
20941     // Find least/most significant bit that need to be part of the narrowed
20942     // operation. We assume target will need to address/access full bytes, so
20943     // we make sure to align LSB and MSB at byte boundaries.
20944     unsigned BitsPerByteMask = 7u;
20945     unsigned LSB = Imm.countr_zero() & ~BitsPerByteMask;
20946     unsigned MSB = (Imm.getActiveBits() - 1) | BitsPerByteMask;
20947     unsigned NewBW = NextPowerOf2(MSB - LSB);
20948     EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
20949     // The narrowing should be profitable, the load/store operation should be
20950     // legal (or custom) and the store size should be equal to the NewVT width.
20951     while (NewBW < BitWidth &&
20952            (NewVT.getStoreSizeInBits() != NewBW ||
20953             !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
20954             (!ReduceLoadOpStoreWidthForceNarrowingProfitable &&
20955              !TLI.isNarrowingProfitable(N, VT, NewVT)))) {
20956       NewBW = NextPowerOf2(NewBW);
20957       NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
20958     }
20959     if (NewBW >= BitWidth)
20960       return SDValue();
20961 
20962     // If we come this far NewVT/NewBW reflect a power-of-2 sized type that is
20963     // large enough to cover all bits that should be modified. This type might
20964     // however be larger than really needed (such as i32 while we actually only
20965     // need to modify one byte). Now we need to find our how to align the memory
20966     // accesses to satisfy preferred alignments as well as avoiding to access
20967     // memory outside the store size of the orignal access.
20968 
20969     unsigned VTStoreSize = VT.getStoreSizeInBits().getFixedValue();
20970 
20971     // Let ShAmt denote amount of bits to skip, counted from the least
20972     // significant bits of Imm. And let PtrOff how much the pointer needs to be
20973     // offsetted (in bytes) for the new access.
20974     unsigned ShAmt = 0;
20975     uint64_t PtrOff = 0;
20976     for (; ShAmt + NewBW <= VTStoreSize; ShAmt += 8) {
20977       // Make sure the range [ShAmt, ShAmt+NewBW) cover both LSB and MSB.
20978       if (ShAmt > LSB)
20979         return SDValue();
20980       if (ShAmt + NewBW < MSB)
20981         continue;
20982 
20983       // Calculate PtrOff.
20984       unsigned PtrAdjustmentInBits = DAG.getDataLayout().isBigEndian()
20985                                          ? VTStoreSize - NewBW - ShAmt
20986                                          : ShAmt;
20987       PtrOff = PtrAdjustmentInBits / 8;
20988 
20989       // Now check if narrow access is allowed and fast, considering alignments.
20990       unsigned IsFast = 0;
20991       Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
20992       if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
20993                                  LD->getAddressSpace(), NewAlign,
20994                                  LD->getMemOperand()->getFlags(), &IsFast) &&
20995           IsFast)
20996         break;
20997     }
20998     // If loop above did not find any accepted ShAmt we need to exit here.
20999     if (ShAmt + NewBW > VTStoreSize)
21000       return SDValue();
21001 
21002     APInt NewImm = Imm.lshr(ShAmt).trunc(NewBW);
21003     if (Opc == ISD::AND)
21004       NewImm.flipAllBits();
21005     Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
21006     SDValue NewPtr =
21007         DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(PtrOff), SDLoc(LD));
21008     SDValue NewLD =
21009         DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
21010                     LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
21011                     LD->getMemOperand()->getFlags(), LD->getAAInfo());
21012     SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
21013                                  DAG.getConstant(NewImm, SDLoc(Value), NewVT));
21014     SDValue NewST =
21015         DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
21016                      ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
21017 
21018     AddToWorklist(NewPtr.getNode());
21019     AddToWorklist(NewLD.getNode());
21020     AddToWorklist(NewVal.getNode());
21021     WorklistRemover DeadNodes(*this);
21022     DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
21023     ++OpsNarrowed;
21024     return NewST;
21025   }
21026 
21027   return SDValue();
21028 }
21029 
21030 /// For a given floating point load / store pair, if the load value isn't used
21031 /// by any other operations, then consider transforming the pair to integer
21032 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)21033 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
21034   StoreSDNode *ST  = cast<StoreSDNode>(N);
21035   SDValue Value = ST->getValue();
21036   if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
21037       Value.hasOneUse()) {
21038     LoadSDNode *LD = cast<LoadSDNode>(Value);
21039     EVT VT = LD->getMemoryVT();
21040     if (!VT.isSimple() || !VT.isFloatingPoint() || VT != ST->getMemoryVT() ||
21041         LD->isNonTemporal() || ST->isNonTemporal() ||
21042         LD->getPointerInfo().getAddrSpace() != 0 ||
21043         ST->getPointerInfo().getAddrSpace() != 0)
21044       return SDValue();
21045 
21046     TypeSize VTSize = VT.getSizeInBits();
21047 
21048     // We don't know the size of scalable types at compile time so we cannot
21049     // create an integer of the equivalent size.
21050     if (VTSize.isScalable())
21051       return SDValue();
21052 
21053     unsigned FastLD = 0, FastST = 0;
21054     EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedValue());
21055     if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
21056         !TLI.isOperationLegal(ISD::STORE, IntVT) ||
21057         !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
21058         !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
21059         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
21060                                 *LD->getMemOperand(), &FastLD) ||
21061         !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
21062                                 *ST->getMemOperand(), &FastST) ||
21063         !FastLD || !FastST)
21064       return SDValue();
21065 
21066     SDValue NewLD = DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(),
21067                                 LD->getBasePtr(), LD->getMemOperand());
21068 
21069     SDValue NewST = DAG.getStore(ST->getChain(), SDLoc(N), NewLD,
21070                                  ST->getBasePtr(), ST->getMemOperand());
21071 
21072     AddToWorklist(NewLD.getNode());
21073     AddToWorklist(NewST.getNode());
21074     WorklistRemover DeadNodes(*this);
21075     DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
21076     ++LdStFP2Int;
21077     return NewST;
21078   }
21079 
21080   return SDValue();
21081 }
21082 
21083 // This is a helper function for visitMUL to check the profitability
21084 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
21085 // MulNode is the original multiply, AddNode is (add x, c1),
21086 // and ConstNode is c2.
21087 //
21088 // If the (add x, c1) has multiple uses, we could increase
21089 // the number of adds if we make this transformation.
21090 // It would only be worth doing this if we can remove a
21091 // multiply in the process. Check for that here.
21092 // To illustrate:
21093 //     (A + c1) * c3
21094 //     (A + c2) * c3
21095 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)21096 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
21097                                               SDValue ConstNode) {
21098   // If the add only has one use, and the target thinks the folding is
21099   // profitable or does not lead to worse code, this would be OK to do.
21100   if (AddNode->hasOneUse() &&
21101       TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
21102     return true;
21103 
21104   // Walk all the users of the constant with which we're multiplying.
21105   for (SDNode *User : ConstNode->users()) {
21106     if (User == MulNode) // This use is the one we're on right now. Skip it.
21107       continue;
21108 
21109     if (User->getOpcode() == ISD::MUL) { // We have another multiply use.
21110       SDNode *OtherOp;
21111       SDNode *MulVar = AddNode.getOperand(0).getNode();
21112 
21113       // OtherOp is what we're multiplying against the constant.
21114       if (User->getOperand(0) == ConstNode)
21115         OtherOp = User->getOperand(1).getNode();
21116       else
21117         OtherOp = User->getOperand(0).getNode();
21118 
21119       // Check to see if multiply is with the same operand of our "add".
21120       //
21121       //     ConstNode  = CONST
21122       //     User = ConstNode * A  <-- visiting User. OtherOp is A.
21123       //     ...
21124       //     AddNode  = (A + c1)  <-- MulVar is A.
21125       //         = AddNode * ConstNode   <-- current visiting instruction.
21126       //
21127       // If we make this transformation, we will have a common
21128       // multiply (ConstNode * A) that we can save.
21129       if (OtherOp == MulVar)
21130         return true;
21131 
21132       // Now check to see if a future expansion will give us a common
21133       // multiply.
21134       //
21135       //     ConstNode  = CONST
21136       //     AddNode    = (A + c1)
21137       //     ...   = AddNode * ConstNode <-- current visiting instruction.
21138       //     ...
21139       //     OtherOp = (A + c2)
21140       //     User    = OtherOp * ConstNode <-- visiting User.
21141       //
21142       // If we make this transformation, we will have a common
21143       // multiply (CONST * A) after we also do the same transformation
21144       // to the "t2" instruction.
21145       if (OtherOp->getOpcode() == ISD::ADD &&
21146           DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
21147           OtherOp->getOperand(0).getNode() == MulVar)
21148         return true;
21149     }
21150   }
21151 
21152   // Didn't find a case where this would be profitable.
21153   return false;
21154 }
21155 
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)21156 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
21157                                          unsigned NumStores) {
21158   SmallVector<SDValue, 8> Chains;
21159   SmallPtrSet<const SDNode *, 8> Visited;
21160   SDLoc StoreDL(StoreNodes[0].MemNode);
21161 
21162   for (unsigned i = 0; i < NumStores; ++i) {
21163     Visited.insert(StoreNodes[i].MemNode);
21164   }
21165 
21166   // don't include nodes that are children or repeated nodes.
21167   for (unsigned i = 0; i < NumStores; ++i) {
21168     if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
21169       Chains.push_back(StoreNodes[i].MemNode->getChain());
21170   }
21171 
21172   assert(!Chains.empty() && "Chain should have generated a chain");
21173   return DAG.getTokenFactor(StoreDL, Chains);
21174 }
21175 
hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes)21176 bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
21177   const Value *UnderlyingObj = nullptr;
21178   for (const auto &MemOp : StoreNodes) {
21179     const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
21180     // Pseudo value like stack frame has its own frame index and size, should
21181     // not use the first store's frame index for other frames.
21182     if (MMO->getPseudoValue())
21183       return false;
21184 
21185     if (!MMO->getValue())
21186       return false;
21187 
21188     const Value *Obj = getUnderlyingObject(MMO->getValue());
21189 
21190     if (UnderlyingObj && UnderlyingObj != Obj)
21191       return false;
21192 
21193     if (!UnderlyingObj)
21194       UnderlyingObj = Obj;
21195   }
21196 
21197   return true;
21198 }
21199 
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)21200 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
21201     SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
21202     bool IsConstantSrc, bool UseVector, bool UseTrunc) {
21203   // Make sure we have something to merge.
21204   if (NumStores < 2)
21205     return false;
21206 
21207   assert((!UseTrunc || !UseVector) &&
21208          "This optimization cannot emit a vector truncating store");
21209 
21210   // The latest Node in the DAG.
21211   SDLoc DL(StoreNodes[0].MemNode);
21212 
21213   TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
21214   unsigned SizeInBits = NumStores * ElementSizeBits;
21215   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21216 
21217   std::optional<MachineMemOperand::Flags> Flags;
21218   AAMDNodes AAInfo;
21219   for (unsigned I = 0; I != NumStores; ++I) {
21220     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
21221     if (!Flags) {
21222       Flags = St->getMemOperand()->getFlags();
21223       AAInfo = St->getAAInfo();
21224       continue;
21225     }
21226     // Skip merging if there's an inconsistent flag.
21227     if (Flags != St->getMemOperand()->getFlags())
21228       return false;
21229     // Concatenate AA metadata.
21230     AAInfo = AAInfo.concat(St->getAAInfo());
21231   }
21232 
21233   EVT StoreTy;
21234   if (UseVector) {
21235     unsigned Elts = NumStores * NumMemElts;
21236     // Get the type for the merged vector store.
21237     StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
21238   } else
21239     StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
21240 
21241   SDValue StoredVal;
21242   if (UseVector) {
21243     if (IsConstantSrc) {
21244       SmallVector<SDValue, 8> BuildVector;
21245       for (unsigned I = 0; I != NumStores; ++I) {
21246         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
21247         SDValue Val = St->getValue();
21248         // If constant is of the wrong type, convert it now.  This comes up
21249         // when one of our stores was truncating.
21250         if (MemVT != Val.getValueType()) {
21251           Val = peekThroughBitcasts(Val);
21252           // Deal with constants of wrong size.
21253           if (ElementSizeBits != Val.getValueSizeInBits()) {
21254             auto *C = dyn_cast<ConstantSDNode>(Val);
21255             if (!C)
21256               // Not clear how to truncate FP values.
21257               // TODO: Handle truncation of build_vector constants
21258               return false;
21259 
21260             EVT IntMemVT =
21261                 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
21262             Val = DAG.getConstant(C->getAPIntValue()
21263                                       .zextOrTrunc(Val.getValueSizeInBits())
21264                                       .zextOrTrunc(ElementSizeBits),
21265                                   SDLoc(C), IntMemVT);
21266           }
21267           // Make sure correctly size type is the correct type.
21268           Val = DAG.getBitcast(MemVT, Val);
21269         }
21270         BuildVector.push_back(Val);
21271       }
21272       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
21273                                                : ISD::BUILD_VECTOR,
21274                               DL, StoreTy, BuildVector);
21275     } else {
21276       SmallVector<SDValue, 8> Ops;
21277       for (unsigned i = 0; i < NumStores; ++i) {
21278         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
21279         SDValue Val = peekThroughBitcasts(St->getValue());
21280         // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
21281         // type MemVT. If the underlying value is not the correct
21282         // type, but it is an extraction of an appropriate vector we
21283         // can recast Val to be of the correct type. This may require
21284         // converting between EXTRACT_VECTOR_ELT and
21285         // EXTRACT_SUBVECTOR.
21286         if ((MemVT != Val.getValueType()) &&
21287             (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
21288              Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
21289           EVT MemVTScalarTy = MemVT.getScalarType();
21290           // We may need to add a bitcast here to get types to line up.
21291           if (MemVTScalarTy != Val.getValueType().getScalarType()) {
21292             Val = DAG.getBitcast(MemVT, Val);
21293           } else if (MemVT.isVector() &&
21294                      Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
21295             Val = DAG.getNode(ISD::BUILD_VECTOR, DL, MemVT, Val);
21296           } else {
21297             unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
21298                                             : ISD::EXTRACT_VECTOR_ELT;
21299             SDValue Vec = Val.getOperand(0);
21300             SDValue Idx = Val.getOperand(1);
21301             Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
21302           }
21303         }
21304         Ops.push_back(Val);
21305       }
21306 
21307       // Build the extracted vector elements back into a vector.
21308       StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
21309                                                : ISD::BUILD_VECTOR,
21310                               DL, StoreTy, Ops);
21311     }
21312   } else {
21313     // We should always use a vector store when merging extracted vector
21314     // elements, so this path implies a store of constants.
21315     assert(IsConstantSrc && "Merged vector elements should use vector store");
21316 
21317     APInt StoreInt(SizeInBits, 0);
21318 
21319     // Construct a single integer constant which is made of the smaller
21320     // constant inputs.
21321     bool IsLE = DAG.getDataLayout().isLittleEndian();
21322     for (unsigned i = 0; i < NumStores; ++i) {
21323       unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
21324       StoreSDNode *St  = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
21325 
21326       SDValue Val = St->getValue();
21327       Val = peekThroughBitcasts(Val);
21328       StoreInt <<= ElementSizeBits;
21329       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
21330         StoreInt |= C->getAPIntValue()
21331                         .zextOrTrunc(ElementSizeBits)
21332                         .zextOrTrunc(SizeInBits);
21333       } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
21334         StoreInt |= C->getValueAPF()
21335                         .bitcastToAPInt()
21336                         .zextOrTrunc(ElementSizeBits)
21337                         .zextOrTrunc(SizeInBits);
21338         // If fp truncation is necessary give up for now.
21339         if (MemVT.getSizeInBits() != ElementSizeBits)
21340           return false;
21341       } else if (ISD::isBuildVectorOfConstantSDNodes(Val.getNode()) ||
21342                  ISD::isBuildVectorOfConstantFPSDNodes(Val.getNode())) {
21343         // Not yet handled
21344         return false;
21345       } else {
21346         llvm_unreachable("Invalid constant element type");
21347       }
21348     }
21349 
21350     // Create the new Load and Store operations.
21351     StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
21352   }
21353 
21354   LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21355   SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
21356   bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
21357 
21358   // make sure we use trunc store if it's necessary to be legal.
21359   // When generate the new widen store, if the first store's pointer info can
21360   // not be reused, discard the pointer info except the address space because
21361   // now the widen store can not be represented by the original pointer info
21362   // which is for the narrow memory object.
21363   SDValue NewStore;
21364   if (!UseTrunc) {
21365     NewStore = DAG.getStore(
21366         NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
21367         CanReusePtrInfo
21368             ? FirstInChain->getPointerInfo()
21369             : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
21370         FirstInChain->getAlign(), *Flags, AAInfo);
21371   } else { // Must be realized as a trunc store
21372     EVT LegalizedStoredValTy =
21373         TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
21374     unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
21375     ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
21376     SDValue ExtendedStoreVal =
21377         DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
21378                         LegalizedStoredValTy);
21379     NewStore = DAG.getTruncStore(
21380         NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
21381         CanReusePtrInfo
21382             ? FirstInChain->getPointerInfo()
21383             : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
21384         StoredVal.getValueType() /*TVT*/, FirstInChain->getAlign(), *Flags,
21385         AAInfo);
21386   }
21387 
21388   // Replace all merged stores with the new store.
21389   for (unsigned i = 0; i < NumStores; ++i)
21390     CombineTo(StoreNodes[i].MemNode, NewStore);
21391 
21392   AddToWorklist(NewChain.getNode());
21393   return true;
21394 }
21395 
21396 SDNode *
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes)21397 DAGCombiner::getStoreMergeCandidates(StoreSDNode *St,
21398                                      SmallVectorImpl<MemOpLink> &StoreNodes) {
21399   // This holds the base pointer, index, and the offset in bytes from the base
21400   // pointer. We must have a base and an offset. Do not handle stores to undef
21401   // base pointers.
21402   BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
21403   if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
21404     return nullptr;
21405 
21406   SDValue Val = peekThroughBitcasts(St->getValue());
21407   StoreSource StoreSrc = getStoreSource(Val);
21408   assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
21409 
21410   // Match on loadbaseptr if relevant.
21411   EVT MemVT = St->getMemoryVT();
21412   BaseIndexOffset LBasePtr;
21413   EVT LoadVT;
21414   if (StoreSrc == StoreSource::Load) {
21415     auto *Ld = cast<LoadSDNode>(Val);
21416     LBasePtr = BaseIndexOffset::match(Ld, DAG);
21417     LoadVT = Ld->getMemoryVT();
21418     // Load and store should be the same type.
21419     if (MemVT != LoadVT)
21420       return nullptr;
21421     // Loads must only have one use.
21422     if (!Ld->hasNUsesOfValue(1, 0))
21423       return nullptr;
21424     // The memory operands must not be volatile/indexed/atomic.
21425     // TODO: May be able to relax for unordered atomics (see D66309)
21426     if (!Ld->isSimple() || Ld->isIndexed())
21427       return nullptr;
21428   }
21429   auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
21430                             int64_t &Offset) -> bool {
21431     // The memory operands must not be volatile/indexed/atomic.
21432     // TODO: May be able to relax for unordered atomics (see D66309)
21433     if (!Other->isSimple() || Other->isIndexed())
21434       return false;
21435     // Don't mix temporal stores with non-temporal stores.
21436     if (St->isNonTemporal() != Other->isNonTemporal())
21437       return false;
21438     if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*St, *Other))
21439       return false;
21440     SDValue OtherBC = peekThroughBitcasts(Other->getValue());
21441     // Allow merging constants of different types as integers.
21442     bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
21443                                            : Other->getMemoryVT() != MemVT;
21444     switch (StoreSrc) {
21445     case StoreSource::Load: {
21446       if (NoTypeMatch)
21447         return false;
21448       // The Load's Base Ptr must also match.
21449       auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
21450       if (!OtherLd)
21451         return false;
21452       BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
21453       if (LoadVT != OtherLd->getMemoryVT())
21454         return false;
21455       // Loads must only have one use.
21456       if (!OtherLd->hasNUsesOfValue(1, 0))
21457         return false;
21458       // The memory operands must not be volatile/indexed/atomic.
21459       // TODO: May be able to relax for unordered atomics (see D66309)
21460       if (!OtherLd->isSimple() || OtherLd->isIndexed())
21461         return false;
21462       // Don't mix temporal loads with non-temporal loads.
21463       if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
21464         return false;
21465       if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*cast<LoadSDNode>(Val),
21466                                                    *OtherLd))
21467         return false;
21468       if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
21469         return false;
21470       break;
21471     }
21472     case StoreSource::Constant:
21473       if (NoTypeMatch)
21474         return false;
21475       if (getStoreSource(OtherBC) != StoreSource::Constant)
21476         return false;
21477       break;
21478     case StoreSource::Extract:
21479       // Do not merge truncated stores here.
21480       if (Other->isTruncatingStore())
21481         return false;
21482       if (!MemVT.bitsEq(OtherBC.getValueType()))
21483         return false;
21484       if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
21485           OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
21486         return false;
21487       break;
21488     default:
21489       llvm_unreachable("Unhandled store source for merging");
21490     }
21491     Ptr = BaseIndexOffset::match(Other, DAG);
21492     return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
21493   };
21494 
21495   // We are looking for a root node which is an ancestor to all mergable
21496   // stores. We search up through a load, to our root and then down
21497   // through all children. For instance we will find Store{1,2,3} if
21498   // St is Store1, Store2. or Store3 where the root is not a load
21499   // which always true for nonvolatile ops. TODO: Expand
21500   // the search to find all valid candidates through multiple layers of loads.
21501   //
21502   // Root
21503   // |-------|-------|
21504   // Load    Load    Store3
21505   // |       |
21506   // Store1   Store2
21507   //
21508   // FIXME: We should be able to climb and
21509   // descend TokenFactors to find candidates as well.
21510 
21511   SDNode *RootNode = St->getChain().getNode();
21512   // Bail out if we already analyzed this root node and found nothing.
21513   if (ChainsWithoutMergeableStores.contains(RootNode))
21514     return nullptr;
21515 
21516   // Check if the pair of StoreNode and the RootNode already bail out many
21517   // times which is over the limit in dependence check.
21518   auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
21519                                         SDNode *RootNode) -> bool {
21520     auto RootCount = StoreRootCountMap.find(StoreNode);
21521     return RootCount != StoreRootCountMap.end() &&
21522            RootCount->second.first == RootNode &&
21523            RootCount->second.second > StoreMergeDependenceLimit;
21524   };
21525 
21526   auto TryToAddCandidate = [&](SDUse &Use) {
21527     // This must be a chain use.
21528     if (Use.getOperandNo() != 0)
21529       return;
21530     if (auto *OtherStore = dyn_cast<StoreSDNode>(Use.getUser())) {
21531       BaseIndexOffset Ptr;
21532       int64_t PtrDiff;
21533       if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
21534           !OverLimitInDependenceCheck(OtherStore, RootNode))
21535         StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
21536     }
21537   };
21538 
21539   unsigned NumNodesExplored = 0;
21540   const unsigned MaxSearchNodes = 1024;
21541   if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
21542     RootNode = Ldn->getChain().getNode();
21543     // Bail out if we already analyzed this root node and found nothing.
21544     if (ChainsWithoutMergeableStores.contains(RootNode))
21545       return nullptr;
21546     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
21547          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
21548       SDNode *User = I->getUser();
21549       if (I->getOperandNo() == 0 && isa<LoadSDNode>(User)) { // walk down chain
21550         for (SDUse &U2 : User->uses())
21551           TryToAddCandidate(U2);
21552       }
21553       // Check stores that depend on the root (e.g. Store 3 in the chart above).
21554       if (I->getOperandNo() == 0 && isa<StoreSDNode>(User)) {
21555         TryToAddCandidate(*I);
21556       }
21557     }
21558   } else {
21559     for (auto I = RootNode->use_begin(), E = RootNode->use_end();
21560          I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
21561       TryToAddCandidate(*I);
21562   }
21563 
21564   return RootNode;
21565 }
21566 
21567 // We need to check that merging these stores does not cause a loop in the
21568 // DAG. Any store candidate may depend on another candidate indirectly through
21569 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)21570 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
21571     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
21572     SDNode *RootNode) {
21573   // FIXME: We should be able to truncate a full search of
21574   // predecessors by doing a BFS and keeping tabs the originating
21575   // stores from which worklist nodes come from in a similar way to
21576   // TokenFactor simplfication.
21577 
21578   SmallPtrSet<const SDNode *, 32> Visited;
21579   SmallVector<const SDNode *, 8> Worklist;
21580 
21581   // RootNode is a predecessor to all candidates so we need not search
21582   // past it. Add RootNode (peeking through TokenFactors). Do not count
21583   // these towards size check.
21584 
21585   Worklist.push_back(RootNode);
21586   while (!Worklist.empty()) {
21587     auto N = Worklist.pop_back_val();
21588     if (!Visited.insert(N).second)
21589       continue; // Already present in Visited.
21590     if (N->getOpcode() == ISD::TokenFactor) {
21591       for (SDValue Op : N->ops())
21592         Worklist.push_back(Op.getNode());
21593     }
21594   }
21595 
21596   // Don't count pruning nodes towards max.
21597   unsigned int Max = 1024 + Visited.size();
21598   // Search Ops of store candidates.
21599   for (unsigned i = 0; i < NumStores; ++i) {
21600     SDNode *N = StoreNodes[i].MemNode;
21601     // Of the 4 Store Operands:
21602     //   * Chain (Op 0) -> We have already considered these
21603     //                     in candidate selection, but only by following the
21604     //                     chain dependencies. We could still have a chain
21605     //                     dependency to a load, that has a non-chain dep to
21606     //                     another load, that depends on a store, etc. So it is
21607     //                     possible to have dependencies that consist of a mix
21608     //                     of chain and non-chain deps, and we need to include
21609     //                     chain operands in the analysis here..
21610     //   * Value (Op 1) -> Cycles may happen (e.g. through load chains)
21611     //   * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
21612     //                       but aren't necessarily fromt the same base node, so
21613     //                       cycles possible (e.g. via indexed store).
21614     //   * (Op 3) -> Represents the pre or post-indexing offset (or undef for
21615     //               non-indexed stores). Not constant on all targets (e.g. ARM)
21616     //               and so can participate in a cycle.
21617     for (const SDValue &Op : N->op_values())
21618       Worklist.push_back(Op.getNode());
21619   }
21620   // Search through DAG. We can stop early if we find a store node.
21621   for (unsigned i = 0; i < NumStores; ++i)
21622     if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
21623                                      Max)) {
21624       // If the searching bail out, record the StoreNode and RootNode in the
21625       // StoreRootCountMap. If we have seen the pair many times over a limit,
21626       // we won't add the StoreNode into StoreNodes set again.
21627       if (Visited.size() >= Max) {
21628         auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
21629         if (RootCount.first == RootNode)
21630           RootCount.second++;
21631         else
21632           RootCount = {RootNode, 1};
21633       }
21634       return false;
21635     }
21636   return true;
21637 }
21638 
hasCallInLdStChain(StoreSDNode * St,LoadSDNode * Ld)21639 bool DAGCombiner::hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld) {
21640   SmallPtrSet<const SDNode *, 32> Visited;
21641   SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
21642   Worklist.emplace_back(St->getChain().getNode(), false);
21643 
21644   while (!Worklist.empty()) {
21645     auto [Node, FoundCall] = Worklist.pop_back_val();
21646     if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
21647       continue;
21648 
21649     switch (Node->getOpcode()) {
21650     case ISD::CALLSEQ_END:
21651       Worklist.emplace_back(Node->getOperand(0).getNode(), true);
21652       break;
21653     case ISD::TokenFactor:
21654       for (SDValue Op : Node->ops())
21655         Worklist.emplace_back(Op.getNode(), FoundCall);
21656       break;
21657     case ISD::LOAD:
21658       if (Node == Ld)
21659         return FoundCall;
21660       [[fallthrough]];
21661     default:
21662       assert(Node->getOperand(0).getValueType() == MVT::Other &&
21663              "Invalid chain type");
21664       Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
21665       break;
21666     }
21667   }
21668   return false;
21669 }
21670 
21671 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const21672 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
21673                                   int64_t ElementSizeBytes) const {
21674   while (true) {
21675     // Find a store past the width of the first store.
21676     size_t StartIdx = 0;
21677     while ((StartIdx + 1 < StoreNodes.size()) &&
21678            StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
21679               StoreNodes[StartIdx + 1].OffsetFromBase)
21680       ++StartIdx;
21681 
21682     // Bail if we don't have enough candidates to merge.
21683     if (StartIdx + 1 >= StoreNodes.size())
21684       return 0;
21685 
21686     // Trim stores that overlapped with the first store.
21687     if (StartIdx)
21688       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
21689 
21690     // Scan the memory operations on the chain and find the first
21691     // non-consecutive store memory address.
21692     unsigned NumConsecutiveStores = 1;
21693     int64_t StartAddress = StoreNodes[0].OffsetFromBase;
21694     // Check that the addresses are consecutive starting from the second
21695     // element in the list of stores.
21696     for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
21697       int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
21698       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
21699         break;
21700       NumConsecutiveStores = i + 1;
21701     }
21702     if (NumConsecutiveStores > 1)
21703       return NumConsecutiveStores;
21704 
21705     // There are no consecutive stores at the start of the list.
21706     // Remove the first store and try again.
21707     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
21708   }
21709 }
21710 
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)21711 bool DAGCombiner::tryStoreMergeOfConstants(
21712     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
21713     EVT MemVT, SDNode *RootNode, bool AllowVectors) {
21714   LLVMContext &Context = *DAG.getContext();
21715   const DataLayout &DL = DAG.getDataLayout();
21716   int64_t ElementSizeBytes = MemVT.getStoreSize();
21717   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21718   bool MadeChange = false;
21719 
21720   // Store the constants into memory as one consecutive store.
21721   while (NumConsecutiveStores >= 2) {
21722     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21723     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21724     Align FirstStoreAlign = FirstInChain->getAlign();
21725     unsigned LastLegalType = 1;
21726     unsigned LastLegalVectorType = 1;
21727     bool LastIntegerTrunc = false;
21728     bool NonZero = false;
21729     unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
21730     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21731       StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
21732       SDValue StoredVal = ST->getValue();
21733       bool IsElementZero = false;
21734       if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
21735         IsElementZero = C->isZero();
21736       else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
21737         IsElementZero = C->getConstantFPValue()->isNullValue();
21738       else if (ISD::isBuildVectorAllZeros(StoredVal.getNode()))
21739         IsElementZero = true;
21740       if (IsElementZero) {
21741         if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
21742           FirstZeroAfterNonZero = i;
21743       }
21744       NonZero |= !IsElementZero;
21745 
21746       // Find a legal type for the constant store.
21747       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
21748       EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
21749       unsigned IsFast = 0;
21750 
21751       // Break early when size is too large to be legal.
21752       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
21753         break;
21754 
21755       if (TLI.isTypeLegal(StoreTy) &&
21756           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
21757                                DAG.getMachineFunction()) &&
21758           TLI.allowsMemoryAccess(Context, DL, StoreTy,
21759                                  *FirstInChain->getMemOperand(), &IsFast) &&
21760           IsFast) {
21761         LastIntegerTrunc = false;
21762         LastLegalType = i + 1;
21763         // Or check whether a truncstore is legal.
21764       } else if (TLI.getTypeAction(Context, StoreTy) ==
21765                  TargetLowering::TypePromoteInteger) {
21766         EVT LegalizedStoredValTy =
21767             TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
21768         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
21769             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
21770                                  DAG.getMachineFunction()) &&
21771             TLI.allowsMemoryAccess(Context, DL, StoreTy,
21772                                    *FirstInChain->getMemOperand(), &IsFast) &&
21773             IsFast) {
21774           LastIntegerTrunc = true;
21775           LastLegalType = i + 1;
21776         }
21777       }
21778 
21779       // We only use vectors if the target allows it and the function is not
21780       // marked with the noimplicitfloat attribute.
21781       if (TLI.storeOfVectorConstantIsCheap(!NonZero, MemVT, i + 1, FirstStoreAS) &&
21782           AllowVectors) {
21783         // Find a legal type for the vector store.
21784         unsigned Elts = (i + 1) * NumMemElts;
21785         EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
21786         if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
21787             TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
21788             TLI.allowsMemoryAccess(Context, DL, Ty,
21789                                    *FirstInChain->getMemOperand(), &IsFast) &&
21790             IsFast)
21791           LastLegalVectorType = i + 1;
21792       }
21793     }
21794 
21795     bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
21796     unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
21797     bool UseTrunc = LastIntegerTrunc && !UseVector;
21798 
21799     // Check if we found a legal integer type that creates a meaningful
21800     // merge.
21801     if (NumElem < 2) {
21802       // We know that candidate stores are in order and of correct
21803       // shape. While there is no mergeable sequence from the
21804       // beginning one may start later in the sequence. The only
21805       // reason a merge of size N could have failed where another of
21806       // the same size would not have, is if the alignment has
21807       // improved or we've dropped a non-zero value. Drop as many
21808       // candidates as we can here.
21809       unsigned NumSkip = 1;
21810       while ((NumSkip < NumConsecutiveStores) &&
21811              (NumSkip < FirstZeroAfterNonZero) &&
21812              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21813         NumSkip++;
21814 
21815       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
21816       NumConsecutiveStores -= NumSkip;
21817       continue;
21818     }
21819 
21820     // Check that we can merge these candidates without causing a cycle.
21821     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
21822                                                   RootNode)) {
21823       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21824       NumConsecutiveStores -= NumElem;
21825       continue;
21826     }
21827 
21828     MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
21829                                                   /*IsConstantSrc*/ true,
21830                                                   UseVector, UseTrunc);
21831 
21832     // Remove merged stores for next iteration.
21833     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21834     NumConsecutiveStores -= NumElem;
21835   }
21836   return MadeChange;
21837 }
21838 
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)21839 bool DAGCombiner::tryStoreMergeOfExtracts(
21840     SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
21841     EVT MemVT, SDNode *RootNode) {
21842   LLVMContext &Context = *DAG.getContext();
21843   const DataLayout &DL = DAG.getDataLayout();
21844   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21845   bool MadeChange = false;
21846 
21847   // Loop on Consecutive Stores on success.
21848   while (NumConsecutiveStores >= 2) {
21849     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21850     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21851     Align FirstStoreAlign = FirstInChain->getAlign();
21852     unsigned NumStoresToMerge = 1;
21853     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21854       // Find a legal type for the vector store.
21855       unsigned Elts = (i + 1) * NumMemElts;
21856       EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
21857       unsigned IsFast = 0;
21858 
21859       // Break early when size is too large to be legal.
21860       if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
21861         break;
21862 
21863       if (TLI.isTypeLegal(Ty) &&
21864           TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
21865           TLI.allowsMemoryAccess(Context, DL, Ty,
21866                                  *FirstInChain->getMemOperand(), &IsFast) &&
21867           IsFast)
21868         NumStoresToMerge = i + 1;
21869     }
21870 
21871     // Check if we found a legal integer type creating a meaningful
21872     // merge.
21873     if (NumStoresToMerge < 2) {
21874       // We know that candidate stores are in order and of correct
21875       // shape. While there is no mergeable sequence from the
21876       // beginning one may start later in the sequence. The only
21877       // reason a merge of size N could have failed where another of
21878       // the same size would not have, is if the alignment has
21879       // improved. Drop as many candidates as we can here.
21880       unsigned NumSkip = 1;
21881       while ((NumSkip < NumConsecutiveStores) &&
21882              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21883         NumSkip++;
21884 
21885       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
21886       NumConsecutiveStores -= NumSkip;
21887       continue;
21888     }
21889 
21890     // Check that we can merge these candidates without causing a cycle.
21891     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
21892                                                   RootNode)) {
21893       StoreNodes.erase(StoreNodes.begin(),
21894                        StoreNodes.begin() + NumStoresToMerge);
21895       NumConsecutiveStores -= NumStoresToMerge;
21896       continue;
21897     }
21898 
21899     MadeChange |= mergeStoresOfConstantsOrVecElts(
21900         StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
21901         /*UseVector*/ true, /*UseTrunc*/ false);
21902 
21903     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
21904     NumConsecutiveStores -= NumStoresToMerge;
21905   }
21906   return MadeChange;
21907 }
21908 
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)21909 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
21910                                        unsigned NumConsecutiveStores, EVT MemVT,
21911                                        SDNode *RootNode, bool AllowVectors,
21912                                        bool IsNonTemporalStore,
21913                                        bool IsNonTemporalLoad) {
21914   LLVMContext &Context = *DAG.getContext();
21915   const DataLayout &DL = DAG.getDataLayout();
21916   int64_t ElementSizeBytes = MemVT.getStoreSize();
21917   unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21918   bool MadeChange = false;
21919 
21920   // Look for load nodes which are used by the stored values.
21921   SmallVector<MemOpLink, 8> LoadNodes;
21922 
21923   // Find acceptable loads. Loads need to have the same chain (token factor),
21924   // must not be zext, volatile, indexed, and they must be consecutive.
21925   BaseIndexOffset LdBasePtr;
21926 
21927   for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21928     StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
21929     SDValue Val = peekThroughBitcasts(St->getValue());
21930     LoadSDNode *Ld = cast<LoadSDNode>(Val);
21931 
21932     BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
21933     // If this is not the first ptr that we check.
21934     int64_t LdOffset = 0;
21935     if (LdBasePtr.getBase().getNode()) {
21936       // The base ptr must be the same.
21937       if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
21938         break;
21939     } else {
21940       // Check that all other base pointers are the same as this one.
21941       LdBasePtr = LdPtr;
21942     }
21943 
21944     // We found a potential memory operand to merge.
21945     LoadNodes.push_back(MemOpLink(Ld, LdOffset));
21946   }
21947 
21948   while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
21949     Align RequiredAlignment;
21950     bool NeedRotate = false;
21951     if (LoadNodes.size() == 2) {
21952       // If we have load/store pair instructions and we only have two values,
21953       // don't bother merging.
21954       if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
21955           StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
21956         StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
21957         LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
21958         break;
21959       }
21960       // If the loads are reversed, see if we can rotate the halves into place.
21961       int64_t Offset0 = LoadNodes[0].OffsetFromBase;
21962       int64_t Offset1 = LoadNodes[1].OffsetFromBase;
21963       EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
21964       if (Offset0 - Offset1 == ElementSizeBytes &&
21965           (hasOperation(ISD::ROTL, PairVT) ||
21966            hasOperation(ISD::ROTR, PairVT))) {
21967         std::swap(LoadNodes[0], LoadNodes[1]);
21968         NeedRotate = true;
21969       }
21970     }
21971     LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21972     unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21973     Align FirstStoreAlign = FirstInChain->getAlign();
21974     LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
21975 
21976     // Scan the memory operations on the chain and find the first
21977     // non-consecutive load memory address. These variables hold the index in
21978     // the store node array.
21979 
21980     unsigned LastConsecutiveLoad = 1;
21981 
21982     // This variable refers to the size and not index in the array.
21983     unsigned LastLegalVectorType = 1;
21984     unsigned LastLegalIntegerType = 1;
21985     bool isDereferenceable = true;
21986     bool DoIntegerTruncate = false;
21987     int64_t StartAddress = LoadNodes[0].OffsetFromBase;
21988     SDValue LoadChain = FirstLoad->getChain();
21989     for (unsigned i = 1; i < LoadNodes.size(); ++i) {
21990       // All loads must share the same chain.
21991       if (LoadNodes[i].MemNode->getChain() != LoadChain)
21992         break;
21993 
21994       int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
21995       if (CurrAddress - StartAddress != (ElementSizeBytes * i))
21996         break;
21997       LastConsecutiveLoad = i;
21998 
21999       if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
22000         isDereferenceable = false;
22001 
22002       // Find a legal type for the vector store.
22003       unsigned Elts = (i + 1) * NumMemElts;
22004       EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
22005 
22006       // Break early when size is too large to be legal.
22007       if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
22008         break;
22009 
22010       unsigned IsFastSt = 0;
22011       unsigned IsFastLd = 0;
22012       // Don't try vector types if we need a rotate. We may still fail the
22013       // legality checks for the integer type, but we can't handle the rotate
22014       // case with vectors.
22015       // FIXME: We could use a shuffle in place of the rotate.
22016       if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
22017           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
22018                                DAG.getMachineFunction()) &&
22019           TLI.allowsMemoryAccess(Context, DL, StoreTy,
22020                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
22021           IsFastSt &&
22022           TLI.allowsMemoryAccess(Context, DL, StoreTy,
22023                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
22024           IsFastLd) {
22025         LastLegalVectorType = i + 1;
22026       }
22027 
22028       // Find a legal type for the integer store.
22029       unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
22030       StoreTy = EVT::getIntegerVT(Context, SizeInBits);
22031       if (TLI.isTypeLegal(StoreTy) &&
22032           TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
22033                                DAG.getMachineFunction()) &&
22034           TLI.allowsMemoryAccess(Context, DL, StoreTy,
22035                                  *FirstInChain->getMemOperand(), &IsFastSt) &&
22036           IsFastSt &&
22037           TLI.allowsMemoryAccess(Context, DL, StoreTy,
22038                                  *FirstLoad->getMemOperand(), &IsFastLd) &&
22039           IsFastLd) {
22040         LastLegalIntegerType = i + 1;
22041         DoIntegerTruncate = false;
22042         // Or check whether a truncstore and extload is legal.
22043       } else if (TLI.getTypeAction(Context, StoreTy) ==
22044                  TargetLowering::TypePromoteInteger) {
22045         EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
22046         if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
22047             TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
22048                                  DAG.getMachineFunction()) &&
22049             TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
22050             TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
22051             TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
22052             TLI.allowsMemoryAccess(Context, DL, StoreTy,
22053                                    *FirstInChain->getMemOperand(), &IsFastSt) &&
22054             IsFastSt &&
22055             TLI.allowsMemoryAccess(Context, DL, StoreTy,
22056                                    *FirstLoad->getMemOperand(), &IsFastLd) &&
22057             IsFastLd) {
22058           LastLegalIntegerType = i + 1;
22059           DoIntegerTruncate = true;
22060         }
22061       }
22062     }
22063 
22064     // Only use vector types if the vector type is larger than the integer
22065     // type. If they are the same, use integers.
22066     bool UseVectorTy =
22067         LastLegalVectorType > LastLegalIntegerType && AllowVectors;
22068     unsigned LastLegalType =
22069         std::max(LastLegalVectorType, LastLegalIntegerType);
22070 
22071     // We add +1 here because the LastXXX variables refer to location while
22072     // the NumElem refers to array/index size.
22073     unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
22074     NumElem = std::min(LastLegalType, NumElem);
22075     Align FirstLoadAlign = FirstLoad->getAlign();
22076 
22077     if (NumElem < 2) {
22078       // We know that candidate stores are in order and of correct
22079       // shape. While there is no mergeable sequence from the
22080       // beginning one may start later in the sequence. The only
22081       // reason a merge of size N could have failed where another of
22082       // the same size would not have is if the alignment or either
22083       // the load or store has improved. Drop as many candidates as we
22084       // can here.
22085       unsigned NumSkip = 1;
22086       while ((NumSkip < LoadNodes.size()) &&
22087              (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
22088              (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
22089         NumSkip++;
22090       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
22091       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
22092       NumConsecutiveStores -= NumSkip;
22093       continue;
22094     }
22095 
22096     // Check that we can merge these candidates without causing a cycle.
22097     if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
22098                                                   RootNode)) {
22099       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
22100       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
22101       NumConsecutiveStores -= NumElem;
22102       continue;
22103     }
22104 
22105     // Find if it is better to use vectors or integers to load and store
22106     // to memory.
22107     EVT JointMemOpVT;
22108     if (UseVectorTy) {
22109       // Find a legal type for the vector store.
22110       unsigned Elts = NumElem * NumMemElts;
22111       JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
22112     } else {
22113       unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
22114       JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
22115     }
22116 
22117     // Check if there is a call in the load/store chain.
22118     if (!TLI.shouldMergeStoreOfLoadsOverCall(MemVT, JointMemOpVT) &&
22119         hasCallInLdStChain(cast<StoreSDNode>(StoreNodes[0].MemNode),
22120                            cast<LoadSDNode>(LoadNodes[0].MemNode))) {
22121       StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
22122       LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
22123       NumConsecutiveStores -= NumElem;
22124       continue;
22125     }
22126 
22127     SDLoc LoadDL(LoadNodes[0].MemNode);
22128     SDLoc StoreDL(StoreNodes[0].MemNode);
22129 
22130     // The merged loads are required to have the same incoming chain, so
22131     // using the first's chain is acceptable.
22132 
22133     SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
22134     bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
22135     AddToWorklist(NewStoreChain.getNode());
22136 
22137     MachineMemOperand::Flags LdMMOFlags =
22138         isDereferenceable ? MachineMemOperand::MODereferenceable
22139                           : MachineMemOperand::MONone;
22140     if (IsNonTemporalLoad)
22141       LdMMOFlags |= MachineMemOperand::MONonTemporal;
22142 
22143     LdMMOFlags |= TLI.getTargetMMOFlags(*FirstLoad);
22144 
22145     MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
22146                                               ? MachineMemOperand::MONonTemporal
22147                                               : MachineMemOperand::MONone;
22148 
22149     StMMOFlags |= TLI.getTargetMMOFlags(*StoreNodes[0].MemNode);
22150 
22151     SDValue NewLoad, NewStore;
22152     if (UseVectorTy || !DoIntegerTruncate) {
22153       NewLoad = DAG.getLoad(
22154           JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
22155           FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
22156       SDValue StoreOp = NewLoad;
22157       if (NeedRotate) {
22158         unsigned LoadWidth = ElementSizeBytes * 8 * 2;
22159         assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
22160                "Unexpected type for rotate-able load pair");
22161         SDValue RotAmt =
22162             DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
22163         // Target can convert to the identical ROTR if it does not have ROTL.
22164         StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
22165       }
22166       NewStore = DAG.getStore(
22167           NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
22168           CanReusePtrInfo ? FirstInChain->getPointerInfo()
22169                           : MachinePointerInfo(FirstStoreAS),
22170           FirstStoreAlign, StMMOFlags);
22171     } else { // This must be the truncstore/extload case
22172       EVT ExtendedTy =
22173           TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
22174       NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
22175                                FirstLoad->getChain(), FirstLoad->getBasePtr(),
22176                                FirstLoad->getPointerInfo(), JointMemOpVT,
22177                                FirstLoadAlign, LdMMOFlags);
22178       NewStore = DAG.getTruncStore(
22179           NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
22180           CanReusePtrInfo ? FirstInChain->getPointerInfo()
22181                           : MachinePointerInfo(FirstStoreAS),
22182           JointMemOpVT, FirstInChain->getAlign(),
22183           FirstInChain->getMemOperand()->getFlags());
22184     }
22185 
22186     // Transfer chain users from old loads to the new load.
22187     for (unsigned i = 0; i < NumElem; ++i) {
22188       LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
22189       DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
22190                                     SDValue(NewLoad.getNode(), 1));
22191     }
22192 
22193     // Replace all stores with the new store. Recursively remove corresponding
22194     // values if they are no longer used.
22195     for (unsigned i = 0; i < NumElem; ++i) {
22196       SDValue Val = StoreNodes[i].MemNode->getOperand(1);
22197       CombineTo(StoreNodes[i].MemNode, NewStore);
22198       if (Val->use_empty())
22199         recursivelyDeleteUnusedNodes(Val.getNode());
22200     }
22201 
22202     MadeChange = true;
22203     StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
22204     LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
22205     NumConsecutiveStores -= NumElem;
22206   }
22207   return MadeChange;
22208 }
22209 
mergeConsecutiveStores(StoreSDNode * St)22210 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
22211   if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
22212     return false;
22213 
22214   // TODO: Extend this function to merge stores of scalable vectors.
22215   // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
22216   // store since we know <vscale x 16 x i8> is exactly twice as large as
22217   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
22218   EVT MemVT = St->getMemoryVT();
22219   if (MemVT.isScalableVT())
22220     return false;
22221   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
22222     return false;
22223 
22224   // This function cannot currently deal with non-byte-sized memory sizes.
22225   int64_t ElementSizeBytes = MemVT.getStoreSize();
22226   if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
22227     return false;
22228 
22229   // Do not bother looking at stored values that are not constants, loads, or
22230   // extracted vector elements.
22231   SDValue StoredVal = peekThroughBitcasts(St->getValue());
22232   const StoreSource StoreSrc = getStoreSource(StoredVal);
22233   if (StoreSrc == StoreSource::Unknown)
22234     return false;
22235 
22236   SmallVector<MemOpLink, 8> StoreNodes;
22237   // Find potential store merge candidates by searching through chain sub-DAG
22238   SDNode *RootNode = getStoreMergeCandidates(St, StoreNodes);
22239 
22240   // Check if there is anything to merge.
22241   if (StoreNodes.size() < 2)
22242     return false;
22243 
22244   // Sort the memory operands according to their distance from the
22245   // base pointer.
22246   llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
22247     return LHS.OffsetFromBase < RHS.OffsetFromBase;
22248   });
22249 
22250   bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
22251       Attribute::NoImplicitFloat);
22252   bool IsNonTemporalStore = St->isNonTemporal();
22253   bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
22254                            cast<LoadSDNode>(StoredVal)->isNonTemporal();
22255 
22256   // Store Merge attempts to merge the lowest stores. This generally
22257   // works out as if successful, as the remaining stores are checked
22258   // after the first collection of stores is merged. However, in the
22259   // case that a non-mergeable store is found first, e.g., {p[-2],
22260   // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
22261   // mergeable cases. To prevent this, we prune such stores from the
22262   // front of StoreNodes here.
22263   bool MadeChange = false;
22264   while (StoreNodes.size() > 1) {
22265     unsigned NumConsecutiveStores =
22266         getConsecutiveStores(StoreNodes, ElementSizeBytes);
22267     // There are no more stores in the list to examine.
22268     if (NumConsecutiveStores == 0)
22269       return MadeChange;
22270 
22271     // We have at least 2 consecutive stores. Try to merge them.
22272     assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
22273     switch (StoreSrc) {
22274     case StoreSource::Constant:
22275       MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
22276                                              MemVT, RootNode, AllowVectors);
22277       break;
22278 
22279     case StoreSource::Extract:
22280       MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
22281                                             MemVT, RootNode);
22282       break;
22283 
22284     case StoreSource::Load:
22285       MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
22286                                          MemVT, RootNode, AllowVectors,
22287                                          IsNonTemporalStore, IsNonTemporalLoad);
22288       break;
22289 
22290     default:
22291       llvm_unreachable("Unhandled store source type");
22292     }
22293   }
22294 
22295   // Remember if we failed to optimize, to save compile time.
22296   if (!MadeChange)
22297     ChainsWithoutMergeableStores.insert(RootNode);
22298 
22299   return MadeChange;
22300 }
22301 
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)22302 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
22303   SDLoc SL(ST);
22304   SDValue ReplStore;
22305 
22306   // Replace the chain to avoid dependency.
22307   if (ST->isTruncatingStore()) {
22308     ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
22309                                   ST->getBasePtr(), ST->getMemoryVT(),
22310                                   ST->getMemOperand());
22311   } else {
22312     ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
22313                              ST->getMemOperand());
22314   }
22315 
22316   // Create token to keep both nodes around.
22317   SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
22318                               MVT::Other, ST->getChain(), ReplStore);
22319 
22320   // Make sure the new and old chains are cleaned up.
22321   AddToWorklist(Token.getNode());
22322 
22323   // Don't add users to work list.
22324   return CombineTo(ST, Token, false);
22325 }
22326 
replaceStoreOfFPConstant(StoreSDNode * ST)22327 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
22328   SDValue Value = ST->getValue();
22329   if (Value.getOpcode() == ISD::TargetConstantFP)
22330     return SDValue();
22331 
22332   if (!ISD::isNormalStore(ST))
22333     return SDValue();
22334 
22335   SDLoc DL(ST);
22336 
22337   SDValue Chain = ST->getChain();
22338   SDValue Ptr = ST->getBasePtr();
22339 
22340   const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
22341 
22342   // NOTE: If the original store is volatile, this transform must not increase
22343   // the number of stores.  For example, on x86-32 an f64 can be stored in one
22344   // processor operation but an i64 (which is not legal) requires two.  So the
22345   // transform should not be done in this case.
22346 
22347   SDValue Tmp;
22348   switch (CFP->getSimpleValueType(0).SimpleTy) {
22349   default:
22350     llvm_unreachable("Unknown FP type");
22351   case MVT::f16:    // We don't do this for these yet.
22352   case MVT::bf16:
22353   case MVT::f80:
22354   case MVT::f128:
22355   case MVT::ppcf128:
22356     return SDValue();
22357   case MVT::f32:
22358     if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
22359         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
22360       Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
22361                             bitcastToAPInt().getZExtValue(), SDLoc(CFP),
22362                             MVT::i32);
22363       return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
22364     }
22365 
22366     return SDValue();
22367   case MVT::f64:
22368     if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
22369          ST->isSimple()) ||
22370         TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
22371       Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
22372                             getZExtValue(), SDLoc(CFP), MVT::i64);
22373       return DAG.getStore(Chain, DL, Tmp,
22374                           Ptr, ST->getMemOperand());
22375     }
22376 
22377     if (ST->isSimple() && TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32) &&
22378         !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
22379       // Many FP stores are not made apparent until after legalize, e.g. for
22380       // argument passing.  Since this is so common, custom legalize the
22381       // 64-bit integer store into two 32-bit stores.
22382       uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
22383       SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
22384       SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
22385       if (DAG.getDataLayout().isBigEndian())
22386         std::swap(Lo, Hi);
22387 
22388       MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
22389       AAMDNodes AAInfo = ST->getAAInfo();
22390 
22391       SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
22392                                  ST->getBaseAlign(), MMOFlags, AAInfo);
22393       Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(4), DL);
22394       SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
22395                                  ST->getPointerInfo().getWithOffset(4),
22396                                  ST->getBaseAlign(), MMOFlags, AAInfo);
22397       return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
22398                          St0, St1);
22399     }
22400 
22401     return SDValue();
22402   }
22403 }
22404 
22405 // (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
22406 //
22407 // If a store of a load with an element inserted into it has no other
22408 // uses in between the chain, then we can consider the vector store
22409 // dead and replace it with just the single scalar element store.
replaceStoreOfInsertLoad(StoreSDNode * ST)22410 SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
22411   SDLoc DL(ST);
22412   SDValue Value = ST->getValue();
22413   SDValue Ptr = ST->getBasePtr();
22414   SDValue Chain = ST->getChain();
22415   if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
22416     return SDValue();
22417 
22418   SDValue Elt = Value.getOperand(1);
22419   SDValue Idx = Value.getOperand(2);
22420 
22421   // If the element isn't byte sized or is implicitly truncated then we can't
22422   // compute an offset.
22423   EVT EltVT = Elt.getValueType();
22424   if (!EltVT.isByteSized() ||
22425       EltVT != Value.getOperand(0).getValueType().getVectorElementType())
22426     return SDValue();
22427 
22428   auto *Ld = dyn_cast<LoadSDNode>(Value.getOperand(0));
22429   if (!Ld || Ld->getBasePtr() != Ptr ||
22430       ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
22431       !ISD::isNormalStore(ST) ||
22432       Ld->getAddressSpace() != ST->getAddressSpace() ||
22433       !Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1)))
22434     return SDValue();
22435 
22436   unsigned IsFast;
22437   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
22438                               Elt.getValueType(), ST->getAddressSpace(),
22439                               ST->getAlign(), ST->getMemOperand()->getFlags(),
22440                               &IsFast) ||
22441       !IsFast)
22442     return SDValue();
22443 
22444   MachinePointerInfo PointerInfo(ST->getAddressSpace());
22445 
22446   // If the offset is a known constant then try to recover the pointer
22447   // info
22448   SDValue NewPtr;
22449   if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) {
22450     unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
22451     NewPtr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(COffset), DL);
22452     PointerInfo = ST->getPointerInfo().getWithOffset(COffset);
22453   } else {
22454     NewPtr = TLI.getVectorElementPointer(DAG, Ptr, Value.getValueType(), Idx);
22455   }
22456 
22457   return DAG.getStore(Chain, DL, Elt, NewPtr, PointerInfo, ST->getAlign(),
22458                       ST->getMemOperand()->getFlags());
22459 }
22460 
visitATOMIC_STORE(SDNode * N)22461 SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
22462   AtomicSDNode *ST = cast<AtomicSDNode>(N);
22463   SDValue Val = ST->getVal();
22464   EVT VT = Val.getValueType();
22465   EVT MemVT = ST->getMemoryVT();
22466 
22467   if (MemVT.bitsLT(VT)) { // Is truncating store
22468     APInt TruncDemandedBits = APInt::getLowBitsSet(VT.getScalarSizeInBits(),
22469                                                    MemVT.getScalarSizeInBits());
22470     // See if we can simplify the operation with SimplifyDemandedBits, which
22471     // only works if the value has a single use.
22472     if (SimplifyDemandedBits(Val, TruncDemandedBits))
22473       return SDValue(N, 0);
22474   }
22475 
22476   return SDValue();
22477 }
22478 
visitSTORE(SDNode * N)22479 SDValue DAGCombiner::visitSTORE(SDNode *N) {
22480   StoreSDNode *ST  = cast<StoreSDNode>(N);
22481   SDValue Chain = ST->getChain();
22482   SDValue Value = ST->getValue();
22483   SDValue Ptr   = ST->getBasePtr();
22484 
22485   // If this is a store of a bit convert, store the input value if the
22486   // resultant store does not need a higher alignment than the original.
22487   if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
22488       ST->isUnindexed()) {
22489     EVT SVT = Value.getOperand(0).getValueType();
22490     // If the store is volatile, we only want to change the store type if the
22491     // resulting store is legal. Otherwise we might increase the number of
22492     // memory accesses. We don't care if the original type was legal or not
22493     // as we assume software couldn't rely on the number of accesses of an
22494     // illegal type.
22495     // TODO: May be able to relax for unordered atomics (see D66309)
22496     if (((!LegalOperations && ST->isSimple()) ||
22497          TLI.isOperationLegal(ISD::STORE, SVT)) &&
22498         TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
22499                                      DAG, *ST->getMemOperand())) {
22500       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
22501                           ST->getMemOperand());
22502     }
22503   }
22504 
22505   // Turn 'store undef, Ptr' -> nothing.
22506   if (Value.isUndef() && ST->isUnindexed() && !ST->isVolatile())
22507     return Chain;
22508 
22509   // Try to infer better alignment information than the store already has.
22510   if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
22511       !ST->isAtomic()) {
22512     if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
22513       if (*Alignment > ST->getAlign() &&
22514           isAligned(*Alignment, ST->getSrcValueOffset())) {
22515         SDValue NewStore =
22516             DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
22517                               ST->getMemoryVT(), *Alignment,
22518                               ST->getMemOperand()->getFlags(), ST->getAAInfo());
22519         // NewStore will always be N as we are only refining the alignment
22520         assert(NewStore.getNode() == N);
22521         (void)NewStore;
22522       }
22523     }
22524   }
22525 
22526   // Try transforming a pair floating point load / store ops to integer
22527   // load / store ops.
22528   if (SDValue NewST = TransformFPLoadStorePair(N))
22529     return NewST;
22530 
22531   // Try transforming several stores into STORE (BSWAP).
22532   if (SDValue Store = mergeTruncStores(ST))
22533     return Store;
22534 
22535   if (ST->isUnindexed()) {
22536     // Walk up chain skipping non-aliasing memory nodes, on this store and any
22537     // adjacent stores.
22538     if (findBetterNeighborChains(ST)) {
22539       // replaceStoreChain uses CombineTo, which handled all of the worklist
22540       // manipulation. Return the original node to not do anything else.
22541       return SDValue(ST, 0);
22542     }
22543     Chain = ST->getChain();
22544   }
22545 
22546   // FIXME: is there such a thing as a truncating indexed store?
22547   if (ST->isTruncatingStore() && ST->isUnindexed() &&
22548       Value.getValueType().isInteger() &&
22549       (!isa<ConstantSDNode>(Value) ||
22550        !cast<ConstantSDNode>(Value)->isOpaque())) {
22551     // Convert a truncating store of a extension into a standard store.
22552     if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
22553          Value.getOpcode() == ISD::SIGN_EXTEND ||
22554          Value.getOpcode() == ISD::ANY_EXTEND) &&
22555         Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
22556         TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
22557       return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
22558                           ST->getMemOperand());
22559 
22560     APInt TruncDemandedBits =
22561         APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
22562                              ST->getMemoryVT().getScalarSizeInBits());
22563 
22564     // See if we can simplify the operation with SimplifyDemandedBits, which
22565     // only works if the value has a single use.
22566     AddToWorklist(Value.getNode());
22567     if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
22568       // Re-visit the store if anything changed and the store hasn't been merged
22569       // with another node (N is deleted) SimplifyDemandedBits will add Value's
22570       // node back to the worklist if necessary, but we also need to re-visit
22571       // the Store node itself.
22572       if (N->getOpcode() != ISD::DELETED_NODE)
22573         AddToWorklist(N);
22574       return SDValue(N, 0);
22575     }
22576 
22577     // Otherwise, see if we can simplify the input to this truncstore with
22578     // knowledge that only the low bits are being used.  For example:
22579     // "truncstore (or (shl x, 8), y), i8"  -> "truncstore y, i8"
22580     if (SDValue Shorter =
22581             TLI.SimplifyMultipleUseDemandedBits(Value, TruncDemandedBits, DAG))
22582       return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
22583                                ST->getMemOperand());
22584 
22585     // If we're storing a truncated constant, see if we can simplify it.
22586     // TODO: Move this to targetShrinkDemandedConstant?
22587     if (auto *Cst = dyn_cast<ConstantSDNode>(Value))
22588       if (!Cst->isOpaque()) {
22589         const APInt &CValue = Cst->getAPIntValue();
22590         APInt NewVal = CValue & TruncDemandedBits;
22591         if (NewVal != CValue) {
22592           SDValue Shorter =
22593               DAG.getConstant(NewVal, SDLoc(N), Value.getValueType());
22594           return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr,
22595                                    ST->getMemoryVT(), ST->getMemOperand());
22596         }
22597       }
22598   }
22599 
22600   // If this is a load followed by a store to the same location, then the store
22601   // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
22602   // TODO: Add big-endian truncate support with test coverage.
22603   // TODO: Can relax for unordered atomics (see D66309)
22604   SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
22605                          ? peekThroughTruncates(Value)
22606                          : Value;
22607   if (auto *Ld = dyn_cast<LoadSDNode>(TruncVal)) {
22608     if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
22609         ST->isUnindexed() && ST->isSimple() &&
22610         Ld->getAddressSpace() == ST->getAddressSpace() &&
22611         // There can't be any side effects between the load and store, such as
22612         // a call or store.
22613         Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
22614       // The store is dead, remove it.
22615       return Chain;
22616     }
22617   }
22618 
22619   // Try scalarizing vector stores of loads where we only change one element
22620   if (SDValue NewST = replaceStoreOfInsertLoad(ST))
22621     return NewST;
22622 
22623   // TODO: Can relax for unordered atomics (see D66309)
22624   if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
22625     if (ST->isUnindexed() && ST->isSimple() &&
22626         ST1->isUnindexed() && ST1->isSimple()) {
22627       if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
22628           ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
22629           ST->getAddressSpace() == ST1->getAddressSpace()) {
22630         // If this is a store followed by a store with the same value to the
22631         // same location, then the store is dead/noop.
22632         return Chain;
22633       }
22634 
22635       if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
22636           !ST1->getBasePtr().isUndef() &&
22637           ST->getAddressSpace() == ST1->getAddressSpace()) {
22638         // If we consider two stores and one smaller in size is a scalable
22639         // vector type and another one a bigger size store with a fixed type,
22640         // then we could not allow the scalable store removal because we don't
22641         // know its final size in the end.
22642         if (ST->getMemoryVT().isScalableVector() ||
22643             ST1->getMemoryVT().isScalableVector()) {
22644           if (ST1->getBasePtr() == Ptr &&
22645               TypeSize::isKnownLE(ST1->getMemoryVT().getStoreSize(),
22646                                   ST->getMemoryVT().getStoreSize())) {
22647             CombineTo(ST1, ST1->getChain());
22648             return SDValue(N, 0);
22649           }
22650         } else {
22651           const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
22652           const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
22653           // If this is a store who's preceding store to a subset of the current
22654           // location and no one other node is chained to that store we can
22655           // effectively drop the store. Do not remove stores to undef as they
22656           // may be used as data sinks.
22657           if (STBase.contains(DAG, ST->getMemoryVT().getFixedSizeInBits(),
22658                               ChainBase,
22659                               ST1->getMemoryVT().getFixedSizeInBits())) {
22660             CombineTo(ST1, ST1->getChain());
22661             return SDValue(N, 0);
22662           }
22663         }
22664       }
22665     }
22666   }
22667 
22668   // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
22669   // truncating store.  We can do this even if this is already a truncstore.
22670   if ((Value.getOpcode() == ISD::FP_ROUND ||
22671        Value.getOpcode() == ISD::TRUNCATE) &&
22672       Value->hasOneUse() && ST->isUnindexed() &&
22673       TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
22674                                ST->getMemoryVT(), LegalOperations)) {
22675     return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
22676                              Ptr, ST->getMemoryVT(), ST->getMemOperand());
22677   }
22678 
22679   // Always perform this optimization before types are legal. If the target
22680   // prefers, also try this after legalization to catch stores that were created
22681   // by intrinsics or other nodes.
22682   if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
22683     while (true) {
22684       // There can be multiple store sequences on the same chain.
22685       // Keep trying to merge store sequences until we are unable to do so
22686       // or until we merge the last store on the chain.
22687       bool Changed = mergeConsecutiveStores(ST);
22688       if (!Changed) break;
22689       // Return N as merge only uses CombineTo and no worklist clean
22690       // up is necessary.
22691       if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
22692         return SDValue(N, 0);
22693     }
22694   }
22695 
22696   // Try transforming N to an indexed store.
22697   if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
22698     return SDValue(N, 0);
22699 
22700   // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
22701   //
22702   // Make sure to do this only after attempting to merge stores in order to
22703   //  avoid changing the types of some subset of stores due to visit order,
22704   //  preventing their merging.
22705   if (isa<ConstantFPSDNode>(ST->getValue())) {
22706     if (SDValue NewSt = replaceStoreOfFPConstant(ST))
22707       return NewSt;
22708   }
22709 
22710   if (SDValue NewSt = splitMergedValStore(ST))
22711     return NewSt;
22712 
22713   return ReduceLoadOpStoreWidth(N);
22714 }
22715 
visitLIFETIME_END(SDNode * N)22716 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
22717   const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
22718   if (!LifetimeEnd->hasOffset())
22719     return SDValue();
22720 
22721   const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
22722                                         LifetimeEnd->getOffset(), false);
22723 
22724   // We walk up the chains to find stores.
22725   SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
22726   while (!Chains.empty()) {
22727     SDValue Chain = Chains.pop_back_val();
22728     if (!Chain.hasOneUse())
22729       continue;
22730     switch (Chain.getOpcode()) {
22731     case ISD::TokenFactor:
22732       for (unsigned Nops = Chain.getNumOperands(); Nops;)
22733         Chains.push_back(Chain.getOperand(--Nops));
22734       break;
22735     case ISD::LIFETIME_START:
22736     case ISD::LIFETIME_END:
22737       // We can forward past any lifetime start/end that can be proven not to
22738       // alias the node.
22739       if (!mayAlias(Chain.getNode(), N))
22740         Chains.push_back(Chain.getOperand(0));
22741       break;
22742     case ISD::STORE: {
22743       StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
22744       // TODO: Can relax for unordered atomics (see D66309)
22745       if (!ST->isSimple() || ST->isIndexed())
22746         continue;
22747       const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
22748       // The bounds of a scalable store are not known until runtime, so this
22749       // store cannot be elided.
22750       if (StoreSize.isScalable())
22751         continue;
22752       const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
22753       // If we store purely within object bounds just before its lifetime ends,
22754       // we can remove the store.
22755       if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
22756                                    StoreSize.getFixedValue() * 8)) {
22757         LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
22758                    dbgs() << "\nwithin LIFETIME_END of : ";
22759                    LifetimeEndBase.dump(); dbgs() << "\n");
22760         CombineTo(ST, ST->getChain());
22761         return SDValue(N, 0);
22762       }
22763     }
22764     }
22765   }
22766   return SDValue();
22767 }
22768 
22769 /// For the instruction sequence of store below, F and I values
22770 /// are bundled together as an i64 value before being stored into memory.
22771 /// Sometimes it is more efficent to generate separate stores for F and I,
22772 /// which can remove the bitwise instructions or sink them to colder places.
22773 ///
22774 ///   (store (or (zext (bitcast F to i32) to i64),
22775 ///              (shl (zext I to i64), 32)), addr)  -->
22776 ///   (store F, addr) and (store I, addr+4)
22777 ///
22778 /// Similarly, splitting for other merged store can also be beneficial, like:
22779 /// For pair of {i32, i32}, i64 store --> two i32 stores.
22780 /// For pair of {i32, i16}, i64 store --> two i32 stores.
22781 /// For pair of {i16, i16}, i32 store --> two i16 stores.
22782 /// For pair of {i16, i8},  i32 store --> two i16 stores.
22783 /// For pair of {i8, i8},   i16 store --> two i8 stores.
22784 ///
22785 /// We allow each target to determine specifically which kind of splitting is
22786 /// supported.
22787 ///
22788 /// The store patterns are commonly seen from the simple code snippet below
22789 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
22790 ///   void goo(const std::pair<int, float> &);
22791 ///   hoo() {
22792 ///     ...
22793 ///     goo(std::make_pair(tmp, ftmp));
22794 ///     ...
22795 ///   }
22796 ///
splitMergedValStore(StoreSDNode * ST)22797 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
22798   if (OptLevel == CodeGenOptLevel::None)
22799     return SDValue();
22800 
22801   // Can't change the number of memory accesses for a volatile store or break
22802   // atomicity for an atomic one.
22803   if (!ST->isSimple())
22804     return SDValue();
22805 
22806   SDValue Val = ST->getValue();
22807   SDLoc DL(ST);
22808 
22809   // Match OR operand.
22810   if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
22811     return SDValue();
22812 
22813   // Match SHL operand and get Lower and Higher parts of Val.
22814   SDValue Op1 = Val.getOperand(0);
22815   SDValue Op2 = Val.getOperand(1);
22816   SDValue Lo, Hi;
22817   if (Op1.getOpcode() != ISD::SHL) {
22818     std::swap(Op1, Op2);
22819     if (Op1.getOpcode() != ISD::SHL)
22820       return SDValue();
22821   }
22822   Lo = Op2;
22823   Hi = Op1.getOperand(0);
22824   if (!Op1.hasOneUse())
22825     return SDValue();
22826 
22827   // Match shift amount to HalfValBitSize.
22828   unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
22829   ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
22830   if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
22831     return SDValue();
22832 
22833   // Lo and Hi are zero-extended from int with size less equal than 32
22834   // to i64.
22835   if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
22836       !Lo.getOperand(0).getValueType().isScalarInteger() ||
22837       Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
22838       Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
22839       !Hi.getOperand(0).getValueType().isScalarInteger() ||
22840       Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
22841     return SDValue();
22842 
22843   // Use the EVT of low and high parts before bitcast as the input
22844   // of target query.
22845   EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
22846                   ? Lo.getOperand(0).getValueType()
22847                   : Lo.getValueType();
22848   EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
22849                    ? Hi.getOperand(0).getValueType()
22850                    : Hi.getValueType();
22851   if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
22852     return SDValue();
22853 
22854   // Start to split store.
22855   MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
22856   AAMDNodes AAInfo = ST->getAAInfo();
22857 
22858   // Change the sizes of Lo and Hi's value types to HalfValBitSize.
22859   EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
22860   Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
22861   Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
22862 
22863   SDValue Chain = ST->getChain();
22864   SDValue Ptr = ST->getBasePtr();
22865   // Lower value store.
22866   SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
22867                              ST->getBaseAlign(), MMOFlags, AAInfo);
22868   Ptr =
22869       DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(HalfValBitSize / 8), DL);
22870   // Higher value store.
22871   SDValue St1 = DAG.getStore(
22872       St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
22873       ST->getBaseAlign(), MMOFlags, AAInfo);
22874   return St1;
22875 }
22876 
22877 // Merge an insertion into an existing shuffle:
22878 // (insert_vector_elt (vector_shuffle X, Y, Mask),
22879 //                   .(extract_vector_elt X, N), InsIndex)
22880 //   --> (vector_shuffle X, Y, NewMask)
22881 //  and variations where shuffle operands may be CONCAT_VECTORS.
mergeEltWithShuffle(SDValue & X,SDValue & Y,ArrayRef<int> Mask,SmallVectorImpl<int> & NewMask,SDValue Elt,unsigned InsIndex)22882 static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
22883                                 SmallVectorImpl<int> &NewMask, SDValue Elt,
22884                                 unsigned InsIndex) {
22885   if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
22886       !isa<ConstantSDNode>(Elt.getOperand(1)))
22887     return false;
22888 
22889   // Vec's operand 0 is using indices from 0 to N-1 and
22890   // operand 1 from N to 2N - 1, where N is the number of
22891   // elements in the vectors.
22892   SDValue InsertVal0 = Elt.getOperand(0);
22893   int ElementOffset = -1;
22894 
22895   // We explore the inputs of the shuffle in order to see if we find the
22896   // source of the extract_vector_elt. If so, we can use it to modify the
22897   // shuffle rather than perform an insert_vector_elt.
22898   SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
22899   ArgWorkList.emplace_back(Mask.size(), Y);
22900   ArgWorkList.emplace_back(0, X);
22901 
22902   while (!ArgWorkList.empty()) {
22903     int ArgOffset;
22904     SDValue ArgVal;
22905     std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
22906 
22907     if (ArgVal == InsertVal0) {
22908       ElementOffset = ArgOffset;
22909       break;
22910     }
22911 
22912     // Peek through concat_vector.
22913     if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
22914       int CurrentArgOffset =
22915           ArgOffset + ArgVal.getValueType().getVectorNumElements();
22916       int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
22917       for (SDValue Op : reverse(ArgVal->ops())) {
22918         CurrentArgOffset -= Step;
22919         ArgWorkList.emplace_back(CurrentArgOffset, Op);
22920       }
22921 
22922       // Make sure we went through all the elements and did not screw up index
22923       // computation.
22924       assert(CurrentArgOffset == ArgOffset);
22925     }
22926   }
22927 
22928   // If we failed to find a match, see if we can replace an UNDEF shuffle
22929   // operand.
22930   if (ElementOffset == -1) {
22931     if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
22932       return false;
22933     ElementOffset = Mask.size();
22934     Y = InsertVal0;
22935   }
22936 
22937   NewMask.assign(Mask.begin(), Mask.end());
22938   NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(1);
22939   assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
22940          "NewMask[InsIndex] is out of bound");
22941   return true;
22942 }
22943 
22944 // Merge an insertion into an existing shuffle:
22945 // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
22946 // InsIndex)
22947 //   --> (vector_shuffle X, Y) and variations where shuffle operands may be
22948 //   CONCAT_VECTORS.
mergeInsertEltWithShuffle(SDNode * N,unsigned InsIndex)22949 SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
22950   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
22951          "Expected extract_vector_elt");
22952   SDValue InsertVal = N->getOperand(1);
22953   SDValue Vec = N->getOperand(0);
22954 
22955   auto *SVN = dyn_cast<ShuffleVectorSDNode>(Vec);
22956   if (!SVN || !Vec.hasOneUse())
22957     return SDValue();
22958 
22959   ArrayRef<int> Mask = SVN->getMask();
22960   SDValue X = Vec.getOperand(0);
22961   SDValue Y = Vec.getOperand(1);
22962 
22963   SmallVector<int, 16> NewMask(Mask);
22964   if (mergeEltWithShuffle(X, Y, Mask, NewMask, InsertVal, InsIndex)) {
22965     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
22966         Vec.getValueType(), SDLoc(N), X, Y, NewMask, DAG);
22967     if (LegalShuffle)
22968       return LegalShuffle;
22969   }
22970 
22971   return SDValue();
22972 }
22973 
22974 // Convert a disguised subvector insertion into a shuffle:
22975 // insert_vector_elt V, (bitcast X from vector type), IdxC -->
22976 // bitcast(shuffle (bitcast V), (extended X), Mask)
22977 // Note: We do not use an insert_subvector node because that requires a
22978 // legal subvector type.
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)22979 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
22980   assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
22981          "Expected extract_vector_elt");
22982   SDValue InsertVal = N->getOperand(1);
22983 
22984   if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
22985       !InsertVal.getOperand(0).getValueType().isVector())
22986     return SDValue();
22987 
22988   SDValue SubVec = InsertVal.getOperand(0);
22989   SDValue DestVec = N->getOperand(0);
22990   EVT SubVecVT = SubVec.getValueType();
22991   EVT VT = DestVec.getValueType();
22992   unsigned NumSrcElts = SubVecVT.getVectorNumElements();
22993   // If the source only has a single vector element, the cost of creating adding
22994   // it to a vector is likely to exceed the cost of a insert_vector_elt.
22995   if (NumSrcElts == 1)
22996     return SDValue();
22997   unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
22998   unsigned NumMaskVals = ExtendRatio * NumSrcElts;
22999 
23000   // Step 1: Create a shuffle mask that implements this insert operation. The
23001   // vector that we are inserting into will be operand 0 of the shuffle, so
23002   // those elements are just 'i'. The inserted subvector is in the first
23003   // positions of operand 1 of the shuffle. Example:
23004   // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
23005   SmallVector<int, 16> Mask(NumMaskVals);
23006   for (unsigned i = 0; i != NumMaskVals; ++i) {
23007     if (i / NumSrcElts == InsIndex)
23008       Mask[i] = (i % NumSrcElts) + NumMaskVals;
23009     else
23010       Mask[i] = i;
23011   }
23012 
23013   // Bail out if the target can not handle the shuffle we want to create.
23014   EVT SubVecEltVT = SubVecVT.getVectorElementType();
23015   EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
23016   if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
23017     return SDValue();
23018 
23019   // Step 2: Create a wide vector from the inserted source vector by appending
23020   // undefined elements. This is the same size as our destination vector.
23021   SDLoc DL(N);
23022   SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
23023   ConcatOps[0] = SubVec;
23024   SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
23025 
23026   // Step 3: Shuffle in the padded subvector.
23027   SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
23028   SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
23029   AddToWorklist(PaddedSubV.getNode());
23030   AddToWorklist(DestVecBC.getNode());
23031   AddToWorklist(Shuf.getNode());
23032   return DAG.getBitcast(VT, Shuf);
23033 }
23034 
23035 // Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
23036 // possible and the new load will be quick. We use more loads but less shuffles
23037 // and inserts.
combineInsertEltToLoad(SDNode * N,unsigned InsIndex)23038 SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
23039   EVT VT = N->getValueType(0);
23040 
23041   // InsIndex is expected to be the first of last lane.
23042   if (!VT.isFixedLengthVector() ||
23043       (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
23044     return SDValue();
23045 
23046   // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
23047   // depending on the InsIndex.
23048   auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0));
23049   SDValue Scalar = N->getOperand(1);
23050   if (!Shuffle || !all_of(enumerate(Shuffle->getMask()), [&](auto P) {
23051         return InsIndex == P.index() || P.value() < 0 ||
23052                (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
23053                (InsIndex == VT.getVectorNumElements() - 1 &&
23054                 P.value() == (int)P.index() + 1);
23055       }))
23056     return SDValue();
23057 
23058   // We optionally skip over an extend so long as both loads are extended in the
23059   // same way from the same type.
23060   unsigned Extend = 0;
23061   if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
23062       Scalar.getOpcode() == ISD::SIGN_EXTEND ||
23063       Scalar.getOpcode() == ISD::ANY_EXTEND) {
23064     Extend = Scalar.getOpcode();
23065     Scalar = Scalar.getOperand(0);
23066   }
23067 
23068   auto *ScalarLoad = dyn_cast<LoadSDNode>(Scalar);
23069   if (!ScalarLoad)
23070     return SDValue();
23071 
23072   SDValue Vec = Shuffle->getOperand(0);
23073   if (Extend) {
23074     if (Vec.getOpcode() != Extend)
23075       return SDValue();
23076     Vec = Vec.getOperand(0);
23077   }
23078   auto *VecLoad = dyn_cast<LoadSDNode>(Vec);
23079   if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
23080     return SDValue();
23081 
23082   int EltSize = ScalarLoad->getValueType(0).getScalarSizeInBits();
23083   if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
23084       !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23085       ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23086       ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
23087     return SDValue();
23088 
23089   // Check that the offset between the pointers to produce a single continuous
23090   // load.
23091   if (InsIndex == 0) {
23092     if (!DAG.areNonVolatileConsecutiveLoads(ScalarLoad, VecLoad, EltSize / 8,
23093                                             -1))
23094       return SDValue();
23095   } else {
23096     if (!DAG.areNonVolatileConsecutiveLoads(
23097             VecLoad, ScalarLoad, VT.getVectorNumElements() * EltSize / 8, -1))
23098       return SDValue();
23099   }
23100 
23101   // And that the new unaligned load will be fast.
23102   unsigned IsFast = 0;
23103   Align NewAlign = commonAlignment(VecLoad->getAlign(), EltSize / 8);
23104   if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
23105                               Vec.getValueType(), VecLoad->getAddressSpace(),
23106                               NewAlign, VecLoad->getMemOperand()->getFlags(),
23107                               &IsFast) ||
23108       !IsFast)
23109     return SDValue();
23110 
23111   // Calculate the new Ptr and create the new load.
23112   SDLoc DL(N);
23113   SDValue Ptr = ScalarLoad->getBasePtr();
23114   if (InsIndex != 0)
23115     Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), VecLoad->getBasePtr(),
23116                       DAG.getConstant(EltSize / 8, DL, Ptr.getValueType()));
23117   MachinePointerInfo PtrInfo =
23118       InsIndex == 0 ? ScalarLoad->getPointerInfo()
23119                     : VecLoad->getPointerInfo().getWithOffset(EltSize / 8);
23120 
23121   SDValue Load = DAG.getLoad(VecLoad->getValueType(0), DL,
23122                              ScalarLoad->getChain(), Ptr, PtrInfo, NewAlign);
23123   DAG.makeEquivalentMemoryOrdering(ScalarLoad, Load.getValue(1));
23124   DAG.makeEquivalentMemoryOrdering(VecLoad, Load.getValue(1));
23125   return Extend ? DAG.getNode(Extend, DL, VT, Load) : Load;
23126 }
23127 
visitINSERT_VECTOR_ELT(SDNode * N)23128 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
23129   SDValue InVec = N->getOperand(0);
23130   SDValue InVal = N->getOperand(1);
23131   SDValue EltNo = N->getOperand(2);
23132   SDLoc DL(N);
23133 
23134   EVT VT = InVec.getValueType();
23135   auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
23136 
23137   // Insert into out-of-bounds element is undefined.
23138   if (IndexC && VT.isFixedLengthVector() &&
23139       IndexC->getZExtValue() >= VT.getVectorNumElements())
23140     return DAG.getUNDEF(VT);
23141 
23142   // Remove redundant insertions:
23143   // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
23144   if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23145       InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
23146     return InVec;
23147 
23148   if (!IndexC) {
23149     // If this is variable insert to undef vector, it might be better to splat:
23150     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23151     if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23152       return DAG.getSplat(VT, DL, InVal);
23153     return SDValue();
23154   }
23155 
23156   if (VT.isScalableVector())
23157     return SDValue();
23158 
23159   unsigned NumElts = VT.getVectorNumElements();
23160 
23161   // We must know which element is being inserted for folds below here.
23162   unsigned Elt = IndexC->getZExtValue();
23163 
23164   // Handle <1 x ???> vector insertion special cases.
23165   if (NumElts == 1) {
23166     // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
23167     if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23168         InVal.getOperand(0).getValueType() == VT &&
23169         isNullConstant(InVal.getOperand(1)))
23170       return InVal.getOperand(0);
23171   }
23172 
23173   // Canonicalize insert_vector_elt dag nodes.
23174   // Example:
23175   // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
23176   // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
23177   //
23178   // Do this only if the child insert_vector node has one use; also
23179   // do this only if indices are both constants and Idx1 < Idx0.
23180   if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
23181       && isa<ConstantSDNode>(InVec.getOperand(2))) {
23182     unsigned OtherElt = InVec.getConstantOperandVal(2);
23183     if (Elt < OtherElt) {
23184       // Swap nodes.
23185       SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
23186                                   InVec.getOperand(0), InVal, EltNo);
23187       AddToWorklist(NewOp.getNode());
23188       return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
23189                          VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
23190     }
23191   }
23192 
23193   if (SDValue Shuf = mergeInsertEltWithShuffle(N, Elt))
23194     return Shuf;
23195 
23196   if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
23197     return Shuf;
23198 
23199   if (SDValue Shuf = combineInsertEltToLoad(N, Elt))
23200     return Shuf;
23201 
23202   // Attempt to convert an insert_vector_elt chain into a legal build_vector.
23203   if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
23204     // vXi1 vector - we don't need to recurse.
23205     if (NumElts == 1)
23206       return DAG.getBuildVector(VT, DL, {InVal});
23207 
23208     // If we haven't already collected the element, insert into the op list.
23209     EVT MaxEltVT = InVal.getValueType();
23210     auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
23211                                 unsigned Idx) {
23212       if (!Ops[Idx]) {
23213         Ops[Idx] = Elt;
23214         if (VT.isInteger()) {
23215           EVT EltVT = Elt.getValueType();
23216           MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
23217         }
23218       }
23219     };
23220 
23221     // Ensure all the operands are the same value type, fill any missing
23222     // operands with UNDEF and create the BUILD_VECTOR.
23223     auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops,
23224                                        bool FreezeUndef = false) {
23225       assert(Ops.size() == NumElts && "Unexpected vector size");
23226       SDValue UndefOp = FreezeUndef ? DAG.getFreeze(DAG.getUNDEF(MaxEltVT))
23227                                     : DAG.getUNDEF(MaxEltVT);
23228       for (SDValue &Op : Ops) {
23229         if (Op)
23230           Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
23231         else
23232           Op = UndefOp;
23233       }
23234       return DAG.getBuildVector(VT, DL, Ops);
23235     };
23236 
23237     SmallVector<SDValue, 8> Ops(NumElts, SDValue());
23238     Ops[Elt] = InVal;
23239 
23240     // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
23241     for (SDValue CurVec = InVec; CurVec;) {
23242       // UNDEF - build new BUILD_VECTOR from already inserted operands.
23243       if (CurVec.isUndef())
23244         return CanonicalizeBuildVector(Ops);
23245 
23246       // FREEZE(UNDEF) - build new BUILD_VECTOR from already inserted operands.
23247       if (ISD::isFreezeUndef(CurVec.getNode()) && CurVec.hasOneUse())
23248         return CanonicalizeBuildVector(Ops, /*FreezeUndef=*/true);
23249 
23250       // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
23251       if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
23252         for (unsigned I = 0; I != NumElts; ++I)
23253           AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
23254         return CanonicalizeBuildVector(Ops);
23255       }
23256 
23257       // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
23258       if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
23259         AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
23260         return CanonicalizeBuildVector(Ops);
23261       }
23262 
23263       // INSERT_VECTOR_ELT - insert operand and continue up the chain.
23264       if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
23265         if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
23266           if (CurIdx->getAPIntValue().ult(NumElts)) {
23267             unsigned Idx = CurIdx->getZExtValue();
23268             AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
23269 
23270             // Found entire BUILD_VECTOR.
23271             if (all_of(Ops, [](SDValue Op) { return !!Op; }))
23272               return CanonicalizeBuildVector(Ops);
23273 
23274             CurVec = CurVec->getOperand(0);
23275             continue;
23276           }
23277 
23278       // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
23279       // update the shuffle mask (and second operand if we started with unary
23280       // shuffle) and create a new legal shuffle.
23281       if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
23282         auto *SVN = cast<ShuffleVectorSDNode>(CurVec);
23283         SDValue LHS = SVN->getOperand(0);
23284         SDValue RHS = SVN->getOperand(1);
23285         SmallVector<int, 16> Mask(SVN->getMask());
23286         bool Merged = true;
23287         for (auto I : enumerate(Ops)) {
23288           SDValue &Op = I.value();
23289           if (Op) {
23290             SmallVector<int, 16> NewMask;
23291             if (!mergeEltWithShuffle(LHS, RHS, Mask, NewMask, Op, I.index())) {
23292               Merged = false;
23293               break;
23294             }
23295             Mask = std::move(NewMask);
23296           }
23297         }
23298         if (Merged)
23299           if (SDValue NewShuffle =
23300                   TLI.buildLegalVectorShuffle(VT, DL, LHS, RHS, Mask, DAG))
23301             return NewShuffle;
23302       }
23303 
23304       if (!LegalOperations) {
23305         bool IsNull = llvm::isNullConstant(InVal);
23306         // We can convert to AND/OR mask if all insertions are zero or -1
23307         // respectively.
23308         if ((IsNull || llvm::isAllOnesConstant(InVal)) &&
23309             all_of(Ops, [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
23310             count_if(Ops, [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
23311           SDValue Zero = DAG.getConstant(0, DL, MaxEltVT);
23312           SDValue AllOnes = DAG.getAllOnesConstant(DL, MaxEltVT);
23313           SmallVector<SDValue, 8> Mask(NumElts);
23314 
23315           // Build the mask and return the corresponding DAG node.
23316           auto BuildMaskAndNode = [&](SDValue TrueVal, SDValue FalseVal,
23317                                       unsigned MaskOpcode) {
23318             for (unsigned I = 0; I != NumElts; ++I)
23319               Mask[I] = Ops[I] ? TrueVal : FalseVal;
23320             return DAG.getNode(MaskOpcode, DL, VT, CurVec,
23321                                DAG.getBuildVector(VT, DL, Mask));
23322           };
23323 
23324           // If all elements are zero, we can use AND with all ones.
23325           if (IsNull)
23326             return BuildMaskAndNode(Zero, AllOnes, ISD::AND);
23327 
23328           // If all elements are -1, we can use OR with zero.
23329           return BuildMaskAndNode(AllOnes, Zero, ISD::OR);
23330         }
23331       }
23332 
23333       // Failed to find a match in the chain - bail.
23334       break;
23335     }
23336 
23337     // See if we can fill in the missing constant elements as zeros.
23338     // TODO: Should we do this for any constant?
23339     APInt DemandedZeroElts = APInt::getZero(NumElts);
23340     for (unsigned I = 0; I != NumElts; ++I)
23341       if (!Ops[I])
23342         DemandedZeroElts.setBit(I);
23343 
23344     if (DAG.MaskedVectorIsZero(InVec, DemandedZeroElts)) {
23345       SDValue Zero = VT.isInteger() ? DAG.getConstant(0, DL, MaxEltVT)
23346                                     : DAG.getConstantFP(0, DL, MaxEltVT);
23347       for (unsigned I = 0; I != NumElts; ++I)
23348         if (!Ops[I])
23349           Ops[I] = Zero;
23350 
23351       return CanonicalizeBuildVector(Ops);
23352     }
23353   }
23354 
23355   return SDValue();
23356 }
23357 
23358 /// Transform a vector binary operation into a scalar binary operation by moving
23359 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinOp(SDNode * ExtElt,SelectionDAG & DAG,const SDLoc & DL,bool LegalTypes)23360 static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
23361                                        const SDLoc &DL, bool LegalTypes) {
23362   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23363   SDValue Vec = ExtElt->getOperand(0);
23364   SDValue Index = ExtElt->getOperand(1);
23365   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
23366   unsigned Opc = Vec.getOpcode();
23367   if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
23368       Vec->getNumValues() != 1)
23369     return SDValue();
23370 
23371   // Targets may want to avoid this to prevent an expensive register transfer.
23372   if (!TLI.shouldScalarizeBinop(Vec))
23373     return SDValue();
23374 
23375   EVT ResVT = ExtElt->getValueType(0);
23376   if (Opc == ISD::SETCC &&
23377       (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
23378     return SDValue();
23379 
23380   // Extracting an element of a vector constant is constant-folded, so this
23381   // transform is just replacing a vector op with a scalar op while moving the
23382   // extract.
23383   SDValue Op0 = Vec.getOperand(0);
23384   SDValue Op1 = Vec.getOperand(1);
23385   APInt SplatVal;
23386   if (!isAnyConstantBuildVector(Op0, true) &&
23387       !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
23388       !isAnyConstantBuildVector(Op1, true) &&
23389       !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
23390     return SDValue();
23391 
23392   // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
23393   // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
23394   if (Opc == ISD::SETCC) {
23395     EVT OpVT = Op0.getValueType().getVectorElementType();
23396     Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
23397     Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
23398     SDValue NewVal = DAG.getSetCC(
23399         DL, ResVT, Op0, Op1, cast<CondCodeSDNode>(Vec->getOperand(2))->get());
23400     // We may need to sign- or zero-extend the result to match the same
23401     // behaviour as the vector version of SETCC.
23402     unsigned VecBoolContents = TLI.getBooleanContents(Vec.getValueType());
23403     if (ResVT != MVT::i1 &&
23404         VecBoolContents != TargetLowering::UndefinedBooleanContent &&
23405         VecBoolContents != TLI.getBooleanContents(ResVT)) {
23406       if (VecBoolContents == TargetLowering::ZeroOrNegativeOneBooleanContent)
23407         NewVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ResVT, NewVal,
23408                              DAG.getValueType(MVT::i1));
23409       else
23410         NewVal = DAG.getZeroExtendInReg(NewVal, DL, MVT::i1);
23411     }
23412     return NewVal;
23413   }
23414   Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
23415   Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
23416   return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
23417 }
23418 
23419 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
23420 // recursively analyse all of it's users. and try to model themselves as
23421 // bit sequence extractions. If all of them agree on the new, narrower element
23422 // type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
23423 // new element type, do so now.
23424 // This is mainly useful to recover from legalization that scalarized
23425 // the vector as wide elements, but tries to rebuild it with narrower elements.
23426 //
23427 // Some more nodes could be modelled if that helps cover interesting patterns.
refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode * N)23428 bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
23429     SDNode *N) {
23430   // We perform this optimization post type-legalization because
23431   // the type-legalizer often scalarizes integer-promoted vectors.
23432   // Performing this optimization before may cause legalizaton cycles.
23433   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
23434     return false;
23435 
23436   // TODO: Add support for big-endian.
23437   if (DAG.getDataLayout().isBigEndian())
23438     return false;
23439 
23440   SDValue VecOp = N->getOperand(0);
23441   EVT VecVT = VecOp.getValueType();
23442   assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
23443 
23444   // We must start with a constant extraction index.
23445   auto *IndexC = dyn_cast<ConstantSDNode>(N->getOperand(1));
23446   if (!IndexC)
23447     return false;
23448 
23449   assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
23450          "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
23451 
23452   // TODO: deal with the case of implicit anyext of the extraction.
23453   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
23454   EVT ScalarVT = N->getValueType(0);
23455   if (VecVT.getScalarType() != ScalarVT)
23456     return false;
23457 
23458   // TODO: deal with the cases other than everything being integer-typed.
23459   if (!ScalarVT.isScalarInteger())
23460     return false;
23461 
23462   struct Entry {
23463     SDNode *Producer;
23464 
23465     // Which bits of VecOp does it contain?
23466     unsigned BitPos;
23467     int NumBits;
23468     // NOTE: the actual width of \p Producer may be wider than NumBits!
23469 
23470     Entry(Entry &&) = default;
23471     Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
23472         : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
23473 
23474     Entry() = delete;
23475     Entry(const Entry &) = delete;
23476     Entry &operator=(const Entry &) = delete;
23477     Entry &operator=(Entry &&) = delete;
23478   };
23479   SmallVector<Entry, 32> Worklist;
23480   SmallVector<Entry, 32> Leafs;
23481 
23482   // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
23483   Worklist.emplace_back(N, /*BitPos=*/VecEltBitWidth * IndexC->getZExtValue(),
23484                         /*NumBits=*/VecEltBitWidth);
23485 
23486   while (!Worklist.empty()) {
23487     Entry E = Worklist.pop_back_val();
23488     // Does the node not even use any of the VecOp bits?
23489     if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
23490           E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
23491       return false; // Let's allow the other combines clean this up first.
23492     // Did we fail to model any of the users of the Producer?
23493     bool ProducerIsLeaf = false;
23494     // Look at each user of this Producer.
23495     for (SDNode *User : E.Producer->users()) {
23496       switch (User->getOpcode()) {
23497       // TODO: support ISD::BITCAST
23498       // TODO: support ISD::ANY_EXTEND
23499       // TODO: support ISD::ZERO_EXTEND
23500       // TODO: support ISD::SIGN_EXTEND
23501       case ISD::TRUNCATE:
23502         // Truncation simply means we keep position, but extract less bits.
23503         Worklist.emplace_back(User, E.BitPos,
23504                               /*NumBits=*/User->getValueSizeInBits(0));
23505         break;
23506       // TODO: support ISD::SRA
23507       // TODO: support ISD::SHL
23508       case ISD::SRL:
23509         // We should be shifting the Producer by a constant amount.
23510         if (auto *ShAmtC = dyn_cast<ConstantSDNode>(User->getOperand(1));
23511             User->getOperand(0).getNode() == E.Producer && ShAmtC) {
23512           // Logical right-shift means that we start extraction later,
23513           // but stop it at the same position we did previously.
23514           unsigned ShAmt = ShAmtC->getZExtValue();
23515           Worklist.emplace_back(User, E.BitPos + ShAmt, E.NumBits - ShAmt);
23516           break;
23517         }
23518         [[fallthrough]];
23519       default:
23520         // We can not model this user of the Producer.
23521         // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
23522         ProducerIsLeaf = true;
23523         // Profitability check: all users that we can not model
23524         //                      must be ISD::BUILD_VECTOR's.
23525         if (User->getOpcode() != ISD::BUILD_VECTOR)
23526           return false;
23527         break;
23528       }
23529     }
23530     if (ProducerIsLeaf)
23531       Leafs.emplace_back(std::move(E));
23532   }
23533 
23534   unsigned NewVecEltBitWidth = Leafs.front().NumBits;
23535 
23536   // If we are still at the same element granularity, give up,
23537   if (NewVecEltBitWidth == VecEltBitWidth)
23538     return false;
23539 
23540   // The vector width must be a multiple of the new element width.
23541   if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
23542     return false;
23543 
23544   // All leafs must agree on the new element width.
23545   // All leafs must not expect any "padding" bits ontop of that width.
23546   // All leafs must start extraction from multiple of that width.
23547   if (!all_of(Leafs, [NewVecEltBitWidth](const Entry &E) {
23548         return (unsigned)E.NumBits == NewVecEltBitWidth &&
23549                E.Producer->getValueSizeInBits(0) == NewVecEltBitWidth &&
23550                E.BitPos % NewVecEltBitWidth == 0;
23551       }))
23552     return false;
23553 
23554   EVT NewScalarVT = EVT::getIntegerVT(*DAG.getContext(), NewVecEltBitWidth);
23555   EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewScalarVT,
23556                                   VecVT.getSizeInBits() / NewVecEltBitWidth);
23557 
23558   if (LegalTypes &&
23559       !(TLI.isTypeLegal(NewScalarVT) && TLI.isTypeLegal(NewVecVT)))
23560     return false;
23561 
23562   if (LegalOperations &&
23563       !(TLI.isOperationLegalOrCustom(ISD::BITCAST, NewVecVT) &&
23564         TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, NewVecVT)))
23565     return false;
23566 
23567   SDValue NewVecOp = DAG.getBitcast(NewVecVT, VecOp);
23568   for (const Entry &E : Leafs) {
23569     SDLoc DL(E.Producer);
23570     unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
23571     assert(NewIndex < NewVecVT.getVectorNumElements() &&
23572            "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
23573     SDValue V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, NewScalarVT, NewVecOp,
23574                             DAG.getVectorIdxConstant(NewIndex, DL));
23575     CombineTo(E.Producer, V);
23576   }
23577 
23578   return true;
23579 }
23580 
visitEXTRACT_VECTOR_ELT(SDNode * N)23581 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
23582   SDValue VecOp = N->getOperand(0);
23583   SDValue Index = N->getOperand(1);
23584   EVT ScalarVT = N->getValueType(0);
23585   EVT VecVT = VecOp.getValueType();
23586   if (VecOp.isUndef())
23587     return DAG.getUNDEF(ScalarVT);
23588 
23589   // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
23590   //
23591   // This only really matters if the index is non-constant since other combines
23592   // on the constant elements already work.
23593   SDLoc DL(N);
23594   if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
23595       Index == VecOp.getOperand(2)) {
23596     SDValue Elt = VecOp.getOperand(1);
23597     AddUsersToWorklist(VecOp.getNode());
23598     return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
23599   }
23600 
23601   // (vextract (scalar_to_vector val, 0) -> val
23602   if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
23603     // Only 0'th element of SCALAR_TO_VECTOR is defined.
23604     if (DAG.isKnownNeverZero(Index))
23605       return DAG.getUNDEF(ScalarVT);
23606 
23607     // Check if the result type doesn't match the inserted element type.
23608     // The inserted element and extracted element may have mismatched bitwidth.
23609     // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
23610     SDValue InOp = VecOp.getOperand(0);
23611     if (InOp.getValueType() != ScalarVT) {
23612       assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
23613       if (InOp.getValueType().bitsGT(ScalarVT))
23614         return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
23615       return DAG.getNode(ISD::ANY_EXTEND, DL, ScalarVT, InOp);
23616     }
23617     return InOp;
23618   }
23619 
23620   // extract_vector_elt of out-of-bounds element -> UNDEF
23621   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
23622   if (IndexC && VecVT.isFixedLengthVector() &&
23623       IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
23624     return DAG.getUNDEF(ScalarVT);
23625 
23626   // extract_vector_elt (build_vector x, y), 1 -> y
23627   if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
23628        VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
23629       TLI.isTypeLegal(VecVT)) {
23630     assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
23631             VecVT.isFixedLengthVector()) &&
23632            "BUILD_VECTOR used for scalable vectors");
23633     unsigned IndexVal =
23634         VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
23635     SDValue Elt = VecOp.getOperand(IndexVal);
23636     EVT InEltVT = Elt.getValueType();
23637 
23638     if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
23639         isNullConstant(Elt)) {
23640       // Sometimes build_vector's scalar input types do not match result type.
23641       if (ScalarVT == InEltVT)
23642         return Elt;
23643 
23644       // TODO: It may be useful to truncate if free if the build_vector
23645       // implicitly converts.
23646     }
23647   }
23648 
23649   if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
23650     return BO;
23651 
23652   if (VecVT.isScalableVector())
23653     return SDValue();
23654 
23655   // All the code from this point onwards assumes fixed width vectors, but it's
23656   // possible that some of the combinations could be made to work for scalable
23657   // vectors too.
23658   unsigned NumElts = VecVT.getVectorNumElements();
23659   unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
23660 
23661   // See if the extracted element is constant, in which case fold it if its
23662   // a legal fp immediate.
23663   if (IndexC && ScalarVT.isFloatingPoint()) {
23664     APInt EltMask = APInt::getOneBitSet(NumElts, IndexC->getZExtValue());
23665     KnownBits KnownElt = DAG.computeKnownBits(VecOp, EltMask);
23666     if (KnownElt.isConstant()) {
23667       APFloat CstFP =
23668           APFloat(ScalarVT.getFltSemantics(), KnownElt.getConstant());
23669       if (TLI.isFPImmLegal(CstFP, ScalarVT))
23670         return DAG.getConstantFP(CstFP, DL, ScalarVT);
23671     }
23672   }
23673 
23674   // TODO: These transforms should not require the 'hasOneUse' restriction, but
23675   // there are regressions on multiple targets without it. We can end up with a
23676   // mess of scalar and vector code if we reduce only part of the DAG to scalar.
23677   if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
23678       VecOp.hasOneUse()) {
23679     // The vector index of the LSBs of the source depend on the endian-ness.
23680     bool IsLE = DAG.getDataLayout().isLittleEndian();
23681     unsigned ExtractIndex = IndexC->getZExtValue();
23682     // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
23683     unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
23684     SDValue BCSrc = VecOp.getOperand(0);
23685     if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
23686       return DAG.getAnyExtOrTrunc(BCSrc, DL, ScalarVT);
23687 
23688     // TODO: Add support for SCALAR_TO_VECTOR implicit truncation.
23689     if (LegalTypes && BCSrc.getValueType().isInteger() &&
23690         BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23691         BCSrc.getScalarValueSizeInBits() ==
23692             BCSrc.getOperand(0).getScalarValueSizeInBits()) {
23693       // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
23694       // trunc i64 X to i32
23695       SDValue X = BCSrc.getOperand(0);
23696       EVT XVT = X.getValueType();
23697       assert(XVT.isScalarInteger() && ScalarVT.isScalarInteger() &&
23698              "Extract element and scalar to vector can't change element type "
23699              "from FP to integer.");
23700       unsigned XBitWidth = X.getValueSizeInBits();
23701       unsigned Scale = XBitWidth / VecEltBitWidth;
23702       BCTruncElt = IsLE ? 0 : Scale - 1;
23703 
23704       // An extract element return value type can be wider than its vector
23705       // operand element type. In that case, the high bits are undefined, so
23706       // it's possible that we may need to extend rather than truncate.
23707       if (ExtractIndex < Scale && XBitWidth > VecEltBitWidth) {
23708         assert(XBitWidth % VecEltBitWidth == 0 &&
23709                "Scalar bitwidth must be a multiple of vector element bitwidth");
23710 
23711         if (ExtractIndex != BCTruncElt) {
23712           unsigned ShiftIndex =
23713               IsLE ? ExtractIndex : (Scale - 1) - ExtractIndex;
23714           X = DAG.getNode(
23715               ISD::SRL, DL, XVT, X,
23716               DAG.getShiftAmountConstant(ShiftIndex * VecEltBitWidth, XVT, DL));
23717         }
23718 
23719         return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
23720       }
23721     }
23722   }
23723 
23724   // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
23725   // We only perform this optimization before the op legalization phase because
23726   // we may introduce new vector instructions which are not backed by TD
23727   // patterns. For example on AVX, extracting elements from a wide vector
23728   // without using extract_subvector. However, if we can find an underlying
23729   // scalar value, then we can always use that.
23730   if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
23731     auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
23732     // Find the new index to extract from.
23733     int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
23734 
23735     // Extracting an undef index is undef.
23736     if (OrigElt == -1)
23737       return DAG.getUNDEF(ScalarVT);
23738 
23739     // Select the right vector half to extract from.
23740     SDValue SVInVec;
23741     if (OrigElt < (int)NumElts) {
23742       SVInVec = VecOp.getOperand(0);
23743     } else {
23744       SVInVec = VecOp.getOperand(1);
23745       OrigElt -= NumElts;
23746     }
23747 
23748     if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
23749       // TODO: Check if shuffle mask is legal?
23750       if (LegalOperations && TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VecVT) &&
23751           !VecOp.hasOneUse())
23752         return SDValue();
23753 
23754       SDValue InOp = SVInVec.getOperand(OrigElt);
23755       if (InOp.getValueType() != ScalarVT) {
23756         assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
23757         InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
23758       }
23759 
23760       return InOp;
23761     }
23762 
23763     // FIXME: We should handle recursing on other vector shuffles and
23764     // scalar_to_vector here as well.
23765 
23766     if (!LegalOperations ||
23767         // FIXME: Should really be just isOperationLegalOrCustom.
23768         TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
23769         TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
23770       return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
23771                          DAG.getVectorIdxConstant(OrigElt, DL));
23772     }
23773   }
23774 
23775   // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
23776   // simplify it based on the (valid) extraction indices.
23777   if (llvm::all_of(VecOp->users(), [&](SDNode *Use) {
23778         return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23779                Use->getOperand(0) == VecOp &&
23780                isa<ConstantSDNode>(Use->getOperand(1));
23781       })) {
23782     APInt DemandedElts = APInt::getZero(NumElts);
23783     for (SDNode *User : VecOp->users()) {
23784       auto *CstElt = cast<ConstantSDNode>(User->getOperand(1));
23785       if (CstElt->getAPIntValue().ult(NumElts))
23786         DemandedElts.setBit(CstElt->getZExtValue());
23787     }
23788     if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
23789       // We simplified the vector operand of this extract element. If this
23790       // extract is not dead, visit it again so it is folded properly.
23791       if (N->getOpcode() != ISD::DELETED_NODE)
23792         AddToWorklist(N);
23793       return SDValue(N, 0);
23794     }
23795     APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
23796     if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
23797       // We simplified the vector operand of this extract element. If this
23798       // extract is not dead, visit it again so it is folded properly.
23799       if (N->getOpcode() != ISD::DELETED_NODE)
23800         AddToWorklist(N);
23801       return SDValue(N, 0);
23802     }
23803   }
23804 
23805   if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
23806     return SDValue(N, 0);
23807 
23808   // Everything under here is trying to match an extract of a loaded value.
23809   // If the result of load has to be truncated, then it's not necessarily
23810   // profitable.
23811   bool BCNumEltsChanged = false;
23812   EVT ExtVT = VecVT.getVectorElementType();
23813   EVT LVT = ExtVT;
23814   if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
23815     return SDValue();
23816 
23817   if (VecOp.getOpcode() == ISD::BITCAST) {
23818     // Don't duplicate a load with other uses.
23819     if (!VecOp.hasOneUse())
23820       return SDValue();
23821 
23822     EVT BCVT = VecOp.getOperand(0).getValueType();
23823     if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
23824       return SDValue();
23825     if (NumElts != BCVT.getVectorNumElements())
23826       BCNumEltsChanged = true;
23827     VecOp = VecOp.getOperand(0);
23828     ExtVT = BCVT.getVectorElementType();
23829   }
23830 
23831   // extract (vector load $addr), i --> load $addr + i * size
23832   if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
23833       ISD::isNormalLoad(VecOp.getNode()) &&
23834       !Index->hasPredecessor(VecOp.getNode())) {
23835     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
23836     if (VecLoad && VecLoad->isSimple()) {
23837       if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
23838               ScalarVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
23839         ++OpsNarrowed;
23840         return Scalarized;
23841       }
23842     }
23843   }
23844 
23845   // Perform only after legalization to ensure build_vector / vector_shuffle
23846   // optimizations have already been done.
23847   if (!LegalOperations || !IndexC)
23848     return SDValue();
23849 
23850   // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
23851   // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
23852   // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
23853   int Elt = IndexC->getZExtValue();
23854   LoadSDNode *LN0 = nullptr;
23855   if (ISD::isNormalLoad(VecOp.getNode())) {
23856     LN0 = cast<LoadSDNode>(VecOp);
23857   } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23858              VecOp.getOperand(0).getValueType() == ExtVT &&
23859              ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
23860     // Don't duplicate a load with other uses.
23861     if (!VecOp.hasOneUse())
23862       return SDValue();
23863 
23864     LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
23865   }
23866   if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
23867     // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
23868     // =>
23869     // (load $addr+1*size)
23870 
23871     // Don't duplicate a load with other uses.
23872     if (!VecOp.hasOneUse())
23873       return SDValue();
23874 
23875     // If the bit convert changed the number of elements, it is unsafe
23876     // to examine the mask.
23877     if (BCNumEltsChanged)
23878       return SDValue();
23879 
23880     // Select the input vector, guarding against out of range extract vector.
23881     int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
23882     VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
23883 
23884     if (VecOp.getOpcode() == ISD::BITCAST) {
23885       // Don't duplicate a load with other uses.
23886       if (!VecOp.hasOneUse())
23887         return SDValue();
23888 
23889       VecOp = VecOp.getOperand(0);
23890     }
23891     if (ISD::isNormalLoad(VecOp.getNode())) {
23892       LN0 = cast<LoadSDNode>(VecOp);
23893       Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
23894       Index = DAG.getConstant(Elt, DL, Index.getValueType());
23895     }
23896   } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
23897              VecVT.getVectorElementType() == ScalarVT &&
23898              (!LegalTypes ||
23899               TLI.isTypeLegal(
23900                   VecOp.getOperand(0).getValueType().getVectorElementType()))) {
23901     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
23902     //      -> extract_vector_elt a, 0
23903     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
23904     //      -> extract_vector_elt a, 1
23905     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
23906     //      -> extract_vector_elt b, 0
23907     // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
23908     //      -> extract_vector_elt b, 1
23909     EVT ConcatVT = VecOp.getOperand(0).getValueType();
23910     unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
23911     SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, DL,
23912                                      Index.getValueType());
23913 
23914     SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
23915     SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
23916                               ConcatVT.getVectorElementType(),
23917                               ConcatOp, NewIdx);
23918     return DAG.getNode(ISD::BITCAST, DL, ScalarVT, Elt);
23919   }
23920 
23921   // Make sure we found a non-volatile load and the extractelement is
23922   // the only use.
23923   if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
23924     return SDValue();
23925 
23926   // If Idx was -1 above, Elt is going to be -1, so just return undef.
23927   if (Elt == -1)
23928     return DAG.getUNDEF(LVT);
23929 
23930   if (SDValue Scalarized =
23931           TLI.scalarizeExtractedVectorLoad(LVT, DL, VecVT, Index, LN0, DAG)) {
23932     ++OpsNarrowed;
23933     return Scalarized;
23934   }
23935 
23936   return SDValue();
23937 }
23938 
23939 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)23940 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
23941   // We perform this optimization post type-legalization because
23942   // the type-legalizer often scalarizes integer-promoted vectors.
23943   // Performing this optimization before may create bit-casts which
23944   // will be type-legalized to complex code sequences.
23945   // We perform this optimization only before the operation legalizer because we
23946   // may introduce illegal operations.
23947   if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
23948     return SDValue();
23949 
23950   unsigned NumInScalars = N->getNumOperands();
23951   SDLoc DL(N);
23952   EVT VT = N->getValueType(0);
23953 
23954   // Check to see if this is a BUILD_VECTOR of a bunch of values
23955   // which come from any_extend or zero_extend nodes. If so, we can create
23956   // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
23957   // optimizations. We do not handle sign-extend because we can't fill the sign
23958   // using shuffles.
23959   EVT SourceType = MVT::Other;
23960   bool AllAnyExt = true;
23961 
23962   for (unsigned i = 0; i != NumInScalars; ++i) {
23963     SDValue In = N->getOperand(i);
23964     // Ignore undef inputs.
23965     if (In.isUndef()) continue;
23966 
23967     bool AnyExt  = In.getOpcode() == ISD::ANY_EXTEND;
23968     bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
23969 
23970     // Abort if the element is not an extension.
23971     if (!ZeroExt && !AnyExt) {
23972       SourceType = MVT::Other;
23973       break;
23974     }
23975 
23976     // The input is a ZeroExt or AnyExt. Check the original type.
23977     EVT InTy = In.getOperand(0).getValueType();
23978 
23979     // Check that all of the widened source types are the same.
23980     if (SourceType == MVT::Other)
23981       // First time.
23982       SourceType = InTy;
23983     else if (InTy != SourceType) {
23984       // Multiple income types. Abort.
23985       SourceType = MVT::Other;
23986       break;
23987     }
23988 
23989     // Check if all of the extends are ANY_EXTENDs.
23990     AllAnyExt &= AnyExt;
23991   }
23992 
23993   // In order to have valid types, all of the inputs must be extended from the
23994   // same source type and all of the inputs must be any or zero extend.
23995   // Scalar sizes must be a power of two.
23996   EVT OutScalarTy = VT.getScalarType();
23997   bool ValidTypes =
23998       SourceType != MVT::Other &&
23999       llvm::has_single_bit<uint32_t>(OutScalarTy.getSizeInBits()) &&
24000       llvm::has_single_bit<uint32_t>(SourceType.getSizeInBits());
24001 
24002   // Create a new simpler BUILD_VECTOR sequence which other optimizations can
24003   // turn into a single shuffle instruction.
24004   if (!ValidTypes)
24005     return SDValue();
24006 
24007   // If we already have a splat buildvector, then don't fold it if it means
24008   // introducing zeros.
24009   if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
24010     return SDValue();
24011 
24012   bool isLE = DAG.getDataLayout().isLittleEndian();
24013   unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
24014   assert(ElemRatio > 1 && "Invalid element size ratio");
24015   SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
24016                                DAG.getConstant(0, DL, SourceType);
24017 
24018   unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
24019   SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
24020 
24021   // Populate the new build_vector
24022   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
24023     SDValue Cast = N->getOperand(i);
24024     assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
24025             Cast.getOpcode() == ISD::ZERO_EXTEND ||
24026             Cast.isUndef()) && "Invalid cast opcode");
24027     SDValue In;
24028     if (Cast.isUndef())
24029       In = DAG.getUNDEF(SourceType);
24030     else
24031       In = Cast->getOperand(0);
24032     unsigned Index = isLE ? (i * ElemRatio) :
24033                             (i * ElemRatio + (ElemRatio - 1));
24034 
24035     assert(Index < Ops.size() && "Invalid index");
24036     Ops[Index] = In;
24037   }
24038 
24039   // The type of the new BUILD_VECTOR node.
24040   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
24041   assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
24042          "Invalid vector size");
24043   // Check if the new vector type is legal.
24044   if (!isTypeLegal(VecVT) ||
24045       (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
24046        TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
24047     return SDValue();
24048 
24049   // Make the new BUILD_VECTOR.
24050   SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
24051 
24052   // The new BUILD_VECTOR node has the potential to be further optimized.
24053   AddToWorklist(BV.getNode());
24054   // Bitcast to the desired type.
24055   return DAG.getBitcast(VT, BV);
24056 }
24057 
24058 // Simplify (build_vec (trunc $1)
24059 //                     (trunc (srl $1 half-width))
24060 //                     (trunc (srl $1 (2 * half-width))))
24061 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)24062 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
24063   assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
24064 
24065   EVT VT = N->getValueType(0);
24066 
24067   // Don't run this before LegalizeTypes if VT is legal.
24068   // Targets may have other preferences.
24069   if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
24070     return SDValue();
24071 
24072   // Only for little endian
24073   if (!DAG.getDataLayout().isLittleEndian())
24074     return SDValue();
24075 
24076   EVT OutScalarTy = VT.getScalarType();
24077   uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
24078 
24079   // Only for power of two types to be sure that bitcast works well
24080   if (!isPowerOf2_64(ScalarTypeBitsize))
24081     return SDValue();
24082 
24083   unsigned NumInScalars = N->getNumOperands();
24084 
24085   // Look through bitcasts
24086   auto PeekThroughBitcast = [](SDValue Op) {
24087     if (Op.getOpcode() == ISD::BITCAST)
24088       return Op.getOperand(0);
24089     return Op;
24090   };
24091 
24092   // The source value where all the parts are extracted.
24093   SDValue Src;
24094   for (unsigned i = 0; i != NumInScalars; ++i) {
24095     SDValue In = PeekThroughBitcast(N->getOperand(i));
24096     // Ignore undef inputs.
24097     if (In.isUndef()) continue;
24098 
24099     if (In.getOpcode() != ISD::TRUNCATE)
24100       return SDValue();
24101 
24102     In = PeekThroughBitcast(In.getOperand(0));
24103 
24104     if (In.getOpcode() != ISD::SRL) {
24105       // For now only build_vec without shuffling, handle shifts here in the
24106       // future.
24107       if (i != 0)
24108         return SDValue();
24109 
24110       Src = In;
24111     } else {
24112       // In is SRL
24113       SDValue part = PeekThroughBitcast(In.getOperand(0));
24114 
24115       if (!Src) {
24116         Src = part;
24117       } else if (Src != part) {
24118         // Vector parts do not stem from the same variable
24119         return SDValue();
24120       }
24121 
24122       SDValue ShiftAmtVal = In.getOperand(1);
24123       if (!isa<ConstantSDNode>(ShiftAmtVal))
24124         return SDValue();
24125 
24126       uint64_t ShiftAmt = In.getConstantOperandVal(1);
24127 
24128       // The extracted value is not extracted at the right position
24129       if (ShiftAmt != i * ScalarTypeBitsize)
24130         return SDValue();
24131     }
24132   }
24133 
24134   // Only cast if the size is the same
24135   if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
24136     return SDValue();
24137 
24138   return DAG.getBitcast(VT, Src);
24139 }
24140 
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)24141 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
24142                                            ArrayRef<int> VectorMask,
24143                                            SDValue VecIn1, SDValue VecIn2,
24144                                            unsigned LeftIdx, bool DidSplitVec) {
24145   EVT VT = N->getValueType(0);
24146   EVT InVT1 = VecIn1.getValueType();
24147   EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
24148 
24149   unsigned NumElems = VT.getVectorNumElements();
24150   unsigned ShuffleNumElems = NumElems;
24151 
24152   // If we artificially split a vector in two already, then the offsets in the
24153   // operands will all be based off of VecIn1, even those in VecIn2.
24154   unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
24155 
24156   uint64_t VTSize = VT.getFixedSizeInBits();
24157   uint64_t InVT1Size = InVT1.getFixedSizeInBits();
24158   uint64_t InVT2Size = InVT2.getFixedSizeInBits();
24159 
24160   assert(InVT2Size <= InVT1Size &&
24161          "Inputs must be sorted to be in non-increasing vector size order.");
24162 
24163   // We can't generate a shuffle node with mismatched input and output types.
24164   // Try to make the types match the type of the output.
24165   if (InVT1 != VT || InVT2 != VT) {
24166     if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
24167       // If the output vector length is a multiple of both input lengths,
24168       // we can concatenate them and pad the rest with undefs.
24169       unsigned NumConcats = VTSize / InVT1Size;
24170       assert(NumConcats >= 2 && "Concat needs at least two inputs!");
24171       SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
24172       ConcatOps[0] = VecIn1;
24173       ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
24174       VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
24175       VecIn2 = SDValue();
24176     } else if (InVT1Size == VTSize * 2) {
24177       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
24178         return SDValue();
24179 
24180       if (!VecIn2.getNode()) {
24181         // If we only have one input vector, and it's twice the size of the
24182         // output, split it in two.
24183         VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
24184                              DAG.getVectorIdxConstant(NumElems, DL));
24185         VecIn1 = DAG.getExtractSubvector(DL, VT, VecIn1, 0);
24186         // Since we now have shorter input vectors, adjust the offset of the
24187         // second vector's start.
24188         Vec2Offset = NumElems;
24189       } else {
24190         assert(InVT2Size <= InVT1Size &&
24191                "Second input is not going to be larger than the first one.");
24192 
24193         // VecIn1 is wider than the output, and we have another, possibly
24194         // smaller input. Pad the smaller input with undefs, shuffle at the
24195         // input vector width, and extract the output.
24196         // The shuffle type is different than VT, so check legality again.
24197         if (LegalOperations &&
24198             !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
24199           return SDValue();
24200 
24201         // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
24202         // lower it back into a BUILD_VECTOR. So if the inserted type is
24203         // illegal, don't even try.
24204         if (InVT1 != InVT2) {
24205           if (!TLI.isTypeLegal(InVT2))
24206             return SDValue();
24207           VecIn2 = DAG.getInsertSubvector(DL, DAG.getUNDEF(InVT1), VecIn2, 0);
24208         }
24209         ShuffleNumElems = NumElems * 2;
24210       }
24211     } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
24212       SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
24213       ConcatOps[0] = VecIn2;
24214       VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
24215     } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
24216       if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems) ||
24217           !TLI.isTypeLegal(InVT1) || !TLI.isTypeLegal(InVT2))
24218         return SDValue();
24219       // If dest vector has less than two elements, then use shuffle and extract
24220       // from larger regs will cost even more.
24221       if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
24222         return SDValue();
24223       assert(InVT2Size <= InVT1Size &&
24224              "Second input is not going to be larger than the first one.");
24225 
24226       // VecIn1 is wider than the output, and we have another, possibly
24227       // smaller input. Pad the smaller input with undefs, shuffle at the
24228       // input vector width, and extract the output.
24229       // The shuffle type is different than VT, so check legality again.
24230       if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
24231         return SDValue();
24232 
24233       if (InVT1 != InVT2) {
24234         VecIn2 = DAG.getInsertSubvector(DL, DAG.getUNDEF(InVT1), VecIn2, 0);
24235       }
24236       ShuffleNumElems = InVT1Size / VTSize * NumElems;
24237     } else {
24238       // TODO: Support cases where the length mismatch isn't exactly by a
24239       // factor of 2.
24240       // TODO: Move this check upwards, so that if we have bad type
24241       // mismatches, we don't create any DAG nodes.
24242       return SDValue();
24243     }
24244   }
24245 
24246   // Initialize mask to undef.
24247   SmallVector<int, 8> Mask(ShuffleNumElems, -1);
24248 
24249   // Only need to run up to the number of elements actually used, not the
24250   // total number of elements in the shuffle - if we are shuffling a wider
24251   // vector, the high lanes should be set to undef.
24252   for (unsigned i = 0; i != NumElems; ++i) {
24253     if (VectorMask[i] <= 0)
24254       continue;
24255 
24256     unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
24257     if (VectorMask[i] == (int)LeftIdx) {
24258       Mask[i] = ExtIndex;
24259     } else if (VectorMask[i] == (int)LeftIdx + 1) {
24260       Mask[i] = Vec2Offset + ExtIndex;
24261     }
24262   }
24263 
24264   // The type the input vectors may have changed above.
24265   InVT1 = VecIn1.getValueType();
24266 
24267   // If we already have a VecIn2, it should have the same type as VecIn1.
24268   // If we don't, get an undef/zero vector of the appropriate type.
24269   VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
24270   assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
24271 
24272   SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
24273   if (ShuffleNumElems > NumElems)
24274     Shuffle = DAG.getExtractSubvector(DL, VT, Shuffle, 0);
24275 
24276   return Shuffle;
24277 }
24278 
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)24279 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
24280   assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
24281 
24282   // First, determine where the build vector is not undef.
24283   // TODO: We could extend this to handle zero elements as well as undefs.
24284   int NumBVOps = BV->getNumOperands();
24285   int ZextElt = -1;
24286   for (int i = 0; i != NumBVOps; ++i) {
24287     SDValue Op = BV->getOperand(i);
24288     if (Op.isUndef())
24289       continue;
24290     if (ZextElt == -1)
24291       ZextElt = i;
24292     else
24293       return SDValue();
24294   }
24295   // Bail out if there's no non-undef element.
24296   if (ZextElt == -1)
24297     return SDValue();
24298 
24299   // The build vector contains some number of undef elements and exactly
24300   // one other element. That other element must be a zero-extended scalar
24301   // extracted from a vector at a constant index to turn this into a shuffle.
24302   // Also, require that the build vector does not implicitly truncate/extend
24303   // its elements.
24304   // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
24305   EVT VT = BV->getValueType(0);
24306   SDValue Zext = BV->getOperand(ZextElt);
24307   if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
24308       Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
24309       !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
24310       Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
24311     return SDValue();
24312 
24313   // The zero-extend must be a multiple of the source size, and we must be
24314   // building a vector of the same size as the source of the extract element.
24315   SDValue Extract = Zext.getOperand(0);
24316   unsigned DestSize = Zext.getValueSizeInBits();
24317   unsigned SrcSize = Extract.getValueSizeInBits();
24318   if (DestSize % SrcSize != 0 ||
24319       Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
24320     return SDValue();
24321 
24322   // Create a shuffle mask that will combine the extracted element with zeros
24323   // and undefs.
24324   int ZextRatio = DestSize / SrcSize;
24325   int NumMaskElts = NumBVOps * ZextRatio;
24326   SmallVector<int, 32> ShufMask(NumMaskElts, -1);
24327   for (int i = 0; i != NumMaskElts; ++i) {
24328     if (i / ZextRatio == ZextElt) {
24329       // The low bits of the (potentially translated) extracted element map to
24330       // the source vector. The high bits map to zero. We will use a zero vector
24331       // as the 2nd source operand of the shuffle, so use the 1st element of
24332       // that vector (mask value is number-of-elements) for the high bits.
24333       int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
24334       ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(1)
24335                                            : NumMaskElts;
24336     }
24337 
24338     // Undef elements of the build vector remain undef because we initialize
24339     // the shuffle mask with -1.
24340   }
24341 
24342   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
24343   // bitcast (shuffle V, ZeroVec, VectorMask)
24344   SDLoc DL(BV);
24345   EVT VecVT = Extract.getOperand(0).getValueType();
24346   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
24347   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24348   SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
24349                                              ZeroVec, ShufMask, DAG);
24350   if (!Shuf)
24351     return SDValue();
24352   return DAG.getBitcast(VT, Shuf);
24353 }
24354 
24355 // FIXME: promote to STLExtras.
24356 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)24357 static auto getFirstIndexOf(R &&Range, const T &Val) {
24358   auto I = find(Range, Val);
24359   if (I == Range.end())
24360     return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
24361   return std::distance(Range.begin(), I);
24362 }
24363 
24364 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
24365 // operations. If the types of the vectors we're extracting from allow it,
24366 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)24367 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
24368   SDLoc DL(N);
24369   EVT VT = N->getValueType(0);
24370 
24371   // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
24372   if (!isTypeLegal(VT))
24373     return SDValue();
24374 
24375   if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
24376     return V;
24377 
24378   // May only combine to shuffle after legalize if shuffle is legal.
24379   if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
24380     return SDValue();
24381 
24382   bool UsesZeroVector = false;
24383   unsigned NumElems = N->getNumOperands();
24384 
24385   // Record, for each element of the newly built vector, which input vector
24386   // that element comes from. -1 stands for undef, 0 for the zero vector,
24387   // and positive values for the input vectors.
24388   // VectorMask maps each element to its vector number, and VecIn maps vector
24389   // numbers to their initial SDValues.
24390 
24391   SmallVector<int, 8> VectorMask(NumElems, -1);
24392   SmallVector<SDValue, 8> VecIn;
24393   VecIn.push_back(SDValue());
24394 
24395   // If we have a single extract_element with a constant index, track the index
24396   // value.
24397   unsigned OneConstExtractIndex = ~0u;
24398 
24399   // Count the number of extract_vector_elt sources (i.e. non-constant or undef)
24400   unsigned NumExtracts = 0;
24401 
24402   for (unsigned i = 0; i != NumElems; ++i) {
24403     SDValue Op = N->getOperand(i);
24404 
24405     if (Op.isUndef())
24406       continue;
24407 
24408     // See if we can use a blend with a zero vector.
24409     // TODO: Should we generalize this to a blend with an arbitrary constant
24410     // vector?
24411     if (isNullConstant(Op) || isNullFPConstant(Op)) {
24412       UsesZeroVector = true;
24413       VectorMask[i] = 0;
24414       continue;
24415     }
24416 
24417     // Not an undef or zero. If the input is something other than an
24418     // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
24419     if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
24420       return SDValue();
24421 
24422     SDValue ExtractedFromVec = Op.getOperand(0);
24423     if (ExtractedFromVec.getValueType().isScalableVector())
24424       return SDValue();
24425     auto *ExtractIdx = dyn_cast<ConstantSDNode>(Op.getOperand(1));
24426     if (!ExtractIdx)
24427       return SDValue();
24428 
24429     if (ExtractIdx->getAsAPIntVal().uge(
24430             ExtractedFromVec.getValueType().getVectorNumElements()))
24431       return SDValue();
24432 
24433     // All inputs must have the same element type as the output.
24434     if (VT.getVectorElementType() !=
24435         ExtractedFromVec.getValueType().getVectorElementType())
24436       return SDValue();
24437 
24438     OneConstExtractIndex = ExtractIdx->getZExtValue();
24439     ++NumExtracts;
24440 
24441     // Have we seen this input vector before?
24442     // The vectors are expected to be tiny (usually 1 or 2 elements), so using
24443     // a map back from SDValues to numbers isn't worth it.
24444     int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
24445     if (Idx == -1) { // A new source vector?
24446       Idx = VecIn.size();
24447       VecIn.push_back(ExtractedFromVec);
24448     }
24449 
24450     VectorMask[i] = Idx;
24451   }
24452 
24453   // If we didn't find at least one input vector, bail out.
24454   if (VecIn.size() < 2)
24455     return SDValue();
24456 
24457   // If all the Operands of BUILD_VECTOR extract from same
24458   // vector, then split the vector efficiently based on the maximum
24459   // vector access index and adjust the VectorMask and
24460   // VecIn accordingly.
24461   bool DidSplitVec = false;
24462   if (VecIn.size() == 2) {
24463     // If we only found a single constant indexed extract_vector_elt feeding the
24464     // build_vector, do not produce a more complicated shuffle if the extract is
24465     // cheap with other constant/undef elements. Skip broadcast patterns with
24466     // multiple uses in the build_vector.
24467 
24468     // TODO: This should be more aggressive about skipping the shuffle
24469     // formation, particularly if VecIn[1].hasOneUse(), and regardless of the
24470     // index.
24471     if (NumExtracts == 1 &&
24472         TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, VT) &&
24473         TLI.isTypeLegal(VT.getVectorElementType()) &&
24474         TLI.isExtractVecEltCheap(VT, OneConstExtractIndex))
24475       return SDValue();
24476 
24477     unsigned MaxIndex = 0;
24478     unsigned NearestPow2 = 0;
24479     SDValue Vec = VecIn.back();
24480     EVT InVT = Vec.getValueType();
24481     SmallVector<unsigned, 8> IndexVec(NumElems, 0);
24482 
24483     for (unsigned i = 0; i < NumElems; i++) {
24484       if (VectorMask[i] <= 0)
24485         continue;
24486       unsigned Index = N->getOperand(i).getConstantOperandVal(1);
24487       IndexVec[i] = Index;
24488       MaxIndex = std::max(MaxIndex, Index);
24489     }
24490 
24491     NearestPow2 = PowerOf2Ceil(MaxIndex);
24492     if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
24493         NumElems * 2 < NearestPow2) {
24494       unsigned SplitSize = NearestPow2 / 2;
24495       EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
24496                                      InVT.getVectorElementType(), SplitSize);
24497       if (TLI.isTypeLegal(SplitVT) &&
24498           SplitSize + SplitVT.getVectorNumElements() <=
24499               InVT.getVectorNumElements()) {
24500         SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
24501                                      DAG.getVectorIdxConstant(SplitSize, DL));
24502         SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
24503                                      DAG.getVectorIdxConstant(0, DL));
24504         VecIn.pop_back();
24505         VecIn.push_back(VecIn1);
24506         VecIn.push_back(VecIn2);
24507         DidSplitVec = true;
24508 
24509         for (unsigned i = 0; i < NumElems; i++) {
24510           if (VectorMask[i] <= 0)
24511             continue;
24512           VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
24513         }
24514       }
24515     }
24516   }
24517 
24518   // Sort input vectors by decreasing vector element count,
24519   // while preserving the relative order of equally-sized vectors.
24520   // Note that we keep the first "implicit zero vector as-is.
24521   SmallVector<SDValue, 8> SortedVecIn(VecIn);
24522   llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
24523                     [](const SDValue &a, const SDValue &b) {
24524                       return a.getValueType().getVectorNumElements() >
24525                              b.getValueType().getVectorNumElements();
24526                     });
24527 
24528   // We now also need to rebuild the VectorMask, because it referenced element
24529   // order in VecIn, and we just sorted them.
24530   for (int &SourceVectorIndex : VectorMask) {
24531     if (SourceVectorIndex <= 0)
24532       continue;
24533     unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
24534     assert(Idx > 0 && Idx < SortedVecIn.size() &&
24535            VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
24536     SourceVectorIndex = Idx;
24537   }
24538 
24539   VecIn = std::move(SortedVecIn);
24540 
24541   // TODO: Should this fire if some of the input vectors has illegal type (like
24542   // it does now), or should we let legalization run its course first?
24543 
24544   // Shuffle phase:
24545   // Take pairs of vectors, and shuffle them so that the result has elements
24546   // from these vectors in the correct places.
24547   // For example, given:
24548   // t10: i32 = extract_vector_elt t1, Constant:i64<0>
24549   // t11: i32 = extract_vector_elt t2, Constant:i64<0>
24550   // t12: i32 = extract_vector_elt t3, Constant:i64<0>
24551   // t13: i32 = extract_vector_elt t1, Constant:i64<1>
24552   // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
24553   // We will generate:
24554   // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
24555   // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
24556   SmallVector<SDValue, 4> Shuffles;
24557   for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
24558     unsigned LeftIdx = 2 * In + 1;
24559     SDValue VecLeft = VecIn[LeftIdx];
24560     SDValue VecRight =
24561         (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
24562 
24563     if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
24564                                                 VecRight, LeftIdx, DidSplitVec))
24565       Shuffles.push_back(Shuffle);
24566     else
24567       return SDValue();
24568   }
24569 
24570   // If we need the zero vector as an "ingredient" in the blend tree, add it
24571   // to the list of shuffles.
24572   if (UsesZeroVector)
24573     Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
24574                                       : DAG.getConstantFP(0.0, DL, VT));
24575 
24576   // If we only have one shuffle, we're done.
24577   if (Shuffles.size() == 1)
24578     return Shuffles[0];
24579 
24580   // Update the vector mask to point to the post-shuffle vectors.
24581   for (int &Vec : VectorMask)
24582     if (Vec == 0)
24583       Vec = Shuffles.size() - 1;
24584     else
24585       Vec = (Vec - 1) / 2;
24586 
24587   // More than one shuffle. Generate a binary tree of blends, e.g. if from
24588   // the previous step we got the set of shuffles t10, t11, t12, t13, we will
24589   // generate:
24590   // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
24591   // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
24592   // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
24593   // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
24594   // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
24595   // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
24596   // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
24597 
24598   // Make sure the initial size of the shuffle list is even.
24599   if (Shuffles.size() % 2)
24600     Shuffles.push_back(DAG.getUNDEF(VT));
24601 
24602   for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
24603     if (CurSize % 2) {
24604       Shuffles[CurSize] = DAG.getUNDEF(VT);
24605       CurSize++;
24606     }
24607     for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
24608       int Left = 2 * In;
24609       int Right = 2 * In + 1;
24610       SmallVector<int, 8> Mask(NumElems, -1);
24611       SDValue L = Shuffles[Left];
24612       ArrayRef<int> LMask;
24613       bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
24614                            L.use_empty() && L.getOperand(1).isUndef() &&
24615                            L.getOperand(0).getValueType() == L.getValueType();
24616       if (IsLeftShuffle) {
24617         LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
24618         L = L.getOperand(0);
24619       }
24620       SDValue R = Shuffles[Right];
24621       ArrayRef<int> RMask;
24622       bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
24623                             R.use_empty() && R.getOperand(1).isUndef() &&
24624                             R.getOperand(0).getValueType() == R.getValueType();
24625       if (IsRightShuffle) {
24626         RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
24627         R = R.getOperand(0);
24628       }
24629       for (unsigned I = 0; I != NumElems; ++I) {
24630         if (VectorMask[I] == Left) {
24631           Mask[I] = I;
24632           if (IsLeftShuffle)
24633             Mask[I] = LMask[I];
24634           VectorMask[I] = In;
24635         } else if (VectorMask[I] == Right) {
24636           Mask[I] = I + NumElems;
24637           if (IsRightShuffle)
24638             Mask[I] = RMask[I] + NumElems;
24639           VectorMask[I] = In;
24640         }
24641       }
24642 
24643       Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
24644     }
24645   }
24646   return Shuffles[0];
24647 }
24648 
24649 // Try to turn a build vector of zero extends of extract vector elts into a
24650 // a vector zero extend and possibly an extract subvector.
24651 // TODO: Support sign extend?
24652 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)24653 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
24654   if (LegalOperations)
24655     return SDValue();
24656 
24657   EVT VT = N->getValueType(0);
24658 
24659   bool FoundZeroExtend = false;
24660   SDValue Op0 = N->getOperand(0);
24661   auto checkElem = [&](SDValue Op) -> int64_t {
24662     unsigned Opc = Op.getOpcode();
24663     FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
24664     if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
24665         Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24666         Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
24667       if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
24668         return C->getZExtValue();
24669     return -1;
24670   };
24671 
24672   // Make sure the first element matches
24673   // (zext (extract_vector_elt X, C))
24674   // Offset must be a constant multiple of the
24675   // known-minimum vector length of the result type.
24676   int64_t Offset = checkElem(Op0);
24677   if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
24678     return SDValue();
24679 
24680   unsigned NumElems = N->getNumOperands();
24681   SDValue In = Op0.getOperand(0).getOperand(0);
24682   EVT InSVT = In.getValueType().getScalarType();
24683   EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
24684 
24685   // Don't create an illegal input type after type legalization.
24686   if (LegalTypes && !TLI.isTypeLegal(InVT))
24687     return SDValue();
24688 
24689   // Ensure all the elements come from the same vector and are adjacent.
24690   for (unsigned i = 1; i != NumElems; ++i) {
24691     if ((Offset + i) != checkElem(N->getOperand(i)))
24692       return SDValue();
24693   }
24694 
24695   SDLoc DL(N);
24696   In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
24697                    Op0.getOperand(0).getOperand(1));
24698   return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
24699                      VT, In);
24700 }
24701 
24702 // If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
24703 // and all other elements being constant zero's, granularize the BUILD_VECTOR's
24704 // element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
24705 // This patten can appear during legalization.
24706 //
24707 // NOTE: This can be generalized to allow more than a single
24708 //       non-constant-zero op, UNDEF's, and to be KnownBits-based,
convertBuildVecZextToBuildVecWithZeros(SDNode * N)24709 SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
24710   // Don't run this after legalization. Targets may have other preferences.
24711   if (Level >= AfterLegalizeDAG)
24712     return SDValue();
24713 
24714   // FIXME: support big-endian.
24715   if (DAG.getDataLayout().isBigEndian())
24716     return SDValue();
24717 
24718   EVT VT = N->getValueType(0);
24719   EVT OpVT = N->getOperand(0).getValueType();
24720   assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
24721 
24722   EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
24723 
24724   if (!TLI.isTypeLegal(OpIntVT) ||
24725       (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
24726     return SDValue();
24727 
24728   unsigned EltBitwidth = VT.getScalarSizeInBits();
24729   // NOTE: the actual width of operands may be wider than that!
24730 
24731   // Analyze all operands of this BUILD_VECTOR. What is the largest number of
24732   // active bits they all have? We'll want to truncate them all to that width.
24733   unsigned ActiveBits = 0;
24734   APInt KnownZeroOps(VT.getVectorNumElements(), 0);
24735   for (auto I : enumerate(N->ops())) {
24736     SDValue Op = I.value();
24737     // FIXME: support UNDEF elements?
24738     if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) {
24739       unsigned OpActiveBits =
24740           Cst->getAPIntValue().trunc(EltBitwidth).getActiveBits();
24741       if (OpActiveBits == 0) {
24742         KnownZeroOps.setBit(I.index());
24743         continue;
24744       }
24745       // Profitability check: don't allow non-zero constant operands.
24746       return SDValue();
24747     }
24748     // Profitability check: there must only be a single non-zero operand,
24749     // and it must be the first operand of the BUILD_VECTOR.
24750     if (I.index() != 0)
24751       return SDValue();
24752     // The operand must be a zero-extension itself.
24753     // FIXME: this could be generalized to known leading zeros check.
24754     if (Op.getOpcode() != ISD::ZERO_EXTEND)
24755       return SDValue();
24756     unsigned CurrActiveBits =
24757         Op.getOperand(0).getValueSizeInBits().getFixedValue();
24758     assert(!ActiveBits && "Already encountered non-constant-zero operand?");
24759     ActiveBits = CurrActiveBits;
24760     // We want to at least halve the element size.
24761     if (2 * ActiveBits > EltBitwidth)
24762       return SDValue();
24763   }
24764 
24765   // This BUILD_VECTOR must have at least one non-constant-zero operand.
24766   if (ActiveBits == 0)
24767     return SDValue();
24768 
24769   // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
24770   // into how many chunks can we split our element width?
24771   EVT NewScalarIntVT, NewIntVT;
24772   std::optional<unsigned> Factor;
24773   // We can split the element into at least two chunks, but not into more
24774   // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
24775   // for which the element width is a multiple of it,
24776   // and the resulting types/operations on that chunk width are legal.
24777   assert(2 * ActiveBits <= EltBitwidth &&
24778          "We know that half or less bits of the element are active.");
24779   for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
24780     if (EltBitwidth % Scale != 0)
24781       continue;
24782     unsigned ChunkBitwidth = EltBitwidth / Scale;
24783     assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
24784     NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
24785     NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
24786                                 Scale * N->getNumOperands());
24787     if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
24788         (LegalOperations &&
24789          !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
24790            TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
24791       continue;
24792     Factor = Scale;
24793     break;
24794   }
24795   if (!Factor)
24796     return SDValue();
24797 
24798   SDLoc DL(N);
24799   SDValue ZeroOp = DAG.getConstant(0, DL, NewScalarIntVT);
24800 
24801   // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
24802   SmallVector<SDValue, 16> NewOps;
24803   NewOps.reserve(NewIntVT.getVectorNumElements());
24804   for (auto I : enumerate(N->ops())) {
24805     SDValue Op = I.value();
24806     assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
24807     unsigned SrcOpIdx = I.index();
24808     if (KnownZeroOps[SrcOpIdx]) {
24809       NewOps.append(*Factor, ZeroOp);
24810       continue;
24811     }
24812     Op = DAG.getBitcast(OpIntVT, Op);
24813     Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
24814     NewOps.emplace_back(Op);
24815     NewOps.append(*Factor - 1, ZeroOp);
24816   }
24817   assert(NewOps.size() == NewIntVT.getVectorNumElements());
24818   SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
24819   NewBV = DAG.getBitcast(VT, NewBV);
24820   return NewBV;
24821 }
24822 
visitBUILD_VECTOR(SDNode * N)24823 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
24824   EVT VT = N->getValueType(0);
24825 
24826   // A vector built entirely of undefs is undef.
24827   if (ISD::allOperandsUndef(N))
24828     return DAG.getUNDEF(VT);
24829 
24830   // If this is a splat of a bitcast from another vector, change to a
24831   // concat_vector.
24832   // For example:
24833   //   (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
24834   //     (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
24835   //
24836   // If X is a build_vector itself, the concat can become a larger build_vector.
24837   // TODO: Maybe this is useful for non-splat too?
24838   if (!LegalOperations) {
24839     SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue();
24840     // Only change build_vector to a concat_vector if the splat value type is
24841     // same as the vector element type.
24842     if (Splat && Splat.getValueType() == VT.getVectorElementType()) {
24843       Splat = peekThroughBitcasts(Splat);
24844       EVT SrcVT = Splat.getValueType();
24845       if (SrcVT.isVector()) {
24846         unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
24847         EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
24848                                      SrcVT.getVectorElementType(), NumElts);
24849         if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
24850           SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
24851           SDValue Concat =
24852               DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), NewVT, Ops);
24853           return DAG.getBitcast(VT, Concat);
24854         }
24855       }
24856     }
24857   }
24858 
24859   // Check if we can express BUILD VECTOR via subvector extract.
24860   if (!LegalTypes && (N->getNumOperands() > 1)) {
24861     SDValue Op0 = N->getOperand(0);
24862     auto checkElem = [&](SDValue Op) -> uint64_t {
24863       if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
24864           (Op0.getOperand(0) == Op.getOperand(0)))
24865         if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
24866           return CNode->getZExtValue();
24867       return -1;
24868     };
24869 
24870     int Offset = checkElem(Op0);
24871     for (unsigned i = 0; i < N->getNumOperands(); ++i) {
24872       if (Offset + i != checkElem(N->getOperand(i))) {
24873         Offset = -1;
24874         break;
24875       }
24876     }
24877 
24878     if ((Offset == 0) &&
24879         (Op0.getOperand(0).getValueType() == N->getValueType(0)))
24880       return Op0.getOperand(0);
24881     if ((Offset != -1) &&
24882         ((Offset % N->getValueType(0).getVectorNumElements()) ==
24883          0)) // IDX must be multiple of output size.
24884       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
24885                          Op0.getOperand(0), Op0.getOperand(1));
24886   }
24887 
24888   if (SDValue V = convertBuildVecZextToZext(N))
24889     return V;
24890 
24891   if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
24892     return V;
24893 
24894   if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
24895     return V;
24896 
24897   if (SDValue V = reduceBuildVecTruncToBitCast(N))
24898     return V;
24899 
24900   if (SDValue V = reduceBuildVecToShuffle(N))
24901     return V;
24902 
24903   // A splat of a single element is a SPLAT_VECTOR if supported on the target.
24904   // Do this late as some of the above may replace the splat.
24905   if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
24906     if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
24907       assert(!V.isUndef() && "Splat of undef should have been handled earlier");
24908       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
24909     }
24910 
24911   return SDValue();
24912 }
24913 
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)24914 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
24915   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24916   EVT OpVT = N->getOperand(0).getValueType();
24917 
24918   // If the operands are legal vectors, leave them alone.
24919   if (TLI.isTypeLegal(OpVT) || OpVT.isScalableVector())
24920     return SDValue();
24921 
24922   SDLoc DL(N);
24923   EVT VT = N->getValueType(0);
24924   SmallVector<SDValue, 8> Ops;
24925   EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
24926 
24927   // Keep track of what we encounter.
24928   EVT AnyFPVT;
24929 
24930   for (const SDValue &Op : N->ops()) {
24931     if (ISD::BITCAST == Op.getOpcode() &&
24932         !Op.getOperand(0).getValueType().isVector())
24933       Ops.push_back(Op.getOperand(0));
24934     else if (Op.isUndef())
24935       Ops.push_back(DAG.getNode(ISD::UNDEF, DL, SVT));
24936     else
24937       return SDValue();
24938 
24939     // Note whether we encounter an integer or floating point scalar.
24940     // If it's neither, bail out, it could be something weird like x86mmx.
24941     EVT LastOpVT = Ops.back().getValueType();
24942     if (LastOpVT.isFloatingPoint())
24943       AnyFPVT = LastOpVT;
24944     else if (!LastOpVT.isInteger())
24945       return SDValue();
24946   }
24947 
24948   // If any of the operands is a floating point scalar bitcast to a vector,
24949   // use floating point types throughout, and bitcast everything.
24950   // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
24951   if (AnyFPVT != EVT()) {
24952     SVT = AnyFPVT;
24953     for (SDValue &Op : Ops) {
24954       if (Op.getValueType() == SVT)
24955         continue;
24956       if (Op.isUndef())
24957         Op = DAG.getNode(ISD::UNDEF, DL, SVT);
24958       else
24959         Op = DAG.getBitcast(SVT, Op);
24960     }
24961   }
24962 
24963   EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
24964                                VT.getSizeInBits() / SVT.getSizeInBits());
24965   return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
24966 }
24967 
24968 // Attempt to merge nested concat_vectors/undefs.
24969 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
24970 //  --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)24971 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
24972                                                   SelectionDAG &DAG) {
24973   EVT VT = N->getValueType(0);
24974 
24975   // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
24976   EVT SubVT;
24977   SDValue FirstConcat;
24978   for (const SDValue &Op : N->ops()) {
24979     if (Op.isUndef())
24980       continue;
24981     if (Op.getOpcode() != ISD::CONCAT_VECTORS)
24982       return SDValue();
24983     if (!FirstConcat) {
24984       SubVT = Op.getOperand(0).getValueType();
24985       if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
24986         return SDValue();
24987       FirstConcat = Op;
24988       continue;
24989     }
24990     if (SubVT != Op.getOperand(0).getValueType())
24991       return SDValue();
24992   }
24993   assert(FirstConcat && "Concat of all-undefs found");
24994 
24995   SmallVector<SDValue> ConcatOps;
24996   for (const SDValue &Op : N->ops()) {
24997     if (Op.isUndef()) {
24998       ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
24999       continue;
25000     }
25001     ConcatOps.append(Op->op_begin(), Op->op_end());
25002   }
25003   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
25004 }
25005 
25006 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
25007 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
25008 // most two distinct vectors the same size as the result, attempt to turn this
25009 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)25010 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
25011   EVT VT = N->getValueType(0);
25012   EVT OpVT = N->getOperand(0).getValueType();
25013 
25014   // We currently can't generate an appropriate shuffle for a scalable vector.
25015   if (VT.isScalableVector())
25016     return SDValue();
25017 
25018   int NumElts = VT.getVectorNumElements();
25019   int NumOpElts = OpVT.getVectorNumElements();
25020 
25021   SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
25022   SmallVector<int, 8> Mask;
25023 
25024   for (SDValue Op : N->ops()) {
25025     Op = peekThroughBitcasts(Op);
25026 
25027     // UNDEF nodes convert to UNDEF shuffle mask values.
25028     if (Op.isUndef()) {
25029       Mask.append((unsigned)NumOpElts, -1);
25030       continue;
25031     }
25032 
25033     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
25034       return SDValue();
25035 
25036     // What vector are we extracting the subvector from and at what index?
25037     SDValue ExtVec = Op.getOperand(0);
25038     int ExtIdx = Op.getConstantOperandVal(1);
25039 
25040     // We want the EVT of the original extraction to correctly scale the
25041     // extraction index.
25042     EVT ExtVT = ExtVec.getValueType();
25043     ExtVec = peekThroughBitcasts(ExtVec);
25044 
25045     // UNDEF nodes convert to UNDEF shuffle mask values.
25046     if (ExtVec.isUndef()) {
25047       Mask.append((unsigned)NumOpElts, -1);
25048       continue;
25049     }
25050 
25051     // Ensure that we are extracting a subvector from a vector the same
25052     // size as the result.
25053     if (ExtVT.getSizeInBits() != VT.getSizeInBits())
25054       return SDValue();
25055 
25056     // Scale the subvector index to account for any bitcast.
25057     int NumExtElts = ExtVT.getVectorNumElements();
25058     if (0 == (NumExtElts % NumElts))
25059       ExtIdx /= (NumExtElts / NumElts);
25060     else if (0 == (NumElts % NumExtElts))
25061       ExtIdx *= (NumElts / NumExtElts);
25062     else
25063       return SDValue();
25064 
25065     // At most we can reference 2 inputs in the final shuffle.
25066     if (SV0.isUndef() || SV0 == ExtVec) {
25067       SV0 = ExtVec;
25068       for (int i = 0; i != NumOpElts; ++i)
25069         Mask.push_back(i + ExtIdx);
25070     } else if (SV1.isUndef() || SV1 == ExtVec) {
25071       SV1 = ExtVec;
25072       for (int i = 0; i != NumOpElts; ++i)
25073         Mask.push_back(i + ExtIdx + NumElts);
25074     } else {
25075       return SDValue();
25076     }
25077   }
25078 
25079   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25080   return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
25081                                      DAG.getBitcast(VT, SV1), Mask, DAG);
25082 }
25083 
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)25084 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
25085   unsigned CastOpcode = N->getOperand(0).getOpcode();
25086   switch (CastOpcode) {
25087   case ISD::SINT_TO_FP:
25088   case ISD::UINT_TO_FP:
25089   case ISD::FP_TO_SINT:
25090   case ISD::FP_TO_UINT:
25091     // TODO: Allow more opcodes?
25092     //  case ISD::BITCAST:
25093     //  case ISD::TRUNCATE:
25094     //  case ISD::ZERO_EXTEND:
25095     //  case ISD::SIGN_EXTEND:
25096     //  case ISD::FP_EXTEND:
25097     break;
25098   default:
25099     return SDValue();
25100   }
25101 
25102   EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
25103   if (!SrcVT.isVector())
25104     return SDValue();
25105 
25106   // All operands of the concat must be the same kind of cast from the same
25107   // source type.
25108   SmallVector<SDValue, 4> SrcOps;
25109   for (SDValue Op : N->ops()) {
25110     if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
25111         Op.getOperand(0).getValueType() != SrcVT)
25112       return SDValue();
25113     SrcOps.push_back(Op.getOperand(0));
25114   }
25115 
25116   // The wider cast must be supported by the target. This is unusual because
25117   // the operation support type parameter depends on the opcode. In addition,
25118   // check the other type in the cast to make sure this is really legal.
25119   EVT VT = N->getValueType(0);
25120   EVT SrcEltVT = SrcVT.getVectorElementType();
25121   ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
25122   EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
25123   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25124   switch (CastOpcode) {
25125   case ISD::SINT_TO_FP:
25126   case ISD::UINT_TO_FP:
25127     if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
25128         !TLI.isTypeLegal(VT))
25129       return SDValue();
25130     break;
25131   case ISD::FP_TO_SINT:
25132   case ISD::FP_TO_UINT:
25133     if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
25134         !TLI.isTypeLegal(ConcatSrcVT))
25135       return SDValue();
25136     break;
25137   default:
25138     llvm_unreachable("Unexpected cast opcode");
25139   }
25140 
25141   // concat (cast X), (cast Y)... -> cast (concat X, Y...)
25142   SDLoc DL(N);
25143   SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
25144   return DAG.getNode(CastOpcode, DL, VT, NewConcat);
25145 }
25146 
25147 // See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
25148 // the operands is a SHUFFLE_VECTOR, and all other operands are also operands
25149 // to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
combineConcatVectorOfShuffleAndItsOperands(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)25150 static SDValue combineConcatVectorOfShuffleAndItsOperands(
25151     SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
25152     bool LegalOperations) {
25153   EVT VT = N->getValueType(0);
25154   EVT OpVT = N->getOperand(0).getValueType();
25155   if (VT.isScalableVector())
25156     return SDValue();
25157 
25158   // For now, only allow simple 2-operand concatenations.
25159   if (N->getNumOperands() != 2)
25160     return SDValue();
25161 
25162   // Don't create illegal types/shuffles when not allowed to.
25163   if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
25164       (LegalOperations &&
25165        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
25166     return SDValue();
25167 
25168   // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
25169   // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
25170   // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
25171   // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
25172   // (4) and for now, the SHUFFLE_VECTOR must be unary.
25173   ShuffleVectorSDNode *SVN = nullptr;
25174   for (SDValue Op : N->ops()) {
25175     if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
25176         CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
25177         all_of(N->ops(), [CurSVN](SDValue Op) {
25178           // FIXME: can we allow UNDEF operands?
25179           return !Op.isUndef() &&
25180                  (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
25181         })) {
25182       SVN = CurSVN;
25183       break;
25184     }
25185   }
25186   if (!SVN)
25187     return SDValue();
25188 
25189   // We are going to pad the shuffle operands, so any indice, that was picking
25190   // from the second operand, must be adjusted.
25191   SmallVector<int, 16> AdjustedMask(SVN->getMask());
25192   assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
25193 
25194   // Identity masks for the operands of the (padded) shuffle.
25195   SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
25196   MutableArrayRef<int> FirstShufOpIdentityMask =
25197       MutableArrayRef<int>(IdentityMask)
25198           .take_front(OpVT.getVectorNumElements());
25199   MutableArrayRef<int> SecondShufOpIdentityMask =
25200       MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
25201   std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
25202   std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
25203             VT.getVectorNumElements());
25204 
25205   // New combined shuffle mask.
25206   SmallVector<int, 32> Mask;
25207   Mask.reserve(VT.getVectorNumElements());
25208   for (SDValue Op : N->ops()) {
25209     assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
25210     if (Op.getNode() == SVN) {
25211       append_range(Mask, AdjustedMask);
25212       continue;
25213     }
25214     if (Op == SVN->getOperand(0)) {
25215       append_range(Mask, FirstShufOpIdentityMask);
25216       continue;
25217     }
25218     if (Op == SVN->getOperand(1)) {
25219       append_range(Mask, SecondShufOpIdentityMask);
25220       continue;
25221     }
25222     llvm_unreachable("Unexpected operand!");
25223   }
25224 
25225   // Don't create illegal shuffle masks.
25226   if (!TLI.isShuffleMaskLegal(Mask, VT))
25227     return SDValue();
25228 
25229   // Pad the shuffle operands with UNDEF.
25230   SDLoc dl(N);
25231   std::array<SDValue, 2> ShufOps;
25232   for (auto I : zip(SVN->ops(), ShufOps)) {
25233     SDValue ShufOp = std::get<0>(I);
25234     SDValue &NewShufOp = std::get<1>(I);
25235     if (ShufOp.isUndef())
25236       NewShufOp = DAG.getUNDEF(VT);
25237     else {
25238       SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
25239                                           DAG.getUNDEF(OpVT));
25240       ShufOpParts[0] = ShufOp;
25241       NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
25242     }
25243   }
25244   // Finally, create the new wide shuffle.
25245   return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
25246 }
25247 
visitCONCAT_VECTORS(SDNode * N)25248 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
25249   // If we only have one input vector, we don't need to do any concatenation.
25250   if (N->getNumOperands() == 1)
25251     return N->getOperand(0);
25252 
25253   // Check if all of the operands are undefs.
25254   EVT VT = N->getValueType(0);
25255   if (ISD::allOperandsUndef(N))
25256     return DAG.getUNDEF(VT);
25257 
25258   // Optimize concat_vectors where all but the first of the vectors are undef.
25259   if (all_of(drop_begin(N->ops()),
25260              [](const SDValue &Op) { return Op.isUndef(); })) {
25261     SDValue In = N->getOperand(0);
25262     assert(In.getValueType().isVector() && "Must concat vectors");
25263 
25264     // If the input is a concat_vectors, just make a larger concat by padding
25265     // with smaller undefs.
25266     //
25267     // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
25268     // here could cause an infinite loop. That legalizing happens when LegalDAG
25269     // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
25270     // scalable.
25271     if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
25272         !(LegalDAG && In.getValueType().isScalableVector())) {
25273       unsigned NumOps = N->getNumOperands() * In.getNumOperands();
25274       SmallVector<SDValue, 4> Ops(In->ops());
25275       Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
25276       return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
25277     }
25278 
25279     SDValue Scalar = peekThroughOneUseBitcasts(In);
25280 
25281     // concat_vectors(scalar_to_vector(scalar), undef) ->
25282     //     scalar_to_vector(scalar)
25283     if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
25284          Scalar.hasOneUse()) {
25285       EVT SVT = Scalar.getValueType().getVectorElementType();
25286       if (SVT == Scalar.getOperand(0).getValueType())
25287         Scalar = Scalar.getOperand(0);
25288     }
25289 
25290     // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
25291     if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
25292       // If the bitcast type isn't legal, it might be a trunc of a legal type;
25293       // look through the trunc so we can still do the transform:
25294       //   concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
25295       if (Scalar->getOpcode() == ISD::TRUNCATE &&
25296           !TLI.isTypeLegal(Scalar.getValueType()) &&
25297           TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
25298         Scalar = Scalar->getOperand(0);
25299 
25300       EVT SclTy = Scalar.getValueType();
25301 
25302       if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
25303         return SDValue();
25304 
25305       // Bail out if the vector size is not a multiple of the scalar size.
25306       if (VT.getSizeInBits() % SclTy.getSizeInBits())
25307         return SDValue();
25308 
25309       unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
25310       if (VNTNumElms < 2)
25311         return SDValue();
25312 
25313       EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
25314       if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
25315         return SDValue();
25316 
25317       SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
25318       return DAG.getBitcast(VT, Res);
25319     }
25320   }
25321 
25322   // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
25323   // We have already tested above for an UNDEF only concatenation.
25324   // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
25325   // -> (BUILD_VECTOR A, B, ..., C, D, ...)
25326   auto IsBuildVectorOrUndef = [](const SDValue &Op) {
25327     return Op.isUndef() || ISD::BUILD_VECTOR == Op.getOpcode();
25328   };
25329   if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
25330     SmallVector<SDValue, 8> Opnds;
25331     EVT SVT = VT.getScalarType();
25332 
25333     EVT MinVT = SVT;
25334     if (!SVT.isFloatingPoint()) {
25335       // If BUILD_VECTOR are from built from integer, they may have different
25336       // operand types. Get the smallest type and truncate all operands to it.
25337       bool FoundMinVT = false;
25338       for (const SDValue &Op : N->ops())
25339         if (ISD::BUILD_VECTOR == Op.getOpcode()) {
25340           EVT OpSVT = Op.getOperand(0).getValueType();
25341           MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
25342           FoundMinVT = true;
25343         }
25344       assert(FoundMinVT && "Concat vector type mismatch");
25345     }
25346 
25347     for (const SDValue &Op : N->ops()) {
25348       EVT OpVT = Op.getValueType();
25349       unsigned NumElts = OpVT.getVectorNumElements();
25350 
25351       if (Op.isUndef())
25352         Opnds.append(NumElts, DAG.getUNDEF(MinVT));
25353 
25354       if (ISD::BUILD_VECTOR == Op.getOpcode()) {
25355         if (SVT.isFloatingPoint()) {
25356           assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
25357           Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
25358         } else {
25359           for (unsigned i = 0; i != NumElts; ++i)
25360             Opnds.push_back(
25361                 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
25362         }
25363       }
25364     }
25365 
25366     assert(VT.getVectorNumElements() == Opnds.size() &&
25367            "Concat vector type mismatch");
25368     return DAG.getBuildVector(VT, SDLoc(N), Opnds);
25369   }
25370 
25371   // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
25372   // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
25373   if (SDValue V = combineConcatVectorOfScalars(N, DAG))
25374     return V;
25375 
25376   if (Level <= AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
25377     // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
25378     if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
25379       return V;
25380 
25381     // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
25382     if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
25383       return V;
25384   }
25385 
25386   if (SDValue V = combineConcatVectorOfCasts(N, DAG))
25387     return V;
25388 
25389   if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
25390           N, DAG, TLI, LegalTypes, LegalOperations))
25391     return V;
25392 
25393   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
25394   // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
25395   // operands and look for a CONCAT operations that place the incoming vectors
25396   // at the exact same location.
25397   //
25398   // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
25399   SDValue SingleSource = SDValue();
25400   unsigned PartNumElem =
25401       N->getOperand(0).getValueType().getVectorMinNumElements();
25402 
25403   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
25404     SDValue Op = N->getOperand(i);
25405 
25406     if (Op.isUndef())
25407       continue;
25408 
25409     // Check if this is the identity extract:
25410     if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
25411       return SDValue();
25412 
25413     // Find the single incoming vector for the extract_subvector.
25414     if (SingleSource.getNode()) {
25415       if (Op.getOperand(0) != SingleSource)
25416         return SDValue();
25417     } else {
25418       SingleSource = Op.getOperand(0);
25419 
25420       // Check the source type is the same as the type of the result.
25421       // If not, this concat may extend the vector, so we can not
25422       // optimize it away.
25423       if (SingleSource.getValueType() != N->getValueType(0))
25424         return SDValue();
25425     }
25426 
25427     // Check that we are reading from the identity index.
25428     unsigned IdentityIndex = i * PartNumElem;
25429     if (Op.getConstantOperandAPInt(1) != IdentityIndex)
25430       return SDValue();
25431   }
25432 
25433   if (SingleSource.getNode())
25434     return SingleSource;
25435 
25436   return SDValue();
25437 }
25438 
25439 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
25440 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,unsigned Index,EVT SubVT)25441 static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
25442   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
25443       V.getOperand(1).getValueType() == SubVT &&
25444       V.getConstantOperandAPInt(2) == Index) {
25445     return V.getOperand(1);
25446   }
25447   if (V.getOpcode() == ISD::CONCAT_VECTORS &&
25448       V.getOperand(0).getValueType() == SubVT &&
25449       (Index % SubVT.getVectorMinNumElements()) == 0) {
25450     uint64_t SubIdx = Index / SubVT.getVectorMinNumElements();
25451     return V.getOperand(SubIdx);
25452   }
25453   return SDValue();
25454 }
25455 
narrowInsertExtractVectorBinOp(EVT SubVT,SDValue BinOp,unsigned Index,const SDLoc & DL,SelectionDAG & DAG,bool LegalOperations)25456 static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp,
25457                                               unsigned Index, const SDLoc &DL,
25458                                               SelectionDAG &DAG,
25459                                               bool LegalOperations) {
25460   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25461   unsigned BinOpcode = BinOp.getOpcode();
25462   if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
25463     return SDValue();
25464 
25465   EVT VecVT = BinOp.getValueType();
25466   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
25467   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
25468     return SDValue();
25469   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
25470     return SDValue();
25471 
25472   SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
25473   SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
25474 
25475   // TODO: We could handle the case where only 1 operand is being inserted by
25476   //       creating an extract of the other operand, but that requires checking
25477   //       number of uses and/or costs.
25478   if (!Sub0 || !Sub1)
25479     return SDValue();
25480 
25481   // We are inserting both operands of the wide binop only to extract back
25482   // to the narrow vector size. Eliminate all of the insert/extract:
25483   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
25484   return DAG.getNode(BinOpcode, DL, SubVT, Sub0, Sub1, BinOp->getFlags());
25485 }
25486 
25487 /// If we are extracting a subvector produced by a wide binary operator try
25488 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(EVT VT,SDValue Src,unsigned Index,const SDLoc & DL,SelectionDAG & DAG,bool LegalOperations)25489 static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
25490                                           const SDLoc &DL, SelectionDAG &DAG,
25491                                           bool LegalOperations) {
25492   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
25493   // some of these bailouts with other transforms.
25494 
25495   if (SDValue V = narrowInsertExtractVectorBinOp(VT, Src, Index, DL, DAG,
25496                                                  LegalOperations))
25497     return V;
25498 
25499   // We are looking for an optionally bitcasted wide vector binary operator
25500   // feeding an extract subvector.
25501   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25502   SDValue BinOp = peekThroughBitcasts(Src);
25503   unsigned BOpcode = BinOp.getOpcode();
25504   if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
25505     return SDValue();
25506 
25507   // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
25508   // reduced to the unary fneg when it is visited, and we probably want to deal
25509   // with fneg in a target-specific way.
25510   if (BOpcode == ISD::FSUB) {
25511     auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
25512     if (C && C->getValueAPF().isNegZero())
25513       return SDValue();
25514   }
25515 
25516   // The binop must be a vector type, so we can extract some fraction of it.
25517   EVT WideBVT = BinOp.getValueType();
25518   // The optimisations below currently assume we are dealing with fixed length
25519   // vectors. It is possible to add support for scalable vectors, but at the
25520   // moment we've done no analysis to prove whether they are profitable or not.
25521   if (!WideBVT.isFixedLengthVector())
25522     return SDValue();
25523 
25524   assert((Index % VT.getVectorNumElements()) == 0 &&
25525          "Extract index is not a multiple of the vector length.");
25526 
25527   // Bail out if this is not a proper multiple width extraction.
25528   unsigned WideWidth = WideBVT.getSizeInBits();
25529   unsigned NarrowWidth = VT.getSizeInBits();
25530   if (WideWidth % NarrowWidth != 0)
25531     return SDValue();
25532 
25533   // Bail out if we are extracting a fraction of a single operation. This can
25534   // occur because we potentially looked through a bitcast of the binop.
25535   unsigned NarrowingRatio = WideWidth / NarrowWidth;
25536   unsigned WideNumElts = WideBVT.getVectorNumElements();
25537   if (WideNumElts % NarrowingRatio != 0)
25538     return SDValue();
25539 
25540   // Bail out if the target does not support a narrower version of the binop.
25541   EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
25542                                    WideNumElts / NarrowingRatio);
25543   if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT,
25544                                              LegalOperations))
25545     return SDValue();
25546 
25547   // If extraction is cheap, we don't need to look at the binop operands
25548   // for concat ops. The narrow binop alone makes this transform profitable.
25549   // We can't just reuse the original extract index operand because we may have
25550   // bitcasted.
25551   unsigned ConcatOpNum = Index / VT.getVectorNumElements();
25552   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
25553   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
25554       BinOp.hasOneUse() && Src->hasOneUse()) {
25555     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
25556     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
25557     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25558                             BinOp.getOperand(0), NewExtIndex);
25559     SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25560                             BinOp.getOperand(1), NewExtIndex);
25561     SDValue NarrowBinOp =
25562         DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
25563     return DAG.getBitcast(VT, NarrowBinOp);
25564   }
25565 
25566   // Only handle the case where we are doubling and then halving. A larger ratio
25567   // may require more than two narrow binops to replace the wide binop.
25568   if (NarrowingRatio != 2)
25569     return SDValue();
25570 
25571   // TODO: The motivating case for this transform is an x86 AVX1 target. That
25572   // target has temptingly almost legal versions of bitwise logic ops in 256-bit
25573   // flavors, but no other 256-bit integer support. This could be extended to
25574   // handle any binop, but that may require fixing/adding other folds to avoid
25575   // codegen regressions.
25576   if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
25577     return SDValue();
25578 
25579   // We need at least one concatenation operation of a binop operand to make
25580   // this transform worthwhile. The concat must double the input vector sizes.
25581   auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
25582     if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
25583       return V.getOperand(ConcatOpNum);
25584     return SDValue();
25585   };
25586   SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
25587   SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
25588 
25589   if (SubVecL || SubVecR) {
25590     // If a binop operand was not the result of a concat, we must extract a
25591     // half-sized operand for our new narrow binop:
25592     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
25593     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
25594     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
25595     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
25596     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
25597                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25598                                       BinOp.getOperand(0), IndexC);
25599 
25600     SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
25601                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25602                                       BinOp.getOperand(1), IndexC);
25603 
25604     SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
25605     return DAG.getBitcast(VT, NarrowBinOp);
25606   }
25607 
25608   return SDValue();
25609 }
25610 
25611 /// If we are extracting a subvector from a wide vector load, convert to a
25612 /// narrow load to eliminate the extraction:
25613 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(EVT VT,SDValue Src,unsigned Index,const SDLoc & DL,SelectionDAG & DAG)25614 static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index,
25615                                          const SDLoc &DL, SelectionDAG &DAG) {
25616   // TODO: Add support for big-endian. The offset calculation must be adjusted.
25617   if (DAG.getDataLayout().isBigEndian())
25618     return SDValue();
25619 
25620   auto *Ld = dyn_cast<LoadSDNode>(Src);
25621   if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple())
25622     return SDValue();
25623 
25624   // We can only create byte sized loads.
25625   if (!VT.isByteSized())
25626     return SDValue();
25627 
25628   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25629   if (!TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, VT))
25630     return SDValue();
25631 
25632   unsigned NumElts = VT.getVectorMinNumElements();
25633   // A fixed length vector being extracted from a scalable vector
25634   // may not be any *smaller* than the scalable one.
25635   if (Index == 0 && NumElts >= Ld->getValueType(0).getVectorMinNumElements())
25636     return SDValue();
25637 
25638   // The definition of EXTRACT_SUBVECTOR states that the index must be a
25639   // multiple of the minimum number of elements in the result type.
25640   assert(Index % NumElts == 0 && "The extract subvector index is not a "
25641                                  "multiple of the result's element count");
25642 
25643   // It's fine to use TypeSize here as we know the offset will not be negative.
25644   TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
25645   std::optional<unsigned> ByteOffset;
25646   if (Offset.isFixed())
25647     ByteOffset = Offset.getFixedValue();
25648 
25649   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT, ByteOffset))
25650     return SDValue();
25651 
25652   // The narrow load will be offset from the base address of the old load if
25653   // we are extracting from something besides index 0 (little-endian).
25654   // TODO: Use "BaseIndexOffset" to make this more effective.
25655   SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
25656 
25657   MachineFunction &MF = DAG.getMachineFunction();
25658   MachineMemOperand *MMO;
25659   if (Offset.isScalable()) {
25660     MachinePointerInfo MPI =
25661         MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
25662     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, VT.getStoreSize());
25663   } else
25664     MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedValue(),
25665                                   VT.getStoreSize());
25666 
25667   SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
25668   DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
25669   return NewLd;
25670 }
25671 
25672 /// Given  EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
25673 /// try to produce  VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
25674 ///                                EXTRACT_SUBVECTOR(Op?, ?),
25675 ///                                Mask'))
25676 /// iff it is legal and profitable to do so. Notably, the trimmed mask
25677 /// (containing only the elements that are extracted)
25678 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(EVT NarrowVT,SDValue Src,unsigned Index,const SDLoc & DL,SelectionDAG & DAG,bool LegalOperations)25679 static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
25680                                                      unsigned Index,
25681                                                      const SDLoc &DL,
25682                                                      SelectionDAG &DAG,
25683                                                      bool LegalOperations) {
25684   // Only deal with non-scalable vectors.
25685   EVT WideVT = Src.getValueType();
25686   if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
25687     return SDValue();
25688 
25689   // The operand must be a shufflevector.
25690   auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Src);
25691   if (!WideShuffleVector)
25692     return SDValue();
25693 
25694   // The old shuffleneeds to go away.
25695   if (!WideShuffleVector->hasOneUse())
25696     return SDValue();
25697 
25698   // And the narrow shufflevector that we'll form must be legal.
25699   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25700   if (LegalOperations &&
25701       !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
25702     return SDValue();
25703 
25704   int NumEltsExtracted = NarrowVT.getVectorNumElements();
25705   assert((Index % NumEltsExtracted) == 0 &&
25706          "Extract index is not a multiple of the output vector length.");
25707 
25708   int WideNumElts = WideVT.getVectorNumElements();
25709 
25710   SmallVector<int, 16> NewMask;
25711   NewMask.reserve(NumEltsExtracted);
25712   SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
25713       DemandedSubvectors;
25714 
25715   // Try to decode the wide mask into narrow mask from at most two subvectors.
25716   for (int M : WideShuffleVector->getMask().slice(Index, NumEltsExtracted)) {
25717     assert((M >= -1) && (M < (2 * WideNumElts)) &&
25718            "Out-of-bounds shuffle mask?");
25719 
25720     if (M < 0) {
25721       // Does not depend on operands, does not require adjustment.
25722       NewMask.emplace_back(M);
25723       continue;
25724     }
25725 
25726     // From which operand of the shuffle does this shuffle mask element pick?
25727     int WideShufOpIdx = M / WideNumElts;
25728     // Which element of that operand is picked?
25729     int OpEltIdx = M % WideNumElts;
25730 
25731     assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
25732            "Shuffle mask vector decomposition failure.");
25733 
25734     // And which NumEltsExtracted-sized subvector of that operand is that?
25735     int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
25736     // And which element within that subvector of that operand is that?
25737     int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
25738 
25739     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
25740            "Shuffle mask subvector decomposition failure.");
25741 
25742     assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
25743             WideShufOpIdx * WideNumElts) == M &&
25744            "Shuffle mask full decomposition failure.");
25745 
25746     SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
25747 
25748     if (Op.isUndef()) {
25749       // Picking from an undef operand. Let's adjust mask instead.
25750       NewMask.emplace_back(-1);
25751       continue;
25752     }
25753 
25754     const std::pair<SDValue, int> DemandedSubvector =
25755         std::make_pair(Op, OpSubvecIdx);
25756 
25757     if (DemandedSubvectors.insert(DemandedSubvector)) {
25758       if (DemandedSubvectors.size() > 2)
25759         return SDValue(); // We can't handle more than two subvectors.
25760       // How many elements into the WideVT does this subvector start?
25761       int Index = NumEltsExtracted * OpSubvecIdx;
25762       // Bail out if the extraction isn't going to be cheap.
25763       if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
25764         return SDValue();
25765     }
25766 
25767     // Ok, but from which operand of the new shuffle will this element pick?
25768     int NewOpIdx =
25769         getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
25770     assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
25771 
25772     int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
25773     NewMask.emplace_back(AdjM);
25774   }
25775   assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
25776   assert(DemandedSubvectors.size() <= 2 &&
25777          "Should have ended up demanding at most two subvectors.");
25778 
25779   // Did we discover that the shuffle does not actually depend on operands?
25780   if (DemandedSubvectors.empty())
25781     return DAG.getUNDEF(NarrowVT);
25782 
25783   // Profitability check: only deal with extractions from the first subvector
25784   // unless the mask becomes an identity mask.
25785   if (!ShuffleVectorInst::isIdentityMask(NewMask, NewMask.size()) ||
25786       any_of(NewMask, [](int M) { return M < 0; }))
25787     for (auto &DemandedSubvector : DemandedSubvectors)
25788       if (DemandedSubvector.second != 0)
25789         return SDValue();
25790 
25791   // We still perform the exact same EXTRACT_SUBVECTOR,  just on different
25792   // operand[s]/index[es], so there is no point in checking for it's legality.
25793 
25794   // Do not turn a legal shuffle into an illegal one.
25795   if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
25796       !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
25797     return SDValue();
25798 
25799   SmallVector<SDValue, 2> NewOps;
25800   for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
25801            &DemandedSubvector : DemandedSubvectors) {
25802     // How many elements into the WideVT does this subvector start?
25803     int Index = NumEltsExtracted * DemandedSubvector.second;
25804     SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
25805     NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
25806                                     DemandedSubvector.first, IndexC));
25807   }
25808   assert((NewOps.size() == 1 || NewOps.size() == 2) &&
25809          "Should end up with either one or two ops");
25810 
25811   // If we ended up with only one operand, pad with an undef.
25812   if (NewOps.size() == 1)
25813     NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
25814 
25815   return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
25816 }
25817 
visitEXTRACT_SUBVECTOR(SDNode * N)25818 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
25819   EVT NVT = N->getValueType(0);
25820   SDValue V = N->getOperand(0);
25821   uint64_t ExtIdx = N->getConstantOperandVal(1);
25822   SDLoc DL(N);
25823 
25824   // Extract from UNDEF is UNDEF.
25825   if (V.isUndef())
25826     return DAG.getUNDEF(NVT);
25827 
25828   if (SDValue NarrowLoad = narrowExtractedVectorLoad(NVT, V, ExtIdx, DL, DAG))
25829     return NarrowLoad;
25830 
25831   // Combine an extract of an extract into a single extract_subvector.
25832   // ext (ext X, C), 0 --> ext X, C
25833   if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
25834     if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
25835                                     V.getConstantOperandVal(1)) &&
25836         TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
25837       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, V.getOperand(0),
25838                          V.getOperand(1));
25839     }
25840   }
25841 
25842   // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
25843   if (V.getOpcode() == ISD::SPLAT_VECTOR)
25844     if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
25845       if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
25846         return DAG.getSplatVector(NVT, DL, V.getOperand(0));
25847 
25848   // extract_subvector(insert_subvector(x,y,c1),c2)
25849   //  --> extract_subvector(y,c2-c1)
25850   // iff we're just extracting from the inserted subvector.
25851   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
25852     SDValue InsSub = V.getOperand(1);
25853     EVT InsSubVT = InsSub.getValueType();
25854     unsigned NumInsElts = InsSubVT.getVectorMinNumElements();
25855     unsigned InsIdx = V.getConstantOperandVal(2);
25856     unsigned NumSubElts = NVT.getVectorMinNumElements();
25857     if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) &&
25858         TLI.isExtractSubvectorCheap(NVT, InsSubVT, ExtIdx - InsIdx) &&
25859         InsSubVT.isFixedLengthVector() && NVT.isFixedLengthVector() &&
25860         V.getValueType().isFixedLengthVector())
25861       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, InsSub,
25862                          DAG.getVectorIdxConstant(ExtIdx - InsIdx, DL));
25863   }
25864 
25865   // Try to move vector bitcast after extract_subv by scaling extraction index:
25866   // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
25867   if (V.getOpcode() == ISD::BITCAST &&
25868       V.getOperand(0).getValueType().isVector() &&
25869       (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
25870     SDValue SrcOp = V.getOperand(0);
25871     EVT SrcVT = SrcOp.getValueType();
25872     unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
25873     unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
25874     if ((SrcNumElts % DestNumElts) == 0) {
25875       unsigned SrcDestRatio = SrcNumElts / DestNumElts;
25876       ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
25877       EVT NewExtVT =
25878           EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(), NewExtEC);
25879       if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
25880         SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
25881         SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
25882                                          V.getOperand(0), NewIndex);
25883         return DAG.getBitcast(NVT, NewExtract);
25884       }
25885     }
25886     if ((DestNumElts % SrcNumElts) == 0) {
25887       unsigned DestSrcRatio = DestNumElts / SrcNumElts;
25888       if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
25889         ElementCount NewExtEC =
25890             NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
25891         EVT ScalarVT = SrcVT.getScalarType();
25892         if ((ExtIdx % DestSrcRatio) == 0) {
25893           unsigned IndexValScaled = ExtIdx / DestSrcRatio;
25894           EVT NewExtVT =
25895               EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
25896           if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
25897             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
25898             SDValue NewExtract =
25899                 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
25900                             V.getOperand(0), NewIndex);
25901             return DAG.getBitcast(NVT, NewExtract);
25902           }
25903           if (NewExtEC.isScalar() &&
25904               TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
25905             SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
25906             SDValue NewExtract =
25907                 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
25908                             V.getOperand(0), NewIndex);
25909             return DAG.getBitcast(NVT, NewExtract);
25910           }
25911         }
25912       }
25913     }
25914   }
25915 
25916   if (V.getOpcode() == ISD::CONCAT_VECTORS) {
25917     unsigned ExtNumElts = NVT.getVectorMinNumElements();
25918     EVT ConcatSrcVT = V.getOperand(0).getValueType();
25919     assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
25920            "Concat and extract subvector do not change element type");
25921     assert((ExtIdx % ExtNumElts) == 0 &&
25922            "Extract index is not a multiple of the input vector length.");
25923 
25924     unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
25925     unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
25926 
25927     // If the concatenated source types match this extract, it's a direct
25928     // simplification:
25929     // extract_subvec (concat V1, V2, ...), i --> Vi
25930     if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
25931       return V.getOperand(ConcatOpIdx);
25932 
25933     // If the concatenated source vectors are a multiple length of this extract,
25934     // then extract a fraction of one of those source vectors directly from a
25935     // concat operand. Example:
25936     //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
25937     //   v2i8 extract_subvec v8i8 Y, 6
25938     if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
25939         ConcatSrcNumElts % ExtNumElts == 0) {
25940       unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
25941       assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
25942              "Trying to extract from >1 concat operand?");
25943       assert(NewExtIdx % ExtNumElts == 0 &&
25944              "Extract index is not a multiple of the input vector length.");
25945       SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
25946       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
25947                          V.getOperand(ConcatOpIdx), NewIndexC);
25948     }
25949   }
25950 
25951   if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector(
25952           NVT, V, ExtIdx, DL, DAG, LegalOperations))
25953     return Shuffle;
25954 
25955   if (SDValue NarrowBOp =
25956           narrowExtractedVectorBinOp(NVT, V, ExtIdx, DL, DAG, LegalOperations))
25957     return NarrowBOp;
25958 
25959   V = peekThroughBitcasts(V);
25960 
25961   // If the input is a build vector. Try to make a smaller build vector.
25962   if (V.getOpcode() == ISD::BUILD_VECTOR) {
25963     EVT InVT = V.getValueType();
25964     unsigned ExtractSize = NVT.getSizeInBits();
25965     unsigned EltSize = InVT.getScalarSizeInBits();
25966     // Only do this if we won't split any elements.
25967     if (ExtractSize % EltSize == 0) {
25968       unsigned NumElems = ExtractSize / EltSize;
25969       EVT EltVT = InVT.getVectorElementType();
25970       EVT ExtractVT =
25971           NumElems == 1 ? EltVT
25972                         : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
25973       if ((Level < AfterLegalizeDAG ||
25974            (NumElems == 1 ||
25975             TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
25976           (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
25977         unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
25978 
25979         if (NumElems == 1) {
25980           SDValue Src = V->getOperand(IdxVal);
25981           if (EltVT != Src.getValueType())
25982             Src = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Src);
25983           return DAG.getBitcast(NVT, Src);
25984         }
25985 
25986         // Extract the pieces from the original build_vector.
25987         SDValue BuildVec =
25988             DAG.getBuildVector(ExtractVT, DL, V->ops().slice(IdxVal, NumElems));
25989         return DAG.getBitcast(NVT, BuildVec);
25990       }
25991     }
25992   }
25993 
25994   if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
25995     // Handle only simple case where vector being inserted and vector
25996     // being extracted are of same size.
25997     EVT SmallVT = V.getOperand(1).getValueType();
25998     if (NVT.bitsEq(SmallVT)) {
25999       // Combine:
26000       //    (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
26001       // Into:
26002       //    indices are equal or bit offsets are equal => V1
26003       //    otherwise => (extract_subvec V1, ExtIdx)
26004       uint64_t InsIdx = V.getConstantOperandVal(2);
26005       if (InsIdx * SmallVT.getScalarSizeInBits() ==
26006           ExtIdx * NVT.getScalarSizeInBits()) {
26007         if (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))
26008           return DAG.getBitcast(NVT, V.getOperand(1));
26009       } else {
26010         return DAG.getNode(
26011             ISD::EXTRACT_SUBVECTOR, DL, NVT,
26012             DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
26013             N->getOperand(1));
26014       }
26015     }
26016   }
26017 
26018   // If only EXTRACT_SUBVECTOR nodes use the source vector we can
26019   // simplify it based on the (valid) extractions.
26020   if (!V.getValueType().isScalableVector() &&
26021       llvm::all_of(V->users(), [&](SDNode *Use) {
26022         return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26023                Use->getOperand(0) == V;
26024       })) {
26025     unsigned NumElts = V.getValueType().getVectorNumElements();
26026     APInt DemandedElts = APInt::getZero(NumElts);
26027     for (SDNode *User : V->users()) {
26028       unsigned ExtIdx = User->getConstantOperandVal(1);
26029       unsigned NumSubElts = User->getValueType(0).getVectorNumElements();
26030       DemandedElts.setBits(ExtIdx, ExtIdx + NumSubElts);
26031     }
26032     if (SimplifyDemandedVectorElts(V, DemandedElts, /*AssumeSingleUse=*/true)) {
26033       // We simplified the vector operand of this extract subvector. If this
26034       // extract is not dead, visit it again so it is folded properly.
26035       if (N->getOpcode() != ISD::DELETED_NODE)
26036         AddToWorklist(N);
26037       return SDValue(N, 0);
26038     }
26039   } else {
26040     if (SimplifyDemandedVectorElts(SDValue(N, 0)))
26041       return SDValue(N, 0);
26042   }
26043 
26044   return SDValue();
26045 }
26046 
26047 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
26048 /// followed by concatenation. Narrow vector ops may have better performance
26049 /// than wide ops, and this can unlock further narrowing of other vector ops.
26050 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)26051 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
26052                                          SelectionDAG &DAG) {
26053   SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
26054   if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
26055       N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
26056       !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
26057     return SDValue();
26058 
26059   // Split the wide shuffle mask into halves. Any mask element that is accessing
26060   // operand 1 is offset down to account for narrowing of the vectors.
26061   ArrayRef<int> Mask = Shuf->getMask();
26062   EVT VT = Shuf->getValueType(0);
26063   unsigned NumElts = VT.getVectorNumElements();
26064   unsigned HalfNumElts = NumElts / 2;
26065   SmallVector<int, 16> Mask0(HalfNumElts, -1);
26066   SmallVector<int, 16> Mask1(HalfNumElts, -1);
26067   for (unsigned i = 0; i != NumElts; ++i) {
26068     if (Mask[i] == -1)
26069       continue;
26070     // If we reference the upper (undef) subvector then the element is undef.
26071     if ((Mask[i] % NumElts) >= HalfNumElts)
26072       continue;
26073     int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
26074     if (i < HalfNumElts)
26075       Mask0[i] = M;
26076     else
26077       Mask1[i - HalfNumElts] = M;
26078   }
26079 
26080   // Ask the target if this is a valid transform.
26081   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26082   EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
26083                                 HalfNumElts);
26084   if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
26085       !TLI.isShuffleMaskLegal(Mask1, HalfVT))
26086     return SDValue();
26087 
26088   // shuffle (concat X, undef), (concat Y, undef), Mask -->
26089   // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
26090   SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
26091   SDLoc DL(Shuf);
26092   SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
26093   SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
26094   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
26095 }
26096 
26097 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
26098 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)26099 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
26100   EVT VT = N->getValueType(0);
26101   unsigned NumElts = VT.getVectorNumElements();
26102 
26103   SDValue N0 = N->getOperand(0);
26104   SDValue N1 = N->getOperand(1);
26105   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
26106   ArrayRef<int> Mask = SVN->getMask();
26107 
26108   SmallVector<SDValue, 4> Ops;
26109   EVT ConcatVT = N0.getOperand(0).getValueType();
26110   unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
26111   unsigned NumConcats = NumElts / NumElemsPerConcat;
26112 
26113   auto IsUndefMaskElt = [](int i) { return i == -1; };
26114 
26115   // Special case: shuffle(concat(A,B)) can be more efficiently represented
26116   // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
26117   // half vector elements.
26118   if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
26119       llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
26120                    IsUndefMaskElt)) {
26121     N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
26122                               N0.getOperand(1),
26123                               Mask.slice(0, NumElemsPerConcat));
26124     N1 = DAG.getUNDEF(ConcatVT);
26125     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
26126   }
26127 
26128   // Look at every vector that's inserted. We're looking for exact
26129   // subvector-sized copies from a concatenated vector
26130   for (unsigned I = 0; I != NumConcats; ++I) {
26131     unsigned Begin = I * NumElemsPerConcat;
26132     ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
26133 
26134     // Make sure we're dealing with a copy.
26135     if (llvm::all_of(SubMask, IsUndefMaskElt)) {
26136       Ops.push_back(DAG.getUNDEF(ConcatVT));
26137       continue;
26138     }
26139 
26140     int OpIdx = -1;
26141     for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
26142       if (IsUndefMaskElt(SubMask[i]))
26143         continue;
26144       if ((SubMask[i] % (int)NumElemsPerConcat) != i)
26145         return SDValue();
26146       int EltOpIdx = SubMask[i] / NumElemsPerConcat;
26147       if (0 <= OpIdx && EltOpIdx != OpIdx)
26148         return SDValue();
26149       OpIdx = EltOpIdx;
26150     }
26151     assert(0 <= OpIdx && "Unknown concat_vectors op");
26152 
26153     if (OpIdx < (int)N0.getNumOperands())
26154       Ops.push_back(N0.getOperand(OpIdx));
26155     else
26156       Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
26157   }
26158 
26159   return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
26160 }
26161 
26162 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
26163 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
26164 //
26165 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
26166 // a simplification in some sense, but it isn't appropriate in general: some
26167 // BUILD_VECTORs are substantially cheaper than others. The general case
26168 // of a BUILD_VECTOR requires inserting each element individually (or
26169 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
26170 // all constants is a single constant pool load.  A BUILD_VECTOR where each
26171 // element is identical is a splat.  A BUILD_VECTOR where most of the operands
26172 // are undef lowers to a small number of element insertions.
26173 //
26174 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
26175 // We don't fold shuffles where one side is a non-zero constant, and we don't
26176 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
26177 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)26178 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
26179                                        SelectionDAG &DAG,
26180                                        const TargetLowering &TLI) {
26181   EVT VT = SVN->getValueType(0);
26182   unsigned NumElts = VT.getVectorNumElements();
26183   SDValue N0 = SVN->getOperand(0);
26184   SDValue N1 = SVN->getOperand(1);
26185 
26186   if (!N0->hasOneUse())
26187     return SDValue();
26188 
26189   // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
26190   // discussed above.
26191   if (!N1.isUndef()) {
26192     if (!N1->hasOneUse())
26193       return SDValue();
26194 
26195     bool N0AnyConst = isAnyConstantBuildVector(N0);
26196     bool N1AnyConst = isAnyConstantBuildVector(N1);
26197     if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
26198       return SDValue();
26199     if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
26200       return SDValue();
26201   }
26202 
26203   // If both inputs are splats of the same value then we can safely merge this
26204   // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
26205   bool IsSplat = false;
26206   auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
26207   auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
26208   if (BV0 && BV1)
26209     if (SDValue Splat0 = BV0->getSplatValue())
26210       IsSplat = (Splat0 == BV1->getSplatValue());
26211 
26212   SmallVector<SDValue, 8> Ops;
26213   SmallSet<SDValue, 16> DuplicateOps;
26214   for (int M : SVN->getMask()) {
26215     SDValue Op = DAG.getUNDEF(VT.getScalarType());
26216     if (M >= 0) {
26217       int Idx = M < (int)NumElts ? M : M - NumElts;
26218       SDValue &S = (M < (int)NumElts ? N0 : N1);
26219       if (S.getOpcode() == ISD::BUILD_VECTOR) {
26220         Op = S.getOperand(Idx);
26221       } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
26222         SDValue Op0 = S.getOperand(0);
26223         Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
26224       } else {
26225         // Operand can't be combined - bail out.
26226         return SDValue();
26227       }
26228     }
26229 
26230     // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
26231     // generating a splat; semantically, this is fine, but it's likely to
26232     // generate low-quality code if the target can't reconstruct an appropriate
26233     // shuffle.
26234     if (!Op.isUndef() && !isIntOrFPConstant(Op))
26235       if (!IsSplat && !DuplicateOps.insert(Op).second)
26236         return SDValue();
26237 
26238     Ops.push_back(Op);
26239   }
26240 
26241   // BUILD_VECTOR requires all inputs to be of the same type, find the
26242   // maximum type and extend them all.
26243   EVT SVT = VT.getScalarType();
26244   if (SVT.isInteger())
26245     for (SDValue &Op : Ops)
26246       SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
26247   if (SVT != VT.getScalarType())
26248     for (SDValue &Op : Ops)
26249       Op = Op.isUndef() ? DAG.getUNDEF(SVT)
26250                         : (TLI.isZExtFree(Op.getValueType(), SVT)
26251                                ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
26252                                : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
26253   return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
26254 }
26255 
26256 // Match shuffles that can be converted to *_vector_extend_in_reg.
26257 // This is often generated during legalization.
26258 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
26259 // and returns the EVT to which the extension should be performed.
26260 // NOTE: this assumes that the src is the first operand of the shuffle.
canCombineShuffleToExtendVectorInreg(unsigned Opcode,EVT VT,std::function<bool (unsigned)> Match,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)26261 static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
26262     unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
26263     SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
26264     bool LegalOperations) {
26265   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26266 
26267   // TODO Add support for big-endian when we have a test case.
26268   if (!VT.isInteger() || IsBigEndian)
26269     return std::nullopt;
26270 
26271   unsigned NumElts = VT.getVectorNumElements();
26272   unsigned EltSizeInBits = VT.getScalarSizeInBits();
26273 
26274   // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
26275   // power-of-2 extensions as they are the most likely.
26276   // FIXME: should try Scale == NumElts case too,
26277   for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
26278     // The vector width must be a multiple of Scale.
26279     if (NumElts % Scale != 0)
26280       continue;
26281 
26282     EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
26283     EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
26284 
26285     if ((LegalTypes && !TLI.isTypeLegal(OutVT)) ||
26286         (LegalOperations && !TLI.isOperationLegalOrCustom(Opcode, OutVT)))
26287       continue;
26288 
26289     if (Match(Scale))
26290       return OutVT;
26291   }
26292 
26293   return std::nullopt;
26294 }
26295 
26296 // Match shuffles that can be converted to any_vector_extend_in_reg.
26297 // This is often generated during legalization.
26298 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)26299 static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
26300                                                     SelectionDAG &DAG,
26301                                                     const TargetLowering &TLI,
26302                                                     bool LegalOperations) {
26303   EVT VT = SVN->getValueType(0);
26304   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26305 
26306   // TODO Add support for big-endian when we have a test case.
26307   if (!VT.isInteger() || IsBigEndian)
26308     return SDValue();
26309 
26310   // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
26311   auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
26312                       Mask = SVN->getMask()](unsigned Scale) {
26313     for (unsigned i = 0; i != NumElts; ++i) {
26314       if (Mask[i] < 0)
26315         continue;
26316       if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
26317         continue;
26318       return false;
26319     }
26320     return true;
26321   };
26322 
26323   unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
26324   SDValue N0 = SVN->getOperand(0);
26325   // Never create an illegal type. Only create unsupported operations if we
26326   // are pre-legalization.
26327   std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
26328       Opcode, VT, isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
26329   if (!OutVT)
26330     return SDValue();
26331   return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT, N0));
26332 }
26333 
26334 // Match shuffles that can be converted to zero_extend_vector_inreg.
26335 // This is often generated during legalization.
26336 // e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)26337 static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
26338                                                      SelectionDAG &DAG,
26339                                                      const TargetLowering &TLI,
26340                                                      bool LegalOperations) {
26341   bool LegalTypes = true;
26342   EVT VT = SVN->getValueType(0);
26343   assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
26344   unsigned NumElts = VT.getVectorNumElements();
26345   unsigned EltSizeInBits = VT.getScalarSizeInBits();
26346 
26347   // TODO: add support for big-endian when we have a test case.
26348   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26349   if (!VT.isInteger() || IsBigEndian)
26350     return SDValue();
26351 
26352   SmallVector<int, 16> Mask(SVN->getMask());
26353   auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
26354     for (int &Indice : Mask) {
26355       if (Indice < 0)
26356         continue;
26357       int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
26358       int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
26359       Fn(Indice, OpIdx, OpEltIdx);
26360     }
26361   };
26362 
26363   // Which elements of which operand does this shuffle demand?
26364   std::array<APInt, 2> OpsDemandedElts;
26365   for (APInt &OpDemandedElts : OpsDemandedElts)
26366     OpDemandedElts = APInt::getZero(NumElts);
26367   ForEachDecomposedIndice(
26368       [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
26369         OpsDemandedElts[OpIdx].setBit(OpEltIdx);
26370       });
26371 
26372   // Element-wise(!), which of these demanded elements are know to be zero?
26373   std::array<APInt, 2> OpsKnownZeroElts;
26374   for (auto I : zip(SVN->ops(), OpsDemandedElts, OpsKnownZeroElts))
26375     std::get<2>(I) =
26376         DAG.computeVectorKnownZeroElements(std::get<0>(I), std::get<1>(I));
26377 
26378   // Manifest zeroable element knowledge in the shuffle mask.
26379   // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
26380   //       this is a local invention, but it won't leak into DAG.
26381   // FIXME: should we not manifest them, but just check when matching?
26382   bool HadZeroableElts = false;
26383   ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
26384                               int &Indice, int OpIdx, int OpEltIdx) {
26385     if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
26386       Indice = -2; // Zeroable element.
26387       HadZeroableElts = true;
26388     }
26389   });
26390 
26391   // Don't proceed unless we've refined at least one zeroable mask indice.
26392   // If we didn't, then we are still trying to match the same shuffle mask
26393   // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
26394   // and evidently failed. Proceeding will lead to endless combine loops.
26395   if (!HadZeroableElts)
26396     return SDValue();
26397 
26398   // The shuffle may be more fine-grained than we want. Widen elements first.
26399   // FIXME: should we do this before manifesting zeroable shuffle mask indices?
26400   SmallVector<int, 16> ScaledMask;
26401   getShuffleMaskWithWidestElts(Mask, ScaledMask);
26402   assert(Mask.size() >= ScaledMask.size() &&
26403          Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
26404   int Prescale = Mask.size() / ScaledMask.size();
26405 
26406   NumElts = ScaledMask.size();
26407   EltSizeInBits *= Prescale;
26408 
26409   EVT PrescaledVT = EVT::getVectorVT(
26410       *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits),
26411       NumElts);
26412 
26413   if (LegalTypes && !TLI.isTypeLegal(PrescaledVT) && TLI.isTypeLegal(VT))
26414     return SDValue();
26415 
26416   // For example,
26417   // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
26418   // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
26419   auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
26420     assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
26421            "Unexpected mask scaling factor.");
26422     ArrayRef<int> Mask = ScaledMask;
26423     for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
26424          SrcElt != NumSrcElts; ++SrcElt) {
26425       // Analyze the shuffle mask in Scale-sized chunks.
26426       ArrayRef<int> MaskChunk = Mask.take_front(Scale);
26427       assert(MaskChunk.size() == Scale && "Unexpected mask size.");
26428       Mask = Mask.drop_front(MaskChunk.size());
26429       // The first indice in this chunk must be SrcElt, but not zero!
26430       // FIXME: undef should be fine, but that results in more-defined result.
26431       if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
26432         return false;
26433       // The rest of the indices in this chunk must be zeros.
26434       // FIXME: undef should be fine, but that results in more-defined result.
26435       if (!all_of(MaskChunk.drop_front(1),
26436                   [](int Indice) { return Indice == -2; }))
26437         return false;
26438     }
26439     assert(Mask.empty() && "Did not process the whole mask?");
26440     return true;
26441   };
26442 
26443   unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
26444   for (bool Commuted : {false, true}) {
26445     SDValue Op = SVN->getOperand(!Commuted ? 0 : 1);
26446     if (Commuted)
26447       ShuffleVectorSDNode::commuteMask(ScaledMask);
26448     std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
26449         Opcode, PrescaledVT, isZeroExtend, DAG, TLI, LegalTypes,
26450         LegalOperations);
26451     if (OutVT)
26452       return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT,
26453                                             DAG.getBitcast(PrescaledVT, Op)));
26454   }
26455   return SDValue();
26456 }
26457 
26458 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
26459 // each source element of a large type into the lowest elements of a smaller
26460 // destination type. This is often generated during legalization.
26461 // If the source node itself was a '*_extend_vector_inreg' node then we should
26462 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)26463 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
26464                                         SelectionDAG &DAG) {
26465   EVT VT = SVN->getValueType(0);
26466   bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26467 
26468   // TODO Add support for big-endian when we have a test case.
26469   if (!VT.isInteger() || IsBigEndian)
26470     return SDValue();
26471 
26472   SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
26473 
26474   unsigned Opcode = N0.getOpcode();
26475   if (!ISD::isExtVecInRegOpcode(Opcode))
26476     return SDValue();
26477 
26478   SDValue N00 = N0.getOperand(0);
26479   ArrayRef<int> Mask = SVN->getMask();
26480   unsigned NumElts = VT.getVectorNumElements();
26481   unsigned EltSizeInBits = VT.getScalarSizeInBits();
26482   unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
26483   unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
26484 
26485   if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
26486     return SDValue();
26487   unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
26488 
26489   // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
26490   // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
26491   // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
26492   auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
26493     for (unsigned i = 0; i != NumElts; ++i) {
26494       if (Mask[i] < 0)
26495         continue;
26496       if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
26497         continue;
26498       return false;
26499     }
26500     return true;
26501   };
26502 
26503   // At the moment we just handle the case where we've truncated back to the
26504   // same size as before the extension.
26505   // TODO: handle more extension/truncation cases as cases arise.
26506   if (EltSizeInBits != ExtSrcSizeInBits)
26507     return SDValue();
26508 
26509   // We can remove *extend_vector_inreg only if the truncation happens at
26510   // the same scale as the extension.
26511   if (isTruncate(ExtScale))
26512     return DAG.getBitcast(VT, N00);
26513 
26514   return SDValue();
26515 }
26516 
26517 // Combine shuffles of splat-shuffles of the form:
26518 // shuffle (shuffle V, undef, splat-mask), undef, M
26519 // If splat-mask contains undef elements, we need to be careful about
26520 // introducing undef's in the folded mask which are not the result of composing
26521 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)26522 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
26523                                         SelectionDAG &DAG) {
26524   EVT VT = Shuf->getValueType(0);
26525   unsigned NumElts = VT.getVectorNumElements();
26526 
26527   if (!Shuf->getOperand(1).isUndef())
26528     return SDValue();
26529 
26530   // See if this unary non-splat shuffle actually *is* a splat shuffle,
26531   // in disguise, with all demanded elements being identical.
26532   // FIXME: this can be done per-operand.
26533   if (!Shuf->isSplat()) {
26534     APInt DemandedElts(NumElts, 0);
26535     for (int Idx : Shuf->getMask()) {
26536       if (Idx < 0)
26537         continue; // Ignore sentinel indices.
26538       assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
26539       DemandedElts.setBit(Idx);
26540     }
26541     assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
26542     APInt UndefElts;
26543     if (DAG.isSplatValue(Shuf->getOperand(0), DemandedElts, UndefElts)) {
26544       // Even if all demanded elements are splat, some of them could be undef.
26545       // Which lowest demanded element is *not* known-undef?
26546       std::optional<unsigned> MinNonUndefIdx;
26547       for (int Idx : Shuf->getMask()) {
26548         if (Idx < 0 || UndefElts[Idx])
26549           continue; // Ignore sentinel indices, and undef elements.
26550         MinNonUndefIdx = std::min<unsigned>(Idx, MinNonUndefIdx.value_or(~0U));
26551       }
26552       if (!MinNonUndefIdx)
26553         return DAG.getUNDEF(VT); // All undef - result is undef.
26554       assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
26555       SmallVector<int, 8> SplatMask(Shuf->getMask());
26556       for (int &Idx : SplatMask) {
26557         if (Idx < 0)
26558           continue; // Passthrough sentinel indices.
26559         // Otherwise, just pick the lowest demanded non-undef element.
26560         // Or sentinel undef, if we know we'd pick a known-undef element.
26561         Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
26562       }
26563       assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
26564       return DAG.getVectorShuffle(VT, SDLoc(Shuf), Shuf->getOperand(0),
26565                                   Shuf->getOperand(1), SplatMask);
26566     }
26567   }
26568 
26569   // If the inner operand is a known splat with no undefs, just return that directly.
26570   // TODO: Create DemandedElts mask from Shuf's mask.
26571   // TODO: Allow undef elements and merge with the shuffle code below.
26572   if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
26573     return Shuf->getOperand(0);
26574 
26575   auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
26576   if (!Splat || !Splat->isSplat())
26577     return SDValue();
26578 
26579   ArrayRef<int> ShufMask = Shuf->getMask();
26580   ArrayRef<int> SplatMask = Splat->getMask();
26581   assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
26582 
26583   // Prefer simplifying to the splat-shuffle, if possible. This is legal if
26584   // every undef mask element in the splat-shuffle has a corresponding undef
26585   // element in the user-shuffle's mask or if the composition of mask elements
26586   // would result in undef.
26587   // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
26588   // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
26589   //   In this case it is not legal to simplify to the splat-shuffle because we
26590   //   may be exposing the users of the shuffle an undef element at index 1
26591   //   which was not there before the combine.
26592   // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
26593   //   In this case the composition of masks yields SplatMask, so it's ok to
26594   //   simplify to the splat-shuffle.
26595   // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
26596   //   In this case the composed mask includes all undef elements of SplatMask
26597   //   and in addition sets element zero to undef. It is safe to simplify to
26598   //   the splat-shuffle.
26599   auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
26600                                        ArrayRef<int> SplatMask) {
26601     for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
26602       if (UserMask[i] != -1 && SplatMask[i] == -1 &&
26603           SplatMask[UserMask[i]] != -1)
26604         return false;
26605     return true;
26606   };
26607   if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
26608     return Shuf->getOperand(0);
26609 
26610   // Create a new shuffle with a mask that is composed of the two shuffles'
26611   // masks.
26612   SmallVector<int, 32> NewMask;
26613   for (int Idx : ShufMask)
26614     NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
26615 
26616   return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
26617                               Splat->getOperand(0), Splat->getOperand(1),
26618                               NewMask);
26619 }
26620 
26621 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
26622 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)26623 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
26624                                        SelectionDAG &DAG,
26625                                        const TargetLowering &TLI,
26626                                        bool LegalOperations) {
26627   SDValue Op0 = SVN->getOperand(0);
26628   SDValue Op1 = SVN->getOperand(1);
26629   EVT VT = SVN->getValueType(0);
26630   if (Op0.getOpcode() != ISD::BITCAST)
26631     return SDValue();
26632   EVT InVT = Op0.getOperand(0).getValueType();
26633   if (!InVT.isVector() ||
26634       (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
26635                           Op1.getOperand(0).getValueType() != InVT)))
26636     return SDValue();
26637   if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
26638       (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
26639     return SDValue();
26640 
26641   int VTLanes = VT.getVectorNumElements();
26642   int InLanes = InVT.getVectorNumElements();
26643   if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
26644       (LegalOperations &&
26645        !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
26646     return SDValue();
26647   int Factor = VTLanes / InLanes;
26648 
26649   // Check that each group of lanes in the mask are either undef or make a valid
26650   // mask for the wider lane type.
26651   ArrayRef<int> Mask = SVN->getMask();
26652   SmallVector<int> NewMask;
26653   if (!widenShuffleMaskElts(Factor, Mask, NewMask))
26654     return SDValue();
26655 
26656   if (!TLI.isShuffleMaskLegal(NewMask, InVT))
26657     return SDValue();
26658 
26659   // Create the new shuffle with the new mask and bitcast it back to the
26660   // original type.
26661   SDLoc DL(SVN);
26662   Op0 = Op0.getOperand(0);
26663   Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
26664   SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
26665   return DAG.getBitcast(VT, NewShuf);
26666 }
26667 
26668 /// Combine shuffle of shuffle of the form:
26669 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)26670 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
26671                                      SelectionDAG &DAG) {
26672   if (!OuterShuf->getOperand(1).isUndef())
26673     return SDValue();
26674   auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
26675   if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
26676     return SDValue();
26677 
26678   ArrayRef<int> OuterMask = OuterShuf->getMask();
26679   ArrayRef<int> InnerMask = InnerShuf->getMask();
26680   unsigned NumElts = OuterMask.size();
26681   assert(NumElts == InnerMask.size() && "Mask length mismatch");
26682   SmallVector<int, 32> CombinedMask(NumElts, -1);
26683   int SplatIndex = -1;
26684   for (unsigned i = 0; i != NumElts; ++i) {
26685     // Undef lanes remain undef.
26686     int OuterMaskElt = OuterMask[i];
26687     if (OuterMaskElt == -1)
26688       continue;
26689 
26690     // Peek through the shuffle masks to get the underlying source element.
26691     int InnerMaskElt = InnerMask[OuterMaskElt];
26692     if (InnerMaskElt == -1)
26693       continue;
26694 
26695     // Initialize the splatted element.
26696     if (SplatIndex == -1)
26697       SplatIndex = InnerMaskElt;
26698 
26699     // Non-matching index - this is not a splat.
26700     if (SplatIndex != InnerMaskElt)
26701       return SDValue();
26702 
26703     CombinedMask[i] = InnerMaskElt;
26704   }
26705   assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
26706           getSplatIndex(CombinedMask) != -1) &&
26707          "Expected a splat mask");
26708 
26709   // TODO: The transform may be a win even if the mask is not legal.
26710   EVT VT = OuterShuf->getValueType(0);
26711   assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
26712   if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
26713     return SDValue();
26714 
26715   return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
26716                               InnerShuf->getOperand(1), CombinedMask);
26717 }
26718 
26719 /// If the shuffle mask is taking exactly one element from the first vector
26720 /// operand and passing through all other elements from the second vector
26721 /// operand, return the index of the mask element that is choosing an element
26722 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)26723 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
26724   int MaskSize = Mask.size();
26725   int EltFromOp0 = -1;
26726   // TODO: This does not match if there are undef elements in the shuffle mask.
26727   // Should we ignore undefs in the shuffle mask instead? The trade-off is
26728   // removing an instruction (a shuffle), but losing the knowledge that some
26729   // vector lanes are not needed.
26730   for (int i = 0; i != MaskSize; ++i) {
26731     if (Mask[i] >= 0 && Mask[i] < MaskSize) {
26732       // We're looking for a shuffle of exactly one element from operand 0.
26733       if (EltFromOp0 != -1)
26734         return -1;
26735       EltFromOp0 = i;
26736     } else if (Mask[i] != i + MaskSize) {
26737       // Nothing from operand 1 can change lanes.
26738       return -1;
26739     }
26740   }
26741   return EltFromOp0;
26742 }
26743 
26744 /// If a shuffle inserts exactly one element from a source vector operand into
26745 /// another vector operand and we can access the specified element as a scalar,
26746 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf)26747 SDValue DAGCombiner::replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf) {
26748   // First, check if we are taking one element of a vector and shuffling that
26749   // element into another vector.
26750   ArrayRef<int> Mask = Shuf->getMask();
26751   SmallVector<int, 16> CommutedMask(Mask);
26752   SDValue Op0 = Shuf->getOperand(0);
26753   SDValue Op1 = Shuf->getOperand(1);
26754   int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
26755   if (ShufOp0Index == -1) {
26756     // Commute mask and check again.
26757     ShuffleVectorSDNode::commuteMask(CommutedMask);
26758     ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
26759     if (ShufOp0Index == -1)
26760       return SDValue();
26761     // Commute operands to match the commuted shuffle mask.
26762     std::swap(Op0, Op1);
26763     Mask = CommutedMask;
26764   }
26765 
26766   // The shuffle inserts exactly one element from operand 0 into operand 1.
26767   // Now see if we can access that element as a scalar via a real insert element
26768   // instruction.
26769   // TODO: We can try harder to locate the element as a scalar. Examples: it
26770   // could be an operand of BUILD_VECTOR, or a constant.
26771   assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
26772          "Shuffle mask value must be from operand 0");
26773 
26774   SDValue Elt;
26775   if (sd_match(Op0, m_InsertElt(m_Value(), m_Value(Elt),
26776                                 m_SpecificInt(Mask[ShufOp0Index])))) {
26777     // There's an existing insertelement with constant insertion index, so we
26778     // don't need to check the legality/profitability of a replacement operation
26779     // that differs at most in the constant value. The target should be able to
26780     // lower any of those in a similar way. If not, legalization will expand
26781     // this to a scalar-to-vector plus shuffle.
26782     //
26783     // Note that the shuffle may move the scalar from the position that the
26784     // insert element used. Therefore, our new insert element occurs at the
26785     // shuffle's mask index value, not the insert's index value.
26786     //
26787     // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
26788     SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
26789     return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
26790                        Op1, Elt, NewInsIndex);
26791   }
26792 
26793   if (!hasOperation(ISD::INSERT_VECTOR_ELT, Op0.getValueType()))
26794     return SDValue();
26795 
26796   if (sd_match(Op0, m_UnaryOp(ISD::SCALAR_TO_VECTOR, m_Value(Elt))) &&
26797       Mask[ShufOp0Index] == 0) {
26798     SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
26799     return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
26800                        Op1, Elt, NewInsIndex);
26801   }
26802 
26803   return SDValue();
26804 }
26805 
26806 /// If we have a unary shuffle of a shuffle, see if it can be folded away
26807 /// completely. This has the potential to lose undef knowledge because the first
26808 /// shuffle may not have an undef mask element where the second one does. So
26809 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)26810 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
26811   // shuf (shuf0 X, Y, Mask0), undef, Mask
26812   auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
26813   if (!Shuf0 || !Shuf->getOperand(1).isUndef())
26814     return SDValue();
26815 
26816   ArrayRef<int> Mask = Shuf->getMask();
26817   ArrayRef<int> Mask0 = Shuf0->getMask();
26818   for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
26819     // Ignore undef elements.
26820     if (Mask[i] == -1)
26821       continue;
26822     assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
26823 
26824     // Is the element of the shuffle operand chosen by this shuffle the same as
26825     // the element chosen by the shuffle operand itself?
26826     if (Mask0[Mask[i]] != Mask0[i])
26827       return SDValue();
26828   }
26829   // Every element of this shuffle is identical to the result of the previous
26830   // shuffle, so we can replace this value.
26831   return Shuf->getOperand(0);
26832 }
26833 
visitVECTOR_SHUFFLE(SDNode * N)26834 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
26835   EVT VT = N->getValueType(0);
26836   unsigned NumElts = VT.getVectorNumElements();
26837 
26838   SDValue N0 = N->getOperand(0);
26839   SDValue N1 = N->getOperand(1);
26840 
26841   assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
26842 
26843   // Canonicalize shuffle undef, undef -> undef
26844   if (N0.isUndef() && N1.isUndef())
26845     return DAG.getUNDEF(VT);
26846 
26847   ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
26848 
26849   // Canonicalize shuffle v, v -> v, undef
26850   if (N0 == N1)
26851     return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
26852                                 createUnaryMask(SVN->getMask(), NumElts));
26853 
26854   // Canonicalize shuffle undef, v -> v, undef.  Commute the shuffle mask.
26855   if (N0.isUndef())
26856     return DAG.getCommutedVectorShuffle(*SVN);
26857 
26858   // Remove references to rhs if it is undef
26859   if (N1.isUndef()) {
26860     bool Changed = false;
26861     SmallVector<int, 8> NewMask;
26862     for (unsigned i = 0; i != NumElts; ++i) {
26863       int Idx = SVN->getMaskElt(i);
26864       if (Idx >= (int)NumElts) {
26865         Idx = -1;
26866         Changed = true;
26867       }
26868       NewMask.push_back(Idx);
26869     }
26870     if (Changed)
26871       return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
26872   }
26873 
26874   if (SDValue InsElt = replaceShuffleOfInsert(SVN))
26875     return InsElt;
26876 
26877   // A shuffle of a single vector that is a splatted value can always be folded.
26878   if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
26879     return V;
26880 
26881   if (SDValue V = formSplatFromShuffles(SVN, DAG))
26882     return V;
26883 
26884   // If it is a splat, check if the argument vector is another splat or a
26885   // build_vector.
26886   if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
26887     int SplatIndex = SVN->getSplatIndex();
26888     if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
26889         TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
26890       // splat (vector_bo L, R), Index -->
26891       // splat (scalar_bo (extelt L, Index), (extelt R, Index))
26892       SDValue L = N0.getOperand(0), R = N0.getOperand(1);
26893       SDLoc DL(N);
26894       EVT EltVT = VT.getScalarType();
26895       SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
26896       SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
26897       SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
26898       SDValue NewBO =
26899           DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
26900       SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
26901       SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
26902       return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
26903     }
26904 
26905     // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
26906     // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
26907     if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
26908         N0.hasOneUse()) {
26909       if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
26910         return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
26911 
26912       if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
26913         if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
26914           if (Idx->getAPIntValue() == SplatIndex)
26915             return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
26916 
26917       // Look through a bitcast if LE and splatting lane 0, through to a
26918       // scalar_to_vector or a build_vector.
26919       if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(0).hasOneUse() &&
26920           SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
26921           (N0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
26922            N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR)) {
26923         EVT N00VT = N0.getOperand(0).getValueType();
26924         if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
26925             VT.isInteger() && N00VT.isInteger()) {
26926           EVT InVT =
26927               TLI.getTypeToTransformTo(*DAG.getContext(), VT.getScalarType());
26928           SDValue Op = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0),
26929                                           SDLoc(N), InVT);
26930           return DAG.getSplatBuildVector(VT, SDLoc(N), Op);
26931         }
26932       }
26933     }
26934 
26935     // If this is a bit convert that changes the element type of the vector but
26936     // not the number of vector elements, look through it.  Be careful not to
26937     // look though conversions that change things like v4f32 to v2f64.
26938     SDNode *V = N0.getNode();
26939     if (V->getOpcode() == ISD::BITCAST) {
26940       SDValue ConvInput = V->getOperand(0);
26941       if (ConvInput.getValueType().isVector() &&
26942           ConvInput.getValueType().getVectorNumElements() == NumElts)
26943         V = ConvInput.getNode();
26944     }
26945 
26946     if (V->getOpcode() == ISD::BUILD_VECTOR) {
26947       assert(V->getNumOperands() == NumElts &&
26948              "BUILD_VECTOR has wrong number of operands");
26949       SDValue Base;
26950       bool AllSame = true;
26951       for (unsigned i = 0; i != NumElts; ++i) {
26952         if (!V->getOperand(i).isUndef()) {
26953           Base = V->getOperand(i);
26954           break;
26955         }
26956       }
26957       // Splat of <u, u, u, u>, return <u, u, u, u>
26958       if (!Base.getNode())
26959         return N0;
26960       for (unsigned i = 0; i != NumElts; ++i) {
26961         if (V->getOperand(i) != Base) {
26962           AllSame = false;
26963           break;
26964         }
26965       }
26966       // Splat of <x, x, x, x>, return <x, x, x, x>
26967       if (AllSame)
26968         return N0;
26969 
26970       // Canonicalize any other splat as a build_vector, but avoid defining any
26971       // undefined elements in the mask.
26972       SDValue Splatted = V->getOperand(SplatIndex);
26973       SmallVector<SDValue, 8> Ops(NumElts, Splatted);
26974       EVT EltVT = Splatted.getValueType();
26975 
26976       for (unsigned i = 0; i != NumElts; ++i) {
26977         if (SVN->getMaskElt(i) < 0)
26978           Ops[i] = DAG.getUNDEF(EltVT);
26979       }
26980 
26981       SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
26982 
26983       // We may have jumped through bitcasts, so the type of the
26984       // BUILD_VECTOR may not match the type of the shuffle.
26985       if (V->getValueType(0) != VT)
26986         NewBV = DAG.getBitcast(VT, NewBV);
26987       return NewBV;
26988     }
26989   }
26990 
26991   // Simplify source operands based on shuffle mask.
26992   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
26993     return SDValue(N, 0);
26994 
26995   // This is intentionally placed after demanded elements simplification because
26996   // it could eliminate knowledge of undef elements created by this shuffle.
26997   if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
26998     return ShufOp;
26999 
27000   // Match shuffles that can be converted to any_vector_extend_in_reg.
27001   if (SDValue V =
27002           combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
27003     return V;
27004 
27005   // Combine "truncate_vector_in_reg" style shuffles.
27006   if (SDValue V = combineTruncationShuffle(SVN, DAG))
27007     return V;
27008 
27009   if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
27010       Level < AfterLegalizeVectorOps &&
27011       (N1.isUndef() ||
27012       (N1.getOpcode() == ISD::CONCAT_VECTORS &&
27013        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
27014     if (SDValue V = partitionShuffleOfConcats(N, DAG))
27015       return V;
27016   }
27017 
27018   // A shuffle of a concat of the same narrow vector can be reduced to use
27019   // only low-half elements of a concat with undef:
27020   // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
27021   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
27022       N0.getNumOperands() == 2 &&
27023       N0.getOperand(0) == N0.getOperand(1)) {
27024     int HalfNumElts = (int)NumElts / 2;
27025     SmallVector<int, 8> NewMask;
27026     for (unsigned i = 0; i != NumElts; ++i) {
27027       int Idx = SVN->getMaskElt(i);
27028       if (Idx >= HalfNumElts) {
27029         assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
27030         Idx -= HalfNumElts;
27031       }
27032       NewMask.push_back(Idx);
27033     }
27034     if (TLI.isShuffleMaskLegal(NewMask, VT)) {
27035       SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
27036       SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
27037                                    N0.getOperand(0), UndefVec);
27038       return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
27039     }
27040   }
27041 
27042   // See if we can replace a shuffle with an insert_subvector.
27043   // e.g. v2i32 into v8i32:
27044   // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
27045   // --> insert_subvector(lhs,rhs1,4).
27046   if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
27047       TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
27048     auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
27049       // Ensure RHS subvectors are legal.
27050       assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
27051       EVT SubVT = RHS.getOperand(0).getValueType();
27052       int NumSubVecs = RHS.getNumOperands();
27053       int NumSubElts = SubVT.getVectorNumElements();
27054       assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
27055       if (!TLI.isTypeLegal(SubVT))
27056         return SDValue();
27057 
27058       // Don't bother if we have an unary shuffle (matches undef + LHS elts).
27059       if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
27060         return SDValue();
27061 
27062       // Search [NumSubElts] spans for RHS sequence.
27063       // TODO: Can we avoid nested loops to increase performance?
27064       SmallVector<int> InsertionMask(NumElts);
27065       for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
27066         for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
27067           // Reset mask to identity.
27068           std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
27069 
27070           // Add subvector insertion.
27071           std::iota(InsertionMask.begin() + SubIdx,
27072                     InsertionMask.begin() + SubIdx + NumSubElts,
27073                     NumElts + (SubVec * NumSubElts));
27074 
27075           // See if the shuffle mask matches the reference insertion mask.
27076           bool MatchingShuffle = true;
27077           for (int i = 0; i != (int)NumElts; ++i) {
27078             int ExpectIdx = InsertionMask[i];
27079             int ActualIdx = Mask[i];
27080             if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
27081               MatchingShuffle = false;
27082               break;
27083             }
27084           }
27085 
27086           if (MatchingShuffle)
27087             return DAG.getInsertSubvector(SDLoc(N), LHS, RHS.getOperand(SubVec),
27088                                           SubIdx);
27089         }
27090       }
27091       return SDValue();
27092     };
27093     ArrayRef<int> Mask = SVN->getMask();
27094     if (N1.getOpcode() == ISD::CONCAT_VECTORS)
27095       if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
27096         return InsertN1;
27097     if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
27098       SmallVector<int> CommuteMask(Mask);
27099       ShuffleVectorSDNode::commuteMask(CommuteMask);
27100       if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
27101         return InsertN0;
27102     }
27103   }
27104 
27105   // If we're not performing a select/blend shuffle, see if we can convert the
27106   // shuffle into a AND node, with all the out-of-lane elements are known zero.
27107   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
27108     bool IsInLaneMask = true;
27109     ArrayRef<int> Mask = SVN->getMask();
27110     SmallVector<int, 16> ClearMask(NumElts, -1);
27111     APInt DemandedLHS = APInt::getZero(NumElts);
27112     APInt DemandedRHS = APInt::getZero(NumElts);
27113     for (int I = 0; I != (int)NumElts; ++I) {
27114       int M = Mask[I];
27115       if (M < 0)
27116         continue;
27117       ClearMask[I] = M == I ? I : (I + NumElts);
27118       IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
27119       if (M != I) {
27120         APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
27121         Demanded.setBit(M % NumElts);
27122       }
27123     }
27124     // TODO: Should we try to mask with N1 as well?
27125     if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
27126         (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
27127         (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
27128       SDLoc DL(N);
27129       EVT IntVT = VT.changeVectorElementTypeToInteger();
27130       EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
27131       // Transform the type to a legal type so that the buildvector constant
27132       // elements are not illegal. Make sure that the result is larger than the
27133       // original type, incase the value is split into two (eg i64->i32).
27134       if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
27135         IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
27136       if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
27137         SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
27138         SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
27139         SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
27140         for (int I = 0; I != (int)NumElts; ++I)
27141           if (0 <= Mask[I])
27142             AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
27143 
27144         // See if a clear mask is legal instead of going via
27145         // XformToShuffleWithZero which loses UNDEF mask elements.
27146         if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
27147           return DAG.getBitcast(
27148               VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
27149                                       DAG.getConstant(0, DL, IntVT), ClearMask));
27150 
27151         if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
27152           return DAG.getBitcast(
27153               VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
27154                               DAG.getBuildVector(IntVT, DL, AndMask)));
27155       }
27156     }
27157   }
27158 
27159   // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
27160   // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
27161   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
27162     if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
27163       return Res;
27164 
27165   // If this shuffle only has a single input that is a bitcasted shuffle,
27166   // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
27167   // back to their original types.
27168   if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
27169       N1.isUndef() && Level < AfterLegalizeVectorOps &&
27170       TLI.isTypeLegal(VT)) {
27171 
27172     SDValue BC0 = peekThroughOneUseBitcasts(N0);
27173     if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
27174       EVT SVT = VT.getScalarType();
27175       EVT InnerVT = BC0->getValueType(0);
27176       EVT InnerSVT = InnerVT.getScalarType();
27177 
27178       // Determine which shuffle works with the smaller scalar type.
27179       EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
27180       EVT ScaleSVT = ScaleVT.getScalarType();
27181 
27182       if (TLI.isTypeLegal(ScaleVT) &&
27183           0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
27184           0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
27185         int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
27186         int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
27187 
27188         // Scale the shuffle masks to the smaller scalar type.
27189         ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
27190         SmallVector<int, 8> InnerMask;
27191         SmallVector<int, 8> OuterMask;
27192         narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
27193         narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
27194 
27195         // Merge the shuffle masks.
27196         SmallVector<int, 8> NewMask;
27197         for (int M : OuterMask)
27198           NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
27199 
27200         // Test for shuffle mask legality over both commutations.
27201         SDValue SV0 = BC0->getOperand(0);
27202         SDValue SV1 = BC0->getOperand(1);
27203         bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
27204         if (!LegalMask) {
27205           std::swap(SV0, SV1);
27206           ShuffleVectorSDNode::commuteMask(NewMask);
27207           LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
27208         }
27209 
27210         if (LegalMask) {
27211           SV0 = DAG.getBitcast(ScaleVT, SV0);
27212           SV1 = DAG.getBitcast(ScaleVT, SV1);
27213           return DAG.getBitcast(
27214               VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
27215         }
27216       }
27217     }
27218   }
27219 
27220   // Match shuffles of bitcasts, so long as the mask can be treated as the
27221   // larger type.
27222   if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
27223     return V;
27224 
27225   // Compute the combined shuffle mask for a shuffle with SV0 as the first
27226   // operand, and SV1 as the second operand.
27227   // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
27228   //      Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
27229   auto MergeInnerShuffle =
27230       [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
27231                      ShuffleVectorSDNode *OtherSVN, SDValue N1,
27232                      const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
27233                      SmallVectorImpl<int> &Mask) -> bool {
27234     // Don't try to fold splats; they're likely to simplify somehow, or they
27235     // might be free.
27236     if (OtherSVN->isSplat())
27237       return false;
27238 
27239     SV0 = SV1 = SDValue();
27240     Mask.clear();
27241 
27242     for (unsigned i = 0; i != NumElts; ++i) {
27243       int Idx = SVN->getMaskElt(i);
27244       if (Idx < 0) {
27245         // Propagate Undef.
27246         Mask.push_back(Idx);
27247         continue;
27248       }
27249 
27250       if (Commute)
27251         Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
27252 
27253       SDValue CurrentVec;
27254       if (Idx < (int)NumElts) {
27255         // This shuffle index refers to the inner shuffle N0. Lookup the inner
27256         // shuffle mask to identify which vector is actually referenced.
27257         Idx = OtherSVN->getMaskElt(Idx);
27258         if (Idx < 0) {
27259           // Propagate Undef.
27260           Mask.push_back(Idx);
27261           continue;
27262         }
27263         CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
27264                                           : OtherSVN->getOperand(1);
27265       } else {
27266         // This shuffle index references an element within N1.
27267         CurrentVec = N1;
27268       }
27269 
27270       // Simple case where 'CurrentVec' is UNDEF.
27271       if (CurrentVec.isUndef()) {
27272         Mask.push_back(-1);
27273         continue;
27274       }
27275 
27276       // Canonicalize the shuffle index. We don't know yet if CurrentVec
27277       // will be the first or second operand of the combined shuffle.
27278       Idx = Idx % NumElts;
27279       if (!SV0.getNode() || SV0 == CurrentVec) {
27280         // Ok. CurrentVec is the left hand side.
27281         // Update the mask accordingly.
27282         SV0 = CurrentVec;
27283         Mask.push_back(Idx);
27284         continue;
27285       }
27286       if (!SV1.getNode() || SV1 == CurrentVec) {
27287         // Ok. CurrentVec is the right hand side.
27288         // Update the mask accordingly.
27289         SV1 = CurrentVec;
27290         Mask.push_back(Idx + NumElts);
27291         continue;
27292       }
27293 
27294       // Last chance - see if the vector is another shuffle and if it
27295       // uses one of the existing candidate shuffle ops.
27296       if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
27297         int InnerIdx = CurrentSVN->getMaskElt(Idx);
27298         if (InnerIdx < 0) {
27299           Mask.push_back(-1);
27300           continue;
27301         }
27302         SDValue InnerVec = (InnerIdx < (int)NumElts)
27303                                ? CurrentSVN->getOperand(0)
27304                                : CurrentSVN->getOperand(1);
27305         if (InnerVec.isUndef()) {
27306           Mask.push_back(-1);
27307           continue;
27308         }
27309         InnerIdx %= NumElts;
27310         if (InnerVec == SV0) {
27311           Mask.push_back(InnerIdx);
27312           continue;
27313         }
27314         if (InnerVec == SV1) {
27315           Mask.push_back(InnerIdx + NumElts);
27316           continue;
27317         }
27318       }
27319 
27320       // Bail out if we cannot convert the shuffle pair into a single shuffle.
27321       return false;
27322     }
27323 
27324     if (llvm::all_of(Mask, [](int M) { return M < 0; }))
27325       return true;
27326 
27327     // Avoid introducing shuffles with illegal mask.
27328     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
27329     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
27330     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
27331     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
27332     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
27333     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
27334     if (TLI.isShuffleMaskLegal(Mask, VT))
27335       return true;
27336 
27337     std::swap(SV0, SV1);
27338     ShuffleVectorSDNode::commuteMask(Mask);
27339     return TLI.isShuffleMaskLegal(Mask, VT);
27340   };
27341 
27342   if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
27343     // Canonicalize shuffles according to rules:
27344     //  shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
27345     //  shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
27346     //  shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
27347     if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
27348         N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
27349       // The incoming shuffle must be of the same type as the result of the
27350       // current shuffle.
27351       assert(N1->getOperand(0).getValueType() == VT &&
27352              "Shuffle types don't match");
27353 
27354       SDValue SV0 = N1->getOperand(0);
27355       SDValue SV1 = N1->getOperand(1);
27356       bool HasSameOp0 = N0 == SV0;
27357       bool IsSV1Undef = SV1.isUndef();
27358       if (HasSameOp0 || IsSV1Undef || N0 == SV1)
27359         // Commute the operands of this shuffle so merging below will trigger.
27360         return DAG.getCommutedVectorShuffle(*SVN);
27361     }
27362 
27363     // Canonicalize splat shuffles to the RHS to improve merging below.
27364     //  shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
27365     if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
27366         N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
27367         cast<ShuffleVectorSDNode>(N0)->isSplat() &&
27368         !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
27369       return DAG.getCommutedVectorShuffle(*SVN);
27370     }
27371 
27372     // Try to fold according to rules:
27373     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
27374     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
27375     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
27376     // Don't try to fold shuffles with illegal type.
27377     // Only fold if this shuffle is the only user of the other shuffle.
27378     // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
27379     for (int i = 0; i != 2; ++i) {
27380       if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
27381           N->isOnlyUserOf(N->getOperand(i).getNode())) {
27382         // The incoming shuffle must be of the same type as the result of the
27383         // current shuffle.
27384         auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
27385         assert(OtherSV->getOperand(0).getValueType() == VT &&
27386                "Shuffle types don't match");
27387 
27388         SDValue SV0, SV1;
27389         SmallVector<int, 4> Mask;
27390         if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
27391                               SV0, SV1, Mask)) {
27392           // Check if all indices in Mask are Undef. In case, propagate Undef.
27393           if (llvm::all_of(Mask, [](int M) { return M < 0; }))
27394             return DAG.getUNDEF(VT);
27395 
27396           return DAG.getVectorShuffle(VT, SDLoc(N),
27397                                       SV0 ? SV0 : DAG.getUNDEF(VT),
27398                                       SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
27399         }
27400       }
27401     }
27402 
27403     // Merge shuffles through binops if we are able to merge it with at least
27404     // one other shuffles.
27405     // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
27406     // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
27407     unsigned SrcOpcode = N0.getOpcode();
27408     if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
27409         (N1.isUndef() ||
27410          (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
27411       // Get binop source ops, or just pass on the undef.
27412       SDValue Op00 = N0.getOperand(0);
27413       SDValue Op01 = N0.getOperand(1);
27414       SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
27415       SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
27416       // TODO: We might be able to relax the VT check but we don't currently
27417       // have any isBinOp() that has different result/ops VTs so play safe until
27418       // we have test coverage.
27419       if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
27420           Op01.getValueType() == VT && Op11.getValueType() == VT &&
27421           (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
27422            Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
27423            Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
27424            Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
27425         auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
27426                                         SmallVectorImpl<int> &Mask, bool LeftOp,
27427                                         bool Commute) {
27428           SDValue InnerN = Commute ? N1 : N0;
27429           SDValue Op0 = LeftOp ? Op00 : Op01;
27430           SDValue Op1 = LeftOp ? Op10 : Op11;
27431           if (Commute)
27432             std::swap(Op0, Op1);
27433           // Only accept the merged shuffle if we don't introduce undef elements,
27434           // or the inner shuffle already contained undef elements.
27435           auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
27436           return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
27437                  MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
27438                                    Mask) &&
27439                  (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
27440                   llvm::none_of(Mask, [](int M) { return M < 0; }));
27441         };
27442 
27443         // Ensure we don't increase the number of shuffles - we must merge a
27444         // shuffle from at least one of the LHS and RHS ops.
27445         bool MergedLeft = false;
27446         SDValue LeftSV0, LeftSV1;
27447         SmallVector<int, 4> LeftMask;
27448         if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
27449             CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
27450           MergedLeft = true;
27451         } else {
27452           LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
27453           LeftSV0 = Op00, LeftSV1 = Op10;
27454         }
27455 
27456         bool MergedRight = false;
27457         SDValue RightSV0, RightSV1;
27458         SmallVector<int, 4> RightMask;
27459         if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
27460             CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
27461           MergedRight = true;
27462         } else {
27463           RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
27464           RightSV0 = Op01, RightSV1 = Op11;
27465         }
27466 
27467         if (MergedLeft || MergedRight) {
27468           SDLoc DL(N);
27469           SDValue LHS = DAG.getVectorShuffle(
27470               VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
27471               LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
27472           SDValue RHS = DAG.getVectorShuffle(
27473               VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
27474               RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
27475           return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
27476         }
27477       }
27478     }
27479   }
27480 
27481   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
27482     return V;
27483 
27484   // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
27485   // Perform this really late, because it could eliminate knowledge
27486   // of undef elements created by this shuffle.
27487   if (Level < AfterLegalizeTypes)
27488     if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
27489                                                           LegalOperations))
27490       return V;
27491 
27492   return SDValue();
27493 }
27494 
visitSCALAR_TO_VECTOR(SDNode * N)27495 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
27496   EVT VT = N->getValueType(0);
27497   if (!VT.isFixedLengthVector())
27498     return SDValue();
27499 
27500   // Try to convert a scalar binop with an extracted vector element to a vector
27501   // binop. This is intended to reduce potentially expensive register moves.
27502   // TODO: Check if both operands are extracted.
27503   // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
27504   // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
27505   SDValue Scalar = N->getOperand(0);
27506   unsigned Opcode = Scalar.getOpcode();
27507   EVT VecEltVT = VT.getScalarType();
27508   if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
27509       TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
27510       Scalar.getOperand(0).getValueType() == VecEltVT &&
27511       Scalar.getOperand(1).getValueType() == VecEltVT &&
27512       Scalar->isOnlyUserOf(Scalar.getOperand(0).getNode()) &&
27513       Scalar->isOnlyUserOf(Scalar.getOperand(1).getNode()) &&
27514       DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
27515     // Match an extract element and get a shuffle mask equivalent.
27516     SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
27517 
27518     for (int i : {0, 1}) {
27519       // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
27520       // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
27521       SDValue EE = Scalar.getOperand(i);
27522       auto *C = dyn_cast<ConstantSDNode>(Scalar.getOperand(i ? 0 : 1));
27523       if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
27524           EE.getOperand(0).getValueType() == VT &&
27525           isa<ConstantSDNode>(EE.getOperand(1))) {
27526         // Mask = {ExtractIndex, undef, undef....}
27527         ShufMask[0] = EE.getConstantOperandVal(1);
27528         // Make sure the shuffle is legal if we are crossing lanes.
27529         if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
27530           SDLoc DL(N);
27531           SDValue V[] = {EE.getOperand(0),
27532                          DAG.getConstant(C->getAPIntValue(), DL, VT)};
27533           SDValue VecBO = DAG.getNode(Opcode, DL, VT, V[i], V[1 - i]);
27534           return DAG.getVectorShuffle(VT, DL, VecBO, DAG.getUNDEF(VT),
27535                                       ShufMask);
27536         }
27537       }
27538     }
27539   }
27540 
27541   // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
27542   // with a VECTOR_SHUFFLE and possible truncate.
27543   if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
27544       !Scalar.getOperand(0).getValueType().isFixedLengthVector())
27545     return SDValue();
27546 
27547   // If we have an implicit truncate, truncate here if it is legal.
27548   if (VecEltVT != Scalar.getValueType() &&
27549       Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
27550     SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
27551     return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
27552   }
27553 
27554   auto *ExtIndexC = dyn_cast<ConstantSDNode>(Scalar.getOperand(1));
27555   if (!ExtIndexC)
27556     return SDValue();
27557 
27558   SDValue SrcVec = Scalar.getOperand(0);
27559   EVT SrcVT = SrcVec.getValueType();
27560   unsigned SrcNumElts = SrcVT.getVectorNumElements();
27561   unsigned VTNumElts = VT.getVectorNumElements();
27562   if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
27563     // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
27564     SmallVector<int, 8> Mask(SrcNumElts, -1);
27565     Mask[0] = ExtIndexC->getZExtValue();
27566     SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
27567         SrcVT, SDLoc(N), SrcVec, DAG.getUNDEF(SrcVT), Mask, DAG);
27568     if (!LegalShuffle)
27569       return SDValue();
27570 
27571     // If the initial vector is the same size, the shuffle is the result.
27572     if (VT == SrcVT)
27573       return LegalShuffle;
27574 
27575     // If not, shorten the shuffled vector.
27576     if (VTNumElts != SrcNumElts) {
27577       SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
27578       EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
27579                                    SrcVT.getVectorElementType(), VTNumElts);
27580       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, LegalShuffle,
27581                          ZeroIdx);
27582     }
27583   }
27584 
27585   return SDValue();
27586 }
27587 
visitINSERT_SUBVECTOR(SDNode * N)27588 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
27589   EVT VT = N->getValueType(0);
27590   SDValue N0 = N->getOperand(0);
27591   SDValue N1 = N->getOperand(1);
27592   SDValue N2 = N->getOperand(2);
27593   uint64_t InsIdx = N->getConstantOperandVal(2);
27594 
27595   // If inserting an UNDEF, just return the original vector.
27596   if (N1.isUndef())
27597     return N0;
27598 
27599   // If this is an insert of an extracted vector into an undef vector, we can
27600   // just use the input to the extract if the types match, and can simplify
27601   // in some cases even if they don't.
27602   if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
27603       N1.getOperand(1) == N2) {
27604     EVT SrcVT = N1.getOperand(0).getValueType();
27605     if (SrcVT == VT)
27606       return N1.getOperand(0);
27607     // TODO: To remove the zero check, need to adjust the offset to
27608     // a multiple of the new src type.
27609     if (isNullConstant(N2)) {
27610       if (VT.knownBitsGE(SrcVT) &&
27611           !(VT.isFixedLengthVector() && SrcVT.isScalableVector()))
27612         return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
27613                            VT, N0, N1.getOperand(0), N2);
27614       else if (VT.knownBitsLE(SrcVT) &&
27615                !(VT.isScalableVector() && SrcVT.isFixedLengthVector()))
27616         return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N),
27617                            VT, N1.getOperand(0), N2);
27618     }
27619   }
27620 
27621   // Handle case where we've ended up inserting back into the source vector
27622   // we extracted the subvector from.
27623   // insert_subvector(N0, extract_subvector(N0, N2), N2) --> N0
27624   if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(0) == N0 &&
27625       N1.getOperand(1) == N2)
27626     return N0;
27627 
27628   // Simplify scalar inserts into an undef vector:
27629   // insert_subvector undef, (splat X), N2 -> splat X
27630   if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
27631     if (DAG.isConstantValueOfAnyType(N1.getOperand(0)) || N1.hasOneUse())
27632       return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
27633 
27634   // insert_subvector (splat X), (splat X), N2 -> splat X
27635   if (N0.getOpcode() == ISD::SPLAT_VECTOR && N0.getOpcode() == N1.getOpcode() &&
27636       N0.getOperand(0) == N1.getOperand(0))
27637     return N0;
27638 
27639   // If we are inserting a bitcast value into an undef, with the same
27640   // number of elements, just use the bitcast input of the extract.
27641   // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
27642   //        BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
27643   if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
27644       N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
27645       N1.getOperand(0).getOperand(1) == N2 &&
27646       N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
27647           VT.getVectorElementCount() &&
27648       N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
27649           VT.getSizeInBits()) {
27650     return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
27651   }
27652 
27653   // If both N1 and N2 are bitcast values on which insert_subvector
27654   // would makes sense, pull the bitcast through.
27655   // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
27656   //        BITCAST (INSERT_SUBVECTOR N0 N1 N2)
27657   if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
27658     SDValue CN0 = N0.getOperand(0);
27659     SDValue CN1 = N1.getOperand(0);
27660     EVT CN0VT = CN0.getValueType();
27661     EVT CN1VT = CN1.getValueType();
27662     if (CN0VT.isVector() && CN1VT.isVector() &&
27663         CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
27664         CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
27665       SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
27666                                       CN0.getValueType(), CN0, CN1, N2);
27667       return DAG.getBitcast(VT, NewINSERT);
27668     }
27669   }
27670 
27671   // Combine INSERT_SUBVECTORs where we are inserting to the same index.
27672   // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
27673   // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
27674   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
27675       N0.getOperand(1).getValueType() == N1.getValueType() &&
27676       N0.getOperand(2) == N2)
27677     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
27678                        N1, N2);
27679 
27680   // Eliminate an intermediate insert into an undef vector:
27681   // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
27682   // insert_subvector undef, X, 0
27683   if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
27684       N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)) &&
27685       isNullConstant(N2))
27686     return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
27687                        N1.getOperand(1), N2);
27688 
27689   // Push subvector bitcasts to the output, adjusting the index as we go.
27690   // insert_subvector(bitcast(v), bitcast(s), c1)
27691   // -> bitcast(insert_subvector(v, s, c2))
27692   if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
27693       N1.getOpcode() == ISD::BITCAST) {
27694     SDValue N0Src = peekThroughBitcasts(N0);
27695     SDValue N1Src = peekThroughBitcasts(N1);
27696     EVT N0SrcSVT = N0Src.getValueType().getScalarType();
27697     EVT N1SrcSVT = N1Src.getValueType().getScalarType();
27698     if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
27699         N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
27700       EVT NewVT;
27701       SDLoc DL(N);
27702       SDValue NewIdx;
27703       LLVMContext &Ctx = *DAG.getContext();
27704       ElementCount NumElts = VT.getVectorElementCount();
27705       unsigned EltSizeInBits = VT.getScalarSizeInBits();
27706       if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
27707         unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
27708         NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
27709         NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
27710       } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
27711         unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
27712         if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
27713           NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
27714                                    NumElts.divideCoefficientBy(Scale));
27715           NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
27716         }
27717       }
27718       if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
27719         SDValue Res = DAG.getBitcast(NewVT, N0Src);
27720         Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
27721         return DAG.getBitcast(VT, Res);
27722       }
27723     }
27724   }
27725 
27726   // Canonicalize insert_subvector dag nodes.
27727   // Example:
27728   // (insert_subvector (insert_subvector A, Idx0), Idx1)
27729   // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
27730   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
27731       N1.getValueType() == N0.getOperand(1).getValueType()) {
27732     unsigned OtherIdx = N0.getConstantOperandVal(2);
27733     if (InsIdx < OtherIdx) {
27734       // Swap nodes.
27735       SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
27736                                   N0.getOperand(0), N1, N2);
27737       AddToWorklist(NewOp.getNode());
27738       return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
27739                          VT, NewOp, N0.getOperand(1), N0.getOperand(2));
27740     }
27741   }
27742 
27743   // If the input vector is a concatenation, and the insert replaces
27744   // one of the pieces, we can optimize into a single concat_vectors.
27745   if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
27746       N0.getOperand(0).getValueType() == N1.getValueType() &&
27747       N0.getOperand(0).getValueType().isScalableVector() ==
27748           N1.getValueType().isScalableVector()) {
27749     unsigned Factor = N1.getValueType().getVectorMinNumElements();
27750     SmallVector<SDValue, 8> Ops(N0->ops());
27751     Ops[InsIdx / Factor] = N1;
27752     return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
27753   }
27754 
27755   // Simplify source operands based on insertion.
27756   if (SimplifyDemandedVectorElts(SDValue(N, 0)))
27757     return SDValue(N, 0);
27758 
27759   return SDValue();
27760 }
27761 
visitFP_TO_FP16(SDNode * N)27762 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
27763   SDValue N0 = N->getOperand(0);
27764 
27765   // fold (fp_to_fp16 (fp16_to_fp op)) -> op
27766   if (N0->getOpcode() == ISD::FP16_TO_FP)
27767     return N0->getOperand(0);
27768 
27769   return SDValue();
27770 }
27771 
visitFP16_TO_FP(SDNode * N)27772 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
27773   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
27774   auto Op = N->getOpcode();
27775   assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
27776          "opcode should be FP16_TO_FP or BF16_TO_FP.");
27777   SDValue N0 = N->getOperand(0);
27778 
27779   // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
27780   // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
27781   if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
27782     ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
27783     if (AndConst && AndConst->getAPIntValue() == 0xffff) {
27784       return DAG.getNode(Op, SDLoc(N), N->getValueType(0), N0.getOperand(0));
27785     }
27786   }
27787 
27788   if (SDValue CastEliminated = eliminateFPCastPair(N))
27789     return CastEliminated;
27790 
27791   // Sometimes constants manage to survive very late in the pipeline, e.g.,
27792   // because they are wrapped inside the <1 x f16> type. Try one last time to
27793   // get rid of them.
27794   SDValue Folded = DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N),
27795                                               N->getValueType(0), {N0});
27796   return Folded;
27797 }
27798 
visitFP_TO_BF16(SDNode * N)27799 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
27800   SDValue N0 = N->getOperand(0);
27801 
27802   // fold (fp_to_bf16 (bf16_to_fp op)) -> op
27803   if (N0->getOpcode() == ISD::BF16_TO_FP)
27804     return N0->getOperand(0);
27805 
27806   return SDValue();
27807 }
27808 
visitBF16_TO_FP(SDNode * N)27809 SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
27810   // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
27811   return visitFP16_TO_FP(N);
27812 }
27813 
visitVECREDUCE(SDNode * N)27814 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
27815   SDValue N0 = N->getOperand(0);
27816   EVT VT = N0.getValueType();
27817   unsigned Opcode = N->getOpcode();
27818 
27819   // VECREDUCE over 1-element vector is just an extract.
27820   if (VT.getVectorElementCount().isScalar()) {
27821     SDLoc dl(N);
27822     SDValue Res =
27823         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
27824                     DAG.getVectorIdxConstant(0, dl));
27825     if (Res.getValueType() != N->getValueType(0))
27826       Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
27827     return Res;
27828   }
27829 
27830   // On an boolean vector an and/or reduction is the same as a umin/umax
27831   // reduction. Convert them if the latter is legal while the former isn't.
27832   if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
27833     unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
27834         ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
27835     if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
27836         TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
27837         DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
27838       return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
27839   }
27840 
27841   // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
27842   // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
27843   if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
27844       TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
27845     SDValue Vec = N0.getOperand(0);
27846     SDValue Subvec = N0.getOperand(1);
27847     if ((Opcode == ISD::VECREDUCE_OR &&
27848          (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
27849         (Opcode == ISD::VECREDUCE_AND &&
27850          (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
27851       return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
27852   }
27853 
27854   // vecreduce_or(sext(x)) -> sext(vecreduce_or(x))
27855   // Same for zext and anyext, and for and/or/xor reductions.
27856   if ((Opcode == ISD::VECREDUCE_OR || Opcode == ISD::VECREDUCE_AND ||
27857        Opcode == ISD::VECREDUCE_XOR) &&
27858       (N0.getOpcode() == ISD::SIGN_EXTEND ||
27859        N0.getOpcode() == ISD::ZERO_EXTEND ||
27860        N0.getOpcode() == ISD::ANY_EXTEND) &&
27861       TLI.isOperationLegalOrCustom(Opcode, N0.getOperand(0).getValueType())) {
27862     SDValue Red = DAG.getNode(Opcode, SDLoc(N),
27863                               N0.getOperand(0).getValueType().getScalarType(),
27864                               N0.getOperand(0));
27865     return DAG.getNode(N0.getOpcode(), SDLoc(N), N->getValueType(0), Red);
27866   }
27867   return SDValue();
27868 }
27869 
visitVP_FSUB(SDNode * N)27870 SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
27871   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
27872 
27873   // FSUB -> FMA combines:
27874   if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
27875     AddToWorklist(Fused.getNode());
27876     return Fused;
27877   }
27878   return SDValue();
27879 }
27880 
visitVPOp(SDNode * N)27881 SDValue DAGCombiner::visitVPOp(SDNode *N) {
27882 
27883   if (N->getOpcode() == ISD::VP_GATHER)
27884     if (SDValue SD = visitVPGATHER(N))
27885       return SD;
27886 
27887   if (N->getOpcode() == ISD::VP_SCATTER)
27888     if (SDValue SD = visitVPSCATTER(N))
27889       return SD;
27890 
27891   if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
27892     if (SDValue SD = visitVP_STRIDED_LOAD(N))
27893       return SD;
27894 
27895   if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
27896     if (SDValue SD = visitVP_STRIDED_STORE(N))
27897       return SD;
27898 
27899   // VP operations in which all vector elements are disabled - either by
27900   // determining that the mask is all false or that the EVL is 0 - can be
27901   // eliminated.
27902   bool AreAllEltsDisabled = false;
27903   if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
27904     AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
27905   if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
27906     AreAllEltsDisabled |=
27907         ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
27908 
27909   // This is the only generic VP combine we support for now.
27910   if (!AreAllEltsDisabled) {
27911     switch (N->getOpcode()) {
27912     case ISD::VP_FADD:
27913       return visitVP_FADD(N);
27914     case ISD::VP_FSUB:
27915       return visitVP_FSUB(N);
27916     case ISD::VP_FMA:
27917       return visitFMA<VPMatchContext>(N);
27918     case ISD::VP_SELECT:
27919       return visitVP_SELECT(N);
27920     case ISD::VP_MUL:
27921       return visitMUL<VPMatchContext>(N);
27922     case ISD::VP_SUB:
27923       return foldSubCtlzNot<VPMatchContext>(N, DAG);
27924     default:
27925       break;
27926     }
27927     return SDValue();
27928   }
27929 
27930   // Binary operations can be replaced by UNDEF.
27931   if (ISD::isVPBinaryOp(N->getOpcode()))
27932     return DAG.getUNDEF(N->getValueType(0));
27933 
27934   // VP Memory operations can be replaced by either the chain (stores) or the
27935   // chain + undef (loads).
27936   if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
27937     if (MemSD->writeMem())
27938       return MemSD->getChain();
27939     return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
27940   }
27941 
27942   // Reduction operations return the start operand when no elements are active.
27943   if (ISD::isVPReduction(N->getOpcode()))
27944     return N->getOperand(0);
27945 
27946   return SDValue();
27947 }
27948 
visitGET_FPENV_MEM(SDNode * N)27949 SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
27950   SDValue Chain = N->getOperand(0);
27951   SDValue Ptr = N->getOperand(1);
27952   EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
27953 
27954   // Check if the memory, where FP state is written to, is used only in a single
27955   // load operation.
27956   LoadSDNode *LdNode = nullptr;
27957   for (auto *U : Ptr->users()) {
27958     if (U == N)
27959       continue;
27960     if (auto *Ld = dyn_cast<LoadSDNode>(U)) {
27961       if (LdNode && LdNode != Ld)
27962         return SDValue();
27963       LdNode = Ld;
27964       continue;
27965     }
27966     return SDValue();
27967   }
27968   if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
27969       !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
27970       !LdNode->getChain().reachesChainWithoutSideEffects(SDValue(N, 0)))
27971     return SDValue();
27972 
27973   // Check if the loaded value is used only in a store operation.
27974   StoreSDNode *StNode = nullptr;
27975   for (SDUse &U : LdNode->uses()) {
27976     if (U.getResNo() == 0) {
27977       if (auto *St = dyn_cast<StoreSDNode>(U.getUser())) {
27978         if (StNode)
27979           return SDValue();
27980         StNode = St;
27981       } else {
27982         return SDValue();
27983       }
27984     }
27985   }
27986   if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
27987       !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
27988       !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
27989     return SDValue();
27990 
27991   // Create new node GET_FPENV_MEM, which uses the store address to write FP
27992   // environment.
27993   SDValue Res = DAG.getGetFPEnv(Chain, SDLoc(N), StNode->getBasePtr(), MemVT,
27994                                 StNode->getMemOperand());
27995   CombineTo(StNode, Res, false);
27996   return Res;
27997 }
27998 
visitSET_FPENV_MEM(SDNode * N)27999 SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
28000   SDValue Chain = N->getOperand(0);
28001   SDValue Ptr = N->getOperand(1);
28002   EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
28003 
28004   // Check if the address of FP state is used also in a store operation only.
28005   StoreSDNode *StNode = nullptr;
28006   for (auto *U : Ptr->users()) {
28007     if (U == N)
28008       continue;
28009     if (auto *St = dyn_cast<StoreSDNode>(U)) {
28010       if (StNode && StNode != St)
28011         return SDValue();
28012       StNode = St;
28013       continue;
28014     }
28015     return SDValue();
28016   }
28017   if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
28018       !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
28019       !Chain.reachesChainWithoutSideEffects(SDValue(StNode, 0)))
28020     return SDValue();
28021 
28022   // Check if the stored value is loaded from some location and the loaded
28023   // value is used only in the store operation.
28024   SDValue StValue = StNode->getValue();
28025   auto *LdNode = dyn_cast<LoadSDNode>(StValue);
28026   if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
28027       !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
28028       !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
28029     return SDValue();
28030 
28031   // Create new node SET_FPENV_MEM, which uses the load address to read FP
28032   // environment.
28033   SDValue Res =
28034       DAG.getSetFPEnv(LdNode->getChain(), SDLoc(N), LdNode->getBasePtr(), MemVT,
28035                       LdNode->getMemOperand());
28036   return Res;
28037 }
28038 
28039 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
28040 /// with the destination vector and a zero vector.
28041 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
28042 ///      vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)28043 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
28044   assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
28045 
28046   EVT VT = N->getValueType(0);
28047   SDValue LHS = N->getOperand(0);
28048   SDValue RHS = peekThroughBitcasts(N->getOperand(1));
28049   SDLoc DL(N);
28050 
28051   // Make sure we're not running after operation legalization where it
28052   // may have custom lowered the vector shuffles.
28053   if (LegalOperations)
28054     return SDValue();
28055 
28056   if (RHS.getOpcode() != ISD::BUILD_VECTOR)
28057     return SDValue();
28058 
28059   EVT RVT = RHS.getValueType();
28060   unsigned NumElts = RHS.getNumOperands();
28061 
28062   // Attempt to create a valid clear mask, splitting the mask into
28063   // sub elements and checking to see if each is
28064   // all zeros or all ones - suitable for shuffle masking.
28065   auto BuildClearMask = [&](int Split) {
28066     int NumSubElts = NumElts * Split;
28067     int NumSubBits = RVT.getScalarSizeInBits() / Split;
28068 
28069     SmallVector<int, 8> Indices;
28070     for (int i = 0; i != NumSubElts; ++i) {
28071       int EltIdx = i / Split;
28072       int SubIdx = i % Split;
28073       SDValue Elt = RHS.getOperand(EltIdx);
28074       // X & undef --> 0 (not undef). So this lane must be converted to choose
28075       // from the zero constant vector (same as if the element had all 0-bits).
28076       if (Elt.isUndef()) {
28077         Indices.push_back(i + NumSubElts);
28078         continue;
28079       }
28080 
28081       std::optional<APInt> Bits = Elt->bitcastToAPInt();
28082       if (!Bits)
28083         return SDValue();
28084 
28085       // Extract the sub element from the constant bit mask.
28086       if (DAG.getDataLayout().isBigEndian())
28087         *Bits =
28088             Bits->extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
28089       else
28090         *Bits = Bits->extractBits(NumSubBits, SubIdx * NumSubBits);
28091 
28092       if (Bits->isAllOnes())
28093         Indices.push_back(i);
28094       else if (*Bits == 0)
28095         Indices.push_back(i + NumSubElts);
28096       else
28097         return SDValue();
28098     }
28099 
28100     // Let's see if the target supports this vector_shuffle.
28101     EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
28102     EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
28103     if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
28104       return SDValue();
28105 
28106     SDValue Zero = DAG.getConstant(0, DL, ClearVT);
28107     return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
28108                                                    DAG.getBitcast(ClearVT, LHS),
28109                                                    Zero, Indices));
28110   };
28111 
28112   // Determine maximum split level (byte level masking).
28113   int MaxSplit = 1;
28114   if (RVT.getScalarSizeInBits() % 8 == 0)
28115     MaxSplit = RVT.getScalarSizeInBits() / 8;
28116 
28117   for (int Split = 1; Split <= MaxSplit; ++Split)
28118     if (RVT.getScalarSizeInBits() % Split == 0)
28119       if (SDValue S = BuildClearMask(Split))
28120         return S;
28121 
28122   return SDValue();
28123 }
28124 
28125 /// If a vector binop is performed on splat values, it may be profitable to
28126 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL,bool LegalTypes)28127 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
28128                                       const SDLoc &DL, bool LegalTypes) {
28129   SDValue N0 = N->getOperand(0);
28130   SDValue N1 = N->getOperand(1);
28131   unsigned Opcode = N->getOpcode();
28132   EVT VT = N->getValueType(0);
28133   EVT EltVT = VT.getVectorElementType();
28134   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28135 
28136   // TODO: Remove/replace the extract cost check? If the elements are available
28137   //       as scalars, then there may be no extract cost. Should we ask if
28138   //       inserting a scalar back into a vector is cheap instead?
28139   int Index0, Index1;
28140   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
28141   SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
28142   // Extract element from splat_vector should be free.
28143   // TODO: use DAG.isSplatValue instead?
28144   bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
28145                            N1.getOpcode() == ISD::SPLAT_VECTOR;
28146   if (!Src0 || !Src1 || Index0 != Index1 ||
28147       Src0.getValueType().getVectorElementType() != EltVT ||
28148       Src1.getValueType().getVectorElementType() != EltVT ||
28149       !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
28150       // If before type legalization, allow scalar types that will eventually be
28151       // made legal.
28152       !TLI.isOperationLegalOrCustom(
28153           Opcode, LegalTypes
28154                       ? EltVT
28155                       : TLI.getTypeToTransformTo(*DAG.getContext(), EltVT)))
28156     return SDValue();
28157 
28158   // FIXME: Type legalization can't handle illegal MULHS/MULHU.
28159   if ((Opcode == ISD::MULHS || Opcode == ISD::MULHU) && !TLI.isTypeLegal(EltVT))
28160     return SDValue();
28161 
28162   if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode()) {
28163     // All but one element should have an undef input, which will fold to a
28164     // constant or undef. Avoid splatting which would over-define potentially
28165     // undefined elements.
28166 
28167     // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
28168     //   build_vec ..undef, (bo X, Y), undef...
28169     SmallVector<SDValue, 16> EltsX, EltsY, EltsResult;
28170     DAG.ExtractVectorElements(Src0, EltsX);
28171     DAG.ExtractVectorElements(Src1, EltsY);
28172 
28173     for (auto [X, Y] : zip(EltsX, EltsY))
28174       EltsResult.push_back(DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags()));
28175     return DAG.getBuildVector(VT, DL, EltsResult);
28176   }
28177 
28178   SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
28179   SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
28180   SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
28181   SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
28182 
28183   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
28184   return DAG.getSplat(VT, DL, ScalarBO);
28185 }
28186 
28187 /// Visit a vector cast operation, like FP_EXTEND.
SimplifyVCastOp(SDNode * N,const SDLoc & DL)28188 SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
28189   EVT VT = N->getValueType(0);
28190   assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
28191   EVT EltVT = VT.getVectorElementType();
28192   unsigned Opcode = N->getOpcode();
28193 
28194   SDValue N0 = N->getOperand(0);
28195   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28196 
28197   // TODO: promote operation might be also good here?
28198   int Index0;
28199   SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
28200   if (Src0 &&
28201       (N0.getOpcode() == ISD::SPLAT_VECTOR ||
28202        TLI.isExtractVecEltCheap(VT, Index0)) &&
28203       TLI.isOperationLegalOrCustom(Opcode, EltVT) &&
28204       TLI.preferScalarizeSplat(N)) {
28205     EVT SrcVT = N0.getValueType();
28206     EVT SrcEltVT = SrcVT.getVectorElementType();
28207     SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
28208     SDValue Elt =
28209         DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC);
28210     SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags());
28211     if (VT.isScalableVector())
28212       return DAG.getSplatVector(VT, DL, ScalarBO);
28213     SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
28214     return DAG.getBuildVector(VT, DL, Ops);
28215   }
28216 
28217   return SDValue();
28218 }
28219 
28220 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)28221 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
28222   EVT VT = N->getValueType(0);
28223   assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
28224 
28225   SDValue LHS = N->getOperand(0);
28226   SDValue RHS = N->getOperand(1);
28227   unsigned Opcode = N->getOpcode();
28228   SDNodeFlags Flags = N->getFlags();
28229 
28230   // Move unary shuffles with identical masks after a vector binop:
28231   // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
28232   //   --> shuffle (VBinOp A, B), Undef, Mask
28233   // This does not require type legality checks because we are creating the
28234   // same types of operations that are in the original sequence. We do have to
28235   // restrict ops like integer div that have immediate UB (eg, div-by-zero)
28236   // though. This code is adapted from the identical transform in instcombine.
28237   if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
28238     auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
28239     auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
28240     if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
28241         LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
28242         (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
28243       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
28244                                      RHS.getOperand(0), Flags);
28245       SDValue UndefV = LHS.getOperand(1);
28246       return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
28247     }
28248 
28249     // Try to sink a splat shuffle after a binop with a uniform constant.
28250     // This is limited to cases where neither the shuffle nor the constant have
28251     // undefined elements because that could be poison-unsafe or inhibit
28252     // demanded elements analysis. It is further limited to not change a splat
28253     // of an inserted scalar because that may be optimized better by
28254     // load-folding or other target-specific behaviors.
28255     if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
28256         Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
28257         Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
28258       // binop (splat X), (splat C) --> splat (binop X, C)
28259       SDValue X = Shuf0->getOperand(0);
28260       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
28261       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
28262                                   Shuf0->getMask());
28263     }
28264     if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
28265         Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
28266         Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
28267       // binop (splat C), (splat X) --> splat (binop C, X)
28268       SDValue X = Shuf1->getOperand(0);
28269       SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
28270       return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
28271                                   Shuf1->getMask());
28272     }
28273   }
28274 
28275   // The following pattern is likely to emerge with vector reduction ops. Moving
28276   // the binary operation ahead of insertion may allow using a narrower vector
28277   // instruction that has better performance than the wide version of the op:
28278   // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
28279   if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
28280       RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
28281       LHS.getOperand(2) == RHS.getOperand(2) &&
28282       (LHS.hasOneUse() || RHS.hasOneUse())) {
28283     SDValue X = LHS.getOperand(1);
28284     SDValue Y = RHS.getOperand(1);
28285     SDValue Z = LHS.getOperand(2);
28286     EVT NarrowVT = X.getValueType();
28287     if (NarrowVT == Y.getValueType() &&
28288         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
28289                                               LegalOperations)) {
28290       // (binop undef, undef) may not return undef, so compute that result.
28291       SDValue VecC =
28292           DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
28293       SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
28294       return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
28295     }
28296   }
28297 
28298   // Make sure all but the first op are undef or constant.
28299   auto ConcatWithConstantOrUndef = [](SDValue Concat) {
28300     return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
28301            all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
28302              return Op.isUndef() ||
28303                     ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
28304            });
28305   };
28306 
28307   // The following pattern is likely to emerge with vector reduction ops. Moving
28308   // the binary operation ahead of the concat may allow using a narrower vector
28309   // instruction that has better performance than the wide version of the op:
28310   // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
28311   //   concat (VBinOp X, Y), VecC
28312   if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
28313       (LHS.hasOneUse() || RHS.hasOneUse())) {
28314     EVT NarrowVT = LHS.getOperand(0).getValueType();
28315     if (NarrowVT == RHS.getOperand(0).getValueType() &&
28316         TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
28317       unsigned NumOperands = LHS.getNumOperands();
28318       SmallVector<SDValue, 4> ConcatOps;
28319       for (unsigned i = 0; i != NumOperands; ++i) {
28320         // This constant fold for operands 1 and up.
28321         ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
28322                                         RHS.getOperand(i)));
28323       }
28324 
28325       return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
28326     }
28327   }
28328 
28329   if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL, LegalTypes))
28330     return V;
28331 
28332   return SDValue();
28333 }
28334 
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)28335 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
28336                                     SDValue N2) {
28337   assert(N0.getOpcode() == ISD::SETCC &&
28338          "First argument must be a SetCC node!");
28339 
28340   SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
28341                                  cast<CondCodeSDNode>(N0.getOperand(2))->get());
28342 
28343   // If we got a simplified select_cc node back from SimplifySelectCC, then
28344   // break it down into a new SETCC node, and a new SELECT node, and then return
28345   // the SELECT node, since we were called with a SELECT node.
28346   if (SCC.getNode()) {
28347     // Check to see if we got a select_cc back (to turn into setcc/select).
28348     // Otherwise, just return whatever node we got back, like fabs.
28349     if (SCC.getOpcode() == ISD::SELECT_CC) {
28350       const SDNodeFlags Flags = N0->getFlags();
28351       SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
28352                                   N0.getValueType(),
28353                                   SCC.getOperand(0), SCC.getOperand(1),
28354                                   SCC.getOperand(4), Flags);
28355       AddToWorklist(SETCC.getNode());
28356       return DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
28357                            SCC.getOperand(2), SCC.getOperand(3), Flags);
28358     }
28359 
28360     return SCC;
28361   }
28362   return SDValue();
28363 }
28364 
28365 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
28366 /// being selected between, see if we can simplify the select.  Callers of this
28367 /// should assume that TheSelect is deleted if this returns true.  As such, they
28368 /// should return the appropriate thing (e.g. the node) back to the top-level of
28369 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)28370 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
28371                                     SDValue RHS) {
28372   // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
28373   // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
28374   if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
28375     if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
28376       // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
28377       SDValue Sqrt = RHS;
28378       ISD::CondCode CC;
28379       SDValue CmpLHS;
28380       const ConstantFPSDNode *Zero = nullptr;
28381 
28382       if (TheSelect->getOpcode() == ISD::SELECT_CC) {
28383         CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
28384         CmpLHS = TheSelect->getOperand(0);
28385         Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
28386       } else {
28387         // SELECT or VSELECT
28388         SDValue Cmp = TheSelect->getOperand(0);
28389         if (Cmp.getOpcode() == ISD::SETCC) {
28390           CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
28391           CmpLHS = Cmp.getOperand(0);
28392           Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
28393         }
28394       }
28395       if (Zero && Zero->isZero() &&
28396           Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
28397           CC == ISD::SETULT || CC == ISD::SETLT)) {
28398         // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
28399         CombineTo(TheSelect, Sqrt);
28400         return true;
28401       }
28402     }
28403   }
28404   // Cannot simplify select with vector condition
28405   if (TheSelect->getOperand(0).getValueType().isVector()) return false;
28406 
28407   // If this is a select from two identical things, try to pull the operation
28408   // through the select.
28409   if (LHS.getOpcode() != RHS.getOpcode() ||
28410       !LHS.hasOneUse() || !RHS.hasOneUse())
28411     return false;
28412 
28413   // If this is a load and the token chain is identical, replace the select
28414   // of two loads with a load through a select of the address to load from.
28415   // This triggers in things like "select bool X, 10.0, 123.0" after the FP
28416   // constants have been dropped into the constant pool.
28417   if (LHS.getOpcode() == ISD::LOAD) {
28418     LoadSDNode *LLD = cast<LoadSDNode>(LHS);
28419     LoadSDNode *RLD = cast<LoadSDNode>(RHS);
28420 
28421     // Token chains must be identical.
28422     if (LHS.getOperand(0) != RHS.getOperand(0) ||
28423         // Do not let this transformation reduce the number of volatile loads.
28424         // Be conservative for atomics for the moment
28425         // TODO: This does appear to be legal for unordered atomics (see D66309)
28426         !LLD->isSimple() || !RLD->isSimple() ||
28427         // FIXME: If either is a pre/post inc/dec load,
28428         // we'd need to split out the address adjustment.
28429         LLD->isIndexed() || RLD->isIndexed() ||
28430         // If this is an EXTLOAD, the VT's must match.
28431         LLD->getMemoryVT() != RLD->getMemoryVT() ||
28432         // If this is an EXTLOAD, the kind of extension must match.
28433         (LLD->getExtensionType() != RLD->getExtensionType() &&
28434          // The only exception is if one of the extensions is anyext.
28435          LLD->getExtensionType() != ISD::EXTLOAD &&
28436          RLD->getExtensionType() != ISD::EXTLOAD) ||
28437         // FIXME: this discards src value information.  This is
28438         // over-conservative. It would be beneficial to be able to remember
28439         // both potential memory locations.  Since we are discarding
28440         // src value info, don't do the transformation if the memory
28441         // locations are not in the default address space.
28442         LLD->getPointerInfo().getAddrSpace() != 0 ||
28443         RLD->getPointerInfo().getAddrSpace() != 0 ||
28444         // We can't produce a CMOV of a TargetFrameIndex since we won't
28445         // generate the address generation required.
28446         LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
28447         RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
28448         !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
28449                                       LLD->getBasePtr().getValueType()))
28450       return false;
28451 
28452     // The loads must not depend on one another.
28453     if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
28454       return false;
28455 
28456     // Check that the select condition doesn't reach either load.  If so,
28457     // folding this will induce a cycle into the DAG.  If not, this is safe to
28458     // xform, so create a select of the addresses.
28459 
28460     SmallPtrSet<const SDNode *, 32> Visited;
28461     SmallVector<const SDNode *, 16> Worklist;
28462 
28463     // Always fail if LLD and RLD are not independent. TheSelect is a
28464     // predecessor to all Nodes in question so we need not search past it.
28465 
28466     Visited.insert(TheSelect);
28467     Worklist.push_back(LLD);
28468     Worklist.push_back(RLD);
28469 
28470     if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
28471         SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
28472       return false;
28473 
28474     SDValue Addr;
28475     if (TheSelect->getOpcode() == ISD::SELECT) {
28476       // We cannot do this optimization if any pair of {RLD, LLD} is a
28477       // predecessor to {RLD, LLD, CondNode}. As we've already compared the
28478       // Loads, we only need to check if CondNode is a successor to one of the
28479       // loads. We can further avoid this if there's no use of their chain
28480       // value.
28481       SDNode *CondNode = TheSelect->getOperand(0).getNode();
28482       Worklist.push_back(CondNode);
28483 
28484       if ((LLD->hasAnyUseOfValue(1) &&
28485            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
28486           (RLD->hasAnyUseOfValue(1) &&
28487            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
28488         return false;
28489 
28490       Addr = DAG.getSelect(SDLoc(TheSelect),
28491                            LLD->getBasePtr().getValueType(),
28492                            TheSelect->getOperand(0), LLD->getBasePtr(),
28493                            RLD->getBasePtr());
28494     } else {  // Otherwise SELECT_CC
28495       // We cannot do this optimization if any pair of {RLD, LLD} is a
28496       // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
28497       // the Loads, we only need to check if CondLHS/CondRHS is a successor to
28498       // one of the loads. We can further avoid this if there's no use of their
28499       // chain value.
28500 
28501       SDNode *CondLHS = TheSelect->getOperand(0).getNode();
28502       SDNode *CondRHS = TheSelect->getOperand(1).getNode();
28503       Worklist.push_back(CondLHS);
28504       Worklist.push_back(CondRHS);
28505 
28506       if ((LLD->hasAnyUseOfValue(1) &&
28507            SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
28508           (RLD->hasAnyUseOfValue(1) &&
28509            SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
28510         return false;
28511 
28512       Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
28513                          LLD->getBasePtr().getValueType(),
28514                          TheSelect->getOperand(0),
28515                          TheSelect->getOperand(1),
28516                          LLD->getBasePtr(), RLD->getBasePtr(),
28517                          TheSelect->getOperand(4));
28518     }
28519 
28520     SDValue Load;
28521     // It is safe to replace the two loads if they have different alignments,
28522     // but the new load must be the minimum (most restrictive) alignment of the
28523     // inputs.
28524     Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
28525     MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
28526     if (!RLD->isInvariant())
28527       MMOFlags &= ~MachineMemOperand::MOInvariant;
28528     if (!RLD->isDereferenceable())
28529       MMOFlags &= ~MachineMemOperand::MODereferenceable;
28530     if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
28531       // FIXME: Discards pointer and AA info.
28532       Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
28533                          LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
28534                          MMOFlags);
28535     } else {
28536       // FIXME: Discards pointer and AA info.
28537       Load = DAG.getExtLoad(
28538           LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
28539                                                   : LLD->getExtensionType(),
28540           SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
28541           MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
28542     }
28543 
28544     // Users of the select now use the result of the load.
28545     CombineTo(TheSelect, Load);
28546 
28547     // Users of the old loads now use the new load's chain.  We know the
28548     // old-load value is dead now.
28549     CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
28550     CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
28551     return true;
28552   }
28553 
28554   return false;
28555 }
28556 
28557 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
28558 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)28559 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
28560                                             SDValue N1, SDValue N2, SDValue N3,
28561                                             ISD::CondCode CC) {
28562   // If this is a select where the false operand is zero and the compare is a
28563   // check of the sign bit, see if we can perform the "gzip trick":
28564   // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
28565   // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
28566   EVT XType = N0.getValueType();
28567   EVT AType = N2.getValueType();
28568   if (!isNullConstant(N3) || !XType.bitsGE(AType))
28569     return SDValue();
28570 
28571   // If the comparison is testing for a positive value, we have to invert
28572   // the sign bit mask, so only do that transform if the target has a bitwise
28573   // 'and not' instruction (the invert is free).
28574   if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
28575     // (X > -1) ? A : 0
28576     // (X >  0) ? X : 0 <-- This is canonical signed max.
28577     if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
28578       return SDValue();
28579   } else if (CC == ISD::SETLT) {
28580     // (X <  0) ? A : 0
28581     // (X <  1) ? X : 0 <-- This is un-canonicalized signed min.
28582     if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
28583       return SDValue();
28584   } else {
28585     return SDValue();
28586   }
28587 
28588   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
28589   // constant.
28590   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
28591   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
28592     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
28593     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
28594       SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
28595       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
28596       AddToWorklist(Shift.getNode());
28597 
28598       if (XType.bitsGT(AType)) {
28599         Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
28600         AddToWorklist(Shift.getNode());
28601       }
28602 
28603       if (CC == ISD::SETGT)
28604         Shift = DAG.getNOT(DL, Shift, AType);
28605 
28606       return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
28607     }
28608   }
28609 
28610   unsigned ShCt = XType.getSizeInBits() - 1;
28611   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
28612     return SDValue();
28613 
28614   SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
28615   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
28616   AddToWorklist(Shift.getNode());
28617 
28618   if (XType.bitsGT(AType)) {
28619     Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
28620     AddToWorklist(Shift.getNode());
28621   }
28622 
28623   if (CC == ISD::SETGT)
28624     Shift = DAG.getNOT(DL, Shift, AType);
28625 
28626   return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
28627 }
28628 
28629 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)28630 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
28631   SDValue N0 = N->getOperand(0);
28632   SDValue N1 = N->getOperand(1);
28633   SDValue N2 = N->getOperand(2);
28634   SDLoc DL(N);
28635 
28636   unsigned BinOpc = N1.getOpcode();
28637   if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc) ||
28638       (N1.getResNo() != N2.getResNo()))
28639     return SDValue();
28640 
28641   // The use checks are intentionally on SDNode because we may be dealing
28642   // with opcodes that produce more than one SDValue.
28643   // TODO: Do we really need to check N0 (the condition operand of the select)?
28644   //       But removing that clause could cause an infinite loop...
28645   if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
28646     return SDValue();
28647 
28648   // Binops may include opcodes that return multiple values, so all values
28649   // must be created/propagated from the newly created binops below.
28650   SDVTList OpVTs = N1->getVTList();
28651 
28652   // Fold select(cond, binop(x, y), binop(z, y))
28653   //  --> binop(select(cond, x, z), y)
28654   if (N1.getOperand(1) == N2.getOperand(1)) {
28655     SDValue N10 = N1.getOperand(0);
28656     SDValue N20 = N2.getOperand(0);
28657     SDValue NewSel = DAG.getSelect(DL, N10.getValueType(), N0, N10, N20);
28658     SDNodeFlags Flags = N1->getFlags() & N2->getFlags();
28659     SDValue NewBinOp =
28660         DAG.getNode(BinOpc, DL, OpVTs, {NewSel, N1.getOperand(1)}, Flags);
28661     return SDValue(NewBinOp.getNode(), N1.getResNo());
28662   }
28663 
28664   // Fold select(cond, binop(x, y), binop(x, z))
28665   //  --> binop(x, select(cond, y, z))
28666   if (N1.getOperand(0) == N2.getOperand(0)) {
28667     SDValue N11 = N1.getOperand(1);
28668     SDValue N21 = N2.getOperand(1);
28669     // Second op VT might be different (e.g. shift amount type)
28670     if (N11.getValueType() == N21.getValueType()) {
28671       SDValue NewSel = DAG.getSelect(DL, N11.getValueType(), N0, N11, N21);
28672       SDNodeFlags Flags = N1->getFlags() & N2->getFlags();
28673       SDValue NewBinOp =
28674           DAG.getNode(BinOpc, DL, OpVTs, {N1.getOperand(0), NewSel}, Flags);
28675       return SDValue(NewBinOp.getNode(), N1.getResNo());
28676     }
28677   }
28678 
28679   // TODO: Handle isCommutativeBinOp patterns as well?
28680   return SDValue();
28681 }
28682 
28683 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)28684 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
28685   SDValue N0 = N->getOperand(0);
28686   EVT VT = N->getValueType(0);
28687   bool IsFabs = N->getOpcode() == ISD::FABS;
28688   bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
28689 
28690   if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
28691     return SDValue();
28692 
28693   SDValue Int = N0.getOperand(0);
28694   EVT IntVT = Int.getValueType();
28695 
28696   // The operand to cast should be integer.
28697   if (!IntVT.isInteger() || IntVT.isVector())
28698     return SDValue();
28699 
28700   // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
28701   // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
28702   APInt SignMask;
28703   if (N0.getValueType().isVector()) {
28704     // For vector, create a sign mask (0x80...) or its inverse (for fabs,
28705     // 0x7f...) per element and splat it.
28706     SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
28707     if (IsFabs)
28708       SignMask = ~SignMask;
28709     SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
28710   } else {
28711     // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
28712     SignMask = APInt::getSignMask(IntVT.getSizeInBits());
28713     if (IsFabs)
28714       SignMask = ~SignMask;
28715   }
28716   SDLoc DL(N0);
28717   Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
28718                     DAG.getConstant(SignMask, DL, IntVT));
28719   AddToWorklist(Int.getNode());
28720   return DAG.getBitcast(VT, Int);
28721 }
28722 
28723 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
28724 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
28725 /// in it. This may be a win when the constant is not otherwise available
28726 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)28727 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
28728     const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
28729     ISD::CondCode CC) {
28730   if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
28731     return SDValue();
28732 
28733   // If we are before legalize types, we want the other legalization to happen
28734   // first (for example, to avoid messing with soft float).
28735   auto *TV = dyn_cast<ConstantFPSDNode>(N2);
28736   auto *FV = dyn_cast<ConstantFPSDNode>(N3);
28737   EVT VT = N2.getValueType();
28738   if (!TV || !FV || !TLI.isTypeLegal(VT))
28739     return SDValue();
28740 
28741   // If a constant can be materialized without loads, this does not make sense.
28742   if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
28743       TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
28744       TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
28745     return SDValue();
28746 
28747   // If both constants have multiple uses, then we won't need to do an extra
28748   // load. The values are likely around in registers for other users.
28749   if (!TV->hasOneUse() && !FV->hasOneUse())
28750     return SDValue();
28751 
28752   Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
28753                        const_cast<ConstantFP*>(TV->getConstantFPValue()) };
28754   Type *FPTy = Elts[0]->getType();
28755   const DataLayout &TD = DAG.getDataLayout();
28756 
28757   // Create a ConstantArray of the two constants.
28758   Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
28759   SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
28760                                       TD.getPrefTypeAlign(FPTy));
28761   Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
28762 
28763   // Get offsets to the 0 and 1 elements of the array, so we can select between
28764   // them.
28765   SDValue Zero = DAG.getIntPtrConstant(0, DL);
28766   unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
28767   SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
28768   SDValue Cond =
28769       DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
28770   AddToWorklist(Cond.getNode());
28771   SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
28772   AddToWorklist(CstOffset.getNode());
28773   CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
28774   AddToWorklist(CPIdx.getNode());
28775   return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
28776                      MachinePointerInfo::getConstantPool(
28777                          DAG.getMachineFunction()), Alignment);
28778 }
28779 
28780 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
28781 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)28782 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
28783                                       SDValue N2, SDValue N3, ISD::CondCode CC,
28784                                       bool NotExtCompare) {
28785   // (x ? y : y) -> y.
28786   if (N2 == N3) return N2;
28787 
28788   EVT CmpOpVT = N0.getValueType();
28789   EVT CmpResVT = getSetCCResultType(CmpOpVT);
28790   EVT VT = N2.getValueType();
28791   auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
28792   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
28793   auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
28794 
28795   // Determine if the condition we're dealing with is constant.
28796   if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
28797     AddToWorklist(SCC.getNode());
28798     if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
28799       // fold select_cc true, x, y -> x
28800       // fold select_cc false, x, y -> y
28801       return !(SCCC->isZero()) ? N2 : N3;
28802     }
28803   }
28804 
28805   if (SDValue V =
28806           convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
28807     return V;
28808 
28809   if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
28810     return V;
28811 
28812   // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
28813   // where y is has a single bit set.
28814   // A plaintext description would be, we can turn the SELECT_CC into an AND
28815   // when the condition can be materialized as an all-ones register.  Any
28816   // single bit-test can be materialized as an all-ones register with
28817   // shift-left and shift-right-arith.
28818   if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
28819       N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
28820     SDValue AndLHS = N0->getOperand(0);
28821     auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
28822     if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
28823       // Shift the tested bit over the sign bit.
28824       const APInt &AndMask = ConstAndRHS->getAPIntValue();
28825       if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
28826         unsigned ShCt = AndMask.getBitWidth() - 1;
28827         SDValue ShlAmt = DAG.getShiftAmountConstant(AndMask.countl_zero(), VT,
28828                                                     SDLoc(AndLHS));
28829         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
28830 
28831         // Now arithmetic right shift it all the way over, so the result is
28832         // either all-ones, or zero.
28833         SDValue ShrAmt = DAG.getShiftAmountConstant(ShCt, VT, SDLoc(Shl));
28834         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
28835 
28836         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
28837       }
28838     }
28839   }
28840 
28841   // fold select C, 16, 0 -> shl C, 4
28842   bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
28843   bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
28844 
28845   if ((Fold || Swap) &&
28846       TLI.getBooleanContents(CmpOpVT) ==
28847           TargetLowering::ZeroOrOneBooleanContent &&
28848       (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT)) &&
28849       TLI.convertSelectOfConstantsToMath(VT)) {
28850 
28851     if (Swap) {
28852       CC = ISD::getSetCCInverse(CC, CmpOpVT);
28853       std::swap(N2C, N3C);
28854     }
28855 
28856     // If the caller doesn't want us to simplify this into a zext of a compare,
28857     // don't do it.
28858     if (NotExtCompare && N2C->isOne())
28859       return SDValue();
28860 
28861     SDValue Temp, SCC;
28862     // zext (setcc n0, n1)
28863     if (LegalTypes) {
28864       SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
28865       Temp = DAG.getZExtOrTrunc(SCC, SDLoc(N2), VT);
28866     } else {
28867       SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
28868       Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
28869     }
28870 
28871     AddToWorklist(SCC.getNode());
28872     AddToWorklist(Temp.getNode());
28873 
28874     if (N2C->isOne())
28875       return Temp;
28876 
28877     unsigned ShCt = N2C->getAPIntValue().logBase2();
28878     if (TLI.shouldAvoidTransformToShift(VT, ShCt))
28879       return SDValue();
28880 
28881     // shl setcc result by log2 n2c
28882     return DAG.getNode(
28883         ISD::SHL, DL, N2.getValueType(), Temp,
28884         DAG.getShiftAmountConstant(ShCt, N2.getValueType(), SDLoc(Temp)));
28885   }
28886 
28887   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
28888   // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
28889   // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
28890   // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
28891   // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
28892   // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
28893   // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
28894   // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
28895   if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
28896     SDValue ValueOnZero = N2;
28897     SDValue Count = N3;
28898     // If the condition is NE instead of E, swap the operands.
28899     if (CC == ISD::SETNE)
28900       std::swap(ValueOnZero, Count);
28901     // Check if the value on zero is a constant equal to the bits in the type.
28902     if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
28903       if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
28904         // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
28905         // legal, combine to just cttz.
28906         if ((Count.getOpcode() == ISD::CTTZ ||
28907              Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
28908             N0 == Count.getOperand(0) &&
28909             (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
28910           return DAG.getNode(ISD::CTTZ, DL, VT, N0);
28911         // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
28912         // legal, combine to just ctlz.
28913         if ((Count.getOpcode() == ISD::CTLZ ||
28914              Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
28915             N0 == Count.getOperand(0) &&
28916             (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
28917           return DAG.getNode(ISD::CTLZ, DL, VT, N0);
28918       }
28919     }
28920   }
28921 
28922   // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
28923   // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
28924   if (!NotExtCompare && N1C && N2C && N3C &&
28925       N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
28926       ((N1C->isAllOnes() && CC == ISD::SETGT) ||
28927        (N1C->isZero() && CC == ISD::SETLT)) &&
28928       !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
28929     SDValue ASR = DAG.getNode(
28930         ISD::SRA, DL, CmpOpVT, N0,
28931         DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
28932     return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
28933                        DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
28934   }
28935 
28936   if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
28937     return S;
28938   if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
28939     return S;
28940   if (SDValue ABD = foldSelectToABD(N0, N1, N2, N3, CC, DL))
28941     return ABD;
28942 
28943   return SDValue();
28944 }
28945 
28946 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)28947 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
28948                                    ISD::CondCode Cond, const SDLoc &DL,
28949                                    bool foldBooleans) {
28950   TargetLowering::DAGCombinerInfo
28951     DagCombineInfo(DAG, Level, false, this);
28952   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
28953 }
28954 
28955 /// Given an ISD::SDIV node expressing a divide by constant, return
28956 /// a DAG expression to select that will generate the same value by multiplying
28957 /// by a magic number.
28958 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)28959 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
28960   // when optimising for minimum size, we don't want to expand a div to a mul
28961   // and a shift.
28962   if (DAG.getMachineFunction().getFunction().hasMinSize())
28963     return SDValue();
28964 
28965   SmallVector<SDNode *, 8> Built;
28966   if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, LegalTypes, Built)) {
28967     for (SDNode *N : Built)
28968       AddToWorklist(N);
28969     return S;
28970   }
28971 
28972   return SDValue();
28973 }
28974 
28975 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
28976 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)28977 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
28978   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
28979   if (!C)
28980     return SDValue();
28981 
28982   // Avoid division by zero.
28983   if (C->isZero())
28984     return SDValue();
28985 
28986   SmallVector<SDNode *, 8> Built;
28987   if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
28988     for (SDNode *N : Built)
28989       AddToWorklist(N);
28990     return S;
28991   }
28992 
28993   return SDValue();
28994 }
28995 
28996 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
28997 /// expression that will generate the same value by multiplying by a magic
28998 /// number.
28999 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)29000 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
29001   // when optimising for minimum size, we don't want to expand a div to a mul
29002   // and a shift.
29003   if (DAG.getMachineFunction().getFunction().hasMinSize())
29004     return SDValue();
29005 
29006   SmallVector<SDNode *, 8> Built;
29007   if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, LegalTypes, Built)) {
29008     for (SDNode *N : Built)
29009       AddToWorklist(N);
29010     return S;
29011   }
29012 
29013   return SDValue();
29014 }
29015 
29016 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
29017 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)29018 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
29019   ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
29020   if (!C)
29021     return SDValue();
29022 
29023   // Avoid division by zero.
29024   if (C->isZero())
29025     return SDValue();
29026 
29027   SmallVector<SDNode *, 8> Built;
29028   if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
29029     for (SDNode *N : Built)
29030       AddToWorklist(N);
29031     return S;
29032   }
29033 
29034   return SDValue();
29035 }
29036 
29037 // This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
29038 //
29039 // Returns the node that represents `Log2(Op)`. This may create a new node. If
29040 // we are unable to compute `Log2(Op)` its return `SDValue()`.
29041 //
29042 // All nodes will be created at `DL` and the output will be of type `VT`.
29043 //
29044 // This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
29045 // `AssumeNonZero` if this function should simply assume (not require proving
29046 // `Op` is non-zero).
takeInexpensiveLog2(SelectionDAG & DAG,const SDLoc & DL,EVT VT,SDValue Op,unsigned Depth,bool AssumeNonZero)29047 static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
29048                                    SDValue Op, unsigned Depth,
29049                                    bool AssumeNonZero) {
29050   assert(VT.isInteger() && "Only integer types are supported!");
29051 
29052   auto PeekThroughCastsAndTrunc = [](SDValue V) {
29053     while (true) {
29054       switch (V.getOpcode()) {
29055       case ISD::TRUNCATE:
29056       case ISD::ZERO_EXTEND:
29057         V = V.getOperand(0);
29058         break;
29059       default:
29060         return V;
29061       }
29062     }
29063   };
29064 
29065   if (VT.isScalableVector())
29066     return SDValue();
29067 
29068   Op = PeekThroughCastsAndTrunc(Op);
29069 
29070   // Helper for determining whether a value is a power-2 constant scalar or a
29071   // vector of such elements.
29072   SmallVector<APInt> Pow2Constants;
29073   auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
29074     if (C->isZero() || C->isOpaque())
29075       return false;
29076     // TODO: We may also be able to support negative powers of 2 here.
29077     if (C->getAPIntValue().isPowerOf2()) {
29078       Pow2Constants.emplace_back(C->getAPIntValue());
29079       return true;
29080     }
29081     return false;
29082   };
29083 
29084   if (ISD::matchUnaryPredicate(Op, IsPowerOfTwo)) {
29085     if (!VT.isVector())
29086       return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
29087     // We need to create a build vector
29088     if (Op.getOpcode() == ISD::SPLAT_VECTOR)
29089       return DAG.getSplat(VT, DL,
29090                           DAG.getConstant(Pow2Constants.back().logBase2(), DL,
29091                                           VT.getScalarType()));
29092     SmallVector<SDValue> Log2Ops;
29093     for (const APInt &Pow2 : Pow2Constants)
29094       Log2Ops.emplace_back(
29095           DAG.getConstant(Pow2.logBase2(), DL, VT.getScalarType()));
29096     return DAG.getBuildVector(VT, DL, Log2Ops);
29097   }
29098 
29099   if (Depth >= DAG.MaxRecursionDepth)
29100     return SDValue();
29101 
29102   auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
29103     // Peek through zero extend. We can't peek through truncates since this
29104     // function is called on a shift amount. We must ensure that all of the bits
29105     // above the original shift amount are zeroed by this function.
29106     while (ToCast.getOpcode() == ISD::ZERO_EXTEND)
29107       ToCast = ToCast.getOperand(0);
29108     EVT CurVT = ToCast.getValueType();
29109     if (NewVT == CurVT)
29110       return ToCast;
29111 
29112     if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
29113       return DAG.getBitcast(NewVT, ToCast);
29114 
29115     return DAG.getZExtOrTrunc(ToCast, DL, NewVT);
29116   };
29117 
29118   // log2(X << Y) -> log2(X) + Y
29119   if (Op.getOpcode() == ISD::SHL) {
29120     // 1 << Y and X nuw/nsw << Y are all non-zero.
29121     if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
29122         Op->getFlags().hasNoSignedWrap() || isOneConstant(Op.getOperand(0)))
29123       if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0),
29124                                              Depth + 1, AssumeNonZero))
29125         return DAG.getNode(ISD::ADD, DL, VT, LogX,
29126                            CastToVT(VT, Op.getOperand(1)));
29127   }
29128 
29129   // c ? X : Y -> c ? Log2(X) : Log2(Y)
29130   if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
29131       Op.hasOneUse()) {
29132     if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
29133                                            Depth + 1, AssumeNonZero))
29134       if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
29135                                              Depth + 1, AssumeNonZero))
29136         return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
29137   }
29138 
29139   // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
29140   // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
29141   if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
29142       Op.hasOneUse()) {
29143     // Use AssumeNonZero as false here. Otherwise we can hit case where
29144     // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
29145     if (SDValue LogX =
29146             takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), Depth + 1,
29147                                 /*AssumeNonZero*/ false))
29148       if (SDValue LogY =
29149               takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), Depth + 1,
29150                                   /*AssumeNonZero*/ false))
29151         return DAG.getNode(Op.getOpcode(), DL, VT, LogX, LogY);
29152   }
29153 
29154   return SDValue();
29155 }
29156 
29157 /// Determines the LogBase2 value for a non-null input value using the
29158 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL,bool KnownNonZero,bool InexpensiveOnly,std::optional<EVT> OutVT)29159 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
29160                                    bool KnownNonZero, bool InexpensiveOnly,
29161                                    std::optional<EVT> OutVT) {
29162   EVT VT = OutVT ? *OutVT : V.getValueType();
29163   SDValue InexpensiveLogBase2 =
29164       takeInexpensiveLog2(DAG, DL, VT, V, /*Depth*/ 0, KnownNonZero);
29165   if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(V))
29166     return InexpensiveLogBase2;
29167 
29168   SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
29169   SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
29170   SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
29171   return LogBase2;
29172 }
29173 
29174 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29175 /// For the reciprocal, we need to find the zero of the function:
29176 ///   F(X) = 1/X - A [which has a zero at X = 1/A]
29177 ///     =>
29178 ///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
29179 ///     does not require additional intermediate precision]
29180 /// For the last iteration, put numerator N into it to gain more precision:
29181 ///   Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)29182 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
29183                                       SDNodeFlags Flags) {
29184   if (LegalDAG)
29185     return SDValue();
29186 
29187   // TODO: Handle extended types?
29188   EVT VT = Op.getValueType();
29189   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
29190       VT.getScalarType() != MVT::f64)
29191     return SDValue();
29192 
29193   // If estimates are explicitly disabled for this function, we're done.
29194   MachineFunction &MF = DAG.getMachineFunction();
29195   int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
29196   if (Enabled == TLI.ReciprocalEstimate::Disabled)
29197     return SDValue();
29198 
29199   // Estimates may be explicitly enabled for this type with a custom number of
29200   // refinement steps.
29201   int Iterations = TLI.getDivRefinementSteps(VT, MF);
29202   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
29203     AddToWorklist(Est.getNode());
29204 
29205     SDLoc DL(Op);
29206     if (Iterations) {
29207       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
29208 
29209       // Newton iterations: Est = Est + Est (N - Arg * Est)
29210       // If this is the last iteration, also multiply by the numerator.
29211       for (int i = 0; i < Iterations; ++i) {
29212         SDValue MulEst = Est;
29213 
29214         if (i == Iterations - 1) {
29215           MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
29216           AddToWorklist(MulEst.getNode());
29217         }
29218 
29219         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
29220         AddToWorklist(NewEst.getNode());
29221 
29222         NewEst = DAG.getNode(ISD::FSUB, DL, VT,
29223                              (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
29224         AddToWorklist(NewEst.getNode());
29225 
29226         NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
29227         AddToWorklist(NewEst.getNode());
29228 
29229         Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
29230         AddToWorklist(Est.getNode());
29231       }
29232     } else {
29233       // If no iterations are available, multiply with N.
29234       Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
29235       AddToWorklist(Est.getNode());
29236     }
29237 
29238     return Est;
29239   }
29240 
29241   return SDValue();
29242 }
29243 
29244 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29245 /// For the reciprocal sqrt, we need to find the zero of the function:
29246 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
29247 ///     =>
29248 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
29249 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)29250 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
29251                                          unsigned Iterations,
29252                                          SDNodeFlags Flags, bool Reciprocal) {
29253   EVT VT = Arg.getValueType();
29254   SDLoc DL(Arg);
29255   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
29256 
29257   // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
29258   // this entire sequence requires only one FP constant.
29259   SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
29260   HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
29261 
29262   // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
29263   for (unsigned i = 0; i < Iterations; ++i) {
29264     SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
29265     NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
29266     NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
29267     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
29268   }
29269 
29270   // If non-reciprocal square root is requested, multiply the result by Arg.
29271   if (!Reciprocal)
29272     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
29273 
29274   return Est;
29275 }
29276 
29277 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29278 /// For the reciprocal sqrt, we need to find the zero of the function:
29279 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
29280 ///     =>
29281 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)29282 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
29283                                          unsigned Iterations,
29284                                          SDNodeFlags Flags, bool Reciprocal) {
29285   EVT VT = Arg.getValueType();
29286   SDLoc DL(Arg);
29287   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
29288   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
29289 
29290   // This routine must enter the loop below to work correctly
29291   // when (Reciprocal == false).
29292   assert(Iterations > 0);
29293 
29294   // Newton iterations for reciprocal square root:
29295   // E = (E * -0.5) * ((A * E) * E + -3.0)
29296   for (unsigned i = 0; i < Iterations; ++i) {
29297     SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
29298     SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
29299     SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
29300 
29301     // When calculating a square root at the last iteration build:
29302     // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
29303     // (notice a common subexpression)
29304     SDValue LHS;
29305     if (Reciprocal || (i + 1) < Iterations) {
29306       // RSQRT: LHS = (E * -0.5)
29307       LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
29308     } else {
29309       // SQRT: LHS = (A * E) * -0.5
29310       LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
29311     }
29312 
29313     Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
29314   }
29315 
29316   return Est;
29317 }
29318 
29319 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
29320 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
29321 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)29322 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
29323                                            bool Reciprocal) {
29324   if (LegalDAG)
29325     return SDValue();
29326 
29327   // TODO: Handle extended types?
29328   EVT VT = Op.getValueType();
29329   if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
29330       VT.getScalarType() != MVT::f64)
29331     return SDValue();
29332 
29333   // If estimates are explicitly disabled for this function, we're done.
29334   MachineFunction &MF = DAG.getMachineFunction();
29335   int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
29336   if (Enabled == TLI.ReciprocalEstimate::Disabled)
29337     return SDValue();
29338 
29339   // Estimates may be explicitly enabled for this type with a custom number of
29340   // refinement steps.
29341   int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
29342 
29343   bool UseOneConstNR = false;
29344   if (SDValue Est =
29345       TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
29346                           Reciprocal)) {
29347     AddToWorklist(Est.getNode());
29348 
29349     if (Iterations > 0)
29350       Est = UseOneConstNR
29351             ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
29352             : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
29353     if (!Reciprocal) {
29354       SDLoc DL(Op);
29355       // Try the target specific test first.
29356       SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
29357 
29358       // The estimate is now completely wrong if the input was exactly 0.0 or
29359       // possibly a denormal. Force the answer to 0.0 or value provided by
29360       // target for those cases.
29361       Est = DAG.getSelect(DL, VT, Test,
29362                           TLI.getSqrtResultForDenormInput(Op, DAG), Est);
29363     }
29364     return Est;
29365   }
29366 
29367   return SDValue();
29368 }
29369 
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)29370 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
29371   return buildSqrtEstimateImpl(Op, Flags, true);
29372 }
29373 
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)29374 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
29375   return buildSqrtEstimateImpl(Op, Flags, false);
29376 }
29377 
29378 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const29379 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
29380 
29381   struct MemUseCharacteristics {
29382     bool IsVolatile;
29383     bool IsAtomic;
29384     SDValue BasePtr;
29385     int64_t Offset;
29386     LocationSize NumBytes;
29387     MachineMemOperand *MMO;
29388   };
29389 
29390   auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
29391     if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
29392       int64_t Offset = 0;
29393       if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
29394         Offset = (LSN->getAddressingMode() == ISD::PRE_INC) ? C->getSExtValue()
29395                  : (LSN->getAddressingMode() == ISD::PRE_DEC)
29396                      ? -1 * C->getSExtValue()
29397                      : 0;
29398       TypeSize Size = LSN->getMemoryVT().getStoreSize();
29399       return {LSN->isVolatile(),           LSN->isAtomic(),
29400               LSN->getBasePtr(),           Offset /*base offset*/,
29401               LocationSize::precise(Size), LSN->getMemOperand()};
29402     }
29403     if (const auto *LN = cast<LifetimeSDNode>(N))
29404       return {false /*isVolatile*/,
29405               /*isAtomic*/ false,
29406               LN->getOperand(1),
29407               (LN->hasOffset()) ? LN->getOffset() : 0,
29408               (LN->hasOffset()) ? LocationSize::precise(LN->getSize())
29409                                 : LocationSize::beforeOrAfterPointer(),
29410               (MachineMemOperand *)nullptr};
29411     // Default.
29412     return {false /*isvolatile*/,
29413             /*isAtomic*/ false,
29414             SDValue(),
29415             (int64_t)0 /*offset*/,
29416             LocationSize::beforeOrAfterPointer() /*size*/,
29417             (MachineMemOperand *)nullptr};
29418   };
29419 
29420   MemUseCharacteristics MUC0 = getCharacteristics(Op0),
29421                         MUC1 = getCharacteristics(Op1);
29422 
29423   // If they are to the same address, then they must be aliases.
29424   if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
29425       MUC0.Offset == MUC1.Offset)
29426     return true;
29427 
29428   // If they are both volatile then they cannot be reordered.
29429   if (MUC0.IsVolatile && MUC1.IsVolatile)
29430     return true;
29431 
29432   // Be conservative about atomics for the moment
29433   // TODO: This is way overconservative for unordered atomics (see D66309)
29434   if (MUC0.IsAtomic && MUC1.IsAtomic)
29435     return true;
29436 
29437   if (MUC0.MMO && MUC1.MMO) {
29438     if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
29439         (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
29440       return false;
29441   }
29442 
29443   // If NumBytes is scalable and offset is not 0, conservatively return may
29444   // alias
29445   if ((MUC0.NumBytes.hasValue() && MUC0.NumBytes.isScalable() &&
29446        MUC0.Offset != 0) ||
29447       (MUC1.NumBytes.hasValue() && MUC1.NumBytes.isScalable() &&
29448        MUC1.Offset != 0))
29449     return true;
29450   // Try to prove that there is aliasing, or that there is no aliasing. Either
29451   // way, we can return now. If nothing can be proved, proceed with more tests.
29452   bool IsAlias;
29453   if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
29454                                        DAG, IsAlias))
29455     return IsAlias;
29456 
29457   // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
29458   // either are not known.
29459   if (!MUC0.MMO || !MUC1.MMO)
29460     return true;
29461 
29462   // If one operation reads from invariant memory, and the other may store, they
29463   // cannot alias. These should really be checking the equivalent of mayWrite,
29464   // but it only matters for memory nodes other than load /store.
29465   if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
29466       (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
29467     return false;
29468 
29469   // If we know required SrcValue1 and SrcValue2 have relatively large
29470   // alignment compared to the size and offset of the access, we may be able
29471   // to prove they do not alias. This check is conservative for now to catch
29472   // cases created by splitting vector types, it only works when the offsets are
29473   // multiples of the size of the data.
29474   int64_t SrcValOffset0 = MUC0.MMO->getOffset();
29475   int64_t SrcValOffset1 = MUC1.MMO->getOffset();
29476   Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
29477   Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
29478   LocationSize Size0 = MUC0.NumBytes;
29479   LocationSize Size1 = MUC1.NumBytes;
29480 
29481   if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
29482       Size0.hasValue() && Size1.hasValue() && !Size0.isScalable() &&
29483       !Size1.isScalable() && Size0 == Size1 &&
29484       OrigAlignment0 > Size0.getValue().getKnownMinValue() &&
29485       SrcValOffset0 % Size0.getValue().getKnownMinValue() == 0 &&
29486       SrcValOffset1 % Size1.getValue().getKnownMinValue() == 0) {
29487     int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
29488     int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
29489 
29490     // There is no overlap between these relatively aligned accesses of
29491     // similar size. Return no alias.
29492     if ((OffAlign0 + static_cast<int64_t>(
29493                          Size0.getValue().getKnownMinValue())) <= OffAlign1 ||
29494         (OffAlign1 + static_cast<int64_t>(
29495                          Size1.getValue().getKnownMinValue())) <= OffAlign0)
29496       return false;
29497   }
29498 
29499   bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
29500                    ? CombinerGlobalAA
29501                    : DAG.getSubtarget().useAA();
29502 #ifndef NDEBUG
29503   if (CombinerAAOnlyFunc.getNumOccurrences() &&
29504       CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
29505     UseAA = false;
29506 #endif
29507 
29508   if (UseAA && BatchAA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
29509       Size0.hasValue() && Size1.hasValue() &&
29510       // Can't represent a scalable size + fixed offset in LocationSize
29511       (!Size0.isScalable() || SrcValOffset0 == 0) &&
29512       (!Size1.isScalable() || SrcValOffset1 == 0)) {
29513     // Use alias analysis information.
29514     int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
29515     int64_t Overlap0 =
29516         Size0.getValue().getKnownMinValue() + SrcValOffset0 - MinOffset;
29517     int64_t Overlap1 =
29518         Size1.getValue().getKnownMinValue() + SrcValOffset1 - MinOffset;
29519     LocationSize Loc0 =
29520         Size0.isScalable() ? Size0 : LocationSize::precise(Overlap0);
29521     LocationSize Loc1 =
29522         Size1.isScalable() ? Size1 : LocationSize::precise(Overlap1);
29523     if (BatchAA->isNoAlias(
29524             MemoryLocation(MUC0.MMO->getValue(), Loc0,
29525                            UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
29526             MemoryLocation(MUC1.MMO->getValue(), Loc1,
29527                            UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
29528       return false;
29529   }
29530 
29531   // Otherwise we have to assume they alias.
29532   return true;
29533 }
29534 
29535 /// Walk up chain skipping non-aliasing memory nodes,
29536 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)29537 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
29538                                    SmallVectorImpl<SDValue> &Aliases) {
29539   SmallVector<SDValue, 8> Chains;     // List of chains to visit.
29540   SmallPtrSet<SDNode *, 16> Visited;  // Visited node set.
29541 
29542   // Get alias information for node.
29543   // TODO: relax aliasing for unordered atomics (see D66309)
29544   const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
29545 
29546   // Starting off.
29547   Chains.push_back(OriginalChain);
29548   unsigned Depth = 0;
29549 
29550   // Attempt to improve chain by a single step
29551   auto ImproveChain = [&](SDValue &C) -> bool {
29552     switch (C.getOpcode()) {
29553     case ISD::EntryToken:
29554       // No need to mark EntryToken.
29555       C = SDValue();
29556       return true;
29557     case ISD::LOAD:
29558     case ISD::STORE: {
29559       // Get alias information for C.
29560       // TODO: Relax aliasing for unordered atomics (see D66309)
29561       bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
29562                       cast<LSBaseSDNode>(C.getNode())->isSimple();
29563       if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
29564         // Look further up the chain.
29565         C = C.getOperand(0);
29566         return true;
29567       }
29568       // Alias, so stop here.
29569       return false;
29570     }
29571 
29572     case ISD::CopyFromReg:
29573       // Always forward past CopyFromReg.
29574       C = C.getOperand(0);
29575       return true;
29576 
29577     case ISD::LIFETIME_START:
29578     case ISD::LIFETIME_END: {
29579       // We can forward past any lifetime start/end that can be proven not to
29580       // alias the memory access.
29581       if (!mayAlias(N, C.getNode())) {
29582         // Look further up the chain.
29583         C = C.getOperand(0);
29584         return true;
29585       }
29586       return false;
29587     }
29588     default:
29589       return false;
29590     }
29591   };
29592 
29593   // Look at each chain and determine if it is an alias.  If so, add it to the
29594   // aliases list.  If not, then continue up the chain looking for the next
29595   // candidate.
29596   while (!Chains.empty()) {
29597     SDValue Chain = Chains.pop_back_val();
29598 
29599     // Don't bother if we've seen Chain before.
29600     if (!Visited.insert(Chain.getNode()).second)
29601       continue;
29602 
29603     // For TokenFactor nodes, look at each operand and only continue up the
29604     // chain until we reach the depth limit.
29605     //
29606     // FIXME: The depth check could be made to return the last non-aliasing
29607     // chain we found before we hit a tokenfactor rather than the original
29608     // chain.
29609     if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
29610       Aliases.clear();
29611       Aliases.push_back(OriginalChain);
29612       return;
29613     }
29614 
29615     if (Chain.getOpcode() == ISD::TokenFactor) {
29616       // We have to check each of the operands of the token factor for "small"
29617       // token factors, so we queue them up.  Adding the operands to the queue
29618       // (stack) in reverse order maintains the original order and increases the
29619       // likelihood that getNode will find a matching token factor (CSE.)
29620       if (Chain.getNumOperands() > 16) {
29621         Aliases.push_back(Chain);
29622         continue;
29623       }
29624       for (unsigned n = Chain.getNumOperands(); n;)
29625         Chains.push_back(Chain.getOperand(--n));
29626       ++Depth;
29627       continue;
29628     }
29629     // Everything else
29630     if (ImproveChain(Chain)) {
29631       // Updated Chain Found, Consider new chain if one exists.
29632       if (Chain.getNode())
29633         Chains.push_back(Chain);
29634       ++Depth;
29635       continue;
29636     }
29637     // No Improved Chain Possible, treat as Alias.
29638     Aliases.push_back(Chain);
29639   }
29640 }
29641 
29642 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
29643 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)29644 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
29645   if (OptLevel == CodeGenOptLevel::None)
29646     return OldChain;
29647 
29648   // Ops for replacing token factor.
29649   SmallVector<SDValue, 8> Aliases;
29650 
29651   // Accumulate all the aliases to this node.
29652   GatherAllAliases(N, OldChain, Aliases);
29653 
29654   // If no operands then chain to entry token.
29655   if (Aliases.empty())
29656     return DAG.getEntryNode();
29657 
29658   // If a single operand then chain to it.  We don't need to revisit it.
29659   if (Aliases.size() == 1)
29660     return Aliases[0];
29661 
29662   // Construct a custom tailored token factor.
29663   return DAG.getTokenFactor(SDLoc(N), Aliases);
29664 }
29665 
29666 // This function tries to collect a bunch of potentially interesting
29667 // nodes to improve the chains of, all at once. This might seem
29668 // redundant, as this function gets called when visiting every store
29669 // node, so why not let the work be done on each store as it's visited?
29670 //
29671 // I believe this is mainly important because mergeConsecutiveStores
29672 // is unable to deal with merging stores of different sizes, so unless
29673 // we improve the chains of all the potential candidates up-front
29674 // before running mergeConsecutiveStores, it might only see some of
29675 // the nodes that will eventually be candidates, and then not be able
29676 // to go from a partially-merged state to the desired final
29677 // fully-merged state.
29678 
parallelizeChainedStores(StoreSDNode * St)29679 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
29680   SmallVector<StoreSDNode *, 8> ChainedStores;
29681   StoreSDNode *STChain = St;
29682   // Intervals records which offsets from BaseIndex have been covered. In
29683   // the common case, every store writes to the immediately previous address
29684   // space and thus merged with the previous interval at insertion time.
29685 
29686   using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
29687                                  IntervalMapHalfOpenInfo<int64_t>>;
29688   IMap::Allocator A;
29689   IMap Intervals(A);
29690 
29691   // This holds the base pointer, index, and the offset in bytes from the base
29692   // pointer.
29693   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
29694 
29695   // We must have a base and an offset.
29696   if (!BasePtr.getBase().getNode())
29697     return false;
29698 
29699   // Do not handle stores to undef base pointers.
29700   if (BasePtr.getBase().isUndef())
29701     return false;
29702 
29703   // Do not handle stores to opaque types
29704   if (St->getMemoryVT().isZeroSized())
29705     return false;
29706 
29707   // BaseIndexOffset assumes that offsets are fixed-size, which
29708   // is not valid for scalable vectors where the offsets are
29709   // scaled by `vscale`, so bail out early.
29710   if (St->getMemoryVT().isScalableVT())
29711     return false;
29712 
29713   // Add ST's interval.
29714   Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8,
29715                    std::monostate{});
29716 
29717   while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
29718     if (Chain->getMemoryVT().isScalableVector())
29719       return false;
29720 
29721     // If the chain has more than one use, then we can't reorder the mem ops.
29722     if (!SDValue(Chain, 0)->hasOneUse())
29723       break;
29724     // TODO: Relax for unordered atomics (see D66309)
29725     if (!Chain->isSimple() || Chain->isIndexed())
29726       break;
29727 
29728     // Find the base pointer and offset for this memory node.
29729     const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
29730     // Check that the base pointer is the same as the original one.
29731     int64_t Offset;
29732     if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
29733       break;
29734     int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
29735     // Make sure we don't overlap with other intervals by checking the ones to
29736     // the left or right before inserting.
29737     auto I = Intervals.find(Offset);
29738     // If there's a next interval, we should end before it.
29739     if (I != Intervals.end() && I.start() < (Offset + Length))
29740       break;
29741     // If there's a previous interval, we should start after it.
29742     if (I != Intervals.begin() && (--I).stop() <= Offset)
29743       break;
29744     Intervals.insert(Offset, Offset + Length, std::monostate{});
29745 
29746     ChainedStores.push_back(Chain);
29747     STChain = Chain;
29748   }
29749 
29750   // If we didn't find a chained store, exit.
29751   if (ChainedStores.empty())
29752     return false;
29753 
29754   // Improve all chained stores (St and ChainedStores members) starting from
29755   // where the store chain ended and return single TokenFactor.
29756   SDValue NewChain = STChain->getChain();
29757   SmallVector<SDValue, 8> TFOps;
29758   for (unsigned I = ChainedStores.size(); I;) {
29759     StoreSDNode *S = ChainedStores[--I];
29760     SDValue BetterChain = FindBetterChain(S, NewChain);
29761     S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
29762         S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
29763     TFOps.push_back(SDValue(S, 0));
29764     ChainedStores[I] = S;
29765   }
29766 
29767   // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
29768   SDValue BetterChain = FindBetterChain(St, NewChain);
29769   SDValue NewST;
29770   if (St->isTruncatingStore())
29771     NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
29772                               St->getBasePtr(), St->getMemoryVT(),
29773                               St->getMemOperand());
29774   else
29775     NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
29776                          St->getBasePtr(), St->getMemOperand());
29777 
29778   TFOps.push_back(NewST);
29779 
29780   // If we improved every element of TFOps, then we've lost the dependence on
29781   // NewChain to successors of St and we need to add it back to TFOps. Do so at
29782   // the beginning to keep relative order consistent with FindBetterChains.
29783   auto hasImprovedChain = [&](SDValue ST) -> bool {
29784     return ST->getOperand(0) != NewChain;
29785   };
29786   bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
29787   if (AddNewChain)
29788     TFOps.insert(TFOps.begin(), NewChain);
29789 
29790   SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
29791   CombineTo(St, TF);
29792 
29793   // Add TF and its operands to the worklist.
29794   AddToWorklist(TF.getNode());
29795   for (const SDValue &Op : TF->ops())
29796     AddToWorklist(Op.getNode());
29797   AddToWorklist(STChain);
29798   return true;
29799 }
29800 
findBetterNeighborChains(StoreSDNode * St)29801 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
29802   if (OptLevel == CodeGenOptLevel::None)
29803     return false;
29804 
29805   const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
29806 
29807   // We must have a base and an offset.
29808   if (!BasePtr.getBase().getNode())
29809     return false;
29810 
29811   // Do not handle stores to undef base pointers.
29812   if (BasePtr.getBase().isUndef())
29813     return false;
29814 
29815   // Directly improve a chain of disjoint stores starting at St.
29816   if (parallelizeChainedStores(St))
29817     return true;
29818 
29819   // Improve St's Chain..
29820   SDValue BetterChain = FindBetterChain(St, St->getChain());
29821   if (St->getChain() != BetterChain) {
29822     replaceStoreChain(St, BetterChain);
29823     return true;
29824   }
29825   return false;
29826 }
29827 
29828 /// This is the entry point for the file.
Combine(CombineLevel Level,BatchAAResults * BatchAA,CodeGenOptLevel OptLevel)29829 void SelectionDAG::Combine(CombineLevel Level, BatchAAResults *BatchAA,
29830                            CodeGenOptLevel OptLevel) {
29831   /// This is the main entry point to this class.
29832   DAGCombiner(*this, BatchAA, OptLevel).Run(Level);
29833 }
29834