1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/MemoryLocation.h"
32 #include "llvm/Analysis/TargetLibraryInfo.h"
33 #include "llvm/Analysis/ValueTracking.h"
34 #include "llvm/Analysis/VectorUtils.h"
35 #include "llvm/CodeGen/ByteProvider.h"
36 #include "llvm/CodeGen/DAGCombine.h"
37 #include "llvm/CodeGen/ISDOpcodes.h"
38 #include "llvm/CodeGen/MachineFunction.h"
39 #include "llvm/CodeGen/MachineMemOperand.h"
40 #include "llvm/CodeGen/SDPatternMatch.h"
41 #include "llvm/CodeGen/SelectionDAG.h"
42 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
43 #include "llvm/CodeGen/SelectionDAGNodes.h"
44 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
45 #include "llvm/CodeGen/TargetLowering.h"
46 #include "llvm/CodeGen/TargetRegisterInfo.h"
47 #include "llvm/CodeGen/TargetSubtargetInfo.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/CodeGenTypes/MachineValueType.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/Constant.h"
52 #include "llvm/IR/DataLayout.h"
53 #include "llvm/IR/DerivedTypes.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/Metadata.h"
56 #include "llvm/Support/Casting.h"
57 #include "llvm/Support/CodeGen.h"
58 #include "llvm/Support/CommandLine.h"
59 #include "llvm/Support/Compiler.h"
60 #include "llvm/Support/Debug.h"
61 #include "llvm/Support/DebugCounter.h"
62 #include "llvm/Support/ErrorHandling.h"
63 #include "llvm/Support/KnownBits.h"
64 #include "llvm/Support/MathExtras.h"
65 #include "llvm/Support/raw_ostream.h"
66 #include "llvm/Target/TargetMachine.h"
67 #include "llvm/Target/TargetOptions.h"
68 #include <algorithm>
69 #include <cassert>
70 #include <cstdint>
71 #include <functional>
72 #include <iterator>
73 #include <optional>
74 #include <string>
75 #include <tuple>
76 #include <utility>
77 #include <variant>
78
79 #include "MatchContext.h"
80
81 using namespace llvm;
82 using namespace llvm::SDPatternMatch;
83
84 #define DEBUG_TYPE "dagcombine"
85
86 STATISTIC(NodesCombined , "Number of dag nodes combined");
87 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
88 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
89 STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
90 STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
91 STATISTIC(SlicedLoads, "Number of load sliced");
92 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
93
94 DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
95 "Controls whether a DAG combine is performed for a node");
96
97 static cl::opt<bool>
98 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
99 cl::desc("Enable DAG combiner's use of IR alias analysis"));
100
101 static cl::opt<bool>
102 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
103 cl::desc("Enable DAG combiner's use of TBAA"));
104
105 #ifndef NDEBUG
106 static cl::opt<std::string>
107 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
108 cl::desc("Only use DAG-combiner alias analysis in this"
109 " function"));
110 #endif
111
112 /// Hidden option to stress test load slicing, i.e., when this option
113 /// is enabled, load slicing bypasses most of its profitability guards.
114 static cl::opt<bool>
115 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
116 cl::desc("Bypass the profitability model of load slicing"),
117 cl::init(false));
118
119 static cl::opt<bool>
120 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
121 cl::desc("DAG combiner may split indexing from loads"));
122
123 static cl::opt<bool>
124 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
125 cl::desc("DAG combiner enable merging multiple stores "
126 "into a wider store"));
127
128 static cl::opt<unsigned> TokenFactorInlineLimit(
129 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
130 cl::desc("Limit the number of operands to inline for Token Factors"));
131
132 static cl::opt<unsigned> StoreMergeDependenceLimit(
133 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
134 cl::desc("Limit the number of times for the same StoreNode and RootNode "
135 "to bail out in store merging dependence check"));
136
137 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
138 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
139 cl::desc("DAG combiner enable reducing the width of load/op/store "
140 "sequence"));
141 static cl::opt<bool> ReduceLoadOpStoreWidthForceNarrowingProfitable(
142 "combiner-reduce-load-op-store-width-force-narrowing-profitable",
143 cl::Hidden, cl::init(false),
144 cl::desc("DAG combiner force override the narrowing profitable check when "
145 "reducing the width of load/op/store sequences"));
146
147 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
148 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
149 cl::desc("DAG combiner enable load/<replace bytes>/store with "
150 "a narrower store"));
151
152 static cl::opt<bool> DisableCombines("combiner-disabled", cl::Hidden,
153 cl::init(false),
154 cl::desc("Disable the DAG combiner"));
155
156 namespace {
157
158 class DAGCombiner {
159 SelectionDAG &DAG;
160 const TargetLowering &TLI;
161 const SelectionDAGTargetInfo *STI;
162 CombineLevel Level = BeforeLegalizeTypes;
163 CodeGenOptLevel OptLevel;
164 bool LegalDAG = false;
165 bool LegalOperations = false;
166 bool LegalTypes = false;
167 bool ForCodeSize;
168 bool DisableGenericCombines;
169
170 /// Worklist of all of the nodes that need to be simplified.
171 ///
172 /// This must behave as a stack -- new nodes to process are pushed onto the
173 /// back and when processing we pop off of the back.
174 ///
175 /// The worklist will not contain duplicates but may contain null entries
176 /// due to nodes being deleted from the underlying DAG. For fast lookup and
177 /// deduplication, the index of the node in this vector is stored in the
178 /// node in SDNode::CombinerWorklistIndex.
179 SmallVector<SDNode *, 64> Worklist;
180
181 /// This records all nodes attempted to be added to the worklist since we
182 /// considered a new worklist entry. As we keep do not add duplicate nodes
183 /// in the worklist, this is different from the tail of the worklist.
184 SmallSetVector<SDNode *, 32> PruningList;
185
186 /// Map from candidate StoreNode to the pair of RootNode and count.
187 /// The count is used to track how many times we have seen the StoreNode
188 /// with the same RootNode bail out in dependence check. If we have seen
189 /// the bail out for the same pair many times over a limit, we won't
190 /// consider the StoreNode with the same RootNode as store merging
191 /// candidate again.
192 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
193
194 // BatchAA - Used for DAG load/store alias analysis.
195 BatchAAResults *BatchAA;
196
197 /// This caches all chains that have already been processed in
198 /// DAGCombiner::getStoreMergeCandidates() and found to have no mergeable
199 /// stores candidates.
200 SmallPtrSet<SDNode *, 4> ChainsWithoutMergeableStores;
201
202 /// When an instruction is simplified, add all users of the instruction to
203 /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)204 void AddUsersToWorklist(SDNode *N) {
205 for (SDNode *Node : N->users())
206 AddToWorklist(Node);
207 }
208
209 /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)210 void AddToWorklistWithUsers(SDNode *N) {
211 AddUsersToWorklist(N);
212 AddToWorklist(N);
213 }
214
215 // Prune potentially dangling nodes. This is called after
216 // any visit to a node, but should also be called during a visit after any
217 // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()218 void clearAddedDanglingWorklistEntries() {
219 // Check any nodes added to the worklist to see if they are prunable.
220 while (!PruningList.empty()) {
221 auto *N = PruningList.pop_back_val();
222 if (N->use_empty())
223 recursivelyDeleteUnusedNodes(N);
224 }
225 }
226
getNextWorklistEntry()227 SDNode *getNextWorklistEntry() {
228 // Before we do any work, remove nodes that are not in use.
229 clearAddedDanglingWorklistEntries();
230 SDNode *N = nullptr;
231 // The Worklist holds the SDNodes in order, but it may contain null
232 // entries.
233 while (!N && !Worklist.empty()) {
234 N = Worklist.pop_back_val();
235 }
236
237 if (N) {
238 assert(N->getCombinerWorklistIndex() >= 0 &&
239 "Found a worklist entry without a corresponding map entry!");
240 // Set to -2 to indicate that we combined the node.
241 N->setCombinerWorklistIndex(-2);
242 }
243 return N;
244 }
245
246 /// Call the node-specific routine that folds each particular type of node.
247 SDValue visit(SDNode *N);
248
249 public:
DAGCombiner(SelectionDAG & D,BatchAAResults * BatchAA,CodeGenOptLevel OL)250 DAGCombiner(SelectionDAG &D, BatchAAResults *BatchAA, CodeGenOptLevel OL)
251 : DAG(D), TLI(D.getTargetLoweringInfo()),
252 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL),
253 BatchAA(BatchAA) {
254 ForCodeSize = DAG.shouldOptForSize();
255 DisableGenericCombines =
256 DisableCombines || (STI && STI->disableGenericCombines(OptLevel));
257
258 MaximumLegalStoreInBits = 0;
259 // We use the minimum store size here, since that's all we can guarantee
260 // for the scalable vector types.
261 for (MVT VT : MVT::all_valuetypes())
262 if (EVT(VT).isSimple() && VT != MVT::Other &&
263 TLI.isTypeLegal(EVT(VT)) &&
264 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
265 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
266 }
267
ConsiderForPruning(SDNode * N)268 void ConsiderForPruning(SDNode *N) {
269 // Mark this for potential pruning.
270 PruningList.insert(N);
271 }
272
273 /// Add to the worklist making sure its instance is at the back (next to be
274 /// processed.)
AddToWorklist(SDNode * N,bool IsCandidateForPruning=true,bool SkipIfCombinedBefore=false)275 void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true,
276 bool SkipIfCombinedBefore = false) {
277 assert(N->getOpcode() != ISD::DELETED_NODE &&
278 "Deleted Node added to Worklist");
279
280 // Skip handle nodes as they can't usefully be combined and confuse the
281 // zero-use deletion strategy.
282 if (N->getOpcode() == ISD::HANDLENODE)
283 return;
284
285 if (SkipIfCombinedBefore && N->getCombinerWorklistIndex() == -2)
286 return;
287
288 if (IsCandidateForPruning)
289 ConsiderForPruning(N);
290
291 if (N->getCombinerWorklistIndex() < 0) {
292 N->setCombinerWorklistIndex(Worklist.size());
293 Worklist.push_back(N);
294 }
295 }
296
297 /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)298 void removeFromWorklist(SDNode *N) {
299 PruningList.remove(N);
300 StoreRootCountMap.erase(N);
301
302 int WorklistIndex = N->getCombinerWorklistIndex();
303 // If not in the worklist, the index might be -1 or -2 (was combined
304 // before). As the node gets deleted anyway, there's no need to update
305 // the index.
306 if (WorklistIndex < 0)
307 return; // Not in the worklist.
308
309 // Null out the entry rather than erasing it to avoid a linear operation.
310 Worklist[WorklistIndex] = nullptr;
311 N->setCombinerWorklistIndex(-1);
312 }
313
314 void deleteAndRecombine(SDNode *N);
315 bool recursivelyDeleteUnusedNodes(SDNode *N);
316
317 /// Replaces all uses of the results of one DAG node with new values.
318 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
319 bool AddTo = true);
320
321 /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)322 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
323 return CombineTo(N, &Res, 1, AddTo);
324 }
325
326 /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)327 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
328 bool AddTo = true) {
329 SDValue To[] = { Res0, Res1 };
330 return CombineTo(N, To, 2, AddTo);
331 }
332
333 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
334
335 private:
336 unsigned MaximumLegalStoreInBits;
337
338 /// Check the specified integer node value to see if it can be simplified or
339 /// if things it uses can be simplified by bit propagation.
340 /// If so, return true.
SimplifyDemandedBits(SDValue Op)341 bool SimplifyDemandedBits(SDValue Op) {
342 unsigned BitWidth = Op.getScalarValueSizeInBits();
343 APInt DemandedBits = APInt::getAllOnes(BitWidth);
344 return SimplifyDemandedBits(Op, DemandedBits);
345 }
346
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)347 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
348 EVT VT = Op.getValueType();
349 APInt DemandedElts = VT.isFixedLengthVector()
350 ? APInt::getAllOnes(VT.getVectorNumElements())
351 : APInt(1, 1);
352 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, false);
353 }
354
355 /// Check the specified vector node value to see if it can be simplified or
356 /// if things it uses can be simplified as it only uses some of the
357 /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)358 bool SimplifyDemandedVectorElts(SDValue Op) {
359 // TODO: For now just pretend it cannot be simplified.
360 if (Op.getValueType().isScalableVector())
361 return false;
362
363 unsigned NumElts = Op.getValueType().getVectorNumElements();
364 APInt DemandedElts = APInt::getAllOnes(NumElts);
365 return SimplifyDemandedVectorElts(Op, DemandedElts);
366 }
367
368 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
369 const APInt &DemandedElts,
370 bool AssumeSingleUse = false);
371 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
372 bool AssumeSingleUse = false);
373
374 bool CombineToPreIndexedLoadStore(SDNode *N);
375 bool CombineToPostIndexedLoadStore(SDNode *N);
376 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
377 bool SliceUpLoad(SDNode *N);
378
379 // Looks up the chain to find a unique (unaliased) store feeding the passed
380 // load. If no such store is found, returns a nullptr.
381 // Note: This will look past a CALLSEQ_START if the load is chained to it so
382 // so that it can find stack stores for byval params.
383 StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
384 // Scalars have size 0 to distinguish from singleton vectors.
385 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
386 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
387 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
388
389 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
390 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
391 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
392 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
393 SDValue PromoteIntBinOp(SDValue Op);
394 SDValue PromoteIntShiftOp(SDValue Op);
395 SDValue PromoteExtend(SDValue Op);
396 bool PromoteLoad(SDValue Op);
397
398 SDValue foldShiftToAvg(SDNode *N);
399 // Fold `a bitwiseop (~b +/- c)` -> `a bitwiseop ~(b -/+ c)`
400 SDValue foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT);
401
402 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
403 SDValue RHS, SDValue True, SDValue False,
404 ISD::CondCode CC);
405
406 /// Call the node-specific routine that knows how to fold each
407 /// particular type of node. If that doesn't do anything, try the
408 /// target-specific DAG combines.
409 SDValue combine(SDNode *N);
410
411 // Visitation implementation - Implement dag node combining for different
412 // node types. The semantics are as follows:
413 // Return Value:
414 // SDValue.getNode() == 0 - No change was made
415 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
416 // otherwise - N should be replaced by the returned Operand.
417 //
418 SDValue visitTokenFactor(SDNode *N);
419 SDValue visitMERGE_VALUES(SDNode *N);
420 SDValue visitADD(SDNode *N);
421 SDValue visitADDLike(SDNode *N);
422 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
423 SDNode *LocReference);
424 SDValue visitPTRADD(SDNode *N);
425 SDValue visitSUB(SDNode *N);
426 SDValue visitADDSAT(SDNode *N);
427 SDValue visitSUBSAT(SDNode *N);
428 SDValue visitADDC(SDNode *N);
429 SDValue visitADDO(SDNode *N);
430 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
431 SDValue visitSUBC(SDNode *N);
432 SDValue visitSUBO(SDNode *N);
433 SDValue visitADDE(SDNode *N);
434 SDValue visitUADDO_CARRY(SDNode *N);
435 SDValue visitSADDO_CARRY(SDNode *N);
436 SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
437 SDNode *N);
438 SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
439 SDNode *N);
440 SDValue visitSUBE(SDNode *N);
441 SDValue visitUSUBO_CARRY(SDNode *N);
442 SDValue visitSSUBO_CARRY(SDNode *N);
443 template <class MatchContextClass> SDValue visitMUL(SDNode *N);
444 SDValue visitMULFIX(SDNode *N);
445 SDValue useDivRem(SDNode *N);
446 SDValue visitSDIV(SDNode *N);
447 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
448 SDValue visitUDIV(SDNode *N);
449 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
450 SDValue visitREM(SDNode *N);
451 SDValue visitMULHU(SDNode *N);
452 SDValue visitMULHS(SDNode *N);
453 SDValue visitAVG(SDNode *N);
454 SDValue visitABD(SDNode *N);
455 SDValue visitSMUL_LOHI(SDNode *N);
456 SDValue visitUMUL_LOHI(SDNode *N);
457 SDValue visitMULO(SDNode *N);
458 SDValue visitIMINMAX(SDNode *N);
459 SDValue visitAND(SDNode *N);
460 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
461 SDValue visitOR(SDNode *N);
462 SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
463 SDValue visitXOR(SDNode *N);
464 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
465 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
466 SDValue visitSHL(SDNode *N);
467 SDValue visitSRA(SDNode *N);
468 SDValue visitSRL(SDNode *N);
469 SDValue visitFunnelShift(SDNode *N);
470 SDValue visitSHLSAT(SDNode *N);
471 SDValue visitRotate(SDNode *N);
472 SDValue visitABS(SDNode *N);
473 SDValue visitBSWAP(SDNode *N);
474 SDValue visitBITREVERSE(SDNode *N);
475 SDValue visitCTLZ(SDNode *N);
476 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
477 SDValue visitCTTZ(SDNode *N);
478 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
479 SDValue visitCTPOP(SDNode *N);
480 SDValue visitSELECT(SDNode *N);
481 SDValue visitVSELECT(SDNode *N);
482 SDValue visitVP_SELECT(SDNode *N);
483 SDValue visitSELECT_CC(SDNode *N);
484 SDValue visitSETCC(SDNode *N);
485 SDValue visitSETCCCARRY(SDNode *N);
486 SDValue visitSIGN_EXTEND(SDNode *N);
487 SDValue visitZERO_EXTEND(SDNode *N);
488 SDValue visitANY_EXTEND(SDNode *N);
489 SDValue visitAssertExt(SDNode *N);
490 SDValue visitAssertAlign(SDNode *N);
491 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
492 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
493 SDValue visitTRUNCATE(SDNode *N);
494 SDValue visitTRUNCATE_USAT_U(SDNode *N);
495 SDValue visitBITCAST(SDNode *N);
496 SDValue visitFREEZE(SDNode *N);
497 SDValue visitBUILD_PAIR(SDNode *N);
498 SDValue visitFADD(SDNode *N);
499 SDValue visitVP_FADD(SDNode *N);
500 SDValue visitVP_FSUB(SDNode *N);
501 SDValue visitSTRICT_FADD(SDNode *N);
502 SDValue visitFSUB(SDNode *N);
503 SDValue visitFMUL(SDNode *N);
504 template <class MatchContextClass> SDValue visitFMA(SDNode *N);
505 SDValue visitFMAD(SDNode *N);
506 SDValue visitFDIV(SDNode *N);
507 SDValue visitFREM(SDNode *N);
508 SDValue visitFSQRT(SDNode *N);
509 SDValue visitFCOPYSIGN(SDNode *N);
510 SDValue visitFPOW(SDNode *N);
511 SDValue visitFCANONICALIZE(SDNode *N);
512 SDValue visitSINT_TO_FP(SDNode *N);
513 SDValue visitUINT_TO_FP(SDNode *N);
514 SDValue visitFP_TO_SINT(SDNode *N);
515 SDValue visitFP_TO_UINT(SDNode *N);
516 SDValue visitXROUND(SDNode *N);
517 SDValue visitFP_ROUND(SDNode *N);
518 SDValue visitFP_EXTEND(SDNode *N);
519 SDValue visitFNEG(SDNode *N);
520 SDValue visitFABS(SDNode *N);
521 SDValue visitFCEIL(SDNode *N);
522 SDValue visitFTRUNC(SDNode *N);
523 SDValue visitFFREXP(SDNode *N);
524 SDValue visitFFLOOR(SDNode *N);
525 SDValue visitFMinMax(SDNode *N);
526 SDValue visitBRCOND(SDNode *N);
527 SDValue visitBR_CC(SDNode *N);
528 SDValue visitLOAD(SDNode *N);
529
530 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
531 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
532 SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
533
534 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
535
536 SDValue visitSTORE(SDNode *N);
537 SDValue visitATOMIC_STORE(SDNode *N);
538 SDValue visitLIFETIME_END(SDNode *N);
539 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
540 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
541 SDValue visitBUILD_VECTOR(SDNode *N);
542 SDValue visitCONCAT_VECTORS(SDNode *N);
543 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
544 SDValue visitVECTOR_SHUFFLE(SDNode *N);
545 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
546 SDValue visitINSERT_SUBVECTOR(SDNode *N);
547 SDValue visitVECTOR_COMPRESS(SDNode *N);
548 SDValue visitMLOAD(SDNode *N);
549 SDValue visitMSTORE(SDNode *N);
550 SDValue visitMGATHER(SDNode *N);
551 SDValue visitMSCATTER(SDNode *N);
552 SDValue visitMHISTOGRAM(SDNode *N);
553 SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
554 SDValue visitVPGATHER(SDNode *N);
555 SDValue visitVPSCATTER(SDNode *N);
556 SDValue visitVP_STRIDED_LOAD(SDNode *N);
557 SDValue visitVP_STRIDED_STORE(SDNode *N);
558 SDValue visitFP_TO_FP16(SDNode *N);
559 SDValue visitFP16_TO_FP(SDNode *N);
560 SDValue visitFP_TO_BF16(SDNode *N);
561 SDValue visitBF16_TO_FP(SDNode *N);
562 SDValue visitVECREDUCE(SDNode *N);
563 SDValue visitVPOp(SDNode *N);
564 SDValue visitGET_FPENV_MEM(SDNode *N);
565 SDValue visitSET_FPENV_MEM(SDNode *N);
566
567 template <class MatchContextClass>
568 SDValue visitFADDForFMACombine(SDNode *N);
569 template <class MatchContextClass>
570 SDValue visitFSUBForFMACombine(SDNode *N);
571 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
572
573 SDValue XformToShuffleWithZero(SDNode *N);
574 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
575 const SDLoc &DL,
576 SDNode *N,
577 SDValue N0,
578 SDValue N1);
579 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
580 SDValue N1, SDNodeFlags Flags);
581 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
582 SDValue N1, SDNodeFlags Flags);
583 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
584 EVT VT, SDValue N0, SDValue N1,
585 SDNodeFlags Flags = SDNodeFlags());
586
587 SDValue visitShiftByConstant(SDNode *N);
588
589 SDValue foldSelectOfConstants(SDNode *N);
590 SDValue foldVSelectOfConstants(SDNode *N);
591 SDValue foldBinOpIntoSelect(SDNode *BO);
592 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
593 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
594 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
595 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
596 SDValue N2, SDValue N3, ISD::CondCode CC,
597 bool NotExtCompare = false);
598 SDValue convertSelectOfFPConstantsToLoadOffset(
599 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
600 ISD::CondCode CC);
601 SDValue foldSignChangeInBitcast(SDNode *N);
602 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
603 SDValue N2, SDValue N3, ISD::CondCode CC);
604 SDValue foldSelectOfBinops(SDNode *N);
605 SDValue foldSextSetcc(SDNode *N);
606 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
607 const SDLoc &DL);
608 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
609 SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
610 SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
611 SDValue False, ISD::CondCode CC, const SDLoc &DL);
612 SDValue unfoldMaskedMerge(SDNode *N);
613 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
614 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
615 const SDLoc &DL, bool foldBooleans);
616 SDValue rebuildSetCC(SDValue N);
617
618 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
619 SDValue &CC, bool MatchStrict = false) const;
620 bool isOneUseSetCC(SDValue N) const;
621
622 SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
623 SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
624
625 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
626 unsigned HiOp);
627 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
628 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
629 const TargetLowering &TLI);
630 SDValue foldPartialReduceMLAMulOp(SDNode *N);
631 SDValue foldPartialReduceAdd(SDNode *N);
632
633 SDValue CombineExtLoad(SDNode *N);
634 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
635 SDValue combineRepeatedFPDivisors(SDNode *N);
636 SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
637 SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf);
638 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
639 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
640 SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
641 SDValue BuildSDIV(SDNode *N);
642 SDValue BuildSDIVPow2(SDNode *N);
643 SDValue BuildUDIV(SDNode *N);
644 SDValue BuildSREMPow2(SDNode *N);
645 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
646 SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
647 bool KnownNeverZero = false,
648 bool InexpensiveOnly = false,
649 std::optional<EVT> OutVT = std::nullopt);
650 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
651 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
652 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
653 SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
654 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
655 SDNodeFlags Flags, bool Reciprocal);
656 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
657 SDNodeFlags Flags, bool Reciprocal);
658 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
659 bool DemandHighBits = true);
660 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
661 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
662 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
663 bool HasPos, unsigned PosOpcode,
664 unsigned NegOpcode, const SDLoc &DL);
665 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
666 SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
667 bool HasPos, unsigned PosOpcode,
668 unsigned NegOpcode, const SDLoc &DL);
669 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
670 bool FromAdd);
671 SDValue MatchLoadCombine(SDNode *N);
672 SDValue mergeTruncStores(StoreSDNode *N);
673 SDValue reduceLoadWidth(SDNode *N);
674 SDValue ReduceLoadOpStoreWidth(SDNode *N);
675 SDValue splitMergedValStore(StoreSDNode *ST);
676 SDValue TransformFPLoadStorePair(SDNode *N);
677 SDValue convertBuildVecZextToZext(SDNode *N);
678 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
679 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
680 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
681 SDValue reduceBuildVecToShuffle(SDNode *N);
682 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
683 ArrayRef<int> VectorMask, SDValue VecIn1,
684 SDValue VecIn2, unsigned LeftIdx,
685 bool DidSplitVec);
686 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
687
688 /// Walk up chain skipping non-aliasing memory nodes,
689 /// looking for aliasing nodes and adding them to the Aliases vector.
690 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
691 SmallVectorImpl<SDValue> &Aliases);
692
693 /// Return true if there is any possibility that the two addresses overlap.
694 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
695
696 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
697 /// chain (aliasing node.)
698 SDValue FindBetterChain(SDNode *N, SDValue Chain);
699
700 /// Try to replace a store and any possibly adjacent stores on
701 /// consecutive chains with better chains. Return true only if St is
702 /// replaced.
703 ///
704 /// Notice that other chains may still be replaced even if the function
705 /// returns false.
706 bool findBetterNeighborChains(StoreSDNode *St);
707
708 // Helper for findBetterNeighborChains. Walk up store chain add additional
709 // chained stores that do not overlap and can be parallelized.
710 bool parallelizeChainedStores(StoreSDNode *St);
711
712 /// Holds a pointer to an LSBaseSDNode as well as information on where it
713 /// is located in a sequence of memory operations connected by a chain.
714 struct MemOpLink {
715 // Ptr to the mem node.
716 LSBaseSDNode *MemNode;
717
718 // Offset from the base ptr.
719 int64_t OffsetFromBase;
720
MemOpLink__anon666e37100111::DAGCombiner::MemOpLink721 MemOpLink(LSBaseSDNode *N, int64_t Offset)
722 : MemNode(N), OffsetFromBase(Offset) {}
723 };
724
725 // Classify the origin of a stored value.
726 enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)727 StoreSource getStoreSource(SDValue StoreVal) {
728 switch (StoreVal.getOpcode()) {
729 case ISD::Constant:
730 case ISD::ConstantFP:
731 return StoreSource::Constant;
732 case ISD::BUILD_VECTOR:
733 if (ISD::isBuildVectorOfConstantSDNodes(StoreVal.getNode()) ||
734 ISD::isBuildVectorOfConstantFPSDNodes(StoreVal.getNode()))
735 return StoreSource::Constant;
736 return StoreSource::Unknown;
737 case ISD::EXTRACT_VECTOR_ELT:
738 case ISD::EXTRACT_SUBVECTOR:
739 return StoreSource::Extract;
740 case ISD::LOAD:
741 return StoreSource::Load;
742 default:
743 return StoreSource::Unknown;
744 }
745 }
746
747 /// This is a helper function for visitMUL to check the profitability
748 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
749 /// MulNode is the original multiply, AddNode is (add x, c1),
750 /// and ConstNode is c2.
751 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
752 SDValue ConstNode);
753
754 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
755 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
756 /// the type of the loaded value to be extended.
757 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
758 EVT LoadResultTy, EVT &ExtVT);
759
760 /// Helper function to calculate whether the given Load/Store can have its
761 /// width reduced to ExtVT.
762 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
763 EVT &MemVT, unsigned ShAmt = 0);
764
765 /// Used by BackwardsPropagateMask to find suitable loads.
766 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
767 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
768 ConstantSDNode *Mask, SDNode *&NodeToMask);
769 /// Attempt to propagate a given AND node back to load leaves so that they
770 /// can be combined into narrow loads.
771 bool BackwardsPropagateMask(SDNode *N);
772
773 /// Helper function for mergeConsecutiveStores which merges the component
774 /// store chains.
775 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
776 unsigned NumStores);
777
778 /// Helper function for mergeConsecutiveStores which checks if all the store
779 /// nodes have the same underlying object. We can still reuse the first
780 /// store's pointer info if all the stores are from the same object.
781 bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
782
783 /// This is a helper function for mergeConsecutiveStores. When the source
784 /// elements of the consecutive stores are all constants or all extracted
785 /// vector elements, try to merge them into one larger store introducing
786 /// bitcasts if necessary. \return True if a merged store was created.
787 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
788 EVT MemVT, unsigned NumStores,
789 bool IsConstantSrc, bool UseVector,
790 bool UseTrunc);
791
792 /// This is a helper function for mergeConsecutiveStores. Stores that
793 /// potentially may be merged with St are placed in StoreNodes. On success,
794 /// returns a chain predecessor to all store candidates.
795 SDNode *getStoreMergeCandidates(StoreSDNode *St,
796 SmallVectorImpl<MemOpLink> &StoreNodes);
797
798 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
799 /// have indirect dependency through their operands. RootNode is the
800 /// predecessor to all stores calculated by getStoreMergeCandidates and is
801 /// used to prune the dependency check. \return True if safe to merge.
802 bool checkMergeStoreCandidatesForDependencies(
803 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
804 SDNode *RootNode);
805
806 /// Helper function for tryStoreMergeOfLoads. Checks if the load/store
807 /// chain has a call in it. \return True if a call is found.
808 bool hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld);
809
810 /// This is a helper function for mergeConsecutiveStores. Given a list of
811 /// store candidates, find the first N that are consecutive in memory.
812 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
813 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
814 int64_t ElementSizeBytes) const;
815
816 /// This is a helper function for mergeConsecutiveStores. It is used for
817 /// store chains that are composed entirely of constant values.
818 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
819 unsigned NumConsecutiveStores,
820 EVT MemVT, SDNode *Root, bool AllowVectors);
821
822 /// This is a helper function for mergeConsecutiveStores. It is used for
823 /// store chains that are composed entirely of extracted vector elements.
824 /// When extracting multiple vector elements, try to store them in one
825 /// vector store rather than a sequence of scalar stores.
826 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
827 unsigned NumConsecutiveStores, EVT MemVT,
828 SDNode *Root);
829
830 /// This is a helper function for mergeConsecutiveStores. It is used for
831 /// store chains that are composed entirely of loaded values.
832 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
833 unsigned NumConsecutiveStores, EVT MemVT,
834 SDNode *Root, bool AllowVectors,
835 bool IsNonTemporalStore, bool IsNonTemporalLoad);
836
837 /// Merge consecutive store operations into a wide store.
838 /// This optimization uses wide integers or vectors when possible.
839 /// \return true if stores were merged.
840 bool mergeConsecutiveStores(StoreSDNode *St);
841
842 /// Try to transform a truncation where C is a constant:
843 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
844 ///
845 /// \p N needs to be a truncation and its first operand an AND. Other
846 /// requirements are checked by the function (e.g. that trunc is
847 /// single-use) and if missed an empty SDValue is returned.
848 SDValue distributeTruncateThroughAnd(SDNode *N);
849
850 /// Helper function to determine whether the target supports operation
851 /// given by \p Opcode for type \p VT, that is, whether the operation
852 /// is legal or custom before legalizing operations, and whether is
853 /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)854 bool hasOperation(unsigned Opcode, EVT VT) {
855 return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
856 }
857
hasUMin(EVT VT) const858 bool hasUMin(EVT VT) const {
859 auto LK = TLI.getTypeConversion(*DAG.getContext(), VT);
860 return (LK.first == TargetLoweringBase::TypeLegal ||
861 LK.first == TargetLoweringBase::TypePromoteInteger) &&
862 TLI.isOperationLegal(ISD::UMIN, LK.second);
863 }
864
865 public:
866 /// Runs the dag combiner on all nodes in the work list
867 void Run(CombineLevel AtLevel);
868
getDAG() const869 SelectionDAG &getDAG() const { return DAG; }
870
871 /// Convenience wrapper around TargetLowering::getShiftAmountTy.
getShiftAmountTy(EVT LHSTy)872 EVT getShiftAmountTy(EVT LHSTy) {
873 return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout());
874 }
875
876 /// This method returns true if we are running before type legalization or
877 /// if the specified VT is legal.
isTypeLegal(const EVT & VT)878 bool isTypeLegal(const EVT &VT) {
879 if (!LegalTypes) return true;
880 return TLI.isTypeLegal(VT);
881 }
882
883 /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const884 EVT getSetCCResultType(EVT VT) const {
885 return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
886 }
887
888 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
889 SDValue OrigLoad, SDValue ExtLoad,
890 ISD::NodeType ExtType);
891 };
892
893 /// This class is a DAGUpdateListener that removes any deleted
894 /// nodes from the worklist.
895 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
896 DAGCombiner &DC;
897
898 public:
WorklistRemover(DAGCombiner & dc)899 explicit WorklistRemover(DAGCombiner &dc)
900 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
901
NodeDeleted(SDNode * N,SDNode * E)902 void NodeDeleted(SDNode *N, SDNode *E) override {
903 DC.removeFromWorklist(N);
904 }
905 };
906
907 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
908 DAGCombiner &DC;
909
910 public:
WorklistInserter(DAGCombiner & dc)911 explicit WorklistInserter(DAGCombiner &dc)
912 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
913
914 // FIXME: Ideally we could add N to the worklist, but this causes exponential
915 // compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)916 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
917 };
918
919 } // end anonymous namespace
920
921 //===----------------------------------------------------------------------===//
922 // TargetLowering::DAGCombinerInfo implementation
923 //===----------------------------------------------------------------------===//
924
AddToWorklist(SDNode * N)925 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
926 ((DAGCombiner*)DC)->AddToWorklist(N);
927 }
928
929 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)930 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
931 return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
932 }
933
934 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)935 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
936 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
937 }
938
939 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)940 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
941 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
942 }
943
944 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)945 recursivelyDeleteUnusedNodes(SDNode *N) {
946 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
947 }
948
949 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)950 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
951 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
952 }
953
954 //===----------------------------------------------------------------------===//
955 // Helper Functions
956 //===----------------------------------------------------------------------===//
957
deleteAndRecombine(SDNode * N)958 void DAGCombiner::deleteAndRecombine(SDNode *N) {
959 removeFromWorklist(N);
960
961 // If the operands of this node are only used by the node, they will now be
962 // dead. Make sure to re-visit them and recursively delete dead nodes.
963 for (const SDValue &Op : N->ops())
964 // For an operand generating multiple values, one of the values may
965 // become dead allowing further simplification (e.g. split index
966 // arithmetic from an indexed load).
967 if (Op->hasOneUse() || Op->getNumValues() > 1)
968 AddToWorklist(Op.getNode());
969
970 DAG.DeleteNode(N);
971 }
972
973 // APInts must be the same size for most operations, this helper
974 // function zero extends the shorter of the pair so that they match.
975 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)976 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
977 unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
978 LHS = LHS.zext(Bits);
979 RHS = RHS.zext(Bits);
980 }
981
982 // Return true if this node is a setcc, or is a select_cc
983 // that selects between the target values used for true and false, making it
984 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
985 // the appropriate nodes based on the type of node we are checking. This
986 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const987 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
988 SDValue &CC, bool MatchStrict) const {
989 if (N.getOpcode() == ISD::SETCC) {
990 LHS = N.getOperand(0);
991 RHS = N.getOperand(1);
992 CC = N.getOperand(2);
993 return true;
994 }
995
996 if (MatchStrict &&
997 (N.getOpcode() == ISD::STRICT_FSETCC ||
998 N.getOpcode() == ISD::STRICT_FSETCCS)) {
999 LHS = N.getOperand(1);
1000 RHS = N.getOperand(2);
1001 CC = N.getOperand(3);
1002 return true;
1003 }
1004
1005 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
1006 !TLI.isConstFalseVal(N.getOperand(3)))
1007 return false;
1008
1009 if (TLI.getBooleanContents(N.getValueType()) ==
1010 TargetLowering::UndefinedBooleanContent)
1011 return false;
1012
1013 LHS = N.getOperand(0);
1014 RHS = N.getOperand(1);
1015 CC = N.getOperand(4);
1016 return true;
1017 }
1018
1019 /// Return true if this is a SetCC-equivalent operation with only one use.
1020 /// If this is true, it allows the users to invert the operation for free when
1021 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const1022 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1023 SDValue N0, N1, N2;
1024 if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
1025 return true;
1026 return false;
1027 }
1028
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)1029 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1030 if (!ScalarTy.isSimple())
1031 return false;
1032
1033 uint64_t MaskForTy = 0ULL;
1034 switch (ScalarTy.getSimpleVT().SimpleTy) {
1035 case MVT::i8:
1036 MaskForTy = 0xFFULL;
1037 break;
1038 case MVT::i16:
1039 MaskForTy = 0xFFFFULL;
1040 break;
1041 case MVT::i32:
1042 MaskForTy = 0xFFFFFFFFULL;
1043 break;
1044 default:
1045 return false;
1046 break;
1047 }
1048
1049 APInt Val;
1050 if (ISD::isConstantSplatVector(N, Val))
1051 return Val.getLimitedValue() == MaskForTy;
1052
1053 return false;
1054 }
1055
1056 // Determines if it is a constant integer or a splat/build vector of constant
1057 // integers (and undefs).
1058 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)1059 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1060 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
1061 return !(Const->isOpaque() && NoOpaques);
1062 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1063 return false;
1064 unsigned BitWidth = N.getScalarValueSizeInBits();
1065 for (const SDValue &Op : N->op_values()) {
1066 if (Op.isUndef())
1067 continue;
1068 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
1069 if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1070 (Const->isOpaque() && NoOpaques))
1071 return false;
1072 }
1073 return true;
1074 }
1075
1076 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1077 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)1078 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1079 if (V.getOpcode() != ISD::BUILD_VECTOR)
1080 return false;
1081 return isConstantOrConstantVector(V, NoOpaques) ||
1082 ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
1083 }
1084
1085 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)1086 static bool canSplitIdx(LoadSDNode *LD) {
1087 return MaySplitLoadIndex &&
1088 (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
1089 !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
1090 }
1091
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1092 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1093 const SDLoc &DL,
1094 SDNode *N,
1095 SDValue N0,
1096 SDValue N1) {
1097 // Currently this only tries to ensure we don't undo the GEP splits done by
1098 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1099 // we check if the following transformation would be problematic:
1100 // (load/store (add, (add, x, offset1), offset2)) ->
1101 // (load/store (add, x, offset1+offset2)).
1102
1103 // (load/store (add, (add, x, y), offset2)) ->
1104 // (load/store (add, (add, x, offset2), y)).
1105
1106 if (!N0.isAnyAdd())
1107 return false;
1108
1109 // Check for vscale addressing modes.
1110 // (load/store (add/sub (add x, y), vscale))
1111 // (load/store (add/sub (add x, y), (lsl vscale, C)))
1112 // (load/store (add/sub (add x, y), (mul vscale, C)))
1113 if ((N1.getOpcode() == ISD::VSCALE ||
1114 ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::MUL) &&
1115 N1.getOperand(0).getOpcode() == ISD::VSCALE &&
1116 isa<ConstantSDNode>(N1.getOperand(1)))) &&
1117 N1.getValueType().getFixedSizeInBits() <= 64) {
1118 int64_t ScalableOffset = N1.getOpcode() == ISD::VSCALE
1119 ? N1.getConstantOperandVal(0)
1120 : (N1.getOperand(0).getConstantOperandVal(0) *
1121 (N1.getOpcode() == ISD::SHL
1122 ? (1LL << N1.getConstantOperandVal(1))
1123 : N1.getConstantOperandVal(1)));
1124 if (Opc == ISD::SUB)
1125 ScalableOffset = -ScalableOffset;
1126 if (all_of(N->users(), [&](SDNode *Node) {
1127 if (auto *LoadStore = dyn_cast<MemSDNode>(Node);
1128 LoadStore && LoadStore->getBasePtr().getNode() == N) {
1129 TargetLoweringBase::AddrMode AM;
1130 AM.HasBaseReg = true;
1131 AM.ScalableOffset = ScalableOffset;
1132 EVT VT = LoadStore->getMemoryVT();
1133 unsigned AS = LoadStore->getAddressSpace();
1134 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1135 return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy,
1136 AS);
1137 }
1138 return false;
1139 }))
1140 return true;
1141 }
1142
1143 if (Opc != ISD::ADD && Opc != ISD::PTRADD)
1144 return false;
1145
1146 auto *C2 = dyn_cast<ConstantSDNode>(N1);
1147 if (!C2)
1148 return false;
1149
1150 const APInt &C2APIntVal = C2->getAPIntValue();
1151 if (C2APIntVal.getSignificantBits() > 64)
1152 return false;
1153
1154 if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1155 if (N0.hasOneUse())
1156 return false;
1157
1158 const APInt &C1APIntVal = C1->getAPIntValue();
1159 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1160 if (CombinedValueIntVal.getSignificantBits() > 64)
1161 return false;
1162 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1163
1164 for (SDNode *Node : N->users()) {
1165 if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1166 // Is x[offset2] already not a legal addressing mode? If so then
1167 // reassociating the constants breaks nothing (we test offset2 because
1168 // that's the one we hope to fold into the load or store).
1169 TargetLoweringBase::AddrMode AM;
1170 AM.HasBaseReg = true;
1171 AM.BaseOffs = C2APIntVal.getSExtValue();
1172 EVT VT = LoadStore->getMemoryVT();
1173 unsigned AS = LoadStore->getAddressSpace();
1174 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1175 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1176 continue;
1177
1178 // Would x[offset1+offset2] still be a legal addressing mode?
1179 AM.BaseOffs = CombinedValue;
1180 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1181 return true;
1182 }
1183 }
1184 } else {
1185 if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1186 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1187 return false;
1188
1189 for (SDNode *Node : N->users()) {
1190 auto *LoadStore = dyn_cast<MemSDNode>(Node);
1191 if (!LoadStore)
1192 return false;
1193
1194 // Is x[offset2] a legal addressing mode? If so then
1195 // reassociating the constants breaks address pattern
1196 TargetLoweringBase::AddrMode AM;
1197 AM.HasBaseReg = true;
1198 AM.BaseOffs = C2APIntVal.getSExtValue();
1199 EVT VT = LoadStore->getMemoryVT();
1200 unsigned AS = LoadStore->getAddressSpace();
1201 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1202 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1203 return false;
1204 }
1205 return true;
1206 }
1207
1208 return false;
1209 }
1210
1211 /// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1212 /// \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1213 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1214 SDValue N0, SDValue N1,
1215 SDNodeFlags Flags) {
1216 EVT VT = N0.getValueType();
1217
1218 if (N0.getOpcode() != Opc)
1219 return SDValue();
1220
1221 SDValue N00 = N0.getOperand(0);
1222 SDValue N01 = N0.getOperand(1);
1223
1224 if (DAG.isConstantIntBuildVectorOrConstantInt(N01)) {
1225 SDNodeFlags NewFlags;
1226 if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1227 Flags.hasNoUnsignedWrap())
1228 NewFlags |= SDNodeFlags::NoUnsignedWrap;
1229
1230 if (DAG.isConstantIntBuildVectorOrConstantInt(N1)) {
1231 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1232 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1})) {
1233 NewFlags.setDisjoint(Flags.hasDisjoint() &&
1234 N0->getFlags().hasDisjoint());
1235 return DAG.getNode(Opc, DL, VT, N00, OpNode, NewFlags);
1236 }
1237 return SDValue();
1238 }
1239 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1240 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1241 // iff (op x, c1) has one use
1242 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, NewFlags);
1243 return DAG.getNode(Opc, DL, VT, OpNode, N01, NewFlags);
1244 }
1245 }
1246
1247 // Check for repeated operand logic simplifications.
1248 if (Opc == ISD::AND || Opc == ISD::OR) {
1249 // (N00 & N01) & N00 --> N00 & N01
1250 // (N00 & N01) & N01 --> N00 & N01
1251 // (N00 | N01) | N00 --> N00 | N01
1252 // (N00 | N01) | N01 --> N00 | N01
1253 if (N1 == N00 || N1 == N01)
1254 return N0;
1255 }
1256 if (Opc == ISD::XOR) {
1257 // (N00 ^ N01) ^ N00 --> N01
1258 if (N1 == N00)
1259 return N01;
1260 // (N00 ^ N01) ^ N01 --> N00
1261 if (N1 == N01)
1262 return N00;
1263 }
1264
1265 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1266 if (N1 != N01) {
1267 // Reassociate if (op N00, N1) already exist
1268 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1269 // if Op (Op N00, N1), N01 already exist
1270 // we need to stop reassciate to avoid dead loop
1271 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1272 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1273 }
1274 }
1275
1276 if (N1 != N00) {
1277 // Reassociate if (op N01, N1) already exist
1278 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1279 // if Op (Op N01, N1), N00 already exist
1280 // we need to stop reassciate to avoid dead loop
1281 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1282 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1283 }
1284 }
1285
1286 // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1287 // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1288 // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1289 // comparisons with the same predicate. This enables optimizations as the
1290 // following one:
1291 // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1292 // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1293 if (Opc == ISD::AND || Opc == ISD::OR) {
1294 if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1295 N01->getOpcode() == ISD::SETCC) {
1296 ISD::CondCode CC1 = cast<CondCodeSDNode>(N1.getOperand(2))->get();
1297 ISD::CondCode CC00 = cast<CondCodeSDNode>(N00.getOperand(2))->get();
1298 ISD::CondCode CC01 = cast<CondCodeSDNode>(N01.getOperand(2))->get();
1299 if (CC1 == CC00 && CC1 != CC01) {
1300 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1, Flags);
1301 return DAG.getNode(Opc, DL, VT, OpNode, N01, Flags);
1302 }
1303 if (CC1 == CC01 && CC1 != CC00) {
1304 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N01, N1, Flags);
1305 return DAG.getNode(Opc, DL, VT, OpNode, N00, Flags);
1306 }
1307 }
1308 }
1309 }
1310
1311 return SDValue();
1312 }
1313
1314 /// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1315 /// same kind of operation as \p Opc.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1316 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1317 SDValue N1, SDNodeFlags Flags) {
1318 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1319
1320 // Floating-point reassociation is not allowed without loose FP math.
1321 if (N0.getValueType().isFloatingPoint() ||
1322 N1.getValueType().isFloatingPoint())
1323 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1324 return SDValue();
1325
1326 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1327 return Combined;
1328 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0, Flags))
1329 return Combined;
1330 return SDValue();
1331 }
1332
1333 // Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1334 // Note that we only expect Flags to be passed from FP operations. For integer
1335 // operations they need to be dropped.
reassociateReduction(unsigned RedOpc,unsigned Opc,const SDLoc & DL,EVT VT,SDValue N0,SDValue N1,SDNodeFlags Flags)1336 SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1337 const SDLoc &DL, EVT VT, SDValue N0,
1338 SDValue N1, SDNodeFlags Flags) {
1339 if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1340 N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
1341 N0->hasOneUse() && N1->hasOneUse() &&
1342 TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) &&
1343 TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
1344 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1345 return DAG.getNode(RedOpc, DL, VT,
1346 DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
1347 N0.getOperand(0), N1.getOperand(0)));
1348 }
1349
1350 // Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
1351 // op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
1352 // single node.
1353 SDValue A, B, C, D, RedA, RedB;
1354 if (sd_match(N0, m_OneUse(m_c_BinOp(
1355 Opc,
1356 m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(A))),
1357 m_Value(RedA)),
1358 m_Value(B)))) &&
1359 sd_match(N1, m_OneUse(m_c_BinOp(
1360 Opc,
1361 m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(C))),
1362 m_Value(RedB)),
1363 m_Value(D)))) &&
1364 !sd_match(B, m_UnaryOp(RedOpc, m_Value())) &&
1365 !sd_match(D, m_UnaryOp(RedOpc, m_Value())) &&
1366 A.getValueType() == C.getValueType() &&
1367 hasOperation(Opc, A.getValueType()) &&
1368 TLI.shouldReassociateReduction(RedOpc, VT)) {
1369 if ((Opc == ISD::FADD || Opc == ISD::FMUL) &&
1370 (!N0->getFlags().hasAllowReassociation() ||
1371 !N1->getFlags().hasAllowReassociation() ||
1372 !RedA->getFlags().hasAllowReassociation() ||
1373 !RedB->getFlags().hasAllowReassociation()))
1374 return SDValue();
1375 SelectionDAG::FlagInserter FlagsInserter(
1376 DAG, Flags & N0->getFlags() & N1->getFlags() & RedA->getFlags() &
1377 RedB->getFlags());
1378 SDValue Op = DAG.getNode(Opc, DL, A.getValueType(), A, C);
1379 SDValue Red = DAG.getNode(RedOpc, DL, VT, Op);
1380 SDValue Op2 = DAG.getNode(Opc, DL, VT, B, D);
1381 return DAG.getNode(Opc, DL, VT, Red, Op2);
1382 }
1383 return SDValue();
1384 }
1385
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1386 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1387 bool AddTo) {
1388 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1389 ++NodesCombined;
1390 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1391 To[0].dump(&DAG);
1392 dbgs() << " and " << NumTo - 1 << " other values\n");
1393 for (unsigned i = 0, e = NumTo; i != e; ++i)
1394 assert((!To[i].getNode() ||
1395 N->getValueType(i) == To[i].getValueType()) &&
1396 "Cannot combine value to value of different type!");
1397
1398 WorklistRemover DeadNodes(*this);
1399 DAG.ReplaceAllUsesWith(N, To);
1400 if (AddTo) {
1401 // Push the new nodes and any users onto the worklist
1402 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1403 if (To[i].getNode())
1404 AddToWorklistWithUsers(To[i].getNode());
1405 }
1406 }
1407
1408 // Finally, if the node is now dead, remove it from the graph. The node
1409 // may not be dead if the replacement process recursively simplified to
1410 // something else needing this node.
1411 if (N->use_empty())
1412 deleteAndRecombine(N);
1413 return SDValue(N, 0);
1414 }
1415
1416 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1417 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1418 // Replace the old value with the new one.
1419 ++NodesCombined;
1420 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1421 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1422
1423 // Replace all uses.
1424 DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1425
1426 // Push the new node and any (possibly new) users onto the worklist.
1427 AddToWorklistWithUsers(TLO.New.getNode());
1428
1429 // Finally, if the node is now dead, remove it from the graph.
1430 recursivelyDeleteUnusedNodes(TLO.Old.getNode());
1431 }
1432
1433 /// Check the specified integer node value to see if it can be simplified or if
1434 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1435 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1436 const APInt &DemandedElts,
1437 bool AssumeSingleUse) {
1438 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1439 KnownBits Known;
1440 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1441 AssumeSingleUse))
1442 return false;
1443
1444 // Revisit the node.
1445 AddToWorklist(Op.getNode());
1446
1447 CommitTargetLoweringOpt(TLO);
1448 return true;
1449 }
1450
1451 /// Check the specified vector node value to see if it can be simplified or
1452 /// if things it uses can be simplified as it only uses some of the elements.
1453 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1454 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1455 const APInt &DemandedElts,
1456 bool AssumeSingleUse) {
1457 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1458 APInt KnownUndef, KnownZero;
1459 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1460 TLO, 0, AssumeSingleUse))
1461 return false;
1462
1463 // Revisit the node.
1464 AddToWorklist(Op.getNode());
1465
1466 CommitTargetLoweringOpt(TLO);
1467 return true;
1468 }
1469
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1470 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1471 SDLoc DL(Load);
1472 EVT VT = Load->getValueType(0);
1473 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1474
1475 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1476 Trunc.dump(&DAG); dbgs() << '\n');
1477
1478 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1479 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1480
1481 AddToWorklist(Trunc.getNode());
1482 recursivelyDeleteUnusedNodes(Load);
1483 }
1484
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1485 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1486 Replace = false;
1487 SDLoc DL(Op);
1488 if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1489 LoadSDNode *LD = cast<LoadSDNode>(Op);
1490 EVT MemVT = LD->getMemoryVT();
1491 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1492 : LD->getExtensionType();
1493 Replace = true;
1494 return DAG.getExtLoad(ExtType, DL, PVT,
1495 LD->getChain(), LD->getBasePtr(),
1496 MemVT, LD->getMemOperand());
1497 }
1498
1499 unsigned Opc = Op.getOpcode();
1500 switch (Opc) {
1501 default: break;
1502 case ISD::AssertSext:
1503 if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1504 return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1505 break;
1506 case ISD::AssertZext:
1507 if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1508 return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1509 break;
1510 case ISD::Constant: {
1511 unsigned ExtOpc =
1512 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1513 return DAG.getNode(ExtOpc, DL, PVT, Op);
1514 }
1515 }
1516
1517 if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1518 return SDValue();
1519 return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1520 }
1521
SExtPromoteOperand(SDValue Op,EVT PVT)1522 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1523 if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1524 return SDValue();
1525 EVT OldVT = Op.getValueType();
1526 SDLoc DL(Op);
1527 bool Replace = false;
1528 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1529 if (!NewOp.getNode())
1530 return SDValue();
1531 AddToWorklist(NewOp.getNode());
1532
1533 if (Replace)
1534 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1535 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1536 DAG.getValueType(OldVT));
1537 }
1538
ZExtPromoteOperand(SDValue Op,EVT PVT)1539 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1540 EVT OldVT = Op.getValueType();
1541 SDLoc DL(Op);
1542 bool Replace = false;
1543 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1544 if (!NewOp.getNode())
1545 return SDValue();
1546 AddToWorklist(NewOp.getNode());
1547
1548 if (Replace)
1549 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1550 return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1551 }
1552
1553 /// Promote the specified integer binary operation if the target indicates it is
1554 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1555 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1556 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1557 if (!LegalOperations)
1558 return SDValue();
1559
1560 EVT VT = Op.getValueType();
1561 if (VT.isVector() || !VT.isInteger())
1562 return SDValue();
1563
1564 // If operation type is 'undesirable', e.g. i16 on x86, consider
1565 // promoting it.
1566 unsigned Opc = Op.getOpcode();
1567 if (TLI.isTypeDesirableForOp(Opc, VT))
1568 return SDValue();
1569
1570 EVT PVT = VT;
1571 // Consult target whether it is a good idea to promote this operation and
1572 // what's the right type to promote it to.
1573 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1574 assert(PVT != VT && "Don't know what type to promote to!");
1575
1576 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1577
1578 bool Replace0 = false;
1579 SDValue N0 = Op.getOperand(0);
1580 SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1581
1582 bool Replace1 = false;
1583 SDValue N1 = Op.getOperand(1);
1584 SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1585 SDLoc DL(Op);
1586
1587 SDValue RV =
1588 DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1589
1590 // We are always replacing N0/N1's use in N and only need additional
1591 // replacements if there are additional uses.
1592 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1593 // (SDValue) here because the node may reference multiple values
1594 // (for example, the chain value of a load node).
1595 Replace0 &= !N0->hasOneUse();
1596 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1597
1598 // Combine Op here so it is preserved past replacements.
1599 CombineTo(Op.getNode(), RV);
1600
1601 // If operands have a use ordering, make sure we deal with
1602 // predecessor first.
1603 if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1604 std::swap(N0, N1);
1605 std::swap(NN0, NN1);
1606 }
1607
1608 if (Replace0) {
1609 AddToWorklist(NN0.getNode());
1610 ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1611 }
1612 if (Replace1) {
1613 AddToWorklist(NN1.getNode());
1614 ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1615 }
1616 return Op;
1617 }
1618 return SDValue();
1619 }
1620
1621 /// Promote the specified integer shift operation if the target indicates it is
1622 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1623 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1624 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1625 if (!LegalOperations)
1626 return SDValue();
1627
1628 EVT VT = Op.getValueType();
1629 if (VT.isVector() || !VT.isInteger())
1630 return SDValue();
1631
1632 // If operation type is 'undesirable', e.g. i16 on x86, consider
1633 // promoting it.
1634 unsigned Opc = Op.getOpcode();
1635 if (TLI.isTypeDesirableForOp(Opc, VT))
1636 return SDValue();
1637
1638 EVT PVT = VT;
1639 // Consult target whether it is a good idea to promote this operation and
1640 // what's the right type to promote it to.
1641 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1642 assert(PVT != VT && "Don't know what type to promote to!");
1643
1644 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1645
1646 bool Replace = false;
1647 SDValue N0 = Op.getOperand(0);
1648 if (Opc == ISD::SRA)
1649 N0 = SExtPromoteOperand(N0, PVT);
1650 else if (Opc == ISD::SRL)
1651 N0 = ZExtPromoteOperand(N0, PVT);
1652 else
1653 N0 = PromoteOperand(N0, PVT, Replace);
1654
1655 if (!N0.getNode())
1656 return SDValue();
1657
1658 SDLoc DL(Op);
1659 SDValue N1 = Op.getOperand(1);
1660 SDValue RV =
1661 DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1662
1663 if (Replace)
1664 ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1665
1666 // Deal with Op being deleted.
1667 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1668 return RV;
1669 }
1670 return SDValue();
1671 }
1672
PromoteExtend(SDValue Op)1673 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1674 if (!LegalOperations)
1675 return SDValue();
1676
1677 EVT VT = Op.getValueType();
1678 if (VT.isVector() || !VT.isInteger())
1679 return SDValue();
1680
1681 // If operation type is 'undesirable', e.g. i16 on x86, consider
1682 // promoting it.
1683 unsigned Opc = Op.getOpcode();
1684 if (TLI.isTypeDesirableForOp(Opc, VT))
1685 return SDValue();
1686
1687 EVT PVT = VT;
1688 // Consult target whether it is a good idea to promote this operation and
1689 // what's the right type to promote it to.
1690 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1691 assert(PVT != VT && "Don't know what type to promote to!");
1692 // fold (aext (aext x)) -> (aext x)
1693 // fold (aext (zext x)) -> (zext x)
1694 // fold (aext (sext x)) -> (sext x)
1695 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1696 return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1697 }
1698 return SDValue();
1699 }
1700
PromoteLoad(SDValue Op)1701 bool DAGCombiner::PromoteLoad(SDValue Op) {
1702 if (!LegalOperations)
1703 return false;
1704
1705 if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1706 return false;
1707
1708 EVT VT = Op.getValueType();
1709 if (VT.isVector() || !VT.isInteger())
1710 return false;
1711
1712 // If operation type is 'undesirable', e.g. i16 on x86, consider
1713 // promoting it.
1714 unsigned Opc = Op.getOpcode();
1715 if (TLI.isTypeDesirableForOp(Opc, VT))
1716 return false;
1717
1718 EVT PVT = VT;
1719 // Consult target whether it is a good idea to promote this operation and
1720 // what's the right type to promote it to.
1721 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1722 assert(PVT != VT && "Don't know what type to promote to!");
1723
1724 SDLoc DL(Op);
1725 SDNode *N = Op.getNode();
1726 LoadSDNode *LD = cast<LoadSDNode>(N);
1727 EVT MemVT = LD->getMemoryVT();
1728 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1729 : LD->getExtensionType();
1730 SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1731 LD->getChain(), LD->getBasePtr(),
1732 MemVT, LD->getMemOperand());
1733 SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1734
1735 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1736 Result.dump(&DAG); dbgs() << '\n');
1737
1738 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1739 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1740
1741 AddToWorklist(Result.getNode());
1742 recursivelyDeleteUnusedNodes(N);
1743 return true;
1744 }
1745
1746 return false;
1747 }
1748
1749 /// Recursively delete a node which has no uses and any operands for
1750 /// which it is the only use.
1751 ///
1752 /// Note that this both deletes the nodes and removes them from the worklist.
1753 /// It also adds any nodes who have had a user deleted to the worklist as they
1754 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1755 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1756 if (!N->use_empty())
1757 return false;
1758
1759 SmallSetVector<SDNode *, 16> Nodes;
1760 Nodes.insert(N);
1761 do {
1762 N = Nodes.pop_back_val();
1763 if (!N)
1764 continue;
1765
1766 if (N->use_empty()) {
1767 for (const SDValue &ChildN : N->op_values())
1768 Nodes.insert(ChildN.getNode());
1769
1770 removeFromWorklist(N);
1771 DAG.DeleteNode(N);
1772 } else {
1773 AddToWorklist(N);
1774 }
1775 } while (!Nodes.empty());
1776 return true;
1777 }
1778
1779 //===----------------------------------------------------------------------===//
1780 // Main DAG Combiner implementation
1781 //===----------------------------------------------------------------------===//
1782
Run(CombineLevel AtLevel)1783 void DAGCombiner::Run(CombineLevel AtLevel) {
1784 // set the instance variables, so that the various visit routines may use it.
1785 Level = AtLevel;
1786 LegalDAG = Level >= AfterLegalizeDAG;
1787 LegalOperations = Level >= AfterLegalizeVectorOps;
1788 LegalTypes = Level >= AfterLegalizeTypes;
1789
1790 WorklistInserter AddNodes(*this);
1791
1792 // Add all the dag nodes to the worklist.
1793 //
1794 // Note: All nodes are not added to PruningList here, this is because the only
1795 // nodes which can be deleted are those which have no uses and all other nodes
1796 // which would otherwise be added to the worklist by the first call to
1797 // getNextWorklistEntry are already present in it.
1798 for (SDNode &Node : DAG.allnodes())
1799 AddToWorklist(&Node, /* IsCandidateForPruning */ Node.use_empty());
1800
1801 // Create a dummy node (which is not added to allnodes), that adds a reference
1802 // to the root node, preventing it from being deleted, and tracking any
1803 // changes of the root.
1804 HandleSDNode Dummy(DAG.getRoot());
1805
1806 // While we have a valid worklist entry node, try to combine it.
1807 while (SDNode *N = getNextWorklistEntry()) {
1808 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1809 // N is deleted from the DAG, since they too may now be dead or may have a
1810 // reduced number of uses, allowing other xforms.
1811 if (recursivelyDeleteUnusedNodes(N))
1812 continue;
1813
1814 WorklistRemover DeadNodes(*this);
1815
1816 // If this combine is running after legalizing the DAG, re-legalize any
1817 // nodes pulled off the worklist.
1818 if (LegalDAG) {
1819 SmallSetVector<SDNode *, 16> UpdatedNodes;
1820 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1821
1822 for (SDNode *LN : UpdatedNodes)
1823 AddToWorklistWithUsers(LN);
1824
1825 if (!NIsValid)
1826 continue;
1827 }
1828
1829 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1830
1831 // Add any operands of the new node which have not yet been combined to the
1832 // worklist as well. getNextWorklistEntry flags nodes that have been
1833 // combined before. Because the worklist uniques things already, this won't
1834 // repeatedly process the same operand.
1835 for (const SDValue &ChildN : N->op_values())
1836 AddToWorklist(ChildN.getNode(), /*IsCandidateForPruning=*/true,
1837 /*SkipIfCombinedBefore=*/true);
1838
1839 SDValue RV = combine(N);
1840
1841 if (!RV.getNode())
1842 continue;
1843
1844 ++NodesCombined;
1845
1846 // Invalidate cached info.
1847 ChainsWithoutMergeableStores.clear();
1848
1849 // If we get back the same node we passed in, rather than a new node or
1850 // zero, we know that the node must have defined multiple values and
1851 // CombineTo was used. Since CombineTo takes care of the worklist
1852 // mechanics for us, we have no work to do in this case.
1853 if (RV.getNode() == N)
1854 continue;
1855
1856 assert(N->getOpcode() != ISD::DELETED_NODE &&
1857 RV.getOpcode() != ISD::DELETED_NODE &&
1858 "Node was deleted but visit returned new node!");
1859
1860 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1861
1862 if (N->getNumValues() == RV->getNumValues())
1863 DAG.ReplaceAllUsesWith(N, RV.getNode());
1864 else {
1865 assert(N->getValueType(0) == RV.getValueType() &&
1866 N->getNumValues() == 1 && "Type mismatch");
1867 DAG.ReplaceAllUsesWith(N, &RV);
1868 }
1869
1870 // Push the new node and any users onto the worklist. Omit this if the
1871 // new node is the EntryToken (e.g. if a store managed to get optimized
1872 // out), because re-visiting the EntryToken and its users will not uncover
1873 // any additional opportunities, but there may be a large number of such
1874 // users, potentially causing compile time explosion.
1875 if (RV.getOpcode() != ISD::EntryToken)
1876 AddToWorklistWithUsers(RV.getNode());
1877
1878 // Finally, if the node is now dead, remove it from the graph. The node
1879 // may not be dead if the replacement process recursively simplified to
1880 // something else needing this node. This will also take care of adding any
1881 // operands which have lost a user to the worklist.
1882 recursivelyDeleteUnusedNodes(N);
1883 }
1884
1885 // If the root changed (e.g. it was a dead load, update the root).
1886 DAG.setRoot(Dummy.getValue());
1887 DAG.RemoveDeadNodes();
1888 }
1889
visit(SDNode * N)1890 SDValue DAGCombiner::visit(SDNode *N) {
1891 // clang-format off
1892 switch (N->getOpcode()) {
1893 default: break;
1894 case ISD::TokenFactor: return visitTokenFactor(N);
1895 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1896 case ISD::ADD: return visitADD(N);
1897 case ISD::PTRADD: return visitPTRADD(N);
1898 case ISD::SUB: return visitSUB(N);
1899 case ISD::SADDSAT:
1900 case ISD::UADDSAT: return visitADDSAT(N);
1901 case ISD::SSUBSAT:
1902 case ISD::USUBSAT: return visitSUBSAT(N);
1903 case ISD::ADDC: return visitADDC(N);
1904 case ISD::SADDO:
1905 case ISD::UADDO: return visitADDO(N);
1906 case ISD::SUBC: return visitSUBC(N);
1907 case ISD::SSUBO:
1908 case ISD::USUBO: return visitSUBO(N);
1909 case ISD::ADDE: return visitADDE(N);
1910 case ISD::UADDO_CARRY: return visitUADDO_CARRY(N);
1911 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1912 case ISD::SUBE: return visitSUBE(N);
1913 case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N);
1914 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1915 case ISD::SMULFIX:
1916 case ISD::SMULFIXSAT:
1917 case ISD::UMULFIX:
1918 case ISD::UMULFIXSAT: return visitMULFIX(N);
1919 case ISD::MUL: return visitMUL<EmptyMatchContext>(N);
1920 case ISD::SDIV: return visitSDIV(N);
1921 case ISD::UDIV: return visitUDIV(N);
1922 case ISD::SREM:
1923 case ISD::UREM: return visitREM(N);
1924 case ISD::MULHU: return visitMULHU(N);
1925 case ISD::MULHS: return visitMULHS(N);
1926 case ISD::AVGFLOORS:
1927 case ISD::AVGFLOORU:
1928 case ISD::AVGCEILS:
1929 case ISD::AVGCEILU: return visitAVG(N);
1930 case ISD::ABDS:
1931 case ISD::ABDU: return visitABD(N);
1932 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1933 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1934 case ISD::SMULO:
1935 case ISD::UMULO: return visitMULO(N);
1936 case ISD::SMIN:
1937 case ISD::SMAX:
1938 case ISD::UMIN:
1939 case ISD::UMAX: return visitIMINMAX(N);
1940 case ISD::AND: return visitAND(N);
1941 case ISD::OR: return visitOR(N);
1942 case ISD::XOR: return visitXOR(N);
1943 case ISD::SHL: return visitSHL(N);
1944 case ISD::SRA: return visitSRA(N);
1945 case ISD::SRL: return visitSRL(N);
1946 case ISD::ROTR:
1947 case ISD::ROTL: return visitRotate(N);
1948 case ISD::FSHL:
1949 case ISD::FSHR: return visitFunnelShift(N);
1950 case ISD::SSHLSAT:
1951 case ISD::USHLSAT: return visitSHLSAT(N);
1952 case ISD::ABS: return visitABS(N);
1953 case ISD::BSWAP: return visitBSWAP(N);
1954 case ISD::BITREVERSE: return visitBITREVERSE(N);
1955 case ISD::CTLZ: return visitCTLZ(N);
1956 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1957 case ISD::CTTZ: return visitCTTZ(N);
1958 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1959 case ISD::CTPOP: return visitCTPOP(N);
1960 case ISD::SELECT: return visitSELECT(N);
1961 case ISD::VSELECT: return visitVSELECT(N);
1962 case ISD::SELECT_CC: return visitSELECT_CC(N);
1963 case ISD::SETCC: return visitSETCC(N);
1964 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1965 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1966 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
1967 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
1968 case ISD::AssertSext:
1969 case ISD::AssertZext: return visitAssertExt(N);
1970 case ISD::AssertAlign: return visitAssertAlign(N);
1971 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
1972 case ISD::SIGN_EXTEND_VECTOR_INREG:
1973 case ISD::ZERO_EXTEND_VECTOR_INREG:
1974 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1975 case ISD::TRUNCATE: return visitTRUNCATE(N);
1976 case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT_U(N);
1977 case ISD::BITCAST: return visitBITCAST(N);
1978 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
1979 case ISD::FADD: return visitFADD(N);
1980 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
1981 case ISD::FSUB: return visitFSUB(N);
1982 case ISD::FMUL: return visitFMUL(N);
1983 case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
1984 case ISD::FMAD: return visitFMAD(N);
1985 case ISD::FDIV: return visitFDIV(N);
1986 case ISD::FREM: return visitFREM(N);
1987 case ISD::FSQRT: return visitFSQRT(N);
1988 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
1989 case ISD::FPOW: return visitFPOW(N);
1990 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
1991 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
1992 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
1993 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
1994 case ISD::LROUND:
1995 case ISD::LLROUND:
1996 case ISD::LRINT:
1997 case ISD::LLRINT: return visitXROUND(N);
1998 case ISD::FP_ROUND: return visitFP_ROUND(N);
1999 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
2000 case ISD::FNEG: return visitFNEG(N);
2001 case ISD::FABS: return visitFABS(N);
2002 case ISD::FFLOOR: return visitFFLOOR(N);
2003 case ISD::FMINNUM:
2004 case ISD::FMAXNUM:
2005 case ISD::FMINIMUM:
2006 case ISD::FMAXIMUM:
2007 case ISD::FMINIMUMNUM:
2008 case ISD::FMAXIMUMNUM: return visitFMinMax(N);
2009 case ISD::FCEIL: return visitFCEIL(N);
2010 case ISD::FTRUNC: return visitFTRUNC(N);
2011 case ISD::FFREXP: return visitFFREXP(N);
2012 case ISD::BRCOND: return visitBRCOND(N);
2013 case ISD::BR_CC: return visitBR_CC(N);
2014 case ISD::LOAD: return visitLOAD(N);
2015 case ISD::STORE: return visitSTORE(N);
2016 case ISD::ATOMIC_STORE: return visitATOMIC_STORE(N);
2017 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
2018 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
2019 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
2020 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
2021 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
2022 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
2023 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
2024 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
2025 case ISD::MGATHER: return visitMGATHER(N);
2026 case ISD::MLOAD: return visitMLOAD(N);
2027 case ISD::MSCATTER: return visitMSCATTER(N);
2028 case ISD::MSTORE: return visitMSTORE(N);
2029 case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
2030 case ISD::PARTIAL_REDUCE_SMLA:
2031 case ISD::PARTIAL_REDUCE_UMLA:
2032 case ISD::PARTIAL_REDUCE_SUMLA:
2033 return visitPARTIAL_REDUCE_MLA(N);
2034 case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
2035 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
2036 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
2037 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
2038 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
2039 case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
2040 case ISD::FREEZE: return visitFREEZE(N);
2041 case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
2042 case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
2043 case ISD::FCANONICALIZE: return visitFCANONICALIZE(N);
2044 case ISD::VECREDUCE_FADD:
2045 case ISD::VECREDUCE_FMUL:
2046 case ISD::VECREDUCE_ADD:
2047 case ISD::VECREDUCE_MUL:
2048 case ISD::VECREDUCE_AND:
2049 case ISD::VECREDUCE_OR:
2050 case ISD::VECREDUCE_XOR:
2051 case ISD::VECREDUCE_SMAX:
2052 case ISD::VECREDUCE_SMIN:
2053 case ISD::VECREDUCE_UMAX:
2054 case ISD::VECREDUCE_UMIN:
2055 case ISD::VECREDUCE_FMAX:
2056 case ISD::VECREDUCE_FMIN:
2057 case ISD::VECREDUCE_FMAXIMUM:
2058 case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N);
2059 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
2060 #include "llvm/IR/VPIntrinsics.def"
2061 return visitVPOp(N);
2062 }
2063 // clang-format on
2064 return SDValue();
2065 }
2066
combine(SDNode * N)2067 SDValue DAGCombiner::combine(SDNode *N) {
2068 if (!DebugCounter::shouldExecute(DAGCombineCounter))
2069 return SDValue();
2070
2071 SDValue RV;
2072 if (!DisableGenericCombines)
2073 RV = visit(N);
2074
2075 // If nothing happened, try a target-specific DAG combine.
2076 if (!RV.getNode()) {
2077 assert(N->getOpcode() != ISD::DELETED_NODE &&
2078 "Node was deleted but visit returned NULL!");
2079
2080 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
2081 TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
2082
2083 // Expose the DAG combiner to the target combiner impls.
2084 TargetLowering::DAGCombinerInfo
2085 DagCombineInfo(DAG, Level, false, this);
2086
2087 RV = TLI.PerformDAGCombine(N, DagCombineInfo);
2088 }
2089 }
2090
2091 // If nothing happened still, try promoting the operation.
2092 if (!RV.getNode()) {
2093 switch (N->getOpcode()) {
2094 default: break;
2095 case ISD::ADD:
2096 case ISD::SUB:
2097 case ISD::MUL:
2098 case ISD::AND:
2099 case ISD::OR:
2100 case ISD::XOR:
2101 RV = PromoteIntBinOp(SDValue(N, 0));
2102 break;
2103 case ISD::SHL:
2104 case ISD::SRA:
2105 case ISD::SRL:
2106 RV = PromoteIntShiftOp(SDValue(N, 0));
2107 break;
2108 case ISD::SIGN_EXTEND:
2109 case ISD::ZERO_EXTEND:
2110 case ISD::ANY_EXTEND:
2111 RV = PromoteExtend(SDValue(N, 0));
2112 break;
2113 case ISD::LOAD:
2114 if (PromoteLoad(SDValue(N, 0)))
2115 RV = SDValue(N, 0);
2116 break;
2117 }
2118 }
2119
2120 // If N is a commutative binary node, try to eliminate it if the commuted
2121 // version is already present in the DAG.
2122 if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
2123 SDValue N0 = N->getOperand(0);
2124 SDValue N1 = N->getOperand(1);
2125
2126 // Constant operands are canonicalized to RHS.
2127 if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
2128 SDValue Ops[] = {N1, N0};
2129 SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
2130 N->getFlags());
2131 if (CSENode)
2132 return SDValue(CSENode, 0);
2133 }
2134 }
2135
2136 return RV;
2137 }
2138
2139 /// Given a node, return its input chain if it has one, otherwise return a null
2140 /// sd operand.
getInputChainForNode(SDNode * N)2141 static SDValue getInputChainForNode(SDNode *N) {
2142 if (unsigned NumOps = N->getNumOperands()) {
2143 if (N->getOperand(0).getValueType() == MVT::Other)
2144 return N->getOperand(0);
2145 if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
2146 return N->getOperand(NumOps-1);
2147 for (unsigned i = 1; i < NumOps-1; ++i)
2148 if (N->getOperand(i).getValueType() == MVT::Other)
2149 return N->getOperand(i);
2150 }
2151 return SDValue();
2152 }
2153
visitFCANONICALIZE(SDNode * N)2154 SDValue DAGCombiner::visitFCANONICALIZE(SDNode *N) {
2155 SDValue Operand = N->getOperand(0);
2156 EVT VT = Operand.getValueType();
2157 SDLoc dl(N);
2158
2159 // Canonicalize undef to quiet NaN.
2160 if (Operand.isUndef()) {
2161 APFloat CanonicalQNaN = APFloat::getQNaN(VT.getFltSemantics());
2162 return DAG.getConstantFP(CanonicalQNaN, dl, VT);
2163 }
2164 return SDValue();
2165 }
2166
visitTokenFactor(SDNode * N)2167 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2168 // If N has two operands, where one has an input chain equal to the other,
2169 // the 'other' chain is redundant.
2170 if (N->getNumOperands() == 2) {
2171 if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
2172 return N->getOperand(0);
2173 if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
2174 return N->getOperand(1);
2175 }
2176
2177 // Don't simplify token factors if optnone.
2178 if (OptLevel == CodeGenOptLevel::None)
2179 return SDValue();
2180
2181 // Don't simplify the token factor if the node itself has too many operands.
2182 if (N->getNumOperands() > TokenFactorInlineLimit)
2183 return SDValue();
2184
2185 // If the sole user is a token factor, we should make sure we have a
2186 // chance to merge them together. This prevents TF chains from inhibiting
2187 // optimizations.
2188 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TokenFactor)
2189 AddToWorklist(*(N->user_begin()));
2190
2191 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
2192 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
2193 SmallPtrSet<SDNode*, 16> SeenOps;
2194 bool Changed = false; // If we should replace this token factor.
2195
2196 // Start out with this token factor.
2197 TFs.push_back(N);
2198
2199 // Iterate through token factors. The TFs grows when new token factors are
2200 // encountered.
2201 for (unsigned i = 0; i < TFs.size(); ++i) {
2202 // Limit number of nodes to inline, to avoid quadratic compile times.
2203 // We have to add the outstanding Token Factors to Ops, otherwise we might
2204 // drop Ops from the resulting Token Factors.
2205 if (Ops.size() > TokenFactorInlineLimit) {
2206 for (unsigned j = i; j < TFs.size(); j++)
2207 Ops.emplace_back(TFs[j], 0);
2208 // Drop unprocessed Token Factors from TFs, so we do not add them to the
2209 // combiner worklist later.
2210 TFs.resize(i);
2211 break;
2212 }
2213
2214 SDNode *TF = TFs[i];
2215 // Check each of the operands.
2216 for (const SDValue &Op : TF->op_values()) {
2217 switch (Op.getOpcode()) {
2218 case ISD::EntryToken:
2219 // Entry tokens don't need to be added to the list. They are
2220 // redundant.
2221 Changed = true;
2222 break;
2223
2224 case ISD::TokenFactor:
2225 if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
2226 // Queue up for processing.
2227 TFs.push_back(Op.getNode());
2228 Changed = true;
2229 break;
2230 }
2231 [[fallthrough]];
2232
2233 default:
2234 // Only add if it isn't already in the list.
2235 if (SeenOps.insert(Op.getNode()).second)
2236 Ops.push_back(Op);
2237 else
2238 Changed = true;
2239 break;
2240 }
2241 }
2242 }
2243
2244 // Re-visit inlined Token Factors, to clean them up in case they have been
2245 // removed. Skip the first Token Factor, as this is the current node.
2246 for (unsigned i = 1, e = TFs.size(); i < e; i++)
2247 AddToWorklist(TFs[i]);
2248
2249 // Remove Nodes that are chained to another node in the list. Do so
2250 // by walking up chains breath-first stopping when we've seen
2251 // another operand. In general we must climb to the EntryNode, but we can exit
2252 // early if we find all remaining work is associated with just one operand as
2253 // no further pruning is possible.
2254
2255 // List of nodes to search through and original Ops from which they originate.
2256 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2257 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2258 SmallPtrSet<SDNode *, 16> SeenChains;
2259 bool DidPruneOps = false;
2260
2261 unsigned NumLeftToConsider = 0;
2262 for (const SDValue &Op : Ops) {
2263 Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2264 OpWorkCount.push_back(1);
2265 }
2266
2267 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2268 // If this is an Op, we can remove the op from the list. Remark any
2269 // search associated with it as from the current OpNumber.
2270 if (SeenOps.contains(Op)) {
2271 Changed = true;
2272 DidPruneOps = true;
2273 unsigned OrigOpNumber = 0;
2274 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2275 OrigOpNumber++;
2276 assert((OrigOpNumber != Ops.size()) &&
2277 "expected to find TokenFactor Operand");
2278 // Re-mark worklist from OrigOpNumber to OpNumber
2279 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2280 if (Worklist[i].second == OrigOpNumber) {
2281 Worklist[i].second = OpNumber;
2282 }
2283 }
2284 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2285 OpWorkCount[OrigOpNumber] = 0;
2286 NumLeftToConsider--;
2287 }
2288 // Add if it's a new chain
2289 if (SeenChains.insert(Op).second) {
2290 OpWorkCount[OpNumber]++;
2291 Worklist.push_back(std::make_pair(Op, OpNumber));
2292 }
2293 };
2294
2295 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2296 // We need at least be consider at least 2 Ops to prune.
2297 if (NumLeftToConsider <= 1)
2298 break;
2299 auto CurNode = Worklist[i].first;
2300 auto CurOpNumber = Worklist[i].second;
2301 assert((OpWorkCount[CurOpNumber] > 0) &&
2302 "Node should not appear in worklist");
2303 switch (CurNode->getOpcode()) {
2304 case ISD::EntryToken:
2305 // Hitting EntryToken is the only way for the search to terminate without
2306 // hitting
2307 // another operand's search. Prevent us from marking this operand
2308 // considered.
2309 NumLeftToConsider++;
2310 break;
2311 case ISD::TokenFactor:
2312 for (const SDValue &Op : CurNode->op_values())
2313 AddToWorklist(i, Op.getNode(), CurOpNumber);
2314 break;
2315 case ISD::LIFETIME_START:
2316 case ISD::LIFETIME_END:
2317 case ISD::CopyFromReg:
2318 case ISD::CopyToReg:
2319 AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2320 break;
2321 default:
2322 if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2323 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2324 break;
2325 }
2326 OpWorkCount[CurOpNumber]--;
2327 if (OpWorkCount[CurOpNumber] == 0)
2328 NumLeftToConsider--;
2329 }
2330
2331 // If we've changed things around then replace token factor.
2332 if (Changed) {
2333 SDValue Result;
2334 if (Ops.empty()) {
2335 // The entry token is the only possible outcome.
2336 Result = DAG.getEntryNode();
2337 } else {
2338 if (DidPruneOps) {
2339 SmallVector<SDValue, 8> PrunedOps;
2340 //
2341 for (const SDValue &Op : Ops) {
2342 if (SeenChains.count(Op.getNode()) == 0)
2343 PrunedOps.push_back(Op);
2344 }
2345 Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2346 } else {
2347 Result = DAG.getTokenFactor(SDLoc(N), Ops);
2348 }
2349 }
2350 return Result;
2351 }
2352 return SDValue();
2353 }
2354
2355 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2356 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2357 WorklistRemover DeadNodes(*this);
2358 // Replacing results may cause a different MERGE_VALUES to suddenly
2359 // be CSE'd with N, and carry its uses with it. Iterate until no
2360 // uses remain, to ensure that the node can be safely deleted.
2361 // First add the users of this node to the work list so that they
2362 // can be tried again once they have new operands.
2363 AddUsersToWorklist(N);
2364 do {
2365 // Do as a single replacement to avoid rewalking use lists.
2366 SmallVector<SDValue, 8> Ops(N->ops());
2367 DAG.ReplaceAllUsesWith(N, Ops.data());
2368 } while (!N->use_empty());
2369 deleteAndRecombine(N);
2370 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2371 }
2372
2373 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2374 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2375 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2376 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2377 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2378 }
2379
2380 // isTruncateOf - If N is a truncate of some other value, return true, record
2381 // the value being truncated in Op and which of Op's bits are zero/one in Known.
2382 // This function computes KnownBits to avoid a duplicated call to
2383 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)2384 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2385 KnownBits &Known) {
2386 if (N->getOpcode() == ISD::TRUNCATE) {
2387 Op = N->getOperand(0);
2388 Known = DAG.computeKnownBits(Op);
2389 if (N->getFlags().hasNoUnsignedWrap())
2390 Known.Zero.setBitsFrom(N.getScalarValueSizeInBits());
2391 return true;
2392 }
2393
2394 if (N.getValueType().getScalarType() != MVT::i1 ||
2395 !sd_match(
2396 N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
2397 return false;
2398
2399 Known = DAG.computeKnownBits(Op);
2400 return (Known.Zero | 1).isAllOnes();
2401 }
2402
2403 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2404 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2405 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2406 const TargetLowering &TLI) {
2407 EVT VT;
2408 unsigned AS;
2409
2410 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2411 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2412 return false;
2413 VT = LD->getMemoryVT();
2414 AS = LD->getAddressSpace();
2415 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2416 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2417 return false;
2418 VT = ST->getMemoryVT();
2419 AS = ST->getAddressSpace();
2420 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2421 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2422 return false;
2423 VT = LD->getMemoryVT();
2424 AS = LD->getAddressSpace();
2425 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2426 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2427 return false;
2428 VT = ST->getMemoryVT();
2429 AS = ST->getAddressSpace();
2430 } else {
2431 return false;
2432 }
2433
2434 TargetLowering::AddrMode AM;
2435 if (N->isAnyAdd()) {
2436 AM.HasBaseReg = true;
2437 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2438 if (Offset)
2439 // [reg +/- imm]
2440 AM.BaseOffs = Offset->getSExtValue();
2441 else
2442 // [reg +/- reg]
2443 AM.Scale = 1;
2444 } else if (N->getOpcode() == ISD::SUB) {
2445 AM.HasBaseReg = true;
2446 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2447 if (Offset)
2448 // [reg +/- imm]
2449 AM.BaseOffs = -Offset->getSExtValue();
2450 else
2451 // [reg +/- reg]
2452 AM.Scale = 1;
2453 } else {
2454 return false;
2455 }
2456
2457 return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2458 VT.getTypeForEVT(*DAG.getContext()), AS);
2459 }
2460
2461 /// This inverts a canonicalization in IR that replaces a variable select arm
2462 /// with an identity constant. Codegen improves if we re-use the variable
2463 /// operand rather than load a constant. This can also be converted into a
2464 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2465 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2466 bool ShouldCommuteOperands) {
2467 // Match a select as operand 1. The identity constant that we are looking for
2468 // is only valid as operand 1 of a non-commutative binop.
2469 SDValue N0 = N->getOperand(0);
2470 SDValue N1 = N->getOperand(1);
2471 if (ShouldCommuteOperands)
2472 std::swap(N0, N1);
2473
2474 unsigned SelOpcode = N1.getOpcode();
2475 if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
2476 !N1.hasOneUse())
2477 return SDValue();
2478
2479 // We can't hoist all instructions because of immediate UB (not speculatable).
2480 // For example div/rem by zero.
2481 if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2482 return SDValue();
2483
2484 unsigned Opcode = N->getOpcode();
2485 EVT VT = N->getValueType(0);
2486 SDValue Cond = N1.getOperand(0);
2487 SDValue TVal = N1.getOperand(1);
2488 SDValue FVal = N1.getOperand(2);
2489 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2490
2491 // This transform increases uses of N0, so freeze it to be safe.
2492 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2493 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2494 if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
2495 TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2496 FVal)) {
2497 SDValue F0 = DAG.getFreeze(N0);
2498 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2499 return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2500 }
2501 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2502 if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
2503 TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2504 TVal)) {
2505 SDValue F0 = DAG.getFreeze(N0);
2506 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2507 return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2508 }
2509
2510 return SDValue();
2511 }
2512
foldBinOpIntoSelect(SDNode * BO)2513 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2514 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2515 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2516 "Unexpected binary operator");
2517
2518 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2519 return Sel;
2520
2521 if (TLI.isCommutativeBinOp(BO->getOpcode()))
2522 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2523 return Sel;
2524
2525 // Don't do this unless the old select is going away. We want to eliminate the
2526 // binary operator, not replace a binop with a select.
2527 // TODO: Handle ISD::SELECT_CC.
2528 unsigned SelOpNo = 0;
2529 SDValue Sel = BO->getOperand(0);
2530 auto BinOpcode = BO->getOpcode();
2531 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2532 SelOpNo = 1;
2533 Sel = BO->getOperand(1);
2534
2535 // Peek through trunc to shift amount type.
2536 if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2537 BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2538 // This is valid when the truncated bits of x are already zero.
2539 SDValue Op;
2540 KnownBits Known;
2541 if (isTruncateOf(DAG, Sel, Op, Known) &&
2542 Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2543 Sel = Op;
2544 }
2545 }
2546
2547 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2548 return SDValue();
2549
2550 SDValue CT = Sel.getOperand(1);
2551 if (!isConstantOrConstantVector(CT, true) &&
2552 !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2553 return SDValue();
2554
2555 SDValue CF = Sel.getOperand(2);
2556 if (!isConstantOrConstantVector(CF, true) &&
2557 !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2558 return SDValue();
2559
2560 // Bail out if any constants are opaque because we can't constant fold those.
2561 // The exception is "and" and "or" with either 0 or -1 in which case we can
2562 // propagate non constant operands into select. I.e.:
2563 // and (select Cond, 0, -1), X --> select Cond, 0, X
2564 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2565 bool CanFoldNonConst =
2566 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2567 ((isNullOrNullSplat(CT) && isAllOnesOrAllOnesSplat(CF)) ||
2568 (isNullOrNullSplat(CF) && isAllOnesOrAllOnesSplat(CT)));
2569
2570 SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2571 if (!CanFoldNonConst &&
2572 !isConstantOrConstantVector(CBO, true) &&
2573 !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2574 return SDValue();
2575
2576 SDLoc DL(Sel);
2577 SDValue NewCT, NewCF;
2578 EVT VT = BO->getValueType(0);
2579
2580 if (CanFoldNonConst) {
2581 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2582 if ((BinOpcode == ISD::AND && isNullOrNullSplat(CT)) ||
2583 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CT)))
2584 NewCT = CT;
2585 else
2586 NewCT = CBO;
2587
2588 if ((BinOpcode == ISD::AND && isNullOrNullSplat(CF)) ||
2589 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CF)))
2590 NewCF = CF;
2591 else
2592 NewCF = CBO;
2593 } else {
2594 // We have a select-of-constants followed by a binary operator with a
2595 // constant. Eliminate the binop by pulling the constant math into the
2596 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2597 // CBO, CF + CBO
2598 NewCT = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CT})
2599 : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CT, CBO});
2600 if (!NewCT)
2601 return SDValue();
2602
2603 NewCF = SelOpNo ? DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CBO, CF})
2604 : DAG.FoldConstantArithmetic(BinOpcode, DL, VT, {CF, CBO});
2605 if (!NewCF)
2606 return SDValue();
2607 }
2608
2609 return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF, BO->getFlags());
2610 }
2611
foldAddSubBoolOfMaskedVal(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)2612 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
2613 SelectionDAG &DAG) {
2614 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2615 "Expecting add or sub");
2616
2617 // Match a constant operand and a zext operand for the math instruction:
2618 // add Z, C
2619 // sub C, Z
2620 bool IsAdd = N->getOpcode() == ISD::ADD;
2621 SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2622 SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2623 auto *CN = dyn_cast<ConstantSDNode>(C);
2624 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2625 return SDValue();
2626
2627 // Match the zext operand as a setcc of a boolean.
2628 if (Z.getOperand(0).getValueType() != MVT::i1)
2629 return SDValue();
2630
2631 // Match the compare as: setcc (X & 1), 0, eq.
2632 if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
2633 m_SpecificCondCode(ISD::SETEQ))))
2634 return SDValue();
2635
2636 // We are adding/subtracting a constant and an inverted low bit. Turn that
2637 // into a subtract/add of the low bit with incremented/decremented constant:
2638 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2639 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2640 EVT VT = C.getValueType();
2641 SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
2642 SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
2643 : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2644 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2645 }
2646
2647 // Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
foldSubToAvg(SDNode * N,const SDLoc & DL)2648 SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2649 SDValue N0 = N->getOperand(0);
2650 EVT VT = N0.getValueType();
2651 SDValue A, B;
2652
2653 if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
2654 sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
2655 m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
2656 return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
2657 }
2658 if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
2659 sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
2660 m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
2661 return DAG.getNode(ISD::AVGCEILS, DL, VT, A, B);
2662 }
2663 return SDValue();
2664 }
2665
2666 /// Try to fold a pointer arithmetic node.
2667 /// This needs to be done separately from normal addition, because pointer
2668 /// addition is not commutative.
visitPTRADD(SDNode * N)2669 SDValue DAGCombiner::visitPTRADD(SDNode *N) {
2670 SDValue N0 = N->getOperand(0);
2671 SDValue N1 = N->getOperand(1);
2672 EVT PtrVT = N0.getValueType();
2673 EVT IntVT = N1.getValueType();
2674 SDLoc DL(N);
2675
2676 // This is already ensured by an assert in SelectionDAG::getNode(). Several
2677 // combines here depend on this assumption.
2678 assert(PtrVT == IntVT &&
2679 "PTRADD with different operand types is not supported");
2680
2681 // fold (ptradd x, 0) -> x
2682 if (isNullConstant(N1))
2683 return N0;
2684
2685 // fold (ptradd 0, x) -> x
2686 if (PtrVT == IntVT && isNullConstant(N0))
2687 return N1;
2688
2689 if (N0.getOpcode() != ISD::PTRADD ||
2690 reassociationCanBreakAddressingModePattern(ISD::PTRADD, DL, N, N0, N1))
2691 return SDValue();
2692
2693 SDValue X = N0.getOperand(0);
2694 SDValue Y = N0.getOperand(1);
2695 SDValue Z = N1;
2696 bool N0OneUse = N0.hasOneUse();
2697 bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
2698 bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
2699
2700 // (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
2701 // * y is a constant and (ptradd x, y) has one use; or
2702 // * y and z are both constants.
2703 if ((YIsConstant && N0OneUse) || (YIsConstant && ZIsConstant)) {
2704 // If both additions in the original were NUW, the new ones are as well.
2705 SDNodeFlags Flags =
2706 (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;
2707 SDValue Add = DAG.getNode(ISD::ADD, DL, IntVT, {Y, Z}, Flags);
2708 AddToWorklist(Add.getNode());
2709 return DAG.getMemBasePlusOffset(X, Add, DL, Flags);
2710 }
2711
2712 // TODO: There is another possible fold here that was proven useful.
2713 // It would be this:
2714 //
2715 // (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y) if:
2716 // * (ptradd x, y) has one use; and
2717 // * y is a constant; and
2718 // * z is not a constant.
2719 //
2720 // In some cases, specifically in AArch64's FEAT_CPA, it exposes the
2721 // opportunity to select more complex instructions such as SUBPT and
2722 // MSUBPT. However, a hypothetical corner case has been found that we could
2723 // not avoid. Consider this (pseudo-POSIX C):
2724 //
2725 // char *foo(char *x, int z) {return (x + LARGE_CONSTANT) + z;}
2726 // char *p = mmap(LARGE_CONSTANT);
2727 // char *q = foo(p, -LARGE_CONSTANT);
2728 //
2729 // Then x + LARGE_CONSTANT is one-past-the-end, so valid, and a
2730 // further + z takes it back to the start of the mapping, so valid,
2731 // regardless of the address mmap gave back. However, if mmap gives you an
2732 // address < LARGE_CONSTANT (ignoring high bits), x - LARGE_CONSTANT will
2733 // borrow from the high bits (with the subsequent + z carrying back into
2734 // the high bits to give you a well-defined pointer) and thus trip
2735 // FEAT_CPA's pointer corruption checks.
2736 //
2737 // We leave this fold as an opportunity for future work, addressing the
2738 // corner case for FEAT_CPA, as well as reconciling the solution with the
2739 // more general application of pointer arithmetic in other future targets.
2740 // For now each architecture that wants this fold must implement it in the
2741 // target-specific code (see e.g. SITargetLowering::performPtrAddCombine)
2742
2743 return SDValue();
2744 }
2745
2746 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2747 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)2748 static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
2749 SelectionDAG &DAG) {
2750 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2751 "Expecting add or sub");
2752
2753 // We need a constant operand for the add/sub, and the other operand is a
2754 // logical shift right: add (srl), C or sub C, (srl).
2755 bool IsAdd = N->getOpcode() == ISD::ADD;
2756 SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2757 SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2758 if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2759 ShiftOp.getOpcode() != ISD::SRL)
2760 return SDValue();
2761
2762 // The shift must be of a 'not' value.
2763 SDValue Not = ShiftOp.getOperand(0);
2764 if (!Not.hasOneUse() || !isBitwiseNot(Not))
2765 return SDValue();
2766
2767 // The shift must be moving the sign bit to the least-significant-bit.
2768 EVT VT = ShiftOp.getValueType();
2769 SDValue ShAmt = ShiftOp.getOperand(1);
2770 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2771 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2772 return SDValue();
2773
2774 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2775 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2776 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2777 if (SDValue NewC = DAG.FoldConstantArithmetic(
2778 IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2779 {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2780 SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2781 Not.getOperand(0), ShAmt);
2782 return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2783 }
2784
2785 return SDValue();
2786 }
2787
2788 static bool
areBitwiseNotOfEachother(SDValue Op0,SDValue Op1)2789 areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2790 return (isBitwiseNot(Op0) && Op0.getOperand(0) == Op1) ||
2791 (isBitwiseNot(Op1) && Op1.getOperand(0) == Op0);
2792 }
2793
2794 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2795 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2796 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2797 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2798 SDValue N0 = N->getOperand(0);
2799 SDValue N1 = N->getOperand(1);
2800 EVT VT = N0.getValueType();
2801 SDLoc DL(N);
2802
2803 // fold (add x, undef) -> undef
2804 if (N0.isUndef())
2805 return N0;
2806 if (N1.isUndef())
2807 return N1;
2808
2809 // fold (add c1, c2) -> c1+c2
2810 if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2811 return C;
2812
2813 // canonicalize constant to RHS
2814 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2815 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2816 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2817
2818 if (areBitwiseNotOfEachother(N0, N1))
2819 return DAG.getConstant(APInt::getAllOnes(VT.getScalarSizeInBits()), DL, VT);
2820
2821 // fold vector ops
2822 if (VT.isVector()) {
2823 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2824 return FoldedVOp;
2825
2826 // fold (add x, 0) -> x, vector edition
2827 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2828 return N0;
2829 }
2830
2831 // fold (add x, 0) -> x
2832 if (isNullConstant(N1))
2833 return N0;
2834
2835 if (N0.getOpcode() == ISD::SUB) {
2836 SDValue N00 = N0.getOperand(0);
2837 SDValue N01 = N0.getOperand(1);
2838
2839 // fold ((A-c1)+c2) -> (A+(c2-c1))
2840 if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2841 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2842
2843 // fold ((c1-A)+c2) -> (c1+c2)-A
2844 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2845 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2846 }
2847
2848 // add (sext i1 X), 1 -> zext (not i1 X)
2849 // We don't transform this pattern:
2850 // add (zext i1 X), -1 -> sext (not i1 X)
2851 // because most (?) targets generate better code for the zext form.
2852 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2853 isOneOrOneSplat(N1)) {
2854 SDValue X = N0.getOperand(0);
2855 if ((!LegalOperations ||
2856 (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2857 TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2858 X.getScalarValueSizeInBits() == 1) {
2859 SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2860 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2861 }
2862 }
2863
2864 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2865 // iff (or x, c0) is equivalent to (add x, c0).
2866 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2867 // iff (xor x, c0) is equivalent to (add x, c0).
2868 if (DAG.isADDLike(N0)) {
2869 SDValue N01 = N0.getOperand(1);
2870 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2871 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2872 }
2873
2874 if (SDValue NewSel = foldBinOpIntoSelect(N))
2875 return NewSel;
2876
2877 // reassociate add
2878 if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2879 if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2880 return RADD;
2881
2882 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2883 // equivalent to (add x, c).
2884 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2885 // equivalent to (add x, c).
2886 // Do this optimization only when adding c does not introduce instructions
2887 // for adding carries.
2888 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2889 if (DAG.isADDLike(N0) && N0.hasOneUse() &&
2890 isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2891 // If N0's type does not split or is a sign mask, it does not introduce
2892 // add carry.
2893 auto TyActn = TLI.getTypeAction(*DAG.getContext(), N0.getValueType());
2894 bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2895 TyActn == TargetLoweringBase::TypePromoteInteger ||
2896 isMinSignedConstant(N0.getOperand(1));
2897 if (NoAddCarry)
2898 return DAG.getNode(
2899 ISD::ADD, DL, VT,
2900 DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2901 N0.getOperand(1));
2902 }
2903 return SDValue();
2904 };
2905 if (SDValue Add = ReassociateAddOr(N0, N1))
2906 return Add;
2907 if (SDValue Add = ReassociateAddOr(N1, N0))
2908 return Add;
2909
2910 // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2911 if (SDValue SD =
2912 reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
2913 return SD;
2914 }
2915
2916 SDValue A, B, C, D;
2917
2918 // fold ((0-A) + B) -> B-A
2919 if (sd_match(N0, m_Neg(m_Value(A))))
2920 return DAG.getNode(ISD::SUB, DL, VT, N1, A);
2921
2922 // fold (A + (0-B)) -> A-B
2923 if (sd_match(N1, m_Neg(m_Value(B))))
2924 return DAG.getNode(ISD::SUB, DL, VT, N0, B);
2925
2926 // fold (A+(B-A)) -> B
2927 if (sd_match(N1, m_Sub(m_Value(B), m_Specific(N0))))
2928 return B;
2929
2930 // fold ((B-A)+A) -> B
2931 if (sd_match(N0, m_Sub(m_Value(B), m_Specific(N1))))
2932 return B;
2933
2934 // fold ((A-B)+(C-A)) -> (C-B)
2935 if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
2936 sd_match(N1, m_Sub(m_Value(C), m_Specific(A))))
2937 return DAG.getNode(ISD::SUB, DL, VT, C, B);
2938
2939 // fold ((A-B)+(B-C)) -> (A-C)
2940 if (sd_match(N0, m_Sub(m_Value(A), m_Value(B))) &&
2941 sd_match(N1, m_Sub(m_Specific(B), m_Value(C))))
2942 return DAG.getNode(ISD::SUB, DL, VT, A, C);
2943
2944 // fold (A+(B-(A+C))) to (B-C)
2945 // fold (A+(B-(C+A))) to (B-C)
2946 if (sd_match(N1, m_Sub(m_Value(B), m_Add(m_Specific(N0), m_Value(C)))))
2947 return DAG.getNode(ISD::SUB, DL, VT, B, C);
2948
2949 // fold (A+((B-A)+or-C)) to (B+or-C)
2950 if (sd_match(N1,
2951 m_AnyOf(m_Add(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)),
2952 m_Sub(m_Sub(m_Value(B), m_Specific(N0)), m_Value(C)))))
2953 return DAG.getNode(N1.getOpcode(), DL, VT, B, C);
2954
2955 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2956 if (sd_match(N0, m_OneUse(m_Sub(m_Value(A), m_Value(B)))) &&
2957 sd_match(N1, m_OneUse(m_Sub(m_Value(C), m_Value(D)))) &&
2958 (isConstantOrConstantVector(A) || isConstantOrConstantVector(C)))
2959 return DAG.getNode(ISD::SUB, DL, VT,
2960 DAG.getNode(ISD::ADD, SDLoc(N0), VT, A, C),
2961 DAG.getNode(ISD::ADD, SDLoc(N1), VT, B, D));
2962
2963 // fold (add (umax X, C), -C) --> (usubsat X, C)
2964 if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2965 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2966 return (!Max && !Op) ||
2967 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2968 };
2969 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2970 /*AllowUndefs*/ true))
2971 return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2972 N0.getOperand(1));
2973 }
2974
2975 if (SimplifyDemandedBits(SDValue(N, 0)))
2976 return SDValue(N, 0);
2977
2978 if (isOneOrOneSplat(N1)) {
2979 // fold (add (xor a, -1), 1) -> (sub 0, a)
2980 if (isBitwiseNot(N0))
2981 return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2982 N0.getOperand(0));
2983
2984 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2985 if (N0.getOpcode() == ISD::ADD) {
2986 SDValue A, Xor;
2987
2988 if (isBitwiseNot(N0.getOperand(0))) {
2989 A = N0.getOperand(1);
2990 Xor = N0.getOperand(0);
2991 } else if (isBitwiseNot(N0.getOperand(1))) {
2992 A = N0.getOperand(0);
2993 Xor = N0.getOperand(1);
2994 }
2995
2996 if (Xor)
2997 return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2998 }
2999
3000 // Look for:
3001 // add (add x, y), 1
3002 // And if the target does not like this form then turn into:
3003 // sub y, (xor x, -1)
3004 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3005 N0.hasOneUse() &&
3006 // Limit this to after legalization if the add has wrap flags
3007 (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
3008 !N->getFlags().hasNoSignedWrap()))) {
3009 SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
3010 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
3011 }
3012 }
3013
3014 // (x - y) + -1 -> add (xor y, -1), x
3015 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3016 isAllOnesOrAllOnesSplat(N1, /*AllowUndefs=*/true)) {
3017 SDValue Not = DAG.getNOT(DL, N0.getOperand(1), VT);
3018 return DAG.getNode(ISD::ADD, DL, VT, Not, N0.getOperand(0));
3019 }
3020
3021 // Fold add(mul(add(A, CA), CM), CB) -> add(mul(A, CM), CM*CA+CB).
3022 // This can help if the inner add has multiple uses.
3023 APInt CM, CA;
3024 if (ConstantSDNode *CB = dyn_cast<ConstantSDNode>(N1)) {
3025 if (VT.getScalarSizeInBits() <= 64) {
3026 if (sd_match(N0, m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
3027 m_ConstInt(CM)))) &&
3028 TLI.isLegalAddImmediate(
3029 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3030 SDNodeFlags Flags;
3031 // If all the inputs are nuw, the outputs can be nuw. If all the input
3032 // are _also_ nsw the outputs can be too.
3033 if (N->getFlags().hasNoUnsignedWrap() &&
3034 N0->getFlags().hasNoUnsignedWrap() &&
3035 N0.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
3036 Flags |= SDNodeFlags::NoUnsignedWrap;
3037 if (N->getFlags().hasNoSignedWrap() &&
3038 N0->getFlags().hasNoSignedWrap() &&
3039 N0.getOperand(0)->getFlags().hasNoSignedWrap())
3040 Flags |= SDNodeFlags::NoSignedWrap;
3041 }
3042 SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
3043 DAG.getConstant(CM, DL, VT), Flags);
3044 return DAG.getNode(
3045 ISD::ADD, DL, VT, Mul,
3046 DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3047 }
3048 // Also look in case there is an intermediate add.
3049 if (sd_match(N0, m_OneUse(m_Add(
3050 m_OneUse(m_Mul(m_Add(m_Value(A), m_ConstInt(CA)),
3051 m_ConstInt(CM))),
3052 m_Value(B)))) &&
3053 TLI.isLegalAddImmediate(
3054 (CA * CM + CB->getAPIntValue()).getSExtValue())) {
3055 SDNodeFlags Flags;
3056 // If all the inputs are nuw, the outputs can be nuw. If all the input
3057 // are _also_ nsw the outputs can be too.
3058 SDValue OMul =
3059 N0.getOperand(0) == B ? N0.getOperand(1) : N0.getOperand(0);
3060 if (N->getFlags().hasNoUnsignedWrap() &&
3061 N0->getFlags().hasNoUnsignedWrap() &&
3062 OMul->getFlags().hasNoUnsignedWrap() &&
3063 OMul.getOperand(0)->getFlags().hasNoUnsignedWrap()) {
3064 Flags |= SDNodeFlags::NoUnsignedWrap;
3065 if (N->getFlags().hasNoSignedWrap() &&
3066 N0->getFlags().hasNoSignedWrap() &&
3067 OMul->getFlags().hasNoSignedWrap() &&
3068 OMul.getOperand(0)->getFlags().hasNoSignedWrap())
3069 Flags |= SDNodeFlags::NoSignedWrap;
3070 }
3071 SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N1), VT, A,
3072 DAG.getConstant(CM, DL, VT), Flags);
3073 SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N1), VT, Mul, B, Flags);
3074 return DAG.getNode(
3075 ISD::ADD, DL, VT, Add,
3076 DAG.getConstant(CA * CM + CB->getAPIntValue(), DL, VT), Flags);
3077 }
3078 }
3079 }
3080
3081 if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
3082 return Combined;
3083
3084 if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
3085 return Combined;
3086
3087 return SDValue();
3088 }
3089
3090 // Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
foldAddToAvg(SDNode * N,const SDLoc & DL)3091 SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
3092 SDValue N0 = N->getOperand(0);
3093 EVT VT = N0.getValueType();
3094 SDValue A, B;
3095
3096 if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
3097 sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
3098 m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
3099 return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
3100 }
3101 if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
3102 sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
3103 m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)), m_One())))) {
3104 return DAG.getNode(ISD::AVGFLOORS, DL, VT, A, B);
3105 }
3106
3107 return SDValue();
3108 }
3109
visitADD(SDNode * N)3110 SDValue DAGCombiner::visitADD(SDNode *N) {
3111 SDValue N0 = N->getOperand(0);
3112 SDValue N1 = N->getOperand(1);
3113 EVT VT = N0.getValueType();
3114 SDLoc DL(N);
3115
3116 if (SDValue Combined = visitADDLike(N))
3117 return Combined;
3118
3119 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3120 return V;
3121
3122 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3123 return V;
3124
3125 if (SDValue V = MatchRotate(N0, N1, SDLoc(N), /*FromAdd=*/true))
3126 return V;
3127
3128 // Try to match AVGFLOOR fixedwidth pattern
3129 if (SDValue V = foldAddToAvg(N, DL))
3130 return V;
3131
3132 // fold (a+b) -> (a|b) iff a and b share no bits.
3133 if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
3134 DAG.haveNoCommonBitsSet(N0, N1))
3135 return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
3136
3137 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
3138 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
3139 const APInt &C0 = N0->getConstantOperandAPInt(0);
3140 const APInt &C1 = N1->getConstantOperandAPInt(0);
3141 return DAG.getVScale(DL, VT, C0 + C1);
3142 }
3143
3144 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
3145 if (N0.getOpcode() == ISD::ADD &&
3146 N0.getOperand(1).getOpcode() == ISD::VSCALE &&
3147 N1.getOpcode() == ISD::VSCALE) {
3148 const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3149 const APInt &VS1 = N1->getConstantOperandAPInt(0);
3150 SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
3151 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
3152 }
3153
3154 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
3155 if (N0.getOpcode() == ISD::STEP_VECTOR &&
3156 N1.getOpcode() == ISD::STEP_VECTOR) {
3157 const APInt &C0 = N0->getConstantOperandAPInt(0);
3158 const APInt &C1 = N1->getConstantOperandAPInt(0);
3159 APInt NewStep = C0 + C1;
3160 return DAG.getStepVector(DL, VT, NewStep);
3161 }
3162
3163 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
3164 if (N0.getOpcode() == ISD::ADD &&
3165 N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR &&
3166 N1.getOpcode() == ISD::STEP_VECTOR) {
3167 const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
3168 const APInt &SV1 = N1->getConstantOperandAPInt(0);
3169 APInt NewStep = SV0 + SV1;
3170 SDValue SV = DAG.getStepVector(DL, VT, NewStep);
3171 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
3172 }
3173
3174 return SDValue();
3175 }
3176
visitADDSAT(SDNode * N)3177 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
3178 unsigned Opcode = N->getOpcode();
3179 SDValue N0 = N->getOperand(0);
3180 SDValue N1 = N->getOperand(1);
3181 EVT VT = N0.getValueType();
3182 bool IsSigned = Opcode == ISD::SADDSAT;
3183 SDLoc DL(N);
3184
3185 // fold (add_sat x, undef) -> -1
3186 if (N0.isUndef() || N1.isUndef())
3187 return DAG.getAllOnesConstant(DL, VT);
3188
3189 // fold (add_sat c1, c2) -> c3
3190 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
3191 return C;
3192
3193 // canonicalize constant to RHS
3194 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3195 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3196 return DAG.getNode(Opcode, DL, VT, N1, N0);
3197
3198 // fold vector ops
3199 if (VT.isVector()) {
3200 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3201 return FoldedVOp;
3202
3203 // fold (add_sat x, 0) -> x, vector edition
3204 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3205 return N0;
3206 }
3207
3208 // fold (add_sat x, 0) -> x
3209 if (isNullConstant(N1))
3210 return N0;
3211
3212 // If it cannot overflow, transform into an add.
3213 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3214 return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
3215
3216 return SDValue();
3217 }
3218
getAsCarry(const TargetLowering & TLI,SDValue V,bool ForceCarryReconstruction=false)3219 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
3220 bool ForceCarryReconstruction = false) {
3221 bool Masked = false;
3222
3223 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
3224 while (true) {
3225 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
3226 V = V.getOperand(0);
3227 continue;
3228 }
3229
3230 if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
3231 if (ForceCarryReconstruction)
3232 return V;
3233
3234 Masked = true;
3235 V = V.getOperand(0);
3236 continue;
3237 }
3238
3239 if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3240 return V;
3241
3242 break;
3243 }
3244
3245 // If this is not a carry, return.
3246 if (V.getResNo() != 1)
3247 return SDValue();
3248
3249 if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3250 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3251 return SDValue();
3252
3253 EVT VT = V->getValueType(0);
3254 if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
3255 return SDValue();
3256
3257 // If the result is masked, then no matter what kind of bool it is we can
3258 // return. If it isn't, then we need to make sure the bool type is either 0 or
3259 // 1 and not other values.
3260 if (Masked ||
3261 TLI.getBooleanContents(V.getValueType()) ==
3262 TargetLoweringBase::ZeroOrOneBooleanContent)
3263 return V;
3264
3265 return SDValue();
3266 }
3267
3268 /// Given the operands of an add/sub operation, see if the 2nd operand is a
3269 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3270 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)3271 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3272 SelectionDAG &DAG, const SDLoc &DL) {
3273 if (N1.getOpcode() == ISD::ZERO_EXTEND)
3274 N1 = N1.getOperand(0);
3275
3276 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
3277 return SDValue();
3278
3279 EVT VT = N0.getValueType();
3280 SDValue N10 = N1.getOperand(0);
3281 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3282 N10 = N10.getOperand(0);
3283
3284 if (N10.getValueType() != VT)
3285 return SDValue();
3286
3287 if (DAG.ComputeNumSignBits(N10) != VT.getScalarSizeInBits())
3288 return SDValue();
3289
3290 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3291 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3292 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N10);
3293 }
3294
3295 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)3296 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3297 SDNode *LocReference) {
3298 EVT VT = N0.getValueType();
3299 SDLoc DL(LocReference);
3300
3301 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3302 SDValue Y, N;
3303 if (sd_match(N1, m_Shl(m_Neg(m_Value(Y)), m_Value(N))))
3304 return DAG.getNode(ISD::SUB, DL, VT, N0,
3305 DAG.getNode(ISD::SHL, DL, VT, Y, N));
3306
3307 if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
3308 return V;
3309
3310 // Look for:
3311 // add (add x, 1), y
3312 // And if the target does not like this form then turn into:
3313 // sub y, (xor x, -1)
3314 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3315 N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1)) &&
3316 // Limit this to after legalization if the add has wrap flags
3317 (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3318 !N0->getFlags().hasNoSignedWrap()))) {
3319 SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
3320 return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
3321 }
3322
3323 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3324 // Hoist one-use subtraction by non-opaque constant:
3325 // (x - C) + y -> (x + y) - C
3326 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3327 if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3328 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
3329 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
3330 }
3331 // Hoist one-use subtraction from non-opaque constant:
3332 // (C - x) + y -> (y - x) + C
3333 if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3334 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
3335 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
3336 }
3337 }
3338
3339 // add (mul x, C), x -> mul x, C+1
3340 if (N0.getOpcode() == ISD::MUL && N0.getOperand(0) == N1 &&
3341 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true) &&
3342 N0.hasOneUse()) {
3343 SDValue NewC = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
3344 DAG.getConstant(1, DL, VT));
3345 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), NewC);
3346 }
3347
3348 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3349 // rather than 'add 0/-1' (the zext should get folded).
3350 // add (sext i1 Y), X --> sub X, (zext i1 Y)
3351 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3352 N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
3353 TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
3354 SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
3355 return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
3356 }
3357
3358 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3359 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3360 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3361 if (TN->getVT() == MVT::i1) {
3362 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3363 DAG.getConstant(1, DL, VT));
3364 return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
3365 }
3366 }
3367
3368 // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3369 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1)) &&
3370 N1.getResNo() == 0)
3371 return DAG.getNode(ISD::UADDO_CARRY, DL, N1->getVTList(),
3372 N0, N1.getOperand(0), N1.getOperand(2));
3373
3374 // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3375 if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3376 if (SDValue Carry = getAsCarry(TLI, N1))
3377 return DAG.getNode(ISD::UADDO_CARRY, DL,
3378 DAG.getVTList(VT, Carry.getValueType()), N0,
3379 DAG.getConstant(0, DL, VT), Carry);
3380
3381 return SDValue();
3382 }
3383
visitADDC(SDNode * N)3384 SDValue DAGCombiner::visitADDC(SDNode *N) {
3385 SDValue N0 = N->getOperand(0);
3386 SDValue N1 = N->getOperand(1);
3387 EVT VT = N0.getValueType();
3388 SDLoc DL(N);
3389
3390 // If the flag result is dead, turn this into an ADD.
3391 if (!N->hasAnyUseOfValue(1))
3392 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3393 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3394
3395 // canonicalize constant to RHS.
3396 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3397 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3398 if (N0C && !N1C)
3399 return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
3400
3401 // fold (addc x, 0) -> x + no carry out
3402 if (isNullConstant(N1))
3403 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
3404 DL, MVT::Glue));
3405
3406 // If it cannot overflow, transform into an add.
3407 if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3408 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3409 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3410
3411 return SDValue();
3412 }
3413
3414 /**
3415 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3416 * then the flip also occurs if computing the inverse is the same cost.
3417 * This function returns an empty SDValue in case it cannot flip the boolean
3418 * without increasing the cost of the computation. If you want to flip a boolean
3419 * no matter what, use DAG.getLogicalNOT.
3420 */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)3421 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3422 const TargetLowering &TLI,
3423 bool Force) {
3424 if (Force && isa<ConstantSDNode>(V))
3425 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3426
3427 if (V.getOpcode() != ISD::XOR)
3428 return SDValue();
3429
3430 if (DAG.isBoolConstant(V.getOperand(1)) == true)
3431 return V.getOperand(0);
3432 if (Force && isConstOrConstSplat(V.getOperand(1), false))
3433 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
3434 return SDValue();
3435 }
3436
visitADDO(SDNode * N)3437 SDValue DAGCombiner::visitADDO(SDNode *N) {
3438 SDValue N0 = N->getOperand(0);
3439 SDValue N1 = N->getOperand(1);
3440 EVT VT = N0.getValueType();
3441 bool IsSigned = (ISD::SADDO == N->getOpcode());
3442
3443 EVT CarryVT = N->getValueType(1);
3444 SDLoc DL(N);
3445
3446 // If the flag result is dead, turn this into an ADD.
3447 if (!N->hasAnyUseOfValue(1))
3448 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3449 DAG.getUNDEF(CarryVT));
3450
3451 // canonicalize constant to RHS.
3452 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3453 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3454 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
3455
3456 // fold (addo x, 0) -> x + no carry out
3457 if (isNullOrNullSplat(N1))
3458 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3459
3460 // If it cannot overflow, transform into an add.
3461 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3462 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3463 DAG.getConstant(0, DL, CarryVT));
3464
3465 if (IsSigned) {
3466 // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3467 if (isBitwiseNot(N0) && isOneOrOneSplat(N1))
3468 return DAG.getNode(ISD::SSUBO, DL, N->getVTList(),
3469 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3470 } else {
3471 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3472 if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
3473 SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
3474 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3475 return CombineTo(
3476 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3477 }
3478
3479 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3480 return Combined;
3481
3482 if (SDValue Combined = visitUADDOLike(N1, N0, N))
3483 return Combined;
3484 }
3485
3486 return SDValue();
3487 }
3488
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3489 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3490 EVT VT = N0.getValueType();
3491 if (VT.isVector())
3492 return SDValue();
3493
3494 // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3495 // If Y + 1 cannot overflow.
3496 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(N1.getOperand(1))) {
3497 SDValue Y = N1.getOperand(0);
3498 SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3499 if (DAG.computeOverflowForUnsignedAdd(Y, One) == SelectionDAG::OFK_Never)
3500 return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0, Y,
3501 N1.getOperand(2));
3502 }
3503
3504 // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3505 if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT))
3506 if (SDValue Carry = getAsCarry(TLI, N1))
3507 return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(), N0,
3508 DAG.getConstant(0, SDLoc(N), VT), Carry);
3509
3510 return SDValue();
3511 }
3512
visitADDE(SDNode * N)3513 SDValue DAGCombiner::visitADDE(SDNode *N) {
3514 SDValue N0 = N->getOperand(0);
3515 SDValue N1 = N->getOperand(1);
3516 SDValue CarryIn = N->getOperand(2);
3517
3518 // canonicalize constant to RHS
3519 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3520 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3521 if (N0C && !N1C)
3522 return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3523 N1, N0, CarryIn);
3524
3525 // fold (adde x, y, false) -> (addc x, y)
3526 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3527 return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3528
3529 return SDValue();
3530 }
3531
visitUADDO_CARRY(SDNode * N)3532 SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3533 SDValue N0 = N->getOperand(0);
3534 SDValue N1 = N->getOperand(1);
3535 SDValue CarryIn = N->getOperand(2);
3536 SDLoc DL(N);
3537
3538 // canonicalize constant to RHS
3539 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3540 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3541 if (N0C && !N1C)
3542 return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3543
3544 // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3545 if (isNullConstant(CarryIn)) {
3546 if (!LegalOperations ||
3547 TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3548 return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3549 }
3550
3551 // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3552 if (isNullConstant(N0) && isNullConstant(N1)) {
3553 EVT VT = N0.getValueType();
3554 EVT CarryVT = CarryIn.getValueType();
3555 SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3556 AddToWorklist(CarryExt.getNode());
3557 return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3558 DAG.getConstant(1, DL, VT)),
3559 DAG.getConstant(0, DL, CarryVT));
3560 }
3561
3562 if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3563 return Combined;
3564
3565 if (SDValue Combined = visitUADDO_CARRYLike(N1, N0, CarryIn, N))
3566 return Combined;
3567
3568 // We want to avoid useless duplication.
3569 // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3570 // not a binary operation, this is not really possible to leverage this
3571 // existing mechanism for it. However, if more operations require the same
3572 // deduplication logic, then it may be worth generalize.
3573 SDValue Ops[] = {N1, N0, CarryIn};
3574 SDNode *CSENode =
3575 DAG.getNodeIfExists(ISD::UADDO_CARRY, N->getVTList(), Ops, N->getFlags());
3576 if (CSENode)
3577 return SDValue(CSENode, 0);
3578
3579 return SDValue();
3580 }
3581
3582 /**
3583 * If we are facing some sort of diamond carry propagation pattern try to
3584 * break it up to generate something like:
3585 * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3586 *
3587 * The end result is usually an increase in operation required, but because the
3588 * carry is now linearized, other transforms can kick in and optimize the DAG.
3589 *
3590 * Patterns typically look something like
3591 * (uaddo A, B)
3592 * / \
3593 * Carry Sum
3594 * | \
3595 * | (uaddo_carry *, 0, Z)
3596 * | /
3597 * \ Carry
3598 * | /
3599 * (uaddo_carry X, *, *)
3600 *
3601 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3602 * produce a combine with a single path for carry propagation.
3603 */
combineUADDO_CARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3604 static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3605 SelectionDAG &DAG, SDValue X,
3606 SDValue Carry0, SDValue Carry1,
3607 SDNode *N) {
3608 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3609 return SDValue();
3610 if (Carry1.getOpcode() != ISD::UADDO)
3611 return SDValue();
3612
3613 SDValue Z;
3614
3615 /**
3616 * First look for a suitable Z. It will present itself in the form of
3617 * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3618 */
3619 if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3620 isNullConstant(Carry0.getOperand(1))) {
3621 Z = Carry0.getOperand(2);
3622 } else if (Carry0.getOpcode() == ISD::UADDO &&
3623 isOneConstant(Carry0.getOperand(1))) {
3624 EVT VT = Carry0->getValueType(1);
3625 Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3626 } else {
3627 // We couldn't find a suitable Z.
3628 return SDValue();
3629 }
3630
3631
3632 auto cancelDiamond = [&](SDValue A,SDValue B) {
3633 SDLoc DL(N);
3634 SDValue NewY =
3635 DAG.getNode(ISD::UADDO_CARRY, DL, Carry0->getVTList(), A, B, Z);
3636 Combiner.AddToWorklist(NewY.getNode());
3637 return DAG.getNode(ISD::UADDO_CARRY, DL, N->getVTList(), X,
3638 DAG.getConstant(0, DL, X.getValueType()),
3639 NewY.getValue(1));
3640 };
3641
3642 /**
3643 * (uaddo A, B)
3644 * |
3645 * Sum
3646 * |
3647 * (uaddo_carry *, 0, Z)
3648 */
3649 if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3650 return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3651 }
3652
3653 /**
3654 * (uaddo_carry A, 0, Z)
3655 * |
3656 * Sum
3657 * |
3658 * (uaddo *, B)
3659 */
3660 if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3661 return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3662 }
3663
3664 if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3665 return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3666 }
3667
3668 return SDValue();
3669 }
3670
3671 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3672 // match patterns like:
3673 //
3674 // (uaddo A, B) CarryIn
3675 // | \ |
3676 // | \ |
3677 // PartialSum PartialCarryOutX /
3678 // | | /
3679 // | ____|____________/
3680 // | / |
3681 // (uaddo *, *) \________
3682 // | \ \
3683 // | \ |
3684 // | PartialCarryOutY |
3685 // | \ |
3686 // | \ /
3687 // AddCarrySum | ______/
3688 // | /
3689 // CarryOut = (or *, *)
3690 //
3691 // And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3692 //
3693 // {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3694 //
3695 // Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3696 // with a single path for carry/borrow out propagation.
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3697 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3698 SDValue N0, SDValue N1, SDNode *N) {
3699 SDValue Carry0 = getAsCarry(TLI, N0);
3700 if (!Carry0)
3701 return SDValue();
3702 SDValue Carry1 = getAsCarry(TLI, N1);
3703 if (!Carry1)
3704 return SDValue();
3705
3706 unsigned Opcode = Carry0.getOpcode();
3707 if (Opcode != Carry1.getOpcode())
3708 return SDValue();
3709 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3710 return SDValue();
3711 // Guarantee identical type of CarryOut
3712 EVT CarryOutType = N->getValueType(0);
3713 if (CarryOutType != Carry0.getValue(1).getValueType() ||
3714 CarryOutType != Carry1.getValue(1).getValueType())
3715 return SDValue();
3716
3717 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3718 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3719 if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3720 std::swap(Carry0, Carry1);
3721
3722 // Check if nodes are connected in expected way.
3723 if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3724 Carry1.getOperand(1) != Carry0.getValue(0))
3725 return SDValue();
3726
3727 // The carry in value must be on the righthand side for subtraction.
3728 unsigned CarryInOperandNum =
3729 Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3730 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3731 return SDValue();
3732 SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3733
3734 unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3735 if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3736 return SDValue();
3737
3738 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3739 CarryIn = getAsCarry(TLI, CarryIn, true);
3740 if (!CarryIn)
3741 return SDValue();
3742
3743 SDLoc DL(N);
3744 CarryIn = DAG.getBoolExtOrTrunc(CarryIn, DL, Carry1->getValueType(1),
3745 Carry1->getValueType(0));
3746 SDValue Merged =
3747 DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3748 Carry0.getOperand(1), CarryIn);
3749
3750 // Please note that because we have proven that the result of the UADDO/USUBO
3751 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3752 // therefore prove that if the first UADDO/USUBO overflows, the second
3753 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3754 // maximum value.
3755 //
3756 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3757 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3758 //
3759 // This is important because it means that OR and XOR can be used to merge
3760 // carry flags; and that AND can return a constant zero.
3761 //
3762 // TODO: match other operations that can merge flags (ADD, etc)
3763 DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3764 if (N->getOpcode() == ISD::AND)
3765 return DAG.getConstant(0, DL, CarryOutType);
3766 return Merged.getValue(1);
3767 }
3768
visitUADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3769 SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3770 SDValue CarryIn, SDNode *N) {
3771 // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3772 // carry.
3773 if (isBitwiseNot(N0))
3774 if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3775 SDLoc DL(N);
3776 SDValue Sub = DAG.getNode(ISD::USUBO_CARRY, DL, N->getVTList(), N1,
3777 N0.getOperand(0), NotC);
3778 return CombineTo(
3779 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3780 }
3781
3782 // Iff the flag result is dead:
3783 // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3784 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3785 // or the dependency between the instructions.
3786 if ((N0.getOpcode() == ISD::ADD ||
3787 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3788 N0.getValue(1) != CarryIn)) &&
3789 isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3790 return DAG.getNode(ISD::UADDO_CARRY, SDLoc(N), N->getVTList(),
3791 N0.getOperand(0), N0.getOperand(1), CarryIn);
3792
3793 /**
3794 * When one of the uaddo_carry argument is itself a carry, we may be facing
3795 * a diamond carry propagation. In which case we try to transform the DAG
3796 * to ensure linear carry propagation if that is possible.
3797 */
3798 if (auto Y = getAsCarry(TLI, N1)) {
3799 // Because both are carries, Y and Z can be swapped.
3800 if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3801 return R;
3802 if (auto R = combineUADDO_CARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3803 return R;
3804 }
3805
3806 return SDValue();
3807 }
3808
visitSADDO_CARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3809 SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3810 SDValue CarryIn, SDNode *N) {
3811 // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3812 if (isBitwiseNot(N0)) {
3813 if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true))
3814 return DAG.getNode(ISD::SSUBO_CARRY, SDLoc(N), N->getVTList(), N1,
3815 N0.getOperand(0), NotC);
3816 }
3817
3818 return SDValue();
3819 }
3820
visitSADDO_CARRY(SDNode * N)3821 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3822 SDValue N0 = N->getOperand(0);
3823 SDValue N1 = N->getOperand(1);
3824 SDValue CarryIn = N->getOperand(2);
3825 SDLoc DL(N);
3826
3827 // canonicalize constant to RHS
3828 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3829 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3830 if (N0C && !N1C)
3831 return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3832
3833 // fold (saddo_carry x, y, false) -> (saddo x, y)
3834 if (isNullConstant(CarryIn)) {
3835 if (!LegalOperations ||
3836 TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3837 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3838 }
3839
3840 if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3841 return Combined;
3842
3843 if (SDValue Combined = visitSADDO_CARRYLike(N1, N0, CarryIn, N))
3844 return Combined;
3845
3846 return SDValue();
3847 }
3848
3849 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3850 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3851 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3852 SDValue RHS, SelectionDAG &DAG,
3853 const SDLoc &DL) {
3854 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3855 "Illegal truncation");
3856
3857 if (DstVT == SrcVT)
3858 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3859
3860 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3861 // clamping RHS.
3862 APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3863 DstVT.getScalarSizeInBits());
3864 if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3865 return SDValue();
3866
3867 SDValue SatLimit =
3868 DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3869 DstVT.getScalarSizeInBits()),
3870 DL, SrcVT);
3871 RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3872 RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3873 LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3874 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3875 }
3876
3877 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3878 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N,const SDLoc & DL)3879 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3880 if (N->getOpcode() != ISD::SUB ||
3881 !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3882 return SDValue();
3883
3884 EVT SubVT = N->getValueType(0);
3885 SDValue Op0 = N->getOperand(0);
3886 SDValue Op1 = N->getOperand(1);
3887
3888 // Try to find umax(a,b) - b or a - umin(a,b) patterns
3889 // they may be converted to usubsat(a,b).
3890 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3891 SDValue MaxLHS = Op0.getOperand(0);
3892 SDValue MaxRHS = Op0.getOperand(1);
3893 if (MaxLHS == Op1)
3894 return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, DL);
3895 if (MaxRHS == Op1)
3896 return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, DL);
3897 }
3898
3899 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3900 SDValue MinLHS = Op1.getOperand(0);
3901 SDValue MinRHS = Op1.getOperand(1);
3902 if (MinLHS == Op0)
3903 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, DL);
3904 if (MinRHS == Op0)
3905 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, DL);
3906 }
3907
3908 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3909 if (Op1.getOpcode() == ISD::TRUNCATE &&
3910 Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3911 Op1.getOperand(0).hasOneUse()) {
3912 SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3913 SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3914 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3915 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3916 DAG, DL);
3917 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3918 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3919 DAG, DL);
3920 }
3921
3922 return SDValue();
3923 }
3924
3925 // Refinement of DAG/Type Legalisation (promotion) when CTLZ is used for
3926 // counting leading ones. Broadly, it replaces the substraction with a left
3927 // shift.
3928 //
3929 // * DAG Legalisation Pattern:
3930 //
3931 // (sub (ctlz (zeroextend (not Src)))
3932 // BitWidthDiff)
3933 //
3934 // if BitWidthDiff == BitWidth(Node) - BitWidth(Src)
3935 // -->
3936 //
3937 // (ctlz_zero_undef (not (shl (anyextend Src)
3938 // BitWidthDiff)))
3939 //
3940 // * Type Legalisation Pattern:
3941 //
3942 // (sub (ctlz (and (xor Src XorMask)
3943 // AndMask))
3944 // BitWidthDiff)
3945 //
3946 // if AndMask has only trailing ones
3947 // and MaskBitWidth(AndMask) == BitWidth(Node) - BitWidthDiff
3948 // and XorMask has more trailing ones than AndMask
3949 // -->
3950 //
3951 // (ctlz_zero_undef (not (shl Src BitWidthDiff)))
3952 template <class MatchContextClass>
foldSubCtlzNot(SDNode * N,SelectionDAG & DAG)3953 static SDValue foldSubCtlzNot(SDNode *N, SelectionDAG &DAG) {
3954 const SDLoc DL(N);
3955 SDValue N0 = N->getOperand(0);
3956 EVT VT = N0.getValueType();
3957 unsigned BitWidth = VT.getScalarSizeInBits();
3958
3959 MatchContextClass Matcher(DAG, DAG.getTargetLoweringInfo(), N);
3960
3961 APInt AndMask;
3962 APInt XorMask;
3963 APInt BitWidthDiff;
3964
3965 SDValue CtlzOp;
3966 SDValue Src;
3967
3968 if (!sd_context_match(
3969 N, Matcher, m_Sub(m_Ctlz(m_Value(CtlzOp)), m_ConstInt(BitWidthDiff))))
3970 return SDValue();
3971
3972 if (sd_context_match(CtlzOp, Matcher, m_ZExt(m_Not(m_Value(Src))))) {
3973 // DAG Legalisation Pattern:
3974 // (sub (ctlz (zero_extend (not Op)) BitWidthDiff))
3975 if ((BitWidth - Src.getValueType().getScalarSizeInBits()) != BitWidthDiff)
3976 return SDValue();
3977
3978 Src = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Src);
3979 } else if (sd_context_match(CtlzOp, Matcher,
3980 m_And(m_Xor(m_Value(Src), m_ConstInt(XorMask)),
3981 m_ConstInt(AndMask)))) {
3982 // Type Legalisation Pattern:
3983 // (sub (ctlz (and (xor Op XorMask) AndMask)) BitWidthDiff)
3984 unsigned AndMaskWidth = BitWidth - BitWidthDiff.getZExtValue();
3985 if (!(AndMask.isMask(AndMaskWidth) && XorMask.countr_one() >= AndMaskWidth))
3986 return SDValue();
3987 } else
3988 return SDValue();
3989
3990 SDValue ShiftConst = DAG.getShiftAmountConstant(BitWidthDiff, VT, DL);
3991 SDValue LShift = Matcher.getNode(ISD::SHL, DL, VT, Src, ShiftConst);
3992 SDValue Not =
3993 Matcher.getNode(ISD::XOR, DL, VT, LShift, DAG.getAllOnesConstant(DL, VT));
3994
3995 return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
3996 }
3997
3998 // Fold sub(x, mul(divrem(x,y)[0], y)) to divrem(x, y)[1]
foldRemainderIdiom(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)3999 static SDValue foldRemainderIdiom(SDNode *N, SelectionDAG &DAG,
4000 const SDLoc &DL) {
4001 assert(N->getOpcode() == ISD::SUB && "Node must be a SUB");
4002 SDValue Sub0 = N->getOperand(0);
4003 SDValue Sub1 = N->getOperand(1);
4004
4005 auto CheckAndFoldMulCase = [&](SDValue DivRem, SDValue MaybeY) -> SDValue {
4006 if ((DivRem.getOpcode() == ISD::SDIVREM ||
4007 DivRem.getOpcode() == ISD::UDIVREM) &&
4008 DivRem.getResNo() == 0 && DivRem.getOperand(0) == Sub0 &&
4009 DivRem.getOperand(1) == MaybeY) {
4010 return SDValue(DivRem.getNode(), 1);
4011 }
4012 return SDValue();
4013 };
4014
4015 if (Sub1.getOpcode() == ISD::MUL) {
4016 // (sub x, (mul divrem(x,y)[0], y))
4017 SDValue Mul0 = Sub1.getOperand(0);
4018 SDValue Mul1 = Sub1.getOperand(1);
4019
4020 if (SDValue Res = CheckAndFoldMulCase(Mul0, Mul1))
4021 return Res;
4022
4023 if (SDValue Res = CheckAndFoldMulCase(Mul1, Mul0))
4024 return Res;
4025
4026 } else if (Sub1.getOpcode() == ISD::SHL) {
4027 // Handle (sub x, (shl divrem(x,y)[0], C)) where y = 1 << C
4028 SDValue Shl0 = Sub1.getOperand(0);
4029 SDValue Shl1 = Sub1.getOperand(1);
4030 // Check if Shl0 is divrem(x, Y)[0]
4031 if ((Shl0.getOpcode() == ISD::SDIVREM ||
4032 Shl0.getOpcode() == ISD::UDIVREM) &&
4033 Shl0.getResNo() == 0 && Shl0.getOperand(0) == Sub0) {
4034
4035 SDValue Divisor = Shl0.getOperand(1);
4036
4037 ConstantSDNode *DivC = isConstOrConstSplat(Divisor);
4038 ConstantSDNode *ShC = isConstOrConstSplat(Shl1);
4039 if (!DivC || !ShC)
4040 return SDValue();
4041
4042 if (DivC->getAPIntValue().isPowerOf2() &&
4043 DivC->getAPIntValue().logBase2() == ShC->getAPIntValue())
4044 return SDValue(Shl0.getNode(), 1);
4045 }
4046 }
4047 return SDValue();
4048 }
4049
4050 // Since it may not be valid to emit a fold to zero for vector initializers
4051 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)4052 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
4053 SelectionDAG &DAG, bool LegalOperations) {
4054 if (!VT.isVector())
4055 return DAG.getConstant(0, DL, VT);
4056 if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
4057 return DAG.getConstant(0, DL, VT);
4058 return SDValue();
4059 }
4060
visitSUB(SDNode * N)4061 SDValue DAGCombiner::visitSUB(SDNode *N) {
4062 SDValue N0 = N->getOperand(0);
4063 SDValue N1 = N->getOperand(1);
4064 EVT VT = N0.getValueType();
4065 unsigned BitWidth = VT.getScalarSizeInBits();
4066 SDLoc DL(N);
4067
4068 auto PeekThroughFreeze = [](SDValue N) {
4069 if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
4070 return N->getOperand(0);
4071 return N;
4072 };
4073
4074 if (SDValue V = foldSubCtlzNot<EmptyMatchContext>(N, DAG))
4075 return V;
4076
4077 // fold (sub x, x) -> 0
4078 // FIXME: Refactor this and xor and other similar operations together.
4079 if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
4080 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
4081
4082 // fold (sub c1, c2) -> c3
4083 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
4084 return C;
4085
4086 // fold vector ops
4087 if (VT.isVector()) {
4088 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4089 return FoldedVOp;
4090
4091 // fold (sub x, 0) -> x, vector edition
4092 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4093 return N0;
4094 }
4095
4096 if (SDValue NewSel = foldBinOpIntoSelect(N))
4097 return NewSel;
4098
4099 // fold (sub x, c) -> (add x, -c)
4100 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
4101 return DAG.getNode(ISD::ADD, DL, VT, N0,
4102 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4103
4104 if (isNullOrNullSplat(N0)) {
4105 // Right-shifting everything out but the sign bit followed by negation is
4106 // the same as flipping arithmetic/logical shift type without the negation:
4107 // -(X >>u 31) -> (X >>s 31)
4108 // -(X >>s 31) -> (X >>u 31)
4109 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
4110 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
4111 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
4112 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
4113 if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
4114 return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
4115 }
4116 }
4117
4118 // 0 - X --> 0 if the sub is NUW.
4119 if (N->getFlags().hasNoUnsignedWrap())
4120 return N0;
4121
4122 if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
4123 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
4124 // N1 must be 0 because negating the minimum signed value is undefined.
4125 if (N->getFlags().hasNoSignedWrap())
4126 return N0;
4127
4128 // 0 - X --> X if X is 0 or the minimum signed value.
4129 return N1;
4130 }
4131
4132 // Convert 0 - abs(x).
4133 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
4134 !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
4135 if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
4136 return Result;
4137
4138 // Similar to the previous rule, but this time targeting an expanded abs.
4139 // (sub 0, (max X, (sub 0, X))) --> (min X, (sub 0, X))
4140 // as well as
4141 // (sub 0, (min X, (sub 0, X))) --> (max X, (sub 0, X))
4142 // Note that these two are applicable to both signed and unsigned min/max.
4143 SDValue X;
4144 SDValue S0;
4145 auto NegPat = m_AllOf(m_Neg(m_Deferred(X)), m_Value(S0));
4146 if (sd_match(N1, m_OneUse(m_AnyOf(m_SMax(m_Value(X), NegPat),
4147 m_UMax(m_Value(X), NegPat),
4148 m_SMin(m_Value(X), NegPat),
4149 m_UMin(m_Value(X), NegPat))))) {
4150 unsigned NewOpc = ISD::getInverseMinMaxOpcode(N1->getOpcode());
4151 if (hasOperation(NewOpc, VT))
4152 return DAG.getNode(NewOpc, DL, VT, X, S0);
4153 }
4154
4155 // Fold neg(splat(neg(x)) -> splat(x)
4156 if (VT.isVector()) {
4157 SDValue N1S = DAG.getSplatValue(N1, true);
4158 if (N1S && N1S.getOpcode() == ISD::SUB &&
4159 isNullConstant(N1S.getOperand(0)))
4160 return DAG.getSplat(VT, DL, N1S.getOperand(1));
4161 }
4162
4163 // sub 0, (and x, 1) --> SIGN_EXTEND_INREG x, i1
4164 if (N1.getOpcode() == ISD::AND && N1.hasOneUse() &&
4165 isOneOrOneSplat(N1->getOperand(1))) {
4166 EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), 1);
4167 if (VT.isVector())
4168 ExtVT = EVT::getVectorVT(*DAG.getContext(), ExtVT,
4169 VT.getVectorElementCount());
4170 if (TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
4171 TargetLowering::Legal) {
4172 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N1->getOperand(0),
4173 DAG.getValueType(ExtVT));
4174 }
4175 }
4176 }
4177
4178 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
4179 if (isAllOnesOrAllOnesSplat(N0))
4180 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4181
4182 // fold (A - (0-B)) -> A+B
4183 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
4184 return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
4185
4186 // fold A-(A-B) -> B
4187 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
4188 return N1.getOperand(1);
4189
4190 // fold (A+B)-A -> B
4191 if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
4192 return N0.getOperand(1);
4193
4194 // fold (A+B)-B -> A
4195 if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
4196 return N0.getOperand(0);
4197
4198 // fold (A+C1)-C2 -> A+(C1-C2)
4199 if (N0.getOpcode() == ISD::ADD) {
4200 SDValue N01 = N0.getOperand(1);
4201 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
4202 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
4203 }
4204
4205 // fold C2-(A+C1) -> (C2-C1)-A
4206 if (N1.getOpcode() == ISD::ADD) {
4207 SDValue N11 = N1.getOperand(1);
4208 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
4209 return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
4210 }
4211
4212 // fold (A-C1)-C2 -> A-(C1+C2)
4213 if (N0.getOpcode() == ISD::SUB) {
4214 SDValue N01 = N0.getOperand(1);
4215 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
4216 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
4217 }
4218
4219 // fold (c1-A)-c2 -> (c1-c2)-A
4220 if (N0.getOpcode() == ISD::SUB) {
4221 SDValue N00 = N0.getOperand(0);
4222 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
4223 return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
4224 }
4225
4226 SDValue A, B, C;
4227
4228 // fold ((A+(B+C))-B) -> A+C
4229 if (sd_match(N0, m_Add(m_Value(A), m_Add(m_Specific(N1), m_Value(C)))))
4230 return DAG.getNode(ISD::ADD, DL, VT, A, C);
4231
4232 // fold ((A+(B-C))-B) -> A-C
4233 if (sd_match(N0, m_Add(m_Value(A), m_Sub(m_Specific(N1), m_Value(C)))))
4234 return DAG.getNode(ISD::SUB, DL, VT, A, C);
4235
4236 // fold ((A-(B-C))-C) -> A-B
4237 if (sd_match(N0, m_Sub(m_Value(A), m_Sub(m_Value(B), m_Specific(N1)))))
4238 return DAG.getNode(ISD::SUB, DL, VT, A, B);
4239
4240 // fold (A-(B-C)) -> A+(C-B)
4241 if (sd_match(N1, m_OneUse(m_Sub(m_Value(B), m_Value(C)))))
4242 return DAG.getNode(ISD::ADD, DL, VT, N0,
4243 DAG.getNode(ISD::SUB, DL, VT, C, B));
4244
4245 // A - (A & B) -> A & (~B)
4246 if (sd_match(N1, m_And(m_Specific(N0), m_Value(B))) &&
4247 (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true)))
4248 return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getNOT(DL, B, VT));
4249
4250 // fold (A - (-B * C)) -> (A + (B * C))
4251 if (sd_match(N1, m_OneUse(m_Mul(m_Neg(m_Value(B)), m_Value(C)))))
4252 return DAG.getNode(ISD::ADD, DL, VT, N0,
4253 DAG.getNode(ISD::MUL, DL, VT, B, C));
4254
4255 // If either operand of a sub is undef, the result is undef
4256 if (N0.isUndef())
4257 return N0;
4258 if (N1.isUndef())
4259 return N1;
4260
4261 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
4262 return V;
4263
4264 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
4265 return V;
4266
4267 // Try to match AVGCEIL fixedwidth pattern
4268 if (SDValue V = foldSubToAvg(N, DL))
4269 return V;
4270
4271 if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, DL))
4272 return V;
4273
4274 if (SDValue V = foldSubToUSubSat(VT, N, DL))
4275 return V;
4276
4277 if (SDValue V = foldRemainderIdiom(N, DAG, DL))
4278 return V;
4279
4280 // (A - B) - 1 -> add (xor B, -1), A
4281 if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))),
4282 m_One(/*AllowUndefs=*/true))))
4283 return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
4284
4285 // Look for:
4286 // sub y, (xor x, -1)
4287 // And if the target does not like this form then turn into:
4288 // add (add x, y), 1
4289 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
4290 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
4291 return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
4292 }
4293
4294 // Hoist one-use addition by non-opaque constant:
4295 // (x + C) - y -> (x - y) + C
4296 if (!reassociationCanBreakAddressingModePattern(ISD::SUB, DL, N, N0, N1) &&
4297 N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
4298 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4299 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4300 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
4301 }
4302 // y - (x + C) -> (y - x) - C
4303 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
4304 isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
4305 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
4306 return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
4307 }
4308 // (x - C) - y -> (x - y) - C
4309 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
4310 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4311 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
4312 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
4313 return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
4314 }
4315 // (C - x) - y -> C - (x + y)
4316 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
4317 isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
4318 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
4319 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
4320 }
4321
4322 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
4323 // rather than 'sub 0/1' (the sext should get folded).
4324 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
4325 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
4326 N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
4327 TLI.getBooleanContents(VT) ==
4328 TargetLowering::ZeroOrNegativeOneBooleanContent) {
4329 SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
4330 return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
4331 }
4332
4333 // fold B = sra (A, size(A)-1); sub (xor (A, B), B) -> (abs A)
4334 if ((!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4335 sd_match(N1, m_Sra(m_Value(A), m_SpecificInt(BitWidth - 1))) &&
4336 sd_match(N0, m_Xor(m_Specific(A), m_Specific(N1))))
4337 return DAG.getNode(ISD::ABS, DL, VT, A);
4338
4339 // If the relocation model supports it, consider symbol offsets.
4340 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
4341 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
4342 // fold (sub Sym+c1, Sym+c2) -> c1-c2
4343 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
4344 if (GA->getGlobal() == GB->getGlobal())
4345 return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
4346 DL, VT);
4347 }
4348
4349 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
4350 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
4351 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
4352 if (TN->getVT() == MVT::i1) {
4353 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
4354 DAG.getConstant(1, DL, VT));
4355 return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
4356 }
4357 }
4358
4359 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
4360 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
4361 const APInt &IntVal = N1.getConstantOperandAPInt(0);
4362 return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
4363 }
4364
4365 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
4366 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
4367 APInt NewStep = -N1.getConstantOperandAPInt(0);
4368 return DAG.getNode(ISD::ADD, DL, VT, N0,
4369 DAG.getStepVector(DL, VT, NewStep));
4370 }
4371
4372 // Prefer an add for more folding potential and possibly better codegen:
4373 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
4374 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
4375 SDValue ShAmt = N1.getOperand(1);
4376 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
4377 if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
4378 SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
4379 return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
4380 }
4381 }
4382
4383 // As with the previous fold, prefer add for more folding potential.
4384 // Subtracting SMIN/0 is the same as adding SMIN/0:
4385 // N0 - (X << BW-1) --> N0 + (X << BW-1)
4386 if (N1.getOpcode() == ISD::SHL) {
4387 ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
4388 if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4389 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
4390 }
4391
4392 // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4393 if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(N0.getOperand(1)) &&
4394 N0.getResNo() == 0 && N0.hasOneUse())
4395 return DAG.getNode(ISD::USUBO_CARRY, DL, N0->getVTList(),
4396 N0.getOperand(0), N1, N0.getOperand(2));
4397
4398 if (TLI.isOperationLegalOrCustom(ISD::UADDO_CARRY, VT)) {
4399 // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry)
4400 if (SDValue Carry = getAsCarry(TLI, N0)) {
4401 SDValue X = N1;
4402 SDValue Zero = DAG.getConstant(0, DL, VT);
4403 SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
4404 return DAG.getNode(ISD::UADDO_CARRY, DL,
4405 DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
4406 Carry);
4407 }
4408 }
4409
4410 // If there's no chance of borrowing from adjacent bits, then sub is xor:
4411 // sub C0, X --> xor X, C0
4412 if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
4413 if (!C0->isOpaque()) {
4414 const APInt &C0Val = C0->getAPIntValue();
4415 const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
4416 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4417 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
4418 }
4419 }
4420
4421 // smax(a,b) - smin(a,b) --> abds(a,b)
4422 if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
4423 sd_match(N0, m_SMaxLike(m_Value(A), m_Value(B))) &&
4424 sd_match(N1, m_SMinLike(m_Specific(A), m_Specific(B))))
4425 return DAG.getNode(ISD::ABDS, DL, VT, A, B);
4426
4427 // smin(a,b) - smax(a,b) --> neg(abds(a,b))
4428 if (hasOperation(ISD::ABDS, VT) &&
4429 sd_match(N0, m_SMinLike(m_Value(A), m_Value(B))) &&
4430 sd_match(N1, m_SMaxLike(m_Specific(A), m_Specific(B))))
4431 return DAG.getNegative(DAG.getNode(ISD::ABDS, DL, VT, A, B), DL, VT);
4432
4433 // umax(a,b) - umin(a,b) --> abdu(a,b)
4434 if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
4435 sd_match(N0, m_UMaxLike(m_Value(A), m_Value(B))) &&
4436 sd_match(N1, m_UMinLike(m_Specific(A), m_Specific(B))))
4437 return DAG.getNode(ISD::ABDU, DL, VT, A, B);
4438
4439 // umin(a,b) - umax(a,b) --> neg(abdu(a,b))
4440 if (hasOperation(ISD::ABDU, VT) &&
4441 sd_match(N0, m_UMinLike(m_Value(A), m_Value(B))) &&
4442 sd_match(N1, m_UMaxLike(m_Specific(A), m_Specific(B))))
4443 return DAG.getNegative(DAG.getNode(ISD::ABDU, DL, VT, A, B), DL, VT);
4444
4445 // (sub x, (select (ult x, y), 0, y)) -> (umin x, (sub x, y))
4446 // (sub x, (select (uge x, y), y, 0)) -> (umin x, (sub x, y))
4447 if (hasUMin(VT)) {
4448 SDValue Y;
4449 if (sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y),
4450 m_SpecificCondCode(ISD::SETULT)),
4451 m_Zero(), m_Deferred(Y)))) ||
4452 sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y),
4453 m_SpecificCondCode(ISD::SETUGE)),
4454 m_Deferred(Y), m_Zero()))))
4455 return DAG.getNode(ISD::UMIN, DL, VT, N0,
4456 DAG.getNode(ISD::SUB, DL, VT, N0, Y));
4457 }
4458
4459 return SDValue();
4460 }
4461
visitSUBSAT(SDNode * N)4462 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4463 unsigned Opcode = N->getOpcode();
4464 SDValue N0 = N->getOperand(0);
4465 SDValue N1 = N->getOperand(1);
4466 EVT VT = N0.getValueType();
4467 bool IsSigned = Opcode == ISD::SSUBSAT;
4468 SDLoc DL(N);
4469
4470 // fold (sub_sat x, undef) -> 0
4471 if (N0.isUndef() || N1.isUndef())
4472 return DAG.getConstant(0, DL, VT);
4473
4474 // fold (sub_sat x, x) -> 0
4475 if (N0 == N1)
4476 return DAG.getConstant(0, DL, VT);
4477
4478 // fold (sub_sat c1, c2) -> c3
4479 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4480 return C;
4481
4482 // fold vector ops
4483 if (VT.isVector()) {
4484 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4485 return FoldedVOp;
4486
4487 // fold (sub_sat x, 0) -> x, vector edition
4488 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4489 return N0;
4490 }
4491
4492 // fold (sub_sat x, 0) -> x
4493 if (isNullConstant(N1))
4494 return N0;
4495
4496 // If it cannot overflow, transform into an sub.
4497 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4498 return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
4499
4500 return SDValue();
4501 }
4502
visitSUBC(SDNode * N)4503 SDValue DAGCombiner::visitSUBC(SDNode *N) {
4504 SDValue N0 = N->getOperand(0);
4505 SDValue N1 = N->getOperand(1);
4506 EVT VT = N0.getValueType();
4507 SDLoc DL(N);
4508
4509 // If the flag result is dead, turn this into an SUB.
4510 if (!N->hasAnyUseOfValue(1))
4511 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4512 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4513
4514 // fold (subc x, x) -> 0 + no borrow
4515 if (N0 == N1)
4516 return CombineTo(N, DAG.getConstant(0, DL, VT),
4517 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4518
4519 // fold (subc x, 0) -> x + no borrow
4520 if (isNullConstant(N1))
4521 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4522
4523 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4524 if (isAllOnesConstant(N0))
4525 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4526 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4527
4528 return SDValue();
4529 }
4530
visitSUBO(SDNode * N)4531 SDValue DAGCombiner::visitSUBO(SDNode *N) {
4532 SDValue N0 = N->getOperand(0);
4533 SDValue N1 = N->getOperand(1);
4534 EVT VT = N0.getValueType();
4535 bool IsSigned = (ISD::SSUBO == N->getOpcode());
4536
4537 EVT CarryVT = N->getValueType(1);
4538 SDLoc DL(N);
4539
4540 // If the flag result is dead, turn this into an SUB.
4541 if (!N->hasAnyUseOfValue(1))
4542 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4543 DAG.getUNDEF(CarryVT));
4544
4545 // fold (subo x, x) -> 0 + no borrow
4546 if (N0 == N1)
4547 return CombineTo(N, DAG.getConstant(0, DL, VT),
4548 DAG.getConstant(0, DL, CarryVT));
4549
4550 // fold (subox, c) -> (addo x, -c)
4551 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N1))
4552 if (IsSigned && !N1C->isMinSignedValue())
4553 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
4554 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
4555
4556 // fold (subo x, 0) -> x + no borrow
4557 if (isNullOrNullSplat(N1))
4558 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
4559
4560 // If it cannot overflow, transform into an sub.
4561 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4562 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
4563 DAG.getConstant(0, DL, CarryVT));
4564
4565 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4566 if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
4567 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
4568 DAG.getConstant(0, DL, CarryVT));
4569
4570 return SDValue();
4571 }
4572
visitSUBE(SDNode * N)4573 SDValue DAGCombiner::visitSUBE(SDNode *N) {
4574 SDValue N0 = N->getOperand(0);
4575 SDValue N1 = N->getOperand(1);
4576 SDValue CarryIn = N->getOperand(2);
4577
4578 // fold (sube x, y, false) -> (subc x, y)
4579 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4580 return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
4581
4582 return SDValue();
4583 }
4584
visitUSUBO_CARRY(SDNode * N)4585 SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4586 SDValue N0 = N->getOperand(0);
4587 SDValue N1 = N->getOperand(1);
4588 SDValue CarryIn = N->getOperand(2);
4589
4590 // fold (usubo_carry x, y, false) -> (usubo x, y)
4591 if (isNullConstant(CarryIn)) {
4592 if (!LegalOperations ||
4593 TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
4594 return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
4595 }
4596
4597 return SDValue();
4598 }
4599
visitSSUBO_CARRY(SDNode * N)4600 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4601 SDValue N0 = N->getOperand(0);
4602 SDValue N1 = N->getOperand(1);
4603 SDValue CarryIn = N->getOperand(2);
4604
4605 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4606 if (isNullConstant(CarryIn)) {
4607 if (!LegalOperations ||
4608 TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
4609 return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
4610 }
4611
4612 return SDValue();
4613 }
4614
4615 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4616 // UMULFIXSAT here.
visitMULFIX(SDNode * N)4617 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4618 SDValue N0 = N->getOperand(0);
4619 SDValue N1 = N->getOperand(1);
4620 SDValue Scale = N->getOperand(2);
4621 EVT VT = N0.getValueType();
4622
4623 // fold (mulfix x, undef, scale) -> 0
4624 if (N0.isUndef() || N1.isUndef())
4625 return DAG.getConstant(0, SDLoc(N), VT);
4626
4627 // Canonicalize constant to RHS (vector doesn't have to splat)
4628 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4629 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4630 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
4631
4632 // fold (mulfix x, 0, scale) -> 0
4633 if (isNullConstant(N1))
4634 return DAG.getConstant(0, SDLoc(N), VT);
4635
4636 return SDValue();
4637 }
4638
visitMUL(SDNode * N)4639 template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
4640 SDValue N0 = N->getOperand(0);
4641 SDValue N1 = N->getOperand(1);
4642 EVT VT = N0.getValueType();
4643 unsigned BitWidth = VT.getScalarSizeInBits();
4644 SDLoc DL(N);
4645 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
4646 MatchContextClass Matcher(DAG, TLI, N);
4647
4648 // fold (mul x, undef) -> 0
4649 if (N0.isUndef() || N1.isUndef())
4650 return DAG.getConstant(0, DL, VT);
4651
4652 // fold (mul c1, c2) -> c1*c2
4653 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
4654 return C;
4655
4656 // canonicalize constant to RHS (vector doesn't have to splat)
4657 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4658 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4659 return Matcher.getNode(ISD::MUL, DL, VT, N1, N0);
4660
4661 bool N1IsConst = false;
4662 bool N1IsOpaqueConst = false;
4663 APInt ConstValue1;
4664
4665 // fold vector ops
4666 if (VT.isVector()) {
4667 // TODO: Change this to use SimplifyVBinOp when it supports VP op.
4668 if (!UseVP)
4669 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4670 return FoldedVOp;
4671
4672 N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4673 assert((!N1IsConst || ConstValue1.getBitWidth() == BitWidth) &&
4674 "Splat APInt should be element width");
4675 } else {
4676 N1IsConst = isa<ConstantSDNode>(N1);
4677 if (N1IsConst) {
4678 ConstValue1 = N1->getAsAPIntVal();
4679 N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
4680 }
4681 }
4682
4683 // fold (mul x, 0) -> 0
4684 if (N1IsConst && ConstValue1.isZero())
4685 return N1;
4686
4687 // fold (mul x, 1) -> x
4688 if (N1IsConst && ConstValue1.isOne())
4689 return N0;
4690
4691 if (!UseVP)
4692 if (SDValue NewSel = foldBinOpIntoSelect(N))
4693 return NewSel;
4694
4695 // fold (mul x, -1) -> 0-x
4696 if (N1IsConst && ConstValue1.isAllOnes())
4697 return Matcher.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
4698
4699 // fold (mul x, (1 << c)) -> x << c
4700 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4701 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4702 if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
4703 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4704 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4705 return Matcher.getNode(ISD::SHL, DL, VT, N0, Trunc);
4706 }
4707 }
4708
4709 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4710 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4711 unsigned Log2Val = (-ConstValue1).logBase2();
4712
4713 // FIXME: If the input is something that is easily negated (e.g. a
4714 // single-use add), we should put the negate there.
4715 return Matcher.getNode(
4716 ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
4717 Matcher.getNode(ISD::SHL, DL, VT, N0,
4718 DAG.getShiftAmountConstant(Log2Val, VT, DL)));
4719 }
4720
4721 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4722 // hi result is in use in case we hit this mid-legalization.
4723 if (!UseVP) {
4724 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4725 if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4726 SDVTList LoHiVT = DAG.getVTList(VT, VT);
4727 // TODO: Can we match commutable operands with getNodeIfExists?
4728 if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4729 if (LoHi->hasAnyUseOfValue(1))
4730 return SDValue(LoHi, 0);
4731 if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4732 if (LoHi->hasAnyUseOfValue(1))
4733 return SDValue(LoHi, 0);
4734 }
4735 }
4736 }
4737
4738 // Try to transform:
4739 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4740 // mul x, (2^N + 1) --> add (shl x, N), x
4741 // mul x, (2^N - 1) --> sub (shl x, N), x
4742 // Examples: x * 33 --> (x << 5) + x
4743 // x * 15 --> (x << 4) - x
4744 // x * -33 --> -((x << 5) + x)
4745 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4746 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4747 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4748 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4749 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4750 // x * 0xf800 --> (x << 16) - (x << 11)
4751 // x * -0x8800 --> -((x << 15) + (x << 11))
4752 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4753 if (!UseVP && N1IsConst &&
4754 TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4755 // TODO: We could handle more general decomposition of any constant by
4756 // having the target set a limit on number of ops and making a
4757 // callback to determine that sequence (similar to sqrt expansion).
4758 unsigned MathOp = ISD::DELETED_NODE;
4759 APInt MulC = ConstValue1.abs();
4760 // The constant `2` should be treated as (2^0 + 1).
4761 unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4762 MulC.lshrInPlace(TZeros);
4763 if ((MulC - 1).isPowerOf2())
4764 MathOp = ISD::ADD;
4765 else if ((MulC + 1).isPowerOf2())
4766 MathOp = ISD::SUB;
4767
4768 if (MathOp != ISD::DELETED_NODE) {
4769 unsigned ShAmt =
4770 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4771 ShAmt += TZeros;
4772 assert(ShAmt < BitWidth &&
4773 "multiply-by-constant generated out of bounds shift");
4774 SDValue Shl =
4775 DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4776 SDValue R =
4777 TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4778 DAG.getNode(ISD::SHL, DL, VT, N0,
4779 DAG.getConstant(TZeros, DL, VT)))
4780 : DAG.getNode(MathOp, DL, VT, Shl, N0);
4781 if (ConstValue1.isNegative())
4782 R = DAG.getNegative(R, DL, VT);
4783 return R;
4784 }
4785 }
4786
4787 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4788 if (sd_context_match(N0, Matcher, m_Opc(ISD::SHL))) {
4789 SDValue N01 = N0.getOperand(1);
4790 if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4791 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4792 }
4793
4794 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4795 // use.
4796 {
4797 SDValue Sh, Y;
4798
4799 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4800 if (sd_context_match(N0, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4801 isConstantOrConstantVector(N0.getOperand(1))) {
4802 Sh = N0; Y = N1;
4803 } else if (sd_context_match(N1, Matcher, m_OneUse(m_Opc(ISD::SHL))) &&
4804 isConstantOrConstantVector(N1.getOperand(1))) {
4805 Sh = N1; Y = N0;
4806 }
4807
4808 if (Sh.getNode()) {
4809 SDValue Mul = Matcher.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4810 return Matcher.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4811 }
4812 }
4813
4814 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4815 if (sd_context_match(N0, Matcher, m_Opc(ISD::ADD)) &&
4816 DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4817 DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4818 isMulAddWithConstProfitable(N, N0, N1))
4819 return Matcher.getNode(
4820 ISD::ADD, DL, VT,
4821 Matcher.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4822 Matcher.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4823
4824 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4825 ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4826 if (!UseVP && N0.getOpcode() == ISD::VSCALE && NC1) {
4827 const APInt &C0 = N0.getConstantOperandAPInt(0);
4828 const APInt &C1 = NC1->getAPIntValue();
4829 return DAG.getVScale(DL, VT, C0 * C1);
4830 }
4831
4832 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4833 APInt MulVal;
4834 if (!UseVP && N0.getOpcode() == ISD::STEP_VECTOR &&
4835 ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4836 const APInt &C0 = N0.getConstantOperandAPInt(0);
4837 APInt NewStep = C0 * MulVal;
4838 return DAG.getStepVector(DL, VT, NewStep);
4839 }
4840
4841 // Fold Y = sra (X, size(X)-1); mul (or (Y, 1), X) -> (abs X)
4842 SDValue X;
4843 if (!UseVP && (!LegalOperations || hasOperation(ISD::ABS, VT)) &&
4844 sd_context_match(
4845 N, Matcher,
4846 m_Mul(m_Or(m_Sra(m_Value(X), m_SpecificInt(BitWidth - 1)), m_One()),
4847 m_Deferred(X)))) {
4848 return Matcher.getNode(ISD::ABS, DL, VT, X);
4849 }
4850
4851 // Fold ((mul x, 0/undef) -> 0,
4852 // (mul x, 1) -> x) -> x)
4853 // -> and(x, mask)
4854 // We can replace vectors with '0' and '1' factors with a clearing mask.
4855 if (VT.isFixedLengthVector()) {
4856 unsigned NumElts = VT.getVectorNumElements();
4857 SmallBitVector ClearMask;
4858 ClearMask.reserve(NumElts);
4859 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4860 if (!V || V->isZero()) {
4861 ClearMask.push_back(true);
4862 return true;
4863 }
4864 ClearMask.push_back(false);
4865 return V->isOne();
4866 };
4867 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4868 ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4869 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4870 EVT LegalSVT = N1.getOperand(0).getValueType();
4871 SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4872 SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4873 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4874 for (unsigned I = 0; I != NumElts; ++I)
4875 if (ClearMask[I])
4876 Mask[I] = Zero;
4877 return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4878 }
4879 }
4880
4881 // reassociate mul
4882 // TODO: Change reassociateOps to support vp ops.
4883 if (!UseVP)
4884 if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4885 return RMUL;
4886
4887 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4888 // TODO: Change reassociateReduction to support vp ops.
4889 if (!UseVP)
4890 if (SDValue SD =
4891 reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4892 return SD;
4893
4894 // Simplify the operands using demanded-bits information.
4895 if (SimplifyDemandedBits(SDValue(N, 0)))
4896 return SDValue(N, 0);
4897
4898 return SDValue();
4899 }
4900
4901 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4902 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4903 const TargetLowering &TLI) {
4904 RTLIB::Libcall LC;
4905 EVT NodeType = Node->getValueType(0);
4906 if (!NodeType.isSimple())
4907 return false;
4908 switch (NodeType.getSimpleVT().SimpleTy) {
4909 default: return false; // No libcall for vector types.
4910 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
4911 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4912 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4913 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4914 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4915 }
4916
4917 return TLI.getLibcallName(LC) != nullptr;
4918 }
4919
4920 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4921 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4922 if (Node->use_empty())
4923 return SDValue(); // This is a dead node, leave it alone.
4924
4925 unsigned Opcode = Node->getOpcode();
4926 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4927 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4928
4929 // DivMod lib calls can still work on non-legal types if using lib-calls.
4930 EVT VT = Node->getValueType(0);
4931 if (VT.isVector() || !VT.isInteger())
4932 return SDValue();
4933
4934 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4935 return SDValue();
4936
4937 // If DIVREM is going to get expanded into a libcall,
4938 // but there is no libcall available, then don't combine.
4939 if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4940 !isDivRemLibcallAvailable(Node, isSigned, TLI))
4941 return SDValue();
4942
4943 // If div is legal, it's better to do the normal expansion
4944 unsigned OtherOpcode = 0;
4945 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4946 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4947 if (TLI.isOperationLegalOrCustom(Opcode, VT))
4948 return SDValue();
4949 } else {
4950 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4951 if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4952 return SDValue();
4953 }
4954
4955 SDValue Op0 = Node->getOperand(0);
4956 SDValue Op1 = Node->getOperand(1);
4957 SDValue combined;
4958 for (SDNode *User : Op0->users()) {
4959 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4960 User->use_empty())
4961 continue;
4962 // Convert the other matching node(s), too;
4963 // otherwise, the DIVREM may get target-legalized into something
4964 // target-specific that we won't be able to recognize.
4965 unsigned UserOpc = User->getOpcode();
4966 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4967 User->getOperand(0) == Op0 &&
4968 User->getOperand(1) == Op1) {
4969 if (!combined) {
4970 if (UserOpc == OtherOpcode) {
4971 SDVTList VTs = DAG.getVTList(VT, VT);
4972 combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4973 } else if (UserOpc == DivRemOpc) {
4974 combined = SDValue(User, 0);
4975 } else {
4976 assert(UserOpc == Opcode);
4977 continue;
4978 }
4979 }
4980 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4981 CombineTo(User, combined);
4982 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4983 CombineTo(User, combined.getValue(1));
4984 }
4985 }
4986 return combined;
4987 }
4988
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4989 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4990 SDValue N0 = N->getOperand(0);
4991 SDValue N1 = N->getOperand(1);
4992 EVT VT = N->getValueType(0);
4993 SDLoc DL(N);
4994
4995 unsigned Opc = N->getOpcode();
4996 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4997 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4998
4999 // X / undef -> undef
5000 // X % undef -> undef
5001 // X / 0 -> undef
5002 // X % 0 -> undef
5003 // NOTE: This includes vectors where any divisor element is zero/undef.
5004 if (DAG.isUndef(Opc, {N0, N1}))
5005 return DAG.getUNDEF(VT);
5006
5007 // undef / X -> 0
5008 // undef % X -> 0
5009 if (N0.isUndef())
5010 return DAG.getConstant(0, DL, VT);
5011
5012 // 0 / X -> 0
5013 // 0 % X -> 0
5014 ConstantSDNode *N0C = isConstOrConstSplat(N0);
5015 if (N0C && N0C->isZero())
5016 return N0;
5017
5018 // X / X -> 1
5019 // X % X -> 0
5020 if (N0 == N1)
5021 return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
5022
5023 // X / 1 -> X
5024 // X % 1 -> 0
5025 // If this is a boolean op (single-bit element type), we can't have
5026 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
5027 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
5028 // it's a 1.
5029 if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
5030 return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
5031
5032 return SDValue();
5033 }
5034
visitSDIV(SDNode * N)5035 SDValue DAGCombiner::visitSDIV(SDNode *N) {
5036 SDValue N0 = N->getOperand(0);
5037 SDValue N1 = N->getOperand(1);
5038 EVT VT = N->getValueType(0);
5039 EVT CCVT = getSetCCResultType(VT);
5040 SDLoc DL(N);
5041
5042 // fold (sdiv c1, c2) -> c1/c2
5043 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
5044 return C;
5045
5046 // fold vector ops
5047 if (VT.isVector())
5048 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5049 return FoldedVOp;
5050
5051 // fold (sdiv X, -1) -> 0-X
5052 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5053 if (N1C && N1C->isAllOnes())
5054 return DAG.getNegative(N0, DL, VT);
5055
5056 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
5057 if (N1C && N1C->isMinSignedValue())
5058 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
5059 DAG.getConstant(1, DL, VT),
5060 DAG.getConstant(0, DL, VT));
5061
5062 if (SDValue V = simplifyDivRem(N, DAG))
5063 return V;
5064
5065 if (SDValue NewSel = foldBinOpIntoSelect(N))
5066 return NewSel;
5067
5068 // If we know the sign bits of both operands are zero, strength reduce to a
5069 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
5070 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
5071 return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
5072
5073 if (SDValue V = visitSDIVLike(N0, N1, N)) {
5074 // If the corresponding remainder node exists, update its users with
5075 // (Dividend - (Quotient * Divisor).
5076 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
5077 { N0, N1 })) {
5078 // If the sdiv has the exact flag we shouldn't propagate it to the
5079 // remainder node.
5080 if (!N->getFlags().hasExact()) {
5081 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
5082 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5083 AddToWorklist(Mul.getNode());
5084 AddToWorklist(Sub.getNode());
5085 CombineTo(RemNode, Sub);
5086 }
5087 }
5088 return V;
5089 }
5090
5091 // sdiv, srem -> sdivrem
5092 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5093 // true. Otherwise, we break the simplification logic in visitREM().
5094 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5095 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
5096 if (SDValue DivRem = useDivRem(N))
5097 return DivRem;
5098
5099 return SDValue();
5100 }
5101
isDivisorPowerOfTwo(SDValue Divisor)5102 static bool isDivisorPowerOfTwo(SDValue Divisor) {
5103 // Helper for determining whether a value is a power-2 constant scalar or a
5104 // vector of such elements.
5105 auto IsPowerOfTwo = [](ConstantSDNode *C) {
5106 if (C->isZero() || C->isOpaque())
5107 return false;
5108 if (C->getAPIntValue().isPowerOf2())
5109 return true;
5110 if (C->getAPIntValue().isNegatedPowerOf2())
5111 return true;
5112 return false;
5113 };
5114
5115 return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
5116 }
5117
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)5118 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5119 SDLoc DL(N);
5120 EVT VT = N->getValueType(0);
5121 EVT CCVT = getSetCCResultType(VT);
5122 unsigned BitWidth = VT.getScalarSizeInBits();
5123
5124 // fold (sdiv X, pow2) -> simple ops after legalize
5125 // FIXME: We check for the exact bit here because the generic lowering gives
5126 // better results in that case. The target-specific lowering should learn how
5127 // to handle exact sdivs efficiently.
5128 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
5129 // Target-specific implementation of sdiv x, pow2.
5130 if (SDValue Res = BuildSDIVPow2(N))
5131 return Res;
5132
5133 // Create constants that are functions of the shift amount value.
5134 EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
5135 SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
5136 SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
5137 C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
5138 SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
5139 if (!isConstantOrConstantVector(Inexact))
5140 return SDValue();
5141
5142 // Splat the sign bit into the register
5143 SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
5144 DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
5145 AddToWorklist(Sign.getNode());
5146
5147 // Add (N0 < 0) ? abs2 - 1 : 0;
5148 SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
5149 AddToWorklist(Srl.getNode());
5150 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
5151 AddToWorklist(Add.getNode());
5152 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
5153 AddToWorklist(Sra.getNode());
5154
5155 // Special case: (sdiv X, 1) -> X
5156 // Special Case: (sdiv X, -1) -> 0-X
5157 SDValue One = DAG.getConstant(1, DL, VT);
5158 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
5159 SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
5160 SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
5161 SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
5162 Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
5163
5164 // If dividing by a positive value, we're done. Otherwise, the result must
5165 // be negated.
5166 SDValue Zero = DAG.getConstant(0, DL, VT);
5167 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
5168
5169 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
5170 SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
5171 SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
5172 return Res;
5173 }
5174
5175 // If integer divide is expensive and we satisfy the requirements, emit an
5176 // alternate sequence. Targets may check function attributes for size/speed
5177 // trade-offs.
5178 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5179 if (isConstantOrConstantVector(N1) &&
5180 !TLI.isIntDivCheap(N->getValueType(0), Attr))
5181 if (SDValue Op = BuildSDIV(N))
5182 return Op;
5183
5184 return SDValue();
5185 }
5186
visitUDIV(SDNode * N)5187 SDValue DAGCombiner::visitUDIV(SDNode *N) {
5188 SDValue N0 = N->getOperand(0);
5189 SDValue N1 = N->getOperand(1);
5190 EVT VT = N->getValueType(0);
5191 EVT CCVT = getSetCCResultType(VT);
5192 SDLoc DL(N);
5193
5194 // fold (udiv c1, c2) -> c1/c2
5195 if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
5196 return C;
5197
5198 // fold vector ops
5199 if (VT.isVector())
5200 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5201 return FoldedVOp;
5202
5203 // fold (udiv X, -1) -> select(X == -1, 1, 0)
5204 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5205 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
5206 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
5207 DAG.getConstant(1, DL, VT),
5208 DAG.getConstant(0, DL, VT));
5209 }
5210
5211 if (SDValue V = simplifyDivRem(N, DAG))
5212 return V;
5213
5214 if (SDValue NewSel = foldBinOpIntoSelect(N))
5215 return NewSel;
5216
5217 if (SDValue V = visitUDIVLike(N0, N1, N)) {
5218 // If the corresponding remainder node exists, update its users with
5219 // (Dividend - (Quotient * Divisor).
5220 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
5221 { N0, N1 })) {
5222 // If the udiv has the exact flag we shouldn't propagate it to the
5223 // remainder node.
5224 if (!N->getFlags().hasExact()) {
5225 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
5226 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5227 AddToWorklist(Mul.getNode());
5228 AddToWorklist(Sub.getNode());
5229 CombineTo(RemNode, Sub);
5230 }
5231 }
5232 return V;
5233 }
5234
5235 // sdiv, srem -> sdivrem
5236 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
5237 // true. Otherwise, we break the simplification logic in visitREM().
5238 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5239 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
5240 if (SDValue DivRem = useDivRem(N))
5241 return DivRem;
5242
5243 // Simplify the operands using demanded-bits information.
5244 // We don't have demanded bits support for UDIV so this just enables constant
5245 // folding based on known bits.
5246 if (SimplifyDemandedBits(SDValue(N, 0)))
5247 return SDValue(N, 0);
5248
5249 return SDValue();
5250 }
5251
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)5252 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
5253 SDLoc DL(N);
5254 EVT VT = N->getValueType(0);
5255
5256 // fold (udiv x, (1 << c)) -> x >>u c
5257 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true)) {
5258 if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5259 AddToWorklist(LogBase2.getNode());
5260
5261 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5262 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
5263 AddToWorklist(Trunc.getNode());
5264 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5265 }
5266 }
5267
5268 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
5269 if (N1.getOpcode() == ISD::SHL) {
5270 SDValue N10 = N1.getOperand(0);
5271 if (isConstantOrConstantVector(N10, /*NoOpaques*/ true)) {
5272 if (SDValue LogBase2 = BuildLogBase2(N10, DL)) {
5273 AddToWorklist(LogBase2.getNode());
5274
5275 EVT ADDVT = N1.getOperand(1).getValueType();
5276 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
5277 AddToWorklist(Trunc.getNode());
5278 SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
5279 AddToWorklist(Add.getNode());
5280 return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
5281 }
5282 }
5283 }
5284
5285 // fold (udiv x, c) -> alternate
5286 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5287 if (isConstantOrConstantVector(N1) &&
5288 !TLI.isIntDivCheap(N->getValueType(0), Attr))
5289 if (SDValue Op = BuildUDIV(N))
5290 return Op;
5291
5292 return SDValue();
5293 }
5294
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)5295 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
5296 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
5297 !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
5298 // Target-specific implementation of srem x, pow2.
5299 if (SDValue Res = BuildSREMPow2(N))
5300 return Res;
5301 }
5302 return SDValue();
5303 }
5304
5305 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)5306 SDValue DAGCombiner::visitREM(SDNode *N) {
5307 unsigned Opcode = N->getOpcode();
5308 SDValue N0 = N->getOperand(0);
5309 SDValue N1 = N->getOperand(1);
5310 EVT VT = N->getValueType(0);
5311 EVT CCVT = getSetCCResultType(VT);
5312
5313 bool isSigned = (Opcode == ISD::SREM);
5314 SDLoc DL(N);
5315
5316 // fold (rem c1, c2) -> c1%c2
5317 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5318 return C;
5319
5320 // fold (urem X, -1) -> select(FX == -1, 0, FX)
5321 // Freeze the numerator to avoid a miscompile with an undefined value.
5322 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false) &&
5323 CCVT.isVector() == VT.isVector()) {
5324 SDValue F0 = DAG.getFreeze(N0);
5325 SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
5326 return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
5327 }
5328
5329 if (SDValue V = simplifyDivRem(N, DAG))
5330 return V;
5331
5332 if (SDValue NewSel = foldBinOpIntoSelect(N))
5333 return NewSel;
5334
5335 if (isSigned) {
5336 // If we know the sign bits of both operands are zero, strength reduce to a
5337 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
5338 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
5339 return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
5340 } else {
5341 if (DAG.isKnownToBeAPowerOfTwo(N1)) {
5342 // fold (urem x, pow2) -> (and x, pow2-1)
5343 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5344 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5345 AddToWorklist(Add.getNode());
5346 return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5347 }
5348 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
5349 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
5350 // TODO: We should sink the following into isKnownToBePowerOfTwo
5351 // using a OrZero parameter analogous to our handling in ValueTracking.
5352 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
5353 DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
5354 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
5355 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
5356 AddToWorklist(Add.getNode());
5357 return DAG.getNode(ISD::AND, DL, VT, N0, Add);
5358 }
5359 }
5360
5361 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
5362
5363 // If X/C can be simplified by the division-by-constant logic, lower
5364 // X%C to the equivalent of X-X/C*C.
5365 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
5366 // speculative DIV must not cause a DIVREM conversion. We guard against this
5367 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
5368 // combine will not return a DIVREM. Regardless, checking cheapness here
5369 // makes sense since the simplification results in fatter code.
5370 if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
5371 if (isSigned) {
5372 // check if we can build faster implementation for srem
5373 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
5374 return OptimizedRem;
5375 }
5376
5377 SDValue OptimizedDiv =
5378 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
5379 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
5380 // If the equivalent Div node also exists, update its users.
5381 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
5382 if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
5383 { N0, N1 }))
5384 CombineTo(DivNode, OptimizedDiv);
5385 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
5386 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
5387 AddToWorklist(OptimizedDiv.getNode());
5388 AddToWorklist(Mul.getNode());
5389 return Sub;
5390 }
5391 }
5392
5393 // sdiv, srem -> sdivrem
5394 if (SDValue DivRem = useDivRem(N))
5395 return DivRem.getValue(1);
5396
5397 return SDValue();
5398 }
5399
visitMULHS(SDNode * N)5400 SDValue DAGCombiner::visitMULHS(SDNode *N) {
5401 SDValue N0 = N->getOperand(0);
5402 SDValue N1 = N->getOperand(1);
5403 EVT VT = N->getValueType(0);
5404 SDLoc DL(N);
5405
5406 // fold (mulhs c1, c2)
5407 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
5408 return C;
5409
5410 // canonicalize constant to RHS.
5411 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5412 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5413 return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
5414
5415 if (VT.isVector()) {
5416 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5417 return FoldedVOp;
5418
5419 // fold (mulhs x, 0) -> 0
5420 // do not return N1, because undef node may exist.
5421 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5422 return DAG.getConstant(0, DL, VT);
5423 }
5424
5425 // fold (mulhs x, 0) -> 0
5426 if (isNullConstant(N1))
5427 return N1;
5428
5429 // fold (mulhs x, 1) -> (sra x, size(x)-1)
5430 if (isOneConstant(N1))
5431 return DAG.getNode(
5432 ISD::SRA, DL, VT, N0,
5433 DAG.getShiftAmountConstant(N0.getScalarValueSizeInBits() - 1, VT, DL));
5434
5435 // fold (mulhs x, undef) -> 0
5436 if (N0.isUndef() || N1.isUndef())
5437 return DAG.getConstant(0, DL, VT);
5438
5439 // If the type twice as wide is legal, transform the mulhs to a wider multiply
5440 // plus a shift.
5441 if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
5442 !VT.isVector()) {
5443 MVT Simple = VT.getSimpleVT();
5444 unsigned SimpleSize = Simple.getSizeInBits();
5445 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5446 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5447 N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5448 N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5449 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5450 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5451 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5452 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5453 }
5454 }
5455
5456 return SDValue();
5457 }
5458
visitMULHU(SDNode * N)5459 SDValue DAGCombiner::visitMULHU(SDNode *N) {
5460 SDValue N0 = N->getOperand(0);
5461 SDValue N1 = N->getOperand(1);
5462 EVT VT = N->getValueType(0);
5463 SDLoc DL(N);
5464
5465 // fold (mulhu c1, c2)
5466 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
5467 return C;
5468
5469 // canonicalize constant to RHS.
5470 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5471 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5472 return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
5473
5474 if (VT.isVector()) {
5475 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5476 return FoldedVOp;
5477
5478 // fold (mulhu x, 0) -> 0
5479 // do not return N1, because undef node may exist.
5480 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
5481 return DAG.getConstant(0, DL, VT);
5482 }
5483
5484 // fold (mulhu x, 0) -> 0
5485 if (isNullConstant(N1))
5486 return N1;
5487
5488 // fold (mulhu x, 1) -> 0
5489 if (isOneConstant(N1))
5490 return DAG.getConstant(0, DL, VT);
5491
5492 // fold (mulhu x, undef) -> 0
5493 if (N0.isUndef() || N1.isUndef())
5494 return DAG.getConstant(0, DL, VT);
5495
5496 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5497 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
5498 hasOperation(ISD::SRL, VT)) {
5499 if (SDValue LogBase2 = BuildLogBase2(N1, DL)) {
5500 unsigned NumEltBits = VT.getScalarSizeInBits();
5501 SDValue SRLAmt = DAG.getNode(
5502 ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
5503 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
5504 SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
5505 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
5506 }
5507 }
5508
5509 // If the type twice as wide is legal, transform the mulhu to a wider multiply
5510 // plus a shift.
5511 if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
5512 !VT.isVector()) {
5513 MVT Simple = VT.getSimpleVT();
5514 unsigned SimpleSize = Simple.getSizeInBits();
5515 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5516 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5517 N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5518 N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5519 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
5520 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
5521 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5522 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
5523 }
5524 }
5525
5526 // Simplify the operands using demanded-bits information.
5527 // We don't have demanded bits support for MULHU so this just enables constant
5528 // folding based on known bits.
5529 if (SimplifyDemandedBits(SDValue(N, 0)))
5530 return SDValue(N, 0);
5531
5532 return SDValue();
5533 }
5534
visitAVG(SDNode * N)5535 SDValue DAGCombiner::visitAVG(SDNode *N) {
5536 unsigned Opcode = N->getOpcode();
5537 SDValue N0 = N->getOperand(0);
5538 SDValue N1 = N->getOperand(1);
5539 EVT VT = N->getValueType(0);
5540 SDLoc DL(N);
5541 bool IsSigned = Opcode == ISD::AVGCEILS || Opcode == ISD::AVGFLOORS;
5542
5543 // fold (avg c1, c2)
5544 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5545 return C;
5546
5547 // canonicalize constant to RHS.
5548 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5549 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5550 return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5551
5552 if (VT.isVector())
5553 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5554 return FoldedVOp;
5555
5556 // fold (avg x, undef) -> x
5557 if (N0.isUndef())
5558 return N1;
5559 if (N1.isUndef())
5560 return N0;
5561
5562 // fold (avg x, x) --> x
5563 if (N0 == N1 && Level >= AfterLegalizeTypes)
5564 return N0;
5565
5566 // fold (avgfloor x, 0) -> x >> 1
5567 SDValue X, Y;
5568 if (sd_match(N, m_c_BinOp(ISD::AVGFLOORS, m_Value(X), m_Zero())))
5569 return DAG.getNode(ISD::SRA, DL, VT, X,
5570 DAG.getShiftAmountConstant(1, VT, DL));
5571 if (sd_match(N, m_c_BinOp(ISD::AVGFLOORU, m_Value(X), m_Zero())))
5572 return DAG.getNode(ISD::SRL, DL, VT, X,
5573 DAG.getShiftAmountConstant(1, VT, DL));
5574
5575 // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
5576 // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
5577 if (!IsSigned &&
5578 sd_match(N, m_BinOp(Opcode, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
5579 X.getValueType() == Y.getValueType() &&
5580 hasOperation(Opcode, X.getValueType())) {
5581 SDValue AvgU = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5582 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgU);
5583 }
5584 if (IsSigned &&
5585 sd_match(N, m_BinOp(Opcode, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
5586 X.getValueType() == Y.getValueType() &&
5587 hasOperation(Opcode, X.getValueType())) {
5588 SDValue AvgS = DAG.getNode(Opcode, DL, X.getValueType(), X, Y);
5589 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgS);
5590 }
5591
5592 // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5593 // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5594 // Check if avgflooru isn't legal/custom but avgceilu is.
5595 if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
5596 (!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
5597 if (DAG.isKnownNeverZero(N1))
5598 return DAG.getNode(
5599 ISD::AVGCEILU, DL, VT, N0,
5600 DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
5601 if (DAG.isKnownNeverZero(N0))
5602 return DAG.getNode(
5603 ISD::AVGCEILU, DL, VT, N1,
5604 DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
5605 }
5606
5607 // Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
5608 // Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
5609 if ((Opcode == ISD::AVGFLOORU && hasOperation(ISD::AVGCEILU, VT)) ||
5610 (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGCEILS, VT))) {
5611 SDValue Add;
5612 if (sd_match(N,
5613 m_c_BinOp(Opcode,
5614 m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))),
5615 m_One())) ||
5616 sd_match(N, m_c_BinOp(Opcode,
5617 m_AllOf(m_Value(Add), m_Add(m_Value(X), m_One())),
5618 m_Value(Y)))) {
5619
5620 if (IsSigned && Add->getFlags().hasNoSignedWrap())
5621 return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
5622
5623 if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
5624 return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
5625 }
5626 }
5627
5628 // Fold avgfloors(x,y) -> avgflooru(x,y) if both x and y are non-negative
5629 if (Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGFLOORU, VT)) {
5630 if (DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5631 return DAG.getNode(ISD::AVGFLOORU, DL, VT, N0, N1);
5632 }
5633
5634 return SDValue();
5635 }
5636
visitABD(SDNode * N)5637 SDValue DAGCombiner::visitABD(SDNode *N) {
5638 unsigned Opcode = N->getOpcode();
5639 SDValue N0 = N->getOperand(0);
5640 SDValue N1 = N->getOperand(1);
5641 EVT VT = N->getValueType(0);
5642 SDLoc DL(N);
5643
5644 // fold (abd c1, c2)
5645 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5646 return C;
5647
5648 // canonicalize constant to RHS.
5649 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5650 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5651 return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
5652
5653 if (VT.isVector())
5654 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5655 return FoldedVOp;
5656
5657 // fold (abd x, undef) -> 0
5658 if (N0.isUndef() || N1.isUndef())
5659 return DAG.getConstant(0, DL, VT);
5660
5661 // fold (abd x, x) -> 0
5662 if (N0 == N1)
5663 return DAG.getConstant(0, DL, VT);
5664
5665 SDValue X;
5666
5667 // fold (abds x, 0) -> abs x
5668 if (sd_match(N, m_c_BinOp(ISD::ABDS, m_Value(X), m_Zero())) &&
5669 (!LegalOperations || hasOperation(ISD::ABS, VT)))
5670 return DAG.getNode(ISD::ABS, DL, VT, X);
5671
5672 // fold (abdu x, 0) -> x
5673 if (sd_match(N, m_c_BinOp(ISD::ABDU, m_Value(X), m_Zero())))
5674 return X;
5675
5676 // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5677 if (Opcode == ISD::ABDS && hasOperation(ISD::ABDU, VT) &&
5678 DAG.SignBitIsZero(N0) && DAG.SignBitIsZero(N1))
5679 return DAG.getNode(ISD::ABDU, DL, VT, N1, N0);
5680
5681 return SDValue();
5682 }
5683
5684 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5685 /// give the opcodes for the two computations that are being performed. Return
5686 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)5687 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5688 unsigned HiOp) {
5689 // If the high half is not needed, just compute the low half.
5690 bool HiExists = N->hasAnyUseOfValue(1);
5691 if (!HiExists && (!LegalOperations ||
5692 TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
5693 SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5694 return CombineTo(N, Res, Res);
5695 }
5696
5697 // If the low half is not needed, just compute the high half.
5698 bool LoExists = N->hasAnyUseOfValue(0);
5699 if (!LoExists && (!LegalOperations ||
5700 TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
5701 SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5702 return CombineTo(N, Res, Res);
5703 }
5704
5705 // If both halves are used, return as it is.
5706 if (LoExists && HiExists)
5707 return SDValue();
5708
5709 // If the two computed results can be simplified separately, separate them.
5710 if (LoExists) {
5711 SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
5712 AddToWorklist(Lo.getNode());
5713 SDValue LoOpt = combine(Lo.getNode());
5714 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5715 (!LegalOperations ||
5716 TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
5717 return CombineTo(N, LoOpt, LoOpt);
5718 }
5719
5720 if (HiExists) {
5721 SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
5722 AddToWorklist(Hi.getNode());
5723 SDValue HiOpt = combine(Hi.getNode());
5724 if (HiOpt.getNode() && HiOpt != Hi &&
5725 (!LegalOperations ||
5726 TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
5727 return CombineTo(N, HiOpt, HiOpt);
5728 }
5729
5730 return SDValue();
5731 }
5732
visitSMUL_LOHI(SDNode * N)5733 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5734 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
5735 return Res;
5736
5737 SDValue N0 = N->getOperand(0);
5738 SDValue N1 = N->getOperand(1);
5739 EVT VT = N->getValueType(0);
5740 SDLoc DL(N);
5741
5742 // Constant fold.
5743 if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5744 return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N0, N1);
5745
5746 // canonicalize constant to RHS (vector doesn't have to splat)
5747 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5748 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5749 return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
5750
5751 // If the type is twice as wide is legal, transform the mulhu to a wider
5752 // multiply plus a shift.
5753 if (VT.isSimple() && !VT.isVector()) {
5754 MVT Simple = VT.getSimpleVT();
5755 unsigned SimpleSize = Simple.getSizeInBits();
5756 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5757 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5758 SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
5759 SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
5760 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5761 // Compute the high part as N1.
5762 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5763 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5764 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5765 // Compute the low part as N0.
5766 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5767 return CombineTo(N, Lo, Hi);
5768 }
5769 }
5770
5771 return SDValue();
5772 }
5773
visitUMUL_LOHI(SDNode * N)5774 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5775 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
5776 return Res;
5777
5778 SDValue N0 = N->getOperand(0);
5779 SDValue N1 = N->getOperand(1);
5780 EVT VT = N->getValueType(0);
5781 SDLoc DL(N);
5782
5783 // Constant fold.
5784 if (isa<ConstantSDNode>(N0) && isa<ConstantSDNode>(N1))
5785 return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N0, N1);
5786
5787 // canonicalize constant to RHS (vector doesn't have to splat)
5788 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5789 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5790 return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
5791
5792 // (umul_lohi N0, 0) -> (0, 0)
5793 if (isNullConstant(N1)) {
5794 SDValue Zero = DAG.getConstant(0, DL, VT);
5795 return CombineTo(N, Zero, Zero);
5796 }
5797
5798 // (umul_lohi N0, 1) -> (N0, 0)
5799 if (isOneConstant(N1)) {
5800 SDValue Zero = DAG.getConstant(0, DL, VT);
5801 return CombineTo(N, N0, Zero);
5802 }
5803
5804 // If the type is twice as wide is legal, transform the mulhu to a wider
5805 // multiply plus a shift.
5806 if (VT.isSimple() && !VT.isVector()) {
5807 MVT Simple = VT.getSimpleVT();
5808 unsigned SimpleSize = Simple.getSizeInBits();
5809 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5810 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5811 SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5812 SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5813 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5814 // Compute the high part as N1.
5815 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5816 DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
5817 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5818 // Compute the low part as N0.
5819 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5820 return CombineTo(N, Lo, Hi);
5821 }
5822 }
5823
5824 return SDValue();
5825 }
5826
visitMULO(SDNode * N)5827 SDValue DAGCombiner::visitMULO(SDNode *N) {
5828 SDValue N0 = N->getOperand(0);
5829 SDValue N1 = N->getOperand(1);
5830 EVT VT = N0.getValueType();
5831 bool IsSigned = (ISD::SMULO == N->getOpcode());
5832
5833 EVT CarryVT = N->getValueType(1);
5834 SDLoc DL(N);
5835
5836 ConstantSDNode *N0C = isConstOrConstSplat(N0);
5837 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5838
5839 // fold operation with constant operands.
5840 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5841 // multiple results.
5842 if (N0C && N1C) {
5843 bool Overflow;
5844 APInt Result =
5845 IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
5846 : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
5847 return CombineTo(N, DAG.getConstant(Result, DL, VT),
5848 DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
5849 }
5850
5851 // canonicalize constant to RHS.
5852 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5853 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5854 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
5855
5856 // fold (mulo x, 0) -> 0 + no carry out
5857 if (isNullOrNullSplat(N1))
5858 return CombineTo(N, DAG.getConstant(0, DL, VT),
5859 DAG.getConstant(0, DL, CarryVT));
5860
5861 // (mulo x, 2) -> (addo x, x)
5862 // FIXME: This needs a freeze.
5863 if (N1C && N1C->getAPIntValue() == 2 &&
5864 (!IsSigned || VT.getScalarSizeInBits() > 2))
5865 return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5866 N->getVTList(), N0, N0);
5867
5868 // A 1 bit SMULO overflows if both inputs are 1.
5869 if (IsSigned && VT.getScalarSizeInBits() == 1) {
5870 SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5871 SDValue Cmp = DAG.getSetCC(DL, CarryVT, And,
5872 DAG.getConstant(0, DL, VT), ISD::SETNE);
5873 return CombineTo(N, And, Cmp);
5874 }
5875
5876 // If it cannot overflow, transform into a mul.
5877 if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5878 return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5879 DAG.getConstant(0, DL, CarryVT));
5880 return SDValue();
5881 }
5882
5883 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5884 // swapped around) make a signed saturate pattern, clamping to between a signed
5885 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5886 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5887 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5888 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned,SelectionDAG & DAG)5889 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5890 SDValue N3, ISD::CondCode CC, unsigned &BW,
5891 bool &Unsigned, SelectionDAG &DAG) {
5892 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5893 ISD::CondCode CC) {
5894 // The compare and select operand should be the same or the select operands
5895 // should be truncated versions of the comparison.
5896 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5897 return 0;
5898 // The constants need to be the same or a truncated version of each other.
5899 ConstantSDNode *N1C = isConstOrConstSplat(peekThroughTruncates(N1));
5900 ConstantSDNode *N3C = isConstOrConstSplat(peekThroughTruncates(N3));
5901 if (!N1C || !N3C)
5902 return 0;
5903 const APInt &C1 = N1C->getAPIntValue().trunc(N1.getScalarValueSizeInBits());
5904 const APInt &C2 = N3C->getAPIntValue().trunc(N3.getScalarValueSizeInBits());
5905 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5906 return 0;
5907 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5908 };
5909
5910 // Check the initial value is a SMIN/SMAX equivalent.
5911 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5912 if (!Opcode0)
5913 return SDValue();
5914
5915 // We could only need one range check, if the fptosi could never produce
5916 // the upper value.
5917 if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5918 if (isNullOrNullSplat(N3)) {
5919 EVT IntVT = N0.getValueType().getScalarType();
5920 EVT FPVT = N0.getOperand(0).getValueType().getScalarType();
5921 if (FPVT.isSimple()) {
5922 Type *InputTy = FPVT.getTypeForEVT(*DAG.getContext());
5923 const fltSemantics &Semantics = InputTy->getFltSemantics();
5924 uint32_t MinBitWidth =
5925 APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5926 if (IntVT.getSizeInBits() >= MinBitWidth) {
5927 Unsigned = true;
5928 BW = PowerOf2Ceil(MinBitWidth);
5929 return N0;
5930 }
5931 }
5932 }
5933 }
5934
5935 SDValue N00, N01, N02, N03;
5936 ISD::CondCode N0CC;
5937 switch (N0.getOpcode()) {
5938 case ISD::SMIN:
5939 case ISD::SMAX:
5940 N00 = N02 = N0.getOperand(0);
5941 N01 = N03 = N0.getOperand(1);
5942 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5943 break;
5944 case ISD::SELECT_CC:
5945 N00 = N0.getOperand(0);
5946 N01 = N0.getOperand(1);
5947 N02 = N0.getOperand(2);
5948 N03 = N0.getOperand(3);
5949 N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5950 break;
5951 case ISD::SELECT:
5952 case ISD::VSELECT:
5953 if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5954 return SDValue();
5955 N00 = N0.getOperand(0).getOperand(0);
5956 N01 = N0.getOperand(0).getOperand(1);
5957 N02 = N0.getOperand(1);
5958 N03 = N0.getOperand(2);
5959 N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5960 break;
5961 default:
5962 return SDValue();
5963 }
5964
5965 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5966 if (!Opcode1 || Opcode0 == Opcode1)
5967 return SDValue();
5968
5969 ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5970 ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5971 if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5972 return SDValue();
5973
5974 const APInt &MinC = MinCOp->getAPIntValue();
5975 const APInt &MaxC = MaxCOp->getAPIntValue();
5976 APInt MinCPlus1 = MinC + 1;
5977 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5978 BW = MinCPlus1.exactLogBase2() + 1;
5979 Unsigned = false;
5980 return N02;
5981 }
5982
5983 if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5984 BW = MinCPlus1.exactLogBase2();
5985 Unsigned = true;
5986 return N02;
5987 }
5988
5989 return SDValue();
5990 }
5991
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5992 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5993 SDValue N3, ISD::CondCode CC,
5994 SelectionDAG &DAG) {
5995 unsigned BW;
5996 bool Unsigned;
5997 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
5998 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5999 return SDValue();
6000 EVT FPVT = Fp.getOperand(0).getValueType();
6001 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
6002 if (FPVT.isVector())
6003 NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
6004 FPVT.getVectorElementCount());
6005 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
6006 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
6007 return SDValue();
6008 SDLoc DL(Fp);
6009 SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
6010 DAG.getValueType(NewVT.getScalarType()));
6011 return DAG.getExtOrTrunc(!Unsigned, Sat, DL, N2->getValueType(0));
6012 }
6013
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)6014 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
6015 SDValue N3, ISD::CondCode CC,
6016 SelectionDAG &DAG) {
6017 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
6018 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
6019 // be truncated versions of the setcc (N0/N1).
6020 if ((N0 != N2 &&
6021 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
6022 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
6023 return SDValue();
6024 ConstantSDNode *N1C = isConstOrConstSplat(N1);
6025 ConstantSDNode *N3C = isConstOrConstSplat(N3);
6026 if (!N1C || !N3C)
6027 return SDValue();
6028 const APInt &C1 = N1C->getAPIntValue();
6029 const APInt &C3 = N3C->getAPIntValue();
6030 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
6031 C1 != C3.zext(C1.getBitWidth()))
6032 return SDValue();
6033
6034 unsigned BW = (C1 + 1).exactLogBase2();
6035 EVT FPVT = N0.getOperand(0).getValueType();
6036 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
6037 if (FPVT.isVector())
6038 NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
6039 FPVT.getVectorElementCount());
6040 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
6041 FPVT, NewVT))
6042 return SDValue();
6043
6044 SDValue Sat =
6045 DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
6046 DAG.getValueType(NewVT.getScalarType()));
6047 return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
6048 }
6049
visitIMINMAX(SDNode * N)6050 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
6051 SDValue N0 = N->getOperand(0);
6052 SDValue N1 = N->getOperand(1);
6053 EVT VT = N0.getValueType();
6054 unsigned Opcode = N->getOpcode();
6055 SDLoc DL(N);
6056
6057 // fold operation with constant operands.
6058 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
6059 return C;
6060
6061 // If the operands are the same, this is a no-op.
6062 if (N0 == N1)
6063 return N0;
6064
6065 // canonicalize constant to RHS
6066 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6067 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6068 return DAG.getNode(Opcode, DL, VT, N1, N0);
6069
6070 // fold vector ops
6071 if (VT.isVector())
6072 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6073 return FoldedVOp;
6074
6075 // reassociate minmax
6076 if (SDValue RMINMAX = reassociateOps(Opcode, DL, N0, N1, N->getFlags()))
6077 return RMINMAX;
6078
6079 // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
6080 // Only do this if:
6081 // 1. The current op isn't legal and the flipped is.
6082 // 2. The saturation pattern is broken by canonicalization in InstCombine.
6083 bool IsOpIllegal = !TLI.isOperationLegal(Opcode, VT);
6084 bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
6085 if ((IsSatBroken || IsOpIllegal) && (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
6086 (N1.isUndef() || DAG.SignBitIsZero(N1))) {
6087 unsigned AltOpcode;
6088 switch (Opcode) {
6089 case ISD::SMIN: AltOpcode = ISD::UMIN; break;
6090 case ISD::SMAX: AltOpcode = ISD::UMAX; break;
6091 case ISD::UMIN: AltOpcode = ISD::SMIN; break;
6092 case ISD::UMAX: AltOpcode = ISD::SMAX; break;
6093 default: llvm_unreachable("Unknown MINMAX opcode");
6094 }
6095 if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(AltOpcode, VT))
6096 return DAG.getNode(AltOpcode, DL, VT, N0, N1);
6097 }
6098
6099 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
6100 if (SDValue S = PerformMinMaxFpToSatCombine(
6101 N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
6102 return S;
6103 if (Opcode == ISD::UMIN)
6104 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
6105 return S;
6106
6107 // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
6108 auto ReductionOpcode = [](unsigned Opcode) {
6109 switch (Opcode) {
6110 case ISD::SMIN:
6111 return ISD::VECREDUCE_SMIN;
6112 case ISD::SMAX:
6113 return ISD::VECREDUCE_SMAX;
6114 case ISD::UMIN:
6115 return ISD::VECREDUCE_UMIN;
6116 case ISD::UMAX:
6117 return ISD::VECREDUCE_UMAX;
6118 default:
6119 llvm_unreachable("Unexpected opcode");
6120 }
6121 };
6122 if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
6123 SDLoc(N), VT, N0, N1))
6124 return SD;
6125
6126 // Simplify the operands using demanded-bits information.
6127 if (SimplifyDemandedBits(SDValue(N, 0)))
6128 return SDValue(N, 0);
6129
6130 return SDValue();
6131 }
6132
6133 /// If this is a bitwise logic instruction and both operands have the same
6134 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)6135 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
6136 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
6137 EVT VT = N0.getValueType();
6138 unsigned LogicOpcode = N->getOpcode();
6139 unsigned HandOpcode = N0.getOpcode();
6140 assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
6141 assert(HandOpcode == N1.getOpcode() && "Bad input!");
6142
6143 // Bail early if none of these transforms apply.
6144 if (N0.getNumOperands() == 0)
6145 return SDValue();
6146
6147 // FIXME: We should check number of uses of the operands to not increase
6148 // the instruction count for all transforms.
6149
6150 // Handle size-changing casts (or sign_extend_inreg).
6151 SDValue X = N0.getOperand(0);
6152 SDValue Y = N1.getOperand(0);
6153 EVT XVT = X.getValueType();
6154 SDLoc DL(N);
6155 if (ISD::isExtOpcode(HandOpcode) || ISD::isExtVecInRegOpcode(HandOpcode) ||
6156 (HandOpcode == ISD::SIGN_EXTEND_INREG &&
6157 N0.getOperand(1) == N1.getOperand(1))) {
6158 // If both operands have other uses, this transform would create extra
6159 // instructions without eliminating anything.
6160 if (!N0.hasOneUse() && !N1.hasOneUse())
6161 return SDValue();
6162 // We need matching integer source types.
6163 if (XVT != Y.getValueType())
6164 return SDValue();
6165 // Don't create an illegal op during or after legalization. Don't ever
6166 // create an unsupported vector op.
6167 if ((VT.isVector() || LegalOperations) &&
6168 !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
6169 return SDValue();
6170 // Avoid infinite looping with PromoteIntBinOp.
6171 // TODO: Should we apply desirable/legal constraints to all opcodes?
6172 if ((HandOpcode == ISD::ANY_EXTEND ||
6173 HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
6174 LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
6175 return SDValue();
6176 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
6177 SDNodeFlags LogicFlags;
6178 LogicFlags.setDisjoint(N->getFlags().hasDisjoint() &&
6179 ISD::isExtOpcode(HandOpcode));
6180 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y, LogicFlags);
6181 if (HandOpcode == ISD::SIGN_EXTEND_INREG)
6182 return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
6183 return DAG.getNode(HandOpcode, DL, VT, Logic);
6184 }
6185
6186 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
6187 if (HandOpcode == ISD::TRUNCATE) {
6188 // If both operands have other uses, this transform would create extra
6189 // instructions without eliminating anything.
6190 if (!N0.hasOneUse() && !N1.hasOneUse())
6191 return SDValue();
6192 // We need matching source types.
6193 if (XVT != Y.getValueType())
6194 return SDValue();
6195 // Don't create an illegal op during or after legalization.
6196 if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
6197 return SDValue();
6198 // Be extra careful sinking truncate. If it's free, there's no benefit in
6199 // widening a binop. Also, don't create a logic op on an illegal type.
6200 if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
6201 return SDValue();
6202 if (!TLI.isTypeLegal(XVT))
6203 return SDValue();
6204 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6205 return DAG.getNode(HandOpcode, DL, VT, Logic);
6206 }
6207
6208 // For binops SHL/SRL/SRA/AND:
6209 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
6210 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
6211 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
6212 N0.getOperand(1) == N1.getOperand(1)) {
6213 // If either operand has other uses, this transform is not an improvement.
6214 if (!N0.hasOneUse() || !N1.hasOneUse())
6215 return SDValue();
6216 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6217 return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
6218 }
6219
6220 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
6221 if (HandOpcode == ISD::BSWAP) {
6222 // If either operand has other uses, this transform is not an improvement.
6223 if (!N0.hasOneUse() || !N1.hasOneUse())
6224 return SDValue();
6225 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6226 return DAG.getNode(HandOpcode, DL, VT, Logic);
6227 }
6228
6229 // For funnel shifts FSHL/FSHR:
6230 // logic_op (OP x, x1, s), (OP y, y1, s) -->
6231 // --> OP (logic_op x, y), (logic_op, x1, y1), s
6232 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
6233 N0.getOperand(2) == N1.getOperand(2)) {
6234 if (!N0.hasOneUse() || !N1.hasOneUse())
6235 return SDValue();
6236 SDValue X1 = N0.getOperand(1);
6237 SDValue Y1 = N1.getOperand(1);
6238 SDValue S = N0.getOperand(2);
6239 SDValue Logic0 = DAG.getNode(LogicOpcode, DL, VT, X, Y);
6240 SDValue Logic1 = DAG.getNode(LogicOpcode, DL, VT, X1, Y1);
6241 return DAG.getNode(HandOpcode, DL, VT, Logic0, Logic1, S);
6242 }
6243
6244 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
6245 // Only perform this optimization up until type legalization, before
6246 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
6247 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
6248 // we don't want to undo this promotion.
6249 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
6250 // on scalars.
6251 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
6252 Level <= AfterLegalizeTypes) {
6253 // Input types must be integer and the same.
6254 if (XVT.isInteger() && XVT == Y.getValueType() &&
6255 !(VT.isVector() && TLI.isTypeLegal(VT) &&
6256 !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
6257 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
6258 return DAG.getNode(HandOpcode, DL, VT, Logic);
6259 }
6260 }
6261
6262 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
6263 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
6264 // If both shuffles use the same mask, and both shuffle within a single
6265 // vector, then it is worthwhile to move the swizzle after the operation.
6266 // The type-legalizer generates this pattern when loading illegal
6267 // vector types from memory. In many cases this allows additional shuffle
6268 // optimizations.
6269 // There are other cases where moving the shuffle after the xor/and/or
6270 // is profitable even if shuffles don't perform a swizzle.
6271 // If both shuffles use the same mask, and both shuffles have the same first
6272 // or second operand, then it might still be profitable to move the shuffle
6273 // after the xor/and/or operation.
6274 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
6275 auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
6276 auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
6277 assert(X.getValueType() == Y.getValueType() &&
6278 "Inputs to shuffles are not the same type");
6279
6280 // Check that both shuffles use the same mask. The masks are known to be of
6281 // the same length because the result vector type is the same.
6282 // Check also that shuffles have only one use to avoid introducing extra
6283 // instructions.
6284 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
6285 !SVN0->getMask().equals(SVN1->getMask()))
6286 return SDValue();
6287
6288 // Don't try to fold this node if it requires introducing a
6289 // build vector of all zeros that might be illegal at this stage.
6290 SDValue ShOp = N0.getOperand(1);
6291 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6292 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6293
6294 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
6295 if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
6296 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
6297 N0.getOperand(0), N1.getOperand(0));
6298 return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
6299 }
6300
6301 // Don't try to fold this node if it requires introducing a
6302 // build vector of all zeros that might be illegal at this stage.
6303 ShOp = N0.getOperand(0);
6304 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
6305 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
6306
6307 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
6308 if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
6309 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
6310 N1.getOperand(1));
6311 return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
6312 }
6313 }
6314
6315 return SDValue();
6316 }
6317
6318 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)6319 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
6320 const SDLoc &DL) {
6321 SDValue LL, LR, RL, RR, N0CC, N1CC;
6322 if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
6323 !isSetCCEquivalent(N1, RL, RR, N1CC))
6324 return SDValue();
6325
6326 assert(N0.getValueType() == N1.getValueType() &&
6327 "Unexpected operand types for bitwise logic op");
6328 assert(LL.getValueType() == LR.getValueType() &&
6329 RL.getValueType() == RR.getValueType() &&
6330 "Unexpected operand types for setcc");
6331
6332 // If we're here post-legalization or the logic op type is not i1, the logic
6333 // op type must match a setcc result type. Also, all folds require new
6334 // operations on the left and right operands, so those types must match.
6335 EVT VT = N0.getValueType();
6336 EVT OpVT = LL.getValueType();
6337 if (LegalOperations || VT.getScalarType() != MVT::i1)
6338 if (VT != getSetCCResultType(OpVT))
6339 return SDValue();
6340 if (OpVT != RL.getValueType())
6341 return SDValue();
6342
6343 ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
6344 ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
6345 bool IsInteger = OpVT.isInteger();
6346 if (LR == RR && CC0 == CC1 && IsInteger) {
6347 bool IsZero = isNullOrNullSplat(LR);
6348 bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
6349
6350 // All bits clear?
6351 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
6352 // All sign bits clear?
6353 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
6354 // Any bits set?
6355 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
6356 // Any sign bits set?
6357 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
6358
6359 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
6360 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
6361 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
6362 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
6363 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
6364 SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
6365 AddToWorklist(Or.getNode());
6366 return DAG.getSetCC(DL, VT, Or, LR, CC1);
6367 }
6368
6369 // All bits set?
6370 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
6371 // All sign bits set?
6372 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
6373 // Any bits clear?
6374 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
6375 // Any sign bits clear?
6376 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
6377
6378 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
6379 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
6380 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
6381 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
6382 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
6383 SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
6384 AddToWorklist(And.getNode());
6385 return DAG.getSetCC(DL, VT, And, LR, CC1);
6386 }
6387 }
6388
6389 // TODO: What is the 'or' equivalent of this fold?
6390 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
6391 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
6392 IsInteger && CC0 == ISD::SETNE &&
6393 ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
6394 (isAllOnesConstant(LR) && isNullConstant(RR)))) {
6395 SDValue One = DAG.getConstant(1, DL, OpVT);
6396 SDValue Two = DAG.getConstant(2, DL, OpVT);
6397 SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
6398 AddToWorklist(Add.getNode());
6399 return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
6400 }
6401
6402 // Try more general transforms if the predicates match and the only user of
6403 // the compares is the 'and' or 'or'.
6404 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
6405 N0.hasOneUse() && N1.hasOneUse()) {
6406 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
6407 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
6408 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
6409 SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
6410 SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
6411 SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
6412 SDValue Zero = DAG.getConstant(0, DL, OpVT);
6413 return DAG.getSetCC(DL, VT, Or, Zero, CC1);
6414 }
6415
6416 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
6417 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
6418 // Match a shared variable operand and 2 non-opaque constant operands.
6419 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
6420 // The difference of the constants must be a single bit.
6421 const APInt &CMax =
6422 APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
6423 const APInt &CMin =
6424 APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
6425 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
6426 };
6427 if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
6428 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
6429 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
6430 SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
6431 SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
6432 SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
6433 SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
6434 SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
6435 SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
6436 SDValue Zero = DAG.getConstant(0, DL, OpVT);
6437 return DAG.getSetCC(DL, VT, And, Zero, CC0);
6438 }
6439 }
6440 }
6441
6442 // Canonicalize equivalent operands to LL == RL.
6443 if (LL == RR && LR == RL) {
6444 CC1 = ISD::getSetCCSwappedOperands(CC1);
6445 std::swap(RL, RR);
6446 }
6447
6448 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6449 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
6450 if (LL == RL && LR == RR) {
6451 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
6452 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
6453 if (NewCC != ISD::SETCC_INVALID &&
6454 (!LegalOperations ||
6455 (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
6456 TLI.isOperationLegal(ISD::SETCC, OpVT))))
6457 return DAG.getSetCC(DL, VT, LL, LR, NewCC);
6458 }
6459
6460 return SDValue();
6461 }
6462
arebothOperandsNotSNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6463 static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
6464 SelectionDAG &DAG) {
6465 return DAG.isKnownNeverSNaN(Operand2) && DAG.isKnownNeverSNaN(Operand1);
6466 }
6467
arebothOperandsNotNan(SDValue Operand1,SDValue Operand2,SelectionDAG & DAG)6468 static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
6469 SelectionDAG &DAG) {
6470 return DAG.isKnownNeverNaN(Operand2) && DAG.isKnownNeverNaN(Operand1);
6471 }
6472
6473 // FIXME: use FMINIMUMNUM if possible, such as for RISC-V.
getMinMaxOpcodeForFP(SDValue Operand1,SDValue Operand2,ISD::CondCode CC,unsigned OrAndOpcode,SelectionDAG & DAG,bool isFMAXNUMFMINNUM_IEEE,bool isFMAXNUMFMINNUM)6474 static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
6475 ISD::CondCode CC, unsigned OrAndOpcode,
6476 SelectionDAG &DAG,
6477 bool isFMAXNUMFMINNUM_IEEE,
6478 bool isFMAXNUMFMINNUM) {
6479 // The optimization cannot be applied for all the predicates because
6480 // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
6481 // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
6482 // applied at all if one of the operands is a signaling NaN.
6483
6484 // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
6485 // are non NaN values.
6486 if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
6487 ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
6488 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6489 isFMAXNUMFMINNUM_IEEE
6490 ? ISD::FMINNUM_IEEE
6491 : ISD::DELETED_NODE;
6492 else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
6493 (OrAndOpcode == ISD::OR)) ||
6494 ((CC == ISD::SETLT || CC == ISD::SETLE) &&
6495 (OrAndOpcode == ISD::AND)))
6496 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
6497 isFMAXNUMFMINNUM_IEEE
6498 ? ISD::FMAXNUM_IEEE
6499 : ISD::DELETED_NODE;
6500 // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
6501 // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6502 // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6503 // that there are not any sNaNs, then the optimization is not valid
6504 // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6505 // the optimization using FMINNUM/FMAXNUM for the following cases. If
6506 // we can prove that we do not have any sNaNs, then we can do the
6507 // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6508 // cases.
6509 else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6510 (OrAndOpcode == ISD::OR)) ||
6511 ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6512 (OrAndOpcode == ISD::AND)))
6513 return isFMAXNUMFMINNUM ? ISD::FMINNUM
6514 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6515 isFMAXNUMFMINNUM_IEEE
6516 ? ISD::FMINNUM_IEEE
6517 : ISD::DELETED_NODE;
6518 else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6519 (OrAndOpcode == ISD::OR)) ||
6520 ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6521 (OrAndOpcode == ISD::AND)))
6522 return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6523 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6524 isFMAXNUMFMINNUM_IEEE
6525 ? ISD::FMAXNUM_IEEE
6526 : ISD::DELETED_NODE;
6527 return ISD::DELETED_NODE;
6528 }
6529
foldAndOrOfSETCC(SDNode * LogicOp,SelectionDAG & DAG)6530 static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6531 using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6532 assert(
6533 (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6534 "Invalid Op to combine SETCC with");
6535
6536 // TODO: Search past casts/truncates.
6537 SDValue LHS = LogicOp->getOperand(0);
6538 SDValue RHS = LogicOp->getOperand(1);
6539 if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6540 !LHS->hasOneUse() || !RHS->hasOneUse())
6541 return SDValue();
6542
6543 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6544 AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6545 LogicOp, LHS.getNode(), RHS.getNode());
6546
6547 SDValue LHS0 = LHS->getOperand(0);
6548 SDValue RHS0 = RHS->getOperand(0);
6549 SDValue LHS1 = LHS->getOperand(1);
6550 SDValue RHS1 = RHS->getOperand(1);
6551 // TODO: We don't actually need a splat here, for vectors we just need the
6552 // invariants to hold for each element.
6553 auto *LHS1C = isConstOrConstSplat(LHS1);
6554 auto *RHS1C = isConstOrConstSplat(RHS1);
6555 ISD::CondCode CCL = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
6556 ISD::CondCode CCR = cast<CondCodeSDNode>(RHS.getOperand(2))->get();
6557 EVT VT = LogicOp->getValueType(0);
6558 EVT OpVT = LHS0.getValueType();
6559 SDLoc DL(LogicOp);
6560
6561 // Check if the operands of an and/or operation are comparisons and if they
6562 // compare against the same value. Replace the and/or-cmp-cmp sequence with
6563 // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6564 // sequence will be replaced with min-cmp sequence:
6565 // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6566 // and and-cmp-cmp will be replaced with max-cmp sequence:
6567 // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6568 // The optimization does not work for `==` or `!=` .
6569 // The two comparisons should have either the same predicate or the
6570 // predicate of one of the comparisons is the opposite of the other one.
6571 bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(ISD::FMAXNUM_IEEE, OpVT) &&
6572 TLI.isOperationLegal(ISD::FMINNUM_IEEE, OpVT);
6573 bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(ISD::FMAXNUM, OpVT) &&
6574 TLI.isOperationLegalOrCustom(ISD::FMINNUM, OpVT);
6575 if (((OpVT.isInteger() && TLI.isOperationLegal(ISD::UMAX, OpVT) &&
6576 TLI.isOperationLegal(ISD::SMAX, OpVT) &&
6577 TLI.isOperationLegal(ISD::UMIN, OpVT) &&
6578 TLI.isOperationLegal(ISD::SMIN, OpVT)) ||
6579 (OpVT.isFloatingPoint() &&
6580 (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6581 !ISD::isIntEqualitySetCC(CCL) && !ISD::isFPEqualitySetCC(CCL) &&
6582 CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6583 CCL != ISD::SETTRUE &&
6584 (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(CCR))) {
6585
6586 SDValue CommonValue, Operand1, Operand2;
6587 ISD::CondCode CC = ISD::SETCC_INVALID;
6588 if (CCL == CCR) {
6589 if (LHS0 == RHS0) {
6590 CommonValue = LHS0;
6591 Operand1 = LHS1;
6592 Operand2 = RHS1;
6593 CC = ISD::getSetCCSwappedOperands(CCL);
6594 } else if (LHS1 == RHS1) {
6595 CommonValue = LHS1;
6596 Operand1 = LHS0;
6597 Operand2 = RHS0;
6598 CC = CCL;
6599 }
6600 } else {
6601 assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6602 if (LHS0 == RHS1) {
6603 CommonValue = LHS0;
6604 Operand1 = LHS1;
6605 Operand2 = RHS0;
6606 CC = CCR;
6607 } else if (RHS0 == LHS1) {
6608 CommonValue = LHS1;
6609 Operand1 = LHS0;
6610 Operand2 = RHS1;
6611 CC = CCL;
6612 }
6613 }
6614
6615 // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6616 // handle it using OR/AND.
6617 if (CC == ISD::SETLT && isNullOrNullSplat(CommonValue))
6618 CC = ISD::SETCC_INVALID;
6619 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CommonValue))
6620 CC = ISD::SETCC_INVALID;
6621
6622 if (CC != ISD::SETCC_INVALID) {
6623 unsigned NewOpcode = ISD::DELETED_NODE;
6624 bool IsSigned = isSignedIntSetCC(CC);
6625 if (OpVT.isInteger()) {
6626 bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6627 CC == ISD::SETLT || CC == ISD::SETULT);
6628 bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6629 if (IsLess == IsOr)
6630 NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6631 else
6632 NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6633 } else if (OpVT.isFloatingPoint())
6634 NewOpcode =
6635 getMinMaxOpcodeForFP(Operand1, Operand2, CC, LogicOp->getOpcode(),
6636 DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6637
6638 if (NewOpcode != ISD::DELETED_NODE) {
6639 SDValue MinMaxValue =
6640 DAG.getNode(NewOpcode, DL, OpVT, Operand1, Operand2);
6641 return DAG.getSetCC(DL, VT, MinMaxValue, CommonValue, CC);
6642 }
6643 }
6644 }
6645
6646 if (LHS0 == LHS1 && RHS0 == RHS1 && CCL == CCR &&
6647 LHS0.getValueType() == RHS0.getValueType() &&
6648 ((LogicOp->getOpcode() == ISD::AND && CCL == ISD::SETO) ||
6649 (LogicOp->getOpcode() == ISD::OR && CCL == ISD::SETUO)))
6650 return DAG.getSetCC(DL, VT, LHS0, RHS0, CCL);
6651
6652 if (TargetPreference == AndOrSETCCFoldKind::None)
6653 return SDValue();
6654
6655 if (CCL == CCR &&
6656 CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6657 LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6658 const APInt &APLhs = LHS1C->getAPIntValue();
6659 const APInt &APRhs = RHS1C->getAPIntValue();
6660
6661 // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6662 // case this is just a compare).
6663 if (APLhs == (-APRhs) &&
6664 ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6665 DAG.doesNodeExist(ISD::ABS, DAG.getVTList(OpVT), {LHS0}))) {
6666 const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6667 // (icmp eq A, C) | (icmp eq A, -C)
6668 // -> (icmp eq Abs(A), C)
6669 // (icmp ne A, C) & (icmp ne A, -C)
6670 // -> (icmp ne Abs(A), C)
6671 SDValue AbsOp = DAG.getNode(ISD::ABS, DL, OpVT, LHS0);
6672 return DAG.getNode(ISD::SETCC, DL, VT, AbsOp,
6673 DAG.getConstant(C, DL, OpVT), LHS.getOperand(2));
6674 } else if (TargetPreference &
6675 (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6676
6677 // AndOrSETCCFoldKind::AddAnd:
6678 // A == C0 | A == C1
6679 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6680 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6681 // A != C0 & A != C1
6682 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6683 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6684
6685 // AndOrSETCCFoldKind::NotAnd:
6686 // A == C0 | A == C1
6687 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6688 // -> ~A & smin(C0, C1) == 0
6689 // A != C0 & A != C1
6690 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6691 // -> ~A & smin(C0, C1) != 0
6692
6693 const APInt &MaxC = APIntOps::smax(APRhs, APLhs);
6694 const APInt &MinC = APIntOps::smin(APRhs, APLhs);
6695 APInt Dif = MaxC - MinC;
6696 if (!Dif.isZero() && Dif.isPowerOf2()) {
6697 if (MaxC.isAllOnes() &&
6698 (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6699 SDValue NotOp = DAG.getNOT(DL, LHS0, OpVT);
6700 SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, NotOp,
6701 DAG.getConstant(MinC, DL, OpVT));
6702 return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6703 DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6704 } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6705
6706 SDValue AddOp = DAG.getNode(ISD::ADD, DL, OpVT, LHS0,
6707 DAG.getConstant(-MinC, DL, OpVT));
6708 SDValue AndOp = DAG.getNode(ISD::AND, DL, OpVT, AddOp,
6709 DAG.getConstant(~Dif, DL, OpVT));
6710 return DAG.getNode(ISD::SETCC, DL, VT, AndOp,
6711 DAG.getConstant(0, DL, OpVT), LHS.getOperand(2));
6712 }
6713 }
6714 }
6715 }
6716
6717 return SDValue();
6718 }
6719
6720 // Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6721 // We canonicalize to the `select` form in the middle end, but the `and` form
6722 // gets better codegen and all tested targets (arm, x86, riscv)
combineSelectAsExtAnd(SDValue Cond,SDValue T,SDValue F,const SDLoc & DL,SelectionDAG & DAG)6723 static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6724 const SDLoc &DL, SelectionDAG &DAG) {
6725 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6726 if (!isNullConstant(F))
6727 return SDValue();
6728
6729 EVT CondVT = Cond.getValueType();
6730 if (TLI.getBooleanContents(CondVT) !=
6731 TargetLoweringBase::ZeroOrOneBooleanContent)
6732 return SDValue();
6733
6734 if (T.getOpcode() != ISD::AND)
6735 return SDValue();
6736
6737 if (!isOneConstant(T.getOperand(1)))
6738 return SDValue();
6739
6740 EVT OpVT = T.getValueType();
6741
6742 SDValue CondMask =
6743 OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Cond, DL, OpVT, CondVT);
6744 return DAG.getNode(ISD::AND, DL, OpVT, CondMask, T.getOperand(0));
6745 }
6746
6747 /// This contains all DAGCombine rules which reduce two values combined by
6748 /// an And operation to a single value. This makes them reusable in the context
6749 /// of visitSELECT(). Rules involving constants are not included as
6750 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)6751 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6752 EVT VT = N1.getValueType();
6753 SDLoc DL(N);
6754
6755 // fold (and x, undef) -> 0
6756 if (N0.isUndef() || N1.isUndef())
6757 return DAG.getConstant(0, DL, VT);
6758
6759 if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
6760 return V;
6761
6762 // Canonicalize:
6763 // and(x, add) -> and(add, x)
6764 if (N1.getOpcode() == ISD::ADD)
6765 std::swap(N0, N1);
6766
6767 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6768 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6769 VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6770 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
6771 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
6772 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6773 // immediate for an add, but it is legal if its top c2 bits are set,
6774 // transform the ADD so the immediate doesn't need to be materialized
6775 // in a register.
6776 APInt ADDC = ADDI->getAPIntValue();
6777 APInt SRLC = SRLI->getAPIntValue();
6778 if (ADDC.getSignificantBits() <= 64 && SRLC.ult(VT.getSizeInBits()) &&
6779 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6780 APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
6781 SRLC.getZExtValue());
6782 if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
6783 ADDC |= Mask;
6784 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6785 SDLoc DL0(N0);
6786 SDValue NewAdd =
6787 DAG.getNode(ISD::ADD, DL0, VT,
6788 N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
6789 CombineTo(N0.getNode(), NewAdd);
6790 // Return N so it doesn't get rechecked!
6791 return SDValue(N, 0);
6792 }
6793 }
6794 }
6795 }
6796 }
6797 }
6798
6799 return SDValue();
6800 }
6801
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)6802 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6803 EVT LoadResultTy, EVT &ExtVT) {
6804 if (!AndC->getAPIntValue().isMask())
6805 return false;
6806
6807 unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6808
6809 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6810 EVT LoadedVT = LoadN->getMemoryVT();
6811
6812 if (ExtVT == LoadedVT &&
6813 (!LegalOperations ||
6814 TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
6815 // ZEXTLOAD will match without needing to change the size of the value being
6816 // loaded.
6817 return true;
6818 }
6819
6820 // Do not change the width of a volatile or atomic loads.
6821 if (!LoadN->isSimple())
6822 return false;
6823
6824 // Do not generate loads of non-round integer types since these can
6825 // be expensive (and would be wrong if the type is not byte sized).
6826 if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
6827 return false;
6828
6829 if (LegalOperations &&
6830 !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
6831 return false;
6832
6833 if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0))
6834 return false;
6835
6836 return true;
6837 }
6838
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)6839 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6840 ISD::LoadExtType ExtType, EVT &MemVT,
6841 unsigned ShAmt) {
6842 if (!LDST)
6843 return false;
6844
6845 // Only allow byte offsets.
6846 if (ShAmt % 8)
6847 return false;
6848 const unsigned ByteShAmt = ShAmt / 8;
6849
6850 // Do not generate loads of non-round integer types since these can
6851 // be expensive (and would be wrong if the type is not byte sized).
6852 if (!MemVT.isRound())
6853 return false;
6854
6855 // Don't change the width of a volatile or atomic loads.
6856 if (!LDST->isSimple())
6857 return false;
6858
6859 EVT LdStMemVT = LDST->getMemoryVT();
6860
6861 // Bail out when changing the scalable property, since we can't be sure that
6862 // we're actually narrowing here.
6863 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6864 return false;
6865
6866 // Verify that we are actually reducing a load width here.
6867 if (LdStMemVT.bitsLT(MemVT))
6868 return false;
6869
6870 // Ensure that this isn't going to produce an unsupported memory access.
6871 if (ShAmt) {
6872 const Align LDSTAlign = LDST->getAlign();
6873 const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
6874 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
6875 LDST->getAddressSpace(), NarrowAlign,
6876 LDST->getMemOperand()->getFlags()))
6877 return false;
6878 }
6879
6880 // It's not possible to generate a constant of extended or untyped type.
6881 EVT PtrType = LDST->getBasePtr().getValueType();
6882 if (PtrType == MVT::Untyped || PtrType.isExtended())
6883 return false;
6884
6885 if (isa<LoadSDNode>(LDST)) {
6886 LoadSDNode *Load = cast<LoadSDNode>(LDST);
6887 // Don't transform one with multiple uses, this would require adding a new
6888 // load.
6889 if (!SDValue(Load, 0).hasOneUse())
6890 return false;
6891
6892 if (LegalOperations &&
6893 !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
6894 return false;
6895
6896 // For the transform to be legal, the load must produce only two values
6897 // (the value loaded and the chain). Don't transform a pre-increment
6898 // load, for example, which produces an extra value. Otherwise the
6899 // transformation is not equivalent, and the downstream logic to replace
6900 // uses gets things wrong.
6901 if (Load->getNumValues() > 2)
6902 return false;
6903
6904 // If the load that we're shrinking is an extload and we're not just
6905 // discarding the extension we can't simply shrink the load. Bail.
6906 // TODO: It would be possible to merge the extensions in some cases.
6907 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6908 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6909 return false;
6910
6911 if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT, ByteShAmt))
6912 return false;
6913 } else {
6914 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6915 StoreSDNode *Store = cast<StoreSDNode>(LDST);
6916 // Can't write outside the original store
6917 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6918 return false;
6919
6920 if (LegalOperations &&
6921 !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
6922 return false;
6923 }
6924 return true;
6925 }
6926
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)6927 bool DAGCombiner::SearchForAndLoads(SDNode *N,
6928 SmallVectorImpl<LoadSDNode*> &Loads,
6929 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6930 ConstantSDNode *Mask,
6931 SDNode *&NodeToMask) {
6932 // Recursively search for the operands, looking for loads which can be
6933 // narrowed.
6934 for (SDValue Op : N->op_values()) {
6935 if (Op.getValueType().isVector())
6936 return false;
6937
6938 // Some constants may need fixing up later if they are too large.
6939 if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
6940 assert(ISD::isBitwiseLogicOp(N->getOpcode()) &&
6941 "Expected bitwise logic operation");
6942 if (!C->getAPIntValue().isSubsetOf(Mask->getAPIntValue()))
6943 NodesWithConsts.insert(N);
6944 continue;
6945 }
6946
6947 if (!Op.hasOneUse())
6948 return false;
6949
6950 switch(Op.getOpcode()) {
6951 case ISD::LOAD: {
6952 auto *Load = cast<LoadSDNode>(Op);
6953 EVT ExtVT;
6954 if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
6955 isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
6956
6957 // ZEXTLOAD is already small enough.
6958 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6959 ExtVT.bitsGE(Load->getMemoryVT()))
6960 continue;
6961
6962 // Use LE to convert equal sized loads to zext.
6963 if (ExtVT.bitsLE(Load->getMemoryVT()))
6964 Loads.push_back(Load);
6965
6966 continue;
6967 }
6968 return false;
6969 }
6970 case ISD::ZERO_EXTEND:
6971 case ISD::AssertZext: {
6972 unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6973 EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
6974 EVT VT = Op.getOpcode() == ISD::AssertZext ?
6975 cast<VTSDNode>(Op.getOperand(1))->getVT() :
6976 Op.getOperand(0).getValueType();
6977
6978 // We can accept extending nodes if the mask is wider or an equal
6979 // width to the original type.
6980 if (ExtVT.bitsGE(VT))
6981 continue;
6982 break;
6983 }
6984 case ISD::OR:
6985 case ISD::XOR:
6986 case ISD::AND:
6987 if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
6988 NodeToMask))
6989 return false;
6990 continue;
6991 }
6992
6993 // Allow one node which will masked along with any loads found.
6994 if (NodeToMask)
6995 return false;
6996
6997 // Also ensure that the node to be masked only produces one data result.
6998 NodeToMask = Op.getNode();
6999 if (NodeToMask->getNumValues() > 1) {
7000 bool HasValue = false;
7001 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
7002 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
7003 if (VT != MVT::Glue && VT != MVT::Other) {
7004 if (HasValue) {
7005 NodeToMask = nullptr;
7006 return false;
7007 }
7008 HasValue = true;
7009 }
7010 }
7011 assert(HasValue && "Node to be masked has no data result?");
7012 }
7013 }
7014 return true;
7015 }
7016
BackwardsPropagateMask(SDNode * N)7017 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
7018 auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
7019 if (!Mask)
7020 return false;
7021
7022 if (!Mask->getAPIntValue().isMask())
7023 return false;
7024
7025 // No need to do anything if the and directly uses a load.
7026 if (isa<LoadSDNode>(N->getOperand(0)))
7027 return false;
7028
7029 SmallVector<LoadSDNode*, 8> Loads;
7030 SmallPtrSet<SDNode*, 2> NodesWithConsts;
7031 SDNode *FixupNode = nullptr;
7032 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
7033 if (Loads.empty())
7034 return false;
7035
7036 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
7037 SDValue MaskOp = N->getOperand(1);
7038
7039 // If it exists, fixup the single node we allow in the tree that needs
7040 // masking.
7041 if (FixupNode) {
7042 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
7043 SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
7044 FixupNode->getValueType(0),
7045 SDValue(FixupNode, 0), MaskOp);
7046 DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
7047 if (And.getOpcode() == ISD ::AND)
7048 DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
7049 }
7050
7051 // Narrow any constants that need it.
7052 for (auto *LogicN : NodesWithConsts) {
7053 SDValue Op0 = LogicN->getOperand(0);
7054 SDValue Op1 = LogicN->getOperand(1);
7055
7056 // We only need to fix AND if both inputs are constants. And we only need
7057 // to fix one of the constants.
7058 if (LogicN->getOpcode() == ISD::AND &&
7059 (!isa<ConstantSDNode>(Op0) || !isa<ConstantSDNode>(Op1)))
7060 continue;
7061
7062 if (isa<ConstantSDNode>(Op0) && LogicN->getOpcode() != ISD::AND)
7063 Op0 =
7064 DAG.getNode(ISD::AND, SDLoc(Op0), Op0.getValueType(), Op0, MaskOp);
7065
7066 if (isa<ConstantSDNode>(Op1))
7067 Op1 =
7068 DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(), Op1, MaskOp);
7069
7070 if (isa<ConstantSDNode>(Op0) && !isa<ConstantSDNode>(Op1))
7071 std::swap(Op0, Op1);
7072
7073 DAG.UpdateNodeOperands(LogicN, Op0, Op1);
7074 }
7075
7076 // Create narrow loads.
7077 for (auto *Load : Loads) {
7078 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
7079 SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
7080 SDValue(Load, 0), MaskOp);
7081 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
7082 if (And.getOpcode() == ISD ::AND)
7083 And = SDValue(
7084 DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
7085 SDValue NewLoad = reduceLoadWidth(And.getNode());
7086 assert(NewLoad &&
7087 "Shouldn't be masking the load if it can't be narrowed");
7088 CombineTo(Load, NewLoad, NewLoad.getValue(1));
7089 }
7090 DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
7091 return true;
7092 }
7093 return false;
7094 }
7095
7096 // Unfold
7097 // x & (-1 'logical shift' y)
7098 // To
7099 // (x 'opposite logical shift' y) 'logical shift' y
7100 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)7101 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
7102 assert(N->getOpcode() == ISD::AND);
7103
7104 SDValue N0 = N->getOperand(0);
7105 SDValue N1 = N->getOperand(1);
7106
7107 // Do we actually prefer shifts over mask?
7108 if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
7109 return SDValue();
7110
7111 // Try to match (-1 '[outer] logical shift' y)
7112 unsigned OuterShift;
7113 unsigned InnerShift; // The opposite direction to the OuterShift.
7114 SDValue Y; // Shift amount.
7115 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
7116 if (!M.hasOneUse())
7117 return false;
7118 OuterShift = M->getOpcode();
7119 if (OuterShift == ISD::SHL)
7120 InnerShift = ISD::SRL;
7121 else if (OuterShift == ISD::SRL)
7122 InnerShift = ISD::SHL;
7123 else
7124 return false;
7125 if (!isAllOnesConstant(M->getOperand(0)))
7126 return false;
7127 Y = M->getOperand(1);
7128 return true;
7129 };
7130
7131 SDValue X;
7132 if (matchMask(N1))
7133 X = N0;
7134 else if (matchMask(N0))
7135 X = N1;
7136 else
7137 return SDValue();
7138
7139 SDLoc DL(N);
7140 EVT VT = N->getValueType(0);
7141
7142 // tmp = x 'opposite logical shift' y
7143 SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
7144 // ret = tmp 'logical shift' y
7145 SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
7146
7147 return T1;
7148 }
7149
7150 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
7151 /// For a target with a bit test, this is expected to become test + set and save
7152 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)7153 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
7154 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
7155
7156 // Look through an optional extension.
7157 SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
7158 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
7159 And0 = And0.getOperand(0);
7160 if (!isOneConstant(And1) || !And0.hasOneUse())
7161 return SDValue();
7162
7163 SDValue Src = And0;
7164
7165 // Attempt to find a 'not' op.
7166 // TODO: Should we favor test+set even without the 'not' op?
7167 bool FoundNot = false;
7168 if (isBitwiseNot(Src)) {
7169 FoundNot = true;
7170 Src = Src.getOperand(0);
7171
7172 // Look though an optional truncation. The source operand may not be the
7173 // same type as the original 'and', but that is ok because we are masking
7174 // off everything but the low bit.
7175 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
7176 Src = Src.getOperand(0);
7177 }
7178
7179 // Match a shift-right by constant.
7180 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
7181 return SDValue();
7182
7183 // This is probably not worthwhile without a supported type.
7184 EVT SrcVT = Src.getValueType();
7185 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
7186 if (!TLI.isTypeLegal(SrcVT))
7187 return SDValue();
7188
7189 // We might have looked through casts that make this transform invalid.
7190 unsigned BitWidth = SrcVT.getScalarSizeInBits();
7191 SDValue ShiftAmt = Src.getOperand(1);
7192 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
7193 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(BitWidth))
7194 return SDValue();
7195
7196 // Set source to shift source.
7197 Src = Src.getOperand(0);
7198
7199 // Try again to find a 'not' op.
7200 // TODO: Should we favor test+set even with two 'not' ops?
7201 if (!FoundNot) {
7202 if (!isBitwiseNot(Src))
7203 return SDValue();
7204 Src = Src.getOperand(0);
7205 }
7206
7207 if (!TLI.hasBitTest(Src, ShiftAmt))
7208 return SDValue();
7209
7210 // Turn this into a bit-test pattern using mask op + setcc:
7211 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
7212 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
7213 SDLoc DL(And);
7214 SDValue X = DAG.getZExtOrTrunc(Src, DL, SrcVT);
7215 EVT CCVT =
7216 TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT);
7217 SDValue Mask = DAG.getConstant(
7218 APInt::getOneBitSet(BitWidth, ShiftAmtC->getZExtValue()), DL, SrcVT);
7219 SDValue NewAnd = DAG.getNode(ISD::AND, DL, SrcVT, X, Mask);
7220 SDValue Zero = DAG.getConstant(0, DL, SrcVT);
7221 SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
7222 return DAG.getZExtOrTrunc(Setcc, DL, And->getValueType(0));
7223 }
7224
7225 /// For targets that support usubsat, match a bit-hack form of that operation
7226 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)7227 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG, const SDLoc &DL) {
7228 EVT VT = N->getValueType(0);
7229 unsigned BitWidth = VT.getScalarSizeInBits();
7230 APInt SignMask = APInt::getSignMask(BitWidth);
7231
7232 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
7233 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
7234 // xor/add with SMIN (signmask) are logically equivalent.
7235 SDValue X;
7236 if (!sd_match(N, m_And(m_OneUse(m_Xor(m_Value(X), m_SpecificInt(SignMask))),
7237 m_OneUse(m_Sra(m_Deferred(X),
7238 m_SpecificInt(BitWidth - 1))))) &&
7239 !sd_match(N, m_And(m_OneUse(m_Add(m_Value(X), m_SpecificInt(SignMask))),
7240 m_OneUse(m_Sra(m_Deferred(X),
7241 m_SpecificInt(BitWidth - 1))))))
7242 return SDValue();
7243
7244 return DAG.getNode(ISD::USUBSAT, DL, VT, X,
7245 DAG.getConstant(SignMask, DL, VT));
7246 }
7247
7248 /// Given a bitwise logic operation N with a matching bitwise logic operand,
7249 /// fold a pattern where 2 of the source operands are identically shifted
7250 /// values. For example:
7251 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)7252 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
7253 SelectionDAG &DAG) {
7254 unsigned LogicOpcode = N->getOpcode();
7255 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7256 "Expected bitwise logic operation");
7257
7258 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
7259 return SDValue();
7260
7261 // Match another bitwise logic op and a shift.
7262 unsigned ShiftOpcode = ShiftOp.getOpcode();
7263 if (LogicOp.getOpcode() != LogicOpcode ||
7264 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
7265 ShiftOpcode == ISD::SRA))
7266 return SDValue();
7267
7268 // Match another shift op inside the first logic operand. Handle both commuted
7269 // possibilities.
7270 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7271 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
7272 SDValue X1 = ShiftOp.getOperand(0);
7273 SDValue Y = ShiftOp.getOperand(1);
7274 SDValue X0, Z;
7275 if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
7276 LogicOp.getOperand(0).getOperand(1) == Y) {
7277 X0 = LogicOp.getOperand(0).getOperand(0);
7278 Z = LogicOp.getOperand(1);
7279 } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
7280 LogicOp.getOperand(1).getOperand(1) == Y) {
7281 X0 = LogicOp.getOperand(1).getOperand(0);
7282 Z = LogicOp.getOperand(0);
7283 } else {
7284 return SDValue();
7285 }
7286
7287 EVT VT = N->getValueType(0);
7288 SDLoc DL(N);
7289 SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
7290 SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
7291 return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
7292 }
7293
7294 /// Given a tree of logic operations with shape like
7295 /// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
7296 /// try to match and fold shift operations with the same shift amount.
7297 /// For example:
7298 /// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
7299 /// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
foldLogicTreeOfShifts(SDNode * N,SDValue LeftHand,SDValue RightHand,SelectionDAG & DAG)7300 static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
7301 SDValue RightHand, SelectionDAG &DAG) {
7302 unsigned LogicOpcode = N->getOpcode();
7303 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
7304 "Expected bitwise logic operation");
7305 if (LeftHand.getOpcode() != LogicOpcode ||
7306 RightHand.getOpcode() != LogicOpcode)
7307 return SDValue();
7308 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
7309 return SDValue();
7310
7311 // Try to match one of following patterns:
7312 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
7313 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
7314 // Note that foldLogicOfShifts will handle commuted versions of the left hand
7315 // itself.
7316 SDValue CombinedShifts, W;
7317 SDValue R0 = RightHand.getOperand(0);
7318 SDValue R1 = RightHand.getOperand(1);
7319 if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R0, DAG)))
7320 W = R1;
7321 else if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R1, DAG)))
7322 W = R0;
7323 else
7324 return SDValue();
7325
7326 EVT VT = N->getValueType(0);
7327 SDLoc DL(N);
7328 return DAG.getNode(LogicOpcode, DL, VT, CombinedShifts, W);
7329 }
7330
7331 /// Fold "masked merge" expressions like `(m & x) | (~m & y)` and its DeMorgan
7332 /// variant `(~m | x) & (m | y)` into the equivalent `((x ^ y) & m) ^ y)`
7333 /// pattern. This is typically a better representation for targets without a
7334 /// fused "and-not" operation.
foldMaskedMerge(SDNode * Node,SelectionDAG & DAG,const TargetLowering & TLI,const SDLoc & DL)7335 static SDValue foldMaskedMerge(SDNode *Node, SelectionDAG &DAG,
7336 const TargetLowering &TLI, const SDLoc &DL) {
7337 // Note that masked-merge variants using XOR or ADD expressions are
7338 // normalized to OR by InstCombine so we only check for OR or AND.
7339 assert((Node->getOpcode() == ISD::OR || Node->getOpcode() == ISD::AND) &&
7340 "Must be called with ISD::OR or ISD::AND node");
7341
7342 // If the target supports and-not, don't fold this.
7343 if (TLI.hasAndNot(SDValue(Node, 0)))
7344 return SDValue();
7345
7346 SDValue M, X, Y;
7347
7348 if (sd_match(Node,
7349 m_Or(m_OneUse(m_And(m_OneUse(m_Not(m_Value(M))), m_Value(Y))),
7350 m_OneUse(m_And(m_Deferred(M), m_Value(X))))) ||
7351 sd_match(Node,
7352 m_And(m_OneUse(m_Or(m_OneUse(m_Not(m_Value(M))), m_Value(X))),
7353 m_OneUse(m_Or(m_Deferred(M), m_Value(Y)))))) {
7354 EVT VT = M.getValueType();
7355 SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, X, Y);
7356 SDValue And = DAG.getNode(ISD::AND, DL, VT, Xor, M);
7357 return DAG.getNode(ISD::XOR, DL, VT, And, Y);
7358 }
7359 return SDValue();
7360 }
7361
visitAND(SDNode * N)7362 SDValue DAGCombiner::visitAND(SDNode *N) {
7363 SDValue N0 = N->getOperand(0);
7364 SDValue N1 = N->getOperand(1);
7365 EVT VT = N1.getValueType();
7366 SDLoc DL(N);
7367
7368 // x & x --> x
7369 if (N0 == N1)
7370 return N0;
7371
7372 // fold (and c1, c2) -> c1&c2
7373 if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, DL, VT, {N0, N1}))
7374 return C;
7375
7376 // canonicalize constant to RHS
7377 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7378 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7379 return DAG.getNode(ISD::AND, DL, VT, N1, N0);
7380
7381 if (areBitwiseNotOfEachother(N0, N1))
7382 return DAG.getConstant(APInt::getZero(VT.getScalarSizeInBits()), DL, VT);
7383
7384 // fold vector ops
7385 if (VT.isVector()) {
7386 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7387 return FoldedVOp;
7388
7389 // fold (and x, 0) -> 0, vector edition
7390 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
7391 // do not return N1, because undef node may exist in N1
7392 return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()), DL,
7393 N1.getValueType());
7394
7395 // fold (and x, -1) -> x, vector edition
7396 if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
7397 return N0;
7398
7399 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
7400 auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
7401 ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
7402 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
7403 EVT LoadVT = MLoad->getMemoryVT();
7404 EVT ExtVT = VT;
7405 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
7406 // For this AND to be a zero extension of the masked load the elements
7407 // of the BuildVec must mask the bottom bits of the extended element
7408 // type
7409 uint64_t ElementSize =
7410 LoadVT.getVectorElementType().getScalarSizeInBits();
7411 if (Splat->getAPIntValue().isMask(ElementSize)) {
7412 SDValue NewLoad = DAG.getMaskedLoad(
7413 ExtVT, DL, MLoad->getChain(), MLoad->getBasePtr(),
7414 MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
7415 LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
7416 ISD::ZEXTLOAD, MLoad->isExpandingLoad());
7417 bool LoadHasOtherUsers = !N0.hasOneUse();
7418 CombineTo(N, NewLoad);
7419 if (LoadHasOtherUsers)
7420 CombineTo(MLoad, NewLoad.getValue(0), NewLoad.getValue(1));
7421 return SDValue(N, 0);
7422 }
7423 }
7424 }
7425 }
7426
7427 // fold (and x, -1) -> x
7428 if (isAllOnesConstant(N1))
7429 return N0;
7430
7431 // if (and x, c) is known to be zero, return 0
7432 unsigned BitWidth = VT.getScalarSizeInBits();
7433 ConstantSDNode *N1C = isConstOrConstSplat(N1);
7434 if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
7435 return DAG.getConstant(0, DL, VT);
7436
7437 if (SDValue R = foldAndOrOfSETCC(N, DAG))
7438 return R;
7439
7440 if (SDValue NewSel = foldBinOpIntoSelect(N))
7441 return NewSel;
7442
7443 // reassociate and
7444 if (SDValue RAND = reassociateOps(ISD::AND, DL, N0, N1, N->getFlags()))
7445 return RAND;
7446
7447 // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
7448 if (SDValue SD =
7449 reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, DL, VT, N0, N1))
7450 return SD;
7451
7452 // fold (and (or x, C), D) -> D if (C & D) == D
7453 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
7454 return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
7455 };
7456 if (N0.getOpcode() == ISD::OR &&
7457 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
7458 return N1;
7459
7460 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
7461 SDValue N0Op0 = N0.getOperand(0);
7462 EVT SrcVT = N0Op0.getValueType();
7463 unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
7464 APInt Mask = ~N1C->getAPIntValue();
7465 Mask = Mask.trunc(SrcBitWidth);
7466
7467 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
7468 if (DAG.MaskedValueIsZero(N0Op0, Mask))
7469 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0Op0);
7470
7471 // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
7472 if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
7473 TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) &&
7474 TLI.isTypeDesirableForOp(ISD::AND, SrcVT) &&
7475 TLI.isNarrowingProfitable(N, VT, SrcVT))
7476 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
7477 DAG.getNode(ISD::AND, DL, SrcVT, N0Op0,
7478 DAG.getZExtOrTrunc(N1, DL, SrcVT)));
7479 }
7480
7481 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
7482 if (ISD::isExtOpcode(N0.getOpcode())) {
7483 unsigned ExtOpc = N0.getOpcode();
7484 SDValue N0Op0 = N0.getOperand(0);
7485 if (N0Op0.getOpcode() == ISD::AND &&
7486 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0Op0, VT)) &&
7487 N0->hasOneUse() && N0Op0->hasOneUse()) {
7488 if (SDValue NewExt = DAG.FoldConstantArithmetic(ExtOpc, DL, VT,
7489 {N0Op0.getOperand(1)})) {
7490 if (SDValue NewMask =
7491 DAG.FoldConstantArithmetic(ISD::AND, DL, VT, {N1, NewExt})) {
7492 return DAG.getNode(ISD::AND, DL, VT,
7493 DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(0)),
7494 NewMask);
7495 }
7496 }
7497 }
7498 }
7499
7500 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
7501 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
7502 // already be zero by virtue of the width of the base type of the load.
7503 //
7504 // the 'X' node here can either be nothing or an extract_vector_elt to catch
7505 // more cases.
7506 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
7507 N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
7508 N0.getOperand(0).getOpcode() == ISD::LOAD &&
7509 N0.getOperand(0).getResNo() == 0) ||
7510 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
7511 auto *Load =
7512 cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0));
7513
7514 // Get the constant (if applicable) the zero'th operand is being ANDed with.
7515 // This can be a pure constant or a vector splat, in which case we treat the
7516 // vector as a scalar and use the splat value.
7517 APInt Constant = APInt::getZero(1);
7518 if (const ConstantSDNode *C = isConstOrConstSplat(
7519 N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) {
7520 Constant = C->getAPIntValue();
7521 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
7522 unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
7523 APInt SplatValue, SplatUndef;
7524 unsigned SplatBitSize;
7525 bool HasAnyUndefs;
7526 // Endianness should not matter here. Code below makes sure that we only
7527 // use the result if the SplatBitSize is a multiple of the vector element
7528 // size. And after that we AND all element sized parts of the splat
7529 // together. So the end result should be the same regardless of in which
7530 // order we do those operations.
7531 const bool IsBigEndian = false;
7532 bool IsSplat =
7533 Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
7534 HasAnyUndefs, EltBitWidth, IsBigEndian);
7535
7536 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
7537 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
7538 if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
7539 // Undef bits can contribute to a possible optimisation if set, so
7540 // set them.
7541 SplatValue |= SplatUndef;
7542
7543 // The splat value may be something like "0x00FFFFFF", which means 0 for
7544 // the first vector value and FF for the rest, repeating. We need a mask
7545 // that will apply equally to all members of the vector, so AND all the
7546 // lanes of the constant together.
7547 Constant = APInt::getAllOnes(EltBitWidth);
7548 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7549 Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
7550 }
7551 }
7552
7553 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7554 // actually legal and isn't going to get expanded, else this is a false
7555 // optimisation.
7556 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
7557 Load->getValueType(0),
7558 Load->getMemoryVT());
7559
7560 // Resize the constant to the same size as the original memory access before
7561 // extension. If it is still the AllOnesValue then this AND is completely
7562 // unneeded.
7563 Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
7564
7565 bool B;
7566 switch (Load->getExtensionType()) {
7567 default: B = false; break;
7568 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7569 case ISD::ZEXTLOAD:
7570 case ISD::NON_EXTLOAD: B = true; break;
7571 }
7572
7573 if (B && Constant.isAllOnes()) {
7574 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7575 // preserve semantics once we get rid of the AND.
7576 SDValue NewLoad(Load, 0);
7577
7578 // Fold the AND away. NewLoad may get replaced immediately.
7579 CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
7580
7581 if (Load->getExtensionType() == ISD::EXTLOAD) {
7582 NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
7583 Load->getValueType(0), SDLoc(Load),
7584 Load->getChain(), Load->getBasePtr(),
7585 Load->getOffset(), Load->getMemoryVT(),
7586 Load->getMemOperand());
7587 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7588 if (Load->getNumValues() == 3) {
7589 // PRE/POST_INC loads have 3 values.
7590 SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
7591 NewLoad.getValue(2) };
7592 CombineTo(Load, To, 3, true);
7593 } else {
7594 CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
7595 }
7596 }
7597
7598 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7599 }
7600 }
7601
7602 // Try to convert a constant mask AND into a shuffle clear mask.
7603 if (VT.isVector())
7604 if (SDValue Shuffle = XformToShuffleWithZero(N))
7605 return Shuffle;
7606
7607 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7608 return Combined;
7609
7610 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7611 ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
7612 SDValue Ext = N0.getOperand(0);
7613 EVT ExtVT = Ext->getValueType(0);
7614 SDValue Extendee = Ext->getOperand(0);
7615
7616 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7617 if (N1C->getAPIntValue().isMask(ScalarWidth) &&
7618 (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
7619 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7620 // => (extract_subvector (iN_zeroext v))
7621 SDValue ZeroExtExtendee =
7622 DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVT, Extendee);
7623
7624 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ZeroExtExtendee,
7625 N0.getOperand(1));
7626 }
7627 }
7628
7629 // fold (and (masked_gather x)) -> (zext_masked_gather x)
7630 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
7631 EVT MemVT = GN0->getMemoryVT();
7632 EVT ScalarVT = MemVT.getScalarType();
7633
7634 if (SDValue(GN0, 0).hasOneUse() &&
7635 isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
7636 TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
7637 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
7638 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
7639
7640 SDValue ZExtLoad = DAG.getMaskedGather(
7641 DAG.getVTList(VT, MVT::Other), MemVT, DL, Ops, GN0->getMemOperand(),
7642 GN0->getIndexType(), ISD::ZEXTLOAD);
7643
7644 CombineTo(N, ZExtLoad);
7645 AddToWorklist(ZExtLoad.getNode());
7646 // Avoid recheck of N.
7647 return SDValue(N, 0);
7648 }
7649 }
7650
7651 // fold (and (load x), 255) -> (zextload x, i8)
7652 // fold (and (extload x, i16), 255) -> (zextload x, i8)
7653 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7654 if (SDValue Res = reduceLoadWidth(N))
7655 return Res;
7656
7657 if (LegalTypes) {
7658 // Attempt to propagate the AND back up to the leaves which, if they're
7659 // loads, can be combined to narrow loads and the AND node can be removed.
7660 // Perform after legalization so that extend nodes will already be
7661 // combined into the loads.
7662 if (BackwardsPropagateMask(N))
7663 return SDValue(N, 0);
7664 }
7665
7666 if (SDValue Combined = visitANDLike(N0, N1, N))
7667 return Combined;
7668
7669 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
7670 if (N0.getOpcode() == N1.getOpcode())
7671 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7672 return V;
7673
7674 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7675 return R;
7676 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
7677 return R;
7678
7679 // Fold (and X, (bswap (not Y))) -> (and X, (not (bswap Y)))
7680 // Fold (and X, (bitreverse (not Y))) -> (and X, (not (bitreverse Y)))
7681 SDValue X, Y, Z, NotY;
7682 for (unsigned Opc : {ISD::BSWAP, ISD::BITREVERSE})
7683 if (sd_match(N,
7684 m_And(m_Value(X), m_OneUse(m_UnaryOp(Opc, m_Value(NotY))))) &&
7685 sd_match(NotY, m_Not(m_Value(Y))) &&
7686 (TLI.hasAndNot(SDValue(N, 0)) || NotY->hasOneUse()))
7687 return DAG.getNode(ISD::AND, DL, VT, X,
7688 DAG.getNOT(DL, DAG.getNode(Opc, DL, VT, Y), VT));
7689
7690 // Fold (and X, (rot (not Y), Z)) -> (and X, (not (rot Y, Z)))
7691 for (unsigned Opc : {ISD::ROTL, ISD::ROTR})
7692 if (sd_match(N, m_And(m_Value(X),
7693 m_OneUse(m_BinOp(Opc, m_Value(NotY), m_Value(Z))))) &&
7694 sd_match(NotY, m_Not(m_Value(Y))) &&
7695 (TLI.hasAndNot(SDValue(N, 0)) || NotY->hasOneUse()))
7696 return DAG.getNode(ISD::AND, DL, VT, X,
7697 DAG.getNOT(DL, DAG.getNode(Opc, DL, VT, Y, Z), VT));
7698
7699 // Fold (and X, (add (not Y), Z)) -> (and X, (not (sub Y, Z)))
7700 // Fold (and X, (sub (not Y), Z)) -> (and X, (not (add Y, Z)))
7701 if (TLI.hasAndNot(SDValue(N, 0)))
7702 if (SDValue Folded = foldBitwiseOpWithNeg(N, DL, VT))
7703 return Folded;
7704
7705 // Fold (and (srl X, C), 1) -> (srl X, BW-1) for signbit extraction
7706 // If we are shifting down an extended sign bit, see if we can simplify
7707 // this to shifting the MSB directly to expose further simplifications.
7708 // This pattern often appears after sext_inreg legalization.
7709 APInt Amt;
7710 if (sd_match(N, m_And(m_Srl(m_Value(X), m_ConstInt(Amt)), m_One())) &&
7711 Amt.ult(BitWidth - 1) && Amt.uge(BitWidth - DAG.ComputeNumSignBits(X)))
7712 return DAG.getNode(ISD::SRL, DL, VT, X,
7713 DAG.getShiftAmountConstant(BitWidth - 1, VT, DL));
7714
7715 // Masking the negated extension of a boolean is just the zero-extended
7716 // boolean:
7717 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7718 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7719 //
7720 // Note: the SimplifyDemandedBits fold below can make an information-losing
7721 // transform, and then we have no way to find this better fold.
7722 if (sd_match(N, m_And(m_Sub(m_Zero(), m_Value(X)), m_One()))) {
7723 if (X.getOpcode() == ISD::ZERO_EXTEND &&
7724 X.getOperand(0).getScalarValueSizeInBits() == 1)
7725 return X;
7726 if (X.getOpcode() == ISD::SIGN_EXTEND &&
7727 X.getOperand(0).getScalarValueSizeInBits() == 1)
7728 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, X.getOperand(0));
7729 }
7730
7731 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7732 // fold (and (sra)) -> (and (srl)) when possible.
7733 if (SimplifyDemandedBits(SDValue(N, 0)))
7734 return SDValue(N, 0);
7735
7736 // fold (zext_inreg (extload x)) -> (zextload x)
7737 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7738 if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
7739 (ISD::isEXTLoad(N0.getNode()) ||
7740 (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
7741 auto *LN0 = cast<LoadSDNode>(N0);
7742 EVT MemVT = LN0->getMemoryVT();
7743 // If we zero all the possible extended bits, then we can turn this into
7744 // a zextload if we are running before legalize or the operation is legal.
7745 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7746 unsigned MemBitSize = MemVT.getScalarSizeInBits();
7747 APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
7748 if (DAG.MaskedValueIsZero(N1, ExtBits) &&
7749 ((!LegalOperations && LN0->isSimple()) ||
7750 TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
7751 SDValue ExtLoad =
7752 DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
7753 LN0->getBasePtr(), MemVT, LN0->getMemOperand());
7754 AddToWorklist(N);
7755 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
7756 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7757 }
7758 }
7759
7760 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7761 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7762 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
7763 N0.getOperand(1), false))
7764 return BSwap;
7765 }
7766
7767 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7768 return Shifts;
7769
7770 if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
7771 return V;
7772
7773 // Recognize the following pattern:
7774 //
7775 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7776 //
7777 // where bitmask is a mask that clears the upper bits of AndVT. The
7778 // number of bits in bitmask must be a power of two.
7779 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7780 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7781 return false;
7782
7783 auto *C = dyn_cast<ConstantSDNode>(RHS);
7784 if (!C)
7785 return false;
7786
7787 if (!C->getAPIntValue().isMask(
7788 LHS.getOperand(0).getValueType().getFixedSizeInBits()))
7789 return false;
7790
7791 return true;
7792 };
7793
7794 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7795 if (IsAndZeroExtMask(N0, N1))
7796 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
7797
7798 if (hasOperation(ISD::USUBSAT, VT))
7799 if (SDValue V = foldAndToUsubsat(N, DAG, DL))
7800 return V;
7801
7802 // Postpone until legalization completed to avoid interference with bswap
7803 // folding
7804 if (LegalOperations || VT.isVector())
7805 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
7806 return R;
7807
7808 if (VT.isScalarInteger() && VT != MVT::i1)
7809 if (SDValue R = foldMaskedMerge(N, DAG, TLI, DL))
7810 return R;
7811
7812 return SDValue();
7813 }
7814
7815 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)7816 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7817 bool DemandHighBits) {
7818 if (!LegalOperations)
7819 return SDValue();
7820
7821 EVT VT = N->getValueType(0);
7822 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7823 return SDValue();
7824 if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
7825 return SDValue();
7826
7827 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7828 bool LookPassAnd0 = false;
7829 bool LookPassAnd1 = false;
7830 if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
7831 std::swap(N0, N1);
7832 if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
7833 std::swap(N0, N1);
7834 if (N0.getOpcode() == ISD::AND) {
7835 if (!N0->hasOneUse())
7836 return SDValue();
7837 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7838 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7839 // This is needed for X86.
7840 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7841 N01C->getZExtValue() != 0xFFFF))
7842 return SDValue();
7843 N0 = N0.getOperand(0);
7844 LookPassAnd0 = true;
7845 }
7846
7847 if (N1.getOpcode() == ISD::AND) {
7848 if (!N1->hasOneUse())
7849 return SDValue();
7850 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7851 if (!N11C || N11C->getZExtValue() != 0xFF)
7852 return SDValue();
7853 N1 = N1.getOperand(0);
7854 LookPassAnd1 = true;
7855 }
7856
7857 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7858 std::swap(N0, N1);
7859 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7860 return SDValue();
7861 if (!N0->hasOneUse() || !N1->hasOneUse())
7862 return SDValue();
7863
7864 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7865 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
7866 if (!N01C || !N11C)
7867 return SDValue();
7868 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7869 return SDValue();
7870
7871 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7872 SDValue N00 = N0->getOperand(0);
7873 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7874 if (!N00->hasOneUse())
7875 return SDValue();
7876 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
7877 if (!N001C || N001C->getZExtValue() != 0xFF)
7878 return SDValue();
7879 N00 = N00.getOperand(0);
7880 LookPassAnd0 = true;
7881 }
7882
7883 SDValue N10 = N1->getOperand(0);
7884 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7885 if (!N10->hasOneUse())
7886 return SDValue();
7887 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
7888 // Also allow 0xFFFF since the bits will be shifted out. This is needed
7889 // for X86.
7890 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7891 N101C->getZExtValue() != 0xFFFF))
7892 return SDValue();
7893 N10 = N10.getOperand(0);
7894 LookPassAnd1 = true;
7895 }
7896
7897 if (N00 != N10)
7898 return SDValue();
7899
7900 // Make sure everything beyond the low halfword gets set to zero since the SRL
7901 // 16 will clear the top bits.
7902 unsigned OpSizeInBits = VT.getSizeInBits();
7903 if (OpSizeInBits > 16) {
7904 // If the left-shift isn't masked out then the only way this is a bswap is
7905 // if all bits beyond the low 8 are 0. In that case the entire pattern
7906 // reduces to a left shift anyway: leave it for other parts of the combiner.
7907 if (DemandHighBits && !LookPassAnd0)
7908 return SDValue();
7909
7910 // However, if the right shift isn't masked out then it might be because
7911 // it's not needed. See if we can spot that too. If the high bits aren't
7912 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7913 // upper bits to be zero.
7914 if (!LookPassAnd1) {
7915 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7916 if (!DAG.MaskedValueIsZero(N10,
7917 APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
7918 return SDValue();
7919 }
7920 }
7921
7922 SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
7923 if (OpSizeInBits > 16) {
7924 SDLoc DL(N);
7925 Res = DAG.getNode(ISD::SRL, DL, VT, Res,
7926 DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
7927 }
7928 return Res;
7929 }
7930
7931 /// Return true if the specified node is an element that makes up a 32-bit
7932 /// packed halfword byteswap.
7933 /// ((x & 0x000000ff) << 8) |
7934 /// ((x & 0x0000ff00) >> 8) |
7935 /// ((x & 0x00ff0000) << 8) |
7936 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)7937 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7938 if (!N->hasOneUse())
7939 return false;
7940
7941 unsigned Opc = N.getOpcode();
7942 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7943 return false;
7944
7945 SDValue N0 = N.getOperand(0);
7946 unsigned Opc0 = N0.getOpcode();
7947 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7948 return false;
7949
7950 ConstantSDNode *N1C = nullptr;
7951 // SHL or SRL: look upstream for AND mask operand
7952 if (Opc == ISD::AND)
7953 N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
7954 else if (Opc0 == ISD::AND)
7955 N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7956 if (!N1C)
7957 return false;
7958
7959 unsigned MaskByteOffset;
7960 switch (N1C->getZExtValue()) {
7961 default:
7962 return false;
7963 case 0xFF: MaskByteOffset = 0; break;
7964 case 0xFF00: MaskByteOffset = 1; break;
7965 case 0xFFFF:
7966 // In case demanded bits didn't clear the bits that will be shifted out.
7967 // This is needed for X86.
7968 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7969 MaskByteOffset = 1;
7970 break;
7971 }
7972 return false;
7973 case 0xFF0000: MaskByteOffset = 2; break;
7974 case 0xFF000000: MaskByteOffset = 3; break;
7975 }
7976
7977 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7978 if (Opc == ISD::AND) {
7979 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7980 // (x >> 8) & 0xff
7981 // (x >> 8) & 0xff0000
7982 if (Opc0 != ISD::SRL)
7983 return false;
7984 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7985 if (!C || C->getZExtValue() != 8)
7986 return false;
7987 } else {
7988 // (x << 8) & 0xff00
7989 // (x << 8) & 0xff000000
7990 if (Opc0 != ISD::SHL)
7991 return false;
7992 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
7993 if (!C || C->getZExtValue() != 8)
7994 return false;
7995 }
7996 } else if (Opc == ISD::SHL) {
7997 // (x & 0xff) << 8
7998 // (x & 0xff0000) << 8
7999 if (MaskByteOffset != 0 && MaskByteOffset != 2)
8000 return false;
8001 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
8002 if (!C || C->getZExtValue() != 8)
8003 return false;
8004 } else { // Opc == ISD::SRL
8005 // (x & 0xff00) >> 8
8006 // (x & 0xff000000) >> 8
8007 if (MaskByteOffset != 1 && MaskByteOffset != 3)
8008 return false;
8009 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
8010 if (!C || C->getZExtValue() != 8)
8011 return false;
8012 }
8013
8014 if (Parts[MaskByteOffset])
8015 return false;
8016
8017 Parts[MaskByteOffset] = N0.getOperand(0).getNode();
8018 return true;
8019 }
8020
8021 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)8022 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
8023 if (N.getOpcode() == ISD::OR)
8024 return isBSwapHWordElement(N.getOperand(0), Parts) &&
8025 isBSwapHWordElement(N.getOperand(1), Parts);
8026
8027 if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
8028 ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
8029 if (!C || C->getAPIntValue() != 16)
8030 return false;
8031 Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
8032 return true;
8033 }
8034
8035 return false;
8036 }
8037
8038 // Match this pattern:
8039 // (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
8040 // And rewrite this to:
8041 // (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT)8042 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
8043 SelectionDAG &DAG, SDNode *N, SDValue N0,
8044 SDValue N1, EVT VT) {
8045 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
8046 "MatchBSwapHWordOrAndAnd: expecting i32");
8047 if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
8048 return SDValue();
8049 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
8050 return SDValue();
8051 // TODO: this is too restrictive; lifting this restriction requires more tests
8052 if (!N0->hasOneUse() || !N1->hasOneUse())
8053 return SDValue();
8054 ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
8055 ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
8056 if (!Mask0 || !Mask1)
8057 return SDValue();
8058 if (Mask0->getAPIntValue() != 0xff00ff00 ||
8059 Mask1->getAPIntValue() != 0x00ff00ff)
8060 return SDValue();
8061 SDValue Shift0 = N0.getOperand(0);
8062 SDValue Shift1 = N1.getOperand(0);
8063 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
8064 return SDValue();
8065 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
8066 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
8067 if (!ShiftAmt0 || !ShiftAmt1)
8068 return SDValue();
8069 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
8070 return SDValue();
8071 if (Shift0.getOperand(0) != Shift1.getOperand(0))
8072 return SDValue();
8073
8074 SDLoc DL(N);
8075 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
8076 SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
8077 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
8078 }
8079
8080 /// Match a 32-bit packed halfword bswap. That is
8081 /// ((x & 0x000000ff) << 8) |
8082 /// ((x & 0x0000ff00) >> 8) |
8083 /// ((x & 0x00ff0000) << 8) |
8084 /// ((x & 0xff000000) >> 8)
8085 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)8086 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
8087 if (!LegalOperations)
8088 return SDValue();
8089
8090 EVT VT = N->getValueType(0);
8091 if (VT != MVT::i32)
8092 return SDValue();
8093 if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
8094 return SDValue();
8095
8096 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
8097 return BSwap;
8098
8099 // Try again with commuted operands.
8100 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
8101 return BSwap;
8102
8103
8104 // Look for either
8105 // (or (bswaphpair), (bswaphpair))
8106 // (or (or (bswaphpair), (and)), (and))
8107 // (or (or (and), (bswaphpair)), (and))
8108 SDNode *Parts[4] = {};
8109
8110 if (isBSwapHWordPair(N0, Parts)) {
8111 // (or (or (and), (and)), (or (and), (and)))
8112 if (!isBSwapHWordPair(N1, Parts))
8113 return SDValue();
8114 } else if (N0.getOpcode() == ISD::OR) {
8115 // (or (or (or (and), (and)), (and)), (and))
8116 if (!isBSwapHWordElement(N1, Parts))
8117 return SDValue();
8118 SDValue N00 = N0.getOperand(0);
8119 SDValue N01 = N0.getOperand(1);
8120 if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
8121 !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
8122 return SDValue();
8123 } else {
8124 return SDValue();
8125 }
8126
8127 // Make sure the parts are all coming from the same node.
8128 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
8129 return SDValue();
8130
8131 SDLoc DL(N);
8132 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
8133 SDValue(Parts[0], 0));
8134
8135 // Result of the bswap should be rotated by 16. If it's not legal, then
8136 // do (x << 16) | (x >> 16).
8137 SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
8138 if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
8139 return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
8140 if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
8141 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
8142 return DAG.getNode(ISD::OR, DL, VT,
8143 DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
8144 DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
8145 }
8146
8147 /// This contains all DAGCombine rules which reduce two values combined by
8148 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,const SDLoc & DL)8149 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
8150 EVT VT = N1.getValueType();
8151
8152 // fold (or x, undef) -> -1
8153 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
8154 return DAG.getAllOnesConstant(DL, VT);
8155
8156 if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
8157 return V;
8158
8159 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
8160 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
8161 // Don't increase # computations.
8162 (N0->hasOneUse() || N1->hasOneUse())) {
8163 // We can only do this xform if we know that bits from X that are set in C2
8164 // but not in C1 are already zero. Likewise for Y.
8165 if (const ConstantSDNode *N0O1C =
8166 getAsNonOpaqueConstant(N0.getOperand(1))) {
8167 if (const ConstantSDNode *N1O1C =
8168 getAsNonOpaqueConstant(N1.getOperand(1))) {
8169 // We can only do this xform if we know that bits from X that are set in
8170 // C2 but not in C1 are already zero. Likewise for Y.
8171 const APInt &LHSMask = N0O1C->getAPIntValue();
8172 const APInt &RHSMask = N1O1C->getAPIntValue();
8173
8174 if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
8175 DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
8176 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
8177 N0.getOperand(0), N1.getOperand(0));
8178 return DAG.getNode(ISD::AND, DL, VT, X,
8179 DAG.getConstant(LHSMask | RHSMask, DL, VT));
8180 }
8181 }
8182 }
8183 }
8184
8185 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
8186 if (N0.getOpcode() == ISD::AND &&
8187 N1.getOpcode() == ISD::AND &&
8188 N0.getOperand(0) == N1.getOperand(0) &&
8189 // Don't increase # computations.
8190 (N0->hasOneUse() || N1->hasOneUse())) {
8191 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
8192 N0.getOperand(1), N1.getOperand(1));
8193 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
8194 }
8195
8196 return SDValue();
8197 }
8198
8199 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)8200 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
8201 SDNode *N) {
8202 EVT VT = N0.getValueType();
8203 unsigned BW = VT.getScalarSizeInBits();
8204 SDLoc DL(N);
8205
8206 auto peekThroughResize = [](SDValue V) {
8207 if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
8208 return V->getOperand(0);
8209 return V;
8210 };
8211
8212 SDValue N0Resized = peekThroughResize(N0);
8213 if (N0Resized.getOpcode() == ISD::AND) {
8214 SDValue N1Resized = peekThroughResize(N1);
8215 SDValue N00 = N0Resized.getOperand(0);
8216 SDValue N01 = N0Resized.getOperand(1);
8217
8218 // fold or (and x, y), x --> x
8219 if (N00 == N1Resized || N01 == N1Resized)
8220 return N1;
8221
8222 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
8223 // TODO: Set AllowUndefs = true.
8224 if (SDValue NotOperand = getBitwiseNotOperand(N01, N00,
8225 /* AllowUndefs */ false)) {
8226 if (peekThroughResize(NotOperand) == N1Resized)
8227 return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N00, DL, VT),
8228 N1);
8229 }
8230
8231 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
8232 if (SDValue NotOperand = getBitwiseNotOperand(N00, N01,
8233 /* AllowUndefs */ false)) {
8234 if (peekThroughResize(NotOperand) == N1Resized)
8235 return DAG.getNode(ISD::OR, DL, VT, DAG.getZExtOrTrunc(N01, DL, VT),
8236 N1);
8237 }
8238 }
8239
8240 SDValue X, Y;
8241
8242 // fold or (xor X, N1), N1 --> or X, N1
8243 if (sd_match(N0, m_Xor(m_Value(X), m_Specific(N1))))
8244 return DAG.getNode(ISD::OR, DL, VT, X, N1);
8245
8246 // fold or (xor x, y), (x and/or y) --> or x, y
8247 if (sd_match(N0, m_Xor(m_Value(X), m_Value(Y))) &&
8248 (sd_match(N1, m_And(m_Specific(X), m_Specific(Y))) ||
8249 sd_match(N1, m_Or(m_Specific(X), m_Specific(Y)))))
8250 return DAG.getNode(ISD::OR, DL, VT, X, Y);
8251
8252 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
8253 return R;
8254
8255 auto peekThroughZext = [](SDValue V) {
8256 if (V->getOpcode() == ISD::ZERO_EXTEND)
8257 return V->getOperand(0);
8258 return V;
8259 };
8260
8261 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
8262 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
8263 N0.getOperand(0) == N1.getOperand(0) &&
8264 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
8265 return N0;
8266
8267 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
8268 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
8269 N0.getOperand(1) == N1.getOperand(0) &&
8270 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
8271 return N0;
8272
8273 // Attempt to match a legalized build_pair-esque pattern:
8274 // or(shl(aext(Hi),BW/2),zext(Lo))
8275 SDValue Lo, Hi;
8276 if (sd_match(N0,
8277 m_OneUse(m_Shl(m_AnyExt(m_Value(Hi)), m_SpecificInt(BW / 2)))) &&
8278 sd_match(N1, m_ZExt(m_Value(Lo))) &&
8279 Lo.getScalarValueSizeInBits() == (BW / 2) &&
8280 Lo.getValueType() == Hi.getValueType()) {
8281 // Fold build_pair(not(Lo),not(Hi)) -> not(build_pair(Lo,Hi)).
8282 SDValue NotLo, NotHi;
8283 if (sd_match(Lo, m_OneUse(m_Not(m_Value(NotLo)))) &&
8284 sd_match(Hi, m_OneUse(m_Not(m_Value(NotHi))))) {
8285 Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotLo);
8286 Hi = DAG.getNode(ISD::ANY_EXTEND, DL, VT, NotHi);
8287 Hi = DAG.getNode(ISD::SHL, DL, VT, Hi,
8288 DAG.getShiftAmountConstant(BW / 2, VT, DL));
8289 return DAG.getNOT(DL, DAG.getNode(ISD::OR, DL, VT, Lo, Hi), VT);
8290 }
8291 }
8292
8293 return SDValue();
8294 }
8295
visitOR(SDNode * N)8296 SDValue DAGCombiner::visitOR(SDNode *N) {
8297 SDValue N0 = N->getOperand(0);
8298 SDValue N1 = N->getOperand(1);
8299 EVT VT = N1.getValueType();
8300 SDLoc DL(N);
8301
8302 // x | x --> x
8303 if (N0 == N1)
8304 return N0;
8305
8306 // fold (or c1, c2) -> c1|c2
8307 if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, DL, VT, {N0, N1}))
8308 return C;
8309
8310 // canonicalize constant to RHS
8311 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
8312 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
8313 return DAG.getNode(ISD::OR, DL, VT, N1, N0);
8314
8315 // fold vector ops
8316 if (VT.isVector()) {
8317 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8318 return FoldedVOp;
8319
8320 // fold (or x, 0) -> x, vector edition
8321 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
8322 return N0;
8323
8324 // fold (or x, -1) -> -1, vector edition
8325 if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
8326 // do not return N1, because undef node may exist in N1
8327 return DAG.getAllOnesConstant(DL, N1.getValueType());
8328
8329 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
8330 // Do this only if the resulting type / shuffle is legal.
8331 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
8332 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
8333 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
8334 bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
8335 bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
8336 bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
8337 bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
8338 // Ensure both shuffles have a zero input.
8339 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
8340 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
8341 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
8342 bool CanFold = true;
8343 int NumElts = VT.getVectorNumElements();
8344 SmallVector<int, 4> Mask(NumElts, -1);
8345
8346 for (int i = 0; i != NumElts; ++i) {
8347 int M0 = SV0->getMaskElt(i);
8348 int M1 = SV1->getMaskElt(i);
8349
8350 // Determine if either index is pointing to a zero vector.
8351 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
8352 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
8353
8354 // If one element is zero and the otherside is undef, keep undef.
8355 // This also handles the case that both are undef.
8356 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
8357 continue;
8358
8359 // Make sure only one of the elements is zero.
8360 if (M0Zero == M1Zero) {
8361 CanFold = false;
8362 break;
8363 }
8364
8365 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
8366
8367 // We have a zero and non-zero element. If the non-zero came from
8368 // SV0 make the index a LHS index. If it came from SV1, make it
8369 // a RHS index. We need to mod by NumElts because we don't care
8370 // which operand it came from in the original shuffles.
8371 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
8372 }
8373
8374 if (CanFold) {
8375 SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
8376 SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
8377 SDValue LegalShuffle =
8378 TLI.buildLegalVectorShuffle(VT, DL, NewLHS, NewRHS, Mask, DAG);
8379 if (LegalShuffle)
8380 return LegalShuffle;
8381 }
8382 }
8383 }
8384 }
8385
8386 // fold (or x, 0) -> x
8387 if (isNullConstant(N1))
8388 return N0;
8389
8390 // fold (or x, -1) -> -1
8391 if (isAllOnesConstant(N1))
8392 return N1;
8393
8394 if (SDValue NewSel = foldBinOpIntoSelect(N))
8395 return NewSel;
8396
8397 // fold (or x, c) -> c iff (x & ~c) == 0
8398 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
8399 if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
8400 return N1;
8401
8402 if (SDValue R = foldAndOrOfSETCC(N, DAG))
8403 return R;
8404
8405 if (SDValue Combined = visitORLike(N0, N1, DL))
8406 return Combined;
8407
8408 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8409 return Combined;
8410
8411 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
8412 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
8413 return BSwap;
8414 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
8415 return BSwap;
8416
8417 // reassociate or
8418 if (SDValue ROR = reassociateOps(ISD::OR, DL, N0, N1, N->getFlags()))
8419 return ROR;
8420
8421 // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
8422 if (SDValue SD =
8423 reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, DL, VT, N0, N1))
8424 return SD;
8425
8426 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
8427 // iff (c1 & c2) != 0 or c1/c2 are undef.
8428 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
8429 return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
8430 };
8431 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
8432 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
8433 if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
8434 {N1, N0.getOperand(1)})) {
8435 SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
8436 AddToWorklist(IOR.getNode());
8437 return DAG.getNode(ISD::AND, DL, VT, COR, IOR);
8438 }
8439 }
8440
8441 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
8442 return Combined;
8443 if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
8444 return Combined;
8445
8446 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
8447 if (N0.getOpcode() == N1.getOpcode())
8448 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8449 return V;
8450
8451 // See if this is some rotate idiom.
8452 if (SDValue Rot = MatchRotate(N0, N1, DL, /*FromAdd=*/false))
8453 return Rot;
8454
8455 if (SDValue Load = MatchLoadCombine(N))
8456 return Load;
8457
8458 // Simplify the operands using demanded-bits information.
8459 if (SimplifyDemandedBits(SDValue(N, 0)))
8460 return SDValue(N, 0);
8461
8462 // If OR can be rewritten into ADD, try combines based on ADD.
8463 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8464 DAG.isADDLike(SDValue(N, 0)))
8465 if (SDValue Combined = visitADDLike(N))
8466 return Combined;
8467
8468 // Postpone until legalization completed to avoid interference with bswap
8469 // folding
8470 if (LegalOperations || VT.isVector())
8471 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8472 return R;
8473
8474 if (VT.isScalarInteger() && VT != MVT::i1)
8475 if (SDValue R = foldMaskedMerge(N, DAG, TLI, DL))
8476 return R;
8477
8478 return SDValue();
8479 }
8480
stripConstantMask(const SelectionDAG & DAG,SDValue Op,SDValue & Mask)8481 static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
8482 SDValue &Mask) {
8483 if (Op.getOpcode() == ISD::AND &&
8484 DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
8485 Mask = Op.getOperand(1);
8486 return Op.getOperand(0);
8487 }
8488 return Op;
8489 }
8490
8491 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(const SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)8492 static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
8493 SDValue &Mask) {
8494 Op = stripConstantMask(DAG, Op, Mask);
8495 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
8496 Shift = Op;
8497 return true;
8498 }
8499 return false;
8500 }
8501
8502 /// Helper function for visitOR to extract the needed side of a rotate idiom
8503 /// from a shl/srl/mul/udiv. This is meant to handle cases where
8504 /// InstCombine merged some outside op with one of the shifts from
8505 /// the rotate pattern.
8506 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
8507 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
8508 /// patterns:
8509 ///
8510 /// (or (add v v) (shrl v bitwidth-1)):
8511 /// expands (add v v) -> (shl v 1)
8512 ///
8513 /// (or (mul v c0) (shrl (mul v c1) c2)):
8514 /// expands (mul v c0) -> (shl (mul v c1) c3)
8515 ///
8516 /// (or (udiv v c0) (shl (udiv v c1) c2)):
8517 /// expands (udiv v c0) -> (shrl (udiv v c1) c3)
8518 ///
8519 /// (or (shl v c0) (shrl (shl v c1) c2)):
8520 /// expands (shl v c0) -> (shl (shl v c1) c3)
8521 ///
8522 /// (or (shrl v c0) (shl (shrl v c1) c2)):
8523 /// expands (shrl v c0) -> (shrl (shrl v c1) c3)
8524 ///
8525 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)8526 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
8527 SDValue ExtractFrom, SDValue &Mask,
8528 const SDLoc &DL) {
8529 assert(OppShift && ExtractFrom && "Empty SDValue");
8530 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
8531 return SDValue();
8532
8533 ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
8534
8535 // Value and Type of the shift.
8536 SDValue OppShiftLHS = OppShift.getOperand(0);
8537 EVT ShiftedVT = OppShiftLHS.getValueType();
8538
8539 // Amount of the existing shift.
8540 ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
8541
8542 // (add v v) -> (shl v 1)
8543 // TODO: Should this be a general DAG canonicalization?
8544 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
8545 ExtractFrom.getOpcode() == ISD::ADD &&
8546 ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
8547 ExtractFrom.getOperand(0) == OppShiftLHS &&
8548 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
8549 return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
8550 DAG.getShiftAmountConstant(1, ShiftedVT, DL));
8551
8552 // Preconditions:
8553 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
8554 //
8555 // Find opcode of the needed shift to be extracted from (op0 v c0).
8556 unsigned Opcode = ISD::DELETED_NODE;
8557 bool IsMulOrDiv = false;
8558 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
8559 // opcode or its arithmetic (mul or udiv) variant.
8560 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
8561 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
8562 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
8563 return false;
8564 Opcode = NeededShift;
8565 return true;
8566 };
8567 // op0 must be either the needed shift opcode or the mul/udiv equivalent
8568 // that the needed shift can be extracted from.
8569 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
8570 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
8571 return SDValue();
8572
8573 // op0 must be the same opcode on both sides, have the same LHS argument,
8574 // and produce the same value type.
8575 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
8576 OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
8577 ShiftedVT != ExtractFrom.getValueType())
8578 return SDValue();
8579
8580 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
8581 ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
8582 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
8583 ConstantSDNode *ExtractFromCst =
8584 isConstOrConstSplat(ExtractFrom.getOperand(1));
8585 // TODO: We should be able to handle non-uniform constant vectors for these values
8586 // Check that we have constant values.
8587 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
8588 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
8589 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
8590 return SDValue();
8591
8592 // Compute the shift amount we need to extract to complete the rotate.
8593 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
8594 if (OppShiftCst->getAPIntValue().ugt(VTWidth))
8595 return SDValue();
8596 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
8597 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
8598 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
8599 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
8600 zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
8601
8602 // Now try extract the needed shift from the ExtractFrom op and see if the
8603 // result matches up with the existing shift's LHS op.
8604 if (IsMulOrDiv) {
8605 // Op to extract from is a mul or udiv by a constant.
8606 // Check:
8607 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8608 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8609 const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
8610 NeededShiftAmt.getZExtValue());
8611 APInt ResultAmt;
8612 APInt Rem;
8613 APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
8614 if (Rem != 0 || ResultAmt != OppLHSAmt)
8615 return SDValue();
8616 } else {
8617 // Op to extract from is a shift by a constant.
8618 // Check:
8619 // c2 - (bitwidth(op0 v c0) - c1) == c0
8620 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8621 ExtractFromAmt.getBitWidth()))
8622 return SDValue();
8623 }
8624
8625 // Return the expanded shift op that should allow a rotate to be formed.
8626 EVT ShiftVT = OppShift.getOperand(1).getValueType();
8627 EVT ResVT = ExtractFrom.getValueType();
8628 SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
8629 return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
8630 }
8631
8632 // Return true if we can prove that, whenever Neg and Pos are both in the
8633 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
8634 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8635 //
8636 // (or (shift1 X, Neg), (shift2 X, Pos))
8637 //
8638 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8639 // in direction shift1 by Neg. The range [0, EltSize) means that we only need
8640 // to consider shift amounts with defined behavior.
8641 //
8642 // The IsRotate flag should be set when the LHS of both shifts is the same.
8643 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate,bool FromAdd)8644 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8645 SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
8646 const auto &TLI = DAG.getTargetLoweringInfo();
8647 // If EltSize is a power of 2 then:
8648 //
8649 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8650 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8651 //
8652 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8653 // for the stronger condition:
8654 //
8655 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
8656 //
8657 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8658 // we can just replace Neg with Neg' for the rest of the function.
8659 //
8660 // In other cases we check for the even stronger condition:
8661 //
8662 // Neg == EltSize - Pos [B]
8663 //
8664 // for all Neg and Pos. Note that the (or ...) then invokes undefined
8665 // behavior if Pos == 0 (and consequently Neg == EltSize).
8666 //
8667 // We could actually use [A] whenever EltSize is a power of 2, but the
8668 // only extra cases that it would match are those uninteresting ones
8669 // where Neg and Pos are never in range at the same time. E.g. for
8670 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8671 // as well as (sub 32, Pos), but:
8672 //
8673 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8674 //
8675 // always invokes undefined behavior for 32-bit X.
8676 //
8677 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8678 // This allows us to peek through any operations that only affect Mask's
8679 // un-demanded bits.
8680 //
8681 // NOTE: We can only do this when matching operations which won't modify the
8682 // least Log2(EltSize) significant bits and not a general funnel shift.
8683 unsigned MaskLoBits = 0;
8684 if (IsRotate && !FromAdd && isPowerOf2_64(EltSize)) {
8685 unsigned Bits = Log2_64(EltSize);
8686 unsigned NegBits = Neg.getScalarValueSizeInBits();
8687 if (NegBits >= Bits) {
8688 APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
8689 if (SDValue Inner =
8690 TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
8691 Neg = Inner;
8692 MaskLoBits = Bits;
8693 }
8694 }
8695 }
8696
8697 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8698 if (Neg.getOpcode() != ISD::SUB)
8699 return false;
8700 ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
8701 if (!NegC)
8702 return false;
8703 SDValue NegOp1 = Neg.getOperand(1);
8704
8705 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8706 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8707 // are redundant for the purpose of the equality.
8708 if (MaskLoBits) {
8709 unsigned PosBits = Pos.getScalarValueSizeInBits();
8710 if (PosBits >= MaskLoBits) {
8711 APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
8712 if (SDValue Inner =
8713 TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
8714 Pos = Inner;
8715 }
8716 }
8717 }
8718
8719 // The condition we need is now:
8720 //
8721 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8722 //
8723 // If NegOp1 == Pos then we need:
8724 //
8725 // EltSize & Mask == NegC & Mask
8726 //
8727 // (because "x & Mask" is a truncation and distributes through subtraction).
8728 //
8729 // We also need to account for a potential truncation of NegOp1 if the amount
8730 // has already been legalized to a shift amount type.
8731 APInt Width;
8732 if ((Pos == NegOp1) ||
8733 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
8734 Width = NegC->getAPIntValue();
8735
8736 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8737 // Then the condition we want to prove becomes:
8738 //
8739 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8740 //
8741 // which, again because "x & Mask" is a truncation, becomes:
8742 //
8743 // NegC & Mask == (EltSize - PosC) & Mask
8744 // EltSize & Mask == (NegC + PosC) & Mask
8745 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
8746 if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
8747 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8748 else
8749 return false;
8750 } else
8751 return false;
8752
8753 // Now we just need to check that EltSize & Mask == Width & Mask.
8754 if (MaskLoBits)
8755 // EltSize & Mask is 0 since Mask is EltSize - 1.
8756 return Width.getLoBits(MaskLoBits) == 0;
8757 return Width == EltSize;
8758 }
8759
8760 // A subroutine of MatchRotate used once we have found an OR of two opposite
8761 // shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
8762 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8763 // former being preferred if supported. InnerPos and InnerNeg are Pos and
8764 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool FromAdd,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8765 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8766 SDValue Neg, SDValue InnerPos,
8767 SDValue InnerNeg, bool FromAdd,
8768 bool HasPos, unsigned PosOpcode,
8769 unsigned NegOpcode, const SDLoc &DL) {
8770 // fold (or/add (shl x, (*ext y)),
8771 // (srl x, (*ext (sub 32, y)))) ->
8772 // (rotl x, y) or (rotr x, (sub 32, y))
8773 //
8774 // fold (or/add (shl x, (*ext (sub 32, y))),
8775 // (srl x, (*ext y))) ->
8776 // (rotr x, y) or (rotl x, (sub 32, y))
8777 EVT VT = Shifted.getValueType();
8778 if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
8779 /*IsRotate*/ true, FromAdd))
8780 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
8781 HasPos ? Pos : Neg);
8782
8783 return SDValue();
8784 }
8785
8786 // A subroutine of MatchRotate used once we have found an OR of two opposite
8787 // shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
8788 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8789 // former being preferred if supported. InnerPos and InnerNeg are Pos and
8790 // Neg with outer conversions stripped away.
8791 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool FromAdd,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)8792 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8793 SDValue Neg, SDValue InnerPos,
8794 SDValue InnerNeg, bool FromAdd,
8795 bool HasPos, unsigned PosOpcode,
8796 unsigned NegOpcode, const SDLoc &DL) {
8797 EVT VT = N0.getValueType();
8798 unsigned EltBits = VT.getScalarSizeInBits();
8799
8800 // fold (or/add (shl x0, (*ext y)),
8801 // (srl x1, (*ext (sub 32, y)))) ->
8802 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8803 //
8804 // fold (or/add (shl x0, (*ext (sub 32, y))),
8805 // (srl x1, (*ext y))) ->
8806 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8807 if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
8808 FromAdd))
8809 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
8810 HasPos ? Pos : Neg);
8811
8812 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8813 // so for now just use the PosOpcode case if its legal.
8814 // TODO: When can we use the NegOpcode case?
8815 if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
8816 SDValue X;
8817 // fold (or/add (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8818 // -> (fshl x0, x1, y)
8819 if (sd_match(N1, m_Srl(m_Value(X), m_One())) &&
8820 sd_match(InnerNeg,
8821 m_Xor(m_Specific(InnerPos), m_SpecificInt(EltBits - 1))) &&
8822 TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
8823 return DAG.getNode(ISD::FSHL, DL, VT, N0, X, Pos);
8824 }
8825
8826 // fold (or/add (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8827 // -> (fshr x0, x1, y)
8828 if (sd_match(N0, m_Shl(m_Value(X), m_One())) &&
8829 sd_match(InnerPos,
8830 m_Xor(m_Specific(InnerNeg), m_SpecificInt(EltBits - 1))) &&
8831 TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8832 return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
8833 }
8834
8835 // fold (or/add (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8836 // -> (fshr x0, x1, y)
8837 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8838 if (sd_match(N0, m_Add(m_Value(X), m_Deferred(X))) &&
8839 sd_match(InnerPos,
8840 m_Xor(m_Specific(InnerNeg), m_SpecificInt(EltBits - 1))) &&
8841 TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
8842 return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
8843 }
8844 }
8845
8846 return SDValue();
8847 }
8848
8849 // MatchRotate - Handle an 'or' or 'add' of two operands. If this is one of the
8850 // many idioms for rotate, and if the target supports rotation instructions,
8851 // generate a rot[lr]. This also matches funnel shift patterns, similar to
8852 // rotation but with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL,bool FromAdd)8853 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
8854 bool FromAdd) {
8855 EVT VT = LHS.getValueType();
8856
8857 // The target must have at least one rotate/funnel flavor.
8858 // We still try to match rotate by constant pre-legalization.
8859 // TODO: Support pre-legalization funnel-shift by constant.
8860 bool HasROTL = hasOperation(ISD::ROTL, VT);
8861 bool HasROTR = hasOperation(ISD::ROTR, VT);
8862 bool HasFSHL = hasOperation(ISD::FSHL, VT);
8863 bool HasFSHR = hasOperation(ISD::FSHR, VT);
8864
8865 // If the type is going to be promoted and the target has enabled custom
8866 // lowering for rotate, allow matching rotate by non-constants. Only allow
8867 // this for scalar types.
8868 if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
8869 TargetLowering::TypePromoteInteger) {
8870 HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
8871 HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
8872 }
8873
8874 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8875 return SDValue();
8876
8877 // Check for truncated rotate.
8878 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8879 LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
8880 assert(LHS.getValueType() == RHS.getValueType());
8881 if (SDValue Rot =
8882 MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL, FromAdd))
8883 return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
8884 }
8885
8886 // Match "(X shl/srl V1) & V2" where V2 may not be present.
8887 SDValue LHSShift; // The shift.
8888 SDValue LHSMask; // AND value if any.
8889 matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
8890
8891 SDValue RHSShift; // The shift.
8892 SDValue RHSMask; // AND value if any.
8893 matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
8894
8895 // If neither side matched a rotate half, bail
8896 if (!LHSShift && !RHSShift)
8897 return SDValue();
8898
8899 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8900 // side of the rotate, so try to handle that here. In all cases we need to
8901 // pass the matched shift from the opposite side to compute the opcode and
8902 // needed shift amount to extract. We still want to do this if both sides
8903 // matched a rotate half because one half may be a potential overshift that
8904 // can be broken down (ie if InstCombine merged two shl or srl ops into a
8905 // single one).
8906
8907 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8908 if (LHSShift)
8909 if (SDValue NewRHSShift =
8910 extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
8911 RHSShift = NewRHSShift;
8912 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8913 if (RHSShift)
8914 if (SDValue NewLHSShift =
8915 extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
8916 LHSShift = NewLHSShift;
8917
8918 // If a side is still missing, nothing else we can do.
8919 if (!RHSShift || !LHSShift)
8920 return SDValue();
8921
8922 // At this point we've matched or extracted a shift op on each side.
8923
8924 if (LHSShift.getOpcode() == RHSShift.getOpcode())
8925 return SDValue(); // Shifts must disagree.
8926
8927 // Canonicalize shl to left side in a shl/srl pair.
8928 if (RHSShift.getOpcode() == ISD::SHL) {
8929 std::swap(LHS, RHS);
8930 std::swap(LHSShift, RHSShift);
8931 std::swap(LHSMask, RHSMask);
8932 }
8933
8934 // Something has gone wrong - we've lost the shl/srl pair - bail.
8935 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8936 return SDValue();
8937
8938 unsigned EltSizeInBits = VT.getScalarSizeInBits();
8939 SDValue LHSShiftArg = LHSShift.getOperand(0);
8940 SDValue LHSShiftAmt = LHSShift.getOperand(1);
8941 SDValue RHSShiftArg = RHSShift.getOperand(0);
8942 SDValue RHSShiftAmt = RHSShift.getOperand(1);
8943
8944 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8945 ConstantSDNode *RHS) {
8946 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8947 };
8948
8949 auto ApplyMasks = [&](SDValue Res) {
8950 // If there is an AND of either shifted operand, apply it to the result.
8951 if (LHSMask.getNode() || RHSMask.getNode()) {
8952 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8953 SDValue Mask = AllOnes;
8954
8955 if (LHSMask.getNode()) {
8956 SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
8957 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8958 DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
8959 }
8960 if (RHSMask.getNode()) {
8961 SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
8962 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
8963 DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
8964 }
8965
8966 Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
8967 }
8968
8969 return Res;
8970 };
8971
8972 // TODO: Support pre-legalization funnel-shift by constant.
8973 bool IsRotate = LHSShiftArg == RHSShiftArg;
8974 if (!IsRotate && !(HasFSHL || HasFSHR)) {
8975 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8976 ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
8977 // Look for a disguised rotate by constant.
8978 // The common shifted operand X may be hidden inside another 'or'.
8979 SDValue X, Y;
8980 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8981 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8982 return false;
8983 if (CommonOp == Or.getOperand(0)) {
8984 X = CommonOp;
8985 Y = Or.getOperand(1);
8986 return true;
8987 }
8988 if (CommonOp == Or.getOperand(1)) {
8989 X = CommonOp;
8990 Y = Or.getOperand(0);
8991 return true;
8992 }
8993 return false;
8994 };
8995
8996 SDValue Res;
8997 if (matchOr(LHSShiftArg, RHSShiftArg)) {
8998 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
8999 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
9000 SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
9001 Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
9002 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
9003 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
9004 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
9005 SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
9006 Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
9007 } else {
9008 return SDValue();
9009 }
9010
9011 return ApplyMasks(Res);
9012 }
9013
9014 return SDValue(); // Requires funnel shift support.
9015 }
9016
9017 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotl x, C1)
9018 // fold (or/add (shl x, C1), (srl x, C2)) -> (rotr x, C2)
9019 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
9020 // fold (or/add (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
9021 // iff C1+C2 == EltSizeInBits
9022 if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
9023 SDValue Res;
9024 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
9025 bool UseROTL = !LegalOperations || HasROTL;
9026 Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
9027 UseROTL ? LHSShiftAmt : RHSShiftAmt);
9028 } else {
9029 bool UseFSHL = !LegalOperations || HasFSHL;
9030 Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
9031 RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
9032 }
9033
9034 return ApplyMasks(Res);
9035 }
9036
9037 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
9038 // shift.
9039 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
9040 return SDValue();
9041
9042 // If there is a mask here, and we have a variable shift, we can't be sure
9043 // that we're masking out the right stuff.
9044 if (LHSMask.getNode() || RHSMask.getNode())
9045 return SDValue();
9046
9047 // If the shift amount is sign/zext/any-extended just peel it off.
9048 SDValue LExtOp0 = LHSShiftAmt;
9049 SDValue RExtOp0 = RHSShiftAmt;
9050 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9051 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9052 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9053 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
9054 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
9055 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
9056 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
9057 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
9058 LExtOp0 = LHSShiftAmt.getOperand(0);
9059 RExtOp0 = RHSShiftAmt.getOperand(0);
9060 }
9061
9062 if (IsRotate && (HasROTL || HasROTR)) {
9063 if (SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
9064 LExtOp0, RExtOp0, FromAdd, HasROTL,
9065 ISD::ROTL, ISD::ROTR, DL))
9066 return TryL;
9067
9068 if (SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
9069 RExtOp0, LExtOp0, FromAdd, HasROTR,
9070 ISD::ROTR, ISD::ROTL, DL))
9071 return TryR;
9072 }
9073
9074 if (SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt,
9075 RHSShiftAmt, LExtOp0, RExtOp0, FromAdd,
9076 HasFSHL, ISD::FSHL, ISD::FSHR, DL))
9077 return TryL;
9078
9079 if (SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt,
9080 LHSShiftAmt, RExtOp0, LExtOp0, FromAdd,
9081 HasFSHR, ISD::FSHR, ISD::FSHL, DL))
9082 return TryR;
9083
9084 return SDValue();
9085 }
9086
9087 /// Recursively traverses the expression calculating the origin of the requested
9088 /// byte of the given value. Returns std::nullopt if the provider can't be
9089 /// calculated.
9090 ///
9091 /// For all the values except the root of the expression, we verify that the
9092 /// value has exactly one use and if not then return std::nullopt. This way if
9093 /// the origin of the byte is returned it's guaranteed that the values which
9094 /// contribute to the byte are not used outside of this expression.
9095
9096 /// However, there is a special case when dealing with vector loads -- we allow
9097 /// more than one use if the load is a vector type. Since the values that
9098 /// contribute to the byte ultimately come from the ExtractVectorElements of the
9099 /// Load, we don't care if the Load has uses other than ExtractVectorElements,
9100 /// because those operations are independent from the pattern to be combined.
9101 /// For vector loads, we simply care that the ByteProviders are adjacent
9102 /// positions of the same vector, and their index matches the byte that is being
9103 /// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
9104 /// is the index used in an ExtractVectorElement, and \p StartingIndex is the
9105 /// byte position we are trying to provide for the LoadCombine. If these do
9106 /// not match, then we can not combine the vector loads. \p Index uses the
9107 /// byte position we are trying to provide for and is matched against the
9108 /// shl and load size. The \p Index algorithm ensures the requested byte is
9109 /// provided for by the pattern, and the pattern does not over provide bytes.
9110 ///
9111 ///
9112 /// The supported LoadCombine pattern for vector loads is as follows
9113 /// or
9114 /// / \
9115 /// or shl
9116 /// / \ |
9117 /// or shl zext
9118 /// / \ | |
9119 /// shl zext zext EVE*
9120 /// | | | |
9121 /// zext EVE* EVE* LOAD
9122 /// | | |
9123 /// EVE* LOAD LOAD
9124 /// |
9125 /// LOAD
9126 ///
9127 /// *ExtractVectorElement
9128 using SDByteProvider = ByteProvider<SDNode *>;
9129
9130 static std::optional<SDByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,std::optional<uint64_t> VectorIndex,unsigned StartingIndex=0)9131 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
9132 std::optional<uint64_t> VectorIndex,
9133 unsigned StartingIndex = 0) {
9134
9135 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
9136 if (Depth == 10)
9137 return std::nullopt;
9138
9139 // Only allow multiple uses if the instruction is a vector load (in which
9140 // case we will use the load for every ExtractVectorElement)
9141 if (Depth && !Op.hasOneUse() &&
9142 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
9143 return std::nullopt;
9144
9145 // Fail to combine if we have encountered anything but a LOAD after handling
9146 // an ExtractVectorElement.
9147 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
9148 return std::nullopt;
9149
9150 unsigned BitWidth = Op.getScalarValueSizeInBits();
9151 if (BitWidth % 8 != 0)
9152 return std::nullopt;
9153 unsigned ByteWidth = BitWidth / 8;
9154 assert(Index < ByteWidth && "invalid index requested");
9155 (void) ByteWidth;
9156
9157 switch (Op.getOpcode()) {
9158 case ISD::OR: {
9159 auto LHS =
9160 calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
9161 if (!LHS)
9162 return std::nullopt;
9163 auto RHS =
9164 calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
9165 if (!RHS)
9166 return std::nullopt;
9167
9168 if (LHS->isConstantZero())
9169 return RHS;
9170 if (RHS->isConstantZero())
9171 return LHS;
9172 return std::nullopt;
9173 }
9174 case ISD::SHL: {
9175 auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
9176 if (!ShiftOp)
9177 return std::nullopt;
9178
9179 uint64_t BitShift = ShiftOp->getZExtValue();
9180
9181 if (BitShift % 8 != 0)
9182 return std::nullopt;
9183 uint64_t ByteShift = BitShift / 8;
9184
9185 // If we are shifting by an amount greater than the index we are trying to
9186 // provide, then do not provide anything. Otherwise, subtract the index by
9187 // the amount we shifted by.
9188 return Index < ByteShift
9189 ? SDByteProvider::getConstantZero()
9190 : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
9191 Depth + 1, VectorIndex, Index);
9192 }
9193 case ISD::ANY_EXTEND:
9194 case ISD::SIGN_EXTEND:
9195 case ISD::ZERO_EXTEND: {
9196 SDValue NarrowOp = Op->getOperand(0);
9197 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9198 if (NarrowBitWidth % 8 != 0)
9199 return std::nullopt;
9200 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9201
9202 if (Index >= NarrowByteWidth)
9203 return Op.getOpcode() == ISD::ZERO_EXTEND
9204 ? std::optional<SDByteProvider>(
9205 SDByteProvider::getConstantZero())
9206 : std::nullopt;
9207 return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
9208 StartingIndex);
9209 }
9210 case ISD::BSWAP:
9211 return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
9212 Depth + 1, VectorIndex, StartingIndex);
9213 case ISD::EXTRACT_VECTOR_ELT: {
9214 auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
9215 if (!OffsetOp)
9216 return std::nullopt;
9217
9218 VectorIndex = OffsetOp->getZExtValue();
9219
9220 SDValue NarrowOp = Op->getOperand(0);
9221 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
9222 if (NarrowBitWidth % 8 != 0)
9223 return std::nullopt;
9224 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9225 // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
9226 // type, leaving the high bits undefined.
9227 if (Index >= NarrowByteWidth)
9228 return std::nullopt;
9229
9230 // Check to see if the position of the element in the vector corresponds
9231 // with the byte we are trying to provide for. In the case of a vector of
9232 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
9233 // the element will provide a range of bytes. For example, if we have a
9234 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
9235 // 3).
9236 if (*VectorIndex * NarrowByteWidth > StartingIndex)
9237 return std::nullopt;
9238 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
9239 return std::nullopt;
9240
9241 return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
9242 VectorIndex, StartingIndex);
9243 }
9244 case ISD::LOAD: {
9245 auto L = cast<LoadSDNode>(Op.getNode());
9246 if (!L->isSimple() || L->isIndexed())
9247 return std::nullopt;
9248
9249 unsigned NarrowBitWidth = L->getMemoryVT().getScalarSizeInBits();
9250 if (NarrowBitWidth % 8 != 0)
9251 return std::nullopt;
9252 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
9253
9254 // If the width of the load does not reach byte we are trying to provide for
9255 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
9256 // question
9257 if (Index >= NarrowByteWidth)
9258 return L->getExtensionType() == ISD::ZEXTLOAD
9259 ? std::optional<SDByteProvider>(
9260 SDByteProvider::getConstantZero())
9261 : std::nullopt;
9262
9263 unsigned BPVectorIndex = VectorIndex.value_or(0U);
9264 return SDByteProvider::getSrc(L, Index, BPVectorIndex);
9265 }
9266 }
9267
9268 return std::nullopt;
9269 }
9270
littleEndianByteAt(unsigned BW,unsigned i)9271 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
9272 return i;
9273 }
9274
bigEndianByteAt(unsigned BW,unsigned i)9275 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
9276 return BW - i - 1;
9277 }
9278
9279 // Check if the bytes offsets we are looking at match with either big or
9280 // little endian value loaded. Return true for big endian, false for little
9281 // endian, and std::nullopt if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)9282 static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
9283 int64_t FirstOffset) {
9284 // The endian can be decided only when it is 2 bytes at least.
9285 unsigned Width = ByteOffsets.size();
9286 if (Width < 2)
9287 return std::nullopt;
9288
9289 bool BigEndian = true, LittleEndian = true;
9290 for (unsigned i = 0; i < Width; i++) {
9291 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
9292 LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
9293 BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
9294 if (!BigEndian && !LittleEndian)
9295 return std::nullopt;
9296 }
9297
9298 assert((BigEndian != LittleEndian) && "It should be either big endian or"
9299 "little endian");
9300 return BigEndian;
9301 }
9302
9303 // Look through one layer of truncate or extend.
stripTruncAndExt(SDValue Value)9304 static SDValue stripTruncAndExt(SDValue Value) {
9305 switch (Value.getOpcode()) {
9306 case ISD::TRUNCATE:
9307 case ISD::ZERO_EXTEND:
9308 case ISD::SIGN_EXTEND:
9309 case ISD::ANY_EXTEND:
9310 return Value.getOperand(0);
9311 }
9312 return SDValue();
9313 }
9314
9315 /// Match a pattern where a wide type scalar value is stored by several narrow
9316 /// stores. Fold it into a single store or a BSWAP and a store if the targets
9317 /// supports it.
9318 ///
9319 /// Assuming little endian target:
9320 /// i8 *p = ...
9321 /// i32 val = ...
9322 /// p[0] = (val >> 0) & 0xFF;
9323 /// p[1] = (val >> 8) & 0xFF;
9324 /// p[2] = (val >> 16) & 0xFF;
9325 /// p[3] = (val >> 24) & 0xFF;
9326 /// =>
9327 /// *((i32)p) = val;
9328 ///
9329 /// i8 *p = ...
9330 /// i32 val = ...
9331 /// p[0] = (val >> 24) & 0xFF;
9332 /// p[1] = (val >> 16) & 0xFF;
9333 /// p[2] = (val >> 8) & 0xFF;
9334 /// p[3] = (val >> 0) & 0xFF;
9335 /// =>
9336 /// *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)9337 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
9338 // The matching looks for "store (trunc x)" patterns that appear early but are
9339 // likely to be replaced by truncating store nodes during combining.
9340 // TODO: If there is evidence that running this later would help, this
9341 // limitation could be removed. Legality checks may need to be added
9342 // for the created store and optional bswap/rotate.
9343 if (LegalOperations || OptLevel == CodeGenOptLevel::None)
9344 return SDValue();
9345
9346 // We only handle merging simple stores of 1-4 bytes.
9347 // TODO: Allow unordered atomics when wider type is legal (see D66309)
9348 EVT MemVT = N->getMemoryVT();
9349 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
9350 !N->isSimple() || N->isIndexed())
9351 return SDValue();
9352
9353 // Collect all of the stores in the chain, upto the maximum store width (i64).
9354 SDValue Chain = N->getChain();
9355 SmallVector<StoreSDNode *, 8> Stores = {N};
9356 unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
9357 unsigned MaxWideNumBits = 64;
9358 unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
9359 while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
9360 // All stores must be the same size to ensure that we are writing all of the
9361 // bytes in the wide value.
9362 // This store should have exactly one use as a chain operand for another
9363 // store in the merging set. If there are other chain uses, then the
9364 // transform may not be safe because order of loads/stores outside of this
9365 // set may not be preserved.
9366 // TODO: We could allow multiple sizes by tracking each stored byte.
9367 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
9368 Store->isIndexed() || !Store->hasOneUse())
9369 return SDValue();
9370 Stores.push_back(Store);
9371 Chain = Store->getChain();
9372 if (MaxStores < Stores.size())
9373 return SDValue();
9374 }
9375 // There is no reason to continue if we do not have at least a pair of stores.
9376 if (Stores.size() < 2)
9377 return SDValue();
9378
9379 // Handle simple types only.
9380 LLVMContext &Context = *DAG.getContext();
9381 unsigned NumStores = Stores.size();
9382 unsigned WideNumBits = NumStores * NarrowNumBits;
9383 EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
9384 if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
9385 return SDValue();
9386
9387 // Check if all bytes of the source value that we are looking at are stored
9388 // to the same base address. Collect offsets from Base address into OffsetMap.
9389 SDValue SourceValue;
9390 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
9391 int64_t FirstOffset = INT64_MAX;
9392 StoreSDNode *FirstStore = nullptr;
9393 std::optional<BaseIndexOffset> Base;
9394 for (auto *Store : Stores) {
9395 // All the stores store different parts of the CombinedValue. A truncate is
9396 // required to get the partial value.
9397 SDValue Trunc = Store->getValue();
9398 if (Trunc.getOpcode() != ISD::TRUNCATE)
9399 return SDValue();
9400 // Other than the first/last part, a shift operation is required to get the
9401 // offset.
9402 int64_t Offset = 0;
9403 SDValue WideVal = Trunc.getOperand(0);
9404 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
9405 isa<ConstantSDNode>(WideVal.getOperand(1))) {
9406 // The shift amount must be a constant multiple of the narrow type.
9407 // It is translated to the offset address in the wide source value "y".
9408 //
9409 // x = srl y, ShiftAmtC
9410 // i8 z = trunc x
9411 // store z, ...
9412 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
9413 if (ShiftAmtC % NarrowNumBits != 0)
9414 return SDValue();
9415
9416 // Make sure we aren't reading bits that are shifted in.
9417 if (ShiftAmtC > WideVal.getScalarValueSizeInBits() - NarrowNumBits)
9418 return SDValue();
9419
9420 Offset = ShiftAmtC / NarrowNumBits;
9421 WideVal = WideVal.getOperand(0);
9422 }
9423
9424 // Stores must share the same source value with different offsets.
9425 if (!SourceValue)
9426 SourceValue = WideVal;
9427 else if (SourceValue != WideVal) {
9428 // Truncate and extends can be stripped to see if the values are related.
9429 if (stripTruncAndExt(SourceValue) != WideVal &&
9430 stripTruncAndExt(WideVal) != SourceValue)
9431 return SDValue();
9432
9433 if (WideVal.getScalarValueSizeInBits() >
9434 SourceValue.getScalarValueSizeInBits())
9435 SourceValue = WideVal;
9436
9437 // Give up if the source value type is smaller than the store size.
9438 if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
9439 return SDValue();
9440 }
9441
9442 // Stores must share the same base address.
9443 BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
9444 int64_t ByteOffsetFromBase = 0;
9445 if (!Base)
9446 Base = Ptr;
9447 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9448 return SDValue();
9449
9450 // Remember the first store.
9451 if (ByteOffsetFromBase < FirstOffset) {
9452 FirstStore = Store;
9453 FirstOffset = ByteOffsetFromBase;
9454 }
9455 // Map the offset in the store and the offset in the combined value, and
9456 // early return if it has been set before.
9457 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
9458 return SDValue();
9459 OffsetMap[Offset] = ByteOffsetFromBase;
9460 }
9461
9462 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9463 assert(FirstStore && "First store must be set");
9464
9465 // Check that a store of the wide type is both allowed and fast on the target
9466 const DataLayout &Layout = DAG.getDataLayout();
9467 unsigned Fast = 0;
9468 bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
9469 *FirstStore->getMemOperand(), &Fast);
9470 if (!Allowed || !Fast)
9471 return SDValue();
9472
9473 // Check if the pieces of the value are going to the expected places in memory
9474 // to merge the stores.
9475 auto checkOffsets = [&](bool MatchLittleEndian) {
9476 if (MatchLittleEndian) {
9477 for (unsigned i = 0; i != NumStores; ++i)
9478 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
9479 return false;
9480 } else { // MatchBigEndian by reversing loop counter.
9481 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
9482 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
9483 return false;
9484 }
9485 return true;
9486 };
9487
9488 // Check if the offsets line up for the native data layout of this target.
9489 bool NeedBswap = false;
9490 bool NeedRotate = false;
9491 if (!checkOffsets(Layout.isLittleEndian())) {
9492 // Special-case: check if byte offsets line up for the opposite endian.
9493 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
9494 NeedBswap = true;
9495 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
9496 NeedRotate = true;
9497 else
9498 return SDValue();
9499 }
9500
9501 SDLoc DL(N);
9502 if (WideVT != SourceValue.getValueType()) {
9503 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
9504 "Unexpected store value to merge");
9505 SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
9506 }
9507
9508 // Before legalize we can introduce illegal bswaps/rotates which will be later
9509 // converted to an explicit bswap sequence. This way we end up with a single
9510 // store and byte shuffling instead of several stores and byte shuffling.
9511 if (NeedBswap) {
9512 SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
9513 } else if (NeedRotate) {
9514 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
9515 SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
9516 SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
9517 }
9518
9519 SDValue NewStore =
9520 DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
9521 FirstStore->getPointerInfo(), FirstStore->getAlign());
9522
9523 // Rely on other DAG combine rules to remove the other individual stores.
9524 DAG.ReplaceAllUsesWith(N, NewStore.getNode());
9525 return NewStore;
9526 }
9527
9528 /// Match a pattern where a wide type scalar value is loaded by several narrow
9529 /// loads and combined by shifts and ors. Fold it into a single load or a load
9530 /// and a BSWAP if the targets supports it.
9531 ///
9532 /// Assuming little endian target:
9533 /// i8 *a = ...
9534 /// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
9535 /// =>
9536 /// i32 val = *((i32)a)
9537 ///
9538 /// i8 *a = ...
9539 /// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
9540 /// =>
9541 /// i32 val = BSWAP(*((i32)a))
9542 ///
9543 /// TODO: This rule matches complex patterns with OR node roots and doesn't
9544 /// interact well with the worklist mechanism. When a part of the pattern is
9545 /// updated (e.g. one of the loads) its direct users are put into the worklist,
9546 /// but the root node of the pattern which triggers the load combine is not
9547 /// necessarily a direct user of the changed node. For example, once the address
9548 /// of t28 load is reassociated load combine won't be triggered:
9549 /// t25: i32 = add t4, Constant:i32<2>
9550 /// t26: i64 = sign_extend t25
9551 /// t27: i64 = add t2, t26
9552 /// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
9553 /// t29: i32 = zero_extend t28
9554 /// t32: i32 = shl t29, Constant:i8<8>
9555 /// t33: i32 = or t23, t32
9556 /// As a possible fix visitLoad can check if the load can be a part of a load
9557 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)9558 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
9559 assert(N->getOpcode() == ISD::OR &&
9560 "Can only match load combining against OR nodes");
9561
9562 // Handles simple types only
9563 EVT VT = N->getValueType(0);
9564 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
9565 return SDValue();
9566 unsigned ByteWidth = VT.getSizeInBits() / 8;
9567
9568 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
9569 auto MemoryByteOffset = [&](SDByteProvider P) {
9570 assert(P.hasSrc() && "Must be a memory byte provider");
9571 auto *Load = cast<LoadSDNode>(P.Src.value());
9572
9573 unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
9574
9575 assert(LoadBitWidth % 8 == 0 &&
9576 "can only analyze providers for individual bytes not bit");
9577 unsigned LoadByteWidth = LoadBitWidth / 8;
9578 return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset)
9579 : littleEndianByteAt(LoadByteWidth, P.DestOffset);
9580 };
9581
9582 std::optional<BaseIndexOffset> Base;
9583 SDValue Chain;
9584
9585 SmallPtrSet<LoadSDNode *, 8> Loads;
9586 std::optional<SDByteProvider> FirstByteProvider;
9587 int64_t FirstOffset = INT64_MAX;
9588
9589 // Check if all the bytes of the OR we are looking at are loaded from the same
9590 // base address. Collect bytes offsets from Base address in ByteOffsets.
9591 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
9592 unsigned ZeroExtendedBytes = 0;
9593 for (int i = ByteWidth - 1; i >= 0; --i) {
9594 auto P =
9595 calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
9596 /*StartingIndex*/ i);
9597 if (!P)
9598 return SDValue();
9599
9600 if (P->isConstantZero()) {
9601 // It's OK for the N most significant bytes to be 0, we can just
9602 // zero-extend the load.
9603 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9604 return SDValue();
9605 continue;
9606 }
9607 assert(P->hasSrc() && "provenance should either be memory or zero");
9608 auto *L = cast<LoadSDNode>(P->Src.value());
9609
9610 // All loads must share the same chain
9611 SDValue LChain = L->getChain();
9612 if (!Chain)
9613 Chain = LChain;
9614 else if (Chain != LChain)
9615 return SDValue();
9616
9617 // Loads must share the same base address
9618 BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
9619 int64_t ByteOffsetFromBase = 0;
9620
9621 // For vector loads, the expected load combine pattern will have an
9622 // ExtractElement for each index in the vector. While each of these
9623 // ExtractElements will be accessing the same base address as determined
9624 // by the load instruction, the actual bytes they interact with will differ
9625 // due to different ExtractElement indices. To accurately determine the
9626 // byte position of an ExtractElement, we offset the base load ptr with
9627 // the index multiplied by the byte size of each element in the vector.
9628 if (L->getMemoryVT().isVector()) {
9629 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9630 if (LoadWidthInBit % 8 != 0)
9631 return SDValue();
9632 unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9633 Ptr.addToOffset(ByteOffsetFromVector);
9634 }
9635
9636 if (!Base)
9637 Base = Ptr;
9638
9639 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9640 return SDValue();
9641
9642 // Calculate the offset of the current byte from the base address
9643 ByteOffsetFromBase += MemoryByteOffset(*P);
9644 ByteOffsets[i] = ByteOffsetFromBase;
9645
9646 // Remember the first byte load
9647 if (ByteOffsetFromBase < FirstOffset) {
9648 FirstByteProvider = P;
9649 FirstOffset = ByteOffsetFromBase;
9650 }
9651
9652 Loads.insert(L);
9653 }
9654
9655 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9656 "memory, so there must be at least one load which produces the value");
9657 assert(Base && "Base address of the accessed memory location must be set");
9658 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9659
9660 bool NeedsZext = ZeroExtendedBytes > 0;
9661
9662 EVT MemVT =
9663 EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
9664
9665 if (!MemVT.isSimple())
9666 return SDValue();
9667
9668 // Before legalize we can introduce too wide illegal loads which will be later
9669 // split into legal sized loads. This enables us to combine i64 load by i8
9670 // patterns to a couple of i32 loads on 32 bit targets.
9671 if (LegalOperations &&
9672 !TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
9673 MemVT))
9674 return SDValue();
9675
9676 // Check if the bytes of the OR we are looking at match with either big or
9677 // little endian value load
9678 std::optional<bool> IsBigEndian = isBigEndian(
9679 ArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
9680 if (!IsBigEndian)
9681 return SDValue();
9682
9683 assert(FirstByteProvider && "must be set");
9684
9685 // Ensure that the first byte is loaded from zero offset of the first load.
9686 // So the combined value can be loaded from the first load address.
9687 if (MemoryByteOffset(*FirstByteProvider) != 0)
9688 return SDValue();
9689 auto *FirstLoad = cast<LoadSDNode>(FirstByteProvider->Src.value());
9690
9691 // The node we are looking at matches with the pattern, check if we can
9692 // replace it with a single (possibly zero-extended) load and bswap + shift if
9693 // needed.
9694
9695 // If the load needs byte swap check if the target supports it
9696 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9697
9698 // Before legalize we can introduce illegal bswaps which will be later
9699 // converted to an explicit bswap sequence. This way we end up with a single
9700 // load and byte shuffling instead of several loads and byte shuffling.
9701 // We do not introduce illegal bswaps when zero-extending as this tends to
9702 // introduce too many arithmetic instructions.
9703 if (NeedsBswap && (LegalOperations || NeedsZext) &&
9704 !TLI.isOperationLegal(ISD::BSWAP, VT))
9705 return SDValue();
9706
9707 // If we need to bswap and zero extend, we have to insert a shift. Check that
9708 // it is legal.
9709 if (NeedsBswap && NeedsZext && LegalOperations &&
9710 !TLI.isOperationLegal(ISD::SHL, VT))
9711 return SDValue();
9712
9713 // Check that a load of the wide type is both allowed and fast on the target
9714 unsigned Fast = 0;
9715 bool Allowed =
9716 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
9717 *FirstLoad->getMemOperand(), &Fast);
9718 if (!Allowed || !Fast)
9719 return SDValue();
9720
9721 SDValue NewLoad =
9722 DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
9723 Chain, FirstLoad->getBasePtr(),
9724 FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
9725
9726 // Transfer chain users from old loads to the new load.
9727 for (LoadSDNode *L : Loads)
9728 DAG.makeEquivalentMemoryOrdering(L, NewLoad);
9729
9730 if (!NeedsBswap)
9731 return NewLoad;
9732
9733 SDValue ShiftedLoad =
9734 NeedsZext ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
9735 DAG.getShiftAmountConstant(ZeroExtendedBytes * 8,
9736 VT, SDLoc(N)))
9737 : NewLoad;
9738 return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
9739 }
9740
9741 // If the target has andn, bsl, or a similar bit-select instruction,
9742 // we want to unfold masked merge, with canonical pattern of:
9743 // | A | |B|
9744 // ((x ^ y) & m) ^ y
9745 // | D |
9746 // Into:
9747 // (x & m) | (y & ~m)
9748 // If y is a constant, m is not a 'not', and the 'andn' does not work with
9749 // immediates, we unfold into a different pattern:
9750 // ~(~x & m) & (m | y)
9751 // If x is a constant, m is a 'not', and the 'andn' does not work with
9752 // immediates, we unfold into a different pattern:
9753 // (x | ~m) & ~(~m & ~y)
9754 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9755 // the very least that breaks andnpd / andnps patterns, and because those
9756 // patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)9757 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9758 assert(N->getOpcode() == ISD::XOR);
9759
9760 // Don't touch 'not' (i.e. where y = -1).
9761 if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
9762 return SDValue();
9763
9764 EVT VT = N->getValueType(0);
9765
9766 // There are 3 commutable operators in the pattern,
9767 // so we have to deal with 8 possible variants of the basic pattern.
9768 SDValue X, Y, M;
9769 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9770 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9771 return false;
9772 SDValue Xor = And.getOperand(XorIdx);
9773 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9774 return false;
9775 SDValue Xor0 = Xor.getOperand(0);
9776 SDValue Xor1 = Xor.getOperand(1);
9777 // Don't touch 'not' (i.e. where y = -1).
9778 if (isAllOnesOrAllOnesSplat(Xor1))
9779 return false;
9780 if (Other == Xor0)
9781 std::swap(Xor0, Xor1);
9782 if (Other != Xor1)
9783 return false;
9784 X = Xor0;
9785 Y = Xor1;
9786 M = And.getOperand(XorIdx ? 0 : 1);
9787 return true;
9788 };
9789
9790 SDValue N0 = N->getOperand(0);
9791 SDValue N1 = N->getOperand(1);
9792 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9793 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9794 return SDValue();
9795
9796 // Don't do anything if the mask is constant. This should not be reachable.
9797 // InstCombine should have already unfolded this pattern, and DAGCombiner
9798 // probably shouldn't produce it, too.
9799 if (isa<ConstantSDNode>(M.getNode()))
9800 return SDValue();
9801
9802 // We can transform if the target has AndNot
9803 if (!TLI.hasAndNot(M))
9804 return SDValue();
9805
9806 SDLoc DL(N);
9807
9808 // If Y is a constant, check that 'andn' works with immediates. Unless M is
9809 // a bitwise not that would already allow ANDN to be used.
9810 if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
9811 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9812 // If not, we need to do a bit more work to make sure andn is still used.
9813 SDValue NotX = DAG.getNOT(DL, X, VT);
9814 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
9815 SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
9816 SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
9817 return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
9818 }
9819
9820 // If X is a constant and M is a bitwise not, check that 'andn' works with
9821 // immediates.
9822 if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
9823 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9824 // If not, we need to do a bit more work to make sure andn is still used.
9825 SDValue NotM = M.getOperand(0);
9826 SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
9827 SDValue NotY = DAG.getNOT(DL, Y, VT);
9828 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
9829 SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
9830 return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
9831 }
9832
9833 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
9834 SDValue NotM = DAG.getNOT(DL, M, VT);
9835 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
9836
9837 return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
9838 }
9839
visitXOR(SDNode * N)9840 SDValue DAGCombiner::visitXOR(SDNode *N) {
9841 SDValue N0 = N->getOperand(0);
9842 SDValue N1 = N->getOperand(1);
9843 EVT VT = N0.getValueType();
9844 SDLoc DL(N);
9845
9846 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9847 if (N0.isUndef() && N1.isUndef())
9848 return DAG.getConstant(0, DL, VT);
9849
9850 // fold (xor x, undef) -> undef
9851 if (N0.isUndef())
9852 return N0;
9853 if (N1.isUndef())
9854 return N1;
9855
9856 // fold (xor c1, c2) -> c1^c2
9857 if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
9858 return C;
9859
9860 // canonicalize constant to RHS
9861 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
9862 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
9863 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
9864
9865 // fold vector ops
9866 if (VT.isVector()) {
9867 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9868 return FoldedVOp;
9869
9870 // fold (xor x, 0) -> x, vector edition
9871 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
9872 return N0;
9873 }
9874
9875 // fold (xor x, 0) -> x
9876 if (isNullConstant(N1))
9877 return N0;
9878
9879 if (SDValue NewSel = foldBinOpIntoSelect(N))
9880 return NewSel;
9881
9882 // reassociate xor
9883 if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
9884 return RXOR;
9885
9886 // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9887 if (SDValue SD =
9888 reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
9889 return SD;
9890
9891 // fold (a^b) -> (a|b) iff a and b share no bits.
9892 if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
9893 DAG.haveNoCommonBitsSet(N0, N1))
9894 return DAG.getNode(ISD::OR, DL, VT, N0, N1, SDNodeFlags::Disjoint);
9895
9896 // look for 'add-like' folds:
9897 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9898 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
9899 isMinSignedConstant(N1))
9900 if (SDValue Combined = visitADDLike(N))
9901 return Combined;
9902
9903 // fold not (setcc x, y, cc) -> setcc x y !cc
9904 // Avoid breaking: and (not(setcc x, y, cc), z) -> andn for vec
9905 unsigned N0Opcode = N0.getOpcode();
9906 SDValue LHS, RHS, CC;
9907 if (TLI.isConstTrueVal(N1) &&
9908 isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true) &&
9909 !(VT.isVector() && TLI.hasAndNot(SDValue(N, 0)) && N->hasOneUse() &&
9910 N->use_begin()->getUser()->getOpcode() == ISD::AND)) {
9911 ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
9912 LHS.getValueType());
9913 if (!LegalOperations ||
9914 TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
9915 switch (N0Opcode) {
9916 default:
9917 llvm_unreachable("Unhandled SetCC Equivalent!");
9918 case ISD::SETCC:
9919 return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
9920 case ISD::SELECT_CC:
9921 return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
9922 N0.getOperand(3), NotCC);
9923 case ISD::STRICT_FSETCC:
9924 case ISD::STRICT_FSETCCS: {
9925 if (N0.hasOneUse()) {
9926 // FIXME Can we handle multiple uses? Could we token factor the chain
9927 // results from the new/old setcc?
9928 SDValue SetCC =
9929 DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
9930 N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
9931 CombineTo(N, SetCC);
9932 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
9933 recursivelyDeleteUnusedNodes(N0.getNode());
9934 return SDValue(N, 0); // Return N so it doesn't get rechecked!
9935 }
9936 break;
9937 }
9938 }
9939 }
9940 }
9941
9942 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9943 if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9944 isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
9945 SDValue V = N0.getOperand(0);
9946 SDLoc DL0(N0);
9947 V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
9948 DAG.getConstant(1, DL0, V.getValueType()));
9949 AddToWorklist(V.getNode());
9950 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
9951 }
9952
9953 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9954 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are setcc
9955 if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
9956 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9957 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9958 if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
9959 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9960 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9961 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9962 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9963 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9964 }
9965 }
9966 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9967 // fold (not (and x, y)) -> (or (not x), (not y)) iff x or y are constants
9968 if (isAllOnesConstant(N1) && N0.hasOneUse() &&
9969 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9970 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
9971 if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
9972 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9973 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
9974 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
9975 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
9976 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
9977 }
9978 }
9979
9980 // fold (not (neg x)) -> (add X, -1)
9981 // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9982 // Y is a constant or the subtract has a single use.
9983 if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
9984 isNullConstant(N0.getOperand(0))) {
9985 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
9986 DAG.getAllOnesConstant(DL, VT));
9987 }
9988
9989 // fold (not (add X, -1)) -> (neg X)
9990 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() && isAllOnesConstant(N1) &&
9991 isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
9992 return DAG.getNegative(N0.getOperand(0), DL, VT);
9993 }
9994
9995 // fold (xor (and x, y), y) -> (and (not x), y)
9996 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
9997 SDValue X = N0.getOperand(0);
9998 SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
9999 AddToWorklist(NotX.getNode());
10000 return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
10001 }
10002
10003 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
10004 if (!LegalOperations || hasOperation(ISD::ABS, VT)) {
10005 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
10006 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
10007 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
10008 SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
10009 SDValue S0 = S.getOperand(0);
10010 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
10011 if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
10012 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
10013 return DAG.getNode(ISD::ABS, DL, VT, S0);
10014 }
10015 }
10016
10017 // fold (xor x, x) -> 0
10018 if (N0 == N1)
10019 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
10020
10021 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
10022 // Here is a concrete example of this equivalence:
10023 // i16 x == 14
10024 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
10025 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
10026 //
10027 // =>
10028 //
10029 // i16 ~1 == 0b1111111111111110
10030 // i16 rol(~1, 14) == 0b1011111111111111
10031 //
10032 // Some additional tips to help conceptualize this transform:
10033 // - Try to see the operation as placing a single zero in a value of all ones.
10034 // - There exists no value for x which would allow the result to contain zero.
10035 // - Values of x larger than the bitwidth are undefined and do not require a
10036 // consistent result.
10037 // - Pushing the zero left requires shifting one bits in from the right.
10038 // A rotate left of ~1 is a nice way of achieving the desired result.
10039 if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
10040 isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
10041 return DAG.getNode(ISD::ROTL, DL, VT, DAG.getSignedConstant(~1, DL, VT),
10042 N0.getOperand(1));
10043 }
10044
10045 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
10046 if (N0Opcode == N1.getOpcode())
10047 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
10048 return V;
10049
10050 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
10051 return R;
10052 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
10053 return R;
10054 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
10055 return R;
10056
10057 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
10058 if (SDValue MM = unfoldMaskedMerge(N))
10059 return MM;
10060
10061 // Simplify the expression using non-local knowledge.
10062 if (SimplifyDemandedBits(SDValue(N, 0)))
10063 return SDValue(N, 0);
10064
10065 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
10066 return Combined;
10067
10068 return SDValue();
10069 }
10070
10071 /// If we have a shift-by-constant of a bitwise logic op that itself has a
10072 /// shift-by-constant operand with identical opcode, we may be able to convert
10073 /// that into 2 independent shifts followed by the logic op. This is a
10074 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)10075 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
10076 // Match a one-use bitwise logic op.
10077 SDValue LogicOp = Shift->getOperand(0);
10078 if (!LogicOp.hasOneUse())
10079 return SDValue();
10080
10081 unsigned LogicOpcode = LogicOp.getOpcode();
10082 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
10083 LogicOpcode != ISD::XOR)
10084 return SDValue();
10085
10086 // Find a matching one-use shift by constant.
10087 unsigned ShiftOpcode = Shift->getOpcode();
10088 SDValue C1 = Shift->getOperand(1);
10089 ConstantSDNode *C1Node = isConstOrConstSplat(C1);
10090 assert(C1Node && "Expected a shift with constant operand");
10091 const APInt &C1Val = C1Node->getAPIntValue();
10092 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
10093 const APInt *&ShiftAmtVal) {
10094 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
10095 return false;
10096
10097 ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
10098 if (!ShiftCNode)
10099 return false;
10100
10101 // Capture the shifted operand and shift amount value.
10102 ShiftOp = V.getOperand(0);
10103 ShiftAmtVal = &ShiftCNode->getAPIntValue();
10104
10105 // Shift amount types do not have to match their operand type, so check that
10106 // the constants are the same width.
10107 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
10108 return false;
10109
10110 // The fold is not valid if the sum of the shift values doesn't fit in the
10111 // given shift amount type.
10112 bool Overflow = false;
10113 APInt NewShiftAmt = C1Val.uadd_ov(*ShiftAmtVal, Overflow);
10114 if (Overflow)
10115 return false;
10116
10117 // The fold is not valid if the sum of the shift values exceeds bitwidth.
10118 if (NewShiftAmt.uge(V.getScalarValueSizeInBits()))
10119 return false;
10120
10121 return true;
10122 };
10123
10124 // Logic ops are commutative, so check each operand for a match.
10125 SDValue X, Y;
10126 const APInt *C0Val;
10127 if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
10128 Y = LogicOp.getOperand(1);
10129 else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
10130 Y = LogicOp.getOperand(0);
10131 else
10132 return SDValue();
10133
10134 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
10135 SDLoc DL(Shift);
10136 EVT VT = Shift->getValueType(0);
10137 EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
10138 SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
10139 SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
10140 SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
10141 return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2,
10142 LogicOp->getFlags());
10143 }
10144
10145 /// Handle transforms common to the three shifts, when the shift amount is a
10146 /// constant.
10147 /// We are looking for: (shift being one of shl/sra/srl)
10148 /// shift (binop X, C0), C1
10149 /// And want to transform into:
10150 /// binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)10151 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
10152 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
10153
10154 // Do not turn a 'not' into a regular xor.
10155 if (isBitwiseNot(N->getOperand(0)))
10156 return SDValue();
10157
10158 // The inner binop must be one-use, since we want to replace it.
10159 SDValue LHS = N->getOperand(0);
10160 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
10161 return SDValue();
10162
10163 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
10164 if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
10165 return R;
10166
10167 // We want to pull some binops through shifts, so that we have (and (shift))
10168 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
10169 // thing happens with address calculations, so it's important to canonicalize
10170 // it.
10171 switch (LHS.getOpcode()) {
10172 default:
10173 return SDValue();
10174 case ISD::OR:
10175 case ISD::XOR:
10176 case ISD::AND:
10177 break;
10178 case ISD::ADD:
10179 if (N->getOpcode() != ISD::SHL)
10180 return SDValue(); // only shl(add) not sr[al](add).
10181 break;
10182 }
10183
10184 // FIXME: disable this unless the input to the binop is a shift by a constant
10185 // or is copy/select. Enable this in other cases when figure out it's exactly
10186 // profitable.
10187 SDValue BinOpLHSVal = LHS.getOperand(0);
10188 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
10189 BinOpLHSVal.getOpcode() == ISD::SRA ||
10190 BinOpLHSVal.getOpcode() == ISD::SRL) &&
10191 isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
10192 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
10193 BinOpLHSVal.getOpcode() == ISD::SELECT;
10194
10195 if (!IsShiftByConstant && !IsCopyOrSelect)
10196 return SDValue();
10197
10198 if (IsCopyOrSelect && N->hasOneUse())
10199 return SDValue();
10200
10201 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
10202 SDLoc DL(N);
10203 EVT VT = N->getValueType(0);
10204 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
10205 N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
10206 SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
10207 N->getOperand(1));
10208 return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
10209 }
10210
10211 return SDValue();
10212 }
10213
distributeTruncateThroughAnd(SDNode * N)10214 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
10215 assert(N->getOpcode() == ISD::TRUNCATE);
10216 assert(N->getOperand(0).getOpcode() == ISD::AND);
10217
10218 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
10219 EVT TruncVT = N->getValueType(0);
10220 if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
10221 TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
10222 SDValue N01 = N->getOperand(0).getOperand(1);
10223 if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
10224 SDLoc DL(N);
10225 SDValue N00 = N->getOperand(0).getOperand(0);
10226 SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
10227 SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
10228 AddToWorklist(Trunc00.getNode());
10229 AddToWorklist(Trunc01.getNode());
10230 return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
10231 }
10232 }
10233
10234 return SDValue();
10235 }
10236
visitRotate(SDNode * N)10237 SDValue DAGCombiner::visitRotate(SDNode *N) {
10238 SDLoc dl(N);
10239 SDValue N0 = N->getOperand(0);
10240 SDValue N1 = N->getOperand(1);
10241 EVT VT = N->getValueType(0);
10242 unsigned Bitsize = VT.getScalarSizeInBits();
10243
10244 // fold (rot x, 0) -> x
10245 if (isNullOrNullSplat(N1))
10246 return N0;
10247
10248 // fold (rot x, c) -> x iff (c % BitSize) == 0
10249 if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
10250 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
10251 if (DAG.MaskedValueIsZero(N1, ModuloMask))
10252 return N0;
10253 }
10254
10255 // fold (rot x, c) -> (rot x, c % BitSize)
10256 bool OutOfRange = false;
10257 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
10258 OutOfRange |= C->getAPIntValue().uge(Bitsize);
10259 return true;
10260 };
10261 if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
10262 EVT AmtVT = N1.getValueType();
10263 SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
10264 if (SDValue Amt =
10265 DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
10266 return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
10267 }
10268
10269 // rot i16 X, 8 --> bswap X
10270 auto *RotAmtC = isConstOrConstSplat(N1);
10271 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
10272 VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
10273 return DAG.getNode(ISD::BSWAP, dl, VT, N0);
10274
10275 // Simplify the operands using demanded-bits information.
10276 if (SimplifyDemandedBits(SDValue(N, 0)))
10277 return SDValue(N, 0);
10278
10279 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
10280 if (N1.getOpcode() == ISD::TRUNCATE &&
10281 N1.getOperand(0).getOpcode() == ISD::AND) {
10282 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10283 return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
10284 }
10285
10286 unsigned NextOp = N0.getOpcode();
10287
10288 // fold (rot* (rot* x, c2), c1)
10289 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
10290 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
10291 bool C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
10292 bool C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
10293 if (C1 && C2 && N1.getValueType() == N0.getOperand(1).getValueType()) {
10294 EVT ShiftVT = N1.getValueType();
10295 bool SameSide = (N->getOpcode() == NextOp);
10296 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
10297 SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
10298 SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
10299 {N1, BitsizeC});
10300 SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
10301 {N0.getOperand(1), BitsizeC});
10302 if (Norm1 && Norm2)
10303 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
10304 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
10305 CombinedShift = DAG.FoldConstantArithmetic(ISD::ADD, dl, ShiftVT,
10306 {CombinedShift, BitsizeC});
10307 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
10308 ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
10309 return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
10310 CombinedShiftNorm);
10311 }
10312 }
10313 }
10314 return SDValue();
10315 }
10316
visitSHL(SDNode * N)10317 SDValue DAGCombiner::visitSHL(SDNode *N) {
10318 SDValue N0 = N->getOperand(0);
10319 SDValue N1 = N->getOperand(1);
10320 if (SDValue V = DAG.simplifyShift(N0, N1))
10321 return V;
10322
10323 SDLoc DL(N);
10324 EVT VT = N0.getValueType();
10325 EVT ShiftVT = N1.getValueType();
10326 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10327
10328 // fold (shl c1, c2) -> c1<<c2
10329 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N0, N1}))
10330 return C;
10331
10332 // fold vector ops
10333 if (VT.isVector()) {
10334 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10335 return FoldedVOp;
10336
10337 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
10338 // If setcc produces all-one true value then:
10339 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
10340 if (N1CV && N1CV->isConstant()) {
10341 if (N0.getOpcode() == ISD::AND) {
10342 SDValue N00 = N0->getOperand(0);
10343 SDValue N01 = N0->getOperand(1);
10344 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
10345
10346 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
10347 TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
10348 TargetLowering::ZeroOrNegativeOneBooleanContent) {
10349 if (SDValue C =
10350 DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N01, N1}))
10351 return DAG.getNode(ISD::AND, DL, VT, N00, C);
10352 }
10353 }
10354 }
10355 }
10356
10357 if (SDValue NewSel = foldBinOpIntoSelect(N))
10358 return NewSel;
10359
10360 // if (shl x, c) is known to be zero, return 0
10361 if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
10362 return DAG.getConstant(0, DL, VT);
10363
10364 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
10365 if (N1.getOpcode() == ISD::TRUNCATE &&
10366 N1.getOperand(0).getOpcode() == ISD::AND) {
10367 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10368 return DAG.getNode(ISD::SHL, DL, VT, N0, NewOp1);
10369 }
10370
10371 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
10372 if (N0.getOpcode() == ISD::SHL) {
10373 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10374 ConstantSDNode *RHS) {
10375 APInt c1 = LHS->getAPIntValue();
10376 APInt c2 = RHS->getAPIntValue();
10377 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10378 return (c1 + c2).uge(OpSizeInBits);
10379 };
10380 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
10381 return DAG.getConstant(0, DL, VT);
10382
10383 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10384 ConstantSDNode *RHS) {
10385 APInt c1 = LHS->getAPIntValue();
10386 APInt c2 = RHS->getAPIntValue();
10387 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10388 return (c1 + c2).ult(OpSizeInBits);
10389 };
10390 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
10391 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
10392 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
10393 }
10394 }
10395
10396 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
10397 // For this to be valid, the second form must not preserve any of the bits
10398 // that are shifted out by the inner shift in the first form. This means
10399 // the outer shift size must be >= the number of bits added by the ext.
10400 // As a corollary, we don't care what kind of ext it is.
10401 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
10402 N0.getOpcode() == ISD::ANY_EXTEND ||
10403 N0.getOpcode() == ISD::SIGN_EXTEND) &&
10404 N0.getOperand(0).getOpcode() == ISD::SHL) {
10405 SDValue N0Op0 = N0.getOperand(0);
10406 SDValue InnerShiftAmt = N0Op0.getOperand(1);
10407 EVT InnerVT = N0Op0.getValueType();
10408 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
10409
10410 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10411 ConstantSDNode *RHS) {
10412 APInt c1 = LHS->getAPIntValue();
10413 APInt c2 = RHS->getAPIntValue();
10414 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10415 return c2.uge(OpSizeInBits - InnerBitwidth) &&
10416 (c1 + c2).uge(OpSizeInBits);
10417 };
10418 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
10419 /*AllowUndefs*/ false,
10420 /*AllowTypeMismatch*/ true))
10421 return DAG.getConstant(0, DL, VT);
10422
10423 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
10424 ConstantSDNode *RHS) {
10425 APInt c1 = LHS->getAPIntValue();
10426 APInt c2 = RHS->getAPIntValue();
10427 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10428 return c2.uge(OpSizeInBits - InnerBitwidth) &&
10429 (c1 + c2).ult(OpSizeInBits);
10430 };
10431 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
10432 /*AllowUndefs*/ false,
10433 /*AllowTypeMismatch*/ true)) {
10434 SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
10435 SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
10436 Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
10437 return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
10438 }
10439 }
10440
10441 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
10442 // Only fold this if the inner zext has no other uses to avoid increasing
10443 // the total number of instructions.
10444 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
10445 N0.getOperand(0).getOpcode() == ISD::SRL) {
10446 SDValue N0Op0 = N0.getOperand(0);
10447 SDValue InnerShiftAmt = N0Op0.getOperand(1);
10448
10449 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10450 APInt c1 = LHS->getAPIntValue();
10451 APInt c2 = RHS->getAPIntValue();
10452 zeroExtendToMatch(c1, c2);
10453 return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
10454 };
10455 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
10456 /*AllowUndefs*/ false,
10457 /*AllowTypeMismatch*/ true)) {
10458 EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
10459 SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
10460 NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
10461 AddToWorklist(NewSHL.getNode());
10462 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
10463 }
10464 }
10465
10466 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
10467 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10468 ConstantSDNode *RHS) {
10469 const APInt &LHSC = LHS->getAPIntValue();
10470 const APInt &RHSC = RHS->getAPIntValue();
10471 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
10472 LHSC.getZExtValue() <= RHSC.getZExtValue();
10473 };
10474
10475 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
10476 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
10477 if (N0->getFlags().hasExact()) {
10478 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10479 /*AllowUndefs*/ false,
10480 /*AllowTypeMismatch*/ true)) {
10481 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10482 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10483 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10484 }
10485 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10486 /*AllowUndefs*/ false,
10487 /*AllowTypeMismatch*/ true)) {
10488 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10489 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10490 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
10491 }
10492 }
10493
10494 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
10495 // (and (srl x, (sub c1, c2), MASK)
10496 // Only fold this if the inner shift has no other uses -- if it does,
10497 // folding this will increase the total number of instructions.
10498 if (N0.getOpcode() == ISD::SRL &&
10499 (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
10500 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10501 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
10502 /*AllowUndefs*/ false,
10503 /*AllowTypeMismatch*/ true)) {
10504 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10505 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
10506 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10507 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
10508 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
10509 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
10510 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10511 }
10512 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
10513 /*AllowUndefs*/ false,
10514 /*AllowTypeMismatch*/ true)) {
10515 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
10516 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
10517 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10518 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
10519 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
10520 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
10521 }
10522 }
10523 }
10524
10525 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
10526 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
10527 isConstantOrConstantVector(N1, /* No Opaques */ true)) {
10528 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
10529 SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
10530 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
10531 }
10532
10533 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
10534 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
10535 // Variant of version done on multiply, except mul by a power of 2 is turned
10536 // into a shift.
10537 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
10538 TLI.isDesirableToCommuteWithShift(N, Level)) {
10539 SDValue N01 = N0.getOperand(1);
10540 if (SDValue Shl1 =
10541 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
10542 SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
10543 AddToWorklist(Shl0.getNode());
10544 SDNodeFlags Flags;
10545 // Preserve the disjoint flag for Or.
10546 if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
10547 Flags |= SDNodeFlags::Disjoint;
10548 return DAG.getNode(N0.getOpcode(), DL, VT, Shl0, Shl1, Flags);
10549 }
10550 }
10551
10552 // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
10553 // TODO: Add zext/add_nuw variant with suitable test coverage
10554 // TODO: Should we limit this with isLegalAddImmediate?
10555 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
10556 N0.getOperand(0).getOpcode() == ISD::ADD &&
10557 N0.getOperand(0)->getFlags().hasNoSignedWrap() &&
10558 TLI.isDesirableToCommuteWithShift(N, Level)) {
10559 SDValue Add = N0.getOperand(0);
10560 SDLoc DL(N0);
10561 if (SDValue ExtC = DAG.FoldConstantArithmetic(N0.getOpcode(), DL, VT,
10562 {Add.getOperand(1)})) {
10563 if (SDValue ShlC =
10564 DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {ExtC, N1})) {
10565 SDValue ExtX = DAG.getNode(N0.getOpcode(), DL, VT, Add.getOperand(0));
10566 SDValue ShlX = DAG.getNode(ISD::SHL, DL, VT, ExtX, N1);
10567 return DAG.getNode(ISD::ADD, DL, VT, ShlX, ShlC);
10568 }
10569 }
10570 }
10571
10572 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
10573 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
10574 SDValue N01 = N0.getOperand(1);
10575 if (SDValue Shl =
10576 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
10577 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), Shl);
10578 }
10579
10580 ConstantSDNode *N1C = isConstOrConstSplat(N1);
10581 if (N1C && !N1C->isOpaque())
10582 if (SDValue NewSHL = visitShiftByConstant(N))
10583 return NewSHL;
10584
10585 // fold (shl X, cttz(Y)) -> (mul (Y & -Y), X) if cttz is unsupported on the
10586 // target.
10587 if (((N1.getOpcode() == ISD::CTTZ &&
10588 VT.getScalarSizeInBits() <= ShiftVT.getScalarSizeInBits()) ||
10589 N1.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
10590 N1.hasOneUse() && !TLI.isOperationLegalOrCustom(ISD::CTTZ, ShiftVT) &&
10591 TLI.isOperationLegalOrCustom(ISD::MUL, VT)) {
10592 SDValue Y = N1.getOperand(0);
10593 SDLoc DL(N);
10594 SDValue NegY = DAG.getNegative(Y, DL, ShiftVT);
10595 SDValue And =
10596 DAG.getZExtOrTrunc(DAG.getNode(ISD::AND, DL, ShiftVT, Y, NegY), DL, VT);
10597 return DAG.getNode(ISD::MUL, DL, VT, And, N0);
10598 }
10599
10600 if (SimplifyDemandedBits(SDValue(N, 0)))
10601 return SDValue(N, 0);
10602
10603 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
10604 if (N0.getOpcode() == ISD::VSCALE && N1C) {
10605 const APInt &C0 = N0.getConstantOperandAPInt(0);
10606 const APInt &C1 = N1C->getAPIntValue();
10607 return DAG.getVScale(DL, VT, C0 << C1);
10608 }
10609
10610 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
10611 APInt ShlVal;
10612 if (N0.getOpcode() == ISD::STEP_VECTOR &&
10613 ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
10614 const APInt &C0 = N0.getConstantOperandAPInt(0);
10615 if (ShlVal.ult(C0.getBitWidth())) {
10616 APInt NewStep = C0 << ShlVal;
10617 return DAG.getStepVector(DL, VT, NewStep);
10618 }
10619 }
10620
10621 return SDValue();
10622 }
10623
10624 // Transform a right shift of a multiply into a multiply-high.
10625 // Examples:
10626 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10627 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const TargetLowering & TLI)10628 static SDValue combineShiftToMULH(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
10629 const TargetLowering &TLI) {
10630 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10631 "SRL or SRA node is required here!");
10632
10633 // Check the shift amount. Proceed with the transformation if the shift
10634 // amount is constant.
10635 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
10636 if (!ShiftAmtSrc)
10637 return SDValue();
10638
10639 // The operation feeding into the shift must be a multiply.
10640 SDValue ShiftOperand = N->getOperand(0);
10641 if (ShiftOperand.getOpcode() != ISD::MUL)
10642 return SDValue();
10643
10644 // Both operands must be equivalent extend nodes.
10645 SDValue LeftOp = ShiftOperand.getOperand(0);
10646 SDValue RightOp = ShiftOperand.getOperand(1);
10647
10648 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10649 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10650
10651 if (!IsSignExt && !IsZeroExt)
10652 return SDValue();
10653
10654 EVT NarrowVT = LeftOp.getOperand(0).getValueType();
10655 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10656
10657 // return true if U may use the lower bits of its operands
10658 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10659 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10660 return true;
10661 }
10662 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(U->getOperand(1));
10663 if (!UShiftAmtSrc) {
10664 return true;
10665 }
10666 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10667 return UShiftAmt < NarrowVTSize;
10668 };
10669
10670 // If the lower part of the MUL is also used and MUL_LOHI is supported
10671 // do not introduce the MULH in favor of MUL_LOHI
10672 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10673 if (!ShiftOperand.hasOneUse() &&
10674 TLI.isOperationLegalOrCustom(MulLoHiOp, NarrowVT) &&
10675 llvm::any_of(ShiftOperand->users(), UserOfLowerBits)) {
10676 return SDValue();
10677 }
10678
10679 SDValue MulhRightOp;
10680 if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
10681 unsigned ActiveBits = IsSignExt
10682 ? Constant->getAPIntValue().getSignificantBits()
10683 : Constant->getAPIntValue().getActiveBits();
10684 if (ActiveBits > NarrowVTSize)
10685 return SDValue();
10686 MulhRightOp = DAG.getConstant(
10687 Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
10688 NarrowVT);
10689 } else {
10690 if (LeftOp.getOpcode() != RightOp.getOpcode())
10691 return SDValue();
10692 // Check that the two extend nodes are the same type.
10693 if (NarrowVT != RightOp.getOperand(0).getValueType())
10694 return SDValue();
10695 MulhRightOp = RightOp.getOperand(0);
10696 }
10697
10698 EVT WideVT = LeftOp.getValueType();
10699 // Proceed with the transformation if the wide types match.
10700 assert((WideVT == RightOp.getValueType()) &&
10701 "Cannot have a multiply node with two different operand types.");
10702
10703 // Proceed with the transformation if the wide type is twice as large
10704 // as the narrow type.
10705 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10706 return SDValue();
10707
10708 // Check the shift amount with the narrow type size.
10709 // Proceed with the transformation if the shift amount is the width
10710 // of the narrow type.
10711 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10712 if (ShiftAmt != NarrowVTSize)
10713 return SDValue();
10714
10715 // If the operation feeding into the MUL is a sign extend (sext),
10716 // we use mulhs. Othewise, zero extends (zext) use mulhu.
10717 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10718
10719 // Combine to mulh if mulh is legal/custom for the narrow type on the target
10720 // or if it is a vector type then we could transform to an acceptable type and
10721 // rely on legalization to split/combine the result.
10722 if (NarrowVT.isVector()) {
10723 EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), NarrowVT);
10724 if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10725 !TLI.isOperationLegalOrCustom(MulhOpcode, TransformVT))
10726 return SDValue();
10727 } else {
10728 if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
10729 return SDValue();
10730 }
10731
10732 SDValue Result =
10733 DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
10734 bool IsSigned = N->getOpcode() == ISD::SRA;
10735 return DAG.getExtOrTrunc(IsSigned, Result, DL, WideVT);
10736 }
10737
10738 // fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10739 // This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
foldBitOrderCrossLogicOp(SDNode * N,SelectionDAG & DAG)10740 static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10741 unsigned Opcode = N->getOpcode();
10742 if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10743 return SDValue();
10744
10745 SDValue N0 = N->getOperand(0);
10746 EVT VT = N->getValueType(0);
10747 SDLoc DL(N);
10748 SDValue X, Y;
10749
10750 // If both operands are bswap/bitreverse, ignore the multiuse
10751 if (sd_match(N0, m_OneUse(m_BitwiseLogic(m_UnaryOp(Opcode, m_Value(X)),
10752 m_UnaryOp(Opcode, m_Value(Y))))))
10753 return DAG.getNode(N0.getOpcode(), DL, VT, X, Y);
10754
10755 // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10756 if (sd_match(N0, m_OneUse(m_BitwiseLogic(
10757 m_OneUse(m_UnaryOp(Opcode, m_Value(X))), m_Value(Y))))) {
10758 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Y);
10759 return DAG.getNode(N0.getOpcode(), DL, VT, X, NewBitReorder);
10760 }
10761
10762 return SDValue();
10763 }
10764
visitSRA(SDNode * N)10765 SDValue DAGCombiner::visitSRA(SDNode *N) {
10766 SDValue N0 = N->getOperand(0);
10767 SDValue N1 = N->getOperand(1);
10768 if (SDValue V = DAG.simplifyShift(N0, N1))
10769 return V;
10770
10771 SDLoc DL(N);
10772 EVT VT = N0.getValueType();
10773 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10774
10775 // fold (sra c1, c2) -> (sra c1, c2)
10776 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, DL, VT, {N0, N1}))
10777 return C;
10778
10779 // Arithmetic shifting an all-sign-bit value is a no-op.
10780 // fold (sra 0, x) -> 0
10781 // fold (sra -1, x) -> -1
10782 if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
10783 return N0;
10784
10785 // fold vector ops
10786 if (VT.isVector())
10787 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10788 return FoldedVOp;
10789
10790 if (SDValue NewSel = foldBinOpIntoSelect(N))
10791 return NewSel;
10792
10793 ConstantSDNode *N1C = isConstOrConstSplat(N1);
10794
10795 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10796 // clamp (add c1, c2) to max shift.
10797 if (N0.getOpcode() == ISD::SRA) {
10798 EVT ShiftVT = N1.getValueType();
10799 EVT ShiftSVT = ShiftVT.getScalarType();
10800 SmallVector<SDValue, 16> ShiftValues;
10801
10802 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10803 APInt c1 = LHS->getAPIntValue();
10804 APInt c2 = RHS->getAPIntValue();
10805 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
10806 APInt Sum = c1 + c2;
10807 unsigned ShiftSum =
10808 Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10809 ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
10810 return true;
10811 };
10812 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
10813 SDValue ShiftValue;
10814 if (N1.getOpcode() == ISD::BUILD_VECTOR)
10815 ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
10816 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10817 assert(ShiftValues.size() == 1 &&
10818 "Expected matchBinaryPredicate to return one element for "
10819 "SPLAT_VECTORs");
10820 ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
10821 } else
10822 ShiftValue = ShiftValues[0];
10823 return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
10824 }
10825 }
10826
10827 // fold (sra (shl X, m), (sub result_size, n))
10828 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10829 // result_size - n != m.
10830 // If truncate is free for the target sext(shl) is likely to result in better
10831 // code.
10832 if (N0.getOpcode() == ISD::SHL && N1C) {
10833 // Get the two constants of the shifts, CN0 = m, CN = n.
10834 const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
10835 if (N01C) {
10836 LLVMContext &Ctx = *DAG.getContext();
10837 // Determine what the truncate's result bitsize and type would be.
10838 EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
10839
10840 if (VT.isVector())
10841 TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10842
10843 // Determine the residual right-shift amount.
10844 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10845
10846 // If the shift is not a no-op (in which case this should be just a sign
10847 // extend already), the truncated to type is legal, sign_extend is legal
10848 // on that type, and the truncate to that type is both legal and free,
10849 // perform the transform.
10850 if ((ShiftAmt > 0) &&
10851 TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
10852 TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
10853 TLI.isTruncateFree(VT, TruncVT)) {
10854 SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
10855 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
10856 N0.getOperand(0), Amt);
10857 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
10858 Shift);
10859 return DAG.getNode(ISD::SIGN_EXTEND, DL,
10860 N->getValueType(0), Trunc);
10861 }
10862 }
10863 }
10864
10865 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10866 // sra (add (shl X, N1C), AddC), N1C -->
10867 // sext (add (trunc X to (width - N1C)), AddC')
10868 // sra (sub AddC, (shl X, N1C)), N1C -->
10869 // sext (sub AddC1',(trunc X to (width - N1C)))
10870 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10871 N0.hasOneUse()) {
10872 bool IsAdd = N0.getOpcode() == ISD::ADD;
10873 SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
10874 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
10875 Shl.hasOneUse()) {
10876 // TODO: AddC does not need to be a splat.
10877 if (ConstantSDNode *AddC =
10878 isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
10879 // Determine what the truncate's type would be and ask the target if
10880 // that is a free operation.
10881 LLVMContext &Ctx = *DAG.getContext();
10882 unsigned ShiftAmt = N1C->getZExtValue();
10883 EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
10884 if (VT.isVector())
10885 TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
10886
10887 // TODO: The simple type check probably belongs in the default hook
10888 // implementation and/or target-specific overrides (because
10889 // non-simple types likely require masking when legalized), but
10890 // that restriction may conflict with other transforms.
10891 if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
10892 TLI.isTruncateFree(VT, TruncVT)) {
10893 SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
10894 SDValue ShiftC =
10895 DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
10896 TruncVT.getScalarSizeInBits()),
10897 DL, TruncVT);
10898 SDValue Add;
10899 if (IsAdd)
10900 Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
10901 else
10902 Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
10903 return DAG.getSExtOrTrunc(Add, DL, VT);
10904 }
10905 }
10906 }
10907 }
10908
10909 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10910 if (N1.getOpcode() == ISD::TRUNCATE &&
10911 N1.getOperand(0).getOpcode() == ISD::AND) {
10912 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
10913 return DAG.getNode(ISD::SRA, DL, VT, N0, NewOp1);
10914 }
10915
10916 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10917 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10918 // if c1 is equal to the number of bits the trunc removes
10919 // TODO - support non-uniform vector shift amounts.
10920 if (N0.getOpcode() == ISD::TRUNCATE &&
10921 (N0.getOperand(0).getOpcode() == ISD::SRL ||
10922 N0.getOperand(0).getOpcode() == ISD::SRA) &&
10923 N0.getOperand(0).hasOneUse() &&
10924 N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
10925 SDValue N0Op0 = N0.getOperand(0);
10926 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
10927 EVT LargeVT = N0Op0.getValueType();
10928 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10929 if (LargeShift->getAPIntValue() == TruncBits) {
10930 EVT LargeShiftVT = getShiftAmountTy(LargeVT);
10931 SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
10932 Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
10933 DAG.getConstant(TruncBits, DL, LargeShiftVT));
10934 SDValue SRA =
10935 DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
10936 return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
10937 }
10938 }
10939 }
10940
10941 // Simplify, based on bits shifted out of the LHS.
10942 if (SimplifyDemandedBits(SDValue(N, 0)))
10943 return SDValue(N, 0);
10944
10945 // If the sign bit is known to be zero, switch this to a SRL.
10946 if (DAG.SignBitIsZero(N0))
10947 return DAG.getNode(ISD::SRL, DL, VT, N0, N1);
10948
10949 if (N1C && !N1C->isOpaque())
10950 if (SDValue NewSRA = visitShiftByConstant(N))
10951 return NewSRA;
10952
10953 // Try to transform this shift into a multiply-high if
10954 // it matches the appropriate pattern detected in combineShiftToMULH.
10955 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
10956 return MULH;
10957
10958 // Attempt to convert a sra of a load into a narrower sign-extending load.
10959 if (SDValue NarrowLoad = reduceLoadWidth(N))
10960 return NarrowLoad;
10961
10962 if (SDValue AVG = foldShiftToAvg(N))
10963 return AVG;
10964
10965 return SDValue();
10966 }
10967
visitSRL(SDNode * N)10968 SDValue DAGCombiner::visitSRL(SDNode *N) {
10969 SDValue N0 = N->getOperand(0);
10970 SDValue N1 = N->getOperand(1);
10971 if (SDValue V = DAG.simplifyShift(N0, N1))
10972 return V;
10973
10974 SDLoc DL(N);
10975 EVT VT = N0.getValueType();
10976 EVT ShiftVT = N1.getValueType();
10977 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10978
10979 // fold (srl c1, c2) -> c1 >>u c2
10980 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, DL, VT, {N0, N1}))
10981 return C;
10982
10983 // fold vector ops
10984 if (VT.isVector())
10985 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
10986 return FoldedVOp;
10987
10988 if (SDValue NewSel = foldBinOpIntoSelect(N))
10989 return NewSel;
10990
10991 // if (srl x, c) is known to be zero, return 0
10992 ConstantSDNode *N1C = isConstOrConstSplat(N1);
10993 if (N1C &&
10994 DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
10995 return DAG.getConstant(0, DL, VT);
10996
10997 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10998 if (N0.getOpcode() == ISD::SRL) {
10999 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
11000 ConstantSDNode *RHS) {
11001 APInt c1 = LHS->getAPIntValue();
11002 APInt c2 = RHS->getAPIntValue();
11003 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
11004 return (c1 + c2).uge(OpSizeInBits);
11005 };
11006 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
11007 return DAG.getConstant(0, DL, VT);
11008
11009 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
11010 ConstantSDNode *RHS) {
11011 APInt c1 = LHS->getAPIntValue();
11012 APInt c2 = RHS->getAPIntValue();
11013 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
11014 return (c1 + c2).ult(OpSizeInBits);
11015 };
11016 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
11017 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
11018 return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
11019 }
11020 }
11021
11022 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
11023 N0.getOperand(0).getOpcode() == ISD::SRL) {
11024 SDValue InnerShift = N0.getOperand(0);
11025 // TODO - support non-uniform vector shift amounts.
11026 if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
11027 uint64_t c1 = N001C->getZExtValue();
11028 uint64_t c2 = N1C->getZExtValue();
11029 EVT InnerShiftVT = InnerShift.getValueType();
11030 EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
11031 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
11032 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
11033 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
11034 if (c1 + OpSizeInBits == InnerShiftSize) {
11035 if (c1 + c2 >= InnerShiftSize)
11036 return DAG.getConstant(0, DL, VT);
11037 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
11038 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
11039 InnerShift.getOperand(0), NewShiftAmt);
11040 return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
11041 }
11042 // In the more general case, we can clear the high bits after the shift:
11043 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
11044 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
11045 c1 + c2 < InnerShiftSize) {
11046 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
11047 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
11048 InnerShift.getOperand(0), NewShiftAmt);
11049 SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
11050 OpSizeInBits - c2),
11051 DL, InnerShiftVT);
11052 SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
11053 return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
11054 }
11055 }
11056 }
11057
11058 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
11059 // (and (srl x, (sub c2, c1), MASK)
11060 if (N0.getOpcode() == ISD::SHL &&
11061 (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
11062 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
11063 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
11064 ConstantSDNode *RHS) {
11065 const APInt &LHSC = LHS->getAPIntValue();
11066 const APInt &RHSC = RHS->getAPIntValue();
11067 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
11068 LHSC.getZExtValue() <= RHSC.getZExtValue();
11069 };
11070 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
11071 /*AllowUndefs*/ false,
11072 /*AllowTypeMismatch*/ true)) {
11073 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
11074 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
11075 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11076 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
11077 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
11078 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
11079 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
11080 }
11081 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
11082 /*AllowUndefs*/ false,
11083 /*AllowTypeMismatch*/ true)) {
11084 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
11085 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
11086 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
11087 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
11088 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
11089 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
11090 }
11091 }
11092
11093 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
11094 // TODO - support non-uniform vector shift amounts.
11095 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
11096 // Shifting in all undef bits?
11097 EVT SmallVT = N0.getOperand(0).getValueType();
11098 unsigned BitSize = SmallVT.getScalarSizeInBits();
11099 if (N1C->getAPIntValue().uge(BitSize))
11100 return DAG.getUNDEF(VT);
11101
11102 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
11103 uint64_t ShiftAmt = N1C->getZExtValue();
11104 SDLoc DL0(N0);
11105 SDValue SmallShift =
11106 DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
11107 DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
11108 AddToWorklist(SmallShift.getNode());
11109 APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
11110 return DAG.getNode(ISD::AND, DL, VT,
11111 DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
11112 DAG.getConstant(Mask, DL, VT));
11113 }
11114 }
11115
11116 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
11117 // bit, which is unmodified by sra.
11118 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
11119 if (N0.getOpcode() == ISD::SRA)
11120 return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), N1);
11121 }
11122
11123 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power
11124 // of two bitwidth. The "5" represents (log2 (bitwidth x)).
11125 if (N1C && N0.getOpcode() == ISD::CTLZ &&
11126 isPowerOf2_32(OpSizeInBits) &&
11127 N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
11128 KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
11129
11130 // If any of the input bits are KnownOne, then the input couldn't be all
11131 // zeros, thus the result of the srl will always be zero.
11132 if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
11133
11134 // If all of the bits input the to ctlz node are known to be zero, then
11135 // the result of the ctlz is "32" and the result of the shift is one.
11136 APInt UnknownBits = ~Known.Zero;
11137 if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
11138
11139 // Otherwise, check to see if there is exactly one bit input to the ctlz.
11140 if (UnknownBits.isPowerOf2()) {
11141 // Okay, we know that only that the single bit specified by UnknownBits
11142 // could be set on input to the CTLZ node. If this bit is set, the SRL
11143 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
11144 // to an SRL/XOR pair, which is likely to simplify more.
11145 unsigned ShAmt = UnknownBits.countr_zero();
11146 SDValue Op = N0.getOperand(0);
11147
11148 if (ShAmt) {
11149 SDLoc DL(N0);
11150 Op = DAG.getNode(ISD::SRL, DL, VT, Op,
11151 DAG.getShiftAmountConstant(ShAmt, VT, DL));
11152 AddToWorklist(Op.getNode());
11153 }
11154 return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
11155 }
11156 }
11157
11158 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
11159 if (N1.getOpcode() == ISD::TRUNCATE &&
11160 N1.getOperand(0).getOpcode() == ISD::AND) {
11161 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
11162 return DAG.getNode(ISD::SRL, DL, VT, N0, NewOp1);
11163 }
11164
11165 // fold (srl (logic_op x, (shl (zext y), c1)), c1)
11166 // -> (logic_op (srl x, c1), (zext y))
11167 // c1 <= leadingzeros(zext(y))
11168 SDValue X, ZExtY;
11169 if (N1C && sd_match(N0, m_OneUse(m_BitwiseLogic(
11170 m_Value(X),
11171 m_OneUse(m_Shl(m_AllOf(m_Value(ZExtY),
11172 m_Opc(ISD::ZERO_EXTEND)),
11173 m_Specific(N1))))))) {
11174 unsigned NumLeadingZeros = ZExtY.getScalarValueSizeInBits() -
11175 ZExtY.getOperand(0).getScalarValueSizeInBits();
11176 if (N1C->getZExtValue() <= NumLeadingZeros)
11177 return DAG.getNode(N0.getOpcode(), SDLoc(N0), VT,
11178 DAG.getNode(ISD::SRL, SDLoc(N0), VT, X, N1), ZExtY);
11179 }
11180
11181 // fold operands of srl based on knowledge that the low bits are not
11182 // demanded.
11183 if (SimplifyDemandedBits(SDValue(N, 0)))
11184 return SDValue(N, 0);
11185
11186 if (N1C && !N1C->isOpaque())
11187 if (SDValue NewSRL = visitShiftByConstant(N))
11188 return NewSRL;
11189
11190 // Attempt to convert a srl of a load into a narrower zero-extending load.
11191 if (SDValue NarrowLoad = reduceLoadWidth(N))
11192 return NarrowLoad;
11193
11194 // Here is a common situation. We want to optimize:
11195 //
11196 // %a = ...
11197 // %b = and i32 %a, 2
11198 // %c = srl i32 %b, 1
11199 // brcond i32 %c ...
11200 //
11201 // into
11202 //
11203 // %a = ...
11204 // %b = and %a, 2
11205 // %c = setcc eq %b, 0
11206 // brcond %c ...
11207 //
11208 // However when after the source operand of SRL is optimized into AND, the SRL
11209 // itself may not be optimized further. Look for it and add the BRCOND into
11210 // the worklist.
11211 //
11212 // The also tends to happen for binary operations when SimplifyDemandedBits
11213 // is involved.
11214 //
11215 // FIXME: This is unecessary if we process the DAG in topological order,
11216 // which we plan to do. This workaround can be removed once the DAG is
11217 // processed in topological order.
11218 if (N->hasOneUse()) {
11219 SDNode *User = *N->user_begin();
11220
11221 // Look pass the truncate.
11222 if (User->getOpcode() == ISD::TRUNCATE && User->hasOneUse())
11223 User = *User->user_begin();
11224
11225 if (User->getOpcode() == ISD::BRCOND || User->getOpcode() == ISD::AND ||
11226 User->getOpcode() == ISD::OR || User->getOpcode() == ISD::XOR)
11227 AddToWorklist(User);
11228 }
11229
11230 // Try to transform this shift into a multiply-high if
11231 // it matches the appropriate pattern detected in combineShiftToMULH.
11232 if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
11233 return MULH;
11234
11235 if (SDValue AVG = foldShiftToAvg(N))
11236 return AVG;
11237
11238 return SDValue();
11239 }
11240
visitFunnelShift(SDNode * N)11241 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
11242 EVT VT = N->getValueType(0);
11243 SDValue N0 = N->getOperand(0);
11244 SDValue N1 = N->getOperand(1);
11245 SDValue N2 = N->getOperand(2);
11246 bool IsFSHL = N->getOpcode() == ISD::FSHL;
11247 unsigned BitWidth = VT.getScalarSizeInBits();
11248 SDLoc DL(N);
11249
11250 // fold (fshl N0, N1, 0) -> N0
11251 // fold (fshr N0, N1, 0) -> N1
11252 if (isPowerOf2_32(BitWidth))
11253 if (DAG.MaskedValueIsZero(
11254 N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
11255 return IsFSHL ? N0 : N1;
11256
11257 auto IsUndefOrZero = [](SDValue V) {
11258 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
11259 };
11260
11261 // TODO - support non-uniform vector shift amounts.
11262 if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
11263 EVT ShAmtTy = N2.getValueType();
11264
11265 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
11266 if (Cst->getAPIntValue().uge(BitWidth)) {
11267 uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
11268 return DAG.getNode(N->getOpcode(), DL, VT, N0, N1,
11269 DAG.getConstant(RotAmt, DL, ShAmtTy));
11270 }
11271
11272 unsigned ShAmt = Cst->getZExtValue();
11273 if (ShAmt == 0)
11274 return IsFSHL ? N0 : N1;
11275
11276 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
11277 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
11278 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
11279 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
11280 if (IsUndefOrZero(N0))
11281 return DAG.getNode(
11282 ISD::SRL, DL, VT, N1,
11283 DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt, DL, ShAmtTy));
11284 if (IsUndefOrZero(N1))
11285 return DAG.getNode(
11286 ISD::SHL, DL, VT, N0,
11287 DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt, DL, ShAmtTy));
11288
11289 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11290 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
11291 // TODO - bigendian support once we have test coverage.
11292 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
11293 // TODO - permit LHS EXTLOAD if extensions are shifted out.
11294 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
11295 !DAG.getDataLayout().isBigEndian()) {
11296 auto *LHS = dyn_cast<LoadSDNode>(N0);
11297 auto *RHS = dyn_cast<LoadSDNode>(N1);
11298 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
11299 LHS->getAddressSpace() == RHS->getAddressSpace() &&
11300 (LHS->hasNUsesOfValue(1, 0) || RHS->hasNUsesOfValue(1, 0)) &&
11301 ISD::isNON_EXTLoad(RHS) && ISD::isNON_EXTLoad(LHS)) {
11302 if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
11303 SDLoc DL(RHS);
11304 uint64_t PtrOff =
11305 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
11306 Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
11307 unsigned Fast = 0;
11308 if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
11309 RHS->getAddressSpace(), NewAlign,
11310 RHS->getMemOperand()->getFlags(), &Fast) &&
11311 Fast) {
11312 SDValue NewPtr = DAG.getMemBasePlusOffset(
11313 RHS->getBasePtr(), TypeSize::getFixed(PtrOff), DL);
11314 AddToWorklist(NewPtr.getNode());
11315 SDValue Load = DAG.getLoad(
11316 VT, DL, RHS->getChain(), NewPtr,
11317 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
11318 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
11319 DAG.makeEquivalentMemoryOrdering(LHS, Load.getValue(1));
11320 DAG.makeEquivalentMemoryOrdering(RHS, Load.getValue(1));
11321 return Load;
11322 }
11323 }
11324 }
11325 }
11326 }
11327
11328 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
11329 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
11330 // iff We know the shift amount is in range.
11331 // TODO: when is it worth doing SUB(BW, N2) as well?
11332 if (isPowerOf2_32(BitWidth)) {
11333 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
11334 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
11335 return DAG.getNode(ISD::SRL, DL, VT, N1, N2);
11336 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
11337 return DAG.getNode(ISD::SHL, DL, VT, N0, N2);
11338 }
11339
11340 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
11341 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
11342 // TODO: Investigate flipping this rotate if only one is legal.
11343 // If funnel shift is legal as well we might be better off avoiding
11344 // non-constant (BW - N2).
11345 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
11346 if (N0 == N1 && hasOperation(RotOpc, VT))
11347 return DAG.getNode(RotOpc, DL, VT, N0, N2);
11348
11349 // Simplify, based on bits shifted out of N0/N1.
11350 if (SimplifyDemandedBits(SDValue(N, 0)))
11351 return SDValue(N, 0);
11352
11353 return SDValue();
11354 }
11355
visitSHLSAT(SDNode * N)11356 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
11357 SDValue N0 = N->getOperand(0);
11358 SDValue N1 = N->getOperand(1);
11359 if (SDValue V = DAG.simplifyShift(N0, N1))
11360 return V;
11361
11362 SDLoc DL(N);
11363 EVT VT = N0.getValueType();
11364
11365 // fold (*shlsat c1, c2) -> c1<<c2
11366 if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
11367 return C;
11368
11369 ConstantSDNode *N1C = isConstOrConstSplat(N1);
11370
11371 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
11372 // fold (sshlsat x, c) -> (shl x, c)
11373 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
11374 N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
11375 return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
11376
11377 // fold (ushlsat x, c) -> (shl x, c)
11378 if (N->getOpcode() == ISD::USHLSAT && N1C &&
11379 N1C->getAPIntValue().ule(
11380 DAG.computeKnownBits(N0).countMinLeadingZeros()))
11381 return DAG.getNode(ISD::SHL, DL, VT, N0, N1);
11382 }
11383
11384 return SDValue();
11385 }
11386
11387 // Given a ABS node, detect the following patterns:
11388 // (ABS (SUB (EXTEND a), (EXTEND b))).
11389 // (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
11390 // Generates UABD/SABD instruction.
foldABSToABD(SDNode * N,const SDLoc & DL)11391 SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
11392 EVT SrcVT = N->getValueType(0);
11393
11394 if (N->getOpcode() == ISD::TRUNCATE)
11395 N = N->getOperand(0).getNode();
11396
11397 EVT VT = N->getValueType(0);
11398 SDValue Op0, Op1;
11399
11400 if (!sd_match(N, m_Abs(m_Sub(m_Value(Op0), m_Value(Op1)))))
11401 return SDValue();
11402
11403 SDValue AbsOp0 = N->getOperand(0);
11404 unsigned Opc0 = Op0.getOpcode();
11405
11406 // Check if the operands of the sub are (zero|sign)-extended, otherwise
11407 // fallback to ValueTracking.
11408 if (Opc0 != Op1.getOpcode() ||
11409 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
11410 Opc0 != ISD::SIGN_EXTEND_INREG)) {
11411 // fold (abs (sub nsw x, y)) -> abds(x, y)
11412 // Don't fold this for unsupported types as we lose the NSW handling.
11413 if (hasOperation(ISD::ABDS, VT) && TLI.preferABDSToABSWithNSW(VT) &&
11414 (AbsOp0->getFlags().hasNoSignedWrap() ||
11415 DAG.willNotOverflowSub(/*IsSigned=*/true, Op0, Op1))) {
11416 SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
11417 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11418 }
11419 // fold (abs (sub x, y)) -> abdu(x, y)
11420 if (hasOperation(ISD::ABDU, VT) && DAG.SignBitIsZero(Op0) &&
11421 DAG.SignBitIsZero(Op1)) {
11422 SDValue ABD = DAG.getNode(ISD::ABDU, DL, VT, Op0, Op1);
11423 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11424 }
11425 return SDValue();
11426 }
11427
11428 EVT VT0, VT1;
11429 if (Opc0 == ISD::SIGN_EXTEND_INREG) {
11430 VT0 = cast<VTSDNode>(Op0.getOperand(1))->getVT();
11431 VT1 = cast<VTSDNode>(Op1.getOperand(1))->getVT();
11432 } else {
11433 VT0 = Op0.getOperand(0).getValueType();
11434 VT1 = Op1.getOperand(0).getValueType();
11435 }
11436 unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
11437
11438 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
11439 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
11440 EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
11441 if ((VT0 == MaxVT || Op0->hasOneUse()) &&
11442 (VT1 == MaxVT || Op1->hasOneUse()) &&
11443 (!LegalTypes || hasOperation(ABDOpcode, MaxVT))) {
11444 SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
11445 DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
11446 DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
11447 ABD = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ABD);
11448 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11449 }
11450
11451 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
11452 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
11453 if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
11454 SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
11455 return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
11456 }
11457
11458 return SDValue();
11459 }
11460
visitABS(SDNode * N)11461 SDValue DAGCombiner::visitABS(SDNode *N) {
11462 SDValue N0 = N->getOperand(0);
11463 EVT VT = N->getValueType(0);
11464 SDLoc DL(N);
11465
11466 // fold (abs c1) -> c2
11467 if (SDValue C = DAG.FoldConstantArithmetic(ISD::ABS, DL, VT, {N0}))
11468 return C;
11469 // fold (abs (abs x)) -> (abs x)
11470 if (N0.getOpcode() == ISD::ABS)
11471 return N0;
11472 // fold (abs x) -> x iff not-negative
11473 if (DAG.SignBitIsZero(N0))
11474 return N0;
11475
11476 if (SDValue ABD = foldABSToABD(N, DL))
11477 return ABD;
11478
11479 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
11480 // iff zero_extend/truncate are free.
11481 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
11482 EVT ExtVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
11483 if (TLI.isTruncateFree(VT, ExtVT) && TLI.isZExtFree(ExtVT, VT) &&
11484 TLI.isTypeDesirableForOp(ISD::ABS, ExtVT) &&
11485 hasOperation(ISD::ABS, ExtVT)) {
11486 return DAG.getNode(
11487 ISD::ZERO_EXTEND, DL, VT,
11488 DAG.getNode(ISD::ABS, DL, ExtVT,
11489 DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N0.getOperand(0))));
11490 }
11491 }
11492
11493 return SDValue();
11494 }
11495
visitBSWAP(SDNode * N)11496 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
11497 SDValue N0 = N->getOperand(0);
11498 EVT VT = N->getValueType(0);
11499 SDLoc DL(N);
11500
11501 // fold (bswap c1) -> c2
11502 if (SDValue C = DAG.FoldConstantArithmetic(ISD::BSWAP, DL, VT, {N0}))
11503 return C;
11504 // fold (bswap (bswap x)) -> x
11505 if (N0.getOpcode() == ISD::BSWAP)
11506 return N0.getOperand(0);
11507
11508 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
11509 // isn't supported, it will be expanded to bswap followed by a manual reversal
11510 // of bits in each byte. By placing bswaps before bitreverse, we can remove
11511 // the two bswaps if the bitreverse gets expanded.
11512 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
11513 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11514 return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
11515 }
11516
11517 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
11518 // iff x >= bw/2 (i.e. lower half is known zero)
11519 unsigned BW = VT.getScalarSizeInBits();
11520 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
11521 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11522 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
11523 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11524 ShAmt->getZExtValue() >= (BW / 2) &&
11525 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
11526 TLI.isTruncateFree(VT, HalfVT) &&
11527 (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
11528 SDValue Res = N0.getOperand(0);
11529 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
11530 Res = DAG.getNode(ISD::SHL, DL, VT, Res,
11531 DAG.getShiftAmountConstant(NewShAmt, VT, DL));
11532 Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
11533 Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
11534 return DAG.getZExtOrTrunc(Res, DL, VT);
11535 }
11536 }
11537
11538 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
11539 // inverse-shift-of-bswap:
11540 // bswap (X u<< C) --> (bswap X) u>> C
11541 // bswap (X u>> C) --> (bswap X) u<< C
11542 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
11543 N0.hasOneUse()) {
11544 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
11545 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
11546 ShAmt->getZExtValue() % 8 == 0) {
11547 SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
11548 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
11549 return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
11550 }
11551 }
11552
11553 if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
11554 return V;
11555
11556 return SDValue();
11557 }
11558
visitBITREVERSE(SDNode * N)11559 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
11560 SDValue N0 = N->getOperand(0);
11561 EVT VT = N->getValueType(0);
11562 SDLoc DL(N);
11563
11564 // fold (bitreverse c1) -> c2
11565 if (SDValue C = DAG.FoldConstantArithmetic(ISD::BITREVERSE, DL, VT, {N0}))
11566 return C;
11567
11568 // fold (bitreverse (bitreverse x)) -> x
11569 if (N0.getOpcode() == ISD::BITREVERSE)
11570 return N0.getOperand(0);
11571
11572 SDValue X, Y;
11573
11574 // fold (bitreverse (lshr (bitreverse x), y)) -> (shl x, y)
11575 if ((!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
11576 sd_match(N, m_BitReverse(m_Srl(m_BitReverse(m_Value(X)), m_Value(Y)))))
11577 return DAG.getNode(ISD::SHL, DL, VT, X, Y);
11578
11579 // fold (bitreverse (shl (bitreverse x), y)) -> (lshr x, y)
11580 if ((!LegalOperations || TLI.isOperationLegal(ISD::SRL, VT)) &&
11581 sd_match(N, m_BitReverse(m_Shl(m_BitReverse(m_Value(X)), m_Value(Y)))))
11582 return DAG.getNode(ISD::SRL, DL, VT, X, Y);
11583
11584 return SDValue();
11585 }
11586
visitCTLZ(SDNode * N)11587 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
11588 SDValue N0 = N->getOperand(0);
11589 EVT VT = N->getValueType(0);
11590 SDLoc DL(N);
11591
11592 // fold (ctlz c1) -> c2
11593 if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTLZ, DL, VT, {N0}))
11594 return C;
11595
11596 // If the value is known never to be zero, switch to the undef version.
11597 if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT))
11598 if (DAG.isKnownNeverZero(N0))
11599 return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, N0);
11600
11601 return SDValue();
11602 }
11603
visitCTLZ_ZERO_UNDEF(SDNode * N)11604 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
11605 SDValue N0 = N->getOperand(0);
11606 EVT VT = N->getValueType(0);
11607 SDLoc DL(N);
11608
11609 // fold (ctlz_zero_undef c1) -> c2
11610 if (SDValue C =
11611 DAG.FoldConstantArithmetic(ISD::CTLZ_ZERO_UNDEF, DL, VT, {N0}))
11612 return C;
11613 return SDValue();
11614 }
11615
visitCTTZ(SDNode * N)11616 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
11617 SDValue N0 = N->getOperand(0);
11618 EVT VT = N->getValueType(0);
11619 SDLoc DL(N);
11620
11621 // fold (cttz c1) -> c2
11622 if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTTZ, DL, VT, {N0}))
11623 return C;
11624
11625 // If the value is known never to be zero, switch to the undef version.
11626 if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT))
11627 if (DAG.isKnownNeverZero(N0))
11628 return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, DL, VT, N0);
11629
11630 return SDValue();
11631 }
11632
visitCTTZ_ZERO_UNDEF(SDNode * N)11633 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11634 SDValue N0 = N->getOperand(0);
11635 EVT VT = N->getValueType(0);
11636 SDLoc DL(N);
11637
11638 // fold (cttz_zero_undef c1) -> c2
11639 if (SDValue C =
11640 DAG.FoldConstantArithmetic(ISD::CTTZ_ZERO_UNDEF, DL, VT, {N0}))
11641 return C;
11642 return SDValue();
11643 }
11644
visitCTPOP(SDNode * N)11645 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11646 SDValue N0 = N->getOperand(0);
11647 EVT VT = N->getValueType(0);
11648 unsigned NumBits = VT.getScalarSizeInBits();
11649 SDLoc DL(N);
11650
11651 // fold (ctpop c1) -> c2
11652 if (SDValue C = DAG.FoldConstantArithmetic(ISD::CTPOP, DL, VT, {N0}))
11653 return C;
11654
11655 // If the source is being shifted, but doesn't affect any active bits,
11656 // then we can call CTPOP on the shift source directly.
11657 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
11658 if (ConstantSDNode *AmtC = isConstOrConstSplat(N0.getOperand(1))) {
11659 const APInt &Amt = AmtC->getAPIntValue();
11660 if (Amt.ult(NumBits)) {
11661 KnownBits KnownSrc = DAG.computeKnownBits(N0.getOperand(0));
11662 if ((N0.getOpcode() == ISD::SRL &&
11663 Amt.ule(KnownSrc.countMinTrailingZeros())) ||
11664 (N0.getOpcode() == ISD::SHL &&
11665 Amt.ule(KnownSrc.countMinLeadingZeros()))) {
11666 return DAG.getNode(ISD::CTPOP, DL, VT, N0.getOperand(0));
11667 }
11668 }
11669 }
11670 }
11671
11672 // If the upper bits are known to be zero, then see if its profitable to
11673 // only count the lower bits.
11674 if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
11675 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), NumBits / 2);
11676 if (hasOperation(ISD::CTPOP, HalfVT) &&
11677 TLI.isTypeDesirableForOp(ISD::CTPOP, HalfVT) &&
11678 TLI.isTruncateFree(N0, HalfVT) && TLI.isZExtFree(HalfVT, VT)) {
11679 APInt UpperBits = APInt::getHighBitsSet(NumBits, NumBits / 2);
11680 if (DAG.MaskedValueIsZero(N0, UpperBits)) {
11681 SDValue PopCnt = DAG.getNode(ISD::CTPOP, DL, HalfVT,
11682 DAG.getZExtOrTrunc(N0, DL, HalfVT));
11683 return DAG.getZExtOrTrunc(PopCnt, DL, VT);
11684 }
11685 }
11686 }
11687
11688 return SDValue();
11689 }
11690
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const SDNodeFlags Flags,const TargetLowering & TLI)11691 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11692 SDValue RHS, const SDNodeFlags Flags,
11693 const TargetLowering &TLI) {
11694 EVT VT = LHS.getValueType();
11695 if (!VT.isFloatingPoint())
11696 return false;
11697
11698 const TargetOptions &Options = DAG.getTarget().Options;
11699
11700 return (Flags.hasNoSignedZeros() || Options.NoSignedZerosFPMath) &&
11701 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11702 (Flags.hasNoNaNs() ||
11703 (DAG.isKnownNeverNaN(RHS) && DAG.isKnownNeverNaN(LHS)));
11704 }
11705
combineMinNumMaxNumImpl(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)11706 static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11707 SDValue RHS, SDValue True, SDValue False,
11708 ISD::CondCode CC,
11709 const TargetLowering &TLI,
11710 SelectionDAG &DAG) {
11711 EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
11712 switch (CC) {
11713 case ISD::SETOLT:
11714 case ISD::SETOLE:
11715 case ISD::SETLT:
11716 case ISD::SETLE:
11717 case ISD::SETULT:
11718 case ISD::SETULE: {
11719 // Since it's known never nan to get here already, either fminnum or
11720 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11721 // expanded in terms of it.
11722 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11723 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11724 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11725
11726 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11727 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11728 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11729 return SDValue();
11730 }
11731 case ISD::SETOGT:
11732 case ISD::SETOGE:
11733 case ISD::SETGT:
11734 case ISD::SETGE:
11735 case ISD::SETUGT:
11736 case ISD::SETUGE: {
11737 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11738 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
11739 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
11740
11741 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11742 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
11743 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
11744 return SDValue();
11745 }
11746 default:
11747 return SDValue();
11748 }
11749 }
11750
foldShiftToAvg(SDNode * N)11751 SDValue DAGCombiner::foldShiftToAvg(SDNode *N) {
11752 const unsigned Opcode = N->getOpcode();
11753
11754 // Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
11755 if (Opcode != ISD::SRA && Opcode != ISD::SRL)
11756 return SDValue();
11757
11758 unsigned FloorISD = 0;
11759 auto VT = N->getValueType(0);
11760 bool IsUnsigned = false;
11761
11762 // Decide wether signed or unsigned.
11763 switch (Opcode) {
11764 case ISD::SRA:
11765 if (!hasOperation(ISD::AVGFLOORS, VT))
11766 return SDValue();
11767 FloorISD = ISD::AVGFLOORS;
11768 break;
11769 case ISD::SRL:
11770 IsUnsigned = true;
11771 if (!hasOperation(ISD::AVGFLOORU, VT))
11772 return SDValue();
11773 FloorISD = ISD::AVGFLOORU;
11774 break;
11775 default:
11776 return SDValue();
11777 }
11778
11779 // Captured values.
11780 SDValue A, B, Add;
11781
11782 // Match floor average as it is common to both floor/ceil avgs.
11783 if (!sd_match(N, m_BinOp(Opcode,
11784 m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
11785 m_One())))
11786 return SDValue();
11787
11788 // Can't optimize adds that may wrap.
11789 if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
11790 return SDValue();
11791
11792 if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
11793 return SDValue();
11794
11795 return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0), {A, B});
11796 }
11797
foldBitwiseOpWithNeg(SDNode * N,const SDLoc & DL,EVT VT)11798 SDValue DAGCombiner::foldBitwiseOpWithNeg(SDNode *N, const SDLoc &DL, EVT VT) {
11799 unsigned Opc = N->getOpcode();
11800 SDValue X, Y, Z;
11801 if (sd_match(
11802 N, m_BitwiseLogic(m_Value(X), m_Add(m_Not(m_Value(Y)), m_Value(Z)))))
11803 return DAG.getNode(Opc, DL, VT, X,
11804 DAG.getNOT(DL, DAG.getNode(ISD::SUB, DL, VT, Y, Z), VT));
11805
11806 if (sd_match(N, m_BitwiseLogic(m_Value(X), m_Sub(m_OneUse(m_Not(m_Value(Y))),
11807 m_Value(Z)))))
11808 return DAG.getNode(Opc, DL, VT, X,
11809 DAG.getNOT(DL, DAG.getNode(ISD::ADD, DL, VT, Y, Z), VT));
11810
11811 return SDValue();
11812 }
11813
11814 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC)11815 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11816 SDValue RHS, SDValue True,
11817 SDValue False, ISD::CondCode CC) {
11818 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11819 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11820
11821 // If we can't directly match this, try to see if we can pull an fneg out of
11822 // the select.
11823 SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11824 True, DAG, LegalOperations, ForCodeSize);
11825 if (!NegTrue)
11826 return SDValue();
11827
11828 HandleSDNode NegTrueHandle(NegTrue);
11829
11830 // Try to unfold an fneg from the select if we are comparing the negated
11831 // constant.
11832 //
11833 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11834 //
11835 // TODO: Handle fabs
11836 if (LHS == NegTrue) {
11837 // If we can't directly match this, try to see if we can pull an fneg out of
11838 // the select.
11839 SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11840 RHS, DAG, LegalOperations, ForCodeSize);
11841 if (NegRHS) {
11842 HandleSDNode NegRHSHandle(NegRHS);
11843 if (NegRHS == False) {
11844 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
11845 False, CC, TLI, DAG);
11846 if (Combined)
11847 return DAG.getNode(ISD::FNEG, DL, VT, Combined);
11848 }
11849 }
11850 }
11851
11852 return SDValue();
11853 }
11854
11855 /// If a (v)select has a condition value that is a sign-bit test, try to smear
11856 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)11857 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, const SDLoc &DL,
11858 SelectionDAG &DAG) {
11859 SDValue Cond = N->getOperand(0);
11860 SDValue C1 = N->getOperand(1);
11861 SDValue C2 = N->getOperand(2);
11862 if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
11863 return SDValue();
11864
11865 EVT VT = N->getValueType(0);
11866 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11867 VT != Cond.getOperand(0).getValueType())
11868 return SDValue();
11869
11870 // The inverted-condition + commuted-select variants of these patterns are
11871 // canonicalized to these forms in IR.
11872 SDValue X = Cond.getOperand(0);
11873 SDValue CondC = Cond.getOperand(1);
11874 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11875 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
11876 isAllOnesOrAllOnesSplat(C2)) {
11877 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11878 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11879 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11880 return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
11881 }
11882 if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
11883 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11884 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
11885 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
11886 return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
11887 }
11888 return SDValue();
11889 }
11890
shouldConvertSelectOfConstantsToMath(const SDValue & Cond,EVT VT,const TargetLowering & TLI)11891 static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11892 const TargetLowering &TLI) {
11893 if (!TLI.convertSelectOfConstantsToMath(VT))
11894 return false;
11895
11896 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11897 return true;
11898 if (!TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))
11899 return true;
11900
11901 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
11902 if (CC == ISD::SETLT && isNullOrNullSplat(Cond.getOperand(1)))
11903 return true;
11904 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond.getOperand(1)))
11905 return true;
11906
11907 return false;
11908 }
11909
foldSelectOfConstants(SDNode * N)11910 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11911 SDValue Cond = N->getOperand(0);
11912 SDValue N1 = N->getOperand(1);
11913 SDValue N2 = N->getOperand(2);
11914 EVT VT = N->getValueType(0);
11915 EVT CondVT = Cond.getValueType();
11916 SDLoc DL(N);
11917
11918 if (!VT.isInteger())
11919 return SDValue();
11920
11921 auto *C1 = dyn_cast<ConstantSDNode>(N1);
11922 auto *C2 = dyn_cast<ConstantSDNode>(N2);
11923 if (!C1 || !C2)
11924 return SDValue();
11925
11926 if (CondVT != MVT::i1 || LegalOperations) {
11927 // fold (select Cond, 0, 1) -> (xor Cond, 1)
11928 // We can't do this reliably if integer based booleans have different contents
11929 // to floating point based booleans. This is because we can't tell whether we
11930 // have an integer-based boolean or a floating-point-based boolean unless we
11931 // can find the SETCC that produced it and inspect its operands. This is
11932 // fairly easy if C is the SETCC node, but it can potentially be
11933 // undiscoverable (or not reasonably discoverable). For example, it could be
11934 // in another basic block or it could require searching a complicated
11935 // expression.
11936 if (CondVT.isInteger() &&
11937 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11938 TargetLowering::ZeroOrOneBooleanContent &&
11939 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11940 TargetLowering::ZeroOrOneBooleanContent &&
11941 C1->isZero() && C2->isOne()) {
11942 SDValue NotCond =
11943 DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
11944 if (VT.bitsEq(CondVT))
11945 return NotCond;
11946 return DAG.getZExtOrTrunc(NotCond, DL, VT);
11947 }
11948
11949 return SDValue();
11950 }
11951
11952 // Only do this before legalization to avoid conflicting with target-specific
11953 // transforms in the other direction (create a select from a zext/sext). There
11954 // is also a target-independent combine here in DAGCombiner in the other
11955 // direction for (select Cond, -1, 0) when the condition is not i1.
11956 assert(CondVT == MVT::i1 && !LegalOperations);
11957
11958 // select Cond, 1, 0 --> zext (Cond)
11959 if (C1->isOne() && C2->isZero())
11960 return DAG.getZExtOrTrunc(Cond, DL, VT);
11961
11962 // select Cond, -1, 0 --> sext (Cond)
11963 if (C1->isAllOnes() && C2->isZero())
11964 return DAG.getSExtOrTrunc(Cond, DL, VT);
11965
11966 // select Cond, 0, 1 --> zext (!Cond)
11967 if (C1->isZero() && C2->isOne()) {
11968 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11969 NotCond = DAG.getZExtOrTrunc(NotCond, DL, VT);
11970 return NotCond;
11971 }
11972
11973 // select Cond, 0, -1 --> sext (!Cond)
11974 if (C1->isZero() && C2->isAllOnes()) {
11975 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11976 NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
11977 return NotCond;
11978 }
11979
11980 // Use a target hook because some targets may prefer to transform in the
11981 // other direction.
11982 if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11983 return SDValue();
11984
11985 // For any constants that differ by 1, we can transform the select into
11986 // an extend and add.
11987 const APInt &C1Val = C1->getAPIntValue();
11988 const APInt &C2Val = C2->getAPIntValue();
11989
11990 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11991 if (C1Val - 1 == C2Val) {
11992 Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
11993 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
11994 }
11995
11996 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11997 if (C1Val + 1 == C2Val) {
11998 Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
11999 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
12000 }
12001
12002 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
12003 if (C1Val.isPowerOf2() && C2Val.isZero()) {
12004 Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
12005 SDValue ShAmtC =
12006 DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
12007 return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
12008 }
12009
12010 // select Cond, -1, C --> or (sext Cond), C
12011 if (C1->isAllOnes()) {
12012 Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
12013 return DAG.getNode(ISD::OR, DL, VT, Cond, N2);
12014 }
12015
12016 // select Cond, C, -1 --> or (sext (not Cond)), C
12017 if (C2->isAllOnes()) {
12018 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
12019 NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
12020 return DAG.getNode(ISD::OR, DL, VT, NotCond, N1);
12021 }
12022
12023 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
12024 return V;
12025
12026 return SDValue();
12027 }
12028
12029 template <class MatchContextClass>
foldBoolSelectToLogic(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)12030 static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL,
12031 SelectionDAG &DAG) {
12032 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
12033 N->getOpcode() == ISD::VP_SELECT) &&
12034 "Expected a (v)(vp.)select");
12035 SDValue Cond = N->getOperand(0);
12036 SDValue T = N->getOperand(1), F = N->getOperand(2);
12037 EVT VT = N->getValueType(0);
12038 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12039 MatchContextClass matcher(DAG, TLI, N);
12040
12041 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
12042 return SDValue();
12043
12044 // select Cond, Cond, F --> or Cond, freeze(F)
12045 // select Cond, 1, F --> or Cond, freeze(F)
12046 if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
12047 return matcher.getNode(ISD::OR, DL, VT, Cond, DAG.getFreeze(F));
12048
12049 // select Cond, T, Cond --> and Cond, freeze(T)
12050 // select Cond, T, 0 --> and Cond, freeze(T)
12051 if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
12052 return matcher.getNode(ISD::AND, DL, VT, Cond, DAG.getFreeze(T));
12053
12054 // select Cond, T, 1 --> or (not Cond), freeze(T)
12055 if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
12056 SDValue NotCond =
12057 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12058 return matcher.getNode(ISD::OR, DL, VT, NotCond, DAG.getFreeze(T));
12059 }
12060
12061 // select Cond, 0, F --> and (not Cond), freeze(F)
12062 if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
12063 SDValue NotCond =
12064 matcher.getNode(ISD::XOR, DL, VT, Cond, DAG.getAllOnesConstant(DL, VT));
12065 return matcher.getNode(ISD::AND, DL, VT, NotCond, DAG.getFreeze(F));
12066 }
12067
12068 return SDValue();
12069 }
12070
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)12071 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
12072 SDValue N0 = N->getOperand(0);
12073 SDValue N1 = N->getOperand(1);
12074 SDValue N2 = N->getOperand(2);
12075 EVT VT = N->getValueType(0);
12076 unsigned EltSizeInBits = VT.getScalarSizeInBits();
12077
12078 SDValue Cond0, Cond1;
12079 ISD::CondCode CC;
12080 if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
12081 m_CondCode(CC)))) ||
12082 VT != Cond0.getValueType())
12083 return SDValue();
12084
12085 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
12086 // compare is inverted from that pattern ("Cond0 s> -1").
12087 if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
12088 ; // This is the pattern we are looking for.
12089 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
12090 std::swap(N1, N2);
12091 else
12092 return SDValue();
12093
12094 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & freeze(N1)
12095 if (isNullOrNullSplat(N2)) {
12096 SDLoc DL(N);
12097 SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12098 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12099 return DAG.getNode(ISD::AND, DL, VT, Sra, DAG.getFreeze(N1));
12100 }
12101
12102 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | freeze(N2)
12103 if (isAllOnesOrAllOnesSplat(N1)) {
12104 SDLoc DL(N);
12105 SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12106 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12107 return DAG.getNode(ISD::OR, DL, VT, Sra, DAG.getFreeze(N2));
12108 }
12109
12110 // If we have to invert the sign bit mask, only do that transform if the
12111 // target has a bitwise 'and not' instruction (the invert is free).
12112 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & freeze(N2)
12113 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12114 if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
12115 SDLoc DL(N);
12116 SDValue ShiftAmt = DAG.getShiftAmountConstant(EltSizeInBits - 1, VT, DL);
12117 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
12118 SDValue Not = DAG.getNOT(DL, Sra, VT);
12119 return DAG.getNode(ISD::AND, DL, VT, Not, DAG.getFreeze(N2));
12120 }
12121
12122 // TODO: There's another pattern in this family, but it may require
12123 // implementing hasOrNot() to check for profitability:
12124 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | freeze(N2)
12125
12126 return SDValue();
12127 }
12128
12129 // Match SELECTs with absolute difference patterns.
12130 // (select (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12131 // (select (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12132 // (select (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12133 // (select (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
foldSelectToABD(SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const SDLoc & DL)12134 SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
12135 SDValue False, ISD::CondCode CC,
12136 const SDLoc &DL) {
12137 bool IsSigned = isSignedIntSetCC(CC);
12138 unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12139 EVT VT = LHS.getValueType();
12140
12141 if (LegalOperations && !hasOperation(ABDOpc, VT))
12142 return SDValue();
12143
12144 switch (CC) {
12145 case ISD::SETGT:
12146 case ISD::SETGE:
12147 case ISD::SETUGT:
12148 case ISD::SETUGE:
12149 if (sd_match(True, m_Sub(m_Specific(LHS), m_Specific(RHS))) &&
12150 sd_match(False, m_Sub(m_Specific(RHS), m_Specific(LHS))))
12151 return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12152 if (sd_match(True, m_Sub(m_Specific(RHS), m_Specific(LHS))) &&
12153 sd_match(False, m_Sub(m_Specific(LHS), m_Specific(RHS))) &&
12154 hasOperation(ABDOpc, VT))
12155 return DAG.getNegative(DAG.getNode(ABDOpc, DL, VT, LHS, RHS), DL, VT);
12156 break;
12157 case ISD::SETLT:
12158 case ISD::SETLE:
12159 case ISD::SETULT:
12160 case ISD::SETULE:
12161 if (sd_match(True, m_Sub(m_Specific(RHS), m_Specific(LHS))) &&
12162 sd_match(False, m_Sub(m_Specific(LHS), m_Specific(RHS))))
12163 return DAG.getNode(ABDOpc, DL, VT, LHS, RHS);
12164 if (sd_match(True, m_Sub(m_Specific(LHS), m_Specific(RHS))) &&
12165 sd_match(False, m_Sub(m_Specific(RHS), m_Specific(LHS))) &&
12166 hasOperation(ABDOpc, VT))
12167 return DAG.getNegative(DAG.getNode(ABDOpc, DL, VT, LHS, RHS), DL, VT);
12168 break;
12169 default:
12170 break;
12171 }
12172
12173 return SDValue();
12174 }
12175
visitSELECT(SDNode * N)12176 SDValue DAGCombiner::visitSELECT(SDNode *N) {
12177 SDValue N0 = N->getOperand(0);
12178 SDValue N1 = N->getOperand(1);
12179 SDValue N2 = N->getOperand(2);
12180 EVT VT = N->getValueType(0);
12181 EVT VT0 = N0.getValueType();
12182 SDLoc DL(N);
12183 SDNodeFlags Flags = N->getFlags();
12184
12185 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
12186 return V;
12187
12188 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
12189 return V;
12190
12191 // select (not Cond), N1, N2 -> select Cond, N2, N1
12192 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
12193 return DAG.getSelect(DL, VT, F, N2, N1, Flags);
12194
12195 if (SDValue V = foldSelectOfConstants(N))
12196 return V;
12197
12198 // If we can fold this based on the true/false value, do so.
12199 if (SimplifySelectOps(N, N1, N2))
12200 return SDValue(N, 0); // Don't revisit N.
12201
12202 if (VT0 == MVT::i1) {
12203 // The code in this block deals with the following 2 equivalences:
12204 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
12205 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
12206 // The target can specify its preferred form with the
12207 // shouldNormalizeToSelectSequence() callback. However we always transform
12208 // to the right anyway if we find the inner select exists in the DAG anyway
12209 // and we always transform to the left side if we know that we can further
12210 // optimize the combination of the conditions.
12211 bool normalizeToSequence =
12212 TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
12213 // select (and Cond0, Cond1), X, Y
12214 // -> select Cond0, (select Cond1, X, Y), Y
12215 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
12216 SDValue Cond0 = N0->getOperand(0);
12217 SDValue Cond1 = N0->getOperand(1);
12218 SDValue InnerSelect =
12219 DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
12220 if (normalizeToSequence || !InnerSelect.use_empty())
12221 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
12222 InnerSelect, N2, Flags);
12223 // Cleanup on failure.
12224 if (InnerSelect.use_empty())
12225 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
12226 }
12227 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
12228 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
12229 SDValue Cond0 = N0->getOperand(0);
12230 SDValue Cond1 = N0->getOperand(1);
12231 SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
12232 Cond1, N1, N2, Flags);
12233 if (normalizeToSequence || !InnerSelect.use_empty())
12234 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
12235 InnerSelect, Flags);
12236 // Cleanup on failure.
12237 if (InnerSelect.use_empty())
12238 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
12239 }
12240
12241 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
12242 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
12243 SDValue N1_0 = N1->getOperand(0);
12244 SDValue N1_1 = N1->getOperand(1);
12245 SDValue N1_2 = N1->getOperand(2);
12246 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
12247 // Create the actual and node if we can generate good code for it.
12248 if (!normalizeToSequence) {
12249 SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
12250 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
12251 N2, Flags);
12252 }
12253 // Otherwise see if we can optimize the "and" to a better pattern.
12254 if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
12255 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
12256 N2, Flags);
12257 }
12258 }
12259 }
12260 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
12261 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
12262 SDValue N2_0 = N2->getOperand(0);
12263 SDValue N2_1 = N2->getOperand(1);
12264 SDValue N2_2 = N2->getOperand(2);
12265 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
12266 // Create the actual or node if we can generate good code for it.
12267 if (!normalizeToSequence) {
12268 SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
12269 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
12270 N2_2, Flags);
12271 }
12272 // Otherwise see if we can optimize to a better pattern.
12273 if (SDValue Combined = visitORLike(N0, N2_0, DL))
12274 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
12275 N2_2, Flags);
12276 }
12277 }
12278
12279 // select usubo(x, y).overflow, (sub y, x), (usubo x, y) -> abdu(x, y)
12280 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12281 N2.getNode() == N0.getNode() && N2.getResNo() == 0 &&
12282 N1.getOpcode() == ISD::SUB && N2.getOperand(0) == N1.getOperand(1) &&
12283 N2.getOperand(1) == N1.getOperand(0) &&
12284 (!LegalOperations || TLI.isOperationLegal(ISD::ABDU, VT)))
12285 return DAG.getNode(ISD::ABDU, DL, VT, N0.getOperand(0), N0.getOperand(1));
12286
12287 // select usubo(x, y).overflow, (usubo x, y), (sub y, x) -> neg (abdu x, y)
12288 if (N0.getOpcode() == ISD::USUBO && N0.getResNo() == 1 &&
12289 N1.getNode() == N0.getNode() && N1.getResNo() == 0 &&
12290 N2.getOpcode() == ISD::SUB && N2.getOperand(0) == N1.getOperand(1) &&
12291 N2.getOperand(1) == N1.getOperand(0) &&
12292 (!LegalOperations || TLI.isOperationLegal(ISD::ABDU, VT)))
12293 return DAG.getNegative(
12294 DAG.getNode(ISD::ABDU, DL, VT, N0.getOperand(0), N0.getOperand(1)),
12295 DL, VT);
12296 }
12297
12298 // Fold selects based on a setcc into other things, such as min/max/abs.
12299 if (N0.getOpcode() == ISD::SETCC) {
12300 SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
12301 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
12302
12303 // select (fcmp lt x, y), x, y -> fminnum x, y
12304 // select (fcmp gt x, y), x, y -> fmaxnum x, y
12305 //
12306 // This is OK if we don't care what happens if either operand is a NaN.
12307 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, Flags, TLI))
12308 if (SDValue FMinMax =
12309 combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
12310 return FMinMax;
12311
12312 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
12313 // This is conservatively limited to pre-legal-operations to give targets
12314 // a chance to reverse the transform if they want to do that. Also, it is
12315 // unlikely that the pattern would be formed late, so it's probably not
12316 // worth going through the other checks.
12317 if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
12318 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
12319 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
12320 auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
12321 auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
12322 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
12323 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
12324 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
12325 //
12326 // The IR equivalent of this transform would have this form:
12327 // %a = add %x, C
12328 // %c = icmp ugt %x, ~C
12329 // %r = select %c, -1, %a
12330 // =>
12331 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
12332 // %u0 = extractvalue %u, 0
12333 // %u1 = extractvalue %u, 1
12334 // %r = select %u1, -1, %u0
12335 SDVTList VTs = DAG.getVTList(VT, VT0);
12336 SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
12337 return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
12338 }
12339 }
12340
12341 if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
12342 (!LegalOperations &&
12343 TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
12344 // Any flags available in a select/setcc fold will be on the setcc as they
12345 // migrated from fcmp
12346 Flags = N0->getFlags();
12347 SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
12348 N2, N0.getOperand(2));
12349 SelectNode->setFlags(Flags);
12350 return SelectNode;
12351 }
12352
12353 if (SDValue ABD = foldSelectToABD(Cond0, Cond1, N1, N2, CC, DL))
12354 return ABD;
12355
12356 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
12357 return NewSel;
12358
12359 // (select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x)
12360 // (select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C))
12361 APInt C;
12362 if (sd_match(Cond1, m_ConstInt(C)) && hasUMin(VT)) {
12363 if (CC == ISD::SETUGT && Cond0 == N2 &&
12364 sd_match(N1, m_Add(m_Specific(N2), m_SpecificInt(~C)))) {
12365 // The resulting code relies on an unsigned wrap in ADD.
12366 // Recreating ADD to drop possible nuw/nsw flags.
12367 SDValue AddC = DAG.getConstant(~C, DL, VT);
12368 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N2, AddC);
12369 return DAG.getNode(ISD::UMIN, DL, VT, Add, N2);
12370 }
12371 if (CC == ISD::SETULT && Cond0 == N1 &&
12372 sd_match(N2, m_Add(m_Specific(N1), m_SpecificInt(-C)))) {
12373 // Ditto.
12374 SDValue AddC = DAG.getConstant(-C, DL, VT);
12375 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, AddC);
12376 return DAG.getNode(ISD::UMIN, DL, VT, N1, Add);
12377 }
12378 }
12379 }
12380
12381 if (!VT.isVector())
12382 if (SDValue BinOp = foldSelectOfBinops(N))
12383 return BinOp;
12384
12385 if (SDValue R = combineSelectAsExtAnd(N0, N1, N2, DL, DAG))
12386 return R;
12387
12388 return SDValue();
12389 }
12390
12391 // This function assumes all the vselect's arguments are CONCAT_VECTOR
12392 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)12393 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
12394 SDLoc DL(N);
12395 SDValue Cond = N->getOperand(0);
12396 SDValue LHS = N->getOperand(1);
12397 SDValue RHS = N->getOperand(2);
12398 EVT VT = N->getValueType(0);
12399 int NumElems = VT.getVectorNumElements();
12400 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
12401 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
12402 Cond.getOpcode() == ISD::BUILD_VECTOR);
12403
12404 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
12405 // binary ones here.
12406 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
12407 return SDValue();
12408
12409 // We're sure we have an even number of elements due to the
12410 // concat_vectors we have as arguments to vselect.
12411 // Skip BV elements until we find one that's not an UNDEF
12412 // After we find an UNDEF element, keep looping until we get to half the
12413 // length of the BV and see if all the non-undef nodes are the same.
12414 ConstantSDNode *BottomHalf = nullptr;
12415 for (int i = 0; i < NumElems / 2; ++i) {
12416 if (Cond->getOperand(i)->isUndef())
12417 continue;
12418
12419 if (BottomHalf == nullptr)
12420 BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
12421 else if (Cond->getOperand(i).getNode() != BottomHalf)
12422 return SDValue();
12423 }
12424
12425 // Do the same for the second half of the BuildVector
12426 ConstantSDNode *TopHalf = nullptr;
12427 for (int i = NumElems / 2; i < NumElems; ++i) {
12428 if (Cond->getOperand(i)->isUndef())
12429 continue;
12430
12431 if (TopHalf == nullptr)
12432 TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
12433 else if (Cond->getOperand(i).getNode() != TopHalf)
12434 return SDValue();
12435 }
12436
12437 assert(TopHalf && BottomHalf &&
12438 "One half of the selector was all UNDEFs and the other was all the "
12439 "same value. This should have been addressed before this function.");
12440 return DAG.getNode(
12441 ISD::CONCAT_VECTORS, DL, VT,
12442 BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
12443 TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
12444 }
12445
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG,const SDLoc & DL)12446 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
12447 SelectionDAG &DAG, const SDLoc &DL) {
12448
12449 // Only perform the transformation when existing operands can be reused.
12450 if (IndexIsScaled)
12451 return false;
12452
12453 if (!isNullConstant(BasePtr) && !Index.hasOneUse())
12454 return false;
12455
12456 EVT VT = BasePtr.getValueType();
12457
12458 if (SDValue SplatVal = DAG.getSplatValue(Index);
12459 SplatVal && !isNullConstant(SplatVal) &&
12460 SplatVal.getValueType() == VT) {
12461 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
12462 Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
12463 return true;
12464 }
12465
12466 if (Index.getOpcode() != ISD::ADD)
12467 return false;
12468
12469 if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
12470 SplatVal && SplatVal.getValueType() == VT) {
12471 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
12472 Index = Index.getOperand(1);
12473 return true;
12474 }
12475 if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
12476 SplatVal && SplatVal.getValueType() == VT) {
12477 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
12478 Index = Index.getOperand(0);
12479 return true;
12480 }
12481 return false;
12482 }
12483
12484 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)12485 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
12486 SelectionDAG &DAG) {
12487 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12488
12489 // It's always safe to look through zero extends.
12490 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
12491 if (TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
12492 IndexType = ISD::UNSIGNED_SCALED;
12493 Index = Index.getOperand(0);
12494 return true;
12495 }
12496 if (ISD::isIndexTypeSigned(IndexType)) {
12497 IndexType = ISD::UNSIGNED_SCALED;
12498 return true;
12499 }
12500 }
12501
12502 // It's only safe to look through sign extends when Index is signed.
12503 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
12504 ISD::isIndexTypeSigned(IndexType) &&
12505 TLI.shouldRemoveExtendFromGSIndex(Index, DataVT)) {
12506 Index = Index.getOperand(0);
12507 return true;
12508 }
12509
12510 return false;
12511 }
12512
visitVPSCATTER(SDNode * N)12513 SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
12514 VPScatterSDNode *MSC = cast<VPScatterSDNode>(N);
12515 SDValue Mask = MSC->getMask();
12516 SDValue Chain = MSC->getChain();
12517 SDValue Index = MSC->getIndex();
12518 SDValue Scale = MSC->getScale();
12519 SDValue StoreVal = MSC->getValue();
12520 SDValue BasePtr = MSC->getBasePtr();
12521 SDValue VL = MSC->getVectorLength();
12522 ISD::MemIndexType IndexType = MSC->getIndexType();
12523 SDLoc DL(N);
12524
12525 // Zap scatters with a zero mask.
12526 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12527 return Chain;
12528
12529 if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
12530 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12531 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12532 DL, Ops, MSC->getMemOperand(), IndexType);
12533 }
12534
12535 if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
12536 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
12537 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12538 DL, Ops, MSC->getMemOperand(), IndexType);
12539 }
12540
12541 return SDValue();
12542 }
12543
visitMSCATTER(SDNode * N)12544 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
12545 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
12546 SDValue Mask = MSC->getMask();
12547 SDValue Chain = MSC->getChain();
12548 SDValue Index = MSC->getIndex();
12549 SDValue Scale = MSC->getScale();
12550 SDValue StoreVal = MSC->getValue();
12551 SDValue BasePtr = MSC->getBasePtr();
12552 ISD::MemIndexType IndexType = MSC->getIndexType();
12553 SDLoc DL(N);
12554
12555 // Zap scatters with a zero mask.
12556 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12557 return Chain;
12558
12559 if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
12560 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
12561 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12562 DL, Ops, MSC->getMemOperand(), IndexType,
12563 MSC->isTruncatingStore());
12564 }
12565
12566 if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
12567 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
12568 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
12569 DL, Ops, MSC->getMemOperand(), IndexType,
12570 MSC->isTruncatingStore());
12571 }
12572
12573 return SDValue();
12574 }
12575
visitMSTORE(SDNode * N)12576 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
12577 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
12578 SDValue Mask = MST->getMask();
12579 SDValue Chain = MST->getChain();
12580 SDValue Value = MST->getValue();
12581 SDValue Ptr = MST->getBasePtr();
12582
12583 // Zap masked stores with a zero mask.
12584 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12585 return Chain;
12586
12587 // Remove a masked store if base pointers and masks are equal.
12588 if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Chain)) {
12589 if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
12590 MST1->isSimple() && MST1->getBasePtr() == Ptr &&
12591 !MST->getBasePtr().isUndef() &&
12592 ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
12593 MST1->getMemoryVT().getStoreSize()) ||
12594 ISD::isConstantSplatVectorAllOnes(Mask.getNode())) &&
12595 TypeSize::isKnownLE(MST1->getMemoryVT().getStoreSize(),
12596 MST->getMemoryVT().getStoreSize())) {
12597 CombineTo(MST1, MST1->getChain());
12598 if (N->getOpcode() != ISD::DELETED_NODE)
12599 AddToWorklist(N);
12600 return SDValue(N, 0);
12601 }
12602 }
12603
12604 // If this is a masked load with an all ones mask, we can use a unmasked load.
12605 // FIXME: Can we do this for indexed, compressing, or truncating stores?
12606 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
12607 !MST->isCompressingStore() && !MST->isTruncatingStore())
12608 return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
12609 MST->getBasePtr(), MST->getPointerInfo(),
12610 MST->getBaseAlign(), MST->getMemOperand()->getFlags(),
12611 MST->getAAInfo());
12612
12613 // Try transforming N to an indexed store.
12614 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12615 return SDValue(N, 0);
12616
12617 if (MST->isTruncatingStore() && MST->isUnindexed() &&
12618 Value.getValueType().isInteger() &&
12619 (!isa<ConstantSDNode>(Value) ||
12620 !cast<ConstantSDNode>(Value)->isOpaque())) {
12621 APInt TruncDemandedBits =
12622 APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
12623 MST->getMemoryVT().getScalarSizeInBits());
12624
12625 // See if we can simplify the operation with
12626 // SimplifyDemandedBits, which only works if the value has a single use.
12627 if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
12628 // Re-visit the store if anything changed and the store hasn't been merged
12629 // with another node (N is deleted) SimplifyDemandedBits will add Value's
12630 // node back to the worklist if necessary, but we also need to re-visit
12631 // the Store node itself.
12632 if (N->getOpcode() != ISD::DELETED_NODE)
12633 AddToWorklist(N);
12634 return SDValue(N, 0);
12635 }
12636 }
12637
12638 // If this is a TRUNC followed by a masked store, fold this into a masked
12639 // truncating store. We can do this even if this is already a masked
12640 // truncstore.
12641 // TODO: Try combine to masked compress store if possiable.
12642 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
12643 MST->isUnindexed() && !MST->isCompressingStore() &&
12644 TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
12645 MST->getMemoryVT(), LegalOperations)) {
12646 auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
12647 Value.getOperand(0).getValueType());
12648 return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
12649 MST->getOffset(), Mask, MST->getMemoryVT(),
12650 MST->getMemOperand(), MST->getAddressingMode(),
12651 /*IsTruncating=*/true);
12652 }
12653
12654 return SDValue();
12655 }
12656
visitVP_STRIDED_STORE(SDNode * N)12657 SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
12658 auto *SST = cast<VPStridedStoreSDNode>(N);
12659 EVT EltVT = SST->getValue().getValueType().getVectorElementType();
12660 // Combine strided stores with unit-stride to a regular VP store.
12661 if (auto *CStride = dyn_cast<ConstantSDNode>(SST->getStride());
12662 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12663 return DAG.getStoreVP(SST->getChain(), SDLoc(N), SST->getValue(),
12664 SST->getBasePtr(), SST->getOffset(), SST->getMask(),
12665 SST->getVectorLength(), SST->getMemoryVT(),
12666 SST->getMemOperand(), SST->getAddressingMode(),
12667 SST->isTruncatingStore(), SST->isCompressingStore());
12668 }
12669 return SDValue();
12670 }
12671
visitVECTOR_COMPRESS(SDNode * N)12672 SDValue DAGCombiner::visitVECTOR_COMPRESS(SDNode *N) {
12673 SDLoc DL(N);
12674 SDValue Vec = N->getOperand(0);
12675 SDValue Mask = N->getOperand(1);
12676 SDValue Passthru = N->getOperand(2);
12677 EVT VecVT = Vec.getValueType();
12678
12679 bool HasPassthru = !Passthru.isUndef();
12680
12681 APInt SplatVal;
12682 if (ISD::isConstantSplatVector(Mask.getNode(), SplatVal))
12683 return TLI.isConstTrueVal(Mask) ? Vec : Passthru;
12684
12685 if (Vec.isUndef() || Mask.isUndef())
12686 return Passthru;
12687
12688 // No need for potentially expensive compress if the mask is constant.
12689 if (ISD::isBuildVectorOfConstantSDNodes(Mask.getNode())) {
12690 SmallVector<SDValue, 16> Ops;
12691 EVT ScalarVT = VecVT.getVectorElementType();
12692 unsigned NumSelected = 0;
12693 unsigned NumElmts = VecVT.getVectorNumElements();
12694 for (unsigned I = 0; I < NumElmts; ++I) {
12695 SDValue MaskI = Mask.getOperand(I);
12696 // We treat undef mask entries as "false".
12697 if (MaskI.isUndef())
12698 continue;
12699
12700 if (TLI.isConstTrueVal(MaskI)) {
12701 SDValue VecI = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Vec,
12702 DAG.getVectorIdxConstant(I, DL));
12703 Ops.push_back(VecI);
12704 NumSelected++;
12705 }
12706 }
12707 for (unsigned Rest = NumSelected; Rest < NumElmts; ++Rest) {
12708 SDValue Val =
12709 HasPassthru
12710 ? DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Passthru,
12711 DAG.getVectorIdxConstant(Rest, DL))
12712 : DAG.getUNDEF(ScalarVT);
12713 Ops.push_back(Val);
12714 }
12715 return DAG.getBuildVector(VecVT, DL, Ops);
12716 }
12717
12718 return SDValue();
12719 }
12720
visitVPGATHER(SDNode * N)12721 SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
12722 VPGatherSDNode *MGT = cast<VPGatherSDNode>(N);
12723 SDValue Mask = MGT->getMask();
12724 SDValue Chain = MGT->getChain();
12725 SDValue Index = MGT->getIndex();
12726 SDValue Scale = MGT->getScale();
12727 SDValue BasePtr = MGT->getBasePtr();
12728 SDValue VL = MGT->getVectorLength();
12729 ISD::MemIndexType IndexType = MGT->getIndexType();
12730 SDLoc DL(N);
12731
12732 if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
12733 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12734 return DAG.getGatherVP(
12735 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12736 Ops, MGT->getMemOperand(), IndexType);
12737 }
12738
12739 if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12740 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
12741 return DAG.getGatherVP(
12742 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12743 Ops, MGT->getMemOperand(), IndexType);
12744 }
12745
12746 return SDValue();
12747 }
12748
visitMGATHER(SDNode * N)12749 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
12750 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
12751 SDValue Mask = MGT->getMask();
12752 SDValue Chain = MGT->getChain();
12753 SDValue Index = MGT->getIndex();
12754 SDValue Scale = MGT->getScale();
12755 SDValue PassThru = MGT->getPassThru();
12756 SDValue BasePtr = MGT->getBasePtr();
12757 ISD::MemIndexType IndexType = MGT->getIndexType();
12758 SDLoc DL(N);
12759
12760 // Zap gathers with a zero mask.
12761 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12762 return CombineTo(N, PassThru, MGT->getChain());
12763
12764 if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
12765 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12766 return DAG.getMaskedGather(
12767 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12768 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12769 }
12770
12771 if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
12772 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
12773 return DAG.getMaskedGather(
12774 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
12775 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
12776 }
12777
12778 return SDValue();
12779 }
12780
visitMLOAD(SDNode * N)12781 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
12782 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
12783 SDValue Mask = MLD->getMask();
12784
12785 // Zap masked loads with a zero mask.
12786 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12787 return CombineTo(N, MLD->getPassThru(), MLD->getChain());
12788
12789 // If this is a masked load with an all ones mask, we can use a unmasked load.
12790 // FIXME: Can we do this for indexed, expanding, or extending loads?
12791 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
12792 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
12793 SDValue NewLd = DAG.getLoad(
12794 N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
12795 MLD->getPointerInfo(), MLD->getBaseAlign(),
12796 MLD->getMemOperand()->getFlags(), MLD->getAAInfo(), MLD->getRanges());
12797 return CombineTo(N, NewLd, NewLd.getValue(1));
12798 }
12799
12800 // Try transforming N to an indexed load.
12801 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
12802 return SDValue(N, 0);
12803
12804 return SDValue();
12805 }
12806
visitMHISTOGRAM(SDNode * N)12807 SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
12808 MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
12809 SDValue Chain = HG->getChain();
12810 SDValue Inc = HG->getInc();
12811 SDValue Mask = HG->getMask();
12812 SDValue BasePtr = HG->getBasePtr();
12813 SDValue Index = HG->getIndex();
12814 SDLoc DL(HG);
12815
12816 EVT MemVT = HG->getMemoryVT();
12817 MachineMemOperand *MMO = HG->getMemOperand();
12818 ISD::MemIndexType IndexType = HG->getIndexType();
12819
12820 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
12821 return Chain;
12822
12823 SDValue Ops[] = {Chain, Inc, Mask, BasePtr, Index,
12824 HG->getScale(), HG->getIntID()};
12825 if (refineUniformBase(BasePtr, Index, HG->isIndexScaled(), DAG, DL))
12826 return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
12827 MMO, IndexType);
12828
12829 EVT DataVT = Index.getValueType();
12830 if (refineIndexType(Index, IndexType, DataVT, DAG))
12831 return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
12832 MMO, IndexType);
12833 return SDValue();
12834 }
12835
visitPARTIAL_REDUCE_MLA(SDNode * N)12836 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12837 if (SDValue Res = foldPartialReduceMLAMulOp(N))
12838 return Res;
12839 if (SDValue Res = foldPartialReduceAdd(N))
12840 return Res;
12841 return SDValue();
12842 }
12843
12844 // partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
12845 // -> partial_reduce_*mla(acc, a, b)
12846 //
12847 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12848 // -> partial_reduce_*mla(acc, x, C)
foldPartialReduceMLAMulOp(SDNode * N)12849 SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
12850 SDLoc DL(N);
12851 auto *Context = DAG.getContext();
12852 SDValue Acc = N->getOperand(0);
12853 SDValue Op1 = N->getOperand(1);
12854 SDValue Op2 = N->getOperand(2);
12855
12856 APInt C;
12857 if (Op1->getOpcode() != ISD::MUL ||
12858 !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
12859 return SDValue();
12860
12861 SDValue LHS = Op1->getOperand(0);
12862 SDValue RHS = Op1->getOperand(1);
12863 unsigned LHSOpcode = LHS->getOpcode();
12864 if (!ISD::isExtOpcode(LHSOpcode))
12865 return SDValue();
12866
12867 SDValue LHSExtOp = LHS->getOperand(0);
12868 EVT LHSExtOpVT = LHSExtOp.getValueType();
12869
12870 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12871 // -> partial_reduce_*mla(acc, x, C)
12872 if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
12873 // TODO: Make use of partial_reduce_sumla here
12874 APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
12875 unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
12876 if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
12877 (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
12878 return SDValue();
12879
12880 unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
12881 ? ISD::PARTIAL_REDUCE_SMLA
12882 : ISD::PARTIAL_REDUCE_UMLA;
12883
12884 // Only perform these combines if the target supports folding
12885 // the extends into the operation.
12886 if (!TLI.isPartialReduceMLALegalOrCustom(
12887 NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12888 TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12889 return SDValue();
12890
12891 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12892 DAG.getConstant(CTrunc, DL, LHSExtOpVT));
12893 }
12894
12895 unsigned RHSOpcode = RHS->getOpcode();
12896 if (!ISD::isExtOpcode(RHSOpcode))
12897 return SDValue();
12898
12899 SDValue RHSExtOp = RHS->getOperand(0);
12900 if (LHSExtOpVT != RHSExtOp.getValueType())
12901 return SDValue();
12902
12903 unsigned NewOpc;
12904 if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
12905 NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12906 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12907 NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12908 else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12909 NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12910 else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
12911 NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12912 std::swap(LHSExtOp, RHSExtOp);
12913 } else
12914 return SDValue();
12915 // For a 2-stage extend the signedness of both of the extends must match
12916 // If the mul has the same type, there is no outer extend, and thus we
12917 // can simply use the inner extends to pick the result node.
12918 // TODO: extend to handle nonneg zext as sext
12919 EVT AccElemVT = Acc.getValueType().getVectorElementType();
12920 if (Op1.getValueType().getVectorElementType() != AccElemVT &&
12921 NewOpc != N->getOpcode())
12922 return SDValue();
12923
12924 // Only perform these combines if the target supports folding
12925 // the extends into the operation.
12926 if (!TLI.isPartialReduceMLALegalOrCustom(
12927 NewOpc, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12928 TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
12929 return SDValue();
12930
12931 return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
12932 }
12933
12934 // partial.reduce.umla(acc, zext(op), splat(1))
12935 // -> partial.reduce.umla(acc, op, splat(trunc(1)))
12936 // partial.reduce.smla(acc, sext(op), splat(1))
12937 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
12938 // partial.reduce.sumla(acc, sext(op), splat(1))
12939 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
foldPartialReduceAdd(SDNode * N)12940 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
12941 SDLoc DL(N);
12942 SDValue Acc = N->getOperand(0);
12943 SDValue Op1 = N->getOperand(1);
12944 SDValue Op2 = N->getOperand(2);
12945
12946 APInt ConstantOne;
12947 if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12948 !ConstantOne.isOne())
12949 return SDValue();
12950
12951 unsigned Op1Opcode = Op1.getOpcode();
12952 if (!ISD::isExtOpcode(Op1Opcode))
12953 return SDValue();
12954
12955 bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12956 bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
12957 EVT AccElemVT = Acc.getValueType().getVectorElementType();
12958 if (Op1IsSigned != NodeIsSigned &&
12959 Op1.getValueType().getVectorElementType() != AccElemVT)
12960 return SDValue();
12961
12962 unsigned NewOpcode =
12963 Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12964
12965 SDValue UnextOp1 = Op1.getOperand(0);
12966 EVT UnextOp1VT = UnextOp1.getValueType();
12967 auto *Context = DAG.getContext();
12968 if (!TLI.isPartialReduceMLALegalOrCustom(
12969 NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12970 TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
12971 return SDValue();
12972
12973 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12974 DAG.getConstant(1, DL, UnextOp1VT));
12975 }
12976
visitVP_STRIDED_LOAD(SDNode * N)12977 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12978 auto *SLD = cast<VPStridedLoadSDNode>(N);
12979 EVT EltVT = SLD->getValueType(0).getVectorElementType();
12980 // Combine strided loads with unit-stride to a regular VP load.
12981 if (auto *CStride = dyn_cast<ConstantSDNode>(SLD->getStride());
12982 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
12983 SDValue NewLd = DAG.getLoadVP(
12984 SLD->getAddressingMode(), SLD->getExtensionType(), SLD->getValueType(0),
12985 SDLoc(N), SLD->getChain(), SLD->getBasePtr(), SLD->getOffset(),
12986 SLD->getMask(), SLD->getVectorLength(), SLD->getMemoryVT(),
12987 SLD->getMemOperand(), SLD->isExpandingLoad());
12988 return CombineTo(N, NewLd, NewLd.getValue(1));
12989 }
12990 return SDValue();
12991 }
12992
12993 /// A vector select of 2 constant vectors can be simplified to math/logic to
12994 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)12995 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12996 SDValue Cond = N->getOperand(0);
12997 SDValue N1 = N->getOperand(1);
12998 SDValue N2 = N->getOperand(2);
12999 EVT VT = N->getValueType(0);
13000 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
13001 !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
13002 !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
13003 !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
13004 return SDValue();
13005
13006 // Check if we can use the condition value to increment/decrement a single
13007 // constant value. This simplifies a select to an add and removes a constant
13008 // load/materialization from the general case.
13009 bool AllAddOne = true;
13010 bool AllSubOne = true;
13011 unsigned Elts = VT.getVectorNumElements();
13012 for (unsigned i = 0; i != Elts; ++i) {
13013 SDValue N1Elt = N1.getOperand(i);
13014 SDValue N2Elt = N2.getOperand(i);
13015 if (N1Elt.isUndef())
13016 continue;
13017 // N2 should not contain undef values since it will be reused in the fold.
13018 if (N2Elt.isUndef() || N1Elt.getValueType() != N2Elt.getValueType()) {
13019 AllAddOne = false;
13020 AllSubOne = false;
13021 break;
13022 }
13023
13024 const APInt &C1 = N1Elt->getAsAPIntVal();
13025 const APInt &C2 = N2Elt->getAsAPIntVal();
13026 if (C1 != C2 + 1)
13027 AllAddOne = false;
13028 if (C1 != C2 - 1)
13029 AllSubOne = false;
13030 }
13031
13032 // Further simplifications for the extra-special cases where the constants are
13033 // all 0 or all -1 should be implemented as folds of these patterns.
13034 SDLoc DL(N);
13035 if (AllAddOne || AllSubOne) {
13036 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
13037 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
13038 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
13039 SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
13040 return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
13041 }
13042
13043 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
13044 APInt Pow2C;
13045 if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
13046 isNullOrNullSplat(N2)) {
13047 SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
13048 SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
13049 return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
13050 }
13051
13052 if (SDValue V = foldSelectOfConstantsUsingSra(N, DL, DAG))
13053 return V;
13054
13055 // The general case for select-of-constants:
13056 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
13057 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
13058 // leave that to a machine-specific pass.
13059 return SDValue();
13060 }
13061
visitVP_SELECT(SDNode * N)13062 SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
13063 SDValue N0 = N->getOperand(0);
13064 SDValue N1 = N->getOperand(1);
13065 SDValue N2 = N->getOperand(2);
13066 SDLoc DL(N);
13067
13068 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
13069 return V;
13070
13071 if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DL, DAG))
13072 return V;
13073
13074 return SDValue();
13075 }
13076
combineVSelectWithAllOnesOrZeros(SDValue Cond,SDValue TVal,SDValue FVal,const TargetLowering & TLI,SelectionDAG & DAG,const SDLoc & DL)13077 static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
13078 SDValue FVal,
13079 const TargetLowering &TLI,
13080 SelectionDAG &DAG,
13081 const SDLoc &DL) {
13082 EVT VT = TVal.getValueType();
13083 if (!TLI.isTypeLegal(VT))
13084 return SDValue();
13085
13086 EVT CondVT = Cond.getValueType();
13087 assert(CondVT.isVector() && "Vector select expects a vector selector!");
13088
13089 bool IsTAllZero = ISD::isConstantSplatVectorAllZeros(TVal.getNode());
13090 bool IsTAllOne = ISD::isConstantSplatVectorAllOnes(TVal.getNode());
13091 bool IsFAllZero = ISD::isConstantSplatVectorAllZeros(FVal.getNode());
13092 bool IsFAllOne = ISD::isConstantSplatVectorAllOnes(FVal.getNode());
13093
13094 // no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
13095 if (!IsTAllZero && !IsTAllOne && !IsFAllZero && !IsFAllOne)
13096 return SDValue();
13097
13098 // select Cond, 0, 0 → 0
13099 if (IsTAllZero && IsFAllZero) {
13100 return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, VT)
13101 : DAG.getConstant(0, DL, VT);
13102 }
13103
13104 // check select(setgt lhs, -1), 1, -1 --> or (sra lhs, bitwidth - 1), 1
13105 APInt TValAPInt;
13106 if (Cond.getOpcode() == ISD::SETCC &&
13107 Cond.getOperand(2) == DAG.getCondCode(ISD::SETGT) &&
13108 Cond.getOperand(0).getValueType() == VT && VT.isSimple() &&
13109 ISD::isConstantSplatVector(TVal.getNode(), TValAPInt) &&
13110 TValAPInt.isOne() &&
13111 ISD::isConstantSplatVectorAllOnes(Cond.getOperand(1).getNode()) &&
13112 ISD::isConstantSplatVectorAllOnes(FVal.getNode())) {
13113 return SDValue();
13114 }
13115
13116 // To use the condition operand as a bitwise mask, it must have elements that
13117 // are the same size as the select elements. i.e, the condition operand must
13118 // have already been promoted from the IR select condition type <N x i1>.
13119 // Don't check if the types themselves are equal because that excludes
13120 // vector floating-point selects.
13121 if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
13122 return SDValue();
13123
13124 // Cond value must be 'sign splat' to be converted to a logical op.
13125 if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
13126 return SDValue();
13127
13128 // Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
13129 if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
13130 Cond.getOpcode() == ISD::SETCC &&
13131 TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
13132 CondVT) {
13133 if (IsTAllZero || IsFAllOne) {
13134 SDValue CC = Cond.getOperand(2);
13135 ISD::CondCode InverseCC = ISD::getSetCCInverse(
13136 cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
13137 Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
13138 InverseCC);
13139 std::swap(TVal, FVal);
13140 std::swap(IsTAllOne, IsFAllOne);
13141 std::swap(IsTAllZero, IsFAllZero);
13142 }
13143 }
13144
13145 assert(DAG.ComputeNumSignBits(Cond) == CondVT.getScalarSizeInBits() &&
13146 "Select condition no longer all-sign bits");
13147
13148 // select Cond, -1, 0 → bitcast Cond
13149 if (IsTAllOne && IsFAllZero)
13150 return DAG.getBitcast(VT, Cond);
13151
13152 // select Cond, -1, x → or Cond, x
13153 if (IsTAllOne) {
13154 SDValue X = DAG.getBitcast(CondVT, FVal);
13155 SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, X);
13156 return DAG.getBitcast(VT, Or);
13157 }
13158
13159 // select Cond, x, 0 → and Cond, x
13160 if (IsFAllZero) {
13161 SDValue X = DAG.getBitcast(CondVT, TVal);
13162 SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, X);
13163 return DAG.getBitcast(VT, And);
13164 }
13165
13166 // select Cond, 0, x -> and not(Cond), x
13167 if (IsTAllZero &&
13168 (isBitwiseNot(peekThroughBitcasts(Cond)) || TLI.hasAndNot(Cond))) {
13169 SDValue X = DAG.getBitcast(CondVT, FVal);
13170 SDValue And =
13171 DAG.getNode(ISD::AND, DL, CondVT, DAG.getNOT(DL, Cond, CondVT), X);
13172 return DAG.getBitcast(VT, And);
13173 }
13174
13175 return SDValue();
13176 }
13177
visitVSELECT(SDNode * N)13178 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
13179 SDValue N0 = N->getOperand(0);
13180 SDValue N1 = N->getOperand(1);
13181 SDValue N2 = N->getOperand(2);
13182 EVT VT = N->getValueType(0);
13183 SDLoc DL(N);
13184
13185 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
13186 return V;
13187
13188 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG))
13189 return V;
13190
13191 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
13192 if (!TLI.isTargetCanonicalSelect(N))
13193 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
13194 return DAG.getSelect(DL, VT, F, N2, N1);
13195
13196 // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
13197 if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&
13198 DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1)) &&
13199 N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits() &&
13200 TLI.getBooleanContents(N0.getValueType()) ==
13201 TargetLowering::ZeroOrNegativeOneBooleanContent) {
13202 return DAG.getNode(
13203 ISD::ADD, DL, N1.getValueType(), N2,
13204 DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1), N0));
13205 }
13206
13207 // Canonicalize integer abs.
13208 // vselect (setg[te] X, 0), X, -X ->
13209 // vselect (setgt X, -1), X, -X ->
13210 // vselect (setl[te] X, 0), -X, X ->
13211 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
13212 if (N0.getOpcode() == ISD::SETCC) {
13213 SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
13214 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
13215 bool isAbs = false;
13216 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
13217
13218 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
13219 (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
13220 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
13221 isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
13222 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
13223 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
13224 isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
13225
13226 if (isAbs) {
13227 if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
13228 return DAG.getNode(ISD::ABS, DL, VT, LHS);
13229
13230 SDValue Shift = DAG.getNode(
13231 ISD::SRA, DL, VT, LHS,
13232 DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
13233 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
13234 AddToWorklist(Shift.getNode());
13235 AddToWorklist(Add.getNode());
13236 return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
13237 }
13238
13239 // vselect x, y (fcmp lt x, y) -> fminnum x, y
13240 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
13241 //
13242 // This is OK if we don't care about what happens if either operand is a
13243 // NaN.
13244 //
13245 if (N0.hasOneUse() &&
13246 isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, N->getFlags(), TLI)) {
13247 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
13248 return FMinMax;
13249 }
13250
13251 if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
13252 return S;
13253 if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
13254 return S;
13255
13256 // If this select has a condition (setcc) with narrower operands than the
13257 // select, try to widen the compare to match the select width.
13258 // TODO: This should be extended to handle any constant.
13259 // TODO: This could be extended to handle non-loading patterns, but that
13260 // requires thorough testing to avoid regressions.
13261 if (isNullOrNullSplat(RHS)) {
13262 EVT NarrowVT = LHS.getValueType();
13263 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
13264 EVT SetCCVT = getSetCCResultType(LHS.getValueType());
13265 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
13266 unsigned WideWidth = WideVT.getScalarSizeInBits();
13267 bool IsSigned = isSignedIntSetCC(CC);
13268 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13269 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
13270 SetCCWidth != 1 && SetCCWidth < WideWidth &&
13271 TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
13272 TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
13273 // Both compare operands can be widened for free. The LHS can use an
13274 // extended load, and the RHS is a constant:
13275 // vselect (ext (setcc load(X), C)), N1, N2 -->
13276 // vselect (setcc extload(X), C'), N1, N2
13277 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13278 SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
13279 SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
13280 EVT WideSetCCVT = getSetCCResultType(WideVT);
13281 SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
13282 return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
13283 }
13284 }
13285
13286 if (SDValue ABD = foldSelectToABD(LHS, RHS, N1, N2, CC, DL))
13287 return ABD;
13288
13289 // Match VSELECTs into add with unsigned saturation.
13290 if (hasOperation(ISD::UADDSAT, VT)) {
13291 // Check if one of the arms of the VSELECT is vector with all bits set.
13292 // If it's on the left side invert the predicate to simplify logic below.
13293 SDValue Other;
13294 ISD::CondCode SatCC = CC;
13295 if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
13296 Other = N2;
13297 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
13298 } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
13299 Other = N1;
13300 }
13301
13302 if (Other && Other.getOpcode() == ISD::ADD) {
13303 SDValue CondLHS = LHS, CondRHS = RHS;
13304 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
13305
13306 // Canonicalize condition operands.
13307 if (SatCC == ISD::SETUGE) {
13308 std::swap(CondLHS, CondRHS);
13309 SatCC = ISD::SETULE;
13310 }
13311
13312 // We can test against either of the addition operands.
13313 // x <= x+y ? x+y : ~0 --> uaddsat x, y
13314 // x+y >= x ? x+y : ~0 --> uaddsat x, y
13315 if (SatCC == ISD::SETULE && Other == CondRHS &&
13316 (OpLHS == CondLHS || OpRHS == CondLHS))
13317 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
13318
13319 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
13320 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13321 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
13322 CondLHS == OpLHS) {
13323 // If the RHS is a constant we have to reverse the const
13324 // canonicalization.
13325 // x >= ~C ? x+C : ~0 --> uaddsat x, C
13326 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13327 return Cond->getAPIntValue() == ~Op->getAPIntValue();
13328 };
13329 if (SatCC == ISD::SETULE &&
13330 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
13331 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
13332 }
13333 }
13334 }
13335
13336 // Match VSELECTs into sub with unsigned saturation.
13337 if (hasOperation(ISD::USUBSAT, VT)) {
13338 // Check if one of the arms of the VSELECT is a zero vector. If it's on
13339 // the left side invert the predicate to simplify logic below.
13340 SDValue Other;
13341 ISD::CondCode SatCC = CC;
13342 if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
13343 Other = N2;
13344 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
13345 } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
13346 Other = N1;
13347 }
13348
13349 // zext(x) >= y ? trunc(zext(x) - y) : 0
13350 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13351 // zext(x) > y ? trunc(zext(x) - y) : 0
13352 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
13353 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
13354 Other.getOperand(0).getOpcode() == ISD::SUB &&
13355 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
13356 SDValue OpLHS = Other.getOperand(0).getOperand(0);
13357 SDValue OpRHS = Other.getOperand(0).getOperand(1);
13358 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
13359 if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
13360 DAG, DL))
13361 return R;
13362 }
13363
13364 if (Other && Other.getNumOperands() == 2) {
13365 SDValue CondRHS = RHS;
13366 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
13367
13368 if (OpLHS == LHS) {
13369 // Look for a general sub with unsigned saturation first.
13370 // x >= y ? x-y : 0 --> usubsat x, y
13371 // x > y ? x-y : 0 --> usubsat x, y
13372 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
13373 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
13374 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
13375
13376 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
13377 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13378 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
13379 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
13380 // If the RHS is a constant we have to reverse the const
13381 // canonicalization.
13382 // x > C-1 ? x+-C : 0 --> usubsat x, C
13383 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
13384 return (!Op && !Cond) ||
13385 (Op && Cond &&
13386 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
13387 };
13388 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
13389 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
13390 /*AllowUndefs*/ true)) {
13391 OpRHS = DAG.getNegative(OpRHS, DL, VT);
13392 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
13393 }
13394
13395 // Another special case: If C was a sign bit, the sub has been
13396 // canonicalized into a xor.
13397 // FIXME: Would it be better to use computeKnownBits to
13398 // determine whether it's safe to decanonicalize the xor?
13399 // x s< 0 ? x^C : 0 --> usubsat x, C
13400 APInt SplatValue;
13401 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
13402 ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
13403 ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
13404 SplatValue.isSignMask()) {
13405 // Note that we have to rebuild the RHS constant here to
13406 // ensure we don't rely on particular values of undef lanes.
13407 OpRHS = DAG.getConstant(SplatValue, DL, VT);
13408 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
13409 }
13410 }
13411 }
13412 }
13413 }
13414 }
13415 }
13416
13417 if (SimplifySelectOps(N, N1, N2))
13418 return SDValue(N, 0); // Don't revisit N.
13419
13420 // Fold (vselect all_ones, N1, N2) -> N1
13421 if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
13422 return N1;
13423 // Fold (vselect all_zeros, N1, N2) -> N2
13424 if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
13425 return N2;
13426
13427 // The ConvertSelectToConcatVector function is assuming both the above
13428 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
13429 // and addressed.
13430 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
13431 N2.getOpcode() == ISD::CONCAT_VECTORS &&
13432 ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
13433 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
13434 return CV;
13435 }
13436
13437 if (SDValue V = foldVSelectOfConstants(N))
13438 return V;
13439
13440 if (hasOperation(ISD::SRA, VT))
13441 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
13442 return V;
13443
13444 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
13445 return SDValue(N, 0);
13446
13447 if (SDValue V = combineVSelectWithAllOnesOrZeros(N0, N1, N2, TLI, DAG, DL))
13448 return V;
13449
13450 return SDValue();
13451 }
13452
visitSELECT_CC(SDNode * N)13453 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
13454 SDValue N0 = N->getOperand(0);
13455 SDValue N1 = N->getOperand(1);
13456 SDValue N2 = N->getOperand(2);
13457 SDValue N3 = N->getOperand(3);
13458 SDValue N4 = N->getOperand(4);
13459 ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
13460 SDLoc DL(N);
13461
13462 // fold select_cc lhs, rhs, x, x, cc -> x
13463 if (N2 == N3)
13464 return N2;
13465
13466 // select_cc bool, 0, x, y, seteq -> select bool, y, x
13467 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
13468 isNullConstant(N1))
13469 return DAG.getSelect(DL, N2.getValueType(), N0, N3, N2);
13470
13471 // Determine if the condition we're dealing with is constant
13472 if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
13473 CC, DL, false)) {
13474 AddToWorklist(SCC.getNode());
13475
13476 // cond always true -> true val
13477 // cond always false -> false val
13478 if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
13479 return SCCC->isZero() ? N3 : N2;
13480
13481 // When the condition is UNDEF, just return the first operand. This is
13482 // coherent the DAG creation, no setcc node is created in this case
13483 if (SCC->isUndef())
13484 return N2;
13485
13486 // Fold to a simpler select_cc
13487 if (SCC.getOpcode() == ISD::SETCC) {
13488 return DAG.getNode(ISD::SELECT_CC, DL, N2.getValueType(),
13489 SCC.getOperand(0), SCC.getOperand(1), N2, N3,
13490 SCC.getOperand(2), SCC->getFlags());
13491 }
13492 }
13493
13494 // If we can fold this based on the true/false value, do so.
13495 if (SimplifySelectOps(N, N2, N3))
13496 return SDValue(N, 0); // Don't revisit N.
13497
13498 // fold select_cc into other things, such as min/max/abs
13499 return SimplifySelectCC(DL, N0, N1, N2, N3, CC);
13500 }
13501
visitSETCC(SDNode * N)13502 SDValue DAGCombiner::visitSETCC(SDNode *N) {
13503 // setcc is very commonly used as an argument to brcond. This pattern
13504 // also lend itself to numerous combines and, as a result, it is desired
13505 // we keep the argument to a brcond as a setcc as much as possible.
13506 bool PreferSetCC =
13507 N->hasOneUse() && N->user_begin()->getOpcode() == ISD::BRCOND;
13508
13509 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
13510 EVT VT = N->getValueType(0);
13511 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
13512 SDLoc DL(N);
13513
13514 if (SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL, !PreferSetCC)) {
13515 // If we prefer to have a setcc, and we don't, we'll try our best to
13516 // recreate one using rebuildSetCC.
13517 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
13518 SDValue NewSetCC = rebuildSetCC(Combined);
13519
13520 // We don't have anything interesting to combine to.
13521 if (NewSetCC.getNode() == N)
13522 return SDValue();
13523
13524 if (NewSetCC)
13525 return NewSetCC;
13526 }
13527 return Combined;
13528 }
13529
13530 // Optimize
13531 // 1) (icmp eq/ne (and X, C0), (shift X, C1))
13532 // or
13533 // 2) (icmp eq/ne X, (rotate X, C1))
13534 // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
13535 // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
13536 // Then:
13537 // If C1 is a power of 2, then the rotate and shift+and versions are
13538 // equivilent, so we can interchange them depending on target preference.
13539 // Otherwise, if we have the shift+and version we can interchange srl/shl
13540 // which inturn affects the constant C0. We can use this to get better
13541 // constants again determined by target preference.
13542 if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
13543 auto IsAndWithShift = [](SDValue A, SDValue B) {
13544 return A.getOpcode() == ISD::AND &&
13545 (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
13546 A.getOperand(0) == B.getOperand(0);
13547 };
13548 auto IsRotateWithOp = [](SDValue A, SDValue B) {
13549 return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
13550 B.getOperand(0) == A;
13551 };
13552 SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
13553 bool IsRotate = false;
13554
13555 // Find either shift+and or rotate pattern.
13556 if (IsAndWithShift(N0, N1)) {
13557 AndOrOp = N0;
13558 ShiftOrRotate = N1;
13559 } else if (IsAndWithShift(N1, N0)) {
13560 AndOrOp = N1;
13561 ShiftOrRotate = N0;
13562 } else if (IsRotateWithOp(N0, N1)) {
13563 IsRotate = true;
13564 AndOrOp = N0;
13565 ShiftOrRotate = N1;
13566 } else if (IsRotateWithOp(N1, N0)) {
13567 IsRotate = true;
13568 AndOrOp = N1;
13569 ShiftOrRotate = N0;
13570 }
13571
13572 if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
13573 (IsRotate || AndOrOp.hasOneUse())) {
13574 EVT OpVT = N0.getValueType();
13575 // Get constant shift/rotate amount and possibly mask (if its shift+and
13576 // variant).
13577 auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
13578 ConstantSDNode *CNode = isConstOrConstSplat(Op, /*AllowUndefs*/ false,
13579 /*AllowTrunc*/ false);
13580 if (CNode == nullptr)
13581 return std::nullopt;
13582 return CNode->getAPIntValue();
13583 };
13584 std::optional<APInt> AndCMask =
13585 IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(1));
13586 std::optional<APInt> ShiftCAmt =
13587 GetAPIntValue(ShiftOrRotate.getOperand(1));
13588 unsigned NumBits = OpVT.getScalarSizeInBits();
13589
13590 // We found constants.
13591 if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(NumBits)) {
13592 unsigned ShiftOpc = ShiftOrRotate.getOpcode();
13593 // Check that the constants meet the constraints.
13594 bool CanTransform = IsRotate;
13595 if (!CanTransform) {
13596 // Check that mask and shift compliment eachother
13597 CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
13598 // Check that we are comparing all bits
13599 CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
13600 // Check that the and mask is correct for the shift
13601 CanTransform &=
13602 ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
13603 }
13604
13605 // See if target prefers another shift/rotate opcode.
13606 unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
13607 OpVT, ShiftOpc, ShiftCAmt->isPowerOf2(), *ShiftCAmt, AndCMask);
13608 // Transform is valid and we have a new preference.
13609 if (CanTransform && NewShiftOpc != ShiftOpc) {
13610 SDValue NewShiftOrRotate =
13611 DAG.getNode(NewShiftOpc, DL, OpVT, ShiftOrRotate.getOperand(0),
13612 ShiftOrRotate.getOperand(1));
13613 SDValue NewAndOrOp = SDValue();
13614
13615 if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
13616 APInt NewMask =
13617 NewShiftOpc == ISD::SHL
13618 ? APInt::getHighBitsSet(NumBits,
13619 NumBits - ShiftCAmt->getZExtValue())
13620 : APInt::getLowBitsSet(NumBits,
13621 NumBits - ShiftCAmt->getZExtValue());
13622 NewAndOrOp =
13623 DAG.getNode(ISD::AND, DL, OpVT, ShiftOrRotate.getOperand(0),
13624 DAG.getConstant(NewMask, DL, OpVT));
13625 } else {
13626 NewAndOrOp = ShiftOrRotate.getOperand(0);
13627 }
13628
13629 return DAG.getSetCC(DL, VT, NewAndOrOp, NewShiftOrRotate, Cond);
13630 }
13631 }
13632 }
13633 }
13634 return SDValue();
13635 }
13636
visitSETCCCARRY(SDNode * N)13637 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
13638 SDValue LHS = N->getOperand(0);
13639 SDValue RHS = N->getOperand(1);
13640 SDValue Carry = N->getOperand(2);
13641 SDValue Cond = N->getOperand(3);
13642
13643 // If Carry is false, fold to a regular SETCC.
13644 if (isNullConstant(Carry))
13645 return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
13646
13647 return SDValue();
13648 }
13649
13650 /// Check if N satisfies:
13651 /// N is used once.
13652 /// N is a Load.
13653 /// The load is compatible with ExtOpcode. It means
13654 /// If load has explicit zero/sign extension, ExpOpcode must have the same
13655 /// extension.
13656 /// Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)13657 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
13658 if (!N.hasOneUse())
13659 return false;
13660
13661 if (!isa<LoadSDNode>(N))
13662 return false;
13663
13664 LoadSDNode *Load = cast<LoadSDNode>(N);
13665 ISD::LoadExtType LoadExt = Load->getExtensionType();
13666 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
13667 return true;
13668
13669 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
13670 // extension.
13671 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
13672 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
13673 return false;
13674
13675 return true;
13676 }
13677
13678 /// Fold
13679 /// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
13680 /// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
13681 /// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
13682 /// This function is called by the DAGCombiner when visiting sext/zext/aext
13683 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,const SDLoc & DL,CombineLevel Level)13684 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
13685 SelectionDAG &DAG, const SDLoc &DL,
13686 CombineLevel Level) {
13687 unsigned Opcode = N->getOpcode();
13688 SDValue N0 = N->getOperand(0);
13689 EVT VT = N->getValueType(0);
13690 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
13691 Opcode == ISD::ANY_EXTEND) &&
13692 "Expected EXTEND dag node in input!");
13693
13694 if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
13695 !N0.hasOneUse())
13696 return SDValue();
13697
13698 SDValue Op1 = N0->getOperand(1);
13699 SDValue Op2 = N0->getOperand(2);
13700 if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
13701 return SDValue();
13702
13703 auto ExtLoadOpcode = ISD::EXTLOAD;
13704 if (Opcode == ISD::SIGN_EXTEND)
13705 ExtLoadOpcode = ISD::SEXTLOAD;
13706 else if (Opcode == ISD::ZERO_EXTEND)
13707 ExtLoadOpcode = ISD::ZEXTLOAD;
13708
13709 // Illegal VSELECT may ISel fail if happen after legalization (DAG
13710 // Combine2), so we should conservatively check the OperationAction.
13711 LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
13712 LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
13713 if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
13714 !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) ||
13715 (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
13716 TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal))
13717 return SDValue();
13718
13719 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
13720 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
13721 return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
13722 }
13723
13724 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
13725 /// a build_vector of constants.
13726 /// This function is called by the DAGCombiner when visiting sext/zext/aext
13727 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
13728 /// Vector extends are not folded if operations are legal; this is to
13729 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)13730 static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
13731 const TargetLowering &TLI,
13732 SelectionDAG &DAG, bool LegalTypes) {
13733 unsigned Opcode = N->getOpcode();
13734 SDValue N0 = N->getOperand(0);
13735 EVT VT = N->getValueType(0);
13736
13737 assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
13738 "Expected EXTEND dag node in input!");
13739
13740 // fold (sext c1) -> c1
13741 // fold (zext c1) -> c1
13742 // fold (aext c1) -> c1
13743 if (isa<ConstantSDNode>(N0))
13744 return DAG.getNode(Opcode, DL, VT, N0);
13745
13746 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
13747 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
13748 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
13749 if (N0->getOpcode() == ISD::SELECT) {
13750 SDValue Op1 = N0->getOperand(1);
13751 SDValue Op2 = N0->getOperand(2);
13752 if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
13753 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
13754 // For any_extend, choose sign extension of the constants to allow a
13755 // possible further transform to sign_extend_inreg.i.e.
13756 //
13757 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
13758 // t2: i64 = any_extend t1
13759 // -->
13760 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
13761 // -->
13762 // t4: i64 = sign_extend_inreg t3
13763 unsigned FoldOpc = Opcode;
13764 if (FoldOpc == ISD::ANY_EXTEND)
13765 FoldOpc = ISD::SIGN_EXTEND;
13766 return DAG.getSelect(DL, VT, N0->getOperand(0),
13767 DAG.getNode(FoldOpc, DL, VT, Op1),
13768 DAG.getNode(FoldOpc, DL, VT, Op2));
13769 }
13770 }
13771
13772 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
13773 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
13774 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
13775 EVT SVT = VT.getScalarType();
13776 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
13777 ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
13778 return SDValue();
13779
13780 // We can fold this node into a build_vector.
13781 unsigned VTBits = SVT.getSizeInBits();
13782 unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
13783 SmallVector<SDValue, 8> Elts;
13784 unsigned NumElts = VT.getVectorNumElements();
13785
13786 for (unsigned i = 0; i != NumElts; ++i) {
13787 SDValue Op = N0.getOperand(i);
13788 if (Op.isUndef()) {
13789 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
13790 Elts.push_back(DAG.getUNDEF(SVT));
13791 else
13792 Elts.push_back(DAG.getConstant(0, DL, SVT));
13793 continue;
13794 }
13795
13796 SDLoc DL(Op);
13797 // Get the constant value and if needed trunc it to the size of the type.
13798 // Nodes like build_vector might have constants wider than the scalar type.
13799 APInt C = Op->getAsAPIntVal().zextOrTrunc(EVTBits);
13800 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
13801 Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
13802 else
13803 Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
13804 }
13805
13806 return DAG.getBuildVector(VT, DL, Elts);
13807 }
13808
13809 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
13810 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
13811 // transformation. Returns true if extension are possible and the above
13812 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)13813 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
13814 unsigned ExtOpc,
13815 SmallVectorImpl<SDNode *> &ExtendNodes,
13816 const TargetLowering &TLI) {
13817 bool HasCopyToRegUses = false;
13818 bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
13819 for (SDUse &Use : N0->uses()) {
13820 SDNode *User = Use.getUser();
13821 if (User == N)
13822 continue;
13823 if (Use.getResNo() != N0.getResNo())
13824 continue;
13825 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
13826 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
13827 ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
13828 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
13829 // Sign bits will be lost after a zext.
13830 return false;
13831 bool Add = false;
13832 for (unsigned i = 0; i != 2; ++i) {
13833 SDValue UseOp = User->getOperand(i);
13834 if (UseOp == N0)
13835 continue;
13836 if (!isa<ConstantSDNode>(UseOp))
13837 return false;
13838 Add = true;
13839 }
13840 if (Add)
13841 ExtendNodes.push_back(User);
13842 continue;
13843 }
13844 // If truncates aren't free and there are users we can't
13845 // extend, it isn't worthwhile.
13846 if (!isTruncFree)
13847 return false;
13848 // Remember if this value is live-out.
13849 if (User->getOpcode() == ISD::CopyToReg)
13850 HasCopyToRegUses = true;
13851 }
13852
13853 if (HasCopyToRegUses) {
13854 bool BothLiveOut = false;
13855 for (SDUse &Use : N->uses()) {
13856 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
13857 BothLiveOut = true;
13858 break;
13859 }
13860 }
13861 if (BothLiveOut)
13862 // Both unextended and extended values are live out. There had better be
13863 // a good reason for the transformation.
13864 return !ExtendNodes.empty();
13865 }
13866 return true;
13867 }
13868
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)13869 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
13870 SDValue OrigLoad, SDValue ExtLoad,
13871 ISD::NodeType ExtType) {
13872 // Extend SetCC uses if necessary.
13873 SDLoc DL(ExtLoad);
13874 for (SDNode *SetCC : SetCCs) {
13875 SmallVector<SDValue, 4> Ops;
13876
13877 for (unsigned j = 0; j != 2; ++j) {
13878 SDValue SOp = SetCC->getOperand(j);
13879 if (SOp == OrigLoad)
13880 Ops.push_back(ExtLoad);
13881 else
13882 Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
13883 }
13884
13885 Ops.push_back(SetCC->getOperand(2));
13886 CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
13887 }
13888 }
13889
13890 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)13891 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
13892 SDValue N0 = N->getOperand(0);
13893 EVT DstVT = N->getValueType(0);
13894 EVT SrcVT = N0.getValueType();
13895
13896 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13897 N->getOpcode() == ISD::ZERO_EXTEND) &&
13898 "Unexpected node type (not an extend)!");
13899
13900 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
13901 // For example, on a target with legal v4i32, but illegal v8i32, turn:
13902 // (v8i32 (sext (v8i16 (load x))))
13903 // into:
13904 // (v8i32 (concat_vectors (v4i32 (sextload x)),
13905 // (v4i32 (sextload (x + 16)))))
13906 // Where uses of the original load, i.e.:
13907 // (v8i16 (load x))
13908 // are replaced with:
13909 // (v8i16 (truncate
13910 // (v8i32 (concat_vectors (v4i32 (sextload x)),
13911 // (v4i32 (sextload (x + 16)))))))
13912 //
13913 // This combine is only applicable to illegal, but splittable, vectors.
13914 // All legal types, and illegal non-vector types, are handled elsewhere.
13915 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
13916 //
13917 if (N0->getOpcode() != ISD::LOAD)
13918 return SDValue();
13919
13920 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13921
13922 if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
13923 !N0.hasOneUse() || !LN0->isSimple() ||
13924 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
13925 !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
13926 return SDValue();
13927
13928 SmallVector<SDNode *, 4> SetCCs;
13929 if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
13930 return SDValue();
13931
13932 ISD::LoadExtType ExtType =
13933 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13934
13935 // Try to split the vector types to get down to legal types.
13936 EVT SplitSrcVT = SrcVT;
13937 EVT SplitDstVT = DstVT;
13938 while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
13939 SplitSrcVT.getVectorNumElements() > 1) {
13940 SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
13941 SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
13942 }
13943
13944 if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
13945 return SDValue();
13946
13947 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
13948
13949 SDLoc DL(N);
13950 const unsigned NumSplits =
13951 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
13952 const unsigned Stride = SplitSrcVT.getStoreSize();
13953 SmallVector<SDValue, 4> Loads;
13954 SmallVector<SDValue, 4> Chains;
13955
13956 SDValue BasePtr = LN0->getBasePtr();
13957 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
13958 const unsigned Offset = Idx * Stride;
13959
13960 SDValue SplitLoad =
13961 DAG.getExtLoad(ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(),
13962 BasePtr, LN0->getPointerInfo().getWithOffset(Offset),
13963 SplitSrcVT, LN0->getBaseAlign(),
13964 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
13965
13966 BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::getFixed(Stride), DL);
13967
13968 Loads.push_back(SplitLoad.getValue(0));
13969 Chains.push_back(SplitLoad.getValue(1));
13970 }
13971
13972 SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
13973 SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
13974
13975 // Simplify TF.
13976 AddToWorklist(NewChain.getNode());
13977
13978 CombineTo(N, NewValue);
13979
13980 // Replace uses of the original load (before extension)
13981 // with a truncate of the concatenated sextloaded vectors.
13982 SDValue Trunc =
13983 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
13984 ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
13985 CombineTo(N0.getNode(), Trunc, NewChain);
13986 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13987 }
13988
13989 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13990 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)13991 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
13992 assert(N->getOpcode() == ISD::ZERO_EXTEND);
13993 EVT VT = N->getValueType(0);
13994 EVT OrigVT = N->getOperand(0).getValueType();
13995 if (TLI.isZExtFree(OrigVT, VT))
13996 return SDValue();
13997
13998 // and/or/xor
13999 SDValue N0 = N->getOperand(0);
14000 if (!ISD::isBitwiseLogicOp(N0.getOpcode()) ||
14001 N0.getOperand(1).getOpcode() != ISD::Constant ||
14002 (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
14003 return SDValue();
14004
14005 // shl/shr
14006 SDValue N1 = N0->getOperand(0);
14007 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
14008 N1.getOperand(1).getOpcode() != ISD::Constant ||
14009 (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
14010 return SDValue();
14011
14012 // load
14013 if (!isa<LoadSDNode>(N1.getOperand(0)))
14014 return SDValue();
14015 LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
14016 EVT MemVT = Load->getMemoryVT();
14017 if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
14018 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
14019 return SDValue();
14020
14021
14022 // If the shift op is SHL, the logic op must be AND, otherwise the result
14023 // will be wrong.
14024 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
14025 return SDValue();
14026
14027 if (!N0.hasOneUse() || !N1.hasOneUse())
14028 return SDValue();
14029
14030 SmallVector<SDNode*, 4> SetCCs;
14031 if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
14032 ISD::ZERO_EXTEND, SetCCs, TLI))
14033 return SDValue();
14034
14035 // Actually do the transformation.
14036 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
14037 Load->getChain(), Load->getBasePtr(),
14038 Load->getMemoryVT(), Load->getMemOperand());
14039
14040 SDLoc DL1(N1);
14041 SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
14042 N1.getOperand(1));
14043
14044 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14045 SDLoc DL0(N0);
14046 SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
14047 DAG.getConstant(Mask, DL0, VT));
14048
14049 ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
14050 CombineTo(N, And);
14051 if (SDValue(Load, 0).hasOneUse()) {
14052 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
14053 } else {
14054 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
14055 Load->getValueType(0), ExtLoad);
14056 CombineTo(Load, Trunc, ExtLoad.getValue(1));
14057 }
14058
14059 // N0 is dead at this point.
14060 recursivelyDeleteUnusedNodes(N0.getNode());
14061
14062 return SDValue(N,0); // Return N so it doesn't get rechecked!
14063 }
14064
14065 /// If we're narrowing or widening the result of a vector select and the final
14066 /// size is the same size as a setcc (compare) feeding the select, then try to
14067 /// apply the cast operation to the select's operands because matching vector
14068 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)14069 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
14070 unsigned CastOpcode = Cast->getOpcode();
14071 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
14072 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
14073 CastOpcode == ISD::FP_ROUND) &&
14074 "Unexpected opcode for vector select narrowing/widening");
14075
14076 // We only do this transform before legal ops because the pattern may be
14077 // obfuscated by target-specific operations after legalization. Do not create
14078 // an illegal select op, however, because that may be difficult to lower.
14079 EVT VT = Cast->getValueType(0);
14080 if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
14081 return SDValue();
14082
14083 SDValue VSel = Cast->getOperand(0);
14084 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
14085 VSel.getOperand(0).getOpcode() != ISD::SETCC)
14086 return SDValue();
14087
14088 // Does the setcc have the same vector size as the casted select?
14089 SDValue SetCC = VSel.getOperand(0);
14090 EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
14091 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
14092 return SDValue();
14093
14094 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
14095 SDValue A = VSel.getOperand(1);
14096 SDValue B = VSel.getOperand(2);
14097 SDValue CastA, CastB;
14098 SDLoc DL(Cast);
14099 if (CastOpcode == ISD::FP_ROUND) {
14100 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
14101 CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
14102 CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
14103 } else {
14104 CastA = DAG.getNode(CastOpcode, DL, VT, A);
14105 CastB = DAG.getNode(CastOpcode, DL, VT, B);
14106 }
14107 return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
14108 }
14109
14110 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14111 // fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)14112 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
14113 const TargetLowering &TLI, EVT VT,
14114 bool LegalOperations, SDNode *N,
14115 SDValue N0, ISD::LoadExtType ExtLoadType) {
14116 SDNode *N0Node = N0.getNode();
14117 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
14118 : ISD::isZEXTLoad(N0Node);
14119 if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
14120 !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
14121 return SDValue();
14122
14123 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14124 EVT MemVT = LN0->getMemoryVT();
14125 if ((LegalOperations || !LN0->isSimple() ||
14126 VT.isVector()) &&
14127 !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
14128 return SDValue();
14129
14130 SDValue ExtLoad =
14131 DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
14132 LN0->getBasePtr(), MemVT, LN0->getMemOperand());
14133 Combiner.CombineTo(N, ExtLoad);
14134 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14135 if (LN0->use_empty())
14136 Combiner.recursivelyDeleteUnusedNodes(LN0);
14137 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14138 }
14139
14140 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
14141 // Only generate vector extloads when 1) they're legal, and 2) they are
14142 // deemed desirable by the target. NonNegZExt can be set to true if a zero
14143 // extend has the nonneg flag to allow use of sextload if profitable.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc,bool NonNegZExt=false)14144 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
14145 const TargetLowering &TLI, EVT VT,
14146 bool LegalOperations, SDNode *N, SDValue N0,
14147 ISD::LoadExtType ExtLoadType,
14148 ISD::NodeType ExtOpc,
14149 bool NonNegZExt = false) {
14150 if (!ISD::isNON_EXTLoad(N0.getNode()) || !ISD::isUNINDEXEDLoad(N0.getNode()))
14151 return {};
14152
14153 // If this is zext nneg, see if it would make sense to treat it as a sext.
14154 if (NonNegZExt) {
14155 assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
14156 "Unexpected load type or opcode");
14157 for (SDNode *User : N0->users()) {
14158 if (User->getOpcode() == ISD::SETCC) {
14159 ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
14160 if (ISD::isSignedIntSetCC(CC)) {
14161 ExtLoadType = ISD::SEXTLOAD;
14162 ExtOpc = ISD::SIGN_EXTEND;
14163 break;
14164 }
14165 }
14166 }
14167 }
14168
14169 // TODO: isFixedLengthVector() should be removed and any negative effects on
14170 // code generation being the result of that target's implementation of
14171 // isVectorLoadExtDesirable().
14172 if ((LegalOperations || VT.isFixedLengthVector() ||
14173 !cast<LoadSDNode>(N0)->isSimple()) &&
14174 !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))
14175 return {};
14176
14177 bool DoXform = true;
14178 SmallVector<SDNode *, 4> SetCCs;
14179 if (!N0.hasOneUse())
14180 DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
14181 if (VT.isVector())
14182 DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
14183 if (!DoXform)
14184 return {};
14185
14186 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14187 SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
14188 LN0->getBasePtr(), N0.getValueType(),
14189 LN0->getMemOperand());
14190 Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
14191 // If the load value is used only by N, replace it via CombineTo N.
14192 bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
14193 Combiner.CombineTo(N, ExtLoad);
14194 if (NoReplaceTrunc) {
14195 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
14196 Combiner.recursivelyDeleteUnusedNodes(LN0);
14197 } else {
14198 SDValue Trunc =
14199 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
14200 Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
14201 }
14202 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14203 }
14204
14205 static SDValue
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)14206 tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
14207 bool LegalOperations, SDNode *N, SDValue N0,
14208 ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
14209 if (!N0.hasOneUse())
14210 return SDValue();
14211
14212 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
14213 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
14214 return SDValue();
14215
14216 if ((LegalOperations || !cast<MaskedLoadSDNode>(N0)->isSimple()) &&
14217 !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
14218 return SDValue();
14219
14220 if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
14221 return SDValue();
14222
14223 SDLoc dl(Ld);
14224 SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
14225 SDValue NewLoad = DAG.getMaskedLoad(
14226 VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
14227 PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
14228 ExtLoadType, Ld->isExpandingLoad());
14229 DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
14230 return NewLoad;
14231 }
14232
14233 // fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
tryToFoldExtOfAtomicLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDValue N0,ISD::LoadExtType ExtLoadType)14234 static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
14235 const TargetLowering &TLI, EVT VT,
14236 SDValue N0,
14237 ISD::LoadExtType ExtLoadType) {
14238 auto *ALoad = dyn_cast<AtomicSDNode>(N0);
14239 if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
14240 return {};
14241 EVT MemoryVT = ALoad->getMemoryVT();
14242 if (!TLI.isAtomicLoadExtLegal(ExtLoadType, VT, MemoryVT))
14243 return {};
14244 // Can't fold into ALoad if it is already extending differently.
14245 ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
14246 if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
14247 (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
14248 return {};
14249
14250 EVT OrigVT = ALoad->getValueType(0);
14251 assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
14252 auto *NewALoad = cast<AtomicSDNode>(DAG.getAtomicLoad(
14253 ExtLoadType, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
14254 ALoad->getBasePtr(), ALoad->getMemOperand()));
14255 DAG.ReplaceAllUsesOfValueWith(
14256 SDValue(ALoad, 0),
14257 DAG.getNode(ISD::TRUNCATE, SDLoc(ALoad), OrigVT, SDValue(NewALoad, 0)));
14258 // Update the chain uses.
14259 DAG.ReplaceAllUsesOfValueWith(SDValue(ALoad, 1), SDValue(NewALoad, 1));
14260 return SDValue(NewALoad, 0);
14261 }
14262
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)14263 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
14264 bool LegalOperations) {
14265 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
14266 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
14267
14268 SDValue SetCC = N->getOperand(0);
14269 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
14270 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
14271 return SDValue();
14272
14273 SDValue X = SetCC.getOperand(0);
14274 SDValue Ones = SetCC.getOperand(1);
14275 ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
14276 EVT VT = N->getValueType(0);
14277 EVT XVT = X.getValueType();
14278 // setge X, C is canonicalized to setgt, so we do not need to match that
14279 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
14280 // not require the 'not' op.
14281 if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
14282 // Invert and smear/shift the sign bit:
14283 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
14284 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
14285 SDLoc DL(N);
14286 unsigned ShCt = VT.getSizeInBits() - 1;
14287 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14288 if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
14289 SDValue NotX = DAG.getNOT(DL, X, VT);
14290 SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
14291 auto ShiftOpcode =
14292 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
14293 return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
14294 }
14295 }
14296 return SDValue();
14297 }
14298
foldSextSetcc(SDNode * N)14299 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
14300 SDValue N0 = N->getOperand(0);
14301 if (N0.getOpcode() != ISD::SETCC)
14302 return SDValue();
14303
14304 SDValue N00 = N0.getOperand(0);
14305 SDValue N01 = N0.getOperand(1);
14306 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
14307 EVT VT = N->getValueType(0);
14308 EVT N00VT = N00.getValueType();
14309 SDLoc DL(N);
14310
14311 // Propagate fast-math-flags.
14312 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14313
14314 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
14315 // the same size as the compared operands. Try to optimize sext(setcc())
14316 // if this is the case.
14317 if (VT.isVector() && !LegalOperations &&
14318 TLI.getBooleanContents(N00VT) ==
14319 TargetLowering::ZeroOrNegativeOneBooleanContent) {
14320 EVT SVT = getSetCCResultType(N00VT);
14321
14322 // If we already have the desired type, don't change it.
14323 if (SVT != N0.getValueType()) {
14324 // We know that the # elements of the results is the same as the
14325 // # elements of the compare (and the # elements of the compare result
14326 // for that matter). Check to see that they are the same size. If so,
14327 // we know that the element size of the sext'd result matches the
14328 // element size of the compare operands.
14329 if (VT.getSizeInBits() == SVT.getSizeInBits())
14330 return DAG.getSetCC(DL, VT, N00, N01, CC);
14331
14332 // If the desired elements are smaller or larger than the source
14333 // elements, we can use a matching integer vector type and then
14334 // truncate/sign extend.
14335 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
14336 if (SVT == MatchingVecType) {
14337 SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
14338 return DAG.getSExtOrTrunc(VsetCC, DL, VT);
14339 }
14340 }
14341
14342 // Try to eliminate the sext of a setcc by zexting the compare operands.
14343 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
14344 !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
14345 bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
14346 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
14347 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
14348
14349 // We have an unsupported narrow vector compare op that would be legal
14350 // if extended to the destination type. See if the compare operands
14351 // can be freely extended to the destination type.
14352 auto IsFreeToExtend = [&](SDValue V) {
14353 if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
14354 return true;
14355 // Match a simple, non-extended load that can be converted to a
14356 // legal {z/s}ext-load.
14357 // TODO: Allow widening of an existing {z/s}ext-load?
14358 if (!(ISD::isNON_EXTLoad(V.getNode()) &&
14359 ISD::isUNINDEXEDLoad(V.getNode()) &&
14360 cast<LoadSDNode>(V)->isSimple() &&
14361 TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
14362 return false;
14363
14364 // Non-chain users of this value must either be the setcc in this
14365 // sequence or extends that can be folded into the new {z/s}ext-load.
14366 for (SDUse &Use : V->uses()) {
14367 // Skip uses of the chain and the setcc.
14368 SDNode *User = Use.getUser();
14369 if (Use.getResNo() != 0 || User == N0.getNode())
14370 continue;
14371 // Extra users must have exactly the same cast we are about to create.
14372 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
14373 // is enhanced similarly.
14374 if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
14375 return false;
14376 }
14377 return true;
14378 };
14379
14380 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
14381 SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
14382 SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
14383 return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
14384 }
14385 }
14386 }
14387
14388 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
14389 // Here, T can be 1 or -1, depending on the type of the setcc and
14390 // getBooleanContents().
14391 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
14392
14393 // To determine the "true" side of the select, we need to know the high bit
14394 // of the value returned by the setcc if it evaluates to true.
14395 // If the type of the setcc is i1, then the true case of the select is just
14396 // sext(i1 1), that is, -1.
14397 // If the type of the setcc is larger (say, i8) then the value of the high
14398 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
14399 // of the appropriate width.
14400 SDValue ExtTrueVal = (SetCCWidth == 1)
14401 ? DAG.getAllOnesConstant(DL, VT)
14402 : DAG.getBoolConstant(true, DL, VT, N00VT);
14403 SDValue Zero = DAG.getConstant(0, DL, VT);
14404 if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
14405 return SCC;
14406
14407 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(N0, VT, TLI)) {
14408 EVT SetCCVT = getSetCCResultType(N00VT);
14409 // Don't do this transform for i1 because there's a select transform
14410 // that would reverse it.
14411 // TODO: We should not do this transform at all without a target hook
14412 // because a sext is likely cheaper than a select?
14413 if (SetCCVT.getScalarSizeInBits() != 1 &&
14414 (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
14415 SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
14416 return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
14417 }
14418 }
14419
14420 return SDValue();
14421 }
14422
visitSIGN_EXTEND(SDNode * N)14423 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
14424 SDValue N0 = N->getOperand(0);
14425 EVT VT = N->getValueType(0);
14426 SDLoc DL(N);
14427
14428 if (VT.isVector())
14429 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
14430 return FoldedVOp;
14431
14432 // sext(undef) = 0 because the top bit will all be the same.
14433 if (N0.isUndef())
14434 return DAG.getConstant(0, DL, VT);
14435
14436 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14437 return Res;
14438
14439 // fold (sext (sext x)) -> (sext x)
14440 // fold (sext (aext x)) -> (sext x)
14441 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
14442 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
14443
14444 // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14445 // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
14446 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14447 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
14448 return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT,
14449 N0.getOperand(0));
14450
14451 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
14452 SDValue N00 = N0.getOperand(0);
14453 EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT();
14454 if (N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(N00, ExtVT)) {
14455 // fold (sext (sext_inreg x)) -> (sext (trunc x))
14456 if ((!LegalTypes || TLI.isTypeLegal(ExtVT))) {
14457 SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00);
14458 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T);
14459 }
14460
14461 // If the trunc wasn't legal, try to fold to (sext_inreg (anyext x))
14462 if (!LegalTypes || TLI.isTypeLegal(VT)) {
14463 SDValue ExtSrc = DAG.getAnyExtOrTrunc(N00, DL, VT);
14464 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, ExtSrc,
14465 N0->getOperand(1));
14466 }
14467 }
14468 }
14469
14470 if (N0.getOpcode() == ISD::TRUNCATE) {
14471 // fold (sext (truncate (load x))) -> (sext (smaller load x))
14472 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
14473 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
14474 SDNode *oye = N0.getOperand(0).getNode();
14475 if (NarrowLoad.getNode() != N0.getNode()) {
14476 CombineTo(N0.getNode(), NarrowLoad);
14477 // CombineTo deleted the truncate, if needed, but not what's under it.
14478 AddToWorklist(oye);
14479 }
14480 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14481 }
14482
14483 // See if the value being truncated is already sign extended. If so, just
14484 // eliminate the trunc/sext pair.
14485 SDValue Op = N0.getOperand(0);
14486 unsigned OpBits = Op.getScalarValueSizeInBits();
14487 unsigned MidBits = N0.getScalarValueSizeInBits();
14488 unsigned DestBits = VT.getScalarSizeInBits();
14489
14490 if (N0->getFlags().hasNoSignedWrap() ||
14491 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
14492 if (OpBits == DestBits) {
14493 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
14494 // bits, it is already ready.
14495 return Op;
14496 }
14497
14498 if (OpBits < DestBits) {
14499 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
14500 // bits, just sext from i32.
14501 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
14502 }
14503
14504 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
14505 // bits, just truncate to i32.
14506 SDNodeFlags Flags;
14507 Flags.setNoSignedWrap(true);
14508 Flags.setNoUnsignedWrap(N0->getFlags().hasNoUnsignedWrap());
14509 return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
14510 }
14511
14512 // fold (sext (truncate x)) -> (sextinreg x).
14513 if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
14514 N0.getValueType())) {
14515 if (OpBits < DestBits)
14516 Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
14517 else if (OpBits > DestBits)
14518 Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
14519 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
14520 DAG.getValueType(N0.getValueType()));
14521 }
14522 }
14523
14524 // Try to simplify (sext (load x)).
14525 if (SDValue foldedExt =
14526 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
14527 ISD::SEXTLOAD, ISD::SIGN_EXTEND))
14528 return foldedExt;
14529
14530 if (SDValue foldedExt =
14531 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
14532 ISD::SEXTLOAD, ISD::SIGN_EXTEND))
14533 return foldedExt;
14534
14535 // fold (sext (load x)) to multiple smaller sextloads.
14536 // Only on illegal but splittable vectors.
14537 if (SDValue ExtLoad = CombineExtLoad(N))
14538 return ExtLoad;
14539
14540 // Try to simplify (sext (sextload x)).
14541 if (SDValue foldedExt = tryToFoldExtOfExtload(
14542 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
14543 return foldedExt;
14544
14545 // Try to simplify (sext (atomic_load x)).
14546 if (SDValue foldedExt =
14547 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::SEXTLOAD))
14548 return foldedExt;
14549
14550 // fold (sext (and/or/xor (load x), cst)) ->
14551 // (and/or/xor (sextload x), (sext cst))
14552 if (ISD::isBitwiseLogicOp(N0.getOpcode()) &&
14553 isa<LoadSDNode>(N0.getOperand(0)) &&
14554 N0.getOperand(1).getOpcode() == ISD::Constant &&
14555 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
14556 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
14557 EVT MemVT = LN00->getMemoryVT();
14558 if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
14559 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
14560 SmallVector<SDNode*, 4> SetCCs;
14561 bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
14562 ISD::SIGN_EXTEND, SetCCs, TLI);
14563 if (DoXform) {
14564 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
14565 LN00->getChain(), LN00->getBasePtr(),
14566 LN00->getMemoryVT(),
14567 LN00->getMemOperand());
14568 APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
14569 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
14570 ExtLoad, DAG.getConstant(Mask, DL, VT));
14571 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
14572 bool NoReplaceTruncAnd = !N0.hasOneUse();
14573 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14574 CombineTo(N, And);
14575 // If N0 has multiple uses, change other uses as well.
14576 if (NoReplaceTruncAnd) {
14577 SDValue TruncAnd =
14578 DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
14579 CombineTo(N0.getNode(), TruncAnd);
14580 }
14581 if (NoReplaceTrunc) {
14582 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
14583 } else {
14584 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
14585 LN00->getValueType(0), ExtLoad);
14586 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
14587 }
14588 return SDValue(N,0); // Return N so it doesn't get rechecked!
14589 }
14590 }
14591 }
14592
14593 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14594 return V;
14595
14596 if (SDValue V = foldSextSetcc(N))
14597 return V;
14598
14599 // fold (sext x) -> (zext x) if the sign bit is known zero.
14600 if (!TLI.isSExtCheaperThanZExt(N0.getValueType(), VT) &&
14601 (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
14602 DAG.SignBitIsZero(N0))
14603 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0, SDNodeFlags::NonNeg);
14604
14605 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14606 return NewVSel;
14607
14608 // Eliminate this sign extend by doing a negation in the destination type:
14609 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
14610 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
14611 isNullOrNullSplat(N0.getOperand(0)) &&
14612 N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
14613 TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
14614 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
14615 return DAG.getNegative(Zext, DL, VT);
14616 }
14617 // Eliminate this sign extend by doing a decrement in the destination type:
14618 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
14619 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
14620 isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
14621 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
14622 TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
14623 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
14624 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
14625 }
14626
14627 // fold sext (not i1 X) -> add (zext i1 X), -1
14628 // TODO: This could be extended to handle bool vectors.
14629 if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
14630 (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
14631 TLI.isOperationLegal(ISD::ADD, VT)))) {
14632 // If we can eliminate the 'not', the sext form should be better
14633 if (SDValue NewXor = visitXOR(N0.getNode())) {
14634 // Returning N0 is a form of in-visit replacement that may have
14635 // invalidated N0.
14636 if (NewXor.getNode() == N0.getNode()) {
14637 // Return SDValue here as the xor should have already been replaced in
14638 // this sext.
14639 return SDValue();
14640 }
14641
14642 // Return a new sext with the new xor.
14643 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
14644 }
14645
14646 SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
14647 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
14648 }
14649
14650 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
14651 return Res;
14652
14653 return SDValue();
14654 }
14655
14656 /// Given an extending node with a pop-count operand, if the target does not
14657 /// support a pop-count in the narrow source type but does support it in the
14658 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG,const SDLoc & DL)14659 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG, const SDLoc &DL) {
14660 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
14661 Extend->getOpcode() == ISD::ANY_EXTEND) &&
14662 "Expected extend op");
14663
14664 SDValue CtPop = Extend->getOperand(0);
14665 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
14666 return SDValue();
14667
14668 EVT VT = Extend->getValueType(0);
14669 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14670 if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
14671 !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
14672 return SDValue();
14673
14674 // zext (ctpop X) --> ctpop (zext X)
14675 SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
14676 return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
14677 }
14678
14679 // If we have (zext (abs X)) where X is a type that will be promoted by type
14680 // legalization, convert to (abs (sext X)). But don't extend past a legal type.
widenAbs(SDNode * Extend,SelectionDAG & DAG)14681 static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
14682 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
14683
14684 EVT VT = Extend->getValueType(0);
14685 if (VT.isVector())
14686 return SDValue();
14687
14688 SDValue Abs = Extend->getOperand(0);
14689 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
14690 return SDValue();
14691
14692 EVT AbsVT = Abs.getValueType();
14693 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
14694 if (TLI.getTypeAction(*DAG.getContext(), AbsVT) !=
14695 TargetLowering::TypePromoteInteger)
14696 return SDValue();
14697
14698 EVT LegalVT = TLI.getTypeToTransformTo(*DAG.getContext(), AbsVT);
14699
14700 SDValue SExt =
14701 DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Abs), LegalVT, Abs.getOperand(0));
14702 SDValue NewAbs = DAG.getNode(ISD::ABS, SDLoc(Abs), LegalVT, SExt);
14703 return DAG.getZExtOrTrunc(NewAbs, SDLoc(Extend), VT);
14704 }
14705
visitZERO_EXTEND(SDNode * N)14706 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
14707 SDValue N0 = N->getOperand(0);
14708 EVT VT = N->getValueType(0);
14709 SDLoc DL(N);
14710
14711 if (VT.isVector())
14712 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
14713 return FoldedVOp;
14714
14715 // zext(undef) = 0
14716 if (N0.isUndef())
14717 return DAG.getConstant(0, DL, VT);
14718
14719 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14720 return Res;
14721
14722 // fold (zext (zext x)) -> (zext x)
14723 // fold (zext (aext x)) -> (zext x)
14724 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14725 SDNodeFlags Flags;
14726 if (N0.getOpcode() == ISD::ZERO_EXTEND)
14727 Flags.setNonNeg(N0->getFlags().hasNonNeg());
14728 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0), Flags);
14729 }
14730
14731 // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14732 // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
14733 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
14734 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
14735 return DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, DL, VT, N0.getOperand(0));
14736
14737 // fold (zext (truncate x)) -> (zext x) or
14738 // (zext (truncate x)) -> (truncate x)
14739 // This is valid when the truncated bits of x are already zero.
14740 SDValue Op;
14741 KnownBits Known;
14742 if (isTruncateOf(DAG, N0, Op, Known)) {
14743 APInt TruncatedBits =
14744 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
14745 APInt(Op.getScalarValueSizeInBits(), 0) :
14746 APInt::getBitsSet(Op.getScalarValueSizeInBits(),
14747 N0.getScalarValueSizeInBits(),
14748 std::min(Op.getScalarValueSizeInBits(),
14749 VT.getScalarSizeInBits()));
14750 if (TruncatedBits.isSubsetOf(Known.Zero)) {
14751 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
14752 DAG.salvageDebugInfo(*N0.getNode());
14753
14754 return ZExtOrTrunc;
14755 }
14756 }
14757
14758 // fold (zext (truncate x)) -> (and x, mask)
14759 if (N0.getOpcode() == ISD::TRUNCATE) {
14760 // fold (zext (truncate (load x))) -> (zext (smaller load x))
14761 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
14762 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
14763 SDNode *oye = N0.getOperand(0).getNode();
14764 if (NarrowLoad.getNode() != N0.getNode()) {
14765 CombineTo(N0.getNode(), NarrowLoad);
14766 // CombineTo deleted the truncate, if needed, but not what's under it.
14767 AddToWorklist(oye);
14768 }
14769 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14770 }
14771
14772 EVT SrcVT = N0.getOperand(0).getValueType();
14773 EVT MinVT = N0.getValueType();
14774
14775 if (N->getFlags().hasNonNeg()) {
14776 SDValue Op = N0.getOperand(0);
14777 unsigned OpBits = SrcVT.getScalarSizeInBits();
14778 unsigned MidBits = MinVT.getScalarSizeInBits();
14779 unsigned DestBits = VT.getScalarSizeInBits();
14780
14781 if (N0->getFlags().hasNoSignedWrap() ||
14782 DAG.ComputeNumSignBits(Op) > OpBits - MidBits) {
14783 if (OpBits == DestBits) {
14784 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
14785 // bits, it is already ready.
14786 return Op;
14787 }
14788
14789 if (OpBits < DestBits) {
14790 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
14791 // bits, just sext from i32.
14792 // FIXME: This can probably be ZERO_EXTEND nneg?
14793 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
14794 }
14795
14796 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
14797 // bits, just truncate to i32.
14798 SDNodeFlags Flags;
14799 Flags.setNoSignedWrap(true);
14800 Flags.setNoUnsignedWrap(true);
14801 return DAG.getNode(ISD::TRUNCATE, DL, VT, Op, Flags);
14802 }
14803 }
14804
14805 // Try to mask before the extension to avoid having to generate a larger mask,
14806 // possibly over several sub-vectors.
14807 if (SrcVT.bitsLT(VT) && VT.isVector()) {
14808 if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
14809 TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
14810 SDValue Op = N0.getOperand(0);
14811 Op = DAG.getZeroExtendInReg(Op, DL, MinVT);
14812 AddToWorklist(Op.getNode());
14813 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
14814 // Transfer the debug info; the new node is equivalent to N0.
14815 DAG.transferDbgValues(N0, ZExtOrTrunc);
14816 return ZExtOrTrunc;
14817 }
14818 }
14819
14820 if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
14821 SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
14822 AddToWorklist(Op.getNode());
14823 SDValue And = DAG.getZeroExtendInReg(Op, DL, MinVT);
14824 // We may safely transfer the debug info describing the truncate node over
14825 // to the equivalent and operation.
14826 DAG.transferDbgValues(N0, And);
14827 return And;
14828 }
14829 }
14830
14831 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
14832 // if either of the casts is not free.
14833 if (N0.getOpcode() == ISD::AND &&
14834 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
14835 N0.getOperand(1).getOpcode() == ISD::Constant &&
14836 (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType()) ||
14837 !TLI.isZExtFree(N0.getValueType(), VT))) {
14838 SDValue X = N0.getOperand(0).getOperand(0);
14839 X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
14840 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14841 return DAG.getNode(ISD::AND, DL, VT,
14842 X, DAG.getConstant(Mask, DL, VT));
14843 }
14844
14845 // Try to simplify (zext (load x)).
14846 if (SDValue foldedExt = tryToFoldExtOfLoad(
14847 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD,
14848 ISD::ZERO_EXTEND, N->getFlags().hasNonNeg()))
14849 return foldedExt;
14850
14851 if (SDValue foldedExt =
14852 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
14853 ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
14854 return foldedExt;
14855
14856 // fold (zext (load x)) to multiple smaller zextloads.
14857 // Only on illegal but splittable vectors.
14858 if (SDValue ExtLoad = CombineExtLoad(N))
14859 return ExtLoad;
14860
14861 // Try to simplify (zext (atomic_load x)).
14862 if (SDValue foldedExt =
14863 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::ZEXTLOAD))
14864 return foldedExt;
14865
14866 // fold (zext (and/or/xor (load x), cst)) ->
14867 // (and/or/xor (zextload x), (zext cst))
14868 // Unless (and (load x) cst) will match as a zextload already and has
14869 // additional users, or the zext is already free.
14870 if (ISD::isBitwiseLogicOp(N0.getOpcode()) && !TLI.isZExtFree(N0, VT) &&
14871 isa<LoadSDNode>(N0.getOperand(0)) &&
14872 N0.getOperand(1).getOpcode() == ISD::Constant &&
14873 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
14874 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
14875 EVT MemVT = LN00->getMemoryVT();
14876 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
14877 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
14878 bool DoXform = true;
14879 SmallVector<SDNode*, 4> SetCCs;
14880 if (!N0.hasOneUse()) {
14881 if (N0.getOpcode() == ISD::AND) {
14882 auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
14883 EVT LoadResultTy = AndC->getValueType(0);
14884 EVT ExtVT;
14885 if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
14886 DoXform = false;
14887 }
14888 }
14889 if (DoXform)
14890 DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
14891 ISD::ZERO_EXTEND, SetCCs, TLI);
14892 if (DoXform) {
14893 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
14894 LN00->getChain(), LN00->getBasePtr(),
14895 LN00->getMemoryVT(),
14896 LN00->getMemOperand());
14897 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
14898 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
14899 ExtLoad, DAG.getConstant(Mask, DL, VT));
14900 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
14901 bool NoReplaceTruncAnd = !N0.hasOneUse();
14902 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
14903 CombineTo(N, And);
14904 // If N0 has multiple uses, change other uses as well.
14905 if (NoReplaceTruncAnd) {
14906 SDValue TruncAnd =
14907 DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
14908 CombineTo(N0.getNode(), TruncAnd);
14909 }
14910 if (NoReplaceTrunc) {
14911 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
14912 } else {
14913 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
14914 LN00->getValueType(0), ExtLoad);
14915 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
14916 }
14917 return SDValue(N,0); // Return N so it doesn't get rechecked!
14918 }
14919 }
14920 }
14921
14922 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
14923 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
14924 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
14925 return ZExtLoad;
14926
14927 // Try to simplify (zext (zextload x)).
14928 if (SDValue foldedExt = tryToFoldExtOfExtload(
14929 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
14930 return foldedExt;
14931
14932 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
14933 return V;
14934
14935 if (N0.getOpcode() == ISD::SETCC) {
14936 // Propagate fast-math-flags.
14937 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14938
14939 // Only do this before legalize for now.
14940 if (!LegalOperations && VT.isVector() &&
14941 N0.getValueType().getVectorElementType() == MVT::i1) {
14942 EVT N00VT = N0.getOperand(0).getValueType();
14943 if (getSetCCResultType(N00VT) == N0.getValueType())
14944 return SDValue();
14945
14946 // We know that the # elements of the results is the same as the #
14947 // elements of the compare (and the # elements of the compare result for
14948 // that matter). Check to see that they are the same size. If so, we know
14949 // that the element size of the sext'd result matches the element size of
14950 // the compare operands.
14951 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
14952 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
14953 SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
14954 N0.getOperand(1), N0.getOperand(2));
14955 return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
14956 }
14957
14958 // If the desired elements are smaller or larger than the source
14959 // elements we can use a matching integer vector type and then
14960 // truncate/any extend followed by zext_in_reg.
14961 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14962 SDValue VsetCC =
14963 DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
14964 N0.getOperand(1), N0.getOperand(2));
14965 return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
14966 N0.getValueType());
14967 }
14968
14969 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
14970 EVT N0VT = N0.getValueType();
14971 EVT N00VT = N0.getOperand(0).getValueType();
14972 if (SDValue SCC = SimplifySelectCC(
14973 DL, N0.getOperand(0), N0.getOperand(1),
14974 DAG.getBoolConstant(true, DL, N0VT, N00VT),
14975 DAG.getBoolConstant(false, DL, N0VT, N00VT),
14976 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
14977 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
14978 }
14979
14980 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
14981 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
14982 !TLI.isZExtFree(N0, VT)) {
14983 SDValue ShVal = N0.getOperand(0);
14984 SDValue ShAmt = N0.getOperand(1);
14985 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(ShAmt)) {
14986 if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
14987 if (N0.getOpcode() == ISD::SHL) {
14988 // If the original shl may be shifting out bits, do not perform this
14989 // transformation.
14990 unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
14991 ShVal.getOperand(0).getValueSizeInBits();
14992 if (ShAmtC->getAPIntValue().ugt(KnownZeroBits)) {
14993 // If the shift is too large, then see if we can deduce that the
14994 // shift is safe anyway.
14995
14996 // Check if the bits being shifted out are known to be zero.
14997 KnownBits KnownShVal = DAG.computeKnownBits(ShVal);
14998 if (ShAmtC->getAPIntValue().ugt(KnownShVal.countMinLeadingZeros()))
14999 return SDValue();
15000 }
15001 }
15002
15003 // Ensure that the shift amount is wide enough for the shifted value.
15004 if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
15005 ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
15006
15007 return DAG.getNode(N0.getOpcode(), DL, VT,
15008 DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ShVal), ShAmt);
15009 }
15010 }
15011 }
15012
15013 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
15014 return NewVSel;
15015
15016 if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
15017 return NewCtPop;
15018
15019 if (SDValue V = widenAbs(N, DAG))
15020 return V;
15021
15022 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15023 return Res;
15024
15025 // CSE zext nneg with sext if the zext is not free.
15026 if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(N0.getValueType(), VT)) {
15027 SDNode *CSENode = DAG.getNodeIfExists(ISD::SIGN_EXTEND, N->getVTList(), N0);
15028 if (CSENode)
15029 return SDValue(CSENode, 0);
15030 }
15031
15032 return SDValue();
15033 }
15034
visitANY_EXTEND(SDNode * N)15035 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
15036 SDValue N0 = N->getOperand(0);
15037 EVT VT = N->getValueType(0);
15038 SDLoc DL(N);
15039
15040 // aext(undef) = undef
15041 if (N0.isUndef())
15042 return DAG.getUNDEF(VT);
15043
15044 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15045 return Res;
15046
15047 // fold (aext (aext x)) -> (aext x)
15048 // fold (aext (zext x)) -> (zext x)
15049 // fold (aext (sext x)) -> (sext x)
15050 if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
15051 N0.getOpcode() == ISD::SIGN_EXTEND) {
15052 SDNodeFlags Flags;
15053 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15054 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15055 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Flags);
15056 }
15057
15058 // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
15059 // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
15060 // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
15061 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
15062 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
15063 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
15064 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0));
15065
15066 // fold (aext (truncate (load x))) -> (aext (smaller load x))
15067 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
15068 if (N0.getOpcode() == ISD::TRUNCATE) {
15069 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
15070 SDNode *oye = N0.getOperand(0).getNode();
15071 if (NarrowLoad.getNode() != N0.getNode()) {
15072 CombineTo(N0.getNode(), NarrowLoad);
15073 // CombineTo deleted the truncate, if needed, but not what's under it.
15074 AddToWorklist(oye);
15075 }
15076 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15077 }
15078 }
15079
15080 // fold (aext (truncate x))
15081 if (N0.getOpcode() == ISD::TRUNCATE)
15082 return DAG.getAnyExtOrTrunc(N0.getOperand(0), DL, VT);
15083
15084 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
15085 // if the trunc is not free.
15086 if (N0.getOpcode() == ISD::AND &&
15087 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
15088 N0.getOperand(1).getOpcode() == ISD::Constant &&
15089 !TLI.isTruncateFree(N0.getOperand(0).getOperand(0), N0.getValueType())) {
15090 SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
15091 SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
15092 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
15093 return DAG.getNode(ISD::AND, DL, VT, X, Y);
15094 }
15095
15096 // fold (aext (load x)) -> (aext (truncate (extload x)))
15097 // None of the supported targets knows how to perform load and any_ext
15098 // on vectors in one instruction, so attempt to fold to zext instead.
15099 if (VT.isVector()) {
15100 // Try to simplify (zext (load x)).
15101 if (SDValue foldedExt =
15102 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
15103 ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
15104 return foldedExt;
15105 } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
15106 ISD::isUNINDEXEDLoad(N0.getNode()) &&
15107 TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
15108 bool DoXform = true;
15109 SmallVector<SDNode *, 4> SetCCs;
15110 if (!N0.hasOneUse())
15111 DoXform =
15112 ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
15113 if (DoXform) {
15114 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15115 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT, LN0->getChain(),
15116 LN0->getBasePtr(), N0.getValueType(),
15117 LN0->getMemOperand());
15118 ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
15119 // If the load value is used only by N, replace it via CombineTo N.
15120 bool NoReplaceTrunc = N0.hasOneUse();
15121 CombineTo(N, ExtLoad);
15122 if (NoReplaceTrunc) {
15123 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
15124 recursivelyDeleteUnusedNodes(LN0);
15125 } else {
15126 SDValue Trunc =
15127 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
15128 CombineTo(LN0, Trunc, ExtLoad.getValue(1));
15129 }
15130 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15131 }
15132 }
15133
15134 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
15135 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
15136 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
15137 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
15138 ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
15139 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15140 ISD::LoadExtType ExtType = LN0->getExtensionType();
15141 EVT MemVT = LN0->getMemoryVT();
15142 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
15143 SDValue ExtLoad =
15144 DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(),
15145 MemVT, LN0->getMemOperand());
15146 CombineTo(N, ExtLoad);
15147 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
15148 recursivelyDeleteUnusedNodes(LN0);
15149 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15150 }
15151 }
15152
15153 if (N0.getOpcode() == ISD::SETCC) {
15154 // Propagate fast-math-flags.
15155 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
15156
15157 // For vectors:
15158 // aext(setcc) -> vsetcc
15159 // aext(setcc) -> truncate(vsetcc)
15160 // aext(setcc) -> aext(vsetcc)
15161 // Only do this before legalize for now.
15162 if (VT.isVector() && !LegalOperations) {
15163 EVT N00VT = N0.getOperand(0).getValueType();
15164 if (getSetCCResultType(N00VT) == N0.getValueType())
15165 return SDValue();
15166
15167 // We know that the # elements of the results is the same as the
15168 // # elements of the compare (and the # elements of the compare result
15169 // for that matter). Check to see that they are the same size. If so,
15170 // we know that the element size of the sext'd result matches the
15171 // element size of the compare operands.
15172 if (VT.getSizeInBits() == N00VT.getSizeInBits())
15173 return DAG.getSetCC(DL, VT, N0.getOperand(0), N0.getOperand(1),
15174 cast<CondCodeSDNode>(N0.getOperand(2))->get());
15175
15176 // If the desired elements are smaller or larger than the source
15177 // elements we can use a matching integer vector type and then
15178 // truncate/any extend
15179 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
15180 SDValue VsetCC = DAG.getSetCC(
15181 DL, MatchingVectorType, N0.getOperand(0), N0.getOperand(1),
15182 cast<CondCodeSDNode>(N0.getOperand(2))->get());
15183 return DAG.getAnyExtOrTrunc(VsetCC, DL, VT);
15184 }
15185
15186 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
15187 if (SDValue SCC = SimplifySelectCC(
15188 DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
15189 DAG.getConstant(0, DL, VT),
15190 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
15191 return SCC;
15192 }
15193
15194 if (SDValue NewCtPop = widenCtPop(N, DAG, DL))
15195 return NewCtPop;
15196
15197 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, DL, Level))
15198 return Res;
15199
15200 return SDValue();
15201 }
15202
visitAssertExt(SDNode * N)15203 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
15204 unsigned Opcode = N->getOpcode();
15205 SDValue N0 = N->getOperand(0);
15206 SDValue N1 = N->getOperand(1);
15207 EVT AssertVT = cast<VTSDNode>(N1)->getVT();
15208
15209 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
15210 if (N0.getOpcode() == Opcode &&
15211 AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
15212 return N0;
15213
15214 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15215 N0.getOperand(0).getOpcode() == Opcode) {
15216 // We have an assert, truncate, assert sandwich. Make one stronger assert
15217 // by asserting on the smallest asserted type to the larger source type.
15218 // This eliminates the later assert:
15219 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
15220 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
15221 SDLoc DL(N);
15222 SDValue BigA = N0.getOperand(0);
15223 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15224 EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
15225 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
15226 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
15227 BigA.getOperand(0), MinAssertVTVal);
15228 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
15229 }
15230
15231 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
15232 // than X. Just move the AssertZext in front of the truncate and drop the
15233 // AssertSExt.
15234 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
15235 N0.getOperand(0).getOpcode() == ISD::AssertSext &&
15236 Opcode == ISD::AssertZext) {
15237 SDValue BigA = N0.getOperand(0);
15238 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15239 if (AssertVT.bitsLT(BigA_AssertVT)) {
15240 SDLoc DL(N);
15241 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
15242 BigA.getOperand(0), N1);
15243 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
15244 }
15245 }
15246
15247 // If we have (AssertZext (and (AssertSext X, iX), M), iY) and Y is smaller
15248 // than X, and the And doesn't change the lower iX bits, we can move the
15249 // AssertZext in front of the And and drop the AssertSext.
15250 if (Opcode == ISD::AssertZext && N0.getOpcode() == ISD::AND &&
15251 N0.hasOneUse() && N0.getOperand(0).getOpcode() == ISD::AssertSext &&
15252 isa<ConstantSDNode>(N0.getOperand(1))) {
15253 SDValue BigA = N0.getOperand(0);
15254 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
15255 const APInt &Mask = N0.getConstantOperandAPInt(1);
15256 if (AssertVT.bitsLT(BigA_AssertVT) &&
15257 Mask.countr_one() >= BigA_AssertVT.getScalarSizeInBits()) {
15258 SDLoc DL(N);
15259 SDValue NewAssert =
15260 DAG.getNode(Opcode, DL, N->getValueType(0), BigA.getOperand(0), N1);
15261 return DAG.getNode(ISD::AND, DL, N->getValueType(0), NewAssert,
15262 N0.getOperand(1));
15263 }
15264 }
15265
15266 return SDValue();
15267 }
15268
visitAssertAlign(SDNode * N)15269 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
15270 SDLoc DL(N);
15271
15272 Align AL = cast<AssertAlignSDNode>(N)->getAlign();
15273 SDValue N0 = N->getOperand(0);
15274
15275 // Fold (assertalign (assertalign x, AL0), AL1) ->
15276 // (assertalign x, max(AL0, AL1))
15277 if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
15278 return DAG.getAssertAlign(DL, N0.getOperand(0),
15279 std::max(AL, AAN->getAlign()));
15280
15281 // In rare cases, there are trivial arithmetic ops in source operands. Sink
15282 // this assert down to source operands so that those arithmetic ops could be
15283 // exposed to the DAG combining.
15284 switch (N0.getOpcode()) {
15285 default:
15286 break;
15287 case ISD::ADD:
15288 case ISD::PTRADD:
15289 case ISD::SUB: {
15290 unsigned AlignShift = Log2(AL);
15291 SDValue LHS = N0.getOperand(0);
15292 SDValue RHS = N0.getOperand(1);
15293 unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
15294 unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
15295 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
15296 if (LHSAlignShift < AlignShift)
15297 LHS = DAG.getAssertAlign(DL, LHS, AL);
15298 if (RHSAlignShift < AlignShift)
15299 RHS = DAG.getAssertAlign(DL, RHS, AL);
15300 return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
15301 }
15302 break;
15303 }
15304 }
15305
15306 return SDValue();
15307 }
15308
15309 /// If the result of a load is shifted/masked/truncated to an effectively
15310 /// narrower type, try to transform the load to a narrower type and/or
15311 /// use an extending load.
reduceLoadWidth(SDNode * N)15312 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
15313 unsigned Opc = N->getOpcode();
15314
15315 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
15316 SDValue N0 = N->getOperand(0);
15317 EVT VT = N->getValueType(0);
15318 EVT ExtVT = VT;
15319
15320 // This transformation isn't valid for vector loads.
15321 if (VT.isVector())
15322 return SDValue();
15323
15324 // The ShAmt variable is used to indicate that we've consumed a right
15325 // shift. I.e. we want to narrow the width of the load by skipping to load the
15326 // ShAmt least significant bits.
15327 unsigned ShAmt = 0;
15328 // A special case is when the least significant bits from the load are masked
15329 // away, but using an AND rather than a right shift. HasShiftedOffset is used
15330 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
15331 // the result.
15332 unsigned ShiftedOffset = 0;
15333 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
15334 // extended to VT.
15335 if (Opc == ISD::SIGN_EXTEND_INREG) {
15336 ExtType = ISD::SEXTLOAD;
15337 ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
15338 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
15339 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
15340 // value, or it may be shifting a higher subword, half or byte into the
15341 // lowest bits.
15342
15343 // Only handle shift with constant shift amount, and the shiftee must be a
15344 // load.
15345 auto *LN = dyn_cast<LoadSDNode>(N0);
15346 auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
15347 if (!N1C || !LN)
15348 return SDValue();
15349 // If the shift amount is larger than the memory type then we're not
15350 // accessing any of the loaded bytes.
15351 ShAmt = N1C->getZExtValue();
15352 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
15353 if (MemoryWidth <= ShAmt)
15354 return SDValue();
15355 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
15356 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
15357 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
15358 // If original load is a SEXTLOAD then we can't simply replace it by a
15359 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
15360 // followed by a ZEXT, but that is not handled at the moment). Similarly if
15361 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
15362 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
15363 LN->getExtensionType() == ISD::ZEXTLOAD) &&
15364 LN->getExtensionType() != ExtType)
15365 return SDValue();
15366 } else if (Opc == ISD::AND) {
15367 // An AND with a constant mask is the same as a truncate + zero-extend.
15368 auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
15369 if (!AndC)
15370 return SDValue();
15371
15372 const APInt &Mask = AndC->getAPIntValue();
15373 unsigned ActiveBits = 0;
15374 if (Mask.isMask()) {
15375 ActiveBits = Mask.countr_one();
15376 } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
15377 ShiftedOffset = ShAmt;
15378 } else {
15379 return SDValue();
15380 }
15381
15382 ExtType = ISD::ZEXTLOAD;
15383 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
15384 }
15385
15386 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
15387 // a right shift. Here we redo some of those checks, to possibly adjust the
15388 // ExtVT even further based on "a masking AND". We could also end up here for
15389 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
15390 // need to be done here as well.
15391 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
15392 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
15393 // Bail out when the SRL has more than one use. This is done for historical
15394 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
15395 // check below? And maybe it could be non-profitable to do the transform in
15396 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
15397 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
15398 if (!SRL.hasOneUse())
15399 return SDValue();
15400
15401 // Only handle shift with constant shift amount, and the shiftee must be a
15402 // load.
15403 auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
15404 auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
15405 if (!SRL1C || !LN)
15406 return SDValue();
15407
15408 // If the shift amount is larger than the input type then we're not
15409 // accessing any of the loaded bytes. If the load was a zextload/extload
15410 // then the result of the shift+trunc is zero/undef (handled elsewhere).
15411 ShAmt = SRL1C->getZExtValue();
15412 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
15413 if (ShAmt >= MemoryWidth)
15414 return SDValue();
15415
15416 // Because a SRL must be assumed to *need* to zero-extend the high bits
15417 // (as opposed to anyext the high bits), we can't combine the zextload
15418 // lowering of SRL and an sextload.
15419 if (LN->getExtensionType() == ISD::SEXTLOAD)
15420 return SDValue();
15421
15422 // Avoid reading outside the memory accessed by the original load (could
15423 // happened if we only adjust the load base pointer by ShAmt). Instead we
15424 // try to narrow the load even further. The typical scenario here is:
15425 // (i64 (truncate (i96 (srl (load x), 64)))) ->
15426 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
15427 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
15428 // Don't replace sextload by zextload.
15429 if (ExtType == ISD::SEXTLOAD)
15430 return SDValue();
15431 // Narrow the load.
15432 ExtType = ISD::ZEXTLOAD;
15433 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
15434 }
15435
15436 // If the SRL is only used by a masking AND, we may be able to adjust
15437 // the ExtVT to make the AND redundant.
15438 SDNode *Mask = *(SRL->user_begin());
15439 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
15440 isa<ConstantSDNode>(Mask->getOperand(1))) {
15441 unsigned Offset, ActiveBits;
15442 const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
15443 if (ShiftMask.isMask()) {
15444 EVT MaskedVT =
15445 EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one());
15446 // If the mask is smaller, recompute the type.
15447 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
15448 TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
15449 ExtVT = MaskedVT;
15450 } else if (ExtType == ISD::ZEXTLOAD &&
15451 ShiftMask.isShiftedMask(Offset, ActiveBits) &&
15452 (Offset + ShAmt) < VT.getScalarSizeInBits()) {
15453 EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
15454 // If the mask is shifted we can use a narrower load and a shl to insert
15455 // the trailing zeros.
15456 if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
15457 TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) {
15458 ExtVT = MaskedVT;
15459 ShAmt = Offset + ShAmt;
15460 ShiftedOffset = Offset;
15461 }
15462 }
15463 }
15464
15465 N0 = SRL.getOperand(0);
15466 }
15467
15468 // If the load is shifted left (and the result isn't shifted back right), we
15469 // can fold a truncate through the shift. The typical scenario is that N
15470 // points at a TRUNCATE here so the attempted fold is:
15471 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
15472 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
15473 unsigned ShLeftAmt = 0;
15474 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
15475 ExtVT == VT && TLI.isNarrowingProfitable(N, N0.getValueType(), VT)) {
15476 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
15477 ShLeftAmt = N01->getZExtValue();
15478 N0 = N0.getOperand(0);
15479 }
15480 }
15481
15482 // If we haven't found a load, we can't narrow it.
15483 if (!isa<LoadSDNode>(N0))
15484 return SDValue();
15485
15486 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
15487 // Reducing the width of a volatile load is illegal. For atomics, we may be
15488 // able to reduce the width provided we never widen again. (see D66309)
15489 if (!LN0->isSimple() ||
15490 !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
15491 return SDValue();
15492
15493 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
15494 unsigned LVTStoreBits =
15495 LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
15496 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
15497 return LVTStoreBits - EVTStoreBits - ShAmt;
15498 };
15499
15500 // We need to adjust the pointer to the load by ShAmt bits in order to load
15501 // the correct bytes.
15502 unsigned PtrAdjustmentInBits =
15503 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
15504
15505 uint64_t PtrOff = PtrAdjustmentInBits / 8;
15506 SDLoc DL(LN0);
15507 // The original load itself didn't wrap, so an offset within it doesn't.
15508 SDValue NewPtr =
15509 DAG.getMemBasePlusOffset(LN0->getBasePtr(), TypeSize::getFixed(PtrOff),
15510 DL, SDNodeFlags::NoUnsignedWrap);
15511 AddToWorklist(NewPtr.getNode());
15512
15513 SDValue Load;
15514 if (ExtType == ISD::NON_EXTLOAD) {
15515 const MDNode *OldRanges = LN0->getRanges();
15516 const MDNode *NewRanges = nullptr;
15517 // If LSBs are loaded and the truncated ConstantRange for the OldRanges
15518 // metadata is not the full-set for the new width then create a NewRanges
15519 // metadata for the truncated load
15520 if (ShAmt == 0 && OldRanges) {
15521 ConstantRange CR = getConstantRangeFromMetadata(*OldRanges);
15522 unsigned BitSize = VT.getScalarSizeInBits();
15523
15524 // It is possible for an 8-bit extending load with 8-bit range
15525 // metadata to be narrowed to an 8-bit load. This guard is necessary to
15526 // ensure that truncation is strictly smaller.
15527 if (CR.getBitWidth() > BitSize) {
15528 ConstantRange TruncatedCR = CR.truncate(BitSize);
15529 if (!TruncatedCR.isFullSet()) {
15530 Metadata *Bounds[2] = {
15531 ConstantAsMetadata::get(
15532 ConstantInt::get(*DAG.getContext(), TruncatedCR.getLower())),
15533 ConstantAsMetadata::get(
15534 ConstantInt::get(*DAG.getContext(), TruncatedCR.getUpper()))};
15535 NewRanges = MDNode::get(*DAG.getContext(), Bounds);
15536 }
15537 } else if (CR.getBitWidth() == BitSize)
15538 NewRanges = OldRanges;
15539 }
15540 Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
15541 LN0->getPointerInfo().getWithOffset(PtrOff),
15542 LN0->getBaseAlign(), LN0->getMemOperand()->getFlags(),
15543 LN0->getAAInfo(), NewRanges);
15544 } else
15545 Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
15546 LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
15547 LN0->getBaseAlign(), LN0->getMemOperand()->getFlags(),
15548 LN0->getAAInfo());
15549
15550 // Replace the old load's chain with the new load's chain.
15551 WorklistRemover DeadNodes(*this);
15552 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
15553
15554 // Shift the result left, if we've swallowed a left shift.
15555 SDValue Result = Load;
15556 if (ShLeftAmt != 0) {
15557 // If the shift amount is as large as the result size (but, presumably,
15558 // no larger than the source) then the useful bits of the result are
15559 // zero; we can't simply return the shortened shift, because the result
15560 // of that operation is undefined.
15561 if (ShLeftAmt >= VT.getScalarSizeInBits())
15562 Result = DAG.getConstant(0, DL, VT);
15563 else
15564 Result = DAG.getNode(ISD::SHL, DL, VT, Result,
15565 DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
15566 }
15567
15568 if (ShiftedOffset != 0) {
15569 // We're using a shifted mask, so the load now has an offset. This means
15570 // that data has been loaded into the lower bytes than it would have been
15571 // before, so we need to shl the loaded data into the correct position in the
15572 // register.
15573 SDValue ShiftC = DAG.getConstant(ShiftedOffset, DL, VT);
15574 Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
15575 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
15576 }
15577
15578 // Return the new loaded value.
15579 return Result;
15580 }
15581
visitSIGN_EXTEND_INREG(SDNode * N)15582 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
15583 SDValue N0 = N->getOperand(0);
15584 SDValue N1 = N->getOperand(1);
15585 EVT VT = N->getValueType(0);
15586 EVT ExtVT = cast<VTSDNode>(N1)->getVT();
15587 unsigned VTBits = VT.getScalarSizeInBits();
15588 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
15589 SDLoc DL(N);
15590
15591 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
15592 if (N0.isUndef())
15593 return DAG.getConstant(0, DL, VT);
15594
15595 // fold (sext_in_reg c1) -> c1
15596 if (SDValue C =
15597 DAG.FoldConstantArithmetic(ISD::SIGN_EXTEND_INREG, DL, VT, {N0, N1}))
15598 return C;
15599
15600 // If the input is already sign extended, just drop the extension.
15601 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
15602 return N0;
15603
15604 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
15605 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
15606 ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
15607 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N0.getOperand(0), N1);
15608
15609 // fold (sext_in_reg (sext x)) -> (sext x)
15610 // fold (sext_in_reg (aext x)) -> (sext x)
15611 // if x is small enough or if we know that x has more than 1 sign bit and the
15612 // sign_extend_inreg is extending from one of them.
15613 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
15614 SDValue N00 = N0.getOperand(0);
15615 unsigned N00Bits = N00.getScalarValueSizeInBits();
15616 if ((N00Bits <= ExtVTBits ||
15617 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
15618 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
15619 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N00);
15620 }
15621
15622 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
15623 // if x is small enough or if we know that x has more than 1 sign bit and the
15624 // sign_extend_inreg is extending from one of them.
15625 if (ISD::isExtVecInRegOpcode(N0.getOpcode())) {
15626 SDValue N00 = N0.getOperand(0);
15627 unsigned N00Bits = N00.getScalarValueSizeInBits();
15628 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
15629 if ((N00Bits == ExtVTBits ||
15630 (!IsZext && (N00Bits < ExtVTBits ||
15631 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
15632 (!LegalOperations ||
15633 TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
15634 return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, DL, VT, N00);
15635 }
15636
15637 // fold (sext_in_reg (zext x)) -> (sext x)
15638 // iff we are extending the source sign bit.
15639 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
15640 SDValue N00 = N0.getOperand(0);
15641 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
15642 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
15643 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N00);
15644 }
15645
15646 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
15647 if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
15648 return DAG.getZeroExtendInReg(N0, DL, ExtVT);
15649
15650 // fold operands of sext_in_reg based on knowledge that the top bits are not
15651 // demanded.
15652 if (SimplifyDemandedBits(SDValue(N, 0)))
15653 return SDValue(N, 0);
15654
15655 // fold (sext_in_reg (load x)) -> (smaller sextload x)
15656 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
15657 if (SDValue NarrowLoad = reduceLoadWidth(N))
15658 return NarrowLoad;
15659
15660 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
15661 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
15662 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
15663 if (N0.getOpcode() == ISD::SRL) {
15664 if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
15665 if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
15666 // We can turn this into an SRA iff the input to the SRL is already sign
15667 // extended enough.
15668 unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
15669 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
15670 return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0),
15671 N0.getOperand(1));
15672 }
15673 }
15674
15675 // fold (sext_inreg (extload x)) -> (sextload x)
15676 // If sextload is not supported by target, we can only do the combine when
15677 // load has one use. Doing otherwise can block folding the extload with other
15678 // extends that the target does support.
15679 if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
15680 ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
15681 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
15682 N0.hasOneUse()) ||
15683 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
15684 auto *LN0 = cast<LoadSDNode>(N0);
15685 SDValue ExtLoad =
15686 DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
15687 LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
15688 CombineTo(N, ExtLoad);
15689 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
15690 AddToWorklist(ExtLoad.getNode());
15691 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15692 }
15693
15694 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
15695 if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
15696 N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
15697 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
15698 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
15699 auto *LN0 = cast<LoadSDNode>(N0);
15700 SDValue ExtLoad =
15701 DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
15702 LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
15703 CombineTo(N, ExtLoad);
15704 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
15705 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15706 }
15707
15708 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
15709 // ignore it if the masked load is already sign extended
15710 if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
15711 if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
15712 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
15713 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
15714 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
15715 VT, DL, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
15716 Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
15717 Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
15718 CombineTo(N, ExtMaskedLoad);
15719 CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
15720 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15721 }
15722 }
15723
15724 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
15725 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
15726 if (SDValue(GN0, 0).hasOneUse() && ExtVT == GN0->getMemoryVT() &&
15727 TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
15728 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
15729 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
15730
15731 SDValue ExtLoad = DAG.getMaskedGather(
15732 DAG.getVTList(VT, MVT::Other), ExtVT, DL, Ops, GN0->getMemOperand(),
15733 GN0->getIndexType(), ISD::SEXTLOAD);
15734
15735 CombineTo(N, ExtLoad);
15736 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
15737 AddToWorklist(ExtLoad.getNode());
15738 return SDValue(N, 0); // Return N so it doesn't get rechecked!
15739 }
15740 }
15741
15742 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
15743 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
15744 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
15745 N0.getOperand(1), false))
15746 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, BSwap, N1);
15747 }
15748
15749 // Fold (iM_signext_inreg
15750 // (extract_subvector (zext|anyext|sext iN_v to _) _)
15751 // from iN)
15752 // -> (extract_subvector (signext iN_v to iM))
15753 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
15754 ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
15755 SDValue InnerExt = N0.getOperand(0);
15756 EVT InnerExtVT = InnerExt->getValueType(0);
15757 SDValue Extendee = InnerExt->getOperand(0);
15758
15759 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
15760 (!LegalOperations ||
15761 TLI.isOperationLegal(ISD::SIGN_EXTEND, InnerExtVT))) {
15762 SDValue SignExtExtendee =
15763 DAG.getNode(ISD::SIGN_EXTEND, DL, InnerExtVT, Extendee);
15764 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SignExtExtendee,
15765 N0.getOperand(1));
15766 }
15767 }
15768
15769 return SDValue();
15770 }
15771
foldExtendVectorInregToExtendOfSubvector(SDNode * N,const SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalOperations)15772 static SDValue foldExtendVectorInregToExtendOfSubvector(
15773 SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
15774 bool LegalOperations) {
15775 unsigned InregOpcode = N->getOpcode();
15776 unsigned Opcode = DAG.getOpcode_EXTEND(InregOpcode);
15777
15778 SDValue Src = N->getOperand(0);
15779 EVT VT = N->getValueType(0);
15780 EVT SrcVT = EVT::getVectorVT(*DAG.getContext(),
15781 Src.getValueType().getVectorElementType(),
15782 VT.getVectorElementCount());
15783
15784 assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
15785 "Expected EXTEND_VECTOR_INREG dag node in input!");
15786
15787 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
15788 // FIXME: one-use check may be overly restrictive
15789 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
15790 return SDValue();
15791
15792 // Profitability check: we must be extending exactly one of it's operands.
15793 // FIXME: this is probably overly restrictive.
15794 Src = Src.getOperand(0);
15795 if (Src.getValueType() != SrcVT)
15796 return SDValue();
15797
15798 if (LegalOperations && !TLI.isOperationLegal(Opcode, VT))
15799 return SDValue();
15800
15801 return DAG.getNode(Opcode, DL, VT, Src);
15802 }
15803
visitEXTEND_VECTOR_INREG(SDNode * N)15804 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
15805 SDValue N0 = N->getOperand(0);
15806 EVT VT = N->getValueType(0);
15807 SDLoc DL(N);
15808
15809 if (N0.isUndef()) {
15810 // aext_vector_inreg(undef) = undef because the top bits are undefined.
15811 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
15812 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
15813 ? DAG.getUNDEF(VT)
15814 : DAG.getConstant(0, DL, VT);
15815 }
15816
15817 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
15818 return Res;
15819
15820 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
15821 return SDValue(N, 0);
15822
15823 if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
15824 LegalOperations))
15825 return R;
15826
15827 return SDValue();
15828 }
15829
visitTRUNCATE_USAT_U(SDNode * N)15830 SDValue DAGCombiner::visitTRUNCATE_USAT_U(SDNode *N) {
15831 EVT VT = N->getValueType(0);
15832 SDValue N0 = N->getOperand(0);
15833
15834 SDValue FPVal;
15835 if (sd_match(N0, m_FPToUI(m_Value(FPVal))) &&
15836 DAG.getTargetLoweringInfo().shouldConvertFpToSat(
15837 ISD::FP_TO_UINT_SAT, FPVal.getValueType(), VT))
15838 return DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), VT, FPVal,
15839 DAG.getValueType(VT.getScalarType()));
15840
15841 return SDValue();
15842 }
15843
15844 /// Detect patterns of truncation with unsigned saturation:
15845 ///
15846 /// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
15847 /// Return the source value x to be truncated or SDValue() if the pattern was
15848 /// not matched.
15849 ///
detectUSatUPattern(SDValue In,EVT VT)15850 static SDValue detectUSatUPattern(SDValue In, EVT VT) {
15851 unsigned NumDstBits = VT.getScalarSizeInBits();
15852 unsigned NumSrcBits = In.getScalarValueSizeInBits();
15853 // Saturation with truncation. We truncate from InVT to VT.
15854 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15855
15856 SDValue Min;
15857 APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
15858 if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax))))
15859 return Min;
15860
15861 return SDValue();
15862 }
15863
15864 /// Detect patterns of truncation with signed saturation:
15865 /// (truncate (smin (smax (x, signed_min_of_dest_type),
15866 /// signed_max_of_dest_type)) to dest_type)
15867 /// or:
15868 /// (truncate (smax (smin (x, signed_max_of_dest_type),
15869 /// signed_min_of_dest_type)) to dest_type).
15870 ///
15871 /// Return the source value to be truncated or SDValue() if the pattern was not
15872 /// matched.
detectSSatSPattern(SDValue In,EVT VT)15873 static SDValue detectSSatSPattern(SDValue In, EVT VT) {
15874 unsigned NumDstBits = VT.getScalarSizeInBits();
15875 unsigned NumSrcBits = In.getScalarValueSizeInBits();
15876 // Saturation with truncation. We truncate from InVT to VT.
15877 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15878
15879 SDValue Val;
15880 APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
15881 APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
15882
15883 if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)),
15884 m_SpecificInt(SignedMax))))
15885 return Val;
15886
15887 if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)),
15888 m_SpecificInt(SignedMin))))
15889 return Val;
15890
15891 return SDValue();
15892 }
15893
15894 /// Detect patterns of truncation with unsigned saturation:
detectSSatUPattern(SDValue In,EVT VT,SelectionDAG & DAG,const SDLoc & DL)15895 static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
15896 const SDLoc &DL) {
15897 unsigned NumDstBits = VT.getScalarSizeInBits();
15898 unsigned NumSrcBits = In.getScalarValueSizeInBits();
15899 // Saturation with truncation. We truncate from InVT to VT.
15900 assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
15901
15902 SDValue Val;
15903 APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
15904 // Min == 0, Max is unsigned max of destination type.
15905 if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)),
15906 m_Zero())))
15907 return Val;
15908
15909 if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()),
15910 m_SpecificInt(UnsignedMax))))
15911 return Val;
15912
15913 if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()),
15914 m_SpecificInt(UnsignedMax))))
15915 return Val;
15916
15917 return SDValue();
15918 }
15919
foldToSaturated(SDNode * N,EVT & VT,SDValue & Src,EVT & SrcVT,SDLoc & DL,const TargetLowering & TLI,SelectionDAG & DAG)15920 static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
15921 SDLoc &DL, const TargetLowering &TLI,
15922 SelectionDAG &DAG) {
15923 auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
15924 return (TLI.isOperationLegalOrCustom(Opc, SrcVT) &&
15925 TLI.isTypeDesirableForOp(Opc, VT));
15926 };
15927
15928 if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15929 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
15930 if (SDValue SSatVal = detectSSatSPattern(Src, VT))
15931 return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15932 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15933 if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15934 return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15935 } else if (Src.getOpcode() == ISD::UMIN) {
15936 if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15937 if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15938 return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15939 if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
15940 if (SDValue USatVal = detectUSatUPattern(Src, VT))
15941 return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
15942 }
15943
15944 return SDValue();
15945 }
15946
visitTRUNCATE(SDNode * N)15947 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
15948 SDValue N0 = N->getOperand(0);
15949 EVT VT = N->getValueType(0);
15950 EVT SrcVT = N0.getValueType();
15951 bool isLE = DAG.getDataLayout().isLittleEndian();
15952 SDLoc DL(N);
15953
15954 // trunc(undef) = undef
15955 if (N0.isUndef())
15956 return DAG.getUNDEF(VT);
15957
15958 // fold (truncate (truncate x)) -> (truncate x)
15959 if (N0.getOpcode() == ISD::TRUNCATE)
15960 return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
15961
15962 // fold saturated truncate
15963 if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG))
15964 return SaturatedTR;
15965
15966 // fold (truncate c1) -> c1
15967 if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
15968 return C;
15969
15970 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
15971 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
15972 N0.getOpcode() == ISD::SIGN_EXTEND ||
15973 N0.getOpcode() == ISD::ANY_EXTEND) {
15974 // if the source is smaller than the dest, we still need an extend.
15975 if (N0.getOperand(0).getValueType().bitsLT(VT)) {
15976 SDNodeFlags Flags;
15977 if (N0.getOpcode() == ISD::ZERO_EXTEND)
15978 Flags.setNonNeg(N0->getFlags().hasNonNeg());
15979 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Flags);
15980 }
15981 // if the source is larger than the dest, than we just need the truncate.
15982 if (N0.getOperand(0).getValueType().bitsGT(VT))
15983 return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
15984 // if the source and dest are the same type, we can drop both the extend
15985 // and the truncate.
15986 return N0.getOperand(0);
15987 }
15988
15989 // Try to narrow a truncate-of-sext_in_reg to the destination type:
15990 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
15991 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
15992 N0.hasOneUse()) {
15993 SDValue X = N0.getOperand(0);
15994 SDValue ExtVal = N0.getOperand(1);
15995 EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
15996 if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(VT, SrcVT, ExtVT)) {
15997 SDValue TrX = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
15998 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, TrX, ExtVal);
15999 }
16000 }
16001
16002 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
16003 if (N->hasOneUse() && (N->user_begin()->getOpcode() == ISD::ANY_EXTEND))
16004 return SDValue();
16005
16006 // Fold extract-and-trunc into a narrow extract. For example:
16007 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
16008 // i32 y = TRUNCATE(i64 x)
16009 // -- becomes --
16010 // v16i8 b = BITCAST (v2i64 val)
16011 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
16012 //
16013 // Note: We only run this optimization after type legalization (which often
16014 // creates this pattern) and before operation legalization after which
16015 // we need to be more careful about the vector instructions that we generate.
16016 if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
16017 N0->hasOneUse()) {
16018 EVT TrTy = N->getValueType(0);
16019 SDValue Src = N0;
16020
16021 // Check for cases where we shift down an upper element before truncation.
16022 int EltOffset = 0;
16023 if (Src.getOpcode() == ISD::SRL && Src.getOperand(0)->hasOneUse()) {
16024 if (auto ShAmt = DAG.getValidShiftAmount(Src)) {
16025 if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
16026 Src = Src.getOperand(0);
16027 EltOffset = *ShAmt / TrTy.getSizeInBits();
16028 }
16029 }
16030 }
16031
16032 if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
16033 EVT VecTy = Src.getOperand(0).getValueType();
16034 EVT ExTy = Src.getValueType();
16035
16036 auto EltCnt = VecTy.getVectorElementCount();
16037 unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
16038 auto NewEltCnt = EltCnt * SizeRatio;
16039
16040 EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
16041 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
16042
16043 SDValue EltNo = Src->getOperand(1);
16044 if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
16045 int Elt = EltNo->getAsZExtVal();
16046 int Index = isLE ? (Elt * SizeRatio + EltOffset)
16047 : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
16048 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
16049 DAG.getBitcast(NVT, Src.getOperand(0)),
16050 DAG.getVectorIdxConstant(Index, DL));
16051 }
16052 }
16053 }
16054
16055 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
16056 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse() &&
16057 TLI.isTruncateFree(SrcVT, VT)) {
16058 if (!LegalOperations ||
16059 (TLI.isOperationLegal(ISD::SELECT, SrcVT) &&
16060 TLI.isNarrowingProfitable(N0.getNode(), SrcVT, VT))) {
16061 SDLoc SL(N0);
16062 SDValue Cond = N0.getOperand(0);
16063 SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
16064 SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
16065 return DAG.getNode(ISD::SELECT, DL, VT, Cond, TruncOp0, TruncOp1);
16066 }
16067 }
16068
16069 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
16070 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
16071 (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
16072 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
16073 SDValue Amt = N0.getOperand(1);
16074 KnownBits Known = DAG.computeKnownBits(Amt);
16075 unsigned Size = VT.getScalarSizeInBits();
16076 if (Known.countMaxActiveBits() <= Log2_32(Size)) {
16077 EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
16078 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16079 if (AmtVT != Amt.getValueType()) {
16080 Amt = DAG.getZExtOrTrunc(Amt, DL, AmtVT);
16081 AddToWorklist(Amt.getNode());
16082 }
16083 return DAG.getNode(ISD::SHL, DL, VT, Trunc, Amt);
16084 }
16085 }
16086
16087 if (SDValue V = foldSubToUSubSat(VT, N0.getNode(), DL))
16088 return V;
16089
16090 if (SDValue ABD = foldABSToABD(N, DL))
16091 return ABD;
16092
16093 // Attempt to pre-truncate BUILD_VECTOR sources.
16094 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
16095 N0.hasOneUse() &&
16096 TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
16097 // Avoid creating illegal types if running after type legalizer.
16098 (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
16099 EVT SVT = VT.getScalarType();
16100 SmallVector<SDValue, 8> TruncOps;
16101 for (const SDValue &Op : N0->op_values()) {
16102 SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
16103 TruncOps.push_back(TruncOp);
16104 }
16105 return DAG.getBuildVector(VT, DL, TruncOps);
16106 }
16107
16108 // trunc (splat_vector x) -> splat_vector (trunc x)
16109 if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
16110 (!LegalTypes || TLI.isTypeLegal(VT.getScalarType())) &&
16111 (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, VT))) {
16112 EVT SVT = VT.getScalarType();
16113 return DAG.getSplatVector(
16114 VT, DL, DAG.getNode(ISD::TRUNCATE, DL, SVT, N0->getOperand(0)));
16115 }
16116
16117 // Fold a series of buildvector, bitcast, and truncate if possible.
16118 // For example fold
16119 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
16120 // (2xi32 (buildvector x, y)).
16121 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
16122 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
16123 N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
16124 N0.getOperand(0).hasOneUse()) {
16125 SDValue BuildVect = N0.getOperand(0);
16126 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
16127 EVT TruncVecEltTy = VT.getVectorElementType();
16128
16129 // Check that the element types match.
16130 if (BuildVectEltTy == TruncVecEltTy) {
16131 // Now we only need to compute the offset of the truncated elements.
16132 unsigned BuildVecNumElts = BuildVect.getNumOperands();
16133 unsigned TruncVecNumElts = VT.getVectorNumElements();
16134 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
16135 unsigned FirstElt = isLE ? 0 : (TruncEltOffset - 1);
16136
16137 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
16138 "Invalid number of elements");
16139
16140 SmallVector<SDValue, 8> Opnds;
16141 for (unsigned i = FirstElt, e = BuildVecNumElts; i < e;
16142 i += TruncEltOffset)
16143 Opnds.push_back(BuildVect.getOperand(i));
16144
16145 return DAG.getBuildVector(VT, DL, Opnds);
16146 }
16147 }
16148
16149 // fold (truncate (load x)) -> (smaller load x)
16150 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
16151 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
16152 if (SDValue Reduced = reduceLoadWidth(N))
16153 return Reduced;
16154
16155 // Handle the case where the truncated result is at least as wide as the
16156 // loaded type.
16157 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
16158 auto *LN0 = cast<LoadSDNode>(N0);
16159 if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
16160 SDValue NewLoad = DAG.getExtLoad(
16161 LN0->getExtensionType(), SDLoc(LN0), VT, LN0->getChain(),
16162 LN0->getBasePtr(), LN0->getMemoryVT(), LN0->getMemOperand());
16163 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
16164 return NewLoad;
16165 }
16166 }
16167 }
16168
16169 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
16170 // where ... are all 'undef'.
16171 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
16172 SmallVector<EVT, 8> VTs;
16173 SDValue V;
16174 unsigned Idx = 0;
16175 unsigned NumDefs = 0;
16176
16177 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
16178 SDValue X = N0.getOperand(i);
16179 if (!X.isUndef()) {
16180 V = X;
16181 Idx = i;
16182 NumDefs++;
16183 }
16184 // Stop if more than one members are non-undef.
16185 if (NumDefs > 1)
16186 break;
16187
16188 VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
16189 VT.getVectorElementType(),
16190 X.getValueType().getVectorElementCount()));
16191 }
16192
16193 if (NumDefs == 0)
16194 return DAG.getUNDEF(VT);
16195
16196 if (NumDefs == 1) {
16197 assert(V.getNode() && "The single defined operand is empty!");
16198 SmallVector<SDValue, 8> Opnds;
16199 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
16200 if (i != Idx) {
16201 Opnds.push_back(DAG.getUNDEF(VTs[i]));
16202 continue;
16203 }
16204 SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
16205 AddToWorklist(NV.getNode());
16206 Opnds.push_back(NV);
16207 }
16208 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Opnds);
16209 }
16210 }
16211
16212 // Fold truncate of a bitcast of a vector to an extract of the low vector
16213 // element.
16214 //
16215 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
16216 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
16217 SDValue VecSrc = N0.getOperand(0);
16218 EVT VecSrcVT = VecSrc.getValueType();
16219 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
16220 (!LegalOperations ||
16221 TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
16222 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
16223 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, VecSrc,
16224 DAG.getVectorIdxConstant(Idx, DL));
16225 }
16226 }
16227
16228 // Simplify the operands using demanded-bits information.
16229 if (SimplifyDemandedBits(SDValue(N, 0)))
16230 return SDValue(N, 0);
16231
16232 // fold (truncate (extract_subvector(ext x))) ->
16233 // (extract_subvector x)
16234 // TODO: This can be generalized to cover cases where the truncate and extract
16235 // do not fully cancel each other out.
16236 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
16237 SDValue N00 = N0.getOperand(0);
16238 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
16239 N00.getOpcode() == ISD::ZERO_EXTEND ||
16240 N00.getOpcode() == ISD::ANY_EXTEND) {
16241 if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
16242 VT.getVectorElementType())
16243 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
16244 N00.getOperand(0), N0.getOperand(1));
16245 }
16246 }
16247
16248 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
16249 return NewVSel;
16250
16251 // Narrow a suitable binary operation with a non-opaque constant operand by
16252 // moving it ahead of the truncate. This is limited to pre-legalization
16253 // because targets may prefer a wider type during later combines and invert
16254 // this transform.
16255 switch (N0.getOpcode()) {
16256 case ISD::ADD:
16257 case ISD::SUB:
16258 case ISD::MUL:
16259 case ISD::AND:
16260 case ISD::OR:
16261 case ISD::XOR:
16262 if (!LegalOperations && N0.hasOneUse() &&
16263 (isConstantOrConstantVector(N0.getOperand(0), true) ||
16264 isConstantOrConstantVector(N0.getOperand(1), true))) {
16265 // TODO: We already restricted this to pre-legalization, but for vectors
16266 // we are extra cautious to not create an unsupported operation.
16267 // Target-specific changes are likely needed to avoid regressions here.
16268 if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
16269 SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16270 SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
16271 return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
16272 }
16273 }
16274 break;
16275 case ISD::ADDE:
16276 case ISD::UADDO_CARRY:
16277 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
16278 // (trunc uaddo_carry(X, Y, Carry)) ->
16279 // (uaddo_carry trunc(X), trunc(Y), Carry)
16280 // When the adde's carry is not used.
16281 // We only do for uaddo_carry before legalize operation
16282 if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
16283 TLI.isOperationLegal(N0.getOpcode(), VT)) &&
16284 N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
16285 SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
16286 SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
16287 SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
16288 return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
16289 }
16290 break;
16291 case ISD::USUBSAT:
16292 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
16293 // enough to know that the upper bits are zero we must ensure that we don't
16294 // introduce an extra truncate.
16295 if (!LegalOperations && N0.hasOneUse() &&
16296 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
16297 N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
16298 VT.getScalarSizeInBits() &&
16299 hasOperation(N0.getOpcode(), VT)) {
16300 return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
16301 DAG, DL);
16302 }
16303 break;
16304 }
16305
16306 return SDValue();
16307 }
16308
getBuildPairElt(SDNode * N,unsigned i)16309 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
16310 SDValue Elt = N->getOperand(i);
16311 if (Elt.getOpcode() != ISD::MERGE_VALUES)
16312 return Elt.getNode();
16313 return Elt.getOperand(Elt.getResNo()).getNode();
16314 }
16315
16316 /// build_pair (load, load) -> load
16317 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)16318 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
16319 assert(N->getOpcode() == ISD::BUILD_PAIR);
16320
16321 auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
16322 auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
16323
16324 // A BUILD_PAIR is always having the least significant part in elt 0 and the
16325 // most significant part in elt 1. So when combining into one large load, we
16326 // need to consider the endianness.
16327 if (DAG.getDataLayout().isBigEndian())
16328 std::swap(LD1, LD2);
16329
16330 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
16331 !LD1->hasOneUse() || !LD2->hasOneUse() ||
16332 LD1->getAddressSpace() != LD2->getAddressSpace())
16333 return SDValue();
16334
16335 unsigned LD1Fast = 0;
16336 EVT LD1VT = LD1->getValueType(0);
16337 unsigned LD1Bytes = LD1VT.getStoreSize();
16338 if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
16339 DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
16340 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
16341 *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
16342 return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
16343 LD1->getPointerInfo(), LD1->getAlign());
16344
16345 return SDValue();
16346 }
16347
getPPCf128HiElementSelector(const SelectionDAG & DAG)16348 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
16349 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
16350 // and Lo parts; on big-endian machines it doesn't.
16351 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
16352 }
16353
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)16354 SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
16355 const TargetLowering &TLI) {
16356 // If this is not a bitcast to an FP type or if the target doesn't have
16357 // IEEE754-compliant FP logic, we're done.
16358 EVT VT = N->getValueType(0);
16359 SDValue N0 = N->getOperand(0);
16360 EVT SourceVT = N0.getValueType();
16361
16362 if (!VT.isFloatingPoint())
16363 return SDValue();
16364
16365 // TODO: Handle cases where the integer constant is a different scalar
16366 // bitwidth to the FP.
16367 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
16368 return SDValue();
16369
16370 unsigned FPOpcode;
16371 APInt SignMask;
16372 switch (N0.getOpcode()) {
16373 case ISD::AND:
16374 FPOpcode = ISD::FABS;
16375 SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
16376 break;
16377 case ISD::XOR:
16378 FPOpcode = ISD::FNEG;
16379 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
16380 break;
16381 case ISD::OR:
16382 FPOpcode = ISD::FABS;
16383 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
16384 break;
16385 default:
16386 return SDValue();
16387 }
16388
16389 if (LegalOperations && !TLI.isOperationLegal(FPOpcode, VT))
16390 return SDValue();
16391
16392 // This needs to be the inverse of logic in foldSignChangeInBitcast.
16393 // FIXME: I don't think looking for bitcast intrinsically makes sense, but
16394 // removing this would require more changes.
16395 auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
16396 if (sd_match(Op, m_BitCast(m_SpecificVT(VT))))
16397 return true;
16398
16399 return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
16400 };
16401
16402 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
16403 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
16404 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
16405 // fneg (fabs X)
16406 SDValue LogicOp0 = N0.getOperand(0);
16407 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
16408 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
16409 IsBitCastOrFree(LogicOp0, VT)) {
16410 SDValue CastOp0 = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, LogicOp0);
16411 SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, CastOp0);
16412 NumFPLogicOpsConv++;
16413 if (N0.getOpcode() == ISD::OR)
16414 return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
16415 return FPOp;
16416 }
16417
16418 return SDValue();
16419 }
16420
visitBITCAST(SDNode * N)16421 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
16422 SDValue N0 = N->getOperand(0);
16423 EVT VT = N->getValueType(0);
16424
16425 if (N0.isUndef())
16426 return DAG.getUNDEF(VT);
16427
16428 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
16429 // Only do this before legalize types, unless both types are integer and the
16430 // scalar type is legal. Only do this before legalize ops, since the target
16431 // maybe depending on the bitcast.
16432 // First check to see if this is all constant.
16433 // TODO: Support FP bitcasts after legalize types.
16434 if (VT.isVector() &&
16435 (!LegalTypes ||
16436 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
16437 TLI.isTypeLegal(VT.getVectorElementType()))) &&
16438 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
16439 cast<BuildVectorSDNode>(N0)->isConstant())
16440 return DAG.FoldConstantBuildVector(cast<BuildVectorSDNode>(N0), SDLoc(N),
16441 VT.getVectorElementType());
16442
16443 // If the input is a constant, let getNode fold it.
16444 if (isIntOrFPConstant(N0)) {
16445 // If we can't allow illegal operations, we need to check that this is just
16446 // a fp -> int or int -> conversion and that the resulting operation will
16447 // be legal.
16448 if (!LegalOperations ||
16449 (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
16450 TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
16451 (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
16452 TLI.isOperationLegal(ISD::Constant, VT))) {
16453 SDValue C = DAG.getBitcast(VT, N0);
16454 if (C.getNode() != N)
16455 return C;
16456 }
16457 }
16458
16459 // (conv (conv x, t1), t2) -> (conv x, t2)
16460 if (N0.getOpcode() == ISD::BITCAST)
16461 return DAG.getBitcast(VT, N0.getOperand(0));
16462
16463 // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
16464 // iff the current bitwise logicop type isn't legal
16465 if (ISD::isBitwiseLogicOp(N0.getOpcode()) && VT.isInteger() &&
16466 !TLI.isTypeLegal(N0.getOperand(0).getValueType())) {
16467 auto IsFreeBitcast = [VT](SDValue V) {
16468 return (V.getOpcode() == ISD::BITCAST &&
16469 V.getOperand(0).getValueType() == VT) ||
16470 (ISD::isBuildVectorOfConstantSDNodes(V.getNode()) &&
16471 V->hasOneUse());
16472 };
16473 if (IsFreeBitcast(N0.getOperand(0)) && IsFreeBitcast(N0.getOperand(1)))
16474 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT,
16475 DAG.getBitcast(VT, N0.getOperand(0)),
16476 DAG.getBitcast(VT, N0.getOperand(1)));
16477 }
16478
16479 // fold (conv (load x)) -> (load (conv*)x)
16480 // If the resultant load doesn't need a higher alignment than the original!
16481 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16482 // Do not remove the cast if the types differ in endian layout.
16483 TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
16484 TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
16485 // If the load is volatile, we only want to change the load type if the
16486 // resulting load is legal. Otherwise we might increase the number of
16487 // memory accesses. We don't care if the original type was legal or not
16488 // as we assume software couldn't rely on the number of accesses of an
16489 // illegal type.
16490 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
16491 TLI.isOperationLegal(ISD::LOAD, VT))) {
16492 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
16493
16494 if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
16495 *LN0->getMemOperand())) {
16496 // If the range metadata type does not match the new memory
16497 // operation type, remove the range metadata.
16498 if (const MDNode *MD = LN0->getRanges()) {
16499 ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
16500 if (Lower->getBitWidth() != VT.getScalarSizeInBits() ||
16501 !VT.isInteger()) {
16502 LN0->getMemOperand()->clearRanges();
16503 }
16504 }
16505 SDValue Load =
16506 DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
16507 LN0->getMemOperand());
16508 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
16509 return Load;
16510 }
16511 }
16512
16513 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
16514 return V;
16515
16516 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
16517 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
16518 //
16519 // For ppc_fp128:
16520 // fold (bitcast (fneg x)) ->
16521 // flipbit = signbit
16522 // (xor (bitcast x) (build_pair flipbit, flipbit))
16523 //
16524 // fold (bitcast (fabs x)) ->
16525 // flipbit = (and (extract_element (bitcast x), 0), signbit)
16526 // (xor (bitcast x) (build_pair flipbit, flipbit))
16527 // This often reduces constant pool loads.
16528 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
16529 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
16530 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
16531 !N0.getValueType().isVector()) {
16532 SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
16533 AddToWorklist(NewConv.getNode());
16534
16535 SDLoc DL(N);
16536 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
16537 assert(VT.getSizeInBits() == 128);
16538 SDValue SignBit = DAG.getConstant(
16539 APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
16540 SDValue FlipBit;
16541 if (N0.getOpcode() == ISD::FNEG) {
16542 FlipBit = SignBit;
16543 AddToWorklist(FlipBit.getNode());
16544 } else {
16545 assert(N0.getOpcode() == ISD::FABS);
16546 SDValue Hi =
16547 DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
16548 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
16549 SDLoc(NewConv)));
16550 AddToWorklist(Hi.getNode());
16551 FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
16552 AddToWorklist(FlipBit.getNode());
16553 }
16554 SDValue FlipBits =
16555 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
16556 AddToWorklist(FlipBits.getNode());
16557 return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
16558 }
16559 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
16560 if (N0.getOpcode() == ISD::FNEG)
16561 return DAG.getNode(ISD::XOR, DL, VT,
16562 NewConv, DAG.getConstant(SignBit, DL, VT));
16563 assert(N0.getOpcode() == ISD::FABS);
16564 return DAG.getNode(ISD::AND, DL, VT,
16565 NewConv, DAG.getConstant(~SignBit, DL, VT));
16566 }
16567
16568 // fold (bitconvert (fcopysign cst, x)) ->
16569 // (or (and (bitconvert x), sign), (and cst, (not sign)))
16570 // Note that we don't handle (copysign x, cst) because this can always be
16571 // folded to an fneg or fabs.
16572 //
16573 // For ppc_fp128:
16574 // fold (bitcast (fcopysign cst, x)) ->
16575 // flipbit = (and (extract_element
16576 // (xor (bitcast cst), (bitcast x)), 0),
16577 // signbit)
16578 // (xor (bitcast cst) (build_pair flipbit, flipbit))
16579 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
16580 isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
16581 !VT.isVector()) {
16582 unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
16583 EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
16584 if (isTypeLegal(IntXVT)) {
16585 SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
16586 AddToWorklist(X.getNode());
16587
16588 // If X has a different width than the result/lhs, sext it or truncate it.
16589 unsigned VTWidth = VT.getSizeInBits();
16590 if (OrigXWidth < VTWidth) {
16591 X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
16592 AddToWorklist(X.getNode());
16593 } else if (OrigXWidth > VTWidth) {
16594 // To get the sign bit in the right place, we have to shift it right
16595 // before truncating.
16596 SDLoc DL(X);
16597 X = DAG.getNode(ISD::SRL, DL,
16598 X.getValueType(), X,
16599 DAG.getConstant(OrigXWidth-VTWidth, DL,
16600 X.getValueType()));
16601 AddToWorklist(X.getNode());
16602 X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
16603 AddToWorklist(X.getNode());
16604 }
16605
16606 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
16607 APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
16608 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
16609 AddToWorklist(Cst.getNode());
16610 SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
16611 AddToWorklist(X.getNode());
16612 SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
16613 AddToWorklist(XorResult.getNode());
16614 SDValue XorResult64 = DAG.getNode(
16615 ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
16616 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
16617 SDLoc(XorResult)));
16618 AddToWorklist(XorResult64.getNode());
16619 SDValue FlipBit =
16620 DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
16621 DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
16622 AddToWorklist(FlipBit.getNode());
16623 SDValue FlipBits =
16624 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
16625 AddToWorklist(FlipBits.getNode());
16626 return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
16627 }
16628 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
16629 X = DAG.getNode(ISD::AND, SDLoc(X), VT,
16630 X, DAG.getConstant(SignBit, SDLoc(X), VT));
16631 AddToWorklist(X.getNode());
16632
16633 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
16634 Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
16635 Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
16636 AddToWorklist(Cst.getNode());
16637
16638 return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
16639 }
16640 }
16641
16642 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
16643 if (N0.getOpcode() == ISD::BUILD_PAIR)
16644 if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
16645 return CombineLD;
16646
16647 // int_vt (bitcast (vec_vt (scalar_to_vector elt_vt:x)))
16648 // => int_vt (any_extend elt_vt:x)
16649 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && VT.isScalarInteger()) {
16650 SDValue SrcScalar = N0.getOperand(0);
16651 if (SrcScalar.getValueType().isScalarInteger())
16652 return DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, SrcScalar);
16653 }
16654
16655 // Remove double bitcasts from shuffles - this is often a legacy of
16656 // XformToShuffleWithZero being used to combine bitmaskings (of
16657 // float vectors bitcast to integer vectors) into shuffles.
16658 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
16659 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
16660 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
16661 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
16662 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
16663 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
16664
16665 // If operands are a bitcast, peek through if it casts the original VT.
16666 // If operands are a constant, just bitcast back to original VT.
16667 auto PeekThroughBitcast = [&](SDValue Op) {
16668 if (Op.getOpcode() == ISD::BITCAST &&
16669 Op.getOperand(0).getValueType() == VT)
16670 return SDValue(Op.getOperand(0));
16671 if (Op.isUndef() || isAnyConstantBuildVector(Op))
16672 return DAG.getBitcast(VT, Op);
16673 return SDValue();
16674 };
16675
16676 // FIXME: If either input vector is bitcast, try to convert the shuffle to
16677 // the result type of this bitcast. This would eliminate at least one
16678 // bitcast. See the transform in InstCombine.
16679 SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
16680 SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
16681 if (!(SV0 && SV1))
16682 return SDValue();
16683
16684 int MaskScale =
16685 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
16686 SmallVector<int, 8> NewMask;
16687 for (int M : SVN->getMask())
16688 for (int i = 0; i != MaskScale; ++i)
16689 NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
16690
16691 SDValue LegalShuffle =
16692 TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
16693 if (LegalShuffle)
16694 return LegalShuffle;
16695 }
16696
16697 return SDValue();
16698 }
16699
visitBUILD_PAIR(SDNode * N)16700 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
16701 EVT VT = N->getValueType(0);
16702 return CombineConsecutiveLoads(N, VT);
16703 }
16704
visitFREEZE(SDNode * N)16705 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
16706 SDValue N0 = N->getOperand(0);
16707
16708 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
16709 return N0;
16710
16711 // We currently avoid folding freeze over SRA/SRL, due to the problems seen
16712 // with (freeze (assert ext)) blocking simplifications of SRA/SRL. See for
16713 // example https://reviews.llvm.org/D136529#4120959.
16714 if (N0.getOpcode() == ISD::SRA || N0.getOpcode() == ISD::SRL)
16715 return SDValue();
16716
16717 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
16718 // Try to push freeze through instructions that propagate but don't produce
16719 // poison as far as possible. If an operand of freeze follows three
16720 // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
16721 // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
16722 // the freeze through to the operands that are not guaranteed non-poison.
16723 // NOTE: we will strip poison-generating flags, so ignore them here.
16724 if (DAG.canCreateUndefOrPoison(N0, /*PoisonOnly*/ false,
16725 /*ConsiderFlags*/ false) ||
16726 N0->getNumValues() != 1 || !N0->hasOneUse())
16727 return SDValue();
16728
16729 // TOOD: we should always allow multiple operands, however this increases the
16730 // likelihood of infinite loops due to the ReplaceAllUsesOfValueWith call
16731 // below causing later nodes that share frozen operands to fold again and no
16732 // longer being able to confirm other operands are not poison due to recursion
16733 // depth limits on isGuaranteedNotToBeUndefOrPoison.
16734 bool AllowMultipleMaybePoisonOperands =
16735 N0.getOpcode() == ISD::SELECT_CC || N0.getOpcode() == ISD::SETCC ||
16736 N0.getOpcode() == ISD::BUILD_VECTOR ||
16737 N0.getOpcode() == ISD::BUILD_PAIR ||
16738 N0.getOpcode() == ISD::VECTOR_SHUFFLE ||
16739 N0.getOpcode() == ISD::CONCAT_VECTORS || N0.getOpcode() == ISD::FMUL;
16740
16741 // Avoid turning a BUILD_VECTOR that can be recognized as "all zeros", "all
16742 // ones" or "constant" into something that depends on FrozenUndef. We can
16743 // instead pick undef values to keep those properties, while at the same time
16744 // folding away the freeze.
16745 // If we implement a more general solution for folding away freeze(undef) in
16746 // the future, then this special handling can be removed.
16747 if (N0.getOpcode() == ISD::BUILD_VECTOR) {
16748 SDLoc DL(N0);
16749 EVT VT = N0.getValueType();
16750 if (llvm::ISD::isBuildVectorAllOnes(N0.getNode()) && VT.isInteger())
16751 return DAG.getAllOnesConstant(DL, VT);
16752 if (llvm::ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
16753 SmallVector<SDValue, 8> NewVecC;
16754 for (const SDValue &Op : N0->op_values())
16755 NewVecC.push_back(
16756 Op.isUndef() ? DAG.getConstant(0, DL, Op.getValueType()) : Op);
16757 return DAG.getBuildVector(VT, DL, NewVecC);
16758 }
16759 }
16760
16761 SmallSet<SDValue, 8> MaybePoisonOperands;
16762 SmallVector<unsigned, 8> MaybePoisonOperandNumbers;
16763 for (auto [OpNo, Op] : enumerate(N0->ops())) {
16764 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
16765 /*Depth*/ 1))
16766 continue;
16767 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
16768 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op).second;
16769 if (IsNewMaybePoisonOperand)
16770 MaybePoisonOperandNumbers.push_back(OpNo);
16771 if (!HadMaybePoisonOperands)
16772 continue;
16773 if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
16774 // Multiple maybe-poison ops when not allowed - bail out.
16775 return SDValue();
16776 }
16777 }
16778 // NOTE: the whole op may be not guaranteed to not be undef or poison because
16779 // it could create undef or poison due to it's poison-generating flags.
16780 // So not finding any maybe-poison operands is fine.
16781
16782 for (unsigned OpNo : MaybePoisonOperandNumbers) {
16783 // N0 can mutate during iteration, so make sure to refetch the maybe poison
16784 // operands via the operand numbers. The typical scenario is that we have
16785 // something like this
16786 // t262: i32 = freeze t181
16787 // t150: i32 = ctlz_zero_undef t262
16788 // t184: i32 = ctlz_zero_undef t181
16789 // t268: i32 = select_cc t181, Constant:i32<0>, t184, t186, setne:ch
16790 // When freezing the t181 operand we get t262 back, and then the
16791 // ReplaceAllUsesOfValueWith call will not only replace t181 by t262, but
16792 // also recursively replace t184 by t150.
16793 SDValue MaybePoisonOperand = N->getOperand(0).getOperand(OpNo);
16794 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
16795 if (MaybePoisonOperand.isUndef())
16796 continue;
16797 // First, freeze each offending operand.
16798 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(MaybePoisonOperand);
16799 // Then, change all other uses of unfrozen operand to use frozen operand.
16800 DAG.ReplaceAllUsesOfValueWith(MaybePoisonOperand, FrozenMaybePoisonOperand);
16801 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
16802 FrozenMaybePoisonOperand.getOperand(0) == FrozenMaybePoisonOperand) {
16803 // But, that also updated the use in the freeze we just created, thus
16804 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
16805 DAG.UpdateNodeOperands(FrozenMaybePoisonOperand.getNode(),
16806 MaybePoisonOperand);
16807 }
16808
16809 // This node has been merged with another.
16810 if (N->getOpcode() == ISD::DELETED_NODE)
16811 return SDValue(N, 0);
16812 }
16813
16814 assert(N->getOpcode() != ISD::DELETED_NODE && "Node was deleted!");
16815
16816 // The whole node may have been updated, so the value we were holding
16817 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
16818 N0 = N->getOperand(0);
16819
16820 // Finally, recreate the node, it's operands were updated to use
16821 // frozen operands, so we just need to use it's "original" operands.
16822 SmallVector<SDValue> Ops(N0->ops());
16823 // TODO: ISD::UNDEF and ISD::POISON should get separate handling, but best
16824 // leave for a future patch.
16825 for (SDValue &Op : Ops) {
16826 if (Op.isUndef())
16827 Op = DAG.getFreeze(Op);
16828 }
16829
16830 SDLoc DL(N0);
16831
16832 // Special case handling for ShuffleVectorSDNode nodes.
16833 if (auto *SVN = dyn_cast<ShuffleVectorSDNode>(N0))
16834 return DAG.getVectorShuffle(N0.getValueType(), DL, Ops[0], Ops[1],
16835 SVN->getMask());
16836
16837 // NOTE: this strips poison generating flags.
16838 // Folding freeze(op(x, ...)) -> op(freeze(x), ...) does not require nnan,
16839 // ninf, nsz, or fast.
16840 // However, contract, reassoc, afn, and arcp should be preserved,
16841 // as these fast-math flags do not introduce poison values.
16842 SDNodeFlags SrcFlags = N0->getFlags();
16843 SDNodeFlags SafeFlags;
16844 SafeFlags.setAllowContract(SrcFlags.hasAllowContract());
16845 SafeFlags.setAllowReassociation(SrcFlags.hasAllowReassociation());
16846 SafeFlags.setApproximateFuncs(SrcFlags.hasApproximateFuncs());
16847 SafeFlags.setAllowReciprocal(SrcFlags.hasAllowReciprocal());
16848 return DAG.getNode(N0.getOpcode(), DL, N0->getVTList(), Ops, SafeFlags);
16849 }
16850
16851 // Returns true if floating point contraction is allowed on the FMUL-SDValue
16852 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)16853 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
16854 assert(N.getOpcode() == ISD::FMUL);
16855
16856 return Options.AllowFPOpFusion == FPOpFusion::Fast ||
16857 N->getFlags().hasAllowContract();
16858 }
16859
16860 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)16861 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
16862 return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
16863 }
16864
16865 /// Try to perform FMA combining on a given FADD node.
16866 template <class MatchContextClass>
visitFADDForFMACombine(SDNode * N)16867 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
16868 SDValue N0 = N->getOperand(0);
16869 SDValue N1 = N->getOperand(1);
16870 EVT VT = N->getValueType(0);
16871 SDLoc SL(N);
16872 MatchContextClass matcher(DAG, TLI, N);
16873 const TargetOptions &Options = DAG.getTarget().Options;
16874
16875 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
16876
16877 // Floating-point multiply-add with intermediate rounding.
16878 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
16879 // FIXME: Add VP_FMAD opcode.
16880 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
16881
16882 // Floating-point multiply-add without intermediate rounding.
16883 bool HasFMA =
16884 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
16885 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT);
16886
16887 // No valid opcode, do not combine.
16888 if (!HasFMAD && !HasFMA)
16889 return SDValue();
16890
16891 bool AllowFusionGlobally =
16892 Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD;
16893 // If the addition is not contractable, do not combine.
16894 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
16895 return SDValue();
16896
16897 // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
16898 // beneficial. It does not reduce latency. It increases register pressure. It
16899 // replaces an fadd with an fma which is a more complex instruction, so is
16900 // likely to have a larger encoding, use more functional units, etc.
16901 if (N0 == N1)
16902 return SDValue();
16903
16904 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
16905 return SDValue();
16906
16907 // Always prefer FMAD to FMA for precision.
16908 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16909 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16910
16911 auto isFusedOp = [&](SDValue N) {
16912 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
16913 };
16914
16915 // Is the node an FMUL and contractable either due to global flags or
16916 // SDNodeFlags.
16917 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
16918 if (!matcher.match(N, ISD::FMUL))
16919 return false;
16920 return AllowFusionGlobally || N->getFlags().hasAllowContract();
16921 };
16922 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
16923 // prefer to fold the multiply with fewer uses.
16924 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
16925 if (N0->use_size() > N1->use_size())
16926 std::swap(N0, N1);
16927 }
16928
16929 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
16930 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
16931 return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
16932 N0.getOperand(1), N1);
16933 }
16934
16935 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
16936 // Note: Commutes FADD operands.
16937 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
16938 return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
16939 N1.getOperand(1), N0);
16940 }
16941
16942 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
16943 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
16944 // This also works with nested fma instructions:
16945 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
16946 // fma A, B, (fma C, D, fma (E, F, G))
16947 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
16948 // fma A, B, (fma C, D, fma (E, F, G)).
16949 // This requires reassociation because it changes the order of operations.
16950 bool CanReassociate =
16951 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16952 if (CanReassociate) {
16953 SDValue FMA, E;
16954 if (isFusedOp(N0) && N0.hasOneUse()) {
16955 FMA = N0;
16956 E = N1;
16957 } else if (isFusedOp(N1) && N1.hasOneUse()) {
16958 FMA = N1;
16959 E = N0;
16960 }
16961
16962 SDValue TmpFMA = FMA;
16963 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
16964 SDValue FMul = TmpFMA->getOperand(2);
16965 if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
16966 SDValue C = FMul.getOperand(0);
16967 SDValue D = FMul.getOperand(1);
16968 SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
16969 DAG.ReplaceAllUsesOfValueWith(FMul, CDE);
16970 // Replacing the inner FMul could cause the outer FMA to be simplified
16971 // away.
16972 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
16973 }
16974
16975 TmpFMA = TmpFMA->getOperand(2);
16976 }
16977 }
16978
16979 // Look through FP_EXTEND nodes to do more combining.
16980
16981 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
16982 if (matcher.match(N0, ISD::FP_EXTEND)) {
16983 SDValue N00 = N0.getOperand(0);
16984 if (isContractableFMUL(N00) &&
16985 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
16986 N00.getValueType())) {
16987 return matcher.getNode(
16988 PreferredFusedOpcode, SL, VT,
16989 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
16990 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1);
16991 }
16992 }
16993
16994 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
16995 // Note: Commutes FADD operands.
16996 if (matcher.match(N1, ISD::FP_EXTEND)) {
16997 SDValue N10 = N1.getOperand(0);
16998 if (isContractableFMUL(N10) &&
16999 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17000 N10.getValueType())) {
17001 return matcher.getNode(
17002 PreferredFusedOpcode, SL, VT,
17003 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
17004 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
17005 }
17006 }
17007
17008 // More folding opportunities when target permits.
17009 if (Aggressive) {
17010 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
17011 // -> (fma x, y, (fma (fpext u), (fpext v), z))
17012 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17013 SDValue Z) {
17014 return matcher.getNode(
17015 PreferredFusedOpcode, SL, VT, X, Y,
17016 matcher.getNode(PreferredFusedOpcode, SL, VT,
17017 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17018 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17019 };
17020 if (isFusedOp(N0)) {
17021 SDValue N02 = N0.getOperand(2);
17022 if (matcher.match(N02, ISD::FP_EXTEND)) {
17023 SDValue N020 = N02.getOperand(0);
17024 if (isContractableFMUL(N020) &&
17025 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17026 N020.getValueType())) {
17027 return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
17028 N020.getOperand(0), N020.getOperand(1),
17029 N1);
17030 }
17031 }
17032 }
17033
17034 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
17035 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
17036 // FIXME: This turns two single-precision and one double-precision
17037 // operation into two double-precision operations, which might not be
17038 // interesting for all targets, especially GPUs.
17039 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
17040 SDValue Z) {
17041 return matcher.getNode(
17042 PreferredFusedOpcode, SL, VT,
17043 matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
17044 matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
17045 matcher.getNode(PreferredFusedOpcode, SL, VT,
17046 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
17047 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
17048 };
17049 if (N0.getOpcode() == ISD::FP_EXTEND) {
17050 SDValue N00 = N0.getOperand(0);
17051 if (isFusedOp(N00)) {
17052 SDValue N002 = N00.getOperand(2);
17053 if (isContractableFMUL(N002) &&
17054 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17055 N00.getValueType())) {
17056 return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
17057 N002.getOperand(0), N002.getOperand(1),
17058 N1);
17059 }
17060 }
17061 }
17062
17063 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
17064 // -> (fma y, z, (fma (fpext u), (fpext v), x))
17065 if (isFusedOp(N1)) {
17066 SDValue N12 = N1.getOperand(2);
17067 if (N12.getOpcode() == ISD::FP_EXTEND) {
17068 SDValue N120 = N12.getOperand(0);
17069 if (isContractableFMUL(N120) &&
17070 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17071 N120.getValueType())) {
17072 return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
17073 N120.getOperand(0), N120.getOperand(1),
17074 N0);
17075 }
17076 }
17077 }
17078
17079 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
17080 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
17081 // FIXME: This turns two single-precision and one double-precision
17082 // operation into two double-precision operations, which might not be
17083 // interesting for all targets, especially GPUs.
17084 if (N1.getOpcode() == ISD::FP_EXTEND) {
17085 SDValue N10 = N1.getOperand(0);
17086 if (isFusedOp(N10)) {
17087 SDValue N102 = N10.getOperand(2);
17088 if (isContractableFMUL(N102) &&
17089 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17090 N10.getValueType())) {
17091 return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
17092 N102.getOperand(0), N102.getOperand(1),
17093 N0);
17094 }
17095 }
17096 }
17097 }
17098
17099 return SDValue();
17100 }
17101
17102 /// Try to perform FMA combining on a given FSUB node.
17103 template <class MatchContextClass>
visitFSUBForFMACombine(SDNode * N)17104 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
17105 SDValue N0 = N->getOperand(0);
17106 SDValue N1 = N->getOperand(1);
17107 EVT VT = N->getValueType(0);
17108 SDLoc SL(N);
17109 MatchContextClass matcher(DAG, TLI, N);
17110 const TargetOptions &Options = DAG.getTarget().Options;
17111
17112 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
17113
17114 // Floating-point multiply-add with intermediate rounding.
17115 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
17116 // FIXME: Add VP_FMAD opcode.
17117 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
17118
17119 // Floating-point multiply-add without intermediate rounding.
17120 bool HasFMA =
17121 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17122 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT);
17123
17124 // No valid opcode, do not combine.
17125 if (!HasFMAD && !HasFMA)
17126 return SDValue();
17127
17128 const SDNodeFlags Flags = N->getFlags();
17129 bool AllowFusionGlobally =
17130 (Options.AllowFPOpFusion == FPOpFusion::Fast || HasFMAD);
17131
17132 // If the subtraction is not contractable, do not combine.
17133 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
17134 return SDValue();
17135
17136 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
17137 return SDValue();
17138
17139 // Always prefer FMAD to FMA for precision.
17140 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17141 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17142 bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
17143
17144 // Is the node an FMUL and contractable either due to global flags or
17145 // SDNodeFlags.
17146 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
17147 if (!matcher.match(N, ISD::FMUL))
17148 return false;
17149 return AllowFusionGlobally || N->getFlags().hasAllowContract();
17150 };
17151
17152 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17153 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
17154 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
17155 return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
17156 XY.getOperand(1),
17157 matcher.getNode(ISD::FNEG, SL, VT, Z));
17158 }
17159 return SDValue();
17160 };
17161
17162 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17163 // Note: Commutes FSUB operands.
17164 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
17165 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
17166 return matcher.getNode(
17167 PreferredFusedOpcode, SL, VT,
17168 matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
17169 YZ.getOperand(1), X);
17170 }
17171 return SDValue();
17172 };
17173
17174 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
17175 // prefer to fold the multiply with fewer uses.
17176 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
17177 (N0->use_size() > N1->use_size())) {
17178 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
17179 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17180 return V;
17181 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
17182 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17183 return V;
17184 } else {
17185 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
17186 if (SDValue V = tryToFoldXYSubZ(N0, N1))
17187 return V;
17188 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
17189 if (SDValue V = tryToFoldXSubYZ(N0, N1))
17190 return V;
17191 }
17192
17193 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
17194 if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) &&
17195 (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
17196 SDValue N00 = N0.getOperand(0).getOperand(0);
17197 SDValue N01 = N0.getOperand(0).getOperand(1);
17198 return matcher.getNode(PreferredFusedOpcode, SL, VT,
17199 matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
17200 matcher.getNode(ISD::FNEG, SL, VT, N1));
17201 }
17202
17203 // Look through FP_EXTEND nodes to do more combining.
17204
17205 // fold (fsub (fpext (fmul x, y)), z)
17206 // -> (fma (fpext x), (fpext y), (fneg z))
17207 if (matcher.match(N0, ISD::FP_EXTEND)) {
17208 SDValue N00 = N0.getOperand(0);
17209 if (isContractableFMUL(N00) &&
17210 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17211 N00.getValueType())) {
17212 return matcher.getNode(
17213 PreferredFusedOpcode, SL, VT,
17214 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
17215 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
17216 matcher.getNode(ISD::FNEG, SL, VT, N1));
17217 }
17218 }
17219
17220 // fold (fsub x, (fpext (fmul y, z)))
17221 // -> (fma (fneg (fpext y)), (fpext z), x)
17222 // Note: Commutes FSUB operands.
17223 if (matcher.match(N1, ISD::FP_EXTEND)) {
17224 SDValue N10 = N1.getOperand(0);
17225 if (isContractableFMUL(N10) &&
17226 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17227 N10.getValueType())) {
17228 return matcher.getNode(
17229 PreferredFusedOpcode, SL, VT,
17230 matcher.getNode(
17231 ISD::FNEG, SL, VT,
17232 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
17233 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
17234 }
17235 }
17236
17237 // fold (fsub (fpext (fneg (fmul, x, y))), z)
17238 // -> (fneg (fma (fpext x), (fpext y), z))
17239 // Note: This could be removed with appropriate canonicalization of the
17240 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
17241 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
17242 // from implementing the canonicalization in visitFSUB.
17243 if (matcher.match(N0, ISD::FP_EXTEND)) {
17244 SDValue N00 = N0.getOperand(0);
17245 if (matcher.match(N00, ISD::FNEG)) {
17246 SDValue N000 = N00.getOperand(0);
17247 if (isContractableFMUL(N000) &&
17248 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17249 N00.getValueType())) {
17250 return matcher.getNode(
17251 ISD::FNEG, SL, VT,
17252 matcher.getNode(
17253 PreferredFusedOpcode, SL, VT,
17254 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
17255 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
17256 N1));
17257 }
17258 }
17259 }
17260
17261 // fold (fsub (fneg (fpext (fmul, x, y))), z)
17262 // -> (fneg (fma (fpext x)), (fpext y), z)
17263 // Note: This could be removed with appropriate canonicalization of the
17264 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
17265 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
17266 // from implementing the canonicalization in visitFSUB.
17267 if (matcher.match(N0, ISD::FNEG)) {
17268 SDValue N00 = N0.getOperand(0);
17269 if (matcher.match(N00, ISD::FP_EXTEND)) {
17270 SDValue N000 = N00.getOperand(0);
17271 if (isContractableFMUL(N000) &&
17272 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17273 N000.getValueType())) {
17274 return matcher.getNode(
17275 ISD::FNEG, SL, VT,
17276 matcher.getNode(
17277 PreferredFusedOpcode, SL, VT,
17278 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
17279 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
17280 N1));
17281 }
17282 }
17283 }
17284
17285 auto isContractableAndReassociableFMUL = [&isContractableFMUL](SDValue N) {
17286 return isContractableFMUL(N) && N->getFlags().hasAllowReassociation();
17287 };
17288
17289 auto isFusedOp = [&](SDValue N) {
17290 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
17291 };
17292
17293 // More folding opportunities when target permits.
17294 if (Aggressive && N->getFlags().hasAllowReassociation()) {
17295 bool CanFuse = N->getFlags().hasAllowContract();
17296 // fold (fsub (fma x, y, (fmul u, v)), z)
17297 // -> (fma x, y (fma u, v, (fneg z)))
17298 if (CanFuse && isFusedOp(N0) &&
17299 isContractableAndReassociableFMUL(N0.getOperand(2)) &&
17300 N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
17301 return matcher.getNode(
17302 PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
17303 matcher.getNode(PreferredFusedOpcode, SL, VT,
17304 N0.getOperand(2).getOperand(0),
17305 N0.getOperand(2).getOperand(1),
17306 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17307 }
17308
17309 // fold (fsub x, (fma y, z, (fmul u, v)))
17310 // -> (fma (fneg y), z, (fma (fneg u), v, x))
17311 if (CanFuse && isFusedOp(N1) &&
17312 isContractableAndReassociableFMUL(N1.getOperand(2)) &&
17313 N1->hasOneUse() && NoSignedZero) {
17314 SDValue N20 = N1.getOperand(2).getOperand(0);
17315 SDValue N21 = N1.getOperand(2).getOperand(1);
17316 return matcher.getNode(
17317 PreferredFusedOpcode, SL, VT,
17318 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
17319 N1.getOperand(1),
17320 matcher.getNode(PreferredFusedOpcode, SL, VT,
17321 matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
17322 }
17323
17324 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
17325 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
17326 if (isFusedOp(N0) && N0->hasOneUse()) {
17327 SDValue N02 = N0.getOperand(2);
17328 if (matcher.match(N02, ISD::FP_EXTEND)) {
17329 SDValue N020 = N02.getOperand(0);
17330 if (isContractableAndReassociableFMUL(N020) &&
17331 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17332 N020.getValueType())) {
17333 return matcher.getNode(
17334 PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
17335 matcher.getNode(
17336 PreferredFusedOpcode, SL, VT,
17337 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
17338 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
17339 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17340 }
17341 }
17342 }
17343
17344 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
17345 // -> (fma (fpext x), (fpext y),
17346 // (fma (fpext u), (fpext v), (fneg z)))
17347 // FIXME: This turns two single-precision and one double-precision
17348 // operation into two double-precision operations, which might not be
17349 // interesting for all targets, especially GPUs.
17350 if (matcher.match(N0, ISD::FP_EXTEND)) {
17351 SDValue N00 = N0.getOperand(0);
17352 if (isFusedOp(N00)) {
17353 SDValue N002 = N00.getOperand(2);
17354 if (isContractableAndReassociableFMUL(N002) &&
17355 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17356 N00.getValueType())) {
17357 return matcher.getNode(
17358 PreferredFusedOpcode, SL, VT,
17359 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
17360 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
17361 matcher.getNode(
17362 PreferredFusedOpcode, SL, VT,
17363 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
17364 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
17365 matcher.getNode(ISD::FNEG, SL, VT, N1)));
17366 }
17367 }
17368 }
17369
17370 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
17371 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
17372 if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) &&
17373 N1->hasOneUse()) {
17374 SDValue N120 = N1.getOperand(2).getOperand(0);
17375 if (isContractableAndReassociableFMUL(N120) &&
17376 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17377 N120.getValueType())) {
17378 SDValue N1200 = N120.getOperand(0);
17379 SDValue N1201 = N120.getOperand(1);
17380 return matcher.getNode(
17381 PreferredFusedOpcode, SL, VT,
17382 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)),
17383 N1.getOperand(1),
17384 matcher.getNode(
17385 PreferredFusedOpcode, SL, VT,
17386 matcher.getNode(ISD::FNEG, SL, VT,
17387 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
17388 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
17389 }
17390 }
17391
17392 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
17393 // -> (fma (fneg (fpext y)), (fpext z),
17394 // (fma (fneg (fpext u)), (fpext v), x))
17395 // FIXME: This turns two single-precision and one double-precision
17396 // operation into two double-precision operations, which might not be
17397 // interesting for all targets, especially GPUs.
17398 if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) {
17399 SDValue CvtSrc = N1.getOperand(0);
17400 SDValue N100 = CvtSrc.getOperand(0);
17401 SDValue N101 = CvtSrc.getOperand(1);
17402 SDValue N102 = CvtSrc.getOperand(2);
17403 if (isContractableAndReassociableFMUL(N102) &&
17404 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
17405 CvtSrc.getValueType())) {
17406 SDValue N1020 = N102.getOperand(0);
17407 SDValue N1021 = N102.getOperand(1);
17408 return matcher.getNode(
17409 PreferredFusedOpcode, SL, VT,
17410 matcher.getNode(ISD::FNEG, SL, VT,
17411 matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
17412 matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
17413 matcher.getNode(
17414 PreferredFusedOpcode, SL, VT,
17415 matcher.getNode(ISD::FNEG, SL, VT,
17416 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
17417 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
17418 }
17419 }
17420 }
17421
17422 return SDValue();
17423 }
17424
17425 /// Try to perform FMA combining on a given FMUL node based on the distributive
17426 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
17427 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)17428 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
17429 SDValue N0 = N->getOperand(0);
17430 SDValue N1 = N->getOperand(1);
17431 EVT VT = N->getValueType(0);
17432 SDLoc SL(N);
17433
17434 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
17435
17436 const TargetOptions &Options = DAG.getTarget().Options;
17437
17438 // The transforms below are incorrect when x == 0 and y == inf, because the
17439 // intermediate multiplication produces a nan.
17440 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
17441 if (!hasNoInfs(Options, FAdd))
17442 return SDValue();
17443
17444 // Floating-point multiply-add without intermediate rounding.
17445 bool HasFMA =
17446 isContractableFMUL(Options, SDValue(N, 0)) &&
17447 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)) &&
17448 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT);
17449
17450 // Floating-point multiply-add with intermediate rounding. This can result
17451 // in a less precise result due to the changed rounding order.
17452 bool HasFMAD = LegalOperations && TLI.isFMADLegal(DAG, N);
17453
17454 // No valid opcode, do not combine.
17455 if (!HasFMAD && !HasFMA)
17456 return SDValue();
17457
17458 // Always prefer FMAD to FMA for precision.
17459 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
17460 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
17461
17462 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
17463 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
17464 auto FuseFADD = [&](SDValue X, SDValue Y) {
17465 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
17466 if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
17467 if (C->isExactlyValue(+1.0))
17468 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17469 Y);
17470 if (C->isExactlyValue(-1.0))
17471 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17472 DAG.getNode(ISD::FNEG, SL, VT, Y));
17473 }
17474 }
17475 return SDValue();
17476 };
17477
17478 if (SDValue FMA = FuseFADD(N0, N1))
17479 return FMA;
17480 if (SDValue FMA = FuseFADD(N1, N0))
17481 return FMA;
17482
17483 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
17484 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
17485 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
17486 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
17487 auto FuseFSUB = [&](SDValue X, SDValue Y) {
17488 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
17489 if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
17490 if (C0->isExactlyValue(+1.0))
17491 return DAG.getNode(PreferredFusedOpcode, SL, VT,
17492 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
17493 Y);
17494 if (C0->isExactlyValue(-1.0))
17495 return DAG.getNode(PreferredFusedOpcode, SL, VT,
17496 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
17497 DAG.getNode(ISD::FNEG, SL, VT, Y));
17498 }
17499 if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
17500 if (C1->isExactlyValue(+1.0))
17501 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17502 DAG.getNode(ISD::FNEG, SL, VT, Y));
17503 if (C1->isExactlyValue(-1.0))
17504 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
17505 Y);
17506 }
17507 }
17508 return SDValue();
17509 };
17510
17511 if (SDValue FMA = FuseFSUB(N0, N1))
17512 return FMA;
17513 if (SDValue FMA = FuseFSUB(N1, N0))
17514 return FMA;
17515
17516 return SDValue();
17517 }
17518
visitVP_FADD(SDNode * N)17519 SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
17520 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17521
17522 // FADD -> FMA combines:
17523 if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
17524 if (Fused.getOpcode() != ISD::DELETED_NODE)
17525 AddToWorklist(Fused.getNode());
17526 return Fused;
17527 }
17528 return SDValue();
17529 }
17530
visitFADD(SDNode * N)17531 SDValue DAGCombiner::visitFADD(SDNode *N) {
17532 SDValue N0 = N->getOperand(0);
17533 SDValue N1 = N->getOperand(1);
17534 bool N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
17535 bool N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
17536 EVT VT = N->getValueType(0);
17537 SDLoc DL(N);
17538 const TargetOptions &Options = DAG.getTarget().Options;
17539 SDNodeFlags Flags = N->getFlags();
17540 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17541
17542 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17543 return R;
17544
17545 // fold (fadd c1, c2) -> c1 + c2
17546 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
17547 return C;
17548
17549 // canonicalize constant to RHS
17550 if (N0CFP && !N1CFP)
17551 return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
17552
17553 // fold vector ops
17554 if (VT.isVector())
17555 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17556 return FoldedVOp;
17557
17558 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
17559 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
17560 if (N1C && N1C->isZero())
17561 if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
17562 return N0;
17563
17564 if (SDValue NewSel = foldBinOpIntoSelect(N))
17565 return NewSel;
17566
17567 // fold (fadd A, (fneg B)) -> (fsub A, B)
17568 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
17569 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
17570 N1, DAG, LegalOperations, ForCodeSize))
17571 return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
17572
17573 // fold (fadd (fneg A), B) -> (fsub B, A)
17574 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
17575 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
17576 N0, DAG, LegalOperations, ForCodeSize))
17577 return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
17578
17579 auto isFMulNegTwo = [](SDValue FMul) {
17580 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
17581 return false;
17582 auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
17583 return C && C->isExactlyValue(-2.0);
17584 };
17585
17586 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
17587 if (isFMulNegTwo(N0)) {
17588 SDValue B = N0.getOperand(0);
17589 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
17590 return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
17591 }
17592 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
17593 if (isFMulNegTwo(N1)) {
17594 SDValue B = N1.getOperand(0);
17595 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
17596 return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
17597 }
17598
17599 // No FP constant should be created after legalization as Instruction
17600 // Selection pass has a hard time dealing with FP constants.
17601 bool AllowNewConst = (Level < AfterLegalizeDAG);
17602
17603 // If nnan is enabled, fold lots of things.
17604 if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
17605 // If allowed, fold (fadd (fneg x), x) -> 0.0
17606 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
17607 return DAG.getConstantFP(0.0, DL, VT);
17608
17609 // If allowed, fold (fadd x, (fneg x)) -> 0.0
17610 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
17611 return DAG.getConstantFP(0.0, DL, VT);
17612 }
17613
17614 // If 'unsafe math' or reassoc and nsz, fold lots of things.
17615 // TODO: break out portions of the transformations below for which Unsafe is
17616 // considered and which do not require both nsz and reassoc
17617 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17618 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
17619 AllowNewConst) {
17620 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
17621 if (N1CFP && N0.getOpcode() == ISD::FADD &&
17622 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
17623 SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
17624 return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
17625 }
17626
17627 // We can fold chains of FADD's of the same value into multiplications.
17628 // This transform is not safe in general because we are reducing the number
17629 // of rounding steps.
17630 if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
17631 if (N0.getOpcode() == ISD::FMUL) {
17632 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
17633 bool CFP01 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
17634
17635 // (fadd (fmul x, c), x) -> (fmul x, c+1)
17636 if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
17637 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
17638 DAG.getConstantFP(1.0, DL, VT));
17639 return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
17640 }
17641
17642 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
17643 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
17644 N1.getOperand(0) == N1.getOperand(1) &&
17645 N0.getOperand(0) == N1.getOperand(0)) {
17646 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
17647 DAG.getConstantFP(2.0, DL, VT));
17648 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
17649 }
17650 }
17651
17652 if (N1.getOpcode() == ISD::FMUL) {
17653 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
17654 bool CFP11 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
17655
17656 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
17657 if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
17658 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
17659 DAG.getConstantFP(1.0, DL, VT));
17660 return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
17661 }
17662
17663 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
17664 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
17665 N0.getOperand(0) == N0.getOperand(1) &&
17666 N1.getOperand(0) == N0.getOperand(0)) {
17667 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
17668 DAG.getConstantFP(2.0, DL, VT));
17669 return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
17670 }
17671 }
17672
17673 if (N0.getOpcode() == ISD::FADD) {
17674 bool CFP00 = DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
17675 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
17676 if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
17677 (N0.getOperand(0) == N1)) {
17678 return DAG.getNode(ISD::FMUL, DL, VT, N1,
17679 DAG.getConstantFP(3.0, DL, VT));
17680 }
17681 }
17682
17683 if (N1.getOpcode() == ISD::FADD) {
17684 bool CFP10 = DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
17685 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
17686 if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
17687 N1.getOperand(0) == N0) {
17688 return DAG.getNode(ISD::FMUL, DL, VT, N0,
17689 DAG.getConstantFP(3.0, DL, VT));
17690 }
17691 }
17692
17693 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
17694 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
17695 N0.getOperand(0) == N0.getOperand(1) &&
17696 N1.getOperand(0) == N1.getOperand(1) &&
17697 N0.getOperand(0) == N1.getOperand(0)) {
17698 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
17699 DAG.getConstantFP(4.0, DL, VT));
17700 }
17701 }
17702 } // enable-unsafe-fp-math && AllowNewConst
17703
17704 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17705 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))) {
17706 // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
17707 if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
17708 VT, N0, N1, Flags))
17709 return SD;
17710 }
17711
17712 // FADD -> FMA combines:
17713 if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
17714 if (Fused.getOpcode() != ISD::DELETED_NODE)
17715 AddToWorklist(Fused.getNode());
17716 return Fused;
17717 }
17718 return SDValue();
17719 }
17720
visitSTRICT_FADD(SDNode * N)17721 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
17722 SDValue Chain = N->getOperand(0);
17723 SDValue N0 = N->getOperand(1);
17724 SDValue N1 = N->getOperand(2);
17725 EVT VT = N->getValueType(0);
17726 EVT ChainVT = N->getValueType(1);
17727 SDLoc DL(N);
17728 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17729
17730 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
17731 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
17732 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
17733 N1, DAG, LegalOperations, ForCodeSize)) {
17734 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
17735 {Chain, N0, NegN1});
17736 }
17737
17738 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
17739 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
17740 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
17741 N0, DAG, LegalOperations, ForCodeSize)) {
17742 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
17743 {Chain, N1, NegN0});
17744 }
17745 return SDValue();
17746 }
17747
visitFSUB(SDNode * N)17748 SDValue DAGCombiner::visitFSUB(SDNode *N) {
17749 SDValue N0 = N->getOperand(0);
17750 SDValue N1 = N->getOperand(1);
17751 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
17752 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
17753 EVT VT = N->getValueType(0);
17754 SDLoc DL(N);
17755 const TargetOptions &Options = DAG.getTarget().Options;
17756 const SDNodeFlags Flags = N->getFlags();
17757 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17758
17759 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17760 return R;
17761
17762 // fold (fsub c1, c2) -> c1-c2
17763 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
17764 return C;
17765
17766 // fold vector ops
17767 if (VT.isVector())
17768 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17769 return FoldedVOp;
17770
17771 if (SDValue NewSel = foldBinOpIntoSelect(N))
17772 return NewSel;
17773
17774 // (fsub A, 0) -> A
17775 if (N1CFP && N1CFP->isZero()) {
17776 if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
17777 Flags.hasNoSignedZeros()) {
17778 return N0;
17779 }
17780 }
17781
17782 if (N0 == N1) {
17783 // (fsub x, x) -> 0.0
17784 if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
17785 return DAG.getConstantFP(0.0f, DL, VT);
17786 }
17787
17788 // (fsub -0.0, N1) -> -N1
17789 if (N0CFP && N0CFP->isZero()) {
17790 if (N0CFP->isNegative() ||
17791 (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
17792 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
17793 // flushed to zero, unless all users treat denorms as zero (DAZ).
17794 // FIXME: This transform will change the sign of a NaN and the behavior
17795 // of a signaling NaN. It is only valid when a NoNaN flag is present.
17796 DenormalMode DenormMode = DAG.getDenormalMode(VT);
17797 if (DenormMode == DenormalMode::getIEEE()) {
17798 if (SDValue NegN1 =
17799 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
17800 return NegN1;
17801 if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
17802 return DAG.getNode(ISD::FNEG, DL, VT, N1);
17803 }
17804 }
17805 }
17806
17807 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
17808 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
17809 N1.getOpcode() == ISD::FADD) {
17810 // X - (X + Y) -> -Y
17811 if (N0 == N1->getOperand(0))
17812 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
17813 // X - (Y + X) -> -Y
17814 if (N0 == N1->getOperand(1))
17815 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
17816 }
17817
17818 // fold (fsub A, (fneg B)) -> (fadd A, B)
17819 if (SDValue NegN1 =
17820 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
17821 return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
17822
17823 // FSUB -> FMA combines:
17824 if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
17825 AddToWorklist(Fused.getNode());
17826 return Fused;
17827 }
17828
17829 return SDValue();
17830 }
17831
17832 // Transform IEEE Floats:
17833 // (fmul C, (uitofp Pow2))
17834 // -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
17835 // (fdiv C, (uitofp Pow2))
17836 // -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
17837 //
17838 // The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
17839 // there is no need for more than an add/sub.
17840 //
17841 // This is valid under the following circumstances:
17842 // 1) We are dealing with IEEE floats
17843 // 2) C is normal
17844 // 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
17845 // TODO: Much of this could also be used for generating `ldexp` on targets the
17846 // prefer it.
combineFMulOrFDivWithIntPow2(SDNode * N)17847 SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
17848 EVT VT = N->getValueType(0);
17849 if (!APFloat::isIEEELikeFP(VT.getFltSemantics()))
17850 return SDValue();
17851
17852 SDValue ConstOp, Pow2Op;
17853
17854 std::optional<int> Mantissa;
17855 auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
17856 if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
17857 return false;
17858
17859 ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx));
17860 Pow2Op = N->getOperand(1 - ConstOpIdx);
17861 if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
17862 (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
17863 !DAG.computeKnownBits(Pow2Op).isNonNegative()))
17864 return false;
17865
17866 Pow2Op = Pow2Op.getOperand(0);
17867
17868 // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
17869 // TODO: We could use knownbits to make this bound more precise.
17870 int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
17871
17872 auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
17873 if (CFP == nullptr)
17874 return false;
17875
17876 const APFloat &APF = CFP->getValueAPF();
17877
17878 // Make sure we have normal constant.
17879 if (!APF.isNormal())
17880 return false;
17881
17882 // Make sure the floats exponent is within the bounds that this transform
17883 // produces bitwise equals value.
17884 int CurExp = ilogb(APF);
17885 // FMul by pow2 will only increase exponent.
17886 int MinExp =
17887 N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
17888 // FDiv by pow2 will only decrease exponent.
17889 int MaxExp =
17890 N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
17891 if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
17892 MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
17893 return false;
17894
17895 // Finally make sure we actually know the mantissa for the float type.
17896 int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
17897 if (!Mantissa)
17898 Mantissa = ThisMantissa;
17899
17900 return *Mantissa == ThisMantissa && ThisMantissa > 0;
17901 };
17902
17903 // TODO: We may be able to include undefs.
17904 return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
17905 };
17906
17907 if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
17908 return SDValue();
17909
17910 if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op))
17911 return SDValue();
17912
17913 // Get log2 after all other checks have taken place. This is because
17914 // BuildLogBase2 may create a new node.
17915 SDLoc DL(N);
17916 // Get Log2 type with same bitwidth as the float type (VT).
17917 EVT NewIntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits());
17918 if (VT.isVector())
17919 NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewIntVT,
17920 VT.getVectorElementCount());
17921
17922 SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op),
17923 /*InexpensiveOnly*/ true, NewIntVT);
17924 if (!Log2)
17925 return SDValue();
17926
17927 // Perform actual transform.
17928 SDValue MantissaShiftCnt =
17929 DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
17930 // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
17931 // `(X << C1) + (C << C1)`, but that isn't always the case because of the
17932 // cast. We could implement that by handle here to handle the casts.
17933 SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt);
17934 SDValue ResAsInt =
17935 DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
17936 NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift);
17937 SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt);
17938 return ResAsFP;
17939 }
17940
visitFMUL(SDNode * N)17941 SDValue DAGCombiner::visitFMUL(SDNode *N) {
17942 SDValue N0 = N->getOperand(0);
17943 SDValue N1 = N->getOperand(1);
17944 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
17945 EVT VT = N->getValueType(0);
17946 SDLoc DL(N);
17947 const TargetOptions &Options = DAG.getTarget().Options;
17948 const SDNodeFlags Flags = N->getFlags();
17949 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17950
17951 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
17952 return R;
17953
17954 // fold (fmul c1, c2) -> c1*c2
17955 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
17956 return C;
17957
17958 // canonicalize constant to RHS
17959 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
17960 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
17961 return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
17962
17963 // fold vector ops
17964 if (VT.isVector())
17965 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17966 return FoldedVOp;
17967
17968 if (SDValue NewSel = foldBinOpIntoSelect(N))
17969 return NewSel;
17970
17971 if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
17972 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
17973 if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
17974 N0.getOpcode() == ISD::FMUL) {
17975 SDValue N00 = N0.getOperand(0);
17976 SDValue N01 = N0.getOperand(1);
17977 // Avoid an infinite loop by making sure that N00 is not a constant
17978 // (the inner multiply has not been constant folded yet).
17979 if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
17980 !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
17981 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
17982 return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
17983 }
17984 }
17985
17986 // Match a special-case: we convert X * 2.0 into fadd.
17987 // fmul (fadd X, X), C -> fmul X, 2.0 * C
17988 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
17989 N0.getOperand(0) == N0.getOperand(1)) {
17990 const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
17991 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
17992 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
17993 }
17994
17995 // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
17996 if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
17997 VT, N0, N1, Flags))
17998 return SD;
17999 }
18000
18001 // fold (fmul X, 2.0) -> (fadd X, X)
18002 if (N1CFP && N1CFP->isExactlyValue(+2.0))
18003 return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
18004
18005 // fold (fmul X, -1.0) -> (fsub -0.0, X)
18006 if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
18007 if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
18008 return DAG.getNode(ISD::FSUB, DL, VT,
18009 DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
18010 }
18011 }
18012
18013 // -N0 * -N1 --> N0 * N1
18014 TargetLowering::NegatibleCost CostN0 =
18015 TargetLowering::NegatibleCost::Expensive;
18016 TargetLowering::NegatibleCost CostN1 =
18017 TargetLowering::NegatibleCost::Expensive;
18018 SDValue NegN0 =
18019 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18020 if (NegN0) {
18021 HandleSDNode NegN0Handle(NegN0);
18022 SDValue NegN1 =
18023 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18024 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18025 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18026 return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
18027 }
18028
18029 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
18030 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
18031 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
18032 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
18033 TLI.isOperationLegal(ISD::FABS, VT)) {
18034 SDValue Select = N0, X = N1;
18035 if (Select.getOpcode() != ISD::SELECT)
18036 std::swap(Select, X);
18037
18038 SDValue Cond = Select.getOperand(0);
18039 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
18040 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
18041
18042 if (TrueOpnd && FalseOpnd &&
18043 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
18044 isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
18045 cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
18046 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
18047 switch (CC) {
18048 default: break;
18049 case ISD::SETOLT:
18050 case ISD::SETULT:
18051 case ISD::SETOLE:
18052 case ISD::SETULE:
18053 case ISD::SETLT:
18054 case ISD::SETLE:
18055 std::swap(TrueOpnd, FalseOpnd);
18056 [[fallthrough]];
18057 case ISD::SETOGT:
18058 case ISD::SETUGT:
18059 case ISD::SETOGE:
18060 case ISD::SETUGE:
18061 case ISD::SETGT:
18062 case ISD::SETGE:
18063 if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
18064 TLI.isOperationLegal(ISD::FNEG, VT))
18065 return DAG.getNode(ISD::FNEG, DL, VT,
18066 DAG.getNode(ISD::FABS, DL, VT, X));
18067 if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
18068 return DAG.getNode(ISD::FABS, DL, VT, X);
18069
18070 break;
18071 }
18072 }
18073 }
18074
18075 // FMUL -> FMA combines:
18076 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
18077 AddToWorklist(Fused.getNode());
18078 return Fused;
18079 }
18080
18081 // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
18082 // able to run.
18083 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18084 return R;
18085
18086 return SDValue();
18087 }
18088
visitFMA(SDNode * N)18089 template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
18090 SDValue N0 = N->getOperand(0);
18091 SDValue N1 = N->getOperand(1);
18092 SDValue N2 = N->getOperand(2);
18093 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
18094 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
18095 ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
18096 EVT VT = N->getValueType(0);
18097 SDLoc DL(N);
18098 const TargetOptions &Options = DAG.getTarget().Options;
18099 // FMA nodes have flags that propagate to the created nodes.
18100 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18101 MatchContextClass matcher(DAG, TLI, N);
18102
18103 // Constant fold FMA.
18104 if (SDValue C =
18105 DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1, N2}))
18106 return C;
18107
18108 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
18109 TargetLowering::NegatibleCost CostN0 =
18110 TargetLowering::NegatibleCost::Expensive;
18111 TargetLowering::NegatibleCost CostN1 =
18112 TargetLowering::NegatibleCost::Expensive;
18113 SDValue NegN0 =
18114 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18115 if (NegN0) {
18116 HandleSDNode NegN0Handle(NegN0);
18117 SDValue NegN1 =
18118 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18119 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18120 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18121 return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
18122 }
18123
18124 // FIXME: use fast math flags instead of Options.UnsafeFPMath
18125 // TODO: Finally migrate away from global TargetOptions.
18126 if ((Options.NoNaNsFPMath && Options.NoInfsFPMath) ||
18127 (N->getFlags().hasNoNaNs() && N->getFlags().hasNoInfs())) {
18128 if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros() ||
18129 (N2CFP && !N2CFP->isExactlyValue(-0.0))) {
18130 if (N0CFP && N0CFP->isZero())
18131 return N2;
18132 if (N1CFP && N1CFP->isZero())
18133 return N2;
18134 }
18135 }
18136
18137 // FIXME: Support splat of constant.
18138 if (N0CFP && N0CFP->isExactlyValue(1.0))
18139 return matcher.getNode(ISD::FADD, DL, VT, N1, N2);
18140 if (N1CFP && N1CFP->isExactlyValue(1.0))
18141 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18142
18143 // Canonicalize (fma c, x, y) -> (fma x, c, y)
18144 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
18145 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
18146 return matcher.getNode(ISD::FMA, DL, VT, N1, N0, N2);
18147
18148 bool CanReassociate =
18149 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
18150 if (CanReassociate) {
18151 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
18152 if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) &&
18153 DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
18154 DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
18155 return matcher.getNode(
18156 ISD::FMUL, DL, VT, N0,
18157 matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
18158 }
18159
18160 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
18161 if (matcher.match(N0, ISD::FMUL) &&
18162 DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
18163 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
18164 return matcher.getNode(
18165 ISD::FMA, DL, VT, N0.getOperand(0),
18166 matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2);
18167 }
18168 }
18169
18170 // (fma x, -1, y) -> (fadd (fneg x), y)
18171 // FIXME: Support splat of constant.
18172 if (N1CFP) {
18173 if (N1CFP->isExactlyValue(1.0))
18174 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
18175
18176 if (N1CFP->isExactlyValue(-1.0) &&
18177 (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
18178 SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
18179 AddToWorklist(RHSNeg.getNode());
18180 return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
18181 }
18182
18183 // fma (fneg x), K, y -> fma x -K, y
18184 if (matcher.match(N0, ISD::FNEG) &&
18185 (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
18186 (N1.hasOneUse() &&
18187 !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
18188 return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
18189 matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
18190 }
18191 }
18192
18193 // FIXME: Support splat of constant.
18194 if (CanReassociate) {
18195 // (fma x, c, x) -> (fmul x, (c+1))
18196 if (N1CFP && N0 == N2) {
18197 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18198 matcher.getNode(ISD::FADD, DL, VT, N1,
18199 DAG.getConstantFP(1.0, DL, VT)));
18200 }
18201
18202 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
18203 if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) {
18204 return matcher.getNode(ISD::FMUL, DL, VT, N0,
18205 matcher.getNode(ISD::FADD, DL, VT, N1,
18206 DAG.getConstantFP(-1.0, DL, VT)));
18207 }
18208 }
18209
18210 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
18211 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
18212 if (!TLI.isFNegFree(VT))
18213 if (SDValue Neg = TLI.getCheaperNegatedExpression(
18214 SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
18215 return matcher.getNode(ISD::FNEG, DL, VT, Neg);
18216 return SDValue();
18217 }
18218
visitFMAD(SDNode * N)18219 SDValue DAGCombiner::visitFMAD(SDNode *N) {
18220 SDValue N0 = N->getOperand(0);
18221 SDValue N1 = N->getOperand(1);
18222 SDValue N2 = N->getOperand(2);
18223 EVT VT = N->getValueType(0);
18224 SDLoc DL(N);
18225
18226 // Constant fold FMAD.
18227 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMAD, DL, VT, {N0, N1, N2}))
18228 return C;
18229
18230 return SDValue();
18231 }
18232
18233 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
18234 // reciprocal.
18235 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
18236 // Notice that this is not always beneficial. One reason is different targets
18237 // may have different costs for FDIV and FMUL, so sometimes the cost of two
18238 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
18239 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)18240 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
18241 // TODO: Limit this transform based on optsize/minsize - it always creates at
18242 // least 1 extra instruction. But the perf win may be substantial enough
18243 // that only minsize should restrict this.
18244 bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
18245 const SDNodeFlags Flags = N->getFlags();
18246 if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
18247 return SDValue();
18248
18249 // Skip if current node is a reciprocal/fneg-reciprocal.
18250 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
18251 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
18252 if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
18253 return SDValue();
18254
18255 // Exit early if the target does not want this transform or if there can't
18256 // possibly be enough uses of the divisor to make the transform worthwhile.
18257 unsigned MinUses = TLI.combineRepeatedFPDivisors();
18258
18259 // For splat vectors, scale the number of uses by the splat factor. If we can
18260 // convert the division into a scalar op, that will likely be much faster.
18261 unsigned NumElts = 1;
18262 EVT VT = N->getValueType(0);
18263 if (VT.isVector() && DAG.isSplatValue(N1))
18264 NumElts = VT.getVectorMinNumElements();
18265
18266 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
18267 return SDValue();
18268
18269 // Find all FDIV users of the same divisor.
18270 // Use a set because duplicates may be present in the user list.
18271 SetVector<SDNode *> Users;
18272 for (auto *U : N1->users()) {
18273 if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
18274 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
18275 if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
18276 U->getOperand(0) == U->getOperand(1).getOperand(0) &&
18277 U->getFlags().hasAllowReassociation() &&
18278 U->getFlags().hasNoSignedZeros())
18279 continue;
18280
18281 // This division is eligible for optimization only if global unsafe math
18282 // is enabled or if this division allows reciprocal formation.
18283 if (UnsafeMath || U->getFlags().hasAllowReciprocal())
18284 Users.insert(U);
18285 }
18286 }
18287
18288 // Now that we have the actual number of divisor uses, make sure it meets
18289 // the minimum threshold specified by the target.
18290 if ((Users.size() * NumElts) < MinUses)
18291 return SDValue();
18292
18293 SDLoc DL(N);
18294 SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
18295 SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
18296
18297 // Dividend / Divisor -> Dividend * Reciprocal
18298 for (auto *U : Users) {
18299 SDValue Dividend = U->getOperand(0);
18300 if (Dividend != FPOne) {
18301 SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
18302 Reciprocal, Flags);
18303 CombineTo(U, NewNode);
18304 } else if (U != Reciprocal.getNode()) {
18305 // In the absence of fast-math-flags, this user node is always the
18306 // same node as Reciprocal, but with FMF they may be different nodes.
18307 CombineTo(U, Reciprocal);
18308 }
18309 }
18310 return SDValue(N, 0); // N was replaced.
18311 }
18312
visitFDIV(SDNode * N)18313 SDValue DAGCombiner::visitFDIV(SDNode *N) {
18314 SDValue N0 = N->getOperand(0);
18315 SDValue N1 = N->getOperand(1);
18316 EVT VT = N->getValueType(0);
18317 SDLoc DL(N);
18318 const TargetOptions &Options = DAG.getTarget().Options;
18319 SDNodeFlags Flags = N->getFlags();
18320 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18321
18322 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
18323 return R;
18324
18325 // fold (fdiv c1, c2) -> c1/c2
18326 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
18327 return C;
18328
18329 // fold vector ops
18330 if (VT.isVector())
18331 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
18332 return FoldedVOp;
18333
18334 if (SDValue NewSel = foldBinOpIntoSelect(N))
18335 return NewSel;
18336
18337 if (SDValue V = combineRepeatedFPDivisors(N))
18338 return V;
18339
18340 // fold (fdiv X, c2) -> (fmul X, 1/c2) if there is no loss in precision, or
18341 // the loss is acceptable with AllowReciprocal.
18342 if (auto *N1CFP = isConstOrConstSplatFP(N1, true)) {
18343 // Compute the reciprocal 1.0 / c2.
18344 const APFloat &N1APF = N1CFP->getValueAPF();
18345 APFloat Recip = APFloat::getOne(N1APF.getSemantics());
18346 APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
18347 // Only do the transform if the reciprocal is a legal fp immediate that
18348 // isn't too nasty (eg NaN, denormal, ...).
18349 if (((st == APFloat::opOK && !Recip.isDenormal()) ||
18350 (st == APFloat::opInexact && Flags.hasAllowReciprocal())) &&
18351 (!LegalOperations ||
18352 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
18353 // backend)... we should handle this gracefully after Legalize.
18354 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
18355 TLI.isOperationLegal(ISD::ConstantFP, VT) ||
18356 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
18357 return DAG.getNode(ISD::FMUL, DL, VT, N0,
18358 DAG.getConstantFP(Recip, DL, VT));
18359 }
18360
18361 if (Flags.hasAllowReciprocal()) {
18362 // If this FDIV is part of a reciprocal square root, it may be folded
18363 // into a target-specific square root estimate instruction.
18364 if (N1.getOpcode() == ISD::FSQRT) {
18365 if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
18366 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
18367 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
18368 N1.getOperand(0).getOpcode() == ISD::FSQRT) {
18369 if (SDValue RV =
18370 buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
18371 RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
18372 AddToWorklist(RV.getNode());
18373 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
18374 }
18375 } else if (N1.getOpcode() == ISD::FP_ROUND &&
18376 N1.getOperand(0).getOpcode() == ISD::FSQRT) {
18377 if (SDValue RV =
18378 buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
18379 RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
18380 AddToWorklist(RV.getNode());
18381 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
18382 }
18383 } else if (N1.getOpcode() == ISD::FMUL) {
18384 // Look through an FMUL. Even though this won't remove the FDIV directly,
18385 // it's still worthwhile to get rid of the FSQRT if possible.
18386 SDValue Sqrt, Y;
18387 if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
18388 Sqrt = N1.getOperand(0);
18389 Y = N1.getOperand(1);
18390 } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
18391 Sqrt = N1.getOperand(1);
18392 Y = N1.getOperand(0);
18393 }
18394 if (Sqrt.getNode()) {
18395 // If the other multiply operand is known positive, pull it into the
18396 // sqrt. That will eliminate the division if we convert to an estimate.
18397 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
18398 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
18399 SDValue A;
18400 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
18401 A = Y.getOperand(0);
18402 else if (Y == Sqrt.getOperand(0))
18403 A = Y;
18404 if (A) {
18405 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
18406 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
18407 SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
18408 SDValue AAZ =
18409 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
18410 if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
18411 return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
18412
18413 // Estimate creation failed. Clean up speculatively created nodes.
18414 recursivelyDeleteUnusedNodes(AAZ.getNode());
18415 }
18416 }
18417
18418 // We found a FSQRT, so try to make this fold:
18419 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
18420 if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
18421 SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
18422 AddToWorklist(Div.getNode());
18423 return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
18424 }
18425 }
18426 }
18427
18428 // Fold into a reciprocal estimate and multiply instead of a real divide.
18429 if (Options.NoInfsFPMath || Flags.hasNoInfs())
18430 if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
18431 return RV;
18432 }
18433
18434 // Fold X/Sqrt(X) -> Sqrt(X)
18435 if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
18436 Flags.hasAllowReassociation())
18437 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
18438 return N1;
18439
18440 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
18441 TargetLowering::NegatibleCost CostN0 =
18442 TargetLowering::NegatibleCost::Expensive;
18443 TargetLowering::NegatibleCost CostN1 =
18444 TargetLowering::NegatibleCost::Expensive;
18445 SDValue NegN0 =
18446 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
18447 if (NegN0) {
18448 HandleSDNode NegN0Handle(NegN0);
18449 SDValue NegN1 =
18450 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
18451 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
18452 CostN1 == TargetLowering::NegatibleCost::Cheaper))
18453 return DAG.getNode(ISD::FDIV, DL, VT, NegN0, NegN1);
18454 }
18455
18456 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
18457 return R;
18458
18459 return SDValue();
18460 }
18461
visitFREM(SDNode * N)18462 SDValue DAGCombiner::visitFREM(SDNode *N) {
18463 SDValue N0 = N->getOperand(0);
18464 SDValue N1 = N->getOperand(1);
18465 EVT VT = N->getValueType(0);
18466 SDNodeFlags Flags = N->getFlags();
18467 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18468 SDLoc DL(N);
18469
18470 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
18471 return R;
18472
18473 // fold (frem c1, c2) -> fmod(c1,c2)
18474 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, DL, VT, {N0, N1}))
18475 return C;
18476
18477 if (SDValue NewSel = foldBinOpIntoSelect(N))
18478 return NewSel;
18479
18480 // Lower frem N0, N1 => x - trunc(N0 / N1) * N1, providing N1 is an integer
18481 // power of 2.
18482 if (!TLI.isOperationLegal(ISD::FREM, VT) &&
18483 TLI.isOperationLegalOrCustom(ISD::FMUL, VT) &&
18484 TLI.isOperationLegalOrCustom(ISD::FDIV, VT) &&
18485 TLI.isOperationLegalOrCustom(ISD::FTRUNC, VT) &&
18486 DAG.isKnownToBeAPowerOfTwoFP(N1)) {
18487 bool NeedsCopySign =
18488 !Flags.hasNoSignedZeros() && !DAG.cannotBeOrderedNegativeFP(N0);
18489 SDValue Div = DAG.getNode(ISD::FDIV, DL, VT, N0, N1);
18490 SDValue Rnd = DAG.getNode(ISD::FTRUNC, DL, VT, Div);
18491 SDValue MLA;
18492 if (TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT)) {
18493 MLA = DAG.getNode(ISD::FMA, DL, VT, DAG.getNode(ISD::FNEG, DL, VT, Rnd),
18494 N1, N0);
18495 } else {
18496 SDValue Mul = DAG.getNode(ISD::FMUL, DL, VT, Rnd, N1);
18497 MLA = DAG.getNode(ISD::FSUB, DL, VT, N0, Mul);
18498 }
18499 return NeedsCopySign ? DAG.getNode(ISD::FCOPYSIGN, DL, VT, MLA, N0) : MLA;
18500 }
18501
18502 return SDValue();
18503 }
18504
visitFSQRT(SDNode * N)18505 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
18506 SDNodeFlags Flags = N->getFlags();
18507 const TargetOptions &Options = DAG.getTarget().Options;
18508
18509 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
18510 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
18511 if (!Flags.hasApproximateFuncs() ||
18512 (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
18513 return SDValue();
18514
18515 SDValue N0 = N->getOperand(0);
18516 if (TLI.isFsqrtCheap(N0, DAG))
18517 return SDValue();
18518
18519 // FSQRT nodes have flags that propagate to the created nodes.
18520 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
18521 // transform the fdiv, we may produce a sub-optimal estimate sequence
18522 // because the reciprocal calculation may not have to filter out a
18523 // 0.0 input.
18524 return buildSqrtEstimate(N0, Flags);
18525 }
18526
18527 /// copysign(x, fp_extend(y)) -> copysign(x, y)
18528 /// copysign(x, fp_round(y)) -> copysign(x, y)
18529 /// Operands to the functions are the type of X and Y respectively.
CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy,EVT YTy)18530 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
18531 // Always fold no-op FP casts.
18532 if (XTy == YTy)
18533 return true;
18534
18535 // Do not optimize out type conversion of f128 type yet.
18536 // For some targets like x86_64, configuration is changed to keep one f128
18537 // value in one SSE register, but instruction selection cannot handle
18538 // FCOPYSIGN on SSE registers yet.
18539 if (YTy == MVT::f128)
18540 return false;
18541
18542 // Avoid mismatched vector operand types, for better instruction selection.
18543 return !YTy.isVector();
18544 }
18545
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)18546 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
18547 SDValue N1 = N->getOperand(1);
18548 if (N1.getOpcode() != ISD::FP_EXTEND &&
18549 N1.getOpcode() != ISD::FP_ROUND)
18550 return false;
18551 EVT N1VT = N1->getValueType(0);
18552 EVT N1Op0VT = N1->getOperand(0).getValueType();
18553 return CanCombineFCOPYSIGN_EXTEND_ROUND(N1VT, N1Op0VT);
18554 }
18555
visitFCOPYSIGN(SDNode * N)18556 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
18557 SDValue N0 = N->getOperand(0);
18558 SDValue N1 = N->getOperand(1);
18559 EVT VT = N->getValueType(0);
18560 SDLoc DL(N);
18561
18562 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
18563 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, DL, VT, {N0, N1}))
18564 return C;
18565
18566 // copysign(x, fp_extend(y)) -> copysign(x, y)
18567 // copysign(x, fp_round(y)) -> copysign(x, y)
18568 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
18569 return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N0, N1.getOperand(0));
18570
18571 if (SimplifyDemandedBits(SDValue(N, 0)))
18572 return SDValue(N, 0);
18573
18574 return SDValue();
18575 }
18576
visitFPOW(SDNode * N)18577 SDValue DAGCombiner::visitFPOW(SDNode *N) {
18578 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
18579 if (!ExponentC)
18580 return SDValue();
18581 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18582
18583 // Try to convert x ** (1/3) into cube root.
18584 // TODO: Handle the various flavors of long double.
18585 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
18586 // Some range near 1/3 should be fine.
18587 EVT VT = N->getValueType(0);
18588 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
18589 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
18590 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
18591 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
18592 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
18593 // For regular numbers, rounding may cause the results to differ.
18594 // Therefore, we require { nsz ninf nnan afn } for this transform.
18595 // TODO: We could select out the special cases if we don't have nsz/ninf.
18596 SDNodeFlags Flags = N->getFlags();
18597 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
18598 !Flags.hasApproximateFuncs())
18599 return SDValue();
18600
18601 // Do not create a cbrt() libcall if the target does not have it, and do not
18602 // turn a pow that has lowering support into a cbrt() libcall.
18603 if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
18604 (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
18605 DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
18606 return SDValue();
18607
18608 return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
18609 }
18610
18611 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
18612 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
18613 // TODO: This could be extended (using a target hook) to handle smaller
18614 // power-of-2 fractional exponents.
18615 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
18616 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
18617 if (ExponentIs025 || ExponentIs075) {
18618 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
18619 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
18620 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
18621 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
18622 // For regular numbers, rounding may cause the results to differ.
18623 // Therefore, we require { nsz ninf afn } for this transform.
18624 // TODO: We could select out the special cases if we don't have nsz/ninf.
18625 SDNodeFlags Flags = N->getFlags();
18626
18627 // We only need no signed zeros for the 0.25 case.
18628 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
18629 !Flags.hasApproximateFuncs())
18630 return SDValue();
18631
18632 // Don't double the number of libcalls. We are trying to inline fast code.
18633 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
18634 return SDValue();
18635
18636 // Assume that libcalls are the smallest code.
18637 // TODO: This restriction should probably be lifted for vectors.
18638 if (ForCodeSize)
18639 return SDValue();
18640
18641 // pow(X, 0.25) --> sqrt(sqrt(X))
18642 SDLoc DL(N);
18643 SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
18644 SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
18645 if (ExponentIs025)
18646 return SqrtSqrt;
18647 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
18648 return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
18649 }
18650
18651 return SDValue();
18652 }
18653
foldFPToIntToFP(SDNode * N,const SDLoc & DL,SelectionDAG & DAG,const TargetLowering & TLI)18654 static SDValue foldFPToIntToFP(SDNode *N, const SDLoc &DL, SelectionDAG &DAG,
18655 const TargetLowering &TLI) {
18656 // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
18657 // replacing casts with a libcall. We also must be allowed to ignore -0.0
18658 // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
18659 // conversions would return +0.0.
18660 // FIXME: We should be able to use node-level FMF here.
18661 // TODO: If strict math, should we use FABS (+ range check for signed cast)?
18662 EVT VT = N->getValueType(0);
18663 if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
18664 !DAG.getTarget().Options.NoSignedZerosFPMath)
18665 return SDValue();
18666
18667 // fptosi/fptoui round towards zero, so converting from FP to integer and
18668 // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
18669 SDValue N0 = N->getOperand(0);
18670 if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
18671 N0.getOperand(0).getValueType() == VT)
18672 return DAG.getNode(ISD::FTRUNC, DL, VT, N0.getOperand(0));
18673
18674 if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
18675 N0.getOperand(0).getValueType() == VT)
18676 return DAG.getNode(ISD::FTRUNC, DL, VT, N0.getOperand(0));
18677
18678 return SDValue();
18679 }
18680
visitSINT_TO_FP(SDNode * N)18681 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
18682 SDValue N0 = N->getOperand(0);
18683 EVT VT = N->getValueType(0);
18684 EVT OpVT = N0.getValueType();
18685 SDLoc DL(N);
18686
18687 // [us]itofp(undef) = 0, because the result value is bounded.
18688 if (N0.isUndef())
18689 return DAG.getConstantFP(0.0, DL, VT);
18690
18691 // fold (sint_to_fp c1) -> c1fp
18692 // ...but only if the target supports immediate floating-point values
18693 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18694 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SINT_TO_FP, DL, VT, {N0}))
18695 return C;
18696
18697 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
18698 // but UINT_TO_FP is legal on this target, try to convert.
18699 if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
18700 hasOperation(ISD::UINT_TO_FP, OpVT)) {
18701 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
18702 if (DAG.SignBitIsZero(N0))
18703 return DAG.getNode(ISD::UINT_TO_FP, DL, VT, N0);
18704 }
18705
18706 // The next optimizations are desirable only if SELECT_CC can be lowered.
18707 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
18708 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
18709 !VT.isVector() &&
18710 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18711 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
18712 DAG.getConstantFP(0.0, DL, VT));
18713
18714 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
18715 // (select (setcc x, y, cc), 1.0, 0.0)
18716 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
18717 N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
18718 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18719 return DAG.getSelect(DL, VT, N0.getOperand(0),
18720 DAG.getConstantFP(1.0, DL, VT),
18721 DAG.getConstantFP(0.0, DL, VT));
18722
18723 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
18724 return FTrunc;
18725
18726 return SDValue();
18727 }
18728
visitUINT_TO_FP(SDNode * N)18729 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
18730 SDValue N0 = N->getOperand(0);
18731 EVT VT = N->getValueType(0);
18732 EVT OpVT = N0.getValueType();
18733 SDLoc DL(N);
18734
18735 // [us]itofp(undef) = 0, because the result value is bounded.
18736 if (N0.isUndef())
18737 return DAG.getConstantFP(0.0, DL, VT);
18738
18739 // fold (uint_to_fp c1) -> c1fp
18740 // ...but only if the target supports immediate floating-point values
18741 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18742 if (SDValue C = DAG.FoldConstantArithmetic(ISD::UINT_TO_FP, DL, VT, {N0}))
18743 return C;
18744
18745 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
18746 // but SINT_TO_FP is legal on this target, try to convert.
18747 if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
18748 hasOperation(ISD::SINT_TO_FP, OpVT)) {
18749 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
18750 if (DAG.SignBitIsZero(N0))
18751 return DAG.getNode(ISD::SINT_TO_FP, DL, VT, N0);
18752 }
18753
18754 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
18755 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
18756 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
18757 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
18758 DAG.getConstantFP(0.0, DL, VT));
18759
18760 if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI))
18761 return FTrunc;
18762
18763 return SDValue();
18764 }
18765
18766 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,const SDLoc & DL,SelectionDAG & DAG)18767 static SDValue FoldIntToFPToInt(SDNode *N, const SDLoc &DL, SelectionDAG &DAG) {
18768 SDValue N0 = N->getOperand(0);
18769 EVT VT = N->getValueType(0);
18770
18771 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
18772 return SDValue();
18773
18774 SDValue Src = N0.getOperand(0);
18775 EVT SrcVT = Src.getValueType();
18776 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
18777 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
18778
18779 // We can safely assume the conversion won't overflow the output range,
18780 // because (for example) (uint8_t)18293.f is undefined behavior.
18781
18782 // Since we can assume the conversion won't overflow, our decision as to
18783 // whether the input will fit in the float should depend on the minimum
18784 // of the input range and output range.
18785
18786 // This means this is also safe for a signed input and unsigned output, since
18787 // a negative input would lead to undefined behavior.
18788 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
18789 unsigned OutputSize = (int)VT.getScalarSizeInBits();
18790 unsigned ActualSize = std::min(InputSize, OutputSize);
18791 const fltSemantics &Sem = N0.getValueType().getFltSemantics();
18792
18793 // We can only fold away the float conversion if the input range can be
18794 // represented exactly in the float range.
18795 if (APFloat::semanticsPrecision(Sem) >= ActualSize) {
18796 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
18797 unsigned ExtOp =
18798 IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
18799 return DAG.getNode(ExtOp, DL, VT, Src);
18800 }
18801 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
18802 return DAG.getNode(ISD::TRUNCATE, DL, VT, Src);
18803 return DAG.getBitcast(VT, Src);
18804 }
18805 return SDValue();
18806 }
18807
visitFP_TO_SINT(SDNode * N)18808 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
18809 SDValue N0 = N->getOperand(0);
18810 EVT VT = N->getValueType(0);
18811 SDLoc DL(N);
18812
18813 // fold (fp_to_sint undef) -> undef
18814 if (N0.isUndef())
18815 return DAG.getUNDEF(VT);
18816
18817 // fold (fp_to_sint c1fp) -> c1
18818 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_TO_SINT, DL, VT, {N0}))
18819 return C;
18820
18821 return FoldIntToFPToInt(N, DL, DAG);
18822 }
18823
visitFP_TO_UINT(SDNode * N)18824 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
18825 SDValue N0 = N->getOperand(0);
18826 EVT VT = N->getValueType(0);
18827 SDLoc DL(N);
18828
18829 // fold (fp_to_uint undef) -> undef
18830 if (N0.isUndef())
18831 return DAG.getUNDEF(VT);
18832
18833 // fold (fp_to_uint c1fp) -> c1
18834 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_TO_UINT, DL, VT, {N0}))
18835 return C;
18836
18837 return FoldIntToFPToInt(N, DL, DAG);
18838 }
18839
visitXROUND(SDNode * N)18840 SDValue DAGCombiner::visitXROUND(SDNode *N) {
18841 SDValue N0 = N->getOperand(0);
18842 EVT VT = N->getValueType(0);
18843
18844 // fold (lrint|llrint undef) -> undef
18845 // fold (lround|llround undef) -> undef
18846 if (N0.isUndef())
18847 return DAG.getUNDEF(VT);
18848
18849 // fold (lrint|llrint c1fp) -> c1
18850 // fold (lround|llround c1fp) -> c1
18851 if (SDValue C =
18852 DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0}))
18853 return C;
18854
18855 return SDValue();
18856 }
18857
visitFP_ROUND(SDNode * N)18858 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
18859 SDValue N0 = N->getOperand(0);
18860 SDValue N1 = N->getOperand(1);
18861 EVT VT = N->getValueType(0);
18862 SDLoc DL(N);
18863
18864 // fold (fp_round c1fp) -> c1fp
18865 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_ROUND, DL, VT, {N0, N1}))
18866 return C;
18867
18868 // fold (fp_round (fp_extend x)) -> x
18869 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
18870 return N0.getOperand(0);
18871
18872 // fold (fp_round (fp_round x)) -> (fp_round x)
18873 if (N0.getOpcode() == ISD::FP_ROUND) {
18874 const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
18875 const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
18876
18877 // Avoid folding legal fp_rounds into non-legal ones.
18878 if (!hasOperation(ISD::FP_ROUND, VT))
18879 return SDValue();
18880
18881 // Skip this folding if it results in an fp_round from f80 to f16.
18882 //
18883 // f80 to f16 always generates an expensive (and as yet, unimplemented)
18884 // libcall to __truncxfhf2 instead of selecting native f16 conversion
18885 // instructions from f32 or f64. Moreover, the first (value-preserving)
18886 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
18887 // x86.
18888 if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
18889 return SDValue();
18890
18891 // If the first fp_round isn't a value preserving truncation, it might
18892 // introduce a tie in the second fp_round, that wouldn't occur in the
18893 // single-step fp_round we want to fold to.
18894 // In other words, double rounding isn't the same as rounding.
18895 // Also, this is a value preserving truncation iff both fp_round's are.
18896 if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc)
18897 return DAG.getNode(
18898 ISD::FP_ROUND, DL, VT, N0.getOperand(0),
18899 DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
18900 }
18901
18902 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
18903 // Note: From a legality perspective, this is a two step transform. First,
18904 // we duplicate the fp_round to the arguments of the copysign, then we
18905 // eliminate the fp_round on Y. The second step requires an additional
18906 // predicate to match the implementation above.
18907 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
18908 CanCombineFCOPYSIGN_EXTEND_ROUND(VT,
18909 N0.getValueType())) {
18910 SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
18911 N0.getOperand(0), N1);
18912 AddToWorklist(Tmp.getNode());
18913 return DAG.getNode(ISD::FCOPYSIGN, DL, VT, Tmp, N0.getOperand(1));
18914 }
18915
18916 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
18917 return NewVSel;
18918
18919 return SDValue();
18920 }
18921
18922 // Eliminate a floating-point widening of a narrowed value if the fast math
18923 // flags allow it.
eliminateFPCastPair(SDNode * N)18924 static SDValue eliminateFPCastPair(SDNode *N) {
18925 SDValue N0 = N->getOperand(0);
18926 EVT VT = N->getValueType(0);
18927
18928 unsigned NarrowingOp;
18929 switch (N->getOpcode()) {
18930 case ISD::FP16_TO_FP:
18931 NarrowingOp = ISD::FP_TO_FP16;
18932 break;
18933 case ISD::BF16_TO_FP:
18934 NarrowingOp = ISD::FP_TO_BF16;
18935 break;
18936 case ISD::FP_EXTEND:
18937 NarrowingOp = ISD::FP_ROUND;
18938 break;
18939 default:
18940 llvm_unreachable("Expected widening FP cast");
18941 }
18942
18943 if (N0.getOpcode() == NarrowingOp && N0.getOperand(0).getValueType() == VT) {
18944 const SDNodeFlags NarrowFlags = N0->getFlags();
18945 const SDNodeFlags WidenFlags = N->getFlags();
18946 // Narrowing can introduce inf and change the encoding of a nan, so the
18947 // widen must have the nnan and ninf flags to indicate that we don't need to
18948 // care about that. We are also removing a rounding step, and that requires
18949 // both the narrow and widen to allow contraction.
18950 if (WidenFlags.hasNoNaNs() && WidenFlags.hasNoInfs() &&
18951 NarrowFlags.hasAllowContract() && WidenFlags.hasAllowContract()) {
18952 return N0.getOperand(0);
18953 }
18954 }
18955
18956 return SDValue();
18957 }
18958
visitFP_EXTEND(SDNode * N)18959 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
18960 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
18961 SDValue N0 = N->getOperand(0);
18962 EVT VT = N->getValueType(0);
18963 SDLoc DL(N);
18964
18965 if (VT.isVector())
18966 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
18967 return FoldedVOp;
18968
18969 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
18970 if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::FP_ROUND)
18971 return SDValue();
18972
18973 // fold (fp_extend c1fp) -> c1fp
18974 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FP_EXTEND, DL, VT, {N0}))
18975 return C;
18976
18977 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
18978 if (N0.getOpcode() == ISD::FP16_TO_FP &&
18979 TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
18980 return DAG.getNode(ISD::FP16_TO_FP, DL, VT, N0.getOperand(0));
18981
18982 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
18983 // value of X.
18984 if (N0.getOpcode() == ISD::FP_ROUND && N0.getConstantOperandVal(1) == 1) {
18985 SDValue In = N0.getOperand(0);
18986 if (In.getValueType() == VT) return In;
18987 if (VT.bitsLT(In.getValueType()))
18988 return DAG.getNode(ISD::FP_ROUND, DL, VT, In, N0.getOperand(1));
18989 return DAG.getNode(ISD::FP_EXTEND, DL, VT, In);
18990 }
18991
18992 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
18993 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
18994 TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
18995 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
18996 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT,
18997 LN0->getChain(),
18998 LN0->getBasePtr(), N0.getValueType(),
18999 LN0->getMemOperand());
19000 CombineTo(N, ExtLoad);
19001 CombineTo(
19002 N0.getNode(),
19003 DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
19004 DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
19005 ExtLoad.getValue(1));
19006 return SDValue(N, 0); // Return N so it doesn't get rechecked!
19007 }
19008
19009 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
19010 return NewVSel;
19011
19012 if (SDValue CastEliminated = eliminateFPCastPair(N))
19013 return CastEliminated;
19014
19015 return SDValue();
19016 }
19017
visitFCEIL(SDNode * N)19018 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
19019 SDValue N0 = N->getOperand(0);
19020 EVT VT = N->getValueType(0);
19021
19022 // fold (fceil c1) -> fceil(c1)
19023 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FCEIL, SDLoc(N), VT, {N0}))
19024 return C;
19025
19026 return SDValue();
19027 }
19028
visitFTRUNC(SDNode * N)19029 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
19030 SDValue N0 = N->getOperand(0);
19031 EVT VT = N->getValueType(0);
19032
19033 // fold (ftrunc c1) -> ftrunc(c1)
19034 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FTRUNC, SDLoc(N), VT, {N0}))
19035 return C;
19036
19037 // fold ftrunc (known rounded int x) -> x
19038 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
19039 // likely to be generated to extract integer from a rounded floating value.
19040 switch (N0.getOpcode()) {
19041 default: break;
19042 case ISD::FRINT:
19043 case ISD::FTRUNC:
19044 case ISD::FNEARBYINT:
19045 case ISD::FROUNDEVEN:
19046 case ISD::FFLOOR:
19047 case ISD::FCEIL:
19048 return N0;
19049 }
19050
19051 return SDValue();
19052 }
19053
visitFFREXP(SDNode * N)19054 SDValue DAGCombiner::visitFFREXP(SDNode *N) {
19055 SDValue N0 = N->getOperand(0);
19056
19057 // fold (ffrexp c1) -> ffrexp(c1)
19058 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
19059 return DAG.getNode(ISD::FFREXP, SDLoc(N), N->getVTList(), N0);
19060 return SDValue();
19061 }
19062
visitFFLOOR(SDNode * N)19063 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
19064 SDValue N0 = N->getOperand(0);
19065 EVT VT = N->getValueType(0);
19066
19067 // fold (ffloor c1) -> ffloor(c1)
19068 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FFLOOR, SDLoc(N), VT, {N0}))
19069 return C;
19070
19071 return SDValue();
19072 }
19073
visitFNEG(SDNode * N)19074 SDValue DAGCombiner::visitFNEG(SDNode *N) {
19075 SDValue N0 = N->getOperand(0);
19076 EVT VT = N->getValueType(0);
19077 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19078
19079 // Constant fold FNEG.
19080 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FNEG, SDLoc(N), VT, {N0}))
19081 return C;
19082
19083 if (SDValue NegN0 =
19084 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
19085 return NegN0;
19086
19087 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
19088 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
19089 // know it was called from a context with a nsz flag if the input fsub does
19090 // not.
19091 if (N0.getOpcode() == ISD::FSUB &&
19092 (DAG.getTarget().Options.NoSignedZerosFPMath ||
19093 N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
19094 return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
19095 N0.getOperand(0));
19096 }
19097
19098 if (SimplifyDemandedBits(SDValue(N, 0)))
19099 return SDValue(N, 0);
19100
19101 if (SDValue Cast = foldSignChangeInBitcast(N))
19102 return Cast;
19103
19104 return SDValue();
19105 }
19106
visitFMinMax(SDNode * N)19107 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
19108 SDValue N0 = N->getOperand(0);
19109 SDValue N1 = N->getOperand(1);
19110 EVT VT = N->getValueType(0);
19111 const SDNodeFlags Flags = N->getFlags();
19112 unsigned Opc = N->getOpcode();
19113 bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
19114 bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
19115 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
19116
19117 // Constant fold.
19118 if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
19119 return C;
19120
19121 // Canonicalize to constant on RHS.
19122 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
19123 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
19124 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
19125
19126 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
19127 const APFloat &AF = N1CFP->getValueAPF();
19128
19129 // minnum(X, nan) -> X
19130 // maxnum(X, nan) -> X
19131 // minimum(X, nan) -> nan
19132 // maximum(X, nan) -> nan
19133 if (AF.isNaN())
19134 return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
19135
19136 // In the following folds, inf can be replaced with the largest finite
19137 // float, if the ninf flag is set.
19138 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
19139 // minnum(X, -inf) -> -inf
19140 // maxnum(X, +inf) -> +inf
19141 // minimum(X, -inf) -> -inf if nnan
19142 // maximum(X, +inf) -> +inf if nnan
19143 if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
19144 return N->getOperand(1);
19145
19146 // minnum(X, +inf) -> X if nnan
19147 // maxnum(X, -inf) -> X if nnan
19148 // minimum(X, +inf) -> X
19149 // maximum(X, -inf) -> X
19150 if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
19151 return N->getOperand(0);
19152 }
19153 }
19154
19155 if (SDValue SD = reassociateReduction(
19156 PropagatesNaN
19157 ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
19158 : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
19159 Opc, SDLoc(N), VT, N0, N1, Flags))
19160 return SD;
19161
19162 return SDValue();
19163 }
19164
visitFABS(SDNode * N)19165 SDValue DAGCombiner::visitFABS(SDNode *N) {
19166 SDValue N0 = N->getOperand(0);
19167 EVT VT = N->getValueType(0);
19168 SDLoc DL(N);
19169
19170 // fold (fabs c1) -> fabs(c1)
19171 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FABS, DL, VT, {N0}))
19172 return C;
19173
19174 if (SimplifyDemandedBits(SDValue(N, 0)))
19175 return SDValue(N, 0);
19176
19177 if (SDValue Cast = foldSignChangeInBitcast(N))
19178 return Cast;
19179
19180 return SDValue();
19181 }
19182
visitBRCOND(SDNode * N)19183 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
19184 SDValue Chain = N->getOperand(0);
19185 SDValue N1 = N->getOperand(1);
19186 SDValue N2 = N->getOperand(2);
19187
19188 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
19189 // nondeterministic jumps).
19190 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
19191 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
19192 N1->getOperand(0), N2, N->getFlags());
19193 }
19194
19195 // Variant of the previous fold where there is a SETCC in between:
19196 // BRCOND(SETCC(FREEZE(X), CONST, Cond))
19197 // =>
19198 // BRCOND(FREEZE(SETCC(X, CONST, Cond)))
19199 // =>
19200 // BRCOND(SETCC(X, CONST, Cond))
19201 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
19202 // isn't equivalent to true or false.
19203 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
19204 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
19205 if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
19206 SDValue S0 = N1->getOperand(0), S1 = N1->getOperand(1);
19207 ISD::CondCode Cond = cast<CondCodeSDNode>(N1->getOperand(2))->get();
19208 ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(S0);
19209 ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(S1);
19210 bool Updated = false;
19211
19212 // Is 'X Cond C' always true or false?
19213 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
19214 bool False = (Cond == ISD::SETULT && C->isZero()) ||
19215 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
19216 (Cond == ISD::SETUGT && C->isAllOnes()) ||
19217 (Cond == ISD::SETGT && C->isMaxSignedValue());
19218 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
19219 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
19220 (Cond == ISD::SETUGE && C->isZero()) ||
19221 (Cond == ISD::SETGE && C->isMinSignedValue());
19222 return True || False;
19223 };
19224
19225 if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
19226 if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
19227 S0 = S0->getOperand(0);
19228 Updated = true;
19229 }
19230 }
19231 if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
19232 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond), S0C)) {
19233 S1 = S1->getOperand(0);
19234 Updated = true;
19235 }
19236 }
19237
19238 if (Updated)
19239 return DAG.getNode(
19240 ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
19241 DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2,
19242 N->getFlags());
19243 }
19244
19245 // If N is a constant we could fold this into a fallthrough or unconditional
19246 // branch. However that doesn't happen very often in normal code, because
19247 // Instcombine/SimplifyCFG should have handled the available opportunities.
19248 // If we did this folding here, it would be necessary to update the
19249 // MachineBasicBlock CFG, which is awkward.
19250
19251 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
19252 // on the target.
19253 if (N1.getOpcode() == ISD::SETCC &&
19254 TLI.isOperationLegalOrCustom(ISD::BR_CC,
19255 N1.getOperand(0).getValueType())) {
19256 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
19257 Chain, N1.getOperand(2),
19258 N1.getOperand(0), N1.getOperand(1), N2);
19259 }
19260
19261 if (N1.hasOneUse()) {
19262 // rebuildSetCC calls visitXor which may change the Chain when there is a
19263 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
19264 HandleSDNode ChainHandle(Chain);
19265 if (SDValue NewN1 = rebuildSetCC(N1))
19266 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
19267 ChainHandle.getValue(), NewN1, N2, N->getFlags());
19268 }
19269
19270 return SDValue();
19271 }
19272
rebuildSetCC(SDValue N)19273 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
19274 if (N.getOpcode() == ISD::SRL ||
19275 (N.getOpcode() == ISD::TRUNCATE &&
19276 (N.getOperand(0).hasOneUse() &&
19277 N.getOperand(0).getOpcode() == ISD::SRL))) {
19278 // Look pass the truncate.
19279 if (N.getOpcode() == ISD::TRUNCATE)
19280 N = N.getOperand(0);
19281
19282 // Match this pattern so that we can generate simpler code:
19283 //
19284 // %a = ...
19285 // %b = and i32 %a, 2
19286 // %c = srl i32 %b, 1
19287 // brcond i32 %c ...
19288 //
19289 // into
19290 //
19291 // %a = ...
19292 // %b = and i32 %a, 2
19293 // %c = setcc eq %b, 0
19294 // brcond %c ...
19295 //
19296 // This applies only when the AND constant value has one bit set and the
19297 // SRL constant is equal to the log2 of the AND constant. The back-end is
19298 // smart enough to convert the result into a TEST/JMP sequence.
19299 SDValue Op0 = N.getOperand(0);
19300 SDValue Op1 = N.getOperand(1);
19301
19302 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
19303 SDValue AndOp1 = Op0.getOperand(1);
19304
19305 if (AndOp1.getOpcode() == ISD::Constant) {
19306 const APInt &AndConst = AndOp1->getAsAPIntVal();
19307
19308 if (AndConst.isPowerOf2() &&
19309 Op1->getAsAPIntVal() == AndConst.logBase2()) {
19310 SDLoc DL(N);
19311 return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
19312 Op0, DAG.getConstant(0, DL, Op0.getValueType()),
19313 ISD::SETNE);
19314 }
19315 }
19316 }
19317 }
19318
19319 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
19320 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
19321 if (N.getOpcode() == ISD::XOR) {
19322 // Because we may call this on a speculatively constructed
19323 // SimplifiedSetCC Node, we need to simplify this node first.
19324 // Ideally this should be folded into SimplifySetCC and not
19325 // here. For now, grab a handle to N so we don't lose it from
19326 // replacements interal to the visit.
19327 HandleSDNode XORHandle(N);
19328 while (N.getOpcode() == ISD::XOR) {
19329 SDValue Tmp = visitXOR(N.getNode());
19330 // No simplification done.
19331 if (!Tmp.getNode())
19332 break;
19333 // Returning N is form in-visit replacement that may invalidated
19334 // N. Grab value from Handle.
19335 if (Tmp.getNode() == N.getNode())
19336 N = XORHandle.getValue();
19337 else // Node simplified. Try simplifying again.
19338 N = Tmp;
19339 }
19340
19341 if (N.getOpcode() != ISD::XOR)
19342 return N;
19343
19344 SDValue Op0 = N->getOperand(0);
19345 SDValue Op1 = N->getOperand(1);
19346
19347 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
19348 bool Equal = false;
19349 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
19350 if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
19351 Op0.getValueType() == MVT::i1) {
19352 N = Op0;
19353 Op0 = N->getOperand(0);
19354 Op1 = N->getOperand(1);
19355 Equal = true;
19356 }
19357
19358 EVT SetCCVT = N.getValueType();
19359 if (LegalTypes)
19360 SetCCVT = getSetCCResultType(SetCCVT);
19361 // Replace the uses of XOR with SETCC. Note, avoid this transformation if
19362 // it would introduce illegal operations post-legalization as this can
19363 // result in infinite looping between converting xor->setcc here, and
19364 // expanding setcc->xor in LegalizeSetCCCondCode if requested.
19365 const ISD::CondCode CC = Equal ? ISD::SETEQ : ISD::SETNE;
19366 if (!LegalOperations || TLI.isCondCodeLegal(CC, Op0.getSimpleValueType()))
19367 return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1, CC);
19368 }
19369 }
19370
19371 return SDValue();
19372 }
19373
19374 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
19375 //
visitBR_CC(SDNode * N)19376 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
19377 CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
19378 SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
19379
19380 // If N is a constant we could fold this into a fallthrough or unconditional
19381 // branch. However that doesn't happen very often in normal code, because
19382 // Instcombine/SimplifyCFG should have handled the available opportunities.
19383 // If we did this folding here, it would be necessary to update the
19384 // MachineBasicBlock CFG, which is awkward.
19385
19386 // Use SimplifySetCC to simplify SETCC's.
19387 SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
19388 CondLHS, CondRHS, CC->get(), SDLoc(N),
19389 false);
19390 if (Simp.getNode()) AddToWorklist(Simp.getNode());
19391
19392 // fold to a simpler setcc
19393 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
19394 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
19395 N->getOperand(0), Simp.getOperand(2),
19396 Simp.getOperand(0), Simp.getOperand(1),
19397 N->getOperand(4));
19398
19399 return SDValue();
19400 }
19401
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)19402 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
19403 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
19404 const TargetLowering &TLI) {
19405 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
19406 if (LD->isIndexed())
19407 return false;
19408 EVT VT = LD->getMemoryVT();
19409 if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
19410 return false;
19411 Ptr = LD->getBasePtr();
19412 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
19413 if (ST->isIndexed())
19414 return false;
19415 EVT VT = ST->getMemoryVT();
19416 if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
19417 return false;
19418 Ptr = ST->getBasePtr();
19419 IsLoad = false;
19420 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
19421 if (LD->isIndexed())
19422 return false;
19423 EVT VT = LD->getMemoryVT();
19424 if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
19425 !TLI.isIndexedMaskedLoadLegal(Dec, VT))
19426 return false;
19427 Ptr = LD->getBasePtr();
19428 IsMasked = true;
19429 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
19430 if (ST->isIndexed())
19431 return false;
19432 EVT VT = ST->getMemoryVT();
19433 if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
19434 !TLI.isIndexedMaskedStoreLegal(Dec, VT))
19435 return false;
19436 Ptr = ST->getBasePtr();
19437 IsLoad = false;
19438 IsMasked = true;
19439 } else {
19440 return false;
19441 }
19442 return true;
19443 }
19444
19445 /// Try turning a load/store into a pre-indexed load/store when the base
19446 /// pointer is an add or subtract and it has other uses besides the load/store.
19447 /// After the transformation, the new indexed load/store has effectively folded
19448 /// the add/subtract in and all of its other uses are redirected to the
19449 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)19450 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
19451 if (Level < AfterLegalizeDAG)
19452 return false;
19453
19454 bool IsLoad = true;
19455 bool IsMasked = false;
19456 SDValue Ptr;
19457 if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
19458 Ptr, TLI))
19459 return false;
19460
19461 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
19462 // out. There is no reason to make this a preinc/predec.
19463 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
19464 Ptr->hasOneUse())
19465 return false;
19466
19467 // Ask the target to do addressing mode selection.
19468 SDValue BasePtr;
19469 SDValue Offset;
19470 ISD::MemIndexedMode AM = ISD::UNINDEXED;
19471 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
19472 return false;
19473
19474 // Backends without true r+i pre-indexed forms may need to pass a
19475 // constant base with a variable offset so that constant coercion
19476 // will work with the patterns in canonical form.
19477 bool Swapped = false;
19478 if (isa<ConstantSDNode>(BasePtr)) {
19479 std::swap(BasePtr, Offset);
19480 Swapped = true;
19481 }
19482
19483 // Don't create a indexed load / store with zero offset.
19484 if (isNullConstant(Offset))
19485 return false;
19486
19487 // Try turning it into a pre-indexed load / store except when:
19488 // 1) The new base ptr is a frame index.
19489 // 2) If N is a store and the new base ptr is either the same as or is a
19490 // predecessor of the value being stored.
19491 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
19492 // that would create a cycle.
19493 // 4) All uses are load / store ops that use it as old base ptr.
19494
19495 // Check #1. Preinc'ing a frame index would require copying the stack pointer
19496 // (plus the implicit offset) to a register to preinc anyway.
19497 if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
19498 return false;
19499
19500 // Check #2.
19501 if (!IsLoad) {
19502 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
19503 : cast<StoreSDNode>(N)->getValue();
19504
19505 // Would require a copy.
19506 if (Val == BasePtr)
19507 return false;
19508
19509 // Would create a cycle.
19510 if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
19511 return false;
19512 }
19513
19514 // Caches for hasPredecessorHelper.
19515 SmallPtrSet<const SDNode *, 32> Visited;
19516 SmallVector<const SDNode *, 16> Worklist;
19517 Worklist.push_back(N);
19518
19519 // If the offset is a constant, there may be other adds of constants that
19520 // can be folded with this one. We should do this to avoid having to keep
19521 // a copy of the original base pointer.
19522 SmallVector<SDNode *, 16> OtherUses;
19523 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19524 if (isa<ConstantSDNode>(Offset))
19525 for (SDUse &Use : BasePtr->uses()) {
19526 // Skip the use that is Ptr and uses of other results from BasePtr's
19527 // node (important for nodes that return multiple results).
19528 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
19529 continue;
19530
19531 if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist,
19532 MaxSteps))
19533 continue;
19534
19535 if (Use.getUser()->getOpcode() != ISD::ADD &&
19536 Use.getUser()->getOpcode() != ISD::SUB) {
19537 OtherUses.clear();
19538 break;
19539 }
19540
19541 SDValue Op1 = Use.getUser()->getOperand((Use.getOperandNo() + 1) & 1);
19542 if (!isa<ConstantSDNode>(Op1)) {
19543 OtherUses.clear();
19544 break;
19545 }
19546
19547 // FIXME: In some cases, we can be smarter about this.
19548 if (Op1.getValueType() != Offset.getValueType()) {
19549 OtherUses.clear();
19550 break;
19551 }
19552
19553 OtherUses.push_back(Use.getUser());
19554 }
19555
19556 if (Swapped)
19557 std::swap(BasePtr, Offset);
19558
19559 // Now check for #3 and #4.
19560 bool RealUse = false;
19561
19562 for (SDNode *User : Ptr->users()) {
19563 if (User == N)
19564 continue;
19565 if (SDNode::hasPredecessorHelper(User, Visited, Worklist, MaxSteps))
19566 return false;
19567
19568 // If Ptr may be folded in addressing mode of other use, then it's
19569 // not profitable to do this transformation.
19570 if (!canFoldInAddressingMode(Ptr.getNode(), User, DAG, TLI))
19571 RealUse = true;
19572 }
19573
19574 if (!RealUse)
19575 return false;
19576
19577 SDValue Result;
19578 if (!IsMasked) {
19579 if (IsLoad)
19580 Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
19581 else
19582 Result =
19583 DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
19584 } else {
19585 if (IsLoad)
19586 Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
19587 Offset, AM);
19588 else
19589 Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
19590 Offset, AM);
19591 }
19592 ++PreIndexedNodes;
19593 ++NodesCombined;
19594 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
19595 Result.dump(&DAG); dbgs() << '\n');
19596 WorklistRemover DeadNodes(*this);
19597 if (IsLoad) {
19598 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
19599 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
19600 } else {
19601 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
19602 }
19603
19604 // Finally, since the node is now dead, remove it from the graph.
19605 deleteAndRecombine(N);
19606
19607 if (Swapped)
19608 std::swap(BasePtr, Offset);
19609
19610 // Replace other uses of BasePtr that can be updated to use Ptr
19611 for (SDNode *OtherUse : OtherUses) {
19612 unsigned OffsetIdx = 1;
19613 if (OtherUse->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
19614 OffsetIdx = 0;
19615 assert(OtherUse->getOperand(!OffsetIdx).getNode() == BasePtr.getNode() &&
19616 "Expected BasePtr operand");
19617
19618 // We need to replace ptr0 in the following expression:
19619 // x0 * offset0 + y0 * ptr0 = t0
19620 // knowing that
19621 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
19622 //
19623 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
19624 // indexed load/store and the expression that needs to be re-written.
19625 //
19626 // Therefore, we have:
19627 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
19628
19629 auto *CN = cast<ConstantSDNode>(OtherUse->getOperand(OffsetIdx));
19630 const APInt &Offset0 = CN->getAPIntValue();
19631 const APInt &Offset1 = Offset->getAsAPIntVal();
19632 int X0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
19633 int Y0 = (OtherUse->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
19634 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
19635 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
19636
19637 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
19638
19639 APInt CNV = Offset0;
19640 if (X0 < 0) CNV = -CNV;
19641 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
19642 else CNV = CNV - Offset1;
19643
19644 SDLoc DL(OtherUse);
19645
19646 // We can now generate the new expression.
19647 SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
19648 SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
19649
19650 SDValue NewUse =
19651 DAG.getNode(Opcode, DL, OtherUse->getValueType(0), NewOp1, NewOp2);
19652 DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUse, 0), NewUse);
19653 deleteAndRecombine(OtherUse);
19654 }
19655
19656 // Replace the uses of Ptr with uses of the updated base value.
19657 DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
19658 deleteAndRecombine(Ptr.getNode());
19659 AddToWorklist(Result.getNode());
19660
19661 return true;
19662 }
19663
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)19664 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
19665 SDValue &BasePtr, SDValue &Offset,
19666 ISD::MemIndexedMode &AM,
19667 SelectionDAG &DAG,
19668 const TargetLowering &TLI) {
19669 if (PtrUse == N ||
19670 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
19671 return false;
19672
19673 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
19674 return false;
19675
19676 // Don't create a indexed load / store with zero offset.
19677 if (isNullConstant(Offset))
19678 return false;
19679
19680 if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
19681 return false;
19682
19683 SmallPtrSet<const SDNode *, 32> Visited;
19684 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19685 for (SDNode *User : BasePtr->users()) {
19686 if (User == Ptr.getNode())
19687 continue;
19688
19689 // No if there's a later user which could perform the index instead.
19690 if (isa<MemSDNode>(User)) {
19691 bool IsLoad = true;
19692 bool IsMasked = false;
19693 SDValue OtherPtr;
19694 if (getCombineLoadStoreParts(User, ISD::POST_INC, ISD::POST_DEC, IsLoad,
19695 IsMasked, OtherPtr, TLI)) {
19696 SmallVector<const SDNode *, 2> Worklist;
19697 Worklist.push_back(User);
19698 if (SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps))
19699 return false;
19700 }
19701 }
19702
19703 // If all the uses are load / store addresses, then don't do the
19704 // transformation.
19705 if (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SUB) {
19706 for (SDNode *UserUser : User->users())
19707 if (canFoldInAddressingMode(User, UserUser, DAG, TLI))
19708 return false;
19709 }
19710 }
19711 return true;
19712 }
19713
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)19714 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
19715 bool &IsMasked, SDValue &Ptr,
19716 SDValue &BasePtr, SDValue &Offset,
19717 ISD::MemIndexedMode &AM,
19718 SelectionDAG &DAG,
19719 const TargetLowering &TLI) {
19720 if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
19721 IsMasked, Ptr, TLI) ||
19722 Ptr->hasOneUse())
19723 return nullptr;
19724
19725 // Try turning it into a post-indexed load / store except when
19726 // 1) All uses are load / store ops that use it as base ptr (and
19727 // it may be folded as addressing mmode).
19728 // 2) Op must be independent of N, i.e. Op is neither a predecessor
19729 // nor a successor of N. Otherwise, if Op is folded that would
19730 // create a cycle.
19731 unsigned MaxSteps = SelectionDAG::getHasPredecessorMaxSteps();
19732 for (SDNode *Op : Ptr->users()) {
19733 // Check for #1.
19734 if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
19735 continue;
19736
19737 // Check for #2.
19738 SmallPtrSet<const SDNode *, 32> Visited;
19739 SmallVector<const SDNode *, 8> Worklist;
19740 // Ptr is predecessor to both N and Op.
19741 Visited.insert(Ptr.getNode());
19742 Worklist.push_back(N);
19743 Worklist.push_back(Op);
19744 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
19745 !SDNode::hasPredecessorHelper(Op, Visited, Worklist, MaxSteps))
19746 return Op;
19747 }
19748 return nullptr;
19749 }
19750
19751 /// Try to combine a load/store with a add/sub of the base pointer node into a
19752 /// post-indexed load/store. The transformation folded the add/subtract into the
19753 /// new indexed load/store effectively and all of its uses are redirected to the
19754 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)19755 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
19756 if (Level < AfterLegalizeDAG)
19757 return false;
19758
19759 bool IsLoad = true;
19760 bool IsMasked = false;
19761 SDValue Ptr;
19762 SDValue BasePtr;
19763 SDValue Offset;
19764 ISD::MemIndexedMode AM = ISD::UNINDEXED;
19765 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
19766 Offset, AM, DAG, TLI);
19767 if (!Op)
19768 return false;
19769
19770 SDValue Result;
19771 if (!IsMasked)
19772 Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
19773 Offset, AM)
19774 : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
19775 BasePtr, Offset, AM);
19776 else
19777 Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
19778 BasePtr, Offset, AM)
19779 : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
19780 BasePtr, Offset, AM);
19781 ++PostIndexedNodes;
19782 ++NodesCombined;
19783 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
19784 Result.dump(&DAG); dbgs() << '\n');
19785 WorklistRemover DeadNodes(*this);
19786 if (IsLoad) {
19787 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
19788 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
19789 } else {
19790 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
19791 }
19792
19793 // Finally, since the node is now dead, remove it from the graph.
19794 deleteAndRecombine(N);
19795
19796 // Replace the uses of Use with uses of the updated base value.
19797 DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
19798 Result.getValue(IsLoad ? 1 : 0));
19799 deleteAndRecombine(Op);
19800 return true;
19801 }
19802
19803 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)19804 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
19805 ISD::MemIndexedMode AM = LD->getAddressingMode();
19806 assert(AM != ISD::UNINDEXED);
19807 SDValue BP = LD->getOperand(1);
19808 SDValue Inc = LD->getOperand(2);
19809
19810 // Some backends use TargetConstants for load offsets, but don't expect
19811 // TargetConstants in general ADD nodes. We can convert these constants into
19812 // regular Constants (if the constant is not opaque).
19813 assert((Inc.getOpcode() != ISD::TargetConstant ||
19814 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
19815 "Cannot split out indexing using opaque target constants");
19816 if (Inc.getOpcode() == ISD::TargetConstant) {
19817 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
19818 Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
19819 ConstInc->getValueType(0));
19820 }
19821
19822 unsigned Opc =
19823 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
19824 return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
19825 }
19826
numVectorEltsOrZero(EVT T)19827 static inline ElementCount numVectorEltsOrZero(EVT T) {
19828 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
19829 }
19830
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)19831 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
19832 EVT STType = Val.getValueType();
19833 EVT STMemType = ST->getMemoryVT();
19834 if (STType == STMemType)
19835 return true;
19836 if (isTypeLegal(STMemType))
19837 return false; // fail.
19838 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
19839 TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
19840 Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
19841 return true;
19842 }
19843 if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
19844 STType.isInteger() && STMemType.isInteger()) {
19845 Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
19846 return true;
19847 }
19848 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
19849 Val = DAG.getBitcast(STMemType, Val);
19850 return true;
19851 }
19852 return false; // fail.
19853 }
19854
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)19855 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
19856 EVT LDMemType = LD->getMemoryVT();
19857 EVT LDType = LD->getValueType(0);
19858 assert(Val.getValueType() == LDMemType &&
19859 "Attempting to extend value of non-matching type");
19860 if (LDType == LDMemType)
19861 return true;
19862 if (LDMemType.isInteger() && LDType.isInteger()) {
19863 switch (LD->getExtensionType()) {
19864 case ISD::NON_EXTLOAD:
19865 Val = DAG.getBitcast(LDType, Val);
19866 return true;
19867 case ISD::EXTLOAD:
19868 Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
19869 return true;
19870 case ISD::SEXTLOAD:
19871 Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
19872 return true;
19873 case ISD::ZEXTLOAD:
19874 Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
19875 return true;
19876 }
19877 }
19878 return false;
19879 }
19880
getUniqueStoreFeeding(LoadSDNode * LD,int64_t & Offset)19881 StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
19882 int64_t &Offset) {
19883 SDValue Chain = LD->getOperand(0);
19884
19885 // Look through CALLSEQ_START.
19886 if (Chain.getOpcode() == ISD::CALLSEQ_START)
19887 Chain = Chain->getOperand(0);
19888
19889 StoreSDNode *ST = nullptr;
19890 SmallVector<SDValue, 8> Aliases;
19891 if (Chain.getOpcode() == ISD::TokenFactor) {
19892 // Look for unique store within the TokenFactor.
19893 for (SDValue Op : Chain->ops()) {
19894 StoreSDNode *Store = dyn_cast<StoreSDNode>(Op.getNode());
19895 if (!Store)
19896 continue;
19897 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
19898 BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
19899 if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
19900 continue;
19901 // Make sure the store is not aliased with any nodes in TokenFactor.
19902 GatherAllAliases(Store, Chain, Aliases);
19903 if (Aliases.empty() ||
19904 (Aliases.size() == 1 && Aliases.front().getNode() == Store))
19905 ST = Store;
19906 break;
19907 }
19908 } else {
19909 StoreSDNode *Store = dyn_cast<StoreSDNode>(Chain.getNode());
19910 if (Store) {
19911 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
19912 BaseIndexOffset BasePtrST = BaseIndexOffset::match(Store, DAG);
19913 if (BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
19914 ST = Store;
19915 }
19916 }
19917
19918 return ST;
19919 }
19920
ForwardStoreValueToDirectLoad(LoadSDNode * LD)19921 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
19922 if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
19923 return SDValue();
19924 SDValue Chain = LD->getOperand(0);
19925 int64_t Offset;
19926
19927 StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
19928 // TODO: Relax this restriction for unordered atomics (see D66309)
19929 if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
19930 return SDValue();
19931
19932 EVT LDType = LD->getValueType(0);
19933 EVT LDMemType = LD->getMemoryVT();
19934 EVT STMemType = ST->getMemoryVT();
19935 EVT STType = ST->getValue().getValueType();
19936
19937 // There are two cases to consider here:
19938 // 1. The store is fixed width and the load is scalable. In this case we
19939 // don't know at compile time if the store completely envelops the load
19940 // so we abandon the optimisation.
19941 // 2. The store is scalable and the load is fixed width. We could
19942 // potentially support a limited number of cases here, but there has been
19943 // no cost-benefit analysis to prove it's worth it.
19944 bool LdStScalable = LDMemType.isScalableVT();
19945 if (LdStScalable != STMemType.isScalableVT())
19946 return SDValue();
19947
19948 // If we are dealing with scalable vectors on a big endian platform the
19949 // calculation of offsets below becomes trickier, since we do not know at
19950 // compile time the absolute size of the vector. Until we've done more
19951 // analysis on big-endian platforms it seems better to bail out for now.
19952 if (LdStScalable && DAG.getDataLayout().isBigEndian())
19953 return SDValue();
19954
19955 // Normalize for Endianness. After this Offset=0 will denote that the least
19956 // significant bit in the loaded value maps to the least significant bit in
19957 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
19958 // n:th least significant byte of the stored value.
19959 int64_t OrigOffset = Offset;
19960 if (DAG.getDataLayout().isBigEndian())
19961 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
19962 (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
19963 8 -
19964 Offset;
19965
19966 // Check that the stored value cover all bits that are loaded.
19967 bool STCoversLD;
19968
19969 TypeSize LdMemSize = LDMemType.getSizeInBits();
19970 TypeSize StMemSize = STMemType.getSizeInBits();
19971 if (LdStScalable)
19972 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
19973 else
19974 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
19975 StMemSize.getFixedValue());
19976
19977 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
19978 if (LD->isIndexed()) {
19979 // Cannot handle opaque target constants and we must respect the user's
19980 // request not to split indexes from loads.
19981 if (!canSplitIdx(LD))
19982 return SDValue();
19983 SDValue Idx = SplitIndexingFromLoad(LD);
19984 SDValue Ops[] = {Val, Idx, Chain};
19985 return CombineTo(LD, Ops, 3);
19986 }
19987 return CombineTo(LD, Val, Chain);
19988 };
19989
19990 if (!STCoversLD)
19991 return SDValue();
19992
19993 // Memory as copy space (potentially masked).
19994 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
19995 // Simple case: Direct non-truncating forwarding
19996 if (LDType.getSizeInBits() == LdMemSize)
19997 return ReplaceLd(LD, ST->getValue(), Chain);
19998 // Can we model the truncate and extension with an and mask?
19999 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
20000 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
20001 // Mask to size of LDMemType
20002 auto Mask =
20003 DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
20004 StMemSize.getFixedValue()),
20005 SDLoc(ST), STType);
20006 auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
20007 return ReplaceLd(LD, Val, Chain);
20008 }
20009 }
20010
20011 // Handle some cases for big-endian that would be Offset 0 and handled for
20012 // little-endian.
20013 SDValue Val = ST->getValue();
20014 if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
20015 if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
20016 !LDType.isVector() && isTypeLegal(STType) &&
20017 TLI.isOperationLegal(ISD::SRL, STType)) {
20018 Val = DAG.getNode(ISD::SRL, SDLoc(LD), STType, Val,
20019 DAG.getConstant(Offset * 8, SDLoc(LD), STType));
20020 Offset = 0;
20021 }
20022 }
20023
20024 // TODO: Deal with nonzero offset.
20025 if (LD->getBasePtr().isUndef() || Offset != 0)
20026 return SDValue();
20027 // Model necessary truncations / extenstions.
20028 // Truncate Value To Stored Memory Size.
20029 do {
20030 if (!getTruncatedStoreValue(ST, Val))
20031 break;
20032 if (!isTypeLegal(LDMemType))
20033 break;
20034 if (STMemType != LDMemType) {
20035 // TODO: Support vectors? This requires extract_subvector/bitcast.
20036 if (!STMemType.isVector() && !LDMemType.isVector() &&
20037 STMemType.isInteger() && LDMemType.isInteger())
20038 Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
20039 else
20040 break;
20041 }
20042 if (!extendLoadedValueToExtension(LD, Val))
20043 break;
20044 return ReplaceLd(LD, Val, Chain);
20045 } while (false);
20046
20047 // On failure, cleanup dead nodes we may have created.
20048 if (Val->use_empty())
20049 deleteAndRecombine(Val.getNode());
20050 return SDValue();
20051 }
20052
visitLOAD(SDNode * N)20053 SDValue DAGCombiner::visitLOAD(SDNode *N) {
20054 LoadSDNode *LD = cast<LoadSDNode>(N);
20055 SDValue Chain = LD->getChain();
20056 SDValue Ptr = LD->getBasePtr();
20057
20058 // If load is not volatile and there are no uses of the loaded value (and
20059 // the updated indexed value in case of indexed loads), change uses of the
20060 // chain value into uses of the chain input (i.e. delete the dead load).
20061 // TODO: Allow this for unordered atomics (see D66309)
20062 if (LD->isSimple()) {
20063 if (N->getValueType(1) == MVT::Other) {
20064 // Unindexed loads.
20065 if (!N->hasAnyUseOfValue(0)) {
20066 // It's not safe to use the two value CombineTo variant here. e.g.
20067 // v1, chain2 = load chain1, loc
20068 // v2, chain3 = load chain2, loc
20069 // v3 = add v2, c
20070 // Now we replace use of chain2 with chain1. This makes the second load
20071 // isomorphic to the one we are deleting, and thus makes this load live.
20072 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
20073 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
20074 dbgs() << "\n");
20075 WorklistRemover DeadNodes(*this);
20076 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
20077 AddUsersToWorklist(Chain.getNode());
20078 if (N->use_empty())
20079 deleteAndRecombine(N);
20080
20081 return SDValue(N, 0); // Return N so it doesn't get rechecked!
20082 }
20083 } else {
20084 // Indexed loads.
20085 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
20086
20087 // If this load has an opaque TargetConstant offset, then we cannot split
20088 // the indexing into an add/sub directly (that TargetConstant may not be
20089 // valid for a different type of node, and we cannot convert an opaque
20090 // target constant into a regular constant).
20091 bool CanSplitIdx = canSplitIdx(LD);
20092
20093 if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
20094 SDValue Undef = DAG.getUNDEF(N->getValueType(0));
20095 SDValue Index;
20096 if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
20097 Index = SplitIndexingFromLoad(LD);
20098 // Try to fold the base pointer arithmetic into subsequent loads and
20099 // stores.
20100 AddUsersToWorklist(N);
20101 } else
20102 Index = DAG.getUNDEF(N->getValueType(1));
20103 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
20104 dbgs() << "\nWith: "; Undef.dump(&DAG);
20105 dbgs() << " and 2 other values\n");
20106 WorklistRemover DeadNodes(*this);
20107 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
20108 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
20109 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
20110 deleteAndRecombine(N);
20111 return SDValue(N, 0); // Return N so it doesn't get rechecked!
20112 }
20113 }
20114 }
20115
20116 // If this load is directly stored, replace the load value with the stored
20117 // value.
20118 if (auto V = ForwardStoreValueToDirectLoad(LD))
20119 return V;
20120
20121 // Try to infer better alignment information than the load already has.
20122 if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
20123 !LD->isAtomic()) {
20124 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
20125 if (*Alignment > LD->getAlign() &&
20126 isAligned(*Alignment, LD->getSrcValueOffset())) {
20127 SDValue NewLoad = DAG.getExtLoad(
20128 LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
20129 LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
20130 LD->getMemOperand()->getFlags(), LD->getAAInfo());
20131 // NewLoad will always be N as we are only refining the alignment
20132 assert(NewLoad.getNode() == N);
20133 (void)NewLoad;
20134 }
20135 }
20136 }
20137
20138 if (LD->isUnindexed()) {
20139 // Walk up chain skipping non-aliasing memory nodes.
20140 SDValue BetterChain = FindBetterChain(LD, Chain);
20141
20142 // If there is a better chain.
20143 if (Chain != BetterChain) {
20144 SDValue ReplLoad;
20145
20146 // Replace the chain to void dependency.
20147 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
20148 ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
20149 BetterChain, Ptr, LD->getMemOperand());
20150 } else {
20151 ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
20152 LD->getValueType(0),
20153 BetterChain, Ptr, LD->getMemoryVT(),
20154 LD->getMemOperand());
20155 }
20156
20157 // Create token factor to keep old chain connected.
20158 SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
20159 MVT::Other, Chain, ReplLoad.getValue(1));
20160
20161 // Replace uses with load result and token factor
20162 return CombineTo(N, ReplLoad.getValue(0), Token);
20163 }
20164 }
20165
20166 // Try transforming N to an indexed load.
20167 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
20168 return SDValue(N, 0);
20169
20170 // Try to slice up N to more direct loads if the slices are mapped to
20171 // different register banks or pairing can take place.
20172 if (SliceUpLoad(N))
20173 return SDValue(N, 0);
20174
20175 return SDValue();
20176 }
20177
20178 namespace {
20179
20180 /// Helper structure used to slice a load in smaller loads.
20181 /// Basically a slice is obtained from the following sequence:
20182 /// Origin = load Ty1, Base
20183 /// Shift = srl Ty1 Origin, CstTy Amount
20184 /// Inst = trunc Shift to Ty2
20185 ///
20186 /// Then, it will be rewritten into:
20187 /// Slice = load SliceTy, Base + SliceOffset
20188 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
20189 ///
20190 /// SliceTy is deduced from the number of bits that are actually used to
20191 /// build Inst.
20192 struct LoadedSlice {
20193 /// Helper structure used to compute the cost of a slice.
20194 struct Cost {
20195 /// Are we optimizing for code size.
20196 bool ForCodeSize = false;
20197
20198 /// Various cost.
20199 unsigned Loads = 0;
20200 unsigned Truncates = 0;
20201 unsigned CrossRegisterBanksCopies = 0;
20202 unsigned ZExts = 0;
20203 unsigned Shift = 0;
20204
Cost__anon666e37104411::LoadedSlice::Cost20205 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
20206
20207 /// Get the cost of one isolated slice.
Cost__anon666e37104411::LoadedSlice::Cost20208 Cost(const LoadedSlice &LS, bool ForCodeSize)
20209 : ForCodeSize(ForCodeSize), Loads(1) {
20210 EVT TruncType = LS.Inst->getValueType(0);
20211 EVT LoadedType = LS.getLoadedType();
20212 if (TruncType != LoadedType &&
20213 !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
20214 ZExts = 1;
20215 }
20216
20217 /// Account for slicing gain in the current cost.
20218 /// Slicing provide a few gains like removing a shift or a
20219 /// truncate. This method allows to grow the cost of the original
20220 /// load with the gain from this slice.
addSliceGain__anon666e37104411::LoadedSlice::Cost20221 void addSliceGain(const LoadedSlice &LS) {
20222 // Each slice saves a truncate.
20223 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
20224 if (!TLI.isTruncateFree(LS.Inst->getOperand(0), LS.Inst->getValueType(0)))
20225 ++Truncates;
20226 // If there is a shift amount, this slice gets rid of it.
20227 if (LS.Shift)
20228 ++Shift;
20229 // If this slice can merge a cross register bank copy, account for it.
20230 if (LS.canMergeExpensiveCrossRegisterBankCopy())
20231 ++CrossRegisterBanksCopies;
20232 }
20233
operator +=__anon666e37104411::LoadedSlice::Cost20234 Cost &operator+=(const Cost &RHS) {
20235 Loads += RHS.Loads;
20236 Truncates += RHS.Truncates;
20237 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
20238 ZExts += RHS.ZExts;
20239 Shift += RHS.Shift;
20240 return *this;
20241 }
20242
operator ==__anon666e37104411::LoadedSlice::Cost20243 bool operator==(const Cost &RHS) const {
20244 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
20245 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
20246 ZExts == RHS.ZExts && Shift == RHS.Shift;
20247 }
20248
operator !=__anon666e37104411::LoadedSlice::Cost20249 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
20250
operator <__anon666e37104411::LoadedSlice::Cost20251 bool operator<(const Cost &RHS) const {
20252 // Assume cross register banks copies are as expensive as loads.
20253 // FIXME: Do we want some more target hooks?
20254 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
20255 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
20256 // Unless we are optimizing for code size, consider the
20257 // expensive operation first.
20258 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
20259 return ExpensiveOpsLHS < ExpensiveOpsRHS;
20260 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
20261 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
20262 }
20263
operator >__anon666e37104411::LoadedSlice::Cost20264 bool operator>(const Cost &RHS) const { return RHS < *this; }
20265
operator <=__anon666e37104411::LoadedSlice::Cost20266 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
20267
operator >=__anon666e37104411::LoadedSlice::Cost20268 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
20269 };
20270
20271 // The last instruction that represent the slice. This should be a
20272 // truncate instruction.
20273 SDNode *Inst;
20274
20275 // The original load instruction.
20276 LoadSDNode *Origin;
20277
20278 // The right shift amount in bits from the original load.
20279 unsigned Shift;
20280
20281 // The DAG from which Origin came from.
20282 // This is used to get some contextual information about legal types, etc.
20283 SelectionDAG *DAG;
20284
LoadedSlice__anon666e37104411::LoadedSlice20285 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
20286 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
20287 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
20288
20289 /// Get the bits used in a chunk of bits \p BitWidth large.
20290 /// \return Result is \p BitWidth and has used bits set to 1 and
20291 /// not used bits set to 0.
getUsedBits__anon666e37104411::LoadedSlice20292 APInt getUsedBits() const {
20293 // Reproduce the trunc(lshr) sequence:
20294 // - Start from the truncated value.
20295 // - Zero extend to the desired bit width.
20296 // - Shift left.
20297 assert(Origin && "No original load to compare against.");
20298 unsigned BitWidth = Origin->getValueSizeInBits(0);
20299 assert(Inst && "This slice is not bound to an instruction");
20300 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
20301 "Extracted slice is bigger than the whole type!");
20302 APInt UsedBits(Inst->getValueSizeInBits(0), 0);
20303 UsedBits.setAllBits();
20304 UsedBits = UsedBits.zext(BitWidth);
20305 UsedBits <<= Shift;
20306 return UsedBits;
20307 }
20308
20309 /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anon666e37104411::LoadedSlice20310 unsigned getLoadedSize() const {
20311 unsigned SliceSize = getUsedBits().popcount();
20312 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
20313 return SliceSize / 8;
20314 }
20315
20316 /// Get the type that will be loaded for this slice.
20317 /// Note: This may not be the final type for the slice.
getLoadedType__anon666e37104411::LoadedSlice20318 EVT getLoadedType() const {
20319 assert(DAG && "Missing context");
20320 LLVMContext &Ctxt = *DAG->getContext();
20321 return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
20322 }
20323
20324 /// Get the alignment of the load used for this slice.
getAlign__anon666e37104411::LoadedSlice20325 Align getAlign() const {
20326 Align Alignment = Origin->getAlign();
20327 uint64_t Offset = getOffsetFromBase();
20328 if (Offset != 0)
20329 Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
20330 return Alignment;
20331 }
20332
20333 /// Check if this slice can be rewritten with legal operations.
isLegal__anon666e37104411::LoadedSlice20334 bool isLegal() const {
20335 // An invalid slice is not legal.
20336 if (!Origin || !Inst || !DAG)
20337 return false;
20338
20339 // Offsets are for indexed load only, we do not handle that.
20340 if (!Origin->getOffset().isUndef())
20341 return false;
20342
20343 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
20344
20345 // Check that the type is legal.
20346 EVT SliceType = getLoadedType();
20347 if (!TLI.isTypeLegal(SliceType))
20348 return false;
20349
20350 // Check that the load is legal for this type.
20351 if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
20352 return false;
20353
20354 // Check that the offset can be computed.
20355 // 1. Check its type.
20356 EVT PtrType = Origin->getBasePtr().getValueType();
20357 if (PtrType == MVT::Untyped || PtrType.isExtended())
20358 return false;
20359
20360 // 2. Check that it fits in the immediate.
20361 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
20362 return false;
20363
20364 // 3. Check that the computation is legal.
20365 if (!TLI.isOperationLegal(ISD::ADD, PtrType))
20366 return false;
20367
20368 // Check that the zext is legal if it needs one.
20369 EVT TruncateType = Inst->getValueType(0);
20370 if (TruncateType != SliceType &&
20371 !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
20372 return false;
20373
20374 return true;
20375 }
20376
20377 /// Get the offset in bytes of this slice in the original chunk of
20378 /// bits.
20379 /// \pre DAG != nullptr.
getOffsetFromBase__anon666e37104411::LoadedSlice20380 uint64_t getOffsetFromBase() const {
20381 assert(DAG && "Missing context.");
20382 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
20383 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
20384 uint64_t Offset = Shift / 8;
20385 unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
20386 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
20387 "The size of the original loaded type is not a multiple of a"
20388 " byte.");
20389 // If Offset is bigger than TySizeInBytes, it means we are loading all
20390 // zeros. This should have been optimized before in the process.
20391 assert(TySizeInBytes > Offset &&
20392 "Invalid shift amount for given loaded size");
20393 if (IsBigEndian)
20394 Offset = TySizeInBytes - Offset - getLoadedSize();
20395 return Offset;
20396 }
20397
20398 /// Generate the sequence of instructions to load the slice
20399 /// represented by this object and redirect the uses of this slice to
20400 /// this new sequence of instructions.
20401 /// \pre this->Inst && this->Origin are valid Instructions and this
20402 /// object passed the legal check: LoadedSlice::isLegal returned true.
20403 /// \return The last instruction of the sequence used to load the slice.
loadSlice__anon666e37104411::LoadedSlice20404 SDValue loadSlice() const {
20405 assert(Inst && Origin && "Unable to replace a non-existing slice.");
20406 const SDValue &OldBaseAddr = Origin->getBasePtr();
20407 SDValue BaseAddr = OldBaseAddr;
20408 // Get the offset in that chunk of bytes w.r.t. the endianness.
20409 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
20410 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
20411 if (Offset) {
20412 // BaseAddr = BaseAddr + Offset.
20413 EVT ArithType = BaseAddr.getValueType();
20414 SDLoc DL(Origin);
20415 BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
20416 DAG->getConstant(Offset, DL, ArithType));
20417 }
20418
20419 // Create the type of the loaded slice according to its size.
20420 EVT SliceType = getLoadedType();
20421
20422 // Create the load for the slice.
20423 SDValue LastInst =
20424 DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
20425 Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
20426 Origin->getMemOperand()->getFlags());
20427 // If the final type is not the same as the loaded type, this means that
20428 // we have to pad with zero. Create a zero extend for that.
20429 EVT FinalType = Inst->getValueType(0);
20430 if (SliceType != FinalType)
20431 LastInst =
20432 DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
20433 return LastInst;
20434 }
20435
20436 /// Check if this slice can be merged with an expensive cross register
20437 /// bank copy. E.g.,
20438 /// i = load i32
20439 /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anon666e37104411::LoadedSlice20440 bool canMergeExpensiveCrossRegisterBankCopy() const {
20441 if (!Inst || !Inst->hasOneUse())
20442 return false;
20443 SDNode *User = *Inst->user_begin();
20444 if (User->getOpcode() != ISD::BITCAST)
20445 return false;
20446 assert(DAG && "Missing context");
20447 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
20448 EVT ResVT = User->getValueType(0);
20449 const TargetRegisterClass *ResRC =
20450 TLI.getRegClassFor(ResVT.getSimpleVT(), User->isDivergent());
20451 const TargetRegisterClass *ArgRC =
20452 TLI.getRegClassFor(User->getOperand(0).getValueType().getSimpleVT(),
20453 User->getOperand(0)->isDivergent());
20454 if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
20455 return false;
20456
20457 // At this point, we know that we perform a cross-register-bank copy.
20458 // Check if it is expensive.
20459 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
20460 // Assume bitcasts are cheap, unless both register classes do not
20461 // explicitly share a common sub class.
20462 if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
20463 return false;
20464
20465 // Check if it will be merged with the load.
20466 // 1. Check the alignment / fast memory access constraint.
20467 unsigned IsFast = 0;
20468 if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
20469 Origin->getAddressSpace(), getAlign(),
20470 Origin->getMemOperand()->getFlags(), &IsFast) ||
20471 !IsFast)
20472 return false;
20473
20474 // 2. Check that the load is a legal operation for that type.
20475 if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
20476 return false;
20477
20478 // 3. Check that we do not have a zext in the way.
20479 if (Inst->getValueType(0) != getLoadedType())
20480 return false;
20481
20482 return true;
20483 }
20484 };
20485
20486 } // end anonymous namespace
20487
20488 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
20489 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)20490 static bool areUsedBitsDense(const APInt &UsedBits) {
20491 // If all the bits are one, this is dense!
20492 if (UsedBits.isAllOnes())
20493 return true;
20494
20495 // Get rid of the unused bits on the right.
20496 APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countr_zero());
20497 // Get rid of the unused bits on the left.
20498 if (NarrowedUsedBits.countl_zero())
20499 NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
20500 // Check that the chunk of bits is completely used.
20501 return NarrowedUsedBits.isAllOnes();
20502 }
20503
20504 /// Check whether or not \p First and \p Second are next to each other
20505 /// in memory. This means that there is no hole between the bits loaded
20506 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)20507 static bool areSlicesNextToEachOther(const LoadedSlice &First,
20508 const LoadedSlice &Second) {
20509 assert(First.Origin == Second.Origin && First.Origin &&
20510 "Unable to match different memory origins.");
20511 APInt UsedBits = First.getUsedBits();
20512 assert((UsedBits & Second.getUsedBits()) == 0 &&
20513 "Slices are not supposed to overlap.");
20514 UsedBits |= Second.getUsedBits();
20515 return areUsedBitsDense(UsedBits);
20516 }
20517
20518 /// Adjust the \p GlobalLSCost according to the target
20519 /// paring capabilities and the layout of the slices.
20520 /// \pre \p GlobalLSCost should account for at least as many loads as
20521 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)20522 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
20523 LoadedSlice::Cost &GlobalLSCost) {
20524 unsigned NumberOfSlices = LoadedSlices.size();
20525 // If there is less than 2 elements, no pairing is possible.
20526 if (NumberOfSlices < 2)
20527 return;
20528
20529 // Sort the slices so that elements that are likely to be next to each
20530 // other in memory are next to each other in the list.
20531 llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
20532 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
20533 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
20534 });
20535 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
20536 // First (resp. Second) is the first (resp. Second) potentially candidate
20537 // to be placed in a paired load.
20538 const LoadedSlice *First = nullptr;
20539 const LoadedSlice *Second = nullptr;
20540 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
20541 // Set the beginning of the pair.
20542 First = Second) {
20543 Second = &LoadedSlices[CurrSlice];
20544
20545 // If First is NULL, it means we start a new pair.
20546 // Get to the next slice.
20547 if (!First)
20548 continue;
20549
20550 EVT LoadedType = First->getLoadedType();
20551
20552 // If the types of the slices are different, we cannot pair them.
20553 if (LoadedType != Second->getLoadedType())
20554 continue;
20555
20556 // Check if the target supplies paired loads for this type.
20557 Align RequiredAlignment;
20558 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
20559 // move to the next pair, this type is hopeless.
20560 Second = nullptr;
20561 continue;
20562 }
20563 // Check if we meet the alignment requirement.
20564 if (First->getAlign() < RequiredAlignment)
20565 continue;
20566
20567 // Check that both loads are next to each other in memory.
20568 if (!areSlicesNextToEachOther(*First, *Second))
20569 continue;
20570
20571 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
20572 --GlobalLSCost.Loads;
20573 // Move to the next pair.
20574 Second = nullptr;
20575 }
20576 }
20577
20578 /// Check the profitability of all involved LoadedSlice.
20579 /// Currently, it is considered profitable if there is exactly two
20580 /// involved slices (1) which are (2) next to each other in memory, and
20581 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
20582 ///
20583 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
20584 /// the elements themselves.
20585 ///
20586 /// FIXME: When the cost model will be mature enough, we can relax
20587 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)20588 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
20589 const APInt &UsedBits, bool ForCodeSize) {
20590 unsigned NumberOfSlices = LoadedSlices.size();
20591 if (StressLoadSlicing)
20592 return NumberOfSlices > 1;
20593
20594 // Check (1).
20595 if (NumberOfSlices != 2)
20596 return false;
20597
20598 // Check (2).
20599 if (!areUsedBitsDense(UsedBits))
20600 return false;
20601
20602 // Check (3).
20603 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
20604 // The original code has one big load.
20605 OrigCost.Loads = 1;
20606 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
20607 const LoadedSlice &LS = LoadedSlices[CurrSlice];
20608 // Accumulate the cost of all the slices.
20609 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
20610 GlobalSlicingCost += SliceCost;
20611
20612 // Account as cost in the original configuration the gain obtained
20613 // with the current slices.
20614 OrigCost.addSliceGain(LS);
20615 }
20616
20617 // If the target supports paired load, adjust the cost accordingly.
20618 adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
20619 return OrigCost > GlobalSlicingCost;
20620 }
20621
20622 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
20623 /// operations, split it in the various pieces being extracted.
20624 ///
20625 /// This sort of thing is introduced by SROA.
20626 /// This slicing takes care not to insert overlapping loads.
20627 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)20628 bool DAGCombiner::SliceUpLoad(SDNode *N) {
20629 if (Level < AfterLegalizeDAG)
20630 return false;
20631
20632 LoadSDNode *LD = cast<LoadSDNode>(N);
20633 if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
20634 !LD->getValueType(0).isInteger())
20635 return false;
20636
20637 // The algorithm to split up a load of a scalable vector into individual
20638 // elements currently requires knowing the length of the loaded type,
20639 // so will need adjusting to work on scalable vectors.
20640 if (LD->getValueType(0).isScalableVector())
20641 return false;
20642
20643 // Keep track of already used bits to detect overlapping values.
20644 // In that case, we will just abort the transformation.
20645 APInt UsedBits(LD->getValueSizeInBits(0), 0);
20646
20647 SmallVector<LoadedSlice, 4> LoadedSlices;
20648
20649 // Check if this load is used as several smaller chunks of bits.
20650 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
20651 // of computation for each trunc.
20652 for (SDUse &U : LD->uses()) {
20653 // Skip the uses of the chain.
20654 if (U.getResNo() != 0)
20655 continue;
20656
20657 SDNode *User = U.getUser();
20658 unsigned Shift = 0;
20659
20660 // Check if this is a trunc(lshr).
20661 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
20662 isa<ConstantSDNode>(User->getOperand(1))) {
20663 Shift = User->getConstantOperandVal(1);
20664 User = *User->user_begin();
20665 }
20666
20667 // At this point, User is a Truncate, iff we encountered, trunc or
20668 // trunc(lshr).
20669 if (User->getOpcode() != ISD::TRUNCATE)
20670 return false;
20671
20672 // The width of the type must be a power of 2 and greater than 8-bits.
20673 // Otherwise the load cannot be represented in LLVM IR.
20674 // Moreover, if we shifted with a non-8-bits multiple, the slice
20675 // will be across several bytes. We do not support that.
20676 unsigned Width = User->getValueSizeInBits(0);
20677 if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
20678 return false;
20679
20680 // Build the slice for this chain of computations.
20681 LoadedSlice LS(User, LD, Shift, &DAG);
20682 APInt CurrentUsedBits = LS.getUsedBits();
20683
20684 // Check if this slice overlaps with another.
20685 if ((CurrentUsedBits & UsedBits) != 0)
20686 return false;
20687 // Update the bits used globally.
20688 UsedBits |= CurrentUsedBits;
20689
20690 // Check if the new slice would be legal.
20691 if (!LS.isLegal())
20692 return false;
20693
20694 // Record the slice.
20695 LoadedSlices.push_back(LS);
20696 }
20697
20698 // Abort slicing if it does not seem to be profitable.
20699 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
20700 return false;
20701
20702 ++SlicedLoads;
20703
20704 // Rewrite each chain to use an independent load.
20705 // By construction, each chain can be represented by a unique load.
20706
20707 // Prepare the argument for the new token factor for all the slices.
20708 SmallVector<SDValue, 8> ArgChains;
20709 for (const LoadedSlice &LS : LoadedSlices) {
20710 SDValue SliceInst = LS.loadSlice();
20711 CombineTo(LS.Inst, SliceInst, true);
20712 if (SliceInst.getOpcode() != ISD::LOAD)
20713 SliceInst = SliceInst.getOperand(0);
20714 assert(SliceInst->getOpcode() == ISD::LOAD &&
20715 "It takes more than a zext to get to the loaded slice!!");
20716 ArgChains.push_back(SliceInst.getValue(1));
20717 }
20718
20719 SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
20720 ArgChains);
20721 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
20722 AddToWorklist(Chain.getNode());
20723 return true;
20724 }
20725
20726 /// Check to see if V is (and load (ptr), imm), where the load is having
20727 /// specific bytes cleared out. If so, return the byte size being masked out
20728 /// and the shift amount.
20729 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)20730 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
20731 std::pair<unsigned, unsigned> Result(0, 0);
20732
20733 // Check for the structure we're looking for.
20734 if (V->getOpcode() != ISD::AND ||
20735 !isa<ConstantSDNode>(V->getOperand(1)) ||
20736 !ISD::isNormalLoad(V->getOperand(0).getNode()))
20737 return Result;
20738
20739 // Check the chain and pointer.
20740 LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
20741 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
20742
20743 // This only handles simple types.
20744 if (V.getValueType() != MVT::i16 &&
20745 V.getValueType() != MVT::i32 &&
20746 V.getValueType() != MVT::i64)
20747 return Result;
20748
20749 // Check the constant mask. Invert it so that the bits being masked out are
20750 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
20751 // follow the sign bit for uniformity.
20752 uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
20753 unsigned NotMaskLZ = llvm::countl_zero(NotMask);
20754 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
20755 unsigned NotMaskTZ = llvm::countr_zero(NotMask);
20756 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
20757 if (NotMaskLZ == 64) return Result; // All zero mask.
20758
20759 // See if we have a continuous run of bits. If so, we have 0*1+0*
20760 if (llvm::countr_one(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
20761 return Result;
20762
20763 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
20764 if (V.getValueType() != MVT::i64 && NotMaskLZ)
20765 NotMaskLZ -= 64-V.getValueSizeInBits();
20766
20767 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
20768 switch (MaskedBytes) {
20769 case 1:
20770 case 2:
20771 case 4: break;
20772 default: return Result; // All one mask, or 5-byte mask.
20773 }
20774
20775 // Verify that the first bit starts at a multiple of mask so that the access
20776 // is aligned the same as the access width.
20777 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
20778
20779 // For narrowing to be valid, it must be the case that the load the
20780 // immediately preceding memory operation before the store.
20781 if (LD == Chain.getNode())
20782 ; // ok.
20783 else if (Chain->getOpcode() == ISD::TokenFactor &&
20784 SDValue(LD, 1).hasOneUse()) {
20785 // LD has only 1 chain use so they are no indirect dependencies.
20786 if (!LD->isOperandOf(Chain.getNode()))
20787 return Result;
20788 } else
20789 return Result; // Fail.
20790
20791 Result.first = MaskedBytes;
20792 Result.second = NotMaskTZ/8;
20793 return Result;
20794 }
20795
20796 /// Check to see if IVal is something that provides a value as specified by
20797 /// MaskInfo. If so, replace the specified store with a narrower store of
20798 /// truncated IVal.
20799 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)20800 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
20801 SDValue IVal, StoreSDNode *St,
20802 DAGCombiner *DC) {
20803 unsigned NumBytes = MaskInfo.first;
20804 unsigned ByteShift = MaskInfo.second;
20805 SelectionDAG &DAG = DC->getDAG();
20806
20807 // Check to see if IVal is all zeros in the part being masked in by the 'or'
20808 // that uses this. If not, this is not a replacement.
20809 APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
20810 ByteShift*8, (ByteShift+NumBytes)*8);
20811 if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
20812
20813 // Check that it is legal on the target to do this. It is legal if the new
20814 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
20815 // legalization. If the source type is legal, but the store type isn't, see
20816 // if we can use a truncating store.
20817 MVT VT = MVT::getIntegerVT(NumBytes * 8);
20818 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20819 bool UseTruncStore;
20820 if (DC->isTypeLegal(VT))
20821 UseTruncStore = false;
20822 else if (TLI.isTypeLegal(IVal.getValueType()) &&
20823 TLI.isTruncStoreLegal(IVal.getValueType(), VT))
20824 UseTruncStore = true;
20825 else
20826 return SDValue();
20827
20828 // Can't do this for indexed stores.
20829 if (St->isIndexed())
20830 return SDValue();
20831
20832 // Check that the target doesn't think this is a bad idea.
20833 if (St->getMemOperand() &&
20834 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
20835 *St->getMemOperand()))
20836 return SDValue();
20837
20838 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
20839 // shifted by ByteShift and truncated down to NumBytes.
20840 if (ByteShift) {
20841 SDLoc DL(IVal);
20842 IVal = DAG.getNode(
20843 ISD::SRL, DL, IVal.getValueType(), IVal,
20844 DAG.getShiftAmountConstant(ByteShift * 8, IVal.getValueType(), DL));
20845 }
20846
20847 // Figure out the offset for the store and the alignment of the access.
20848 unsigned StOffset;
20849 if (DAG.getDataLayout().isLittleEndian())
20850 StOffset = ByteShift;
20851 else
20852 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
20853
20854 SDValue Ptr = St->getBasePtr();
20855 if (StOffset) {
20856 SDLoc DL(IVal);
20857 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(StOffset), DL);
20858 }
20859
20860 ++OpsNarrowed;
20861 if (UseTruncStore)
20862 return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
20863 St->getPointerInfo().getWithOffset(StOffset), VT,
20864 St->getBaseAlign());
20865
20866 // Truncate down to the new size.
20867 IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
20868
20869 return DAG.getStore(St->getChain(), SDLoc(St), IVal, Ptr,
20870 St->getPointerInfo().getWithOffset(StOffset),
20871 St->getBaseAlign());
20872 }
20873
20874 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
20875 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
20876 /// narrowing the load and store if it would end up being a win for performance
20877 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)20878 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
20879 StoreSDNode *ST = cast<StoreSDNode>(N);
20880 if (!ST->isSimple())
20881 return SDValue();
20882
20883 SDValue Chain = ST->getChain();
20884 SDValue Value = ST->getValue();
20885 SDValue Ptr = ST->getBasePtr();
20886 EVT VT = Value.getValueType();
20887
20888 if (ST->isTruncatingStore() || VT.isVector())
20889 return SDValue();
20890
20891 unsigned Opc = Value.getOpcode();
20892
20893 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
20894 !Value.hasOneUse())
20895 return SDValue();
20896
20897 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
20898 // is a byte mask indicating a consecutive number of bytes, check to see if
20899 // Y is known to provide just those bytes. If so, we try to replace the
20900 // load + replace + store sequence with a single (narrower) store, which makes
20901 // the load dead.
20902 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
20903 std::pair<unsigned, unsigned> MaskedLoad;
20904 MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
20905 if (MaskedLoad.first)
20906 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
20907 Value.getOperand(1), ST,this))
20908 return NewST;
20909
20910 // Or is commutative, so try swapping X and Y.
20911 MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
20912 if (MaskedLoad.first)
20913 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
20914 Value.getOperand(0), ST,this))
20915 return NewST;
20916 }
20917
20918 if (!EnableReduceLoadOpStoreWidth)
20919 return SDValue();
20920
20921 if (Value.getOperand(1).getOpcode() != ISD::Constant)
20922 return SDValue();
20923
20924 SDValue N0 = Value.getOperand(0);
20925 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
20926 Chain == SDValue(N0.getNode(), 1)) {
20927 LoadSDNode *LD = cast<LoadSDNode>(N0);
20928 if (LD->getBasePtr() != Ptr ||
20929 LD->getPointerInfo().getAddrSpace() !=
20930 ST->getPointerInfo().getAddrSpace())
20931 return SDValue();
20932
20933 // Find the type NewVT to narrow the load / op / store to.
20934 SDValue N1 = Value.getOperand(1);
20935 unsigned BitWidth = N1.getValueSizeInBits();
20936 APInt Imm = N1->getAsAPIntVal();
20937 if (Opc == ISD::AND)
20938 Imm.flipAllBits();
20939 if (Imm == 0 || Imm.isAllOnes())
20940 return SDValue();
20941 // Find least/most significant bit that need to be part of the narrowed
20942 // operation. We assume target will need to address/access full bytes, so
20943 // we make sure to align LSB and MSB at byte boundaries.
20944 unsigned BitsPerByteMask = 7u;
20945 unsigned LSB = Imm.countr_zero() & ~BitsPerByteMask;
20946 unsigned MSB = (Imm.getActiveBits() - 1) | BitsPerByteMask;
20947 unsigned NewBW = NextPowerOf2(MSB - LSB);
20948 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
20949 // The narrowing should be profitable, the load/store operation should be
20950 // legal (or custom) and the store size should be equal to the NewVT width.
20951 while (NewBW < BitWidth &&
20952 (NewVT.getStoreSizeInBits() != NewBW ||
20953 !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
20954 (!ReduceLoadOpStoreWidthForceNarrowingProfitable &&
20955 !TLI.isNarrowingProfitable(N, VT, NewVT)))) {
20956 NewBW = NextPowerOf2(NewBW);
20957 NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
20958 }
20959 if (NewBW >= BitWidth)
20960 return SDValue();
20961
20962 // If we come this far NewVT/NewBW reflect a power-of-2 sized type that is
20963 // large enough to cover all bits that should be modified. This type might
20964 // however be larger than really needed (such as i32 while we actually only
20965 // need to modify one byte). Now we need to find our how to align the memory
20966 // accesses to satisfy preferred alignments as well as avoiding to access
20967 // memory outside the store size of the orignal access.
20968
20969 unsigned VTStoreSize = VT.getStoreSizeInBits().getFixedValue();
20970
20971 // Let ShAmt denote amount of bits to skip, counted from the least
20972 // significant bits of Imm. And let PtrOff how much the pointer needs to be
20973 // offsetted (in bytes) for the new access.
20974 unsigned ShAmt = 0;
20975 uint64_t PtrOff = 0;
20976 for (; ShAmt + NewBW <= VTStoreSize; ShAmt += 8) {
20977 // Make sure the range [ShAmt, ShAmt+NewBW) cover both LSB and MSB.
20978 if (ShAmt > LSB)
20979 return SDValue();
20980 if (ShAmt + NewBW < MSB)
20981 continue;
20982
20983 // Calculate PtrOff.
20984 unsigned PtrAdjustmentInBits = DAG.getDataLayout().isBigEndian()
20985 ? VTStoreSize - NewBW - ShAmt
20986 : ShAmt;
20987 PtrOff = PtrAdjustmentInBits / 8;
20988
20989 // Now check if narrow access is allowed and fast, considering alignments.
20990 unsigned IsFast = 0;
20991 Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
20992 if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
20993 LD->getAddressSpace(), NewAlign,
20994 LD->getMemOperand()->getFlags(), &IsFast) &&
20995 IsFast)
20996 break;
20997 }
20998 // If loop above did not find any accepted ShAmt we need to exit here.
20999 if (ShAmt + NewBW > VTStoreSize)
21000 return SDValue();
21001
21002 APInt NewImm = Imm.lshr(ShAmt).trunc(NewBW);
21003 if (Opc == ISD::AND)
21004 NewImm.flipAllBits();
21005 Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
21006 SDValue NewPtr =
21007 DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(PtrOff), SDLoc(LD));
21008 SDValue NewLD =
21009 DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
21010 LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
21011 LD->getMemOperand()->getFlags(), LD->getAAInfo());
21012 SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
21013 DAG.getConstant(NewImm, SDLoc(Value), NewVT));
21014 SDValue NewST =
21015 DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
21016 ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
21017
21018 AddToWorklist(NewPtr.getNode());
21019 AddToWorklist(NewLD.getNode());
21020 AddToWorklist(NewVal.getNode());
21021 WorklistRemover DeadNodes(*this);
21022 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
21023 ++OpsNarrowed;
21024 return NewST;
21025 }
21026
21027 return SDValue();
21028 }
21029
21030 /// For a given floating point load / store pair, if the load value isn't used
21031 /// by any other operations, then consider transforming the pair to integer
21032 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)21033 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
21034 StoreSDNode *ST = cast<StoreSDNode>(N);
21035 SDValue Value = ST->getValue();
21036 if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
21037 Value.hasOneUse()) {
21038 LoadSDNode *LD = cast<LoadSDNode>(Value);
21039 EVT VT = LD->getMemoryVT();
21040 if (!VT.isSimple() || !VT.isFloatingPoint() || VT != ST->getMemoryVT() ||
21041 LD->isNonTemporal() || ST->isNonTemporal() ||
21042 LD->getPointerInfo().getAddrSpace() != 0 ||
21043 ST->getPointerInfo().getAddrSpace() != 0)
21044 return SDValue();
21045
21046 TypeSize VTSize = VT.getSizeInBits();
21047
21048 // We don't know the size of scalable types at compile time so we cannot
21049 // create an integer of the equivalent size.
21050 if (VTSize.isScalable())
21051 return SDValue();
21052
21053 unsigned FastLD = 0, FastST = 0;
21054 EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedValue());
21055 if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
21056 !TLI.isOperationLegal(ISD::STORE, IntVT) ||
21057 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
21058 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
21059 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
21060 *LD->getMemOperand(), &FastLD) ||
21061 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
21062 *ST->getMemOperand(), &FastST) ||
21063 !FastLD || !FastST)
21064 return SDValue();
21065
21066 SDValue NewLD = DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(),
21067 LD->getBasePtr(), LD->getMemOperand());
21068
21069 SDValue NewST = DAG.getStore(ST->getChain(), SDLoc(N), NewLD,
21070 ST->getBasePtr(), ST->getMemOperand());
21071
21072 AddToWorklist(NewLD.getNode());
21073 AddToWorklist(NewST.getNode());
21074 WorklistRemover DeadNodes(*this);
21075 DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
21076 ++LdStFP2Int;
21077 return NewST;
21078 }
21079
21080 return SDValue();
21081 }
21082
21083 // This is a helper function for visitMUL to check the profitability
21084 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
21085 // MulNode is the original multiply, AddNode is (add x, c1),
21086 // and ConstNode is c2.
21087 //
21088 // If the (add x, c1) has multiple uses, we could increase
21089 // the number of adds if we make this transformation.
21090 // It would only be worth doing this if we can remove a
21091 // multiply in the process. Check for that here.
21092 // To illustrate:
21093 // (A + c1) * c3
21094 // (A + c2) * c3
21095 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)21096 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
21097 SDValue ConstNode) {
21098 // If the add only has one use, and the target thinks the folding is
21099 // profitable or does not lead to worse code, this would be OK to do.
21100 if (AddNode->hasOneUse() &&
21101 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
21102 return true;
21103
21104 // Walk all the users of the constant with which we're multiplying.
21105 for (SDNode *User : ConstNode->users()) {
21106 if (User == MulNode) // This use is the one we're on right now. Skip it.
21107 continue;
21108
21109 if (User->getOpcode() == ISD::MUL) { // We have another multiply use.
21110 SDNode *OtherOp;
21111 SDNode *MulVar = AddNode.getOperand(0).getNode();
21112
21113 // OtherOp is what we're multiplying against the constant.
21114 if (User->getOperand(0) == ConstNode)
21115 OtherOp = User->getOperand(1).getNode();
21116 else
21117 OtherOp = User->getOperand(0).getNode();
21118
21119 // Check to see if multiply is with the same operand of our "add".
21120 //
21121 // ConstNode = CONST
21122 // User = ConstNode * A <-- visiting User. OtherOp is A.
21123 // ...
21124 // AddNode = (A + c1) <-- MulVar is A.
21125 // = AddNode * ConstNode <-- current visiting instruction.
21126 //
21127 // If we make this transformation, we will have a common
21128 // multiply (ConstNode * A) that we can save.
21129 if (OtherOp == MulVar)
21130 return true;
21131
21132 // Now check to see if a future expansion will give us a common
21133 // multiply.
21134 //
21135 // ConstNode = CONST
21136 // AddNode = (A + c1)
21137 // ... = AddNode * ConstNode <-- current visiting instruction.
21138 // ...
21139 // OtherOp = (A + c2)
21140 // User = OtherOp * ConstNode <-- visiting User.
21141 //
21142 // If we make this transformation, we will have a common
21143 // multiply (CONST * A) after we also do the same transformation
21144 // to the "t2" instruction.
21145 if (OtherOp->getOpcode() == ISD::ADD &&
21146 DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
21147 OtherOp->getOperand(0).getNode() == MulVar)
21148 return true;
21149 }
21150 }
21151
21152 // Didn't find a case where this would be profitable.
21153 return false;
21154 }
21155
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)21156 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
21157 unsigned NumStores) {
21158 SmallVector<SDValue, 8> Chains;
21159 SmallPtrSet<const SDNode *, 8> Visited;
21160 SDLoc StoreDL(StoreNodes[0].MemNode);
21161
21162 for (unsigned i = 0; i < NumStores; ++i) {
21163 Visited.insert(StoreNodes[i].MemNode);
21164 }
21165
21166 // don't include nodes that are children or repeated nodes.
21167 for (unsigned i = 0; i < NumStores; ++i) {
21168 if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
21169 Chains.push_back(StoreNodes[i].MemNode->getChain());
21170 }
21171
21172 assert(!Chains.empty() && "Chain should have generated a chain");
21173 return DAG.getTokenFactor(StoreDL, Chains);
21174 }
21175
hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes)21176 bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
21177 const Value *UnderlyingObj = nullptr;
21178 for (const auto &MemOp : StoreNodes) {
21179 const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
21180 // Pseudo value like stack frame has its own frame index and size, should
21181 // not use the first store's frame index for other frames.
21182 if (MMO->getPseudoValue())
21183 return false;
21184
21185 if (!MMO->getValue())
21186 return false;
21187
21188 const Value *Obj = getUnderlyingObject(MMO->getValue());
21189
21190 if (UnderlyingObj && UnderlyingObj != Obj)
21191 return false;
21192
21193 if (!UnderlyingObj)
21194 UnderlyingObj = Obj;
21195 }
21196
21197 return true;
21198 }
21199
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)21200 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
21201 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
21202 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
21203 // Make sure we have something to merge.
21204 if (NumStores < 2)
21205 return false;
21206
21207 assert((!UseTrunc || !UseVector) &&
21208 "This optimization cannot emit a vector truncating store");
21209
21210 // The latest Node in the DAG.
21211 SDLoc DL(StoreNodes[0].MemNode);
21212
21213 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
21214 unsigned SizeInBits = NumStores * ElementSizeBits;
21215 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21216
21217 std::optional<MachineMemOperand::Flags> Flags;
21218 AAMDNodes AAInfo;
21219 for (unsigned I = 0; I != NumStores; ++I) {
21220 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
21221 if (!Flags) {
21222 Flags = St->getMemOperand()->getFlags();
21223 AAInfo = St->getAAInfo();
21224 continue;
21225 }
21226 // Skip merging if there's an inconsistent flag.
21227 if (Flags != St->getMemOperand()->getFlags())
21228 return false;
21229 // Concatenate AA metadata.
21230 AAInfo = AAInfo.concat(St->getAAInfo());
21231 }
21232
21233 EVT StoreTy;
21234 if (UseVector) {
21235 unsigned Elts = NumStores * NumMemElts;
21236 // Get the type for the merged vector store.
21237 StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
21238 } else
21239 StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
21240
21241 SDValue StoredVal;
21242 if (UseVector) {
21243 if (IsConstantSrc) {
21244 SmallVector<SDValue, 8> BuildVector;
21245 for (unsigned I = 0; I != NumStores; ++I) {
21246 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
21247 SDValue Val = St->getValue();
21248 // If constant is of the wrong type, convert it now. This comes up
21249 // when one of our stores was truncating.
21250 if (MemVT != Val.getValueType()) {
21251 Val = peekThroughBitcasts(Val);
21252 // Deal with constants of wrong size.
21253 if (ElementSizeBits != Val.getValueSizeInBits()) {
21254 auto *C = dyn_cast<ConstantSDNode>(Val);
21255 if (!C)
21256 // Not clear how to truncate FP values.
21257 // TODO: Handle truncation of build_vector constants
21258 return false;
21259
21260 EVT IntMemVT =
21261 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
21262 Val = DAG.getConstant(C->getAPIntValue()
21263 .zextOrTrunc(Val.getValueSizeInBits())
21264 .zextOrTrunc(ElementSizeBits),
21265 SDLoc(C), IntMemVT);
21266 }
21267 // Make sure correctly size type is the correct type.
21268 Val = DAG.getBitcast(MemVT, Val);
21269 }
21270 BuildVector.push_back(Val);
21271 }
21272 StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
21273 : ISD::BUILD_VECTOR,
21274 DL, StoreTy, BuildVector);
21275 } else {
21276 SmallVector<SDValue, 8> Ops;
21277 for (unsigned i = 0; i < NumStores; ++i) {
21278 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
21279 SDValue Val = peekThroughBitcasts(St->getValue());
21280 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
21281 // type MemVT. If the underlying value is not the correct
21282 // type, but it is an extraction of an appropriate vector we
21283 // can recast Val to be of the correct type. This may require
21284 // converting between EXTRACT_VECTOR_ELT and
21285 // EXTRACT_SUBVECTOR.
21286 if ((MemVT != Val.getValueType()) &&
21287 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
21288 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
21289 EVT MemVTScalarTy = MemVT.getScalarType();
21290 // We may need to add a bitcast here to get types to line up.
21291 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
21292 Val = DAG.getBitcast(MemVT, Val);
21293 } else if (MemVT.isVector() &&
21294 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
21295 Val = DAG.getNode(ISD::BUILD_VECTOR, DL, MemVT, Val);
21296 } else {
21297 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
21298 : ISD::EXTRACT_VECTOR_ELT;
21299 SDValue Vec = Val.getOperand(0);
21300 SDValue Idx = Val.getOperand(1);
21301 Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
21302 }
21303 }
21304 Ops.push_back(Val);
21305 }
21306
21307 // Build the extracted vector elements back into a vector.
21308 StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
21309 : ISD::BUILD_VECTOR,
21310 DL, StoreTy, Ops);
21311 }
21312 } else {
21313 // We should always use a vector store when merging extracted vector
21314 // elements, so this path implies a store of constants.
21315 assert(IsConstantSrc && "Merged vector elements should use vector store");
21316
21317 APInt StoreInt(SizeInBits, 0);
21318
21319 // Construct a single integer constant which is made of the smaller
21320 // constant inputs.
21321 bool IsLE = DAG.getDataLayout().isLittleEndian();
21322 for (unsigned i = 0; i < NumStores; ++i) {
21323 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
21324 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
21325
21326 SDValue Val = St->getValue();
21327 Val = peekThroughBitcasts(Val);
21328 StoreInt <<= ElementSizeBits;
21329 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
21330 StoreInt |= C->getAPIntValue()
21331 .zextOrTrunc(ElementSizeBits)
21332 .zextOrTrunc(SizeInBits);
21333 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
21334 StoreInt |= C->getValueAPF()
21335 .bitcastToAPInt()
21336 .zextOrTrunc(ElementSizeBits)
21337 .zextOrTrunc(SizeInBits);
21338 // If fp truncation is necessary give up for now.
21339 if (MemVT.getSizeInBits() != ElementSizeBits)
21340 return false;
21341 } else if (ISD::isBuildVectorOfConstantSDNodes(Val.getNode()) ||
21342 ISD::isBuildVectorOfConstantFPSDNodes(Val.getNode())) {
21343 // Not yet handled
21344 return false;
21345 } else {
21346 llvm_unreachable("Invalid constant element type");
21347 }
21348 }
21349
21350 // Create the new Load and Store operations.
21351 StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
21352 }
21353
21354 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21355 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
21356 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
21357
21358 // make sure we use trunc store if it's necessary to be legal.
21359 // When generate the new widen store, if the first store's pointer info can
21360 // not be reused, discard the pointer info except the address space because
21361 // now the widen store can not be represented by the original pointer info
21362 // which is for the narrow memory object.
21363 SDValue NewStore;
21364 if (!UseTrunc) {
21365 NewStore = DAG.getStore(
21366 NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
21367 CanReusePtrInfo
21368 ? FirstInChain->getPointerInfo()
21369 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
21370 FirstInChain->getAlign(), *Flags, AAInfo);
21371 } else { // Must be realized as a trunc store
21372 EVT LegalizedStoredValTy =
21373 TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
21374 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
21375 ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
21376 SDValue ExtendedStoreVal =
21377 DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
21378 LegalizedStoredValTy);
21379 NewStore = DAG.getTruncStore(
21380 NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
21381 CanReusePtrInfo
21382 ? FirstInChain->getPointerInfo()
21383 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
21384 StoredVal.getValueType() /*TVT*/, FirstInChain->getAlign(), *Flags,
21385 AAInfo);
21386 }
21387
21388 // Replace all merged stores with the new store.
21389 for (unsigned i = 0; i < NumStores; ++i)
21390 CombineTo(StoreNodes[i].MemNode, NewStore);
21391
21392 AddToWorklist(NewChain.getNode());
21393 return true;
21394 }
21395
21396 SDNode *
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes)21397 DAGCombiner::getStoreMergeCandidates(StoreSDNode *St,
21398 SmallVectorImpl<MemOpLink> &StoreNodes) {
21399 // This holds the base pointer, index, and the offset in bytes from the base
21400 // pointer. We must have a base and an offset. Do not handle stores to undef
21401 // base pointers.
21402 BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
21403 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
21404 return nullptr;
21405
21406 SDValue Val = peekThroughBitcasts(St->getValue());
21407 StoreSource StoreSrc = getStoreSource(Val);
21408 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
21409
21410 // Match on loadbaseptr if relevant.
21411 EVT MemVT = St->getMemoryVT();
21412 BaseIndexOffset LBasePtr;
21413 EVT LoadVT;
21414 if (StoreSrc == StoreSource::Load) {
21415 auto *Ld = cast<LoadSDNode>(Val);
21416 LBasePtr = BaseIndexOffset::match(Ld, DAG);
21417 LoadVT = Ld->getMemoryVT();
21418 // Load and store should be the same type.
21419 if (MemVT != LoadVT)
21420 return nullptr;
21421 // Loads must only have one use.
21422 if (!Ld->hasNUsesOfValue(1, 0))
21423 return nullptr;
21424 // The memory operands must not be volatile/indexed/atomic.
21425 // TODO: May be able to relax for unordered atomics (see D66309)
21426 if (!Ld->isSimple() || Ld->isIndexed())
21427 return nullptr;
21428 }
21429 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
21430 int64_t &Offset) -> bool {
21431 // The memory operands must not be volatile/indexed/atomic.
21432 // TODO: May be able to relax for unordered atomics (see D66309)
21433 if (!Other->isSimple() || Other->isIndexed())
21434 return false;
21435 // Don't mix temporal stores with non-temporal stores.
21436 if (St->isNonTemporal() != Other->isNonTemporal())
21437 return false;
21438 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*St, *Other))
21439 return false;
21440 SDValue OtherBC = peekThroughBitcasts(Other->getValue());
21441 // Allow merging constants of different types as integers.
21442 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
21443 : Other->getMemoryVT() != MemVT;
21444 switch (StoreSrc) {
21445 case StoreSource::Load: {
21446 if (NoTypeMatch)
21447 return false;
21448 // The Load's Base Ptr must also match.
21449 auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
21450 if (!OtherLd)
21451 return false;
21452 BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
21453 if (LoadVT != OtherLd->getMemoryVT())
21454 return false;
21455 // Loads must only have one use.
21456 if (!OtherLd->hasNUsesOfValue(1, 0))
21457 return false;
21458 // The memory operands must not be volatile/indexed/atomic.
21459 // TODO: May be able to relax for unordered atomics (see D66309)
21460 if (!OtherLd->isSimple() || OtherLd->isIndexed())
21461 return false;
21462 // Don't mix temporal loads with non-temporal loads.
21463 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
21464 return false;
21465 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*cast<LoadSDNode>(Val),
21466 *OtherLd))
21467 return false;
21468 if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
21469 return false;
21470 break;
21471 }
21472 case StoreSource::Constant:
21473 if (NoTypeMatch)
21474 return false;
21475 if (getStoreSource(OtherBC) != StoreSource::Constant)
21476 return false;
21477 break;
21478 case StoreSource::Extract:
21479 // Do not merge truncated stores here.
21480 if (Other->isTruncatingStore())
21481 return false;
21482 if (!MemVT.bitsEq(OtherBC.getValueType()))
21483 return false;
21484 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
21485 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
21486 return false;
21487 break;
21488 default:
21489 llvm_unreachable("Unhandled store source for merging");
21490 }
21491 Ptr = BaseIndexOffset::match(Other, DAG);
21492 return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
21493 };
21494
21495 // We are looking for a root node which is an ancestor to all mergable
21496 // stores. We search up through a load, to our root and then down
21497 // through all children. For instance we will find Store{1,2,3} if
21498 // St is Store1, Store2. or Store3 where the root is not a load
21499 // which always true for nonvolatile ops. TODO: Expand
21500 // the search to find all valid candidates through multiple layers of loads.
21501 //
21502 // Root
21503 // |-------|-------|
21504 // Load Load Store3
21505 // | |
21506 // Store1 Store2
21507 //
21508 // FIXME: We should be able to climb and
21509 // descend TokenFactors to find candidates as well.
21510
21511 SDNode *RootNode = St->getChain().getNode();
21512 // Bail out if we already analyzed this root node and found nothing.
21513 if (ChainsWithoutMergeableStores.contains(RootNode))
21514 return nullptr;
21515
21516 // Check if the pair of StoreNode and the RootNode already bail out many
21517 // times which is over the limit in dependence check.
21518 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
21519 SDNode *RootNode) -> bool {
21520 auto RootCount = StoreRootCountMap.find(StoreNode);
21521 return RootCount != StoreRootCountMap.end() &&
21522 RootCount->second.first == RootNode &&
21523 RootCount->second.second > StoreMergeDependenceLimit;
21524 };
21525
21526 auto TryToAddCandidate = [&](SDUse &Use) {
21527 // This must be a chain use.
21528 if (Use.getOperandNo() != 0)
21529 return;
21530 if (auto *OtherStore = dyn_cast<StoreSDNode>(Use.getUser())) {
21531 BaseIndexOffset Ptr;
21532 int64_t PtrDiff;
21533 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
21534 !OverLimitInDependenceCheck(OtherStore, RootNode))
21535 StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
21536 }
21537 };
21538
21539 unsigned NumNodesExplored = 0;
21540 const unsigned MaxSearchNodes = 1024;
21541 if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
21542 RootNode = Ldn->getChain().getNode();
21543 // Bail out if we already analyzed this root node and found nothing.
21544 if (ChainsWithoutMergeableStores.contains(RootNode))
21545 return nullptr;
21546 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
21547 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
21548 SDNode *User = I->getUser();
21549 if (I->getOperandNo() == 0 && isa<LoadSDNode>(User)) { // walk down chain
21550 for (SDUse &U2 : User->uses())
21551 TryToAddCandidate(U2);
21552 }
21553 // Check stores that depend on the root (e.g. Store 3 in the chart above).
21554 if (I->getOperandNo() == 0 && isa<StoreSDNode>(User)) {
21555 TryToAddCandidate(*I);
21556 }
21557 }
21558 } else {
21559 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
21560 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
21561 TryToAddCandidate(*I);
21562 }
21563
21564 return RootNode;
21565 }
21566
21567 // We need to check that merging these stores does not cause a loop in the
21568 // DAG. Any store candidate may depend on another candidate indirectly through
21569 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)21570 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
21571 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
21572 SDNode *RootNode) {
21573 // FIXME: We should be able to truncate a full search of
21574 // predecessors by doing a BFS and keeping tabs the originating
21575 // stores from which worklist nodes come from in a similar way to
21576 // TokenFactor simplfication.
21577
21578 SmallPtrSet<const SDNode *, 32> Visited;
21579 SmallVector<const SDNode *, 8> Worklist;
21580
21581 // RootNode is a predecessor to all candidates so we need not search
21582 // past it. Add RootNode (peeking through TokenFactors). Do not count
21583 // these towards size check.
21584
21585 Worklist.push_back(RootNode);
21586 while (!Worklist.empty()) {
21587 auto N = Worklist.pop_back_val();
21588 if (!Visited.insert(N).second)
21589 continue; // Already present in Visited.
21590 if (N->getOpcode() == ISD::TokenFactor) {
21591 for (SDValue Op : N->ops())
21592 Worklist.push_back(Op.getNode());
21593 }
21594 }
21595
21596 // Don't count pruning nodes towards max.
21597 unsigned int Max = 1024 + Visited.size();
21598 // Search Ops of store candidates.
21599 for (unsigned i = 0; i < NumStores; ++i) {
21600 SDNode *N = StoreNodes[i].MemNode;
21601 // Of the 4 Store Operands:
21602 // * Chain (Op 0) -> We have already considered these
21603 // in candidate selection, but only by following the
21604 // chain dependencies. We could still have a chain
21605 // dependency to a load, that has a non-chain dep to
21606 // another load, that depends on a store, etc. So it is
21607 // possible to have dependencies that consist of a mix
21608 // of chain and non-chain deps, and we need to include
21609 // chain operands in the analysis here..
21610 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
21611 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
21612 // but aren't necessarily fromt the same base node, so
21613 // cycles possible (e.g. via indexed store).
21614 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
21615 // non-indexed stores). Not constant on all targets (e.g. ARM)
21616 // and so can participate in a cycle.
21617 for (const SDValue &Op : N->op_values())
21618 Worklist.push_back(Op.getNode());
21619 }
21620 // Search through DAG. We can stop early if we find a store node.
21621 for (unsigned i = 0; i < NumStores; ++i)
21622 if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
21623 Max)) {
21624 // If the searching bail out, record the StoreNode and RootNode in the
21625 // StoreRootCountMap. If we have seen the pair many times over a limit,
21626 // we won't add the StoreNode into StoreNodes set again.
21627 if (Visited.size() >= Max) {
21628 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
21629 if (RootCount.first == RootNode)
21630 RootCount.second++;
21631 else
21632 RootCount = {RootNode, 1};
21633 }
21634 return false;
21635 }
21636 return true;
21637 }
21638
hasCallInLdStChain(StoreSDNode * St,LoadSDNode * Ld)21639 bool DAGCombiner::hasCallInLdStChain(StoreSDNode *St, LoadSDNode *Ld) {
21640 SmallPtrSet<const SDNode *, 32> Visited;
21641 SmallVector<std::pair<const SDNode *, bool>, 8> Worklist;
21642 Worklist.emplace_back(St->getChain().getNode(), false);
21643
21644 while (!Worklist.empty()) {
21645 auto [Node, FoundCall] = Worklist.pop_back_val();
21646 if (!Visited.insert(Node).second || Node->getNumOperands() == 0)
21647 continue;
21648
21649 switch (Node->getOpcode()) {
21650 case ISD::CALLSEQ_END:
21651 Worklist.emplace_back(Node->getOperand(0).getNode(), true);
21652 break;
21653 case ISD::TokenFactor:
21654 for (SDValue Op : Node->ops())
21655 Worklist.emplace_back(Op.getNode(), FoundCall);
21656 break;
21657 case ISD::LOAD:
21658 if (Node == Ld)
21659 return FoundCall;
21660 [[fallthrough]];
21661 default:
21662 assert(Node->getOperand(0).getValueType() == MVT::Other &&
21663 "Invalid chain type");
21664 Worklist.emplace_back(Node->getOperand(0).getNode(), FoundCall);
21665 break;
21666 }
21667 }
21668 return false;
21669 }
21670
21671 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const21672 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
21673 int64_t ElementSizeBytes) const {
21674 while (true) {
21675 // Find a store past the width of the first store.
21676 size_t StartIdx = 0;
21677 while ((StartIdx + 1 < StoreNodes.size()) &&
21678 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
21679 StoreNodes[StartIdx + 1].OffsetFromBase)
21680 ++StartIdx;
21681
21682 // Bail if we don't have enough candidates to merge.
21683 if (StartIdx + 1 >= StoreNodes.size())
21684 return 0;
21685
21686 // Trim stores that overlapped with the first store.
21687 if (StartIdx)
21688 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
21689
21690 // Scan the memory operations on the chain and find the first
21691 // non-consecutive store memory address.
21692 unsigned NumConsecutiveStores = 1;
21693 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
21694 // Check that the addresses are consecutive starting from the second
21695 // element in the list of stores.
21696 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
21697 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
21698 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
21699 break;
21700 NumConsecutiveStores = i + 1;
21701 }
21702 if (NumConsecutiveStores > 1)
21703 return NumConsecutiveStores;
21704
21705 // There are no consecutive stores at the start of the list.
21706 // Remove the first store and try again.
21707 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
21708 }
21709 }
21710
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)21711 bool DAGCombiner::tryStoreMergeOfConstants(
21712 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
21713 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
21714 LLVMContext &Context = *DAG.getContext();
21715 const DataLayout &DL = DAG.getDataLayout();
21716 int64_t ElementSizeBytes = MemVT.getStoreSize();
21717 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21718 bool MadeChange = false;
21719
21720 // Store the constants into memory as one consecutive store.
21721 while (NumConsecutiveStores >= 2) {
21722 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21723 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21724 Align FirstStoreAlign = FirstInChain->getAlign();
21725 unsigned LastLegalType = 1;
21726 unsigned LastLegalVectorType = 1;
21727 bool LastIntegerTrunc = false;
21728 bool NonZero = false;
21729 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
21730 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21731 StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
21732 SDValue StoredVal = ST->getValue();
21733 bool IsElementZero = false;
21734 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
21735 IsElementZero = C->isZero();
21736 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
21737 IsElementZero = C->getConstantFPValue()->isNullValue();
21738 else if (ISD::isBuildVectorAllZeros(StoredVal.getNode()))
21739 IsElementZero = true;
21740 if (IsElementZero) {
21741 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
21742 FirstZeroAfterNonZero = i;
21743 }
21744 NonZero |= !IsElementZero;
21745
21746 // Find a legal type for the constant store.
21747 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
21748 EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
21749 unsigned IsFast = 0;
21750
21751 // Break early when size is too large to be legal.
21752 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
21753 break;
21754
21755 if (TLI.isTypeLegal(StoreTy) &&
21756 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
21757 DAG.getMachineFunction()) &&
21758 TLI.allowsMemoryAccess(Context, DL, StoreTy,
21759 *FirstInChain->getMemOperand(), &IsFast) &&
21760 IsFast) {
21761 LastIntegerTrunc = false;
21762 LastLegalType = i + 1;
21763 // Or check whether a truncstore is legal.
21764 } else if (TLI.getTypeAction(Context, StoreTy) ==
21765 TargetLowering::TypePromoteInteger) {
21766 EVT LegalizedStoredValTy =
21767 TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
21768 if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
21769 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
21770 DAG.getMachineFunction()) &&
21771 TLI.allowsMemoryAccess(Context, DL, StoreTy,
21772 *FirstInChain->getMemOperand(), &IsFast) &&
21773 IsFast) {
21774 LastIntegerTrunc = true;
21775 LastLegalType = i + 1;
21776 }
21777 }
21778
21779 // We only use vectors if the target allows it and the function is not
21780 // marked with the noimplicitfloat attribute.
21781 if (TLI.storeOfVectorConstantIsCheap(!NonZero, MemVT, i + 1, FirstStoreAS) &&
21782 AllowVectors) {
21783 // Find a legal type for the vector store.
21784 unsigned Elts = (i + 1) * NumMemElts;
21785 EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
21786 if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
21787 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
21788 TLI.allowsMemoryAccess(Context, DL, Ty,
21789 *FirstInChain->getMemOperand(), &IsFast) &&
21790 IsFast)
21791 LastLegalVectorType = i + 1;
21792 }
21793 }
21794
21795 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
21796 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
21797 bool UseTrunc = LastIntegerTrunc && !UseVector;
21798
21799 // Check if we found a legal integer type that creates a meaningful
21800 // merge.
21801 if (NumElem < 2) {
21802 // We know that candidate stores are in order and of correct
21803 // shape. While there is no mergeable sequence from the
21804 // beginning one may start later in the sequence. The only
21805 // reason a merge of size N could have failed where another of
21806 // the same size would not have, is if the alignment has
21807 // improved or we've dropped a non-zero value. Drop as many
21808 // candidates as we can here.
21809 unsigned NumSkip = 1;
21810 while ((NumSkip < NumConsecutiveStores) &&
21811 (NumSkip < FirstZeroAfterNonZero) &&
21812 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21813 NumSkip++;
21814
21815 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
21816 NumConsecutiveStores -= NumSkip;
21817 continue;
21818 }
21819
21820 // Check that we can merge these candidates without causing a cycle.
21821 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
21822 RootNode)) {
21823 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21824 NumConsecutiveStores -= NumElem;
21825 continue;
21826 }
21827
21828 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
21829 /*IsConstantSrc*/ true,
21830 UseVector, UseTrunc);
21831
21832 // Remove merged stores for next iteration.
21833 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
21834 NumConsecutiveStores -= NumElem;
21835 }
21836 return MadeChange;
21837 }
21838
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)21839 bool DAGCombiner::tryStoreMergeOfExtracts(
21840 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
21841 EVT MemVT, SDNode *RootNode) {
21842 LLVMContext &Context = *DAG.getContext();
21843 const DataLayout &DL = DAG.getDataLayout();
21844 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21845 bool MadeChange = false;
21846
21847 // Loop on Consecutive Stores on success.
21848 while (NumConsecutiveStores >= 2) {
21849 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21850 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21851 Align FirstStoreAlign = FirstInChain->getAlign();
21852 unsigned NumStoresToMerge = 1;
21853 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21854 // Find a legal type for the vector store.
21855 unsigned Elts = (i + 1) * NumMemElts;
21856 EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
21857 unsigned IsFast = 0;
21858
21859 // Break early when size is too large to be legal.
21860 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
21861 break;
21862
21863 if (TLI.isTypeLegal(Ty) &&
21864 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
21865 TLI.allowsMemoryAccess(Context, DL, Ty,
21866 *FirstInChain->getMemOperand(), &IsFast) &&
21867 IsFast)
21868 NumStoresToMerge = i + 1;
21869 }
21870
21871 // Check if we found a legal integer type creating a meaningful
21872 // merge.
21873 if (NumStoresToMerge < 2) {
21874 // We know that candidate stores are in order and of correct
21875 // shape. While there is no mergeable sequence from the
21876 // beginning one may start later in the sequence. The only
21877 // reason a merge of size N could have failed where another of
21878 // the same size would not have, is if the alignment has
21879 // improved. Drop as many candidates as we can here.
21880 unsigned NumSkip = 1;
21881 while ((NumSkip < NumConsecutiveStores) &&
21882 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
21883 NumSkip++;
21884
21885 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
21886 NumConsecutiveStores -= NumSkip;
21887 continue;
21888 }
21889
21890 // Check that we can merge these candidates without causing a cycle.
21891 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
21892 RootNode)) {
21893 StoreNodes.erase(StoreNodes.begin(),
21894 StoreNodes.begin() + NumStoresToMerge);
21895 NumConsecutiveStores -= NumStoresToMerge;
21896 continue;
21897 }
21898
21899 MadeChange |= mergeStoresOfConstantsOrVecElts(
21900 StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
21901 /*UseVector*/ true, /*UseTrunc*/ false);
21902
21903 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
21904 NumConsecutiveStores -= NumStoresToMerge;
21905 }
21906 return MadeChange;
21907 }
21908
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)21909 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
21910 unsigned NumConsecutiveStores, EVT MemVT,
21911 SDNode *RootNode, bool AllowVectors,
21912 bool IsNonTemporalStore,
21913 bool IsNonTemporalLoad) {
21914 LLVMContext &Context = *DAG.getContext();
21915 const DataLayout &DL = DAG.getDataLayout();
21916 int64_t ElementSizeBytes = MemVT.getStoreSize();
21917 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
21918 bool MadeChange = false;
21919
21920 // Look for load nodes which are used by the stored values.
21921 SmallVector<MemOpLink, 8> LoadNodes;
21922
21923 // Find acceptable loads. Loads need to have the same chain (token factor),
21924 // must not be zext, volatile, indexed, and they must be consecutive.
21925 BaseIndexOffset LdBasePtr;
21926
21927 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
21928 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
21929 SDValue Val = peekThroughBitcasts(St->getValue());
21930 LoadSDNode *Ld = cast<LoadSDNode>(Val);
21931
21932 BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
21933 // If this is not the first ptr that we check.
21934 int64_t LdOffset = 0;
21935 if (LdBasePtr.getBase().getNode()) {
21936 // The base ptr must be the same.
21937 if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
21938 break;
21939 } else {
21940 // Check that all other base pointers are the same as this one.
21941 LdBasePtr = LdPtr;
21942 }
21943
21944 // We found a potential memory operand to merge.
21945 LoadNodes.push_back(MemOpLink(Ld, LdOffset));
21946 }
21947
21948 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
21949 Align RequiredAlignment;
21950 bool NeedRotate = false;
21951 if (LoadNodes.size() == 2) {
21952 // If we have load/store pair instructions and we only have two values,
21953 // don't bother merging.
21954 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
21955 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
21956 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
21957 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
21958 break;
21959 }
21960 // If the loads are reversed, see if we can rotate the halves into place.
21961 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
21962 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
21963 EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
21964 if (Offset0 - Offset1 == ElementSizeBytes &&
21965 (hasOperation(ISD::ROTL, PairVT) ||
21966 hasOperation(ISD::ROTR, PairVT))) {
21967 std::swap(LoadNodes[0], LoadNodes[1]);
21968 NeedRotate = true;
21969 }
21970 }
21971 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
21972 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
21973 Align FirstStoreAlign = FirstInChain->getAlign();
21974 LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
21975
21976 // Scan the memory operations on the chain and find the first
21977 // non-consecutive load memory address. These variables hold the index in
21978 // the store node array.
21979
21980 unsigned LastConsecutiveLoad = 1;
21981
21982 // This variable refers to the size and not index in the array.
21983 unsigned LastLegalVectorType = 1;
21984 unsigned LastLegalIntegerType = 1;
21985 bool isDereferenceable = true;
21986 bool DoIntegerTruncate = false;
21987 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
21988 SDValue LoadChain = FirstLoad->getChain();
21989 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
21990 // All loads must share the same chain.
21991 if (LoadNodes[i].MemNode->getChain() != LoadChain)
21992 break;
21993
21994 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
21995 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
21996 break;
21997 LastConsecutiveLoad = i;
21998
21999 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
22000 isDereferenceable = false;
22001
22002 // Find a legal type for the vector store.
22003 unsigned Elts = (i + 1) * NumMemElts;
22004 EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
22005
22006 // Break early when size is too large to be legal.
22007 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
22008 break;
22009
22010 unsigned IsFastSt = 0;
22011 unsigned IsFastLd = 0;
22012 // Don't try vector types if we need a rotate. We may still fail the
22013 // legality checks for the integer type, but we can't handle the rotate
22014 // case with vectors.
22015 // FIXME: We could use a shuffle in place of the rotate.
22016 if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
22017 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
22018 DAG.getMachineFunction()) &&
22019 TLI.allowsMemoryAccess(Context, DL, StoreTy,
22020 *FirstInChain->getMemOperand(), &IsFastSt) &&
22021 IsFastSt &&
22022 TLI.allowsMemoryAccess(Context, DL, StoreTy,
22023 *FirstLoad->getMemOperand(), &IsFastLd) &&
22024 IsFastLd) {
22025 LastLegalVectorType = i + 1;
22026 }
22027
22028 // Find a legal type for the integer store.
22029 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
22030 StoreTy = EVT::getIntegerVT(Context, SizeInBits);
22031 if (TLI.isTypeLegal(StoreTy) &&
22032 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
22033 DAG.getMachineFunction()) &&
22034 TLI.allowsMemoryAccess(Context, DL, StoreTy,
22035 *FirstInChain->getMemOperand(), &IsFastSt) &&
22036 IsFastSt &&
22037 TLI.allowsMemoryAccess(Context, DL, StoreTy,
22038 *FirstLoad->getMemOperand(), &IsFastLd) &&
22039 IsFastLd) {
22040 LastLegalIntegerType = i + 1;
22041 DoIntegerTruncate = false;
22042 // Or check whether a truncstore and extload is legal.
22043 } else if (TLI.getTypeAction(Context, StoreTy) ==
22044 TargetLowering::TypePromoteInteger) {
22045 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
22046 if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
22047 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
22048 DAG.getMachineFunction()) &&
22049 TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
22050 TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
22051 TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
22052 TLI.allowsMemoryAccess(Context, DL, StoreTy,
22053 *FirstInChain->getMemOperand(), &IsFastSt) &&
22054 IsFastSt &&
22055 TLI.allowsMemoryAccess(Context, DL, StoreTy,
22056 *FirstLoad->getMemOperand(), &IsFastLd) &&
22057 IsFastLd) {
22058 LastLegalIntegerType = i + 1;
22059 DoIntegerTruncate = true;
22060 }
22061 }
22062 }
22063
22064 // Only use vector types if the vector type is larger than the integer
22065 // type. If they are the same, use integers.
22066 bool UseVectorTy =
22067 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
22068 unsigned LastLegalType =
22069 std::max(LastLegalVectorType, LastLegalIntegerType);
22070
22071 // We add +1 here because the LastXXX variables refer to location while
22072 // the NumElem refers to array/index size.
22073 unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
22074 NumElem = std::min(LastLegalType, NumElem);
22075 Align FirstLoadAlign = FirstLoad->getAlign();
22076
22077 if (NumElem < 2) {
22078 // We know that candidate stores are in order and of correct
22079 // shape. While there is no mergeable sequence from the
22080 // beginning one may start later in the sequence. The only
22081 // reason a merge of size N could have failed where another of
22082 // the same size would not have is if the alignment or either
22083 // the load or store has improved. Drop as many candidates as we
22084 // can here.
22085 unsigned NumSkip = 1;
22086 while ((NumSkip < LoadNodes.size()) &&
22087 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
22088 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
22089 NumSkip++;
22090 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
22091 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
22092 NumConsecutiveStores -= NumSkip;
22093 continue;
22094 }
22095
22096 // Check that we can merge these candidates without causing a cycle.
22097 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
22098 RootNode)) {
22099 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
22100 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
22101 NumConsecutiveStores -= NumElem;
22102 continue;
22103 }
22104
22105 // Find if it is better to use vectors or integers to load and store
22106 // to memory.
22107 EVT JointMemOpVT;
22108 if (UseVectorTy) {
22109 // Find a legal type for the vector store.
22110 unsigned Elts = NumElem * NumMemElts;
22111 JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
22112 } else {
22113 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
22114 JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
22115 }
22116
22117 // Check if there is a call in the load/store chain.
22118 if (!TLI.shouldMergeStoreOfLoadsOverCall(MemVT, JointMemOpVT) &&
22119 hasCallInLdStChain(cast<StoreSDNode>(StoreNodes[0].MemNode),
22120 cast<LoadSDNode>(LoadNodes[0].MemNode))) {
22121 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
22122 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
22123 NumConsecutiveStores -= NumElem;
22124 continue;
22125 }
22126
22127 SDLoc LoadDL(LoadNodes[0].MemNode);
22128 SDLoc StoreDL(StoreNodes[0].MemNode);
22129
22130 // The merged loads are required to have the same incoming chain, so
22131 // using the first's chain is acceptable.
22132
22133 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
22134 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
22135 AddToWorklist(NewStoreChain.getNode());
22136
22137 MachineMemOperand::Flags LdMMOFlags =
22138 isDereferenceable ? MachineMemOperand::MODereferenceable
22139 : MachineMemOperand::MONone;
22140 if (IsNonTemporalLoad)
22141 LdMMOFlags |= MachineMemOperand::MONonTemporal;
22142
22143 LdMMOFlags |= TLI.getTargetMMOFlags(*FirstLoad);
22144
22145 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
22146 ? MachineMemOperand::MONonTemporal
22147 : MachineMemOperand::MONone;
22148
22149 StMMOFlags |= TLI.getTargetMMOFlags(*StoreNodes[0].MemNode);
22150
22151 SDValue NewLoad, NewStore;
22152 if (UseVectorTy || !DoIntegerTruncate) {
22153 NewLoad = DAG.getLoad(
22154 JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
22155 FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
22156 SDValue StoreOp = NewLoad;
22157 if (NeedRotate) {
22158 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
22159 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
22160 "Unexpected type for rotate-able load pair");
22161 SDValue RotAmt =
22162 DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
22163 // Target can convert to the identical ROTR if it does not have ROTL.
22164 StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
22165 }
22166 NewStore = DAG.getStore(
22167 NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
22168 CanReusePtrInfo ? FirstInChain->getPointerInfo()
22169 : MachinePointerInfo(FirstStoreAS),
22170 FirstStoreAlign, StMMOFlags);
22171 } else { // This must be the truncstore/extload case
22172 EVT ExtendedTy =
22173 TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
22174 NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
22175 FirstLoad->getChain(), FirstLoad->getBasePtr(),
22176 FirstLoad->getPointerInfo(), JointMemOpVT,
22177 FirstLoadAlign, LdMMOFlags);
22178 NewStore = DAG.getTruncStore(
22179 NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
22180 CanReusePtrInfo ? FirstInChain->getPointerInfo()
22181 : MachinePointerInfo(FirstStoreAS),
22182 JointMemOpVT, FirstInChain->getAlign(),
22183 FirstInChain->getMemOperand()->getFlags());
22184 }
22185
22186 // Transfer chain users from old loads to the new load.
22187 for (unsigned i = 0; i < NumElem; ++i) {
22188 LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
22189 DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
22190 SDValue(NewLoad.getNode(), 1));
22191 }
22192
22193 // Replace all stores with the new store. Recursively remove corresponding
22194 // values if they are no longer used.
22195 for (unsigned i = 0; i < NumElem; ++i) {
22196 SDValue Val = StoreNodes[i].MemNode->getOperand(1);
22197 CombineTo(StoreNodes[i].MemNode, NewStore);
22198 if (Val->use_empty())
22199 recursivelyDeleteUnusedNodes(Val.getNode());
22200 }
22201
22202 MadeChange = true;
22203 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
22204 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
22205 NumConsecutiveStores -= NumElem;
22206 }
22207 return MadeChange;
22208 }
22209
mergeConsecutiveStores(StoreSDNode * St)22210 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
22211 if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
22212 return false;
22213
22214 // TODO: Extend this function to merge stores of scalable vectors.
22215 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
22216 // store since we know <vscale x 16 x i8> is exactly twice as large as
22217 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
22218 EVT MemVT = St->getMemoryVT();
22219 if (MemVT.isScalableVT())
22220 return false;
22221 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
22222 return false;
22223
22224 // This function cannot currently deal with non-byte-sized memory sizes.
22225 int64_t ElementSizeBytes = MemVT.getStoreSize();
22226 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
22227 return false;
22228
22229 // Do not bother looking at stored values that are not constants, loads, or
22230 // extracted vector elements.
22231 SDValue StoredVal = peekThroughBitcasts(St->getValue());
22232 const StoreSource StoreSrc = getStoreSource(StoredVal);
22233 if (StoreSrc == StoreSource::Unknown)
22234 return false;
22235
22236 SmallVector<MemOpLink, 8> StoreNodes;
22237 // Find potential store merge candidates by searching through chain sub-DAG
22238 SDNode *RootNode = getStoreMergeCandidates(St, StoreNodes);
22239
22240 // Check if there is anything to merge.
22241 if (StoreNodes.size() < 2)
22242 return false;
22243
22244 // Sort the memory operands according to their distance from the
22245 // base pointer.
22246 llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
22247 return LHS.OffsetFromBase < RHS.OffsetFromBase;
22248 });
22249
22250 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
22251 Attribute::NoImplicitFloat);
22252 bool IsNonTemporalStore = St->isNonTemporal();
22253 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
22254 cast<LoadSDNode>(StoredVal)->isNonTemporal();
22255
22256 // Store Merge attempts to merge the lowest stores. This generally
22257 // works out as if successful, as the remaining stores are checked
22258 // after the first collection of stores is merged. However, in the
22259 // case that a non-mergeable store is found first, e.g., {p[-2],
22260 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
22261 // mergeable cases. To prevent this, we prune such stores from the
22262 // front of StoreNodes here.
22263 bool MadeChange = false;
22264 while (StoreNodes.size() > 1) {
22265 unsigned NumConsecutiveStores =
22266 getConsecutiveStores(StoreNodes, ElementSizeBytes);
22267 // There are no more stores in the list to examine.
22268 if (NumConsecutiveStores == 0)
22269 return MadeChange;
22270
22271 // We have at least 2 consecutive stores. Try to merge them.
22272 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
22273 switch (StoreSrc) {
22274 case StoreSource::Constant:
22275 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
22276 MemVT, RootNode, AllowVectors);
22277 break;
22278
22279 case StoreSource::Extract:
22280 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
22281 MemVT, RootNode);
22282 break;
22283
22284 case StoreSource::Load:
22285 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
22286 MemVT, RootNode, AllowVectors,
22287 IsNonTemporalStore, IsNonTemporalLoad);
22288 break;
22289
22290 default:
22291 llvm_unreachable("Unhandled store source type");
22292 }
22293 }
22294
22295 // Remember if we failed to optimize, to save compile time.
22296 if (!MadeChange)
22297 ChainsWithoutMergeableStores.insert(RootNode);
22298
22299 return MadeChange;
22300 }
22301
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)22302 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
22303 SDLoc SL(ST);
22304 SDValue ReplStore;
22305
22306 // Replace the chain to avoid dependency.
22307 if (ST->isTruncatingStore()) {
22308 ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
22309 ST->getBasePtr(), ST->getMemoryVT(),
22310 ST->getMemOperand());
22311 } else {
22312 ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
22313 ST->getMemOperand());
22314 }
22315
22316 // Create token to keep both nodes around.
22317 SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
22318 MVT::Other, ST->getChain(), ReplStore);
22319
22320 // Make sure the new and old chains are cleaned up.
22321 AddToWorklist(Token.getNode());
22322
22323 // Don't add users to work list.
22324 return CombineTo(ST, Token, false);
22325 }
22326
replaceStoreOfFPConstant(StoreSDNode * ST)22327 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
22328 SDValue Value = ST->getValue();
22329 if (Value.getOpcode() == ISD::TargetConstantFP)
22330 return SDValue();
22331
22332 if (!ISD::isNormalStore(ST))
22333 return SDValue();
22334
22335 SDLoc DL(ST);
22336
22337 SDValue Chain = ST->getChain();
22338 SDValue Ptr = ST->getBasePtr();
22339
22340 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
22341
22342 // NOTE: If the original store is volatile, this transform must not increase
22343 // the number of stores. For example, on x86-32 an f64 can be stored in one
22344 // processor operation but an i64 (which is not legal) requires two. So the
22345 // transform should not be done in this case.
22346
22347 SDValue Tmp;
22348 switch (CFP->getSimpleValueType(0).SimpleTy) {
22349 default:
22350 llvm_unreachable("Unknown FP type");
22351 case MVT::f16: // We don't do this for these yet.
22352 case MVT::bf16:
22353 case MVT::f80:
22354 case MVT::f128:
22355 case MVT::ppcf128:
22356 return SDValue();
22357 case MVT::f32:
22358 if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
22359 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
22360 Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
22361 bitcastToAPInt().getZExtValue(), SDLoc(CFP),
22362 MVT::i32);
22363 return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
22364 }
22365
22366 return SDValue();
22367 case MVT::f64:
22368 if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
22369 ST->isSimple()) ||
22370 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
22371 Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
22372 getZExtValue(), SDLoc(CFP), MVT::i64);
22373 return DAG.getStore(Chain, DL, Tmp,
22374 Ptr, ST->getMemOperand());
22375 }
22376
22377 if (ST->isSimple() && TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32) &&
22378 !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
22379 // Many FP stores are not made apparent until after legalize, e.g. for
22380 // argument passing. Since this is so common, custom legalize the
22381 // 64-bit integer store into two 32-bit stores.
22382 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
22383 SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
22384 SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
22385 if (DAG.getDataLayout().isBigEndian())
22386 std::swap(Lo, Hi);
22387
22388 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
22389 AAMDNodes AAInfo = ST->getAAInfo();
22390
22391 SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
22392 ST->getBaseAlign(), MMOFlags, AAInfo);
22393 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(4), DL);
22394 SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
22395 ST->getPointerInfo().getWithOffset(4),
22396 ST->getBaseAlign(), MMOFlags, AAInfo);
22397 return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
22398 St0, St1);
22399 }
22400
22401 return SDValue();
22402 }
22403 }
22404
22405 // (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
22406 //
22407 // If a store of a load with an element inserted into it has no other
22408 // uses in between the chain, then we can consider the vector store
22409 // dead and replace it with just the single scalar element store.
replaceStoreOfInsertLoad(StoreSDNode * ST)22410 SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
22411 SDLoc DL(ST);
22412 SDValue Value = ST->getValue();
22413 SDValue Ptr = ST->getBasePtr();
22414 SDValue Chain = ST->getChain();
22415 if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
22416 return SDValue();
22417
22418 SDValue Elt = Value.getOperand(1);
22419 SDValue Idx = Value.getOperand(2);
22420
22421 // If the element isn't byte sized or is implicitly truncated then we can't
22422 // compute an offset.
22423 EVT EltVT = Elt.getValueType();
22424 if (!EltVT.isByteSized() ||
22425 EltVT != Value.getOperand(0).getValueType().getVectorElementType())
22426 return SDValue();
22427
22428 auto *Ld = dyn_cast<LoadSDNode>(Value.getOperand(0));
22429 if (!Ld || Ld->getBasePtr() != Ptr ||
22430 ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
22431 !ISD::isNormalStore(ST) ||
22432 Ld->getAddressSpace() != ST->getAddressSpace() ||
22433 !Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1)))
22434 return SDValue();
22435
22436 unsigned IsFast;
22437 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
22438 Elt.getValueType(), ST->getAddressSpace(),
22439 ST->getAlign(), ST->getMemOperand()->getFlags(),
22440 &IsFast) ||
22441 !IsFast)
22442 return SDValue();
22443
22444 MachinePointerInfo PointerInfo(ST->getAddressSpace());
22445
22446 // If the offset is a known constant then try to recover the pointer
22447 // info
22448 SDValue NewPtr;
22449 if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) {
22450 unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
22451 NewPtr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(COffset), DL);
22452 PointerInfo = ST->getPointerInfo().getWithOffset(COffset);
22453 } else {
22454 NewPtr = TLI.getVectorElementPointer(DAG, Ptr, Value.getValueType(), Idx);
22455 }
22456
22457 return DAG.getStore(Chain, DL, Elt, NewPtr, PointerInfo, ST->getAlign(),
22458 ST->getMemOperand()->getFlags());
22459 }
22460
visitATOMIC_STORE(SDNode * N)22461 SDValue DAGCombiner::visitATOMIC_STORE(SDNode *N) {
22462 AtomicSDNode *ST = cast<AtomicSDNode>(N);
22463 SDValue Val = ST->getVal();
22464 EVT VT = Val.getValueType();
22465 EVT MemVT = ST->getMemoryVT();
22466
22467 if (MemVT.bitsLT(VT)) { // Is truncating store
22468 APInt TruncDemandedBits = APInt::getLowBitsSet(VT.getScalarSizeInBits(),
22469 MemVT.getScalarSizeInBits());
22470 // See if we can simplify the operation with SimplifyDemandedBits, which
22471 // only works if the value has a single use.
22472 if (SimplifyDemandedBits(Val, TruncDemandedBits))
22473 return SDValue(N, 0);
22474 }
22475
22476 return SDValue();
22477 }
22478
visitSTORE(SDNode * N)22479 SDValue DAGCombiner::visitSTORE(SDNode *N) {
22480 StoreSDNode *ST = cast<StoreSDNode>(N);
22481 SDValue Chain = ST->getChain();
22482 SDValue Value = ST->getValue();
22483 SDValue Ptr = ST->getBasePtr();
22484
22485 // If this is a store of a bit convert, store the input value if the
22486 // resultant store does not need a higher alignment than the original.
22487 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
22488 ST->isUnindexed()) {
22489 EVT SVT = Value.getOperand(0).getValueType();
22490 // If the store is volatile, we only want to change the store type if the
22491 // resulting store is legal. Otherwise we might increase the number of
22492 // memory accesses. We don't care if the original type was legal or not
22493 // as we assume software couldn't rely on the number of accesses of an
22494 // illegal type.
22495 // TODO: May be able to relax for unordered atomics (see D66309)
22496 if (((!LegalOperations && ST->isSimple()) ||
22497 TLI.isOperationLegal(ISD::STORE, SVT)) &&
22498 TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
22499 DAG, *ST->getMemOperand())) {
22500 return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
22501 ST->getMemOperand());
22502 }
22503 }
22504
22505 // Turn 'store undef, Ptr' -> nothing.
22506 if (Value.isUndef() && ST->isUnindexed() && !ST->isVolatile())
22507 return Chain;
22508
22509 // Try to infer better alignment information than the store already has.
22510 if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
22511 !ST->isAtomic()) {
22512 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
22513 if (*Alignment > ST->getAlign() &&
22514 isAligned(*Alignment, ST->getSrcValueOffset())) {
22515 SDValue NewStore =
22516 DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
22517 ST->getMemoryVT(), *Alignment,
22518 ST->getMemOperand()->getFlags(), ST->getAAInfo());
22519 // NewStore will always be N as we are only refining the alignment
22520 assert(NewStore.getNode() == N);
22521 (void)NewStore;
22522 }
22523 }
22524 }
22525
22526 // Try transforming a pair floating point load / store ops to integer
22527 // load / store ops.
22528 if (SDValue NewST = TransformFPLoadStorePair(N))
22529 return NewST;
22530
22531 // Try transforming several stores into STORE (BSWAP).
22532 if (SDValue Store = mergeTruncStores(ST))
22533 return Store;
22534
22535 if (ST->isUnindexed()) {
22536 // Walk up chain skipping non-aliasing memory nodes, on this store and any
22537 // adjacent stores.
22538 if (findBetterNeighborChains(ST)) {
22539 // replaceStoreChain uses CombineTo, which handled all of the worklist
22540 // manipulation. Return the original node to not do anything else.
22541 return SDValue(ST, 0);
22542 }
22543 Chain = ST->getChain();
22544 }
22545
22546 // FIXME: is there such a thing as a truncating indexed store?
22547 if (ST->isTruncatingStore() && ST->isUnindexed() &&
22548 Value.getValueType().isInteger() &&
22549 (!isa<ConstantSDNode>(Value) ||
22550 !cast<ConstantSDNode>(Value)->isOpaque())) {
22551 // Convert a truncating store of a extension into a standard store.
22552 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
22553 Value.getOpcode() == ISD::SIGN_EXTEND ||
22554 Value.getOpcode() == ISD::ANY_EXTEND) &&
22555 Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
22556 TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
22557 return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
22558 ST->getMemOperand());
22559
22560 APInt TruncDemandedBits =
22561 APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
22562 ST->getMemoryVT().getScalarSizeInBits());
22563
22564 // See if we can simplify the operation with SimplifyDemandedBits, which
22565 // only works if the value has a single use.
22566 AddToWorklist(Value.getNode());
22567 if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
22568 // Re-visit the store if anything changed and the store hasn't been merged
22569 // with another node (N is deleted) SimplifyDemandedBits will add Value's
22570 // node back to the worklist if necessary, but we also need to re-visit
22571 // the Store node itself.
22572 if (N->getOpcode() != ISD::DELETED_NODE)
22573 AddToWorklist(N);
22574 return SDValue(N, 0);
22575 }
22576
22577 // Otherwise, see if we can simplify the input to this truncstore with
22578 // knowledge that only the low bits are being used. For example:
22579 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
22580 if (SDValue Shorter =
22581 TLI.SimplifyMultipleUseDemandedBits(Value, TruncDemandedBits, DAG))
22582 return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
22583 ST->getMemOperand());
22584
22585 // If we're storing a truncated constant, see if we can simplify it.
22586 // TODO: Move this to targetShrinkDemandedConstant?
22587 if (auto *Cst = dyn_cast<ConstantSDNode>(Value))
22588 if (!Cst->isOpaque()) {
22589 const APInt &CValue = Cst->getAPIntValue();
22590 APInt NewVal = CValue & TruncDemandedBits;
22591 if (NewVal != CValue) {
22592 SDValue Shorter =
22593 DAG.getConstant(NewVal, SDLoc(N), Value.getValueType());
22594 return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr,
22595 ST->getMemoryVT(), ST->getMemOperand());
22596 }
22597 }
22598 }
22599
22600 // If this is a load followed by a store to the same location, then the store
22601 // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
22602 // TODO: Add big-endian truncate support with test coverage.
22603 // TODO: Can relax for unordered atomics (see D66309)
22604 SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
22605 ? peekThroughTruncates(Value)
22606 : Value;
22607 if (auto *Ld = dyn_cast<LoadSDNode>(TruncVal)) {
22608 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
22609 ST->isUnindexed() && ST->isSimple() &&
22610 Ld->getAddressSpace() == ST->getAddressSpace() &&
22611 // There can't be any side effects between the load and store, such as
22612 // a call or store.
22613 Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
22614 // The store is dead, remove it.
22615 return Chain;
22616 }
22617 }
22618
22619 // Try scalarizing vector stores of loads where we only change one element
22620 if (SDValue NewST = replaceStoreOfInsertLoad(ST))
22621 return NewST;
22622
22623 // TODO: Can relax for unordered atomics (see D66309)
22624 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
22625 if (ST->isUnindexed() && ST->isSimple() &&
22626 ST1->isUnindexed() && ST1->isSimple()) {
22627 if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
22628 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
22629 ST->getAddressSpace() == ST1->getAddressSpace()) {
22630 // If this is a store followed by a store with the same value to the
22631 // same location, then the store is dead/noop.
22632 return Chain;
22633 }
22634
22635 if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
22636 !ST1->getBasePtr().isUndef() &&
22637 ST->getAddressSpace() == ST1->getAddressSpace()) {
22638 // If we consider two stores and one smaller in size is a scalable
22639 // vector type and another one a bigger size store with a fixed type,
22640 // then we could not allow the scalable store removal because we don't
22641 // know its final size in the end.
22642 if (ST->getMemoryVT().isScalableVector() ||
22643 ST1->getMemoryVT().isScalableVector()) {
22644 if (ST1->getBasePtr() == Ptr &&
22645 TypeSize::isKnownLE(ST1->getMemoryVT().getStoreSize(),
22646 ST->getMemoryVT().getStoreSize())) {
22647 CombineTo(ST1, ST1->getChain());
22648 return SDValue(N, 0);
22649 }
22650 } else {
22651 const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
22652 const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
22653 // If this is a store who's preceding store to a subset of the current
22654 // location and no one other node is chained to that store we can
22655 // effectively drop the store. Do not remove stores to undef as they
22656 // may be used as data sinks.
22657 if (STBase.contains(DAG, ST->getMemoryVT().getFixedSizeInBits(),
22658 ChainBase,
22659 ST1->getMemoryVT().getFixedSizeInBits())) {
22660 CombineTo(ST1, ST1->getChain());
22661 return SDValue(N, 0);
22662 }
22663 }
22664 }
22665 }
22666 }
22667
22668 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
22669 // truncating store. We can do this even if this is already a truncstore.
22670 if ((Value.getOpcode() == ISD::FP_ROUND ||
22671 Value.getOpcode() == ISD::TRUNCATE) &&
22672 Value->hasOneUse() && ST->isUnindexed() &&
22673 TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
22674 ST->getMemoryVT(), LegalOperations)) {
22675 return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
22676 Ptr, ST->getMemoryVT(), ST->getMemOperand());
22677 }
22678
22679 // Always perform this optimization before types are legal. If the target
22680 // prefers, also try this after legalization to catch stores that were created
22681 // by intrinsics or other nodes.
22682 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
22683 while (true) {
22684 // There can be multiple store sequences on the same chain.
22685 // Keep trying to merge store sequences until we are unable to do so
22686 // or until we merge the last store on the chain.
22687 bool Changed = mergeConsecutiveStores(ST);
22688 if (!Changed) break;
22689 // Return N as merge only uses CombineTo and no worklist clean
22690 // up is necessary.
22691 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
22692 return SDValue(N, 0);
22693 }
22694 }
22695
22696 // Try transforming N to an indexed store.
22697 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
22698 return SDValue(N, 0);
22699
22700 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
22701 //
22702 // Make sure to do this only after attempting to merge stores in order to
22703 // avoid changing the types of some subset of stores due to visit order,
22704 // preventing their merging.
22705 if (isa<ConstantFPSDNode>(ST->getValue())) {
22706 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
22707 return NewSt;
22708 }
22709
22710 if (SDValue NewSt = splitMergedValStore(ST))
22711 return NewSt;
22712
22713 return ReduceLoadOpStoreWidth(N);
22714 }
22715
visitLIFETIME_END(SDNode * N)22716 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
22717 const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
22718 if (!LifetimeEnd->hasOffset())
22719 return SDValue();
22720
22721 const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
22722 LifetimeEnd->getOffset(), false);
22723
22724 // We walk up the chains to find stores.
22725 SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
22726 while (!Chains.empty()) {
22727 SDValue Chain = Chains.pop_back_val();
22728 if (!Chain.hasOneUse())
22729 continue;
22730 switch (Chain.getOpcode()) {
22731 case ISD::TokenFactor:
22732 for (unsigned Nops = Chain.getNumOperands(); Nops;)
22733 Chains.push_back(Chain.getOperand(--Nops));
22734 break;
22735 case ISD::LIFETIME_START:
22736 case ISD::LIFETIME_END:
22737 // We can forward past any lifetime start/end that can be proven not to
22738 // alias the node.
22739 if (!mayAlias(Chain.getNode(), N))
22740 Chains.push_back(Chain.getOperand(0));
22741 break;
22742 case ISD::STORE: {
22743 StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
22744 // TODO: Can relax for unordered atomics (see D66309)
22745 if (!ST->isSimple() || ST->isIndexed())
22746 continue;
22747 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
22748 // The bounds of a scalable store are not known until runtime, so this
22749 // store cannot be elided.
22750 if (StoreSize.isScalable())
22751 continue;
22752 const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
22753 // If we store purely within object bounds just before its lifetime ends,
22754 // we can remove the store.
22755 if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
22756 StoreSize.getFixedValue() * 8)) {
22757 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
22758 dbgs() << "\nwithin LIFETIME_END of : ";
22759 LifetimeEndBase.dump(); dbgs() << "\n");
22760 CombineTo(ST, ST->getChain());
22761 return SDValue(N, 0);
22762 }
22763 }
22764 }
22765 }
22766 return SDValue();
22767 }
22768
22769 /// For the instruction sequence of store below, F and I values
22770 /// are bundled together as an i64 value before being stored into memory.
22771 /// Sometimes it is more efficent to generate separate stores for F and I,
22772 /// which can remove the bitwise instructions or sink them to colder places.
22773 ///
22774 /// (store (or (zext (bitcast F to i32) to i64),
22775 /// (shl (zext I to i64), 32)), addr) -->
22776 /// (store F, addr) and (store I, addr+4)
22777 ///
22778 /// Similarly, splitting for other merged store can also be beneficial, like:
22779 /// For pair of {i32, i32}, i64 store --> two i32 stores.
22780 /// For pair of {i32, i16}, i64 store --> two i32 stores.
22781 /// For pair of {i16, i16}, i32 store --> two i16 stores.
22782 /// For pair of {i16, i8}, i32 store --> two i16 stores.
22783 /// For pair of {i8, i8}, i16 store --> two i8 stores.
22784 ///
22785 /// We allow each target to determine specifically which kind of splitting is
22786 /// supported.
22787 ///
22788 /// The store patterns are commonly seen from the simple code snippet below
22789 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
22790 /// void goo(const std::pair<int, float> &);
22791 /// hoo() {
22792 /// ...
22793 /// goo(std::make_pair(tmp, ftmp));
22794 /// ...
22795 /// }
22796 ///
splitMergedValStore(StoreSDNode * ST)22797 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
22798 if (OptLevel == CodeGenOptLevel::None)
22799 return SDValue();
22800
22801 // Can't change the number of memory accesses for a volatile store or break
22802 // atomicity for an atomic one.
22803 if (!ST->isSimple())
22804 return SDValue();
22805
22806 SDValue Val = ST->getValue();
22807 SDLoc DL(ST);
22808
22809 // Match OR operand.
22810 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
22811 return SDValue();
22812
22813 // Match SHL operand and get Lower and Higher parts of Val.
22814 SDValue Op1 = Val.getOperand(0);
22815 SDValue Op2 = Val.getOperand(1);
22816 SDValue Lo, Hi;
22817 if (Op1.getOpcode() != ISD::SHL) {
22818 std::swap(Op1, Op2);
22819 if (Op1.getOpcode() != ISD::SHL)
22820 return SDValue();
22821 }
22822 Lo = Op2;
22823 Hi = Op1.getOperand(0);
22824 if (!Op1.hasOneUse())
22825 return SDValue();
22826
22827 // Match shift amount to HalfValBitSize.
22828 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
22829 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
22830 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
22831 return SDValue();
22832
22833 // Lo and Hi are zero-extended from int with size less equal than 32
22834 // to i64.
22835 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
22836 !Lo.getOperand(0).getValueType().isScalarInteger() ||
22837 Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
22838 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
22839 !Hi.getOperand(0).getValueType().isScalarInteger() ||
22840 Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
22841 return SDValue();
22842
22843 // Use the EVT of low and high parts before bitcast as the input
22844 // of target query.
22845 EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
22846 ? Lo.getOperand(0).getValueType()
22847 : Lo.getValueType();
22848 EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
22849 ? Hi.getOperand(0).getValueType()
22850 : Hi.getValueType();
22851 if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
22852 return SDValue();
22853
22854 // Start to split store.
22855 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
22856 AAMDNodes AAInfo = ST->getAAInfo();
22857
22858 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
22859 EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
22860 Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
22861 Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
22862
22863 SDValue Chain = ST->getChain();
22864 SDValue Ptr = ST->getBasePtr();
22865 // Lower value store.
22866 SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
22867 ST->getBaseAlign(), MMOFlags, AAInfo);
22868 Ptr =
22869 DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(HalfValBitSize / 8), DL);
22870 // Higher value store.
22871 SDValue St1 = DAG.getStore(
22872 St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
22873 ST->getBaseAlign(), MMOFlags, AAInfo);
22874 return St1;
22875 }
22876
22877 // Merge an insertion into an existing shuffle:
22878 // (insert_vector_elt (vector_shuffle X, Y, Mask),
22879 // .(extract_vector_elt X, N), InsIndex)
22880 // --> (vector_shuffle X, Y, NewMask)
22881 // and variations where shuffle operands may be CONCAT_VECTORS.
mergeEltWithShuffle(SDValue & X,SDValue & Y,ArrayRef<int> Mask,SmallVectorImpl<int> & NewMask,SDValue Elt,unsigned InsIndex)22882 static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
22883 SmallVectorImpl<int> &NewMask, SDValue Elt,
22884 unsigned InsIndex) {
22885 if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
22886 !isa<ConstantSDNode>(Elt.getOperand(1)))
22887 return false;
22888
22889 // Vec's operand 0 is using indices from 0 to N-1 and
22890 // operand 1 from N to 2N - 1, where N is the number of
22891 // elements in the vectors.
22892 SDValue InsertVal0 = Elt.getOperand(0);
22893 int ElementOffset = -1;
22894
22895 // We explore the inputs of the shuffle in order to see if we find the
22896 // source of the extract_vector_elt. If so, we can use it to modify the
22897 // shuffle rather than perform an insert_vector_elt.
22898 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
22899 ArgWorkList.emplace_back(Mask.size(), Y);
22900 ArgWorkList.emplace_back(0, X);
22901
22902 while (!ArgWorkList.empty()) {
22903 int ArgOffset;
22904 SDValue ArgVal;
22905 std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
22906
22907 if (ArgVal == InsertVal0) {
22908 ElementOffset = ArgOffset;
22909 break;
22910 }
22911
22912 // Peek through concat_vector.
22913 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
22914 int CurrentArgOffset =
22915 ArgOffset + ArgVal.getValueType().getVectorNumElements();
22916 int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
22917 for (SDValue Op : reverse(ArgVal->ops())) {
22918 CurrentArgOffset -= Step;
22919 ArgWorkList.emplace_back(CurrentArgOffset, Op);
22920 }
22921
22922 // Make sure we went through all the elements and did not screw up index
22923 // computation.
22924 assert(CurrentArgOffset == ArgOffset);
22925 }
22926 }
22927
22928 // If we failed to find a match, see if we can replace an UNDEF shuffle
22929 // operand.
22930 if (ElementOffset == -1) {
22931 if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
22932 return false;
22933 ElementOffset = Mask.size();
22934 Y = InsertVal0;
22935 }
22936
22937 NewMask.assign(Mask.begin(), Mask.end());
22938 NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(1);
22939 assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
22940 "NewMask[InsIndex] is out of bound");
22941 return true;
22942 }
22943
22944 // Merge an insertion into an existing shuffle:
22945 // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
22946 // InsIndex)
22947 // --> (vector_shuffle X, Y) and variations where shuffle operands may be
22948 // CONCAT_VECTORS.
mergeInsertEltWithShuffle(SDNode * N,unsigned InsIndex)22949 SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
22950 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
22951 "Expected extract_vector_elt");
22952 SDValue InsertVal = N->getOperand(1);
22953 SDValue Vec = N->getOperand(0);
22954
22955 auto *SVN = dyn_cast<ShuffleVectorSDNode>(Vec);
22956 if (!SVN || !Vec.hasOneUse())
22957 return SDValue();
22958
22959 ArrayRef<int> Mask = SVN->getMask();
22960 SDValue X = Vec.getOperand(0);
22961 SDValue Y = Vec.getOperand(1);
22962
22963 SmallVector<int, 16> NewMask(Mask);
22964 if (mergeEltWithShuffle(X, Y, Mask, NewMask, InsertVal, InsIndex)) {
22965 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
22966 Vec.getValueType(), SDLoc(N), X, Y, NewMask, DAG);
22967 if (LegalShuffle)
22968 return LegalShuffle;
22969 }
22970
22971 return SDValue();
22972 }
22973
22974 // Convert a disguised subvector insertion into a shuffle:
22975 // insert_vector_elt V, (bitcast X from vector type), IdxC -->
22976 // bitcast(shuffle (bitcast V), (extended X), Mask)
22977 // Note: We do not use an insert_subvector node because that requires a
22978 // legal subvector type.
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)22979 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
22980 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
22981 "Expected extract_vector_elt");
22982 SDValue InsertVal = N->getOperand(1);
22983
22984 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
22985 !InsertVal.getOperand(0).getValueType().isVector())
22986 return SDValue();
22987
22988 SDValue SubVec = InsertVal.getOperand(0);
22989 SDValue DestVec = N->getOperand(0);
22990 EVT SubVecVT = SubVec.getValueType();
22991 EVT VT = DestVec.getValueType();
22992 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
22993 // If the source only has a single vector element, the cost of creating adding
22994 // it to a vector is likely to exceed the cost of a insert_vector_elt.
22995 if (NumSrcElts == 1)
22996 return SDValue();
22997 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
22998 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
22999
23000 // Step 1: Create a shuffle mask that implements this insert operation. The
23001 // vector that we are inserting into will be operand 0 of the shuffle, so
23002 // those elements are just 'i'. The inserted subvector is in the first
23003 // positions of operand 1 of the shuffle. Example:
23004 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
23005 SmallVector<int, 16> Mask(NumMaskVals);
23006 for (unsigned i = 0; i != NumMaskVals; ++i) {
23007 if (i / NumSrcElts == InsIndex)
23008 Mask[i] = (i % NumSrcElts) + NumMaskVals;
23009 else
23010 Mask[i] = i;
23011 }
23012
23013 // Bail out if the target can not handle the shuffle we want to create.
23014 EVT SubVecEltVT = SubVecVT.getVectorElementType();
23015 EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
23016 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
23017 return SDValue();
23018
23019 // Step 2: Create a wide vector from the inserted source vector by appending
23020 // undefined elements. This is the same size as our destination vector.
23021 SDLoc DL(N);
23022 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
23023 ConcatOps[0] = SubVec;
23024 SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
23025
23026 // Step 3: Shuffle in the padded subvector.
23027 SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
23028 SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
23029 AddToWorklist(PaddedSubV.getNode());
23030 AddToWorklist(DestVecBC.getNode());
23031 AddToWorklist(Shuf.getNode());
23032 return DAG.getBitcast(VT, Shuf);
23033 }
23034
23035 // Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
23036 // possible and the new load will be quick. We use more loads but less shuffles
23037 // and inserts.
combineInsertEltToLoad(SDNode * N,unsigned InsIndex)23038 SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
23039 EVT VT = N->getValueType(0);
23040
23041 // InsIndex is expected to be the first of last lane.
23042 if (!VT.isFixedLengthVector() ||
23043 (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
23044 return SDValue();
23045
23046 // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
23047 // depending on the InsIndex.
23048 auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0));
23049 SDValue Scalar = N->getOperand(1);
23050 if (!Shuffle || !all_of(enumerate(Shuffle->getMask()), [&](auto P) {
23051 return InsIndex == P.index() || P.value() < 0 ||
23052 (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
23053 (InsIndex == VT.getVectorNumElements() - 1 &&
23054 P.value() == (int)P.index() + 1);
23055 }))
23056 return SDValue();
23057
23058 // We optionally skip over an extend so long as both loads are extended in the
23059 // same way from the same type.
23060 unsigned Extend = 0;
23061 if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
23062 Scalar.getOpcode() == ISD::SIGN_EXTEND ||
23063 Scalar.getOpcode() == ISD::ANY_EXTEND) {
23064 Extend = Scalar.getOpcode();
23065 Scalar = Scalar.getOperand(0);
23066 }
23067
23068 auto *ScalarLoad = dyn_cast<LoadSDNode>(Scalar);
23069 if (!ScalarLoad)
23070 return SDValue();
23071
23072 SDValue Vec = Shuffle->getOperand(0);
23073 if (Extend) {
23074 if (Vec.getOpcode() != Extend)
23075 return SDValue();
23076 Vec = Vec.getOperand(0);
23077 }
23078 auto *VecLoad = dyn_cast<LoadSDNode>(Vec);
23079 if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
23080 return SDValue();
23081
23082 int EltSize = ScalarLoad->getValueType(0).getScalarSizeInBits();
23083 if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
23084 !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23085 ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
23086 ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
23087 return SDValue();
23088
23089 // Check that the offset between the pointers to produce a single continuous
23090 // load.
23091 if (InsIndex == 0) {
23092 if (!DAG.areNonVolatileConsecutiveLoads(ScalarLoad, VecLoad, EltSize / 8,
23093 -1))
23094 return SDValue();
23095 } else {
23096 if (!DAG.areNonVolatileConsecutiveLoads(
23097 VecLoad, ScalarLoad, VT.getVectorNumElements() * EltSize / 8, -1))
23098 return SDValue();
23099 }
23100
23101 // And that the new unaligned load will be fast.
23102 unsigned IsFast = 0;
23103 Align NewAlign = commonAlignment(VecLoad->getAlign(), EltSize / 8);
23104 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
23105 Vec.getValueType(), VecLoad->getAddressSpace(),
23106 NewAlign, VecLoad->getMemOperand()->getFlags(),
23107 &IsFast) ||
23108 !IsFast)
23109 return SDValue();
23110
23111 // Calculate the new Ptr and create the new load.
23112 SDLoc DL(N);
23113 SDValue Ptr = ScalarLoad->getBasePtr();
23114 if (InsIndex != 0)
23115 Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), VecLoad->getBasePtr(),
23116 DAG.getConstant(EltSize / 8, DL, Ptr.getValueType()));
23117 MachinePointerInfo PtrInfo =
23118 InsIndex == 0 ? ScalarLoad->getPointerInfo()
23119 : VecLoad->getPointerInfo().getWithOffset(EltSize / 8);
23120
23121 SDValue Load = DAG.getLoad(VecLoad->getValueType(0), DL,
23122 ScalarLoad->getChain(), Ptr, PtrInfo, NewAlign);
23123 DAG.makeEquivalentMemoryOrdering(ScalarLoad, Load.getValue(1));
23124 DAG.makeEquivalentMemoryOrdering(VecLoad, Load.getValue(1));
23125 return Extend ? DAG.getNode(Extend, DL, VT, Load) : Load;
23126 }
23127
visitINSERT_VECTOR_ELT(SDNode * N)23128 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
23129 SDValue InVec = N->getOperand(0);
23130 SDValue InVal = N->getOperand(1);
23131 SDValue EltNo = N->getOperand(2);
23132 SDLoc DL(N);
23133
23134 EVT VT = InVec.getValueType();
23135 auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
23136
23137 // Insert into out-of-bounds element is undefined.
23138 if (IndexC && VT.isFixedLengthVector() &&
23139 IndexC->getZExtValue() >= VT.getVectorNumElements())
23140 return DAG.getUNDEF(VT);
23141
23142 // Remove redundant insertions:
23143 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
23144 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23145 InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
23146 return InVec;
23147
23148 if (!IndexC) {
23149 // If this is variable insert to undef vector, it might be better to splat:
23150 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23151 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23152 return DAG.getSplat(VT, DL, InVal);
23153 return SDValue();
23154 }
23155
23156 if (VT.isScalableVector())
23157 return SDValue();
23158
23159 unsigned NumElts = VT.getVectorNumElements();
23160
23161 // We must know which element is being inserted for folds below here.
23162 unsigned Elt = IndexC->getZExtValue();
23163
23164 // Handle <1 x ???> vector insertion special cases.
23165 if (NumElts == 1) {
23166 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
23167 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23168 InVal.getOperand(0).getValueType() == VT &&
23169 isNullConstant(InVal.getOperand(1)))
23170 return InVal.getOperand(0);
23171 }
23172
23173 // Canonicalize insert_vector_elt dag nodes.
23174 // Example:
23175 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
23176 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
23177 //
23178 // Do this only if the child insert_vector node has one use; also
23179 // do this only if indices are both constants and Idx1 < Idx0.
23180 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
23181 && isa<ConstantSDNode>(InVec.getOperand(2))) {
23182 unsigned OtherElt = InVec.getConstantOperandVal(2);
23183 if (Elt < OtherElt) {
23184 // Swap nodes.
23185 SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
23186 InVec.getOperand(0), InVal, EltNo);
23187 AddToWorklist(NewOp.getNode());
23188 return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
23189 VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
23190 }
23191 }
23192
23193 if (SDValue Shuf = mergeInsertEltWithShuffle(N, Elt))
23194 return Shuf;
23195
23196 if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
23197 return Shuf;
23198
23199 if (SDValue Shuf = combineInsertEltToLoad(N, Elt))
23200 return Shuf;
23201
23202 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
23203 if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
23204 // vXi1 vector - we don't need to recurse.
23205 if (NumElts == 1)
23206 return DAG.getBuildVector(VT, DL, {InVal});
23207
23208 // If we haven't already collected the element, insert into the op list.
23209 EVT MaxEltVT = InVal.getValueType();
23210 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
23211 unsigned Idx) {
23212 if (!Ops[Idx]) {
23213 Ops[Idx] = Elt;
23214 if (VT.isInteger()) {
23215 EVT EltVT = Elt.getValueType();
23216 MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
23217 }
23218 }
23219 };
23220
23221 // Ensure all the operands are the same value type, fill any missing
23222 // operands with UNDEF and create the BUILD_VECTOR.
23223 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops,
23224 bool FreezeUndef = false) {
23225 assert(Ops.size() == NumElts && "Unexpected vector size");
23226 SDValue UndefOp = FreezeUndef ? DAG.getFreeze(DAG.getUNDEF(MaxEltVT))
23227 : DAG.getUNDEF(MaxEltVT);
23228 for (SDValue &Op : Ops) {
23229 if (Op)
23230 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
23231 else
23232 Op = UndefOp;
23233 }
23234 return DAG.getBuildVector(VT, DL, Ops);
23235 };
23236
23237 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
23238 Ops[Elt] = InVal;
23239
23240 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
23241 for (SDValue CurVec = InVec; CurVec;) {
23242 // UNDEF - build new BUILD_VECTOR from already inserted operands.
23243 if (CurVec.isUndef())
23244 return CanonicalizeBuildVector(Ops);
23245
23246 // FREEZE(UNDEF) - build new BUILD_VECTOR from already inserted operands.
23247 if (ISD::isFreezeUndef(CurVec.getNode()) && CurVec.hasOneUse())
23248 return CanonicalizeBuildVector(Ops, /*FreezeUndef=*/true);
23249
23250 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
23251 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
23252 for (unsigned I = 0; I != NumElts; ++I)
23253 AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
23254 return CanonicalizeBuildVector(Ops);
23255 }
23256
23257 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
23258 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
23259 AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
23260 return CanonicalizeBuildVector(Ops);
23261 }
23262
23263 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
23264 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
23265 if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
23266 if (CurIdx->getAPIntValue().ult(NumElts)) {
23267 unsigned Idx = CurIdx->getZExtValue();
23268 AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
23269
23270 // Found entire BUILD_VECTOR.
23271 if (all_of(Ops, [](SDValue Op) { return !!Op; }))
23272 return CanonicalizeBuildVector(Ops);
23273
23274 CurVec = CurVec->getOperand(0);
23275 continue;
23276 }
23277
23278 // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
23279 // update the shuffle mask (and second operand if we started with unary
23280 // shuffle) and create a new legal shuffle.
23281 if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
23282 auto *SVN = cast<ShuffleVectorSDNode>(CurVec);
23283 SDValue LHS = SVN->getOperand(0);
23284 SDValue RHS = SVN->getOperand(1);
23285 SmallVector<int, 16> Mask(SVN->getMask());
23286 bool Merged = true;
23287 for (auto I : enumerate(Ops)) {
23288 SDValue &Op = I.value();
23289 if (Op) {
23290 SmallVector<int, 16> NewMask;
23291 if (!mergeEltWithShuffle(LHS, RHS, Mask, NewMask, Op, I.index())) {
23292 Merged = false;
23293 break;
23294 }
23295 Mask = std::move(NewMask);
23296 }
23297 }
23298 if (Merged)
23299 if (SDValue NewShuffle =
23300 TLI.buildLegalVectorShuffle(VT, DL, LHS, RHS, Mask, DAG))
23301 return NewShuffle;
23302 }
23303
23304 if (!LegalOperations) {
23305 bool IsNull = llvm::isNullConstant(InVal);
23306 // We can convert to AND/OR mask if all insertions are zero or -1
23307 // respectively.
23308 if ((IsNull || llvm::isAllOnesConstant(InVal)) &&
23309 all_of(Ops, [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
23310 count_if(Ops, [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
23311 SDValue Zero = DAG.getConstant(0, DL, MaxEltVT);
23312 SDValue AllOnes = DAG.getAllOnesConstant(DL, MaxEltVT);
23313 SmallVector<SDValue, 8> Mask(NumElts);
23314
23315 // Build the mask and return the corresponding DAG node.
23316 auto BuildMaskAndNode = [&](SDValue TrueVal, SDValue FalseVal,
23317 unsigned MaskOpcode) {
23318 for (unsigned I = 0; I != NumElts; ++I)
23319 Mask[I] = Ops[I] ? TrueVal : FalseVal;
23320 return DAG.getNode(MaskOpcode, DL, VT, CurVec,
23321 DAG.getBuildVector(VT, DL, Mask));
23322 };
23323
23324 // If all elements are zero, we can use AND with all ones.
23325 if (IsNull)
23326 return BuildMaskAndNode(Zero, AllOnes, ISD::AND);
23327
23328 // If all elements are -1, we can use OR with zero.
23329 return BuildMaskAndNode(AllOnes, Zero, ISD::OR);
23330 }
23331 }
23332
23333 // Failed to find a match in the chain - bail.
23334 break;
23335 }
23336
23337 // See if we can fill in the missing constant elements as zeros.
23338 // TODO: Should we do this for any constant?
23339 APInt DemandedZeroElts = APInt::getZero(NumElts);
23340 for (unsigned I = 0; I != NumElts; ++I)
23341 if (!Ops[I])
23342 DemandedZeroElts.setBit(I);
23343
23344 if (DAG.MaskedVectorIsZero(InVec, DemandedZeroElts)) {
23345 SDValue Zero = VT.isInteger() ? DAG.getConstant(0, DL, MaxEltVT)
23346 : DAG.getConstantFP(0, DL, MaxEltVT);
23347 for (unsigned I = 0; I != NumElts; ++I)
23348 if (!Ops[I])
23349 Ops[I] = Zero;
23350
23351 return CanonicalizeBuildVector(Ops);
23352 }
23353 }
23354
23355 return SDValue();
23356 }
23357
23358 /// Transform a vector binary operation into a scalar binary operation by moving
23359 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinOp(SDNode * ExtElt,SelectionDAG & DAG,const SDLoc & DL,bool LegalTypes)23360 static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
23361 const SDLoc &DL, bool LegalTypes) {
23362 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23363 SDValue Vec = ExtElt->getOperand(0);
23364 SDValue Index = ExtElt->getOperand(1);
23365 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
23366 unsigned Opc = Vec.getOpcode();
23367 if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
23368 Vec->getNumValues() != 1)
23369 return SDValue();
23370
23371 // Targets may want to avoid this to prevent an expensive register transfer.
23372 if (!TLI.shouldScalarizeBinop(Vec))
23373 return SDValue();
23374
23375 EVT ResVT = ExtElt->getValueType(0);
23376 if (Opc == ISD::SETCC &&
23377 (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
23378 return SDValue();
23379
23380 // Extracting an element of a vector constant is constant-folded, so this
23381 // transform is just replacing a vector op with a scalar op while moving the
23382 // extract.
23383 SDValue Op0 = Vec.getOperand(0);
23384 SDValue Op1 = Vec.getOperand(1);
23385 APInt SplatVal;
23386 if (!isAnyConstantBuildVector(Op0, true) &&
23387 !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
23388 !isAnyConstantBuildVector(Op1, true) &&
23389 !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
23390 return SDValue();
23391
23392 // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
23393 // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
23394 if (Opc == ISD::SETCC) {
23395 EVT OpVT = Op0.getValueType().getVectorElementType();
23396 Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
23397 Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
23398 SDValue NewVal = DAG.getSetCC(
23399 DL, ResVT, Op0, Op1, cast<CondCodeSDNode>(Vec->getOperand(2))->get());
23400 // We may need to sign- or zero-extend the result to match the same
23401 // behaviour as the vector version of SETCC.
23402 unsigned VecBoolContents = TLI.getBooleanContents(Vec.getValueType());
23403 if (ResVT != MVT::i1 &&
23404 VecBoolContents != TargetLowering::UndefinedBooleanContent &&
23405 VecBoolContents != TLI.getBooleanContents(ResVT)) {
23406 if (VecBoolContents == TargetLowering::ZeroOrNegativeOneBooleanContent)
23407 NewVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ResVT, NewVal,
23408 DAG.getValueType(MVT::i1));
23409 else
23410 NewVal = DAG.getZeroExtendInReg(NewVal, DL, MVT::i1);
23411 }
23412 return NewVal;
23413 }
23414 Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
23415 Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
23416 return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
23417 }
23418
23419 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
23420 // recursively analyse all of it's users. and try to model themselves as
23421 // bit sequence extractions. If all of them agree on the new, narrower element
23422 // type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
23423 // new element type, do so now.
23424 // This is mainly useful to recover from legalization that scalarized
23425 // the vector as wide elements, but tries to rebuild it with narrower elements.
23426 //
23427 // Some more nodes could be modelled if that helps cover interesting patterns.
refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode * N)23428 bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
23429 SDNode *N) {
23430 // We perform this optimization post type-legalization because
23431 // the type-legalizer often scalarizes integer-promoted vectors.
23432 // Performing this optimization before may cause legalizaton cycles.
23433 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
23434 return false;
23435
23436 // TODO: Add support for big-endian.
23437 if (DAG.getDataLayout().isBigEndian())
23438 return false;
23439
23440 SDValue VecOp = N->getOperand(0);
23441 EVT VecVT = VecOp.getValueType();
23442 assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
23443
23444 // We must start with a constant extraction index.
23445 auto *IndexC = dyn_cast<ConstantSDNode>(N->getOperand(1));
23446 if (!IndexC)
23447 return false;
23448
23449 assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
23450 "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
23451
23452 // TODO: deal with the case of implicit anyext of the extraction.
23453 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
23454 EVT ScalarVT = N->getValueType(0);
23455 if (VecVT.getScalarType() != ScalarVT)
23456 return false;
23457
23458 // TODO: deal with the cases other than everything being integer-typed.
23459 if (!ScalarVT.isScalarInteger())
23460 return false;
23461
23462 struct Entry {
23463 SDNode *Producer;
23464
23465 // Which bits of VecOp does it contain?
23466 unsigned BitPos;
23467 int NumBits;
23468 // NOTE: the actual width of \p Producer may be wider than NumBits!
23469
23470 Entry(Entry &&) = default;
23471 Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
23472 : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
23473
23474 Entry() = delete;
23475 Entry(const Entry &) = delete;
23476 Entry &operator=(const Entry &) = delete;
23477 Entry &operator=(Entry &&) = delete;
23478 };
23479 SmallVector<Entry, 32> Worklist;
23480 SmallVector<Entry, 32> Leafs;
23481
23482 // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
23483 Worklist.emplace_back(N, /*BitPos=*/VecEltBitWidth * IndexC->getZExtValue(),
23484 /*NumBits=*/VecEltBitWidth);
23485
23486 while (!Worklist.empty()) {
23487 Entry E = Worklist.pop_back_val();
23488 // Does the node not even use any of the VecOp bits?
23489 if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
23490 E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
23491 return false; // Let's allow the other combines clean this up first.
23492 // Did we fail to model any of the users of the Producer?
23493 bool ProducerIsLeaf = false;
23494 // Look at each user of this Producer.
23495 for (SDNode *User : E.Producer->users()) {
23496 switch (User->getOpcode()) {
23497 // TODO: support ISD::BITCAST
23498 // TODO: support ISD::ANY_EXTEND
23499 // TODO: support ISD::ZERO_EXTEND
23500 // TODO: support ISD::SIGN_EXTEND
23501 case ISD::TRUNCATE:
23502 // Truncation simply means we keep position, but extract less bits.
23503 Worklist.emplace_back(User, E.BitPos,
23504 /*NumBits=*/User->getValueSizeInBits(0));
23505 break;
23506 // TODO: support ISD::SRA
23507 // TODO: support ISD::SHL
23508 case ISD::SRL:
23509 // We should be shifting the Producer by a constant amount.
23510 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(User->getOperand(1));
23511 User->getOperand(0).getNode() == E.Producer && ShAmtC) {
23512 // Logical right-shift means that we start extraction later,
23513 // but stop it at the same position we did previously.
23514 unsigned ShAmt = ShAmtC->getZExtValue();
23515 Worklist.emplace_back(User, E.BitPos + ShAmt, E.NumBits - ShAmt);
23516 break;
23517 }
23518 [[fallthrough]];
23519 default:
23520 // We can not model this user of the Producer.
23521 // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
23522 ProducerIsLeaf = true;
23523 // Profitability check: all users that we can not model
23524 // must be ISD::BUILD_VECTOR's.
23525 if (User->getOpcode() != ISD::BUILD_VECTOR)
23526 return false;
23527 break;
23528 }
23529 }
23530 if (ProducerIsLeaf)
23531 Leafs.emplace_back(std::move(E));
23532 }
23533
23534 unsigned NewVecEltBitWidth = Leafs.front().NumBits;
23535
23536 // If we are still at the same element granularity, give up,
23537 if (NewVecEltBitWidth == VecEltBitWidth)
23538 return false;
23539
23540 // The vector width must be a multiple of the new element width.
23541 if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
23542 return false;
23543
23544 // All leafs must agree on the new element width.
23545 // All leafs must not expect any "padding" bits ontop of that width.
23546 // All leafs must start extraction from multiple of that width.
23547 if (!all_of(Leafs, [NewVecEltBitWidth](const Entry &E) {
23548 return (unsigned)E.NumBits == NewVecEltBitWidth &&
23549 E.Producer->getValueSizeInBits(0) == NewVecEltBitWidth &&
23550 E.BitPos % NewVecEltBitWidth == 0;
23551 }))
23552 return false;
23553
23554 EVT NewScalarVT = EVT::getIntegerVT(*DAG.getContext(), NewVecEltBitWidth);
23555 EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewScalarVT,
23556 VecVT.getSizeInBits() / NewVecEltBitWidth);
23557
23558 if (LegalTypes &&
23559 !(TLI.isTypeLegal(NewScalarVT) && TLI.isTypeLegal(NewVecVT)))
23560 return false;
23561
23562 if (LegalOperations &&
23563 !(TLI.isOperationLegalOrCustom(ISD::BITCAST, NewVecVT) &&
23564 TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, NewVecVT)))
23565 return false;
23566
23567 SDValue NewVecOp = DAG.getBitcast(NewVecVT, VecOp);
23568 for (const Entry &E : Leafs) {
23569 SDLoc DL(E.Producer);
23570 unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
23571 assert(NewIndex < NewVecVT.getVectorNumElements() &&
23572 "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
23573 SDValue V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, NewScalarVT, NewVecOp,
23574 DAG.getVectorIdxConstant(NewIndex, DL));
23575 CombineTo(E.Producer, V);
23576 }
23577
23578 return true;
23579 }
23580
visitEXTRACT_VECTOR_ELT(SDNode * N)23581 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
23582 SDValue VecOp = N->getOperand(0);
23583 SDValue Index = N->getOperand(1);
23584 EVT ScalarVT = N->getValueType(0);
23585 EVT VecVT = VecOp.getValueType();
23586 if (VecOp.isUndef())
23587 return DAG.getUNDEF(ScalarVT);
23588
23589 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
23590 //
23591 // This only really matters if the index is non-constant since other combines
23592 // on the constant elements already work.
23593 SDLoc DL(N);
23594 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
23595 Index == VecOp.getOperand(2)) {
23596 SDValue Elt = VecOp.getOperand(1);
23597 AddUsersToWorklist(VecOp.getNode());
23598 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
23599 }
23600
23601 // (vextract (scalar_to_vector val, 0) -> val
23602 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
23603 // Only 0'th element of SCALAR_TO_VECTOR is defined.
23604 if (DAG.isKnownNeverZero(Index))
23605 return DAG.getUNDEF(ScalarVT);
23606
23607 // Check if the result type doesn't match the inserted element type.
23608 // The inserted element and extracted element may have mismatched bitwidth.
23609 // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
23610 SDValue InOp = VecOp.getOperand(0);
23611 if (InOp.getValueType() != ScalarVT) {
23612 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
23613 if (InOp.getValueType().bitsGT(ScalarVT))
23614 return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
23615 return DAG.getNode(ISD::ANY_EXTEND, DL, ScalarVT, InOp);
23616 }
23617 return InOp;
23618 }
23619
23620 // extract_vector_elt of out-of-bounds element -> UNDEF
23621 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
23622 if (IndexC && VecVT.isFixedLengthVector() &&
23623 IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
23624 return DAG.getUNDEF(ScalarVT);
23625
23626 // extract_vector_elt (build_vector x, y), 1 -> y
23627 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
23628 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
23629 TLI.isTypeLegal(VecVT)) {
23630 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
23631 VecVT.isFixedLengthVector()) &&
23632 "BUILD_VECTOR used for scalable vectors");
23633 unsigned IndexVal =
23634 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
23635 SDValue Elt = VecOp.getOperand(IndexVal);
23636 EVT InEltVT = Elt.getValueType();
23637
23638 if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
23639 isNullConstant(Elt)) {
23640 // Sometimes build_vector's scalar input types do not match result type.
23641 if (ScalarVT == InEltVT)
23642 return Elt;
23643
23644 // TODO: It may be useful to truncate if free if the build_vector
23645 // implicitly converts.
23646 }
23647 }
23648
23649 if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
23650 return BO;
23651
23652 if (VecVT.isScalableVector())
23653 return SDValue();
23654
23655 // All the code from this point onwards assumes fixed width vectors, but it's
23656 // possible that some of the combinations could be made to work for scalable
23657 // vectors too.
23658 unsigned NumElts = VecVT.getVectorNumElements();
23659 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
23660
23661 // See if the extracted element is constant, in which case fold it if its
23662 // a legal fp immediate.
23663 if (IndexC && ScalarVT.isFloatingPoint()) {
23664 APInt EltMask = APInt::getOneBitSet(NumElts, IndexC->getZExtValue());
23665 KnownBits KnownElt = DAG.computeKnownBits(VecOp, EltMask);
23666 if (KnownElt.isConstant()) {
23667 APFloat CstFP =
23668 APFloat(ScalarVT.getFltSemantics(), KnownElt.getConstant());
23669 if (TLI.isFPImmLegal(CstFP, ScalarVT))
23670 return DAG.getConstantFP(CstFP, DL, ScalarVT);
23671 }
23672 }
23673
23674 // TODO: These transforms should not require the 'hasOneUse' restriction, but
23675 // there are regressions on multiple targets without it. We can end up with a
23676 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
23677 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
23678 VecOp.hasOneUse()) {
23679 // The vector index of the LSBs of the source depend on the endian-ness.
23680 bool IsLE = DAG.getDataLayout().isLittleEndian();
23681 unsigned ExtractIndex = IndexC->getZExtValue();
23682 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
23683 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
23684 SDValue BCSrc = VecOp.getOperand(0);
23685 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
23686 return DAG.getAnyExtOrTrunc(BCSrc, DL, ScalarVT);
23687
23688 // TODO: Add support for SCALAR_TO_VECTOR implicit truncation.
23689 if (LegalTypes && BCSrc.getValueType().isInteger() &&
23690 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23691 BCSrc.getScalarValueSizeInBits() ==
23692 BCSrc.getOperand(0).getScalarValueSizeInBits()) {
23693 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
23694 // trunc i64 X to i32
23695 SDValue X = BCSrc.getOperand(0);
23696 EVT XVT = X.getValueType();
23697 assert(XVT.isScalarInteger() && ScalarVT.isScalarInteger() &&
23698 "Extract element and scalar to vector can't change element type "
23699 "from FP to integer.");
23700 unsigned XBitWidth = X.getValueSizeInBits();
23701 unsigned Scale = XBitWidth / VecEltBitWidth;
23702 BCTruncElt = IsLE ? 0 : Scale - 1;
23703
23704 // An extract element return value type can be wider than its vector
23705 // operand element type. In that case, the high bits are undefined, so
23706 // it's possible that we may need to extend rather than truncate.
23707 if (ExtractIndex < Scale && XBitWidth > VecEltBitWidth) {
23708 assert(XBitWidth % VecEltBitWidth == 0 &&
23709 "Scalar bitwidth must be a multiple of vector element bitwidth");
23710
23711 if (ExtractIndex != BCTruncElt) {
23712 unsigned ShiftIndex =
23713 IsLE ? ExtractIndex : (Scale - 1) - ExtractIndex;
23714 X = DAG.getNode(
23715 ISD::SRL, DL, XVT, X,
23716 DAG.getShiftAmountConstant(ShiftIndex * VecEltBitWidth, XVT, DL));
23717 }
23718
23719 return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
23720 }
23721 }
23722 }
23723
23724 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
23725 // We only perform this optimization before the op legalization phase because
23726 // we may introduce new vector instructions which are not backed by TD
23727 // patterns. For example on AVX, extracting elements from a wide vector
23728 // without using extract_subvector. However, if we can find an underlying
23729 // scalar value, then we can always use that.
23730 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
23731 auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
23732 // Find the new index to extract from.
23733 int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
23734
23735 // Extracting an undef index is undef.
23736 if (OrigElt == -1)
23737 return DAG.getUNDEF(ScalarVT);
23738
23739 // Select the right vector half to extract from.
23740 SDValue SVInVec;
23741 if (OrigElt < (int)NumElts) {
23742 SVInVec = VecOp.getOperand(0);
23743 } else {
23744 SVInVec = VecOp.getOperand(1);
23745 OrigElt -= NumElts;
23746 }
23747
23748 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
23749 // TODO: Check if shuffle mask is legal?
23750 if (LegalOperations && TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VecVT) &&
23751 !VecOp.hasOneUse())
23752 return SDValue();
23753
23754 SDValue InOp = SVInVec.getOperand(OrigElt);
23755 if (InOp.getValueType() != ScalarVT) {
23756 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
23757 InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
23758 }
23759
23760 return InOp;
23761 }
23762
23763 // FIXME: We should handle recursing on other vector shuffles and
23764 // scalar_to_vector here as well.
23765
23766 if (!LegalOperations ||
23767 // FIXME: Should really be just isOperationLegalOrCustom.
23768 TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
23769 TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
23770 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
23771 DAG.getVectorIdxConstant(OrigElt, DL));
23772 }
23773 }
23774
23775 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
23776 // simplify it based on the (valid) extraction indices.
23777 if (llvm::all_of(VecOp->users(), [&](SDNode *Use) {
23778 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23779 Use->getOperand(0) == VecOp &&
23780 isa<ConstantSDNode>(Use->getOperand(1));
23781 })) {
23782 APInt DemandedElts = APInt::getZero(NumElts);
23783 for (SDNode *User : VecOp->users()) {
23784 auto *CstElt = cast<ConstantSDNode>(User->getOperand(1));
23785 if (CstElt->getAPIntValue().ult(NumElts))
23786 DemandedElts.setBit(CstElt->getZExtValue());
23787 }
23788 if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
23789 // We simplified the vector operand of this extract element. If this
23790 // extract is not dead, visit it again so it is folded properly.
23791 if (N->getOpcode() != ISD::DELETED_NODE)
23792 AddToWorklist(N);
23793 return SDValue(N, 0);
23794 }
23795 APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
23796 if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
23797 // We simplified the vector operand of this extract element. If this
23798 // extract is not dead, visit it again so it is folded properly.
23799 if (N->getOpcode() != ISD::DELETED_NODE)
23800 AddToWorklist(N);
23801 return SDValue(N, 0);
23802 }
23803 }
23804
23805 if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
23806 return SDValue(N, 0);
23807
23808 // Everything under here is trying to match an extract of a loaded value.
23809 // If the result of load has to be truncated, then it's not necessarily
23810 // profitable.
23811 bool BCNumEltsChanged = false;
23812 EVT ExtVT = VecVT.getVectorElementType();
23813 EVT LVT = ExtVT;
23814 if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
23815 return SDValue();
23816
23817 if (VecOp.getOpcode() == ISD::BITCAST) {
23818 // Don't duplicate a load with other uses.
23819 if (!VecOp.hasOneUse())
23820 return SDValue();
23821
23822 EVT BCVT = VecOp.getOperand(0).getValueType();
23823 if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
23824 return SDValue();
23825 if (NumElts != BCVT.getVectorNumElements())
23826 BCNumEltsChanged = true;
23827 VecOp = VecOp.getOperand(0);
23828 ExtVT = BCVT.getVectorElementType();
23829 }
23830
23831 // extract (vector load $addr), i --> load $addr + i * size
23832 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
23833 ISD::isNormalLoad(VecOp.getNode()) &&
23834 !Index->hasPredecessor(VecOp.getNode())) {
23835 auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
23836 if (VecLoad && VecLoad->isSimple()) {
23837 if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
23838 ScalarVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
23839 ++OpsNarrowed;
23840 return Scalarized;
23841 }
23842 }
23843 }
23844
23845 // Perform only after legalization to ensure build_vector / vector_shuffle
23846 // optimizations have already been done.
23847 if (!LegalOperations || !IndexC)
23848 return SDValue();
23849
23850 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
23851 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
23852 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
23853 int Elt = IndexC->getZExtValue();
23854 LoadSDNode *LN0 = nullptr;
23855 if (ISD::isNormalLoad(VecOp.getNode())) {
23856 LN0 = cast<LoadSDNode>(VecOp);
23857 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23858 VecOp.getOperand(0).getValueType() == ExtVT &&
23859 ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
23860 // Don't duplicate a load with other uses.
23861 if (!VecOp.hasOneUse())
23862 return SDValue();
23863
23864 LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
23865 }
23866 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
23867 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
23868 // =>
23869 // (load $addr+1*size)
23870
23871 // Don't duplicate a load with other uses.
23872 if (!VecOp.hasOneUse())
23873 return SDValue();
23874
23875 // If the bit convert changed the number of elements, it is unsafe
23876 // to examine the mask.
23877 if (BCNumEltsChanged)
23878 return SDValue();
23879
23880 // Select the input vector, guarding against out of range extract vector.
23881 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
23882 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
23883
23884 if (VecOp.getOpcode() == ISD::BITCAST) {
23885 // Don't duplicate a load with other uses.
23886 if (!VecOp.hasOneUse())
23887 return SDValue();
23888
23889 VecOp = VecOp.getOperand(0);
23890 }
23891 if (ISD::isNormalLoad(VecOp.getNode())) {
23892 LN0 = cast<LoadSDNode>(VecOp);
23893 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
23894 Index = DAG.getConstant(Elt, DL, Index.getValueType());
23895 }
23896 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
23897 VecVT.getVectorElementType() == ScalarVT &&
23898 (!LegalTypes ||
23899 TLI.isTypeLegal(
23900 VecOp.getOperand(0).getValueType().getVectorElementType()))) {
23901 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
23902 // -> extract_vector_elt a, 0
23903 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
23904 // -> extract_vector_elt a, 1
23905 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
23906 // -> extract_vector_elt b, 0
23907 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
23908 // -> extract_vector_elt b, 1
23909 EVT ConcatVT = VecOp.getOperand(0).getValueType();
23910 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
23911 SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, DL,
23912 Index.getValueType());
23913
23914 SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
23915 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
23916 ConcatVT.getVectorElementType(),
23917 ConcatOp, NewIdx);
23918 return DAG.getNode(ISD::BITCAST, DL, ScalarVT, Elt);
23919 }
23920
23921 // Make sure we found a non-volatile load and the extractelement is
23922 // the only use.
23923 if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
23924 return SDValue();
23925
23926 // If Idx was -1 above, Elt is going to be -1, so just return undef.
23927 if (Elt == -1)
23928 return DAG.getUNDEF(LVT);
23929
23930 if (SDValue Scalarized =
23931 TLI.scalarizeExtractedVectorLoad(LVT, DL, VecVT, Index, LN0, DAG)) {
23932 ++OpsNarrowed;
23933 return Scalarized;
23934 }
23935
23936 return SDValue();
23937 }
23938
23939 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)23940 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
23941 // We perform this optimization post type-legalization because
23942 // the type-legalizer often scalarizes integer-promoted vectors.
23943 // Performing this optimization before may create bit-casts which
23944 // will be type-legalized to complex code sequences.
23945 // We perform this optimization only before the operation legalizer because we
23946 // may introduce illegal operations.
23947 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
23948 return SDValue();
23949
23950 unsigned NumInScalars = N->getNumOperands();
23951 SDLoc DL(N);
23952 EVT VT = N->getValueType(0);
23953
23954 // Check to see if this is a BUILD_VECTOR of a bunch of values
23955 // which come from any_extend or zero_extend nodes. If so, we can create
23956 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
23957 // optimizations. We do not handle sign-extend because we can't fill the sign
23958 // using shuffles.
23959 EVT SourceType = MVT::Other;
23960 bool AllAnyExt = true;
23961
23962 for (unsigned i = 0; i != NumInScalars; ++i) {
23963 SDValue In = N->getOperand(i);
23964 // Ignore undef inputs.
23965 if (In.isUndef()) continue;
23966
23967 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
23968 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
23969
23970 // Abort if the element is not an extension.
23971 if (!ZeroExt && !AnyExt) {
23972 SourceType = MVT::Other;
23973 break;
23974 }
23975
23976 // The input is a ZeroExt or AnyExt. Check the original type.
23977 EVT InTy = In.getOperand(0).getValueType();
23978
23979 // Check that all of the widened source types are the same.
23980 if (SourceType == MVT::Other)
23981 // First time.
23982 SourceType = InTy;
23983 else if (InTy != SourceType) {
23984 // Multiple income types. Abort.
23985 SourceType = MVT::Other;
23986 break;
23987 }
23988
23989 // Check if all of the extends are ANY_EXTENDs.
23990 AllAnyExt &= AnyExt;
23991 }
23992
23993 // In order to have valid types, all of the inputs must be extended from the
23994 // same source type and all of the inputs must be any or zero extend.
23995 // Scalar sizes must be a power of two.
23996 EVT OutScalarTy = VT.getScalarType();
23997 bool ValidTypes =
23998 SourceType != MVT::Other &&
23999 llvm::has_single_bit<uint32_t>(OutScalarTy.getSizeInBits()) &&
24000 llvm::has_single_bit<uint32_t>(SourceType.getSizeInBits());
24001
24002 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
24003 // turn into a single shuffle instruction.
24004 if (!ValidTypes)
24005 return SDValue();
24006
24007 // If we already have a splat buildvector, then don't fold it if it means
24008 // introducing zeros.
24009 if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
24010 return SDValue();
24011
24012 bool isLE = DAG.getDataLayout().isLittleEndian();
24013 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
24014 assert(ElemRatio > 1 && "Invalid element size ratio");
24015 SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
24016 DAG.getConstant(0, DL, SourceType);
24017
24018 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
24019 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
24020
24021 // Populate the new build_vector
24022 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
24023 SDValue Cast = N->getOperand(i);
24024 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
24025 Cast.getOpcode() == ISD::ZERO_EXTEND ||
24026 Cast.isUndef()) && "Invalid cast opcode");
24027 SDValue In;
24028 if (Cast.isUndef())
24029 In = DAG.getUNDEF(SourceType);
24030 else
24031 In = Cast->getOperand(0);
24032 unsigned Index = isLE ? (i * ElemRatio) :
24033 (i * ElemRatio + (ElemRatio - 1));
24034
24035 assert(Index < Ops.size() && "Invalid index");
24036 Ops[Index] = In;
24037 }
24038
24039 // The type of the new BUILD_VECTOR node.
24040 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
24041 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
24042 "Invalid vector size");
24043 // Check if the new vector type is legal.
24044 if (!isTypeLegal(VecVT) ||
24045 (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
24046 TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
24047 return SDValue();
24048
24049 // Make the new BUILD_VECTOR.
24050 SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
24051
24052 // The new BUILD_VECTOR node has the potential to be further optimized.
24053 AddToWorklist(BV.getNode());
24054 // Bitcast to the desired type.
24055 return DAG.getBitcast(VT, BV);
24056 }
24057
24058 // Simplify (build_vec (trunc $1)
24059 // (trunc (srl $1 half-width))
24060 // (trunc (srl $1 (2 * half-width))))
24061 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)24062 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
24063 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
24064
24065 EVT VT = N->getValueType(0);
24066
24067 // Don't run this before LegalizeTypes if VT is legal.
24068 // Targets may have other preferences.
24069 if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
24070 return SDValue();
24071
24072 // Only for little endian
24073 if (!DAG.getDataLayout().isLittleEndian())
24074 return SDValue();
24075
24076 EVT OutScalarTy = VT.getScalarType();
24077 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
24078
24079 // Only for power of two types to be sure that bitcast works well
24080 if (!isPowerOf2_64(ScalarTypeBitsize))
24081 return SDValue();
24082
24083 unsigned NumInScalars = N->getNumOperands();
24084
24085 // Look through bitcasts
24086 auto PeekThroughBitcast = [](SDValue Op) {
24087 if (Op.getOpcode() == ISD::BITCAST)
24088 return Op.getOperand(0);
24089 return Op;
24090 };
24091
24092 // The source value where all the parts are extracted.
24093 SDValue Src;
24094 for (unsigned i = 0; i != NumInScalars; ++i) {
24095 SDValue In = PeekThroughBitcast(N->getOperand(i));
24096 // Ignore undef inputs.
24097 if (In.isUndef()) continue;
24098
24099 if (In.getOpcode() != ISD::TRUNCATE)
24100 return SDValue();
24101
24102 In = PeekThroughBitcast(In.getOperand(0));
24103
24104 if (In.getOpcode() != ISD::SRL) {
24105 // For now only build_vec without shuffling, handle shifts here in the
24106 // future.
24107 if (i != 0)
24108 return SDValue();
24109
24110 Src = In;
24111 } else {
24112 // In is SRL
24113 SDValue part = PeekThroughBitcast(In.getOperand(0));
24114
24115 if (!Src) {
24116 Src = part;
24117 } else if (Src != part) {
24118 // Vector parts do not stem from the same variable
24119 return SDValue();
24120 }
24121
24122 SDValue ShiftAmtVal = In.getOperand(1);
24123 if (!isa<ConstantSDNode>(ShiftAmtVal))
24124 return SDValue();
24125
24126 uint64_t ShiftAmt = In.getConstantOperandVal(1);
24127
24128 // The extracted value is not extracted at the right position
24129 if (ShiftAmt != i * ScalarTypeBitsize)
24130 return SDValue();
24131 }
24132 }
24133
24134 // Only cast if the size is the same
24135 if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
24136 return SDValue();
24137
24138 return DAG.getBitcast(VT, Src);
24139 }
24140
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)24141 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
24142 ArrayRef<int> VectorMask,
24143 SDValue VecIn1, SDValue VecIn2,
24144 unsigned LeftIdx, bool DidSplitVec) {
24145 EVT VT = N->getValueType(0);
24146 EVT InVT1 = VecIn1.getValueType();
24147 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
24148
24149 unsigned NumElems = VT.getVectorNumElements();
24150 unsigned ShuffleNumElems = NumElems;
24151
24152 // If we artificially split a vector in two already, then the offsets in the
24153 // operands will all be based off of VecIn1, even those in VecIn2.
24154 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
24155
24156 uint64_t VTSize = VT.getFixedSizeInBits();
24157 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
24158 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
24159
24160 assert(InVT2Size <= InVT1Size &&
24161 "Inputs must be sorted to be in non-increasing vector size order.");
24162
24163 // We can't generate a shuffle node with mismatched input and output types.
24164 // Try to make the types match the type of the output.
24165 if (InVT1 != VT || InVT2 != VT) {
24166 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
24167 // If the output vector length is a multiple of both input lengths,
24168 // we can concatenate them and pad the rest with undefs.
24169 unsigned NumConcats = VTSize / InVT1Size;
24170 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
24171 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
24172 ConcatOps[0] = VecIn1;
24173 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
24174 VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
24175 VecIn2 = SDValue();
24176 } else if (InVT1Size == VTSize * 2) {
24177 if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
24178 return SDValue();
24179
24180 if (!VecIn2.getNode()) {
24181 // If we only have one input vector, and it's twice the size of the
24182 // output, split it in two.
24183 VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
24184 DAG.getVectorIdxConstant(NumElems, DL));
24185 VecIn1 = DAG.getExtractSubvector(DL, VT, VecIn1, 0);
24186 // Since we now have shorter input vectors, adjust the offset of the
24187 // second vector's start.
24188 Vec2Offset = NumElems;
24189 } else {
24190 assert(InVT2Size <= InVT1Size &&
24191 "Second input is not going to be larger than the first one.");
24192
24193 // VecIn1 is wider than the output, and we have another, possibly
24194 // smaller input. Pad the smaller input with undefs, shuffle at the
24195 // input vector width, and extract the output.
24196 // The shuffle type is different than VT, so check legality again.
24197 if (LegalOperations &&
24198 !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
24199 return SDValue();
24200
24201 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
24202 // lower it back into a BUILD_VECTOR. So if the inserted type is
24203 // illegal, don't even try.
24204 if (InVT1 != InVT2) {
24205 if (!TLI.isTypeLegal(InVT2))
24206 return SDValue();
24207 VecIn2 = DAG.getInsertSubvector(DL, DAG.getUNDEF(InVT1), VecIn2, 0);
24208 }
24209 ShuffleNumElems = NumElems * 2;
24210 }
24211 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
24212 SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
24213 ConcatOps[0] = VecIn2;
24214 VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
24215 } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
24216 if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems) ||
24217 !TLI.isTypeLegal(InVT1) || !TLI.isTypeLegal(InVT2))
24218 return SDValue();
24219 // If dest vector has less than two elements, then use shuffle and extract
24220 // from larger regs will cost even more.
24221 if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
24222 return SDValue();
24223 assert(InVT2Size <= InVT1Size &&
24224 "Second input is not going to be larger than the first one.");
24225
24226 // VecIn1 is wider than the output, and we have another, possibly
24227 // smaller input. Pad the smaller input with undefs, shuffle at the
24228 // input vector width, and extract the output.
24229 // The shuffle type is different than VT, so check legality again.
24230 if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
24231 return SDValue();
24232
24233 if (InVT1 != InVT2) {
24234 VecIn2 = DAG.getInsertSubvector(DL, DAG.getUNDEF(InVT1), VecIn2, 0);
24235 }
24236 ShuffleNumElems = InVT1Size / VTSize * NumElems;
24237 } else {
24238 // TODO: Support cases where the length mismatch isn't exactly by a
24239 // factor of 2.
24240 // TODO: Move this check upwards, so that if we have bad type
24241 // mismatches, we don't create any DAG nodes.
24242 return SDValue();
24243 }
24244 }
24245
24246 // Initialize mask to undef.
24247 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
24248
24249 // Only need to run up to the number of elements actually used, not the
24250 // total number of elements in the shuffle - if we are shuffling a wider
24251 // vector, the high lanes should be set to undef.
24252 for (unsigned i = 0; i != NumElems; ++i) {
24253 if (VectorMask[i] <= 0)
24254 continue;
24255
24256 unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
24257 if (VectorMask[i] == (int)LeftIdx) {
24258 Mask[i] = ExtIndex;
24259 } else if (VectorMask[i] == (int)LeftIdx + 1) {
24260 Mask[i] = Vec2Offset + ExtIndex;
24261 }
24262 }
24263
24264 // The type the input vectors may have changed above.
24265 InVT1 = VecIn1.getValueType();
24266
24267 // If we already have a VecIn2, it should have the same type as VecIn1.
24268 // If we don't, get an undef/zero vector of the appropriate type.
24269 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
24270 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
24271
24272 SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
24273 if (ShuffleNumElems > NumElems)
24274 Shuffle = DAG.getExtractSubvector(DL, VT, Shuffle, 0);
24275
24276 return Shuffle;
24277 }
24278
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)24279 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
24280 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
24281
24282 // First, determine where the build vector is not undef.
24283 // TODO: We could extend this to handle zero elements as well as undefs.
24284 int NumBVOps = BV->getNumOperands();
24285 int ZextElt = -1;
24286 for (int i = 0; i != NumBVOps; ++i) {
24287 SDValue Op = BV->getOperand(i);
24288 if (Op.isUndef())
24289 continue;
24290 if (ZextElt == -1)
24291 ZextElt = i;
24292 else
24293 return SDValue();
24294 }
24295 // Bail out if there's no non-undef element.
24296 if (ZextElt == -1)
24297 return SDValue();
24298
24299 // The build vector contains some number of undef elements and exactly
24300 // one other element. That other element must be a zero-extended scalar
24301 // extracted from a vector at a constant index to turn this into a shuffle.
24302 // Also, require that the build vector does not implicitly truncate/extend
24303 // its elements.
24304 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
24305 EVT VT = BV->getValueType(0);
24306 SDValue Zext = BV->getOperand(ZextElt);
24307 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
24308 Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
24309 !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
24310 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
24311 return SDValue();
24312
24313 // The zero-extend must be a multiple of the source size, and we must be
24314 // building a vector of the same size as the source of the extract element.
24315 SDValue Extract = Zext.getOperand(0);
24316 unsigned DestSize = Zext.getValueSizeInBits();
24317 unsigned SrcSize = Extract.getValueSizeInBits();
24318 if (DestSize % SrcSize != 0 ||
24319 Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
24320 return SDValue();
24321
24322 // Create a shuffle mask that will combine the extracted element with zeros
24323 // and undefs.
24324 int ZextRatio = DestSize / SrcSize;
24325 int NumMaskElts = NumBVOps * ZextRatio;
24326 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
24327 for (int i = 0; i != NumMaskElts; ++i) {
24328 if (i / ZextRatio == ZextElt) {
24329 // The low bits of the (potentially translated) extracted element map to
24330 // the source vector. The high bits map to zero. We will use a zero vector
24331 // as the 2nd source operand of the shuffle, so use the 1st element of
24332 // that vector (mask value is number-of-elements) for the high bits.
24333 int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
24334 ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(1)
24335 : NumMaskElts;
24336 }
24337
24338 // Undef elements of the build vector remain undef because we initialize
24339 // the shuffle mask with -1.
24340 }
24341
24342 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
24343 // bitcast (shuffle V, ZeroVec, VectorMask)
24344 SDLoc DL(BV);
24345 EVT VecVT = Extract.getOperand(0).getValueType();
24346 SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
24347 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24348 SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
24349 ZeroVec, ShufMask, DAG);
24350 if (!Shuf)
24351 return SDValue();
24352 return DAG.getBitcast(VT, Shuf);
24353 }
24354
24355 // FIXME: promote to STLExtras.
24356 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)24357 static auto getFirstIndexOf(R &&Range, const T &Val) {
24358 auto I = find(Range, Val);
24359 if (I == Range.end())
24360 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
24361 return std::distance(Range.begin(), I);
24362 }
24363
24364 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
24365 // operations. If the types of the vectors we're extracting from allow it,
24366 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)24367 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
24368 SDLoc DL(N);
24369 EVT VT = N->getValueType(0);
24370
24371 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
24372 if (!isTypeLegal(VT))
24373 return SDValue();
24374
24375 if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
24376 return V;
24377
24378 // May only combine to shuffle after legalize if shuffle is legal.
24379 if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
24380 return SDValue();
24381
24382 bool UsesZeroVector = false;
24383 unsigned NumElems = N->getNumOperands();
24384
24385 // Record, for each element of the newly built vector, which input vector
24386 // that element comes from. -1 stands for undef, 0 for the zero vector,
24387 // and positive values for the input vectors.
24388 // VectorMask maps each element to its vector number, and VecIn maps vector
24389 // numbers to their initial SDValues.
24390
24391 SmallVector<int, 8> VectorMask(NumElems, -1);
24392 SmallVector<SDValue, 8> VecIn;
24393 VecIn.push_back(SDValue());
24394
24395 // If we have a single extract_element with a constant index, track the index
24396 // value.
24397 unsigned OneConstExtractIndex = ~0u;
24398
24399 // Count the number of extract_vector_elt sources (i.e. non-constant or undef)
24400 unsigned NumExtracts = 0;
24401
24402 for (unsigned i = 0; i != NumElems; ++i) {
24403 SDValue Op = N->getOperand(i);
24404
24405 if (Op.isUndef())
24406 continue;
24407
24408 // See if we can use a blend with a zero vector.
24409 // TODO: Should we generalize this to a blend with an arbitrary constant
24410 // vector?
24411 if (isNullConstant(Op) || isNullFPConstant(Op)) {
24412 UsesZeroVector = true;
24413 VectorMask[i] = 0;
24414 continue;
24415 }
24416
24417 // Not an undef or zero. If the input is something other than an
24418 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
24419 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
24420 return SDValue();
24421
24422 SDValue ExtractedFromVec = Op.getOperand(0);
24423 if (ExtractedFromVec.getValueType().isScalableVector())
24424 return SDValue();
24425 auto *ExtractIdx = dyn_cast<ConstantSDNode>(Op.getOperand(1));
24426 if (!ExtractIdx)
24427 return SDValue();
24428
24429 if (ExtractIdx->getAsAPIntVal().uge(
24430 ExtractedFromVec.getValueType().getVectorNumElements()))
24431 return SDValue();
24432
24433 // All inputs must have the same element type as the output.
24434 if (VT.getVectorElementType() !=
24435 ExtractedFromVec.getValueType().getVectorElementType())
24436 return SDValue();
24437
24438 OneConstExtractIndex = ExtractIdx->getZExtValue();
24439 ++NumExtracts;
24440
24441 // Have we seen this input vector before?
24442 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
24443 // a map back from SDValues to numbers isn't worth it.
24444 int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
24445 if (Idx == -1) { // A new source vector?
24446 Idx = VecIn.size();
24447 VecIn.push_back(ExtractedFromVec);
24448 }
24449
24450 VectorMask[i] = Idx;
24451 }
24452
24453 // If we didn't find at least one input vector, bail out.
24454 if (VecIn.size() < 2)
24455 return SDValue();
24456
24457 // If all the Operands of BUILD_VECTOR extract from same
24458 // vector, then split the vector efficiently based on the maximum
24459 // vector access index and adjust the VectorMask and
24460 // VecIn accordingly.
24461 bool DidSplitVec = false;
24462 if (VecIn.size() == 2) {
24463 // If we only found a single constant indexed extract_vector_elt feeding the
24464 // build_vector, do not produce a more complicated shuffle if the extract is
24465 // cheap with other constant/undef elements. Skip broadcast patterns with
24466 // multiple uses in the build_vector.
24467
24468 // TODO: This should be more aggressive about skipping the shuffle
24469 // formation, particularly if VecIn[1].hasOneUse(), and regardless of the
24470 // index.
24471 if (NumExtracts == 1 &&
24472 TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, VT) &&
24473 TLI.isTypeLegal(VT.getVectorElementType()) &&
24474 TLI.isExtractVecEltCheap(VT, OneConstExtractIndex))
24475 return SDValue();
24476
24477 unsigned MaxIndex = 0;
24478 unsigned NearestPow2 = 0;
24479 SDValue Vec = VecIn.back();
24480 EVT InVT = Vec.getValueType();
24481 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
24482
24483 for (unsigned i = 0; i < NumElems; i++) {
24484 if (VectorMask[i] <= 0)
24485 continue;
24486 unsigned Index = N->getOperand(i).getConstantOperandVal(1);
24487 IndexVec[i] = Index;
24488 MaxIndex = std::max(MaxIndex, Index);
24489 }
24490
24491 NearestPow2 = PowerOf2Ceil(MaxIndex);
24492 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
24493 NumElems * 2 < NearestPow2) {
24494 unsigned SplitSize = NearestPow2 / 2;
24495 EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
24496 InVT.getVectorElementType(), SplitSize);
24497 if (TLI.isTypeLegal(SplitVT) &&
24498 SplitSize + SplitVT.getVectorNumElements() <=
24499 InVT.getVectorNumElements()) {
24500 SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
24501 DAG.getVectorIdxConstant(SplitSize, DL));
24502 SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
24503 DAG.getVectorIdxConstant(0, DL));
24504 VecIn.pop_back();
24505 VecIn.push_back(VecIn1);
24506 VecIn.push_back(VecIn2);
24507 DidSplitVec = true;
24508
24509 for (unsigned i = 0; i < NumElems; i++) {
24510 if (VectorMask[i] <= 0)
24511 continue;
24512 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
24513 }
24514 }
24515 }
24516 }
24517
24518 // Sort input vectors by decreasing vector element count,
24519 // while preserving the relative order of equally-sized vectors.
24520 // Note that we keep the first "implicit zero vector as-is.
24521 SmallVector<SDValue, 8> SortedVecIn(VecIn);
24522 llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
24523 [](const SDValue &a, const SDValue &b) {
24524 return a.getValueType().getVectorNumElements() >
24525 b.getValueType().getVectorNumElements();
24526 });
24527
24528 // We now also need to rebuild the VectorMask, because it referenced element
24529 // order in VecIn, and we just sorted them.
24530 for (int &SourceVectorIndex : VectorMask) {
24531 if (SourceVectorIndex <= 0)
24532 continue;
24533 unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
24534 assert(Idx > 0 && Idx < SortedVecIn.size() &&
24535 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
24536 SourceVectorIndex = Idx;
24537 }
24538
24539 VecIn = std::move(SortedVecIn);
24540
24541 // TODO: Should this fire if some of the input vectors has illegal type (like
24542 // it does now), or should we let legalization run its course first?
24543
24544 // Shuffle phase:
24545 // Take pairs of vectors, and shuffle them so that the result has elements
24546 // from these vectors in the correct places.
24547 // For example, given:
24548 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
24549 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
24550 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
24551 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
24552 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
24553 // We will generate:
24554 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
24555 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
24556 SmallVector<SDValue, 4> Shuffles;
24557 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
24558 unsigned LeftIdx = 2 * In + 1;
24559 SDValue VecLeft = VecIn[LeftIdx];
24560 SDValue VecRight =
24561 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
24562
24563 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
24564 VecRight, LeftIdx, DidSplitVec))
24565 Shuffles.push_back(Shuffle);
24566 else
24567 return SDValue();
24568 }
24569
24570 // If we need the zero vector as an "ingredient" in the blend tree, add it
24571 // to the list of shuffles.
24572 if (UsesZeroVector)
24573 Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
24574 : DAG.getConstantFP(0.0, DL, VT));
24575
24576 // If we only have one shuffle, we're done.
24577 if (Shuffles.size() == 1)
24578 return Shuffles[0];
24579
24580 // Update the vector mask to point to the post-shuffle vectors.
24581 for (int &Vec : VectorMask)
24582 if (Vec == 0)
24583 Vec = Shuffles.size() - 1;
24584 else
24585 Vec = (Vec - 1) / 2;
24586
24587 // More than one shuffle. Generate a binary tree of blends, e.g. if from
24588 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
24589 // generate:
24590 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
24591 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
24592 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
24593 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
24594 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
24595 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
24596 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
24597
24598 // Make sure the initial size of the shuffle list is even.
24599 if (Shuffles.size() % 2)
24600 Shuffles.push_back(DAG.getUNDEF(VT));
24601
24602 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
24603 if (CurSize % 2) {
24604 Shuffles[CurSize] = DAG.getUNDEF(VT);
24605 CurSize++;
24606 }
24607 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
24608 int Left = 2 * In;
24609 int Right = 2 * In + 1;
24610 SmallVector<int, 8> Mask(NumElems, -1);
24611 SDValue L = Shuffles[Left];
24612 ArrayRef<int> LMask;
24613 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
24614 L.use_empty() && L.getOperand(1).isUndef() &&
24615 L.getOperand(0).getValueType() == L.getValueType();
24616 if (IsLeftShuffle) {
24617 LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
24618 L = L.getOperand(0);
24619 }
24620 SDValue R = Shuffles[Right];
24621 ArrayRef<int> RMask;
24622 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
24623 R.use_empty() && R.getOperand(1).isUndef() &&
24624 R.getOperand(0).getValueType() == R.getValueType();
24625 if (IsRightShuffle) {
24626 RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
24627 R = R.getOperand(0);
24628 }
24629 for (unsigned I = 0; I != NumElems; ++I) {
24630 if (VectorMask[I] == Left) {
24631 Mask[I] = I;
24632 if (IsLeftShuffle)
24633 Mask[I] = LMask[I];
24634 VectorMask[I] = In;
24635 } else if (VectorMask[I] == Right) {
24636 Mask[I] = I + NumElems;
24637 if (IsRightShuffle)
24638 Mask[I] = RMask[I] + NumElems;
24639 VectorMask[I] = In;
24640 }
24641 }
24642
24643 Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
24644 }
24645 }
24646 return Shuffles[0];
24647 }
24648
24649 // Try to turn a build vector of zero extends of extract vector elts into a
24650 // a vector zero extend and possibly an extract subvector.
24651 // TODO: Support sign extend?
24652 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)24653 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
24654 if (LegalOperations)
24655 return SDValue();
24656
24657 EVT VT = N->getValueType(0);
24658
24659 bool FoundZeroExtend = false;
24660 SDValue Op0 = N->getOperand(0);
24661 auto checkElem = [&](SDValue Op) -> int64_t {
24662 unsigned Opc = Op.getOpcode();
24663 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
24664 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
24665 Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24666 Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
24667 if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
24668 return C->getZExtValue();
24669 return -1;
24670 };
24671
24672 // Make sure the first element matches
24673 // (zext (extract_vector_elt X, C))
24674 // Offset must be a constant multiple of the
24675 // known-minimum vector length of the result type.
24676 int64_t Offset = checkElem(Op0);
24677 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
24678 return SDValue();
24679
24680 unsigned NumElems = N->getNumOperands();
24681 SDValue In = Op0.getOperand(0).getOperand(0);
24682 EVT InSVT = In.getValueType().getScalarType();
24683 EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
24684
24685 // Don't create an illegal input type after type legalization.
24686 if (LegalTypes && !TLI.isTypeLegal(InVT))
24687 return SDValue();
24688
24689 // Ensure all the elements come from the same vector and are adjacent.
24690 for (unsigned i = 1; i != NumElems; ++i) {
24691 if ((Offset + i) != checkElem(N->getOperand(i)))
24692 return SDValue();
24693 }
24694
24695 SDLoc DL(N);
24696 In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
24697 Op0.getOperand(0).getOperand(1));
24698 return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
24699 VT, In);
24700 }
24701
24702 // If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
24703 // and all other elements being constant zero's, granularize the BUILD_VECTOR's
24704 // element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
24705 // This patten can appear during legalization.
24706 //
24707 // NOTE: This can be generalized to allow more than a single
24708 // non-constant-zero op, UNDEF's, and to be KnownBits-based,
convertBuildVecZextToBuildVecWithZeros(SDNode * N)24709 SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
24710 // Don't run this after legalization. Targets may have other preferences.
24711 if (Level >= AfterLegalizeDAG)
24712 return SDValue();
24713
24714 // FIXME: support big-endian.
24715 if (DAG.getDataLayout().isBigEndian())
24716 return SDValue();
24717
24718 EVT VT = N->getValueType(0);
24719 EVT OpVT = N->getOperand(0).getValueType();
24720 assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
24721
24722 EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
24723
24724 if (!TLI.isTypeLegal(OpIntVT) ||
24725 (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
24726 return SDValue();
24727
24728 unsigned EltBitwidth = VT.getScalarSizeInBits();
24729 // NOTE: the actual width of operands may be wider than that!
24730
24731 // Analyze all operands of this BUILD_VECTOR. What is the largest number of
24732 // active bits they all have? We'll want to truncate them all to that width.
24733 unsigned ActiveBits = 0;
24734 APInt KnownZeroOps(VT.getVectorNumElements(), 0);
24735 for (auto I : enumerate(N->ops())) {
24736 SDValue Op = I.value();
24737 // FIXME: support UNDEF elements?
24738 if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) {
24739 unsigned OpActiveBits =
24740 Cst->getAPIntValue().trunc(EltBitwidth).getActiveBits();
24741 if (OpActiveBits == 0) {
24742 KnownZeroOps.setBit(I.index());
24743 continue;
24744 }
24745 // Profitability check: don't allow non-zero constant operands.
24746 return SDValue();
24747 }
24748 // Profitability check: there must only be a single non-zero operand,
24749 // and it must be the first operand of the BUILD_VECTOR.
24750 if (I.index() != 0)
24751 return SDValue();
24752 // The operand must be a zero-extension itself.
24753 // FIXME: this could be generalized to known leading zeros check.
24754 if (Op.getOpcode() != ISD::ZERO_EXTEND)
24755 return SDValue();
24756 unsigned CurrActiveBits =
24757 Op.getOperand(0).getValueSizeInBits().getFixedValue();
24758 assert(!ActiveBits && "Already encountered non-constant-zero operand?");
24759 ActiveBits = CurrActiveBits;
24760 // We want to at least halve the element size.
24761 if (2 * ActiveBits > EltBitwidth)
24762 return SDValue();
24763 }
24764
24765 // This BUILD_VECTOR must have at least one non-constant-zero operand.
24766 if (ActiveBits == 0)
24767 return SDValue();
24768
24769 // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
24770 // into how many chunks can we split our element width?
24771 EVT NewScalarIntVT, NewIntVT;
24772 std::optional<unsigned> Factor;
24773 // We can split the element into at least two chunks, but not into more
24774 // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
24775 // for which the element width is a multiple of it,
24776 // and the resulting types/operations on that chunk width are legal.
24777 assert(2 * ActiveBits <= EltBitwidth &&
24778 "We know that half or less bits of the element are active.");
24779 for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
24780 if (EltBitwidth % Scale != 0)
24781 continue;
24782 unsigned ChunkBitwidth = EltBitwidth / Scale;
24783 assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
24784 NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
24785 NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
24786 Scale * N->getNumOperands());
24787 if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
24788 (LegalOperations &&
24789 !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
24790 TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
24791 continue;
24792 Factor = Scale;
24793 break;
24794 }
24795 if (!Factor)
24796 return SDValue();
24797
24798 SDLoc DL(N);
24799 SDValue ZeroOp = DAG.getConstant(0, DL, NewScalarIntVT);
24800
24801 // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
24802 SmallVector<SDValue, 16> NewOps;
24803 NewOps.reserve(NewIntVT.getVectorNumElements());
24804 for (auto I : enumerate(N->ops())) {
24805 SDValue Op = I.value();
24806 assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
24807 unsigned SrcOpIdx = I.index();
24808 if (KnownZeroOps[SrcOpIdx]) {
24809 NewOps.append(*Factor, ZeroOp);
24810 continue;
24811 }
24812 Op = DAG.getBitcast(OpIntVT, Op);
24813 Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
24814 NewOps.emplace_back(Op);
24815 NewOps.append(*Factor - 1, ZeroOp);
24816 }
24817 assert(NewOps.size() == NewIntVT.getVectorNumElements());
24818 SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
24819 NewBV = DAG.getBitcast(VT, NewBV);
24820 return NewBV;
24821 }
24822
visitBUILD_VECTOR(SDNode * N)24823 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
24824 EVT VT = N->getValueType(0);
24825
24826 // A vector built entirely of undefs is undef.
24827 if (ISD::allOperandsUndef(N))
24828 return DAG.getUNDEF(VT);
24829
24830 // If this is a splat of a bitcast from another vector, change to a
24831 // concat_vector.
24832 // For example:
24833 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
24834 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
24835 //
24836 // If X is a build_vector itself, the concat can become a larger build_vector.
24837 // TODO: Maybe this is useful for non-splat too?
24838 if (!LegalOperations) {
24839 SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue();
24840 // Only change build_vector to a concat_vector if the splat value type is
24841 // same as the vector element type.
24842 if (Splat && Splat.getValueType() == VT.getVectorElementType()) {
24843 Splat = peekThroughBitcasts(Splat);
24844 EVT SrcVT = Splat.getValueType();
24845 if (SrcVT.isVector()) {
24846 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
24847 EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
24848 SrcVT.getVectorElementType(), NumElts);
24849 if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
24850 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
24851 SDValue Concat =
24852 DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), NewVT, Ops);
24853 return DAG.getBitcast(VT, Concat);
24854 }
24855 }
24856 }
24857 }
24858
24859 // Check if we can express BUILD VECTOR via subvector extract.
24860 if (!LegalTypes && (N->getNumOperands() > 1)) {
24861 SDValue Op0 = N->getOperand(0);
24862 auto checkElem = [&](SDValue Op) -> uint64_t {
24863 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
24864 (Op0.getOperand(0) == Op.getOperand(0)))
24865 if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
24866 return CNode->getZExtValue();
24867 return -1;
24868 };
24869
24870 int Offset = checkElem(Op0);
24871 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
24872 if (Offset + i != checkElem(N->getOperand(i))) {
24873 Offset = -1;
24874 break;
24875 }
24876 }
24877
24878 if ((Offset == 0) &&
24879 (Op0.getOperand(0).getValueType() == N->getValueType(0)))
24880 return Op0.getOperand(0);
24881 if ((Offset != -1) &&
24882 ((Offset % N->getValueType(0).getVectorNumElements()) ==
24883 0)) // IDX must be multiple of output size.
24884 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
24885 Op0.getOperand(0), Op0.getOperand(1));
24886 }
24887
24888 if (SDValue V = convertBuildVecZextToZext(N))
24889 return V;
24890
24891 if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
24892 return V;
24893
24894 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
24895 return V;
24896
24897 if (SDValue V = reduceBuildVecTruncToBitCast(N))
24898 return V;
24899
24900 if (SDValue V = reduceBuildVecToShuffle(N))
24901 return V;
24902
24903 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
24904 // Do this late as some of the above may replace the splat.
24905 if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
24906 if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
24907 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
24908 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
24909 }
24910
24911 return SDValue();
24912 }
24913
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)24914 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
24915 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24916 EVT OpVT = N->getOperand(0).getValueType();
24917
24918 // If the operands are legal vectors, leave them alone.
24919 if (TLI.isTypeLegal(OpVT) || OpVT.isScalableVector())
24920 return SDValue();
24921
24922 SDLoc DL(N);
24923 EVT VT = N->getValueType(0);
24924 SmallVector<SDValue, 8> Ops;
24925 EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
24926
24927 // Keep track of what we encounter.
24928 EVT AnyFPVT;
24929
24930 for (const SDValue &Op : N->ops()) {
24931 if (ISD::BITCAST == Op.getOpcode() &&
24932 !Op.getOperand(0).getValueType().isVector())
24933 Ops.push_back(Op.getOperand(0));
24934 else if (Op.isUndef())
24935 Ops.push_back(DAG.getNode(ISD::UNDEF, DL, SVT));
24936 else
24937 return SDValue();
24938
24939 // Note whether we encounter an integer or floating point scalar.
24940 // If it's neither, bail out, it could be something weird like x86mmx.
24941 EVT LastOpVT = Ops.back().getValueType();
24942 if (LastOpVT.isFloatingPoint())
24943 AnyFPVT = LastOpVT;
24944 else if (!LastOpVT.isInteger())
24945 return SDValue();
24946 }
24947
24948 // If any of the operands is a floating point scalar bitcast to a vector,
24949 // use floating point types throughout, and bitcast everything.
24950 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
24951 if (AnyFPVT != EVT()) {
24952 SVT = AnyFPVT;
24953 for (SDValue &Op : Ops) {
24954 if (Op.getValueType() == SVT)
24955 continue;
24956 if (Op.isUndef())
24957 Op = DAG.getNode(ISD::UNDEF, DL, SVT);
24958 else
24959 Op = DAG.getBitcast(SVT, Op);
24960 }
24961 }
24962
24963 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
24964 VT.getSizeInBits() / SVT.getSizeInBits());
24965 return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
24966 }
24967
24968 // Attempt to merge nested concat_vectors/undefs.
24969 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
24970 // --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)24971 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
24972 SelectionDAG &DAG) {
24973 EVT VT = N->getValueType(0);
24974
24975 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
24976 EVT SubVT;
24977 SDValue FirstConcat;
24978 for (const SDValue &Op : N->ops()) {
24979 if (Op.isUndef())
24980 continue;
24981 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
24982 return SDValue();
24983 if (!FirstConcat) {
24984 SubVT = Op.getOperand(0).getValueType();
24985 if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
24986 return SDValue();
24987 FirstConcat = Op;
24988 continue;
24989 }
24990 if (SubVT != Op.getOperand(0).getValueType())
24991 return SDValue();
24992 }
24993 assert(FirstConcat && "Concat of all-undefs found");
24994
24995 SmallVector<SDValue> ConcatOps;
24996 for (const SDValue &Op : N->ops()) {
24997 if (Op.isUndef()) {
24998 ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
24999 continue;
25000 }
25001 ConcatOps.append(Op->op_begin(), Op->op_end());
25002 }
25003 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
25004 }
25005
25006 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
25007 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
25008 // most two distinct vectors the same size as the result, attempt to turn this
25009 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)25010 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
25011 EVT VT = N->getValueType(0);
25012 EVT OpVT = N->getOperand(0).getValueType();
25013
25014 // We currently can't generate an appropriate shuffle for a scalable vector.
25015 if (VT.isScalableVector())
25016 return SDValue();
25017
25018 int NumElts = VT.getVectorNumElements();
25019 int NumOpElts = OpVT.getVectorNumElements();
25020
25021 SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
25022 SmallVector<int, 8> Mask;
25023
25024 for (SDValue Op : N->ops()) {
25025 Op = peekThroughBitcasts(Op);
25026
25027 // UNDEF nodes convert to UNDEF shuffle mask values.
25028 if (Op.isUndef()) {
25029 Mask.append((unsigned)NumOpElts, -1);
25030 continue;
25031 }
25032
25033 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
25034 return SDValue();
25035
25036 // What vector are we extracting the subvector from and at what index?
25037 SDValue ExtVec = Op.getOperand(0);
25038 int ExtIdx = Op.getConstantOperandVal(1);
25039
25040 // We want the EVT of the original extraction to correctly scale the
25041 // extraction index.
25042 EVT ExtVT = ExtVec.getValueType();
25043 ExtVec = peekThroughBitcasts(ExtVec);
25044
25045 // UNDEF nodes convert to UNDEF shuffle mask values.
25046 if (ExtVec.isUndef()) {
25047 Mask.append((unsigned)NumOpElts, -1);
25048 continue;
25049 }
25050
25051 // Ensure that we are extracting a subvector from a vector the same
25052 // size as the result.
25053 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
25054 return SDValue();
25055
25056 // Scale the subvector index to account for any bitcast.
25057 int NumExtElts = ExtVT.getVectorNumElements();
25058 if (0 == (NumExtElts % NumElts))
25059 ExtIdx /= (NumExtElts / NumElts);
25060 else if (0 == (NumElts % NumExtElts))
25061 ExtIdx *= (NumElts / NumExtElts);
25062 else
25063 return SDValue();
25064
25065 // At most we can reference 2 inputs in the final shuffle.
25066 if (SV0.isUndef() || SV0 == ExtVec) {
25067 SV0 = ExtVec;
25068 for (int i = 0; i != NumOpElts; ++i)
25069 Mask.push_back(i + ExtIdx);
25070 } else if (SV1.isUndef() || SV1 == ExtVec) {
25071 SV1 = ExtVec;
25072 for (int i = 0; i != NumOpElts; ++i)
25073 Mask.push_back(i + ExtIdx + NumElts);
25074 } else {
25075 return SDValue();
25076 }
25077 }
25078
25079 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25080 return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
25081 DAG.getBitcast(VT, SV1), Mask, DAG);
25082 }
25083
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)25084 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
25085 unsigned CastOpcode = N->getOperand(0).getOpcode();
25086 switch (CastOpcode) {
25087 case ISD::SINT_TO_FP:
25088 case ISD::UINT_TO_FP:
25089 case ISD::FP_TO_SINT:
25090 case ISD::FP_TO_UINT:
25091 // TODO: Allow more opcodes?
25092 // case ISD::BITCAST:
25093 // case ISD::TRUNCATE:
25094 // case ISD::ZERO_EXTEND:
25095 // case ISD::SIGN_EXTEND:
25096 // case ISD::FP_EXTEND:
25097 break;
25098 default:
25099 return SDValue();
25100 }
25101
25102 EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
25103 if (!SrcVT.isVector())
25104 return SDValue();
25105
25106 // All operands of the concat must be the same kind of cast from the same
25107 // source type.
25108 SmallVector<SDValue, 4> SrcOps;
25109 for (SDValue Op : N->ops()) {
25110 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
25111 Op.getOperand(0).getValueType() != SrcVT)
25112 return SDValue();
25113 SrcOps.push_back(Op.getOperand(0));
25114 }
25115
25116 // The wider cast must be supported by the target. This is unusual because
25117 // the operation support type parameter depends on the opcode. In addition,
25118 // check the other type in the cast to make sure this is really legal.
25119 EVT VT = N->getValueType(0);
25120 EVT SrcEltVT = SrcVT.getVectorElementType();
25121 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
25122 EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
25123 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25124 switch (CastOpcode) {
25125 case ISD::SINT_TO_FP:
25126 case ISD::UINT_TO_FP:
25127 if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
25128 !TLI.isTypeLegal(VT))
25129 return SDValue();
25130 break;
25131 case ISD::FP_TO_SINT:
25132 case ISD::FP_TO_UINT:
25133 if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
25134 !TLI.isTypeLegal(ConcatSrcVT))
25135 return SDValue();
25136 break;
25137 default:
25138 llvm_unreachable("Unexpected cast opcode");
25139 }
25140
25141 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
25142 SDLoc DL(N);
25143 SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
25144 return DAG.getNode(CastOpcode, DL, VT, NewConcat);
25145 }
25146
25147 // See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
25148 // the operands is a SHUFFLE_VECTOR, and all other operands are also operands
25149 // to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
combineConcatVectorOfShuffleAndItsOperands(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)25150 static SDValue combineConcatVectorOfShuffleAndItsOperands(
25151 SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
25152 bool LegalOperations) {
25153 EVT VT = N->getValueType(0);
25154 EVT OpVT = N->getOperand(0).getValueType();
25155 if (VT.isScalableVector())
25156 return SDValue();
25157
25158 // For now, only allow simple 2-operand concatenations.
25159 if (N->getNumOperands() != 2)
25160 return SDValue();
25161
25162 // Don't create illegal types/shuffles when not allowed to.
25163 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
25164 (LegalOperations &&
25165 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
25166 return SDValue();
25167
25168 // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
25169 // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
25170 // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
25171 // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
25172 // (4) and for now, the SHUFFLE_VECTOR must be unary.
25173 ShuffleVectorSDNode *SVN = nullptr;
25174 for (SDValue Op : N->ops()) {
25175 if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
25176 CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
25177 all_of(N->ops(), [CurSVN](SDValue Op) {
25178 // FIXME: can we allow UNDEF operands?
25179 return !Op.isUndef() &&
25180 (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
25181 })) {
25182 SVN = CurSVN;
25183 break;
25184 }
25185 }
25186 if (!SVN)
25187 return SDValue();
25188
25189 // We are going to pad the shuffle operands, so any indice, that was picking
25190 // from the second operand, must be adjusted.
25191 SmallVector<int, 16> AdjustedMask(SVN->getMask());
25192 assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
25193
25194 // Identity masks for the operands of the (padded) shuffle.
25195 SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
25196 MutableArrayRef<int> FirstShufOpIdentityMask =
25197 MutableArrayRef<int>(IdentityMask)
25198 .take_front(OpVT.getVectorNumElements());
25199 MutableArrayRef<int> SecondShufOpIdentityMask =
25200 MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
25201 std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
25202 std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
25203 VT.getVectorNumElements());
25204
25205 // New combined shuffle mask.
25206 SmallVector<int, 32> Mask;
25207 Mask.reserve(VT.getVectorNumElements());
25208 for (SDValue Op : N->ops()) {
25209 assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
25210 if (Op.getNode() == SVN) {
25211 append_range(Mask, AdjustedMask);
25212 continue;
25213 }
25214 if (Op == SVN->getOperand(0)) {
25215 append_range(Mask, FirstShufOpIdentityMask);
25216 continue;
25217 }
25218 if (Op == SVN->getOperand(1)) {
25219 append_range(Mask, SecondShufOpIdentityMask);
25220 continue;
25221 }
25222 llvm_unreachable("Unexpected operand!");
25223 }
25224
25225 // Don't create illegal shuffle masks.
25226 if (!TLI.isShuffleMaskLegal(Mask, VT))
25227 return SDValue();
25228
25229 // Pad the shuffle operands with UNDEF.
25230 SDLoc dl(N);
25231 std::array<SDValue, 2> ShufOps;
25232 for (auto I : zip(SVN->ops(), ShufOps)) {
25233 SDValue ShufOp = std::get<0>(I);
25234 SDValue &NewShufOp = std::get<1>(I);
25235 if (ShufOp.isUndef())
25236 NewShufOp = DAG.getUNDEF(VT);
25237 else {
25238 SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
25239 DAG.getUNDEF(OpVT));
25240 ShufOpParts[0] = ShufOp;
25241 NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
25242 }
25243 }
25244 // Finally, create the new wide shuffle.
25245 return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
25246 }
25247
visitCONCAT_VECTORS(SDNode * N)25248 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
25249 // If we only have one input vector, we don't need to do any concatenation.
25250 if (N->getNumOperands() == 1)
25251 return N->getOperand(0);
25252
25253 // Check if all of the operands are undefs.
25254 EVT VT = N->getValueType(0);
25255 if (ISD::allOperandsUndef(N))
25256 return DAG.getUNDEF(VT);
25257
25258 // Optimize concat_vectors where all but the first of the vectors are undef.
25259 if (all_of(drop_begin(N->ops()),
25260 [](const SDValue &Op) { return Op.isUndef(); })) {
25261 SDValue In = N->getOperand(0);
25262 assert(In.getValueType().isVector() && "Must concat vectors");
25263
25264 // If the input is a concat_vectors, just make a larger concat by padding
25265 // with smaller undefs.
25266 //
25267 // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
25268 // here could cause an infinite loop. That legalizing happens when LegalDAG
25269 // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
25270 // scalable.
25271 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
25272 !(LegalDAG && In.getValueType().isScalableVector())) {
25273 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
25274 SmallVector<SDValue, 4> Ops(In->ops());
25275 Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
25276 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
25277 }
25278
25279 SDValue Scalar = peekThroughOneUseBitcasts(In);
25280
25281 // concat_vectors(scalar_to_vector(scalar), undef) ->
25282 // scalar_to_vector(scalar)
25283 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
25284 Scalar.hasOneUse()) {
25285 EVT SVT = Scalar.getValueType().getVectorElementType();
25286 if (SVT == Scalar.getOperand(0).getValueType())
25287 Scalar = Scalar.getOperand(0);
25288 }
25289
25290 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
25291 if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
25292 // If the bitcast type isn't legal, it might be a trunc of a legal type;
25293 // look through the trunc so we can still do the transform:
25294 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
25295 if (Scalar->getOpcode() == ISD::TRUNCATE &&
25296 !TLI.isTypeLegal(Scalar.getValueType()) &&
25297 TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
25298 Scalar = Scalar->getOperand(0);
25299
25300 EVT SclTy = Scalar.getValueType();
25301
25302 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
25303 return SDValue();
25304
25305 // Bail out if the vector size is not a multiple of the scalar size.
25306 if (VT.getSizeInBits() % SclTy.getSizeInBits())
25307 return SDValue();
25308
25309 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
25310 if (VNTNumElms < 2)
25311 return SDValue();
25312
25313 EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
25314 if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
25315 return SDValue();
25316
25317 SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
25318 return DAG.getBitcast(VT, Res);
25319 }
25320 }
25321
25322 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
25323 // We have already tested above for an UNDEF only concatenation.
25324 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
25325 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
25326 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
25327 return Op.isUndef() || ISD::BUILD_VECTOR == Op.getOpcode();
25328 };
25329 if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
25330 SmallVector<SDValue, 8> Opnds;
25331 EVT SVT = VT.getScalarType();
25332
25333 EVT MinVT = SVT;
25334 if (!SVT.isFloatingPoint()) {
25335 // If BUILD_VECTOR are from built from integer, they may have different
25336 // operand types. Get the smallest type and truncate all operands to it.
25337 bool FoundMinVT = false;
25338 for (const SDValue &Op : N->ops())
25339 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
25340 EVT OpSVT = Op.getOperand(0).getValueType();
25341 MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
25342 FoundMinVT = true;
25343 }
25344 assert(FoundMinVT && "Concat vector type mismatch");
25345 }
25346
25347 for (const SDValue &Op : N->ops()) {
25348 EVT OpVT = Op.getValueType();
25349 unsigned NumElts = OpVT.getVectorNumElements();
25350
25351 if (Op.isUndef())
25352 Opnds.append(NumElts, DAG.getUNDEF(MinVT));
25353
25354 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
25355 if (SVT.isFloatingPoint()) {
25356 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
25357 Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
25358 } else {
25359 for (unsigned i = 0; i != NumElts; ++i)
25360 Opnds.push_back(
25361 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
25362 }
25363 }
25364 }
25365
25366 assert(VT.getVectorNumElements() == Opnds.size() &&
25367 "Concat vector type mismatch");
25368 return DAG.getBuildVector(VT, SDLoc(N), Opnds);
25369 }
25370
25371 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
25372 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
25373 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
25374 return V;
25375
25376 if (Level <= AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
25377 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
25378 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
25379 return V;
25380
25381 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
25382 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
25383 return V;
25384 }
25385
25386 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
25387 return V;
25388
25389 if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
25390 N, DAG, TLI, LegalTypes, LegalOperations))
25391 return V;
25392
25393 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
25394 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
25395 // operands and look for a CONCAT operations that place the incoming vectors
25396 // at the exact same location.
25397 //
25398 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
25399 SDValue SingleSource = SDValue();
25400 unsigned PartNumElem =
25401 N->getOperand(0).getValueType().getVectorMinNumElements();
25402
25403 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
25404 SDValue Op = N->getOperand(i);
25405
25406 if (Op.isUndef())
25407 continue;
25408
25409 // Check if this is the identity extract:
25410 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
25411 return SDValue();
25412
25413 // Find the single incoming vector for the extract_subvector.
25414 if (SingleSource.getNode()) {
25415 if (Op.getOperand(0) != SingleSource)
25416 return SDValue();
25417 } else {
25418 SingleSource = Op.getOperand(0);
25419
25420 // Check the source type is the same as the type of the result.
25421 // If not, this concat may extend the vector, so we can not
25422 // optimize it away.
25423 if (SingleSource.getValueType() != N->getValueType(0))
25424 return SDValue();
25425 }
25426
25427 // Check that we are reading from the identity index.
25428 unsigned IdentityIndex = i * PartNumElem;
25429 if (Op.getConstantOperandAPInt(1) != IdentityIndex)
25430 return SDValue();
25431 }
25432
25433 if (SingleSource.getNode())
25434 return SingleSource;
25435
25436 return SDValue();
25437 }
25438
25439 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
25440 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,unsigned Index,EVT SubVT)25441 static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
25442 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
25443 V.getOperand(1).getValueType() == SubVT &&
25444 V.getConstantOperandAPInt(2) == Index) {
25445 return V.getOperand(1);
25446 }
25447 if (V.getOpcode() == ISD::CONCAT_VECTORS &&
25448 V.getOperand(0).getValueType() == SubVT &&
25449 (Index % SubVT.getVectorMinNumElements()) == 0) {
25450 uint64_t SubIdx = Index / SubVT.getVectorMinNumElements();
25451 return V.getOperand(SubIdx);
25452 }
25453 return SDValue();
25454 }
25455
narrowInsertExtractVectorBinOp(EVT SubVT,SDValue BinOp,unsigned Index,const SDLoc & DL,SelectionDAG & DAG,bool LegalOperations)25456 static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp,
25457 unsigned Index, const SDLoc &DL,
25458 SelectionDAG &DAG,
25459 bool LegalOperations) {
25460 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25461 unsigned BinOpcode = BinOp.getOpcode();
25462 if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
25463 return SDValue();
25464
25465 EVT VecVT = BinOp.getValueType();
25466 SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
25467 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
25468 return SDValue();
25469 if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
25470 return SDValue();
25471
25472 SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
25473 SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
25474
25475 // TODO: We could handle the case where only 1 operand is being inserted by
25476 // creating an extract of the other operand, but that requires checking
25477 // number of uses and/or costs.
25478 if (!Sub0 || !Sub1)
25479 return SDValue();
25480
25481 // We are inserting both operands of the wide binop only to extract back
25482 // to the narrow vector size. Eliminate all of the insert/extract:
25483 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
25484 return DAG.getNode(BinOpcode, DL, SubVT, Sub0, Sub1, BinOp->getFlags());
25485 }
25486
25487 /// If we are extracting a subvector produced by a wide binary operator try
25488 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(EVT VT,SDValue Src,unsigned Index,const SDLoc & DL,SelectionDAG & DAG,bool LegalOperations)25489 static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
25490 const SDLoc &DL, SelectionDAG &DAG,
25491 bool LegalOperations) {
25492 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
25493 // some of these bailouts with other transforms.
25494
25495 if (SDValue V = narrowInsertExtractVectorBinOp(VT, Src, Index, DL, DAG,
25496 LegalOperations))
25497 return V;
25498
25499 // We are looking for an optionally bitcasted wide vector binary operator
25500 // feeding an extract subvector.
25501 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25502 SDValue BinOp = peekThroughBitcasts(Src);
25503 unsigned BOpcode = BinOp.getOpcode();
25504 if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
25505 return SDValue();
25506
25507 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
25508 // reduced to the unary fneg when it is visited, and we probably want to deal
25509 // with fneg in a target-specific way.
25510 if (BOpcode == ISD::FSUB) {
25511 auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
25512 if (C && C->getValueAPF().isNegZero())
25513 return SDValue();
25514 }
25515
25516 // The binop must be a vector type, so we can extract some fraction of it.
25517 EVT WideBVT = BinOp.getValueType();
25518 // The optimisations below currently assume we are dealing with fixed length
25519 // vectors. It is possible to add support for scalable vectors, but at the
25520 // moment we've done no analysis to prove whether they are profitable or not.
25521 if (!WideBVT.isFixedLengthVector())
25522 return SDValue();
25523
25524 assert((Index % VT.getVectorNumElements()) == 0 &&
25525 "Extract index is not a multiple of the vector length.");
25526
25527 // Bail out if this is not a proper multiple width extraction.
25528 unsigned WideWidth = WideBVT.getSizeInBits();
25529 unsigned NarrowWidth = VT.getSizeInBits();
25530 if (WideWidth % NarrowWidth != 0)
25531 return SDValue();
25532
25533 // Bail out if we are extracting a fraction of a single operation. This can
25534 // occur because we potentially looked through a bitcast of the binop.
25535 unsigned NarrowingRatio = WideWidth / NarrowWidth;
25536 unsigned WideNumElts = WideBVT.getVectorNumElements();
25537 if (WideNumElts % NarrowingRatio != 0)
25538 return SDValue();
25539
25540 // Bail out if the target does not support a narrower version of the binop.
25541 EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
25542 WideNumElts / NarrowingRatio);
25543 if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT,
25544 LegalOperations))
25545 return SDValue();
25546
25547 // If extraction is cheap, we don't need to look at the binop operands
25548 // for concat ops. The narrow binop alone makes this transform profitable.
25549 // We can't just reuse the original extract index operand because we may have
25550 // bitcasted.
25551 unsigned ConcatOpNum = Index / VT.getVectorNumElements();
25552 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
25553 if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
25554 BinOp.hasOneUse() && Src->hasOneUse()) {
25555 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
25556 SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
25557 SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25558 BinOp.getOperand(0), NewExtIndex);
25559 SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25560 BinOp.getOperand(1), NewExtIndex);
25561 SDValue NarrowBinOp =
25562 DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
25563 return DAG.getBitcast(VT, NarrowBinOp);
25564 }
25565
25566 // Only handle the case where we are doubling and then halving. A larger ratio
25567 // may require more than two narrow binops to replace the wide binop.
25568 if (NarrowingRatio != 2)
25569 return SDValue();
25570
25571 // TODO: The motivating case for this transform is an x86 AVX1 target. That
25572 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
25573 // flavors, but no other 256-bit integer support. This could be extended to
25574 // handle any binop, but that may require fixing/adding other folds to avoid
25575 // codegen regressions.
25576 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
25577 return SDValue();
25578
25579 // We need at least one concatenation operation of a binop operand to make
25580 // this transform worthwhile. The concat must double the input vector sizes.
25581 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
25582 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
25583 return V.getOperand(ConcatOpNum);
25584 return SDValue();
25585 };
25586 SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
25587 SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
25588
25589 if (SubVecL || SubVecR) {
25590 // If a binop operand was not the result of a concat, we must extract a
25591 // half-sized operand for our new narrow binop:
25592 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
25593 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
25594 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
25595 SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
25596 SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
25597 : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25598 BinOp.getOperand(0), IndexC);
25599
25600 SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
25601 : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
25602 BinOp.getOperand(1), IndexC);
25603
25604 SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
25605 return DAG.getBitcast(VT, NarrowBinOp);
25606 }
25607
25608 return SDValue();
25609 }
25610
25611 /// If we are extracting a subvector from a wide vector load, convert to a
25612 /// narrow load to eliminate the extraction:
25613 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(EVT VT,SDValue Src,unsigned Index,const SDLoc & DL,SelectionDAG & DAG)25614 static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index,
25615 const SDLoc &DL, SelectionDAG &DAG) {
25616 // TODO: Add support for big-endian. The offset calculation must be adjusted.
25617 if (DAG.getDataLayout().isBigEndian())
25618 return SDValue();
25619
25620 auto *Ld = dyn_cast<LoadSDNode>(Src);
25621 if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple())
25622 return SDValue();
25623
25624 // We can only create byte sized loads.
25625 if (!VT.isByteSized())
25626 return SDValue();
25627
25628 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25629 if (!TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, VT))
25630 return SDValue();
25631
25632 unsigned NumElts = VT.getVectorMinNumElements();
25633 // A fixed length vector being extracted from a scalable vector
25634 // may not be any *smaller* than the scalable one.
25635 if (Index == 0 && NumElts >= Ld->getValueType(0).getVectorMinNumElements())
25636 return SDValue();
25637
25638 // The definition of EXTRACT_SUBVECTOR states that the index must be a
25639 // multiple of the minimum number of elements in the result type.
25640 assert(Index % NumElts == 0 && "The extract subvector index is not a "
25641 "multiple of the result's element count");
25642
25643 // It's fine to use TypeSize here as we know the offset will not be negative.
25644 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
25645 std::optional<unsigned> ByteOffset;
25646 if (Offset.isFixed())
25647 ByteOffset = Offset.getFixedValue();
25648
25649 if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT, ByteOffset))
25650 return SDValue();
25651
25652 // The narrow load will be offset from the base address of the old load if
25653 // we are extracting from something besides index 0 (little-endian).
25654 // TODO: Use "BaseIndexOffset" to make this more effective.
25655 SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
25656
25657 MachineFunction &MF = DAG.getMachineFunction();
25658 MachineMemOperand *MMO;
25659 if (Offset.isScalable()) {
25660 MachinePointerInfo MPI =
25661 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
25662 MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, VT.getStoreSize());
25663 } else
25664 MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedValue(),
25665 VT.getStoreSize());
25666
25667 SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
25668 DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
25669 return NewLd;
25670 }
25671
25672 /// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
25673 /// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
25674 /// EXTRACT_SUBVECTOR(Op?, ?),
25675 /// Mask'))
25676 /// iff it is legal and profitable to do so. Notably, the trimmed mask
25677 /// (containing only the elements that are extracted)
25678 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(EVT NarrowVT,SDValue Src,unsigned Index,const SDLoc & DL,SelectionDAG & DAG,bool LegalOperations)25679 static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
25680 unsigned Index,
25681 const SDLoc &DL,
25682 SelectionDAG &DAG,
25683 bool LegalOperations) {
25684 // Only deal with non-scalable vectors.
25685 EVT WideVT = Src.getValueType();
25686 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
25687 return SDValue();
25688
25689 // The operand must be a shufflevector.
25690 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Src);
25691 if (!WideShuffleVector)
25692 return SDValue();
25693
25694 // The old shuffleneeds to go away.
25695 if (!WideShuffleVector->hasOneUse())
25696 return SDValue();
25697
25698 // And the narrow shufflevector that we'll form must be legal.
25699 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
25700 if (LegalOperations &&
25701 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
25702 return SDValue();
25703
25704 int NumEltsExtracted = NarrowVT.getVectorNumElements();
25705 assert((Index % NumEltsExtracted) == 0 &&
25706 "Extract index is not a multiple of the output vector length.");
25707
25708 int WideNumElts = WideVT.getVectorNumElements();
25709
25710 SmallVector<int, 16> NewMask;
25711 NewMask.reserve(NumEltsExtracted);
25712 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
25713 DemandedSubvectors;
25714
25715 // Try to decode the wide mask into narrow mask from at most two subvectors.
25716 for (int M : WideShuffleVector->getMask().slice(Index, NumEltsExtracted)) {
25717 assert((M >= -1) && (M < (2 * WideNumElts)) &&
25718 "Out-of-bounds shuffle mask?");
25719
25720 if (M < 0) {
25721 // Does not depend on operands, does not require adjustment.
25722 NewMask.emplace_back(M);
25723 continue;
25724 }
25725
25726 // From which operand of the shuffle does this shuffle mask element pick?
25727 int WideShufOpIdx = M / WideNumElts;
25728 // Which element of that operand is picked?
25729 int OpEltIdx = M % WideNumElts;
25730
25731 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
25732 "Shuffle mask vector decomposition failure.");
25733
25734 // And which NumEltsExtracted-sized subvector of that operand is that?
25735 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
25736 // And which element within that subvector of that operand is that?
25737 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
25738
25739 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
25740 "Shuffle mask subvector decomposition failure.");
25741
25742 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
25743 WideShufOpIdx * WideNumElts) == M &&
25744 "Shuffle mask full decomposition failure.");
25745
25746 SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
25747
25748 if (Op.isUndef()) {
25749 // Picking from an undef operand. Let's adjust mask instead.
25750 NewMask.emplace_back(-1);
25751 continue;
25752 }
25753
25754 const std::pair<SDValue, int> DemandedSubvector =
25755 std::make_pair(Op, OpSubvecIdx);
25756
25757 if (DemandedSubvectors.insert(DemandedSubvector)) {
25758 if (DemandedSubvectors.size() > 2)
25759 return SDValue(); // We can't handle more than two subvectors.
25760 // How many elements into the WideVT does this subvector start?
25761 int Index = NumEltsExtracted * OpSubvecIdx;
25762 // Bail out if the extraction isn't going to be cheap.
25763 if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
25764 return SDValue();
25765 }
25766
25767 // Ok, but from which operand of the new shuffle will this element pick?
25768 int NewOpIdx =
25769 getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
25770 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
25771
25772 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
25773 NewMask.emplace_back(AdjM);
25774 }
25775 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
25776 assert(DemandedSubvectors.size() <= 2 &&
25777 "Should have ended up demanding at most two subvectors.");
25778
25779 // Did we discover that the shuffle does not actually depend on operands?
25780 if (DemandedSubvectors.empty())
25781 return DAG.getUNDEF(NarrowVT);
25782
25783 // Profitability check: only deal with extractions from the first subvector
25784 // unless the mask becomes an identity mask.
25785 if (!ShuffleVectorInst::isIdentityMask(NewMask, NewMask.size()) ||
25786 any_of(NewMask, [](int M) { return M < 0; }))
25787 for (auto &DemandedSubvector : DemandedSubvectors)
25788 if (DemandedSubvector.second != 0)
25789 return SDValue();
25790
25791 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
25792 // operand[s]/index[es], so there is no point in checking for it's legality.
25793
25794 // Do not turn a legal shuffle into an illegal one.
25795 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
25796 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
25797 return SDValue();
25798
25799 SmallVector<SDValue, 2> NewOps;
25800 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
25801 &DemandedSubvector : DemandedSubvectors) {
25802 // How many elements into the WideVT does this subvector start?
25803 int Index = NumEltsExtracted * DemandedSubvector.second;
25804 SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
25805 NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
25806 DemandedSubvector.first, IndexC));
25807 }
25808 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
25809 "Should end up with either one or two ops");
25810
25811 // If we ended up with only one operand, pad with an undef.
25812 if (NewOps.size() == 1)
25813 NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
25814
25815 return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
25816 }
25817
visitEXTRACT_SUBVECTOR(SDNode * N)25818 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
25819 EVT NVT = N->getValueType(0);
25820 SDValue V = N->getOperand(0);
25821 uint64_t ExtIdx = N->getConstantOperandVal(1);
25822 SDLoc DL(N);
25823
25824 // Extract from UNDEF is UNDEF.
25825 if (V.isUndef())
25826 return DAG.getUNDEF(NVT);
25827
25828 if (SDValue NarrowLoad = narrowExtractedVectorLoad(NVT, V, ExtIdx, DL, DAG))
25829 return NarrowLoad;
25830
25831 // Combine an extract of an extract into a single extract_subvector.
25832 // ext (ext X, C), 0 --> ext X, C
25833 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
25834 if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
25835 V.getConstantOperandVal(1)) &&
25836 TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
25837 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, V.getOperand(0),
25838 V.getOperand(1));
25839 }
25840 }
25841
25842 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
25843 if (V.getOpcode() == ISD::SPLAT_VECTOR)
25844 if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
25845 if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
25846 return DAG.getSplatVector(NVT, DL, V.getOperand(0));
25847
25848 // extract_subvector(insert_subvector(x,y,c1),c2)
25849 // --> extract_subvector(y,c2-c1)
25850 // iff we're just extracting from the inserted subvector.
25851 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
25852 SDValue InsSub = V.getOperand(1);
25853 EVT InsSubVT = InsSub.getValueType();
25854 unsigned NumInsElts = InsSubVT.getVectorMinNumElements();
25855 unsigned InsIdx = V.getConstantOperandVal(2);
25856 unsigned NumSubElts = NVT.getVectorMinNumElements();
25857 if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) &&
25858 TLI.isExtractSubvectorCheap(NVT, InsSubVT, ExtIdx - InsIdx) &&
25859 InsSubVT.isFixedLengthVector() && NVT.isFixedLengthVector() &&
25860 V.getValueType().isFixedLengthVector())
25861 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, InsSub,
25862 DAG.getVectorIdxConstant(ExtIdx - InsIdx, DL));
25863 }
25864
25865 // Try to move vector bitcast after extract_subv by scaling extraction index:
25866 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
25867 if (V.getOpcode() == ISD::BITCAST &&
25868 V.getOperand(0).getValueType().isVector() &&
25869 (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
25870 SDValue SrcOp = V.getOperand(0);
25871 EVT SrcVT = SrcOp.getValueType();
25872 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
25873 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
25874 if ((SrcNumElts % DestNumElts) == 0) {
25875 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
25876 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
25877 EVT NewExtVT =
25878 EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(), NewExtEC);
25879 if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
25880 SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
25881 SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
25882 V.getOperand(0), NewIndex);
25883 return DAG.getBitcast(NVT, NewExtract);
25884 }
25885 }
25886 if ((DestNumElts % SrcNumElts) == 0) {
25887 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
25888 if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
25889 ElementCount NewExtEC =
25890 NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
25891 EVT ScalarVT = SrcVT.getScalarType();
25892 if ((ExtIdx % DestSrcRatio) == 0) {
25893 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
25894 EVT NewExtVT =
25895 EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
25896 if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
25897 SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
25898 SDValue NewExtract =
25899 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
25900 V.getOperand(0), NewIndex);
25901 return DAG.getBitcast(NVT, NewExtract);
25902 }
25903 if (NewExtEC.isScalar() &&
25904 TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
25905 SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
25906 SDValue NewExtract =
25907 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
25908 V.getOperand(0), NewIndex);
25909 return DAG.getBitcast(NVT, NewExtract);
25910 }
25911 }
25912 }
25913 }
25914 }
25915
25916 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
25917 unsigned ExtNumElts = NVT.getVectorMinNumElements();
25918 EVT ConcatSrcVT = V.getOperand(0).getValueType();
25919 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
25920 "Concat and extract subvector do not change element type");
25921 assert((ExtIdx % ExtNumElts) == 0 &&
25922 "Extract index is not a multiple of the input vector length.");
25923
25924 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
25925 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
25926
25927 // If the concatenated source types match this extract, it's a direct
25928 // simplification:
25929 // extract_subvec (concat V1, V2, ...), i --> Vi
25930 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
25931 return V.getOperand(ConcatOpIdx);
25932
25933 // If the concatenated source vectors are a multiple length of this extract,
25934 // then extract a fraction of one of those source vectors directly from a
25935 // concat operand. Example:
25936 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
25937 // v2i8 extract_subvec v8i8 Y, 6
25938 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
25939 ConcatSrcNumElts % ExtNumElts == 0) {
25940 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
25941 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
25942 "Trying to extract from >1 concat operand?");
25943 assert(NewExtIdx % ExtNumElts == 0 &&
25944 "Extract index is not a multiple of the input vector length.");
25945 SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
25946 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
25947 V.getOperand(ConcatOpIdx), NewIndexC);
25948 }
25949 }
25950
25951 if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector(
25952 NVT, V, ExtIdx, DL, DAG, LegalOperations))
25953 return Shuffle;
25954
25955 if (SDValue NarrowBOp =
25956 narrowExtractedVectorBinOp(NVT, V, ExtIdx, DL, DAG, LegalOperations))
25957 return NarrowBOp;
25958
25959 V = peekThroughBitcasts(V);
25960
25961 // If the input is a build vector. Try to make a smaller build vector.
25962 if (V.getOpcode() == ISD::BUILD_VECTOR) {
25963 EVT InVT = V.getValueType();
25964 unsigned ExtractSize = NVT.getSizeInBits();
25965 unsigned EltSize = InVT.getScalarSizeInBits();
25966 // Only do this if we won't split any elements.
25967 if (ExtractSize % EltSize == 0) {
25968 unsigned NumElems = ExtractSize / EltSize;
25969 EVT EltVT = InVT.getVectorElementType();
25970 EVT ExtractVT =
25971 NumElems == 1 ? EltVT
25972 : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
25973 if ((Level < AfterLegalizeDAG ||
25974 (NumElems == 1 ||
25975 TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
25976 (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
25977 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
25978
25979 if (NumElems == 1) {
25980 SDValue Src = V->getOperand(IdxVal);
25981 if (EltVT != Src.getValueType())
25982 Src = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Src);
25983 return DAG.getBitcast(NVT, Src);
25984 }
25985
25986 // Extract the pieces from the original build_vector.
25987 SDValue BuildVec =
25988 DAG.getBuildVector(ExtractVT, DL, V->ops().slice(IdxVal, NumElems));
25989 return DAG.getBitcast(NVT, BuildVec);
25990 }
25991 }
25992 }
25993
25994 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
25995 // Handle only simple case where vector being inserted and vector
25996 // being extracted are of same size.
25997 EVT SmallVT = V.getOperand(1).getValueType();
25998 if (NVT.bitsEq(SmallVT)) {
25999 // Combine:
26000 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
26001 // Into:
26002 // indices are equal or bit offsets are equal => V1
26003 // otherwise => (extract_subvec V1, ExtIdx)
26004 uint64_t InsIdx = V.getConstantOperandVal(2);
26005 if (InsIdx * SmallVT.getScalarSizeInBits() ==
26006 ExtIdx * NVT.getScalarSizeInBits()) {
26007 if (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))
26008 return DAG.getBitcast(NVT, V.getOperand(1));
26009 } else {
26010 return DAG.getNode(
26011 ISD::EXTRACT_SUBVECTOR, DL, NVT,
26012 DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
26013 N->getOperand(1));
26014 }
26015 }
26016 }
26017
26018 // If only EXTRACT_SUBVECTOR nodes use the source vector we can
26019 // simplify it based on the (valid) extractions.
26020 if (!V.getValueType().isScalableVector() &&
26021 llvm::all_of(V->users(), [&](SDNode *Use) {
26022 return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26023 Use->getOperand(0) == V;
26024 })) {
26025 unsigned NumElts = V.getValueType().getVectorNumElements();
26026 APInt DemandedElts = APInt::getZero(NumElts);
26027 for (SDNode *User : V->users()) {
26028 unsigned ExtIdx = User->getConstantOperandVal(1);
26029 unsigned NumSubElts = User->getValueType(0).getVectorNumElements();
26030 DemandedElts.setBits(ExtIdx, ExtIdx + NumSubElts);
26031 }
26032 if (SimplifyDemandedVectorElts(V, DemandedElts, /*AssumeSingleUse=*/true)) {
26033 // We simplified the vector operand of this extract subvector. If this
26034 // extract is not dead, visit it again so it is folded properly.
26035 if (N->getOpcode() != ISD::DELETED_NODE)
26036 AddToWorklist(N);
26037 return SDValue(N, 0);
26038 }
26039 } else {
26040 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
26041 return SDValue(N, 0);
26042 }
26043
26044 return SDValue();
26045 }
26046
26047 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
26048 /// followed by concatenation. Narrow vector ops may have better performance
26049 /// than wide ops, and this can unlock further narrowing of other vector ops.
26050 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)26051 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
26052 SelectionDAG &DAG) {
26053 SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
26054 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
26055 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
26056 !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
26057 return SDValue();
26058
26059 // Split the wide shuffle mask into halves. Any mask element that is accessing
26060 // operand 1 is offset down to account for narrowing of the vectors.
26061 ArrayRef<int> Mask = Shuf->getMask();
26062 EVT VT = Shuf->getValueType(0);
26063 unsigned NumElts = VT.getVectorNumElements();
26064 unsigned HalfNumElts = NumElts / 2;
26065 SmallVector<int, 16> Mask0(HalfNumElts, -1);
26066 SmallVector<int, 16> Mask1(HalfNumElts, -1);
26067 for (unsigned i = 0; i != NumElts; ++i) {
26068 if (Mask[i] == -1)
26069 continue;
26070 // If we reference the upper (undef) subvector then the element is undef.
26071 if ((Mask[i] % NumElts) >= HalfNumElts)
26072 continue;
26073 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
26074 if (i < HalfNumElts)
26075 Mask0[i] = M;
26076 else
26077 Mask1[i - HalfNumElts] = M;
26078 }
26079
26080 // Ask the target if this is a valid transform.
26081 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26082 EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
26083 HalfNumElts);
26084 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
26085 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
26086 return SDValue();
26087
26088 // shuffle (concat X, undef), (concat Y, undef), Mask -->
26089 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
26090 SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
26091 SDLoc DL(Shuf);
26092 SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
26093 SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
26094 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
26095 }
26096
26097 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
26098 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)26099 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
26100 EVT VT = N->getValueType(0);
26101 unsigned NumElts = VT.getVectorNumElements();
26102
26103 SDValue N0 = N->getOperand(0);
26104 SDValue N1 = N->getOperand(1);
26105 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
26106 ArrayRef<int> Mask = SVN->getMask();
26107
26108 SmallVector<SDValue, 4> Ops;
26109 EVT ConcatVT = N0.getOperand(0).getValueType();
26110 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
26111 unsigned NumConcats = NumElts / NumElemsPerConcat;
26112
26113 auto IsUndefMaskElt = [](int i) { return i == -1; };
26114
26115 // Special case: shuffle(concat(A,B)) can be more efficiently represented
26116 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
26117 // half vector elements.
26118 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
26119 llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
26120 IsUndefMaskElt)) {
26121 N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
26122 N0.getOperand(1),
26123 Mask.slice(0, NumElemsPerConcat));
26124 N1 = DAG.getUNDEF(ConcatVT);
26125 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
26126 }
26127
26128 // Look at every vector that's inserted. We're looking for exact
26129 // subvector-sized copies from a concatenated vector
26130 for (unsigned I = 0; I != NumConcats; ++I) {
26131 unsigned Begin = I * NumElemsPerConcat;
26132 ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
26133
26134 // Make sure we're dealing with a copy.
26135 if (llvm::all_of(SubMask, IsUndefMaskElt)) {
26136 Ops.push_back(DAG.getUNDEF(ConcatVT));
26137 continue;
26138 }
26139
26140 int OpIdx = -1;
26141 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
26142 if (IsUndefMaskElt(SubMask[i]))
26143 continue;
26144 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
26145 return SDValue();
26146 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
26147 if (0 <= OpIdx && EltOpIdx != OpIdx)
26148 return SDValue();
26149 OpIdx = EltOpIdx;
26150 }
26151 assert(0 <= OpIdx && "Unknown concat_vectors op");
26152
26153 if (OpIdx < (int)N0.getNumOperands())
26154 Ops.push_back(N0.getOperand(OpIdx));
26155 else
26156 Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
26157 }
26158
26159 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
26160 }
26161
26162 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
26163 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
26164 //
26165 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
26166 // a simplification in some sense, but it isn't appropriate in general: some
26167 // BUILD_VECTORs are substantially cheaper than others. The general case
26168 // of a BUILD_VECTOR requires inserting each element individually (or
26169 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
26170 // all constants is a single constant pool load. A BUILD_VECTOR where each
26171 // element is identical is a splat. A BUILD_VECTOR where most of the operands
26172 // are undef lowers to a small number of element insertions.
26173 //
26174 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
26175 // We don't fold shuffles where one side is a non-zero constant, and we don't
26176 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
26177 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)26178 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
26179 SelectionDAG &DAG,
26180 const TargetLowering &TLI) {
26181 EVT VT = SVN->getValueType(0);
26182 unsigned NumElts = VT.getVectorNumElements();
26183 SDValue N0 = SVN->getOperand(0);
26184 SDValue N1 = SVN->getOperand(1);
26185
26186 if (!N0->hasOneUse())
26187 return SDValue();
26188
26189 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
26190 // discussed above.
26191 if (!N1.isUndef()) {
26192 if (!N1->hasOneUse())
26193 return SDValue();
26194
26195 bool N0AnyConst = isAnyConstantBuildVector(N0);
26196 bool N1AnyConst = isAnyConstantBuildVector(N1);
26197 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
26198 return SDValue();
26199 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
26200 return SDValue();
26201 }
26202
26203 // If both inputs are splats of the same value then we can safely merge this
26204 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
26205 bool IsSplat = false;
26206 auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
26207 auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
26208 if (BV0 && BV1)
26209 if (SDValue Splat0 = BV0->getSplatValue())
26210 IsSplat = (Splat0 == BV1->getSplatValue());
26211
26212 SmallVector<SDValue, 8> Ops;
26213 SmallSet<SDValue, 16> DuplicateOps;
26214 for (int M : SVN->getMask()) {
26215 SDValue Op = DAG.getUNDEF(VT.getScalarType());
26216 if (M >= 0) {
26217 int Idx = M < (int)NumElts ? M : M - NumElts;
26218 SDValue &S = (M < (int)NumElts ? N0 : N1);
26219 if (S.getOpcode() == ISD::BUILD_VECTOR) {
26220 Op = S.getOperand(Idx);
26221 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
26222 SDValue Op0 = S.getOperand(0);
26223 Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
26224 } else {
26225 // Operand can't be combined - bail out.
26226 return SDValue();
26227 }
26228 }
26229
26230 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
26231 // generating a splat; semantically, this is fine, but it's likely to
26232 // generate low-quality code if the target can't reconstruct an appropriate
26233 // shuffle.
26234 if (!Op.isUndef() && !isIntOrFPConstant(Op))
26235 if (!IsSplat && !DuplicateOps.insert(Op).second)
26236 return SDValue();
26237
26238 Ops.push_back(Op);
26239 }
26240
26241 // BUILD_VECTOR requires all inputs to be of the same type, find the
26242 // maximum type and extend them all.
26243 EVT SVT = VT.getScalarType();
26244 if (SVT.isInteger())
26245 for (SDValue &Op : Ops)
26246 SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
26247 if (SVT != VT.getScalarType())
26248 for (SDValue &Op : Ops)
26249 Op = Op.isUndef() ? DAG.getUNDEF(SVT)
26250 : (TLI.isZExtFree(Op.getValueType(), SVT)
26251 ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
26252 : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
26253 return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
26254 }
26255
26256 // Match shuffles that can be converted to *_vector_extend_in_reg.
26257 // This is often generated during legalization.
26258 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
26259 // and returns the EVT to which the extension should be performed.
26260 // NOTE: this assumes that the src is the first operand of the shuffle.
canCombineShuffleToExtendVectorInreg(unsigned Opcode,EVT VT,std::function<bool (unsigned)> Match,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)26261 static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
26262 unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
26263 SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
26264 bool LegalOperations) {
26265 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26266
26267 // TODO Add support for big-endian when we have a test case.
26268 if (!VT.isInteger() || IsBigEndian)
26269 return std::nullopt;
26270
26271 unsigned NumElts = VT.getVectorNumElements();
26272 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26273
26274 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
26275 // power-of-2 extensions as they are the most likely.
26276 // FIXME: should try Scale == NumElts case too,
26277 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
26278 // The vector width must be a multiple of Scale.
26279 if (NumElts % Scale != 0)
26280 continue;
26281
26282 EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
26283 EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
26284
26285 if ((LegalTypes && !TLI.isTypeLegal(OutVT)) ||
26286 (LegalOperations && !TLI.isOperationLegalOrCustom(Opcode, OutVT)))
26287 continue;
26288
26289 if (Match(Scale))
26290 return OutVT;
26291 }
26292
26293 return std::nullopt;
26294 }
26295
26296 // Match shuffles that can be converted to any_vector_extend_in_reg.
26297 // This is often generated during legalization.
26298 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)26299 static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
26300 SelectionDAG &DAG,
26301 const TargetLowering &TLI,
26302 bool LegalOperations) {
26303 EVT VT = SVN->getValueType(0);
26304 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26305
26306 // TODO Add support for big-endian when we have a test case.
26307 if (!VT.isInteger() || IsBigEndian)
26308 return SDValue();
26309
26310 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
26311 auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
26312 Mask = SVN->getMask()](unsigned Scale) {
26313 for (unsigned i = 0; i != NumElts; ++i) {
26314 if (Mask[i] < 0)
26315 continue;
26316 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
26317 continue;
26318 return false;
26319 }
26320 return true;
26321 };
26322
26323 unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
26324 SDValue N0 = SVN->getOperand(0);
26325 // Never create an illegal type. Only create unsupported operations if we
26326 // are pre-legalization.
26327 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
26328 Opcode, VT, isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
26329 if (!OutVT)
26330 return SDValue();
26331 return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT, N0));
26332 }
26333
26334 // Match shuffles that can be converted to zero_extend_vector_inreg.
26335 // This is often generated during legalization.
26336 // e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)26337 static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
26338 SelectionDAG &DAG,
26339 const TargetLowering &TLI,
26340 bool LegalOperations) {
26341 bool LegalTypes = true;
26342 EVT VT = SVN->getValueType(0);
26343 assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
26344 unsigned NumElts = VT.getVectorNumElements();
26345 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26346
26347 // TODO: add support for big-endian when we have a test case.
26348 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26349 if (!VT.isInteger() || IsBigEndian)
26350 return SDValue();
26351
26352 SmallVector<int, 16> Mask(SVN->getMask());
26353 auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
26354 for (int &Indice : Mask) {
26355 if (Indice < 0)
26356 continue;
26357 int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
26358 int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
26359 Fn(Indice, OpIdx, OpEltIdx);
26360 }
26361 };
26362
26363 // Which elements of which operand does this shuffle demand?
26364 std::array<APInt, 2> OpsDemandedElts;
26365 for (APInt &OpDemandedElts : OpsDemandedElts)
26366 OpDemandedElts = APInt::getZero(NumElts);
26367 ForEachDecomposedIndice(
26368 [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
26369 OpsDemandedElts[OpIdx].setBit(OpEltIdx);
26370 });
26371
26372 // Element-wise(!), which of these demanded elements are know to be zero?
26373 std::array<APInt, 2> OpsKnownZeroElts;
26374 for (auto I : zip(SVN->ops(), OpsDemandedElts, OpsKnownZeroElts))
26375 std::get<2>(I) =
26376 DAG.computeVectorKnownZeroElements(std::get<0>(I), std::get<1>(I));
26377
26378 // Manifest zeroable element knowledge in the shuffle mask.
26379 // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
26380 // this is a local invention, but it won't leak into DAG.
26381 // FIXME: should we not manifest them, but just check when matching?
26382 bool HadZeroableElts = false;
26383 ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
26384 int &Indice, int OpIdx, int OpEltIdx) {
26385 if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
26386 Indice = -2; // Zeroable element.
26387 HadZeroableElts = true;
26388 }
26389 });
26390
26391 // Don't proceed unless we've refined at least one zeroable mask indice.
26392 // If we didn't, then we are still trying to match the same shuffle mask
26393 // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
26394 // and evidently failed. Proceeding will lead to endless combine loops.
26395 if (!HadZeroableElts)
26396 return SDValue();
26397
26398 // The shuffle may be more fine-grained than we want. Widen elements first.
26399 // FIXME: should we do this before manifesting zeroable shuffle mask indices?
26400 SmallVector<int, 16> ScaledMask;
26401 getShuffleMaskWithWidestElts(Mask, ScaledMask);
26402 assert(Mask.size() >= ScaledMask.size() &&
26403 Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
26404 int Prescale = Mask.size() / ScaledMask.size();
26405
26406 NumElts = ScaledMask.size();
26407 EltSizeInBits *= Prescale;
26408
26409 EVT PrescaledVT = EVT::getVectorVT(
26410 *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits),
26411 NumElts);
26412
26413 if (LegalTypes && !TLI.isTypeLegal(PrescaledVT) && TLI.isTypeLegal(VT))
26414 return SDValue();
26415
26416 // For example,
26417 // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
26418 // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
26419 auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
26420 assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
26421 "Unexpected mask scaling factor.");
26422 ArrayRef<int> Mask = ScaledMask;
26423 for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
26424 SrcElt != NumSrcElts; ++SrcElt) {
26425 // Analyze the shuffle mask in Scale-sized chunks.
26426 ArrayRef<int> MaskChunk = Mask.take_front(Scale);
26427 assert(MaskChunk.size() == Scale && "Unexpected mask size.");
26428 Mask = Mask.drop_front(MaskChunk.size());
26429 // The first indice in this chunk must be SrcElt, but not zero!
26430 // FIXME: undef should be fine, but that results in more-defined result.
26431 if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
26432 return false;
26433 // The rest of the indices in this chunk must be zeros.
26434 // FIXME: undef should be fine, but that results in more-defined result.
26435 if (!all_of(MaskChunk.drop_front(1),
26436 [](int Indice) { return Indice == -2; }))
26437 return false;
26438 }
26439 assert(Mask.empty() && "Did not process the whole mask?");
26440 return true;
26441 };
26442
26443 unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
26444 for (bool Commuted : {false, true}) {
26445 SDValue Op = SVN->getOperand(!Commuted ? 0 : 1);
26446 if (Commuted)
26447 ShuffleVectorSDNode::commuteMask(ScaledMask);
26448 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
26449 Opcode, PrescaledVT, isZeroExtend, DAG, TLI, LegalTypes,
26450 LegalOperations);
26451 if (OutVT)
26452 return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT,
26453 DAG.getBitcast(PrescaledVT, Op)));
26454 }
26455 return SDValue();
26456 }
26457
26458 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
26459 // each source element of a large type into the lowest elements of a smaller
26460 // destination type. This is often generated during legalization.
26461 // If the source node itself was a '*_extend_vector_inreg' node then we should
26462 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)26463 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
26464 SelectionDAG &DAG) {
26465 EVT VT = SVN->getValueType(0);
26466 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
26467
26468 // TODO Add support for big-endian when we have a test case.
26469 if (!VT.isInteger() || IsBigEndian)
26470 return SDValue();
26471
26472 SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
26473
26474 unsigned Opcode = N0.getOpcode();
26475 if (!ISD::isExtVecInRegOpcode(Opcode))
26476 return SDValue();
26477
26478 SDValue N00 = N0.getOperand(0);
26479 ArrayRef<int> Mask = SVN->getMask();
26480 unsigned NumElts = VT.getVectorNumElements();
26481 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26482 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
26483 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
26484
26485 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
26486 return SDValue();
26487 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
26488
26489 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
26490 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
26491 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
26492 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
26493 for (unsigned i = 0; i != NumElts; ++i) {
26494 if (Mask[i] < 0)
26495 continue;
26496 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
26497 continue;
26498 return false;
26499 }
26500 return true;
26501 };
26502
26503 // At the moment we just handle the case where we've truncated back to the
26504 // same size as before the extension.
26505 // TODO: handle more extension/truncation cases as cases arise.
26506 if (EltSizeInBits != ExtSrcSizeInBits)
26507 return SDValue();
26508
26509 // We can remove *extend_vector_inreg only if the truncation happens at
26510 // the same scale as the extension.
26511 if (isTruncate(ExtScale))
26512 return DAG.getBitcast(VT, N00);
26513
26514 return SDValue();
26515 }
26516
26517 // Combine shuffles of splat-shuffles of the form:
26518 // shuffle (shuffle V, undef, splat-mask), undef, M
26519 // If splat-mask contains undef elements, we need to be careful about
26520 // introducing undef's in the folded mask which are not the result of composing
26521 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)26522 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
26523 SelectionDAG &DAG) {
26524 EVT VT = Shuf->getValueType(0);
26525 unsigned NumElts = VT.getVectorNumElements();
26526
26527 if (!Shuf->getOperand(1).isUndef())
26528 return SDValue();
26529
26530 // See if this unary non-splat shuffle actually *is* a splat shuffle,
26531 // in disguise, with all demanded elements being identical.
26532 // FIXME: this can be done per-operand.
26533 if (!Shuf->isSplat()) {
26534 APInt DemandedElts(NumElts, 0);
26535 for (int Idx : Shuf->getMask()) {
26536 if (Idx < 0)
26537 continue; // Ignore sentinel indices.
26538 assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
26539 DemandedElts.setBit(Idx);
26540 }
26541 assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
26542 APInt UndefElts;
26543 if (DAG.isSplatValue(Shuf->getOperand(0), DemandedElts, UndefElts)) {
26544 // Even if all demanded elements are splat, some of them could be undef.
26545 // Which lowest demanded element is *not* known-undef?
26546 std::optional<unsigned> MinNonUndefIdx;
26547 for (int Idx : Shuf->getMask()) {
26548 if (Idx < 0 || UndefElts[Idx])
26549 continue; // Ignore sentinel indices, and undef elements.
26550 MinNonUndefIdx = std::min<unsigned>(Idx, MinNonUndefIdx.value_or(~0U));
26551 }
26552 if (!MinNonUndefIdx)
26553 return DAG.getUNDEF(VT); // All undef - result is undef.
26554 assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
26555 SmallVector<int, 8> SplatMask(Shuf->getMask());
26556 for (int &Idx : SplatMask) {
26557 if (Idx < 0)
26558 continue; // Passthrough sentinel indices.
26559 // Otherwise, just pick the lowest demanded non-undef element.
26560 // Or sentinel undef, if we know we'd pick a known-undef element.
26561 Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
26562 }
26563 assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
26564 return DAG.getVectorShuffle(VT, SDLoc(Shuf), Shuf->getOperand(0),
26565 Shuf->getOperand(1), SplatMask);
26566 }
26567 }
26568
26569 // If the inner operand is a known splat with no undefs, just return that directly.
26570 // TODO: Create DemandedElts mask from Shuf's mask.
26571 // TODO: Allow undef elements and merge with the shuffle code below.
26572 if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
26573 return Shuf->getOperand(0);
26574
26575 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
26576 if (!Splat || !Splat->isSplat())
26577 return SDValue();
26578
26579 ArrayRef<int> ShufMask = Shuf->getMask();
26580 ArrayRef<int> SplatMask = Splat->getMask();
26581 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
26582
26583 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
26584 // every undef mask element in the splat-shuffle has a corresponding undef
26585 // element in the user-shuffle's mask or if the composition of mask elements
26586 // would result in undef.
26587 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
26588 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
26589 // In this case it is not legal to simplify to the splat-shuffle because we
26590 // may be exposing the users of the shuffle an undef element at index 1
26591 // which was not there before the combine.
26592 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
26593 // In this case the composition of masks yields SplatMask, so it's ok to
26594 // simplify to the splat-shuffle.
26595 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
26596 // In this case the composed mask includes all undef elements of SplatMask
26597 // and in addition sets element zero to undef. It is safe to simplify to
26598 // the splat-shuffle.
26599 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
26600 ArrayRef<int> SplatMask) {
26601 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
26602 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
26603 SplatMask[UserMask[i]] != -1)
26604 return false;
26605 return true;
26606 };
26607 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
26608 return Shuf->getOperand(0);
26609
26610 // Create a new shuffle with a mask that is composed of the two shuffles'
26611 // masks.
26612 SmallVector<int, 32> NewMask;
26613 for (int Idx : ShufMask)
26614 NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
26615
26616 return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
26617 Splat->getOperand(0), Splat->getOperand(1),
26618 NewMask);
26619 }
26620
26621 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
26622 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)26623 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
26624 SelectionDAG &DAG,
26625 const TargetLowering &TLI,
26626 bool LegalOperations) {
26627 SDValue Op0 = SVN->getOperand(0);
26628 SDValue Op1 = SVN->getOperand(1);
26629 EVT VT = SVN->getValueType(0);
26630 if (Op0.getOpcode() != ISD::BITCAST)
26631 return SDValue();
26632 EVT InVT = Op0.getOperand(0).getValueType();
26633 if (!InVT.isVector() ||
26634 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
26635 Op1.getOperand(0).getValueType() != InVT)))
26636 return SDValue();
26637 if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
26638 (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
26639 return SDValue();
26640
26641 int VTLanes = VT.getVectorNumElements();
26642 int InLanes = InVT.getVectorNumElements();
26643 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
26644 (LegalOperations &&
26645 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
26646 return SDValue();
26647 int Factor = VTLanes / InLanes;
26648
26649 // Check that each group of lanes in the mask are either undef or make a valid
26650 // mask for the wider lane type.
26651 ArrayRef<int> Mask = SVN->getMask();
26652 SmallVector<int> NewMask;
26653 if (!widenShuffleMaskElts(Factor, Mask, NewMask))
26654 return SDValue();
26655
26656 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
26657 return SDValue();
26658
26659 // Create the new shuffle with the new mask and bitcast it back to the
26660 // original type.
26661 SDLoc DL(SVN);
26662 Op0 = Op0.getOperand(0);
26663 Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
26664 SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
26665 return DAG.getBitcast(VT, NewShuf);
26666 }
26667
26668 /// Combine shuffle of shuffle of the form:
26669 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)26670 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
26671 SelectionDAG &DAG) {
26672 if (!OuterShuf->getOperand(1).isUndef())
26673 return SDValue();
26674 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
26675 if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
26676 return SDValue();
26677
26678 ArrayRef<int> OuterMask = OuterShuf->getMask();
26679 ArrayRef<int> InnerMask = InnerShuf->getMask();
26680 unsigned NumElts = OuterMask.size();
26681 assert(NumElts == InnerMask.size() && "Mask length mismatch");
26682 SmallVector<int, 32> CombinedMask(NumElts, -1);
26683 int SplatIndex = -1;
26684 for (unsigned i = 0; i != NumElts; ++i) {
26685 // Undef lanes remain undef.
26686 int OuterMaskElt = OuterMask[i];
26687 if (OuterMaskElt == -1)
26688 continue;
26689
26690 // Peek through the shuffle masks to get the underlying source element.
26691 int InnerMaskElt = InnerMask[OuterMaskElt];
26692 if (InnerMaskElt == -1)
26693 continue;
26694
26695 // Initialize the splatted element.
26696 if (SplatIndex == -1)
26697 SplatIndex = InnerMaskElt;
26698
26699 // Non-matching index - this is not a splat.
26700 if (SplatIndex != InnerMaskElt)
26701 return SDValue();
26702
26703 CombinedMask[i] = InnerMaskElt;
26704 }
26705 assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
26706 getSplatIndex(CombinedMask) != -1) &&
26707 "Expected a splat mask");
26708
26709 // TODO: The transform may be a win even if the mask is not legal.
26710 EVT VT = OuterShuf->getValueType(0);
26711 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
26712 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
26713 return SDValue();
26714
26715 return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
26716 InnerShuf->getOperand(1), CombinedMask);
26717 }
26718
26719 /// If the shuffle mask is taking exactly one element from the first vector
26720 /// operand and passing through all other elements from the second vector
26721 /// operand, return the index of the mask element that is choosing an element
26722 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)26723 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
26724 int MaskSize = Mask.size();
26725 int EltFromOp0 = -1;
26726 // TODO: This does not match if there are undef elements in the shuffle mask.
26727 // Should we ignore undefs in the shuffle mask instead? The trade-off is
26728 // removing an instruction (a shuffle), but losing the knowledge that some
26729 // vector lanes are not needed.
26730 for (int i = 0; i != MaskSize; ++i) {
26731 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
26732 // We're looking for a shuffle of exactly one element from operand 0.
26733 if (EltFromOp0 != -1)
26734 return -1;
26735 EltFromOp0 = i;
26736 } else if (Mask[i] != i + MaskSize) {
26737 // Nothing from operand 1 can change lanes.
26738 return -1;
26739 }
26740 }
26741 return EltFromOp0;
26742 }
26743
26744 /// If a shuffle inserts exactly one element from a source vector operand into
26745 /// another vector operand and we can access the specified element as a scalar,
26746 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf)26747 SDValue DAGCombiner::replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf) {
26748 // First, check if we are taking one element of a vector and shuffling that
26749 // element into another vector.
26750 ArrayRef<int> Mask = Shuf->getMask();
26751 SmallVector<int, 16> CommutedMask(Mask);
26752 SDValue Op0 = Shuf->getOperand(0);
26753 SDValue Op1 = Shuf->getOperand(1);
26754 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
26755 if (ShufOp0Index == -1) {
26756 // Commute mask and check again.
26757 ShuffleVectorSDNode::commuteMask(CommutedMask);
26758 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
26759 if (ShufOp0Index == -1)
26760 return SDValue();
26761 // Commute operands to match the commuted shuffle mask.
26762 std::swap(Op0, Op1);
26763 Mask = CommutedMask;
26764 }
26765
26766 // The shuffle inserts exactly one element from operand 0 into operand 1.
26767 // Now see if we can access that element as a scalar via a real insert element
26768 // instruction.
26769 // TODO: We can try harder to locate the element as a scalar. Examples: it
26770 // could be an operand of BUILD_VECTOR, or a constant.
26771 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
26772 "Shuffle mask value must be from operand 0");
26773
26774 SDValue Elt;
26775 if (sd_match(Op0, m_InsertElt(m_Value(), m_Value(Elt),
26776 m_SpecificInt(Mask[ShufOp0Index])))) {
26777 // There's an existing insertelement with constant insertion index, so we
26778 // don't need to check the legality/profitability of a replacement operation
26779 // that differs at most in the constant value. The target should be able to
26780 // lower any of those in a similar way. If not, legalization will expand
26781 // this to a scalar-to-vector plus shuffle.
26782 //
26783 // Note that the shuffle may move the scalar from the position that the
26784 // insert element used. Therefore, our new insert element occurs at the
26785 // shuffle's mask index value, not the insert's index value.
26786 //
26787 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
26788 SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
26789 return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
26790 Op1, Elt, NewInsIndex);
26791 }
26792
26793 if (!hasOperation(ISD::INSERT_VECTOR_ELT, Op0.getValueType()))
26794 return SDValue();
26795
26796 if (sd_match(Op0, m_UnaryOp(ISD::SCALAR_TO_VECTOR, m_Value(Elt))) &&
26797 Mask[ShufOp0Index] == 0) {
26798 SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
26799 return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
26800 Op1, Elt, NewInsIndex);
26801 }
26802
26803 return SDValue();
26804 }
26805
26806 /// If we have a unary shuffle of a shuffle, see if it can be folded away
26807 /// completely. This has the potential to lose undef knowledge because the first
26808 /// shuffle may not have an undef mask element where the second one does. So
26809 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)26810 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
26811 // shuf (shuf0 X, Y, Mask0), undef, Mask
26812 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
26813 if (!Shuf0 || !Shuf->getOperand(1).isUndef())
26814 return SDValue();
26815
26816 ArrayRef<int> Mask = Shuf->getMask();
26817 ArrayRef<int> Mask0 = Shuf0->getMask();
26818 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
26819 // Ignore undef elements.
26820 if (Mask[i] == -1)
26821 continue;
26822 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
26823
26824 // Is the element of the shuffle operand chosen by this shuffle the same as
26825 // the element chosen by the shuffle operand itself?
26826 if (Mask0[Mask[i]] != Mask0[i])
26827 return SDValue();
26828 }
26829 // Every element of this shuffle is identical to the result of the previous
26830 // shuffle, so we can replace this value.
26831 return Shuf->getOperand(0);
26832 }
26833
visitVECTOR_SHUFFLE(SDNode * N)26834 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
26835 EVT VT = N->getValueType(0);
26836 unsigned NumElts = VT.getVectorNumElements();
26837
26838 SDValue N0 = N->getOperand(0);
26839 SDValue N1 = N->getOperand(1);
26840
26841 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
26842
26843 // Canonicalize shuffle undef, undef -> undef
26844 if (N0.isUndef() && N1.isUndef())
26845 return DAG.getUNDEF(VT);
26846
26847 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
26848
26849 // Canonicalize shuffle v, v -> v, undef
26850 if (N0 == N1)
26851 return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
26852 createUnaryMask(SVN->getMask(), NumElts));
26853
26854 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
26855 if (N0.isUndef())
26856 return DAG.getCommutedVectorShuffle(*SVN);
26857
26858 // Remove references to rhs if it is undef
26859 if (N1.isUndef()) {
26860 bool Changed = false;
26861 SmallVector<int, 8> NewMask;
26862 for (unsigned i = 0; i != NumElts; ++i) {
26863 int Idx = SVN->getMaskElt(i);
26864 if (Idx >= (int)NumElts) {
26865 Idx = -1;
26866 Changed = true;
26867 }
26868 NewMask.push_back(Idx);
26869 }
26870 if (Changed)
26871 return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
26872 }
26873
26874 if (SDValue InsElt = replaceShuffleOfInsert(SVN))
26875 return InsElt;
26876
26877 // A shuffle of a single vector that is a splatted value can always be folded.
26878 if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
26879 return V;
26880
26881 if (SDValue V = formSplatFromShuffles(SVN, DAG))
26882 return V;
26883
26884 // If it is a splat, check if the argument vector is another splat or a
26885 // build_vector.
26886 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
26887 int SplatIndex = SVN->getSplatIndex();
26888 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
26889 TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
26890 // splat (vector_bo L, R), Index -->
26891 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
26892 SDValue L = N0.getOperand(0), R = N0.getOperand(1);
26893 SDLoc DL(N);
26894 EVT EltVT = VT.getScalarType();
26895 SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
26896 SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
26897 SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
26898 SDValue NewBO =
26899 DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
26900 SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
26901 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
26902 return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
26903 }
26904
26905 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
26906 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
26907 if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
26908 N0.hasOneUse()) {
26909 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
26910 return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
26911
26912 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
26913 if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
26914 if (Idx->getAPIntValue() == SplatIndex)
26915 return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
26916
26917 // Look through a bitcast if LE and splatting lane 0, through to a
26918 // scalar_to_vector or a build_vector.
26919 if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(0).hasOneUse() &&
26920 SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
26921 (N0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
26922 N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR)) {
26923 EVT N00VT = N0.getOperand(0).getValueType();
26924 if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
26925 VT.isInteger() && N00VT.isInteger()) {
26926 EVT InVT =
26927 TLI.getTypeToTransformTo(*DAG.getContext(), VT.getScalarType());
26928 SDValue Op = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0),
26929 SDLoc(N), InVT);
26930 return DAG.getSplatBuildVector(VT, SDLoc(N), Op);
26931 }
26932 }
26933 }
26934
26935 // If this is a bit convert that changes the element type of the vector but
26936 // not the number of vector elements, look through it. Be careful not to
26937 // look though conversions that change things like v4f32 to v2f64.
26938 SDNode *V = N0.getNode();
26939 if (V->getOpcode() == ISD::BITCAST) {
26940 SDValue ConvInput = V->getOperand(0);
26941 if (ConvInput.getValueType().isVector() &&
26942 ConvInput.getValueType().getVectorNumElements() == NumElts)
26943 V = ConvInput.getNode();
26944 }
26945
26946 if (V->getOpcode() == ISD::BUILD_VECTOR) {
26947 assert(V->getNumOperands() == NumElts &&
26948 "BUILD_VECTOR has wrong number of operands");
26949 SDValue Base;
26950 bool AllSame = true;
26951 for (unsigned i = 0; i != NumElts; ++i) {
26952 if (!V->getOperand(i).isUndef()) {
26953 Base = V->getOperand(i);
26954 break;
26955 }
26956 }
26957 // Splat of <u, u, u, u>, return <u, u, u, u>
26958 if (!Base.getNode())
26959 return N0;
26960 for (unsigned i = 0; i != NumElts; ++i) {
26961 if (V->getOperand(i) != Base) {
26962 AllSame = false;
26963 break;
26964 }
26965 }
26966 // Splat of <x, x, x, x>, return <x, x, x, x>
26967 if (AllSame)
26968 return N0;
26969
26970 // Canonicalize any other splat as a build_vector, but avoid defining any
26971 // undefined elements in the mask.
26972 SDValue Splatted = V->getOperand(SplatIndex);
26973 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
26974 EVT EltVT = Splatted.getValueType();
26975
26976 for (unsigned i = 0; i != NumElts; ++i) {
26977 if (SVN->getMaskElt(i) < 0)
26978 Ops[i] = DAG.getUNDEF(EltVT);
26979 }
26980
26981 SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
26982
26983 // We may have jumped through bitcasts, so the type of the
26984 // BUILD_VECTOR may not match the type of the shuffle.
26985 if (V->getValueType(0) != VT)
26986 NewBV = DAG.getBitcast(VT, NewBV);
26987 return NewBV;
26988 }
26989 }
26990
26991 // Simplify source operands based on shuffle mask.
26992 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
26993 return SDValue(N, 0);
26994
26995 // This is intentionally placed after demanded elements simplification because
26996 // it could eliminate knowledge of undef elements created by this shuffle.
26997 if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
26998 return ShufOp;
26999
27000 // Match shuffles that can be converted to any_vector_extend_in_reg.
27001 if (SDValue V =
27002 combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
27003 return V;
27004
27005 // Combine "truncate_vector_in_reg" style shuffles.
27006 if (SDValue V = combineTruncationShuffle(SVN, DAG))
27007 return V;
27008
27009 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
27010 Level < AfterLegalizeVectorOps &&
27011 (N1.isUndef() ||
27012 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
27013 N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
27014 if (SDValue V = partitionShuffleOfConcats(N, DAG))
27015 return V;
27016 }
27017
27018 // A shuffle of a concat of the same narrow vector can be reduced to use
27019 // only low-half elements of a concat with undef:
27020 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
27021 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
27022 N0.getNumOperands() == 2 &&
27023 N0.getOperand(0) == N0.getOperand(1)) {
27024 int HalfNumElts = (int)NumElts / 2;
27025 SmallVector<int, 8> NewMask;
27026 for (unsigned i = 0; i != NumElts; ++i) {
27027 int Idx = SVN->getMaskElt(i);
27028 if (Idx >= HalfNumElts) {
27029 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
27030 Idx -= HalfNumElts;
27031 }
27032 NewMask.push_back(Idx);
27033 }
27034 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
27035 SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
27036 SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
27037 N0.getOperand(0), UndefVec);
27038 return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
27039 }
27040 }
27041
27042 // See if we can replace a shuffle with an insert_subvector.
27043 // e.g. v2i32 into v8i32:
27044 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
27045 // --> insert_subvector(lhs,rhs1,4).
27046 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
27047 TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
27048 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
27049 // Ensure RHS subvectors are legal.
27050 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
27051 EVT SubVT = RHS.getOperand(0).getValueType();
27052 int NumSubVecs = RHS.getNumOperands();
27053 int NumSubElts = SubVT.getVectorNumElements();
27054 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
27055 if (!TLI.isTypeLegal(SubVT))
27056 return SDValue();
27057
27058 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
27059 if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
27060 return SDValue();
27061
27062 // Search [NumSubElts] spans for RHS sequence.
27063 // TODO: Can we avoid nested loops to increase performance?
27064 SmallVector<int> InsertionMask(NumElts);
27065 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
27066 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
27067 // Reset mask to identity.
27068 std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
27069
27070 // Add subvector insertion.
27071 std::iota(InsertionMask.begin() + SubIdx,
27072 InsertionMask.begin() + SubIdx + NumSubElts,
27073 NumElts + (SubVec * NumSubElts));
27074
27075 // See if the shuffle mask matches the reference insertion mask.
27076 bool MatchingShuffle = true;
27077 for (int i = 0; i != (int)NumElts; ++i) {
27078 int ExpectIdx = InsertionMask[i];
27079 int ActualIdx = Mask[i];
27080 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
27081 MatchingShuffle = false;
27082 break;
27083 }
27084 }
27085
27086 if (MatchingShuffle)
27087 return DAG.getInsertSubvector(SDLoc(N), LHS, RHS.getOperand(SubVec),
27088 SubIdx);
27089 }
27090 }
27091 return SDValue();
27092 };
27093 ArrayRef<int> Mask = SVN->getMask();
27094 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
27095 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
27096 return InsertN1;
27097 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
27098 SmallVector<int> CommuteMask(Mask);
27099 ShuffleVectorSDNode::commuteMask(CommuteMask);
27100 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
27101 return InsertN0;
27102 }
27103 }
27104
27105 // If we're not performing a select/blend shuffle, see if we can convert the
27106 // shuffle into a AND node, with all the out-of-lane elements are known zero.
27107 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
27108 bool IsInLaneMask = true;
27109 ArrayRef<int> Mask = SVN->getMask();
27110 SmallVector<int, 16> ClearMask(NumElts, -1);
27111 APInt DemandedLHS = APInt::getZero(NumElts);
27112 APInt DemandedRHS = APInt::getZero(NumElts);
27113 for (int I = 0; I != (int)NumElts; ++I) {
27114 int M = Mask[I];
27115 if (M < 0)
27116 continue;
27117 ClearMask[I] = M == I ? I : (I + NumElts);
27118 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
27119 if (M != I) {
27120 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
27121 Demanded.setBit(M % NumElts);
27122 }
27123 }
27124 // TODO: Should we try to mask with N1 as well?
27125 if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
27126 (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
27127 (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
27128 SDLoc DL(N);
27129 EVT IntVT = VT.changeVectorElementTypeToInteger();
27130 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
27131 // Transform the type to a legal type so that the buildvector constant
27132 // elements are not illegal. Make sure that the result is larger than the
27133 // original type, incase the value is split into two (eg i64->i32).
27134 if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
27135 IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
27136 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
27137 SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
27138 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
27139 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
27140 for (int I = 0; I != (int)NumElts; ++I)
27141 if (0 <= Mask[I])
27142 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
27143
27144 // See if a clear mask is legal instead of going via
27145 // XformToShuffleWithZero which loses UNDEF mask elements.
27146 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
27147 return DAG.getBitcast(
27148 VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
27149 DAG.getConstant(0, DL, IntVT), ClearMask));
27150
27151 if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
27152 return DAG.getBitcast(
27153 VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
27154 DAG.getBuildVector(IntVT, DL, AndMask)));
27155 }
27156 }
27157 }
27158
27159 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
27160 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
27161 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
27162 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
27163 return Res;
27164
27165 // If this shuffle only has a single input that is a bitcasted shuffle,
27166 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
27167 // back to their original types.
27168 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
27169 N1.isUndef() && Level < AfterLegalizeVectorOps &&
27170 TLI.isTypeLegal(VT)) {
27171
27172 SDValue BC0 = peekThroughOneUseBitcasts(N0);
27173 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
27174 EVT SVT = VT.getScalarType();
27175 EVT InnerVT = BC0->getValueType(0);
27176 EVT InnerSVT = InnerVT.getScalarType();
27177
27178 // Determine which shuffle works with the smaller scalar type.
27179 EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
27180 EVT ScaleSVT = ScaleVT.getScalarType();
27181
27182 if (TLI.isTypeLegal(ScaleVT) &&
27183 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
27184 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
27185 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
27186 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
27187
27188 // Scale the shuffle masks to the smaller scalar type.
27189 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
27190 SmallVector<int, 8> InnerMask;
27191 SmallVector<int, 8> OuterMask;
27192 narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
27193 narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
27194
27195 // Merge the shuffle masks.
27196 SmallVector<int, 8> NewMask;
27197 for (int M : OuterMask)
27198 NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
27199
27200 // Test for shuffle mask legality over both commutations.
27201 SDValue SV0 = BC0->getOperand(0);
27202 SDValue SV1 = BC0->getOperand(1);
27203 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
27204 if (!LegalMask) {
27205 std::swap(SV0, SV1);
27206 ShuffleVectorSDNode::commuteMask(NewMask);
27207 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
27208 }
27209
27210 if (LegalMask) {
27211 SV0 = DAG.getBitcast(ScaleVT, SV0);
27212 SV1 = DAG.getBitcast(ScaleVT, SV1);
27213 return DAG.getBitcast(
27214 VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
27215 }
27216 }
27217 }
27218 }
27219
27220 // Match shuffles of bitcasts, so long as the mask can be treated as the
27221 // larger type.
27222 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
27223 return V;
27224
27225 // Compute the combined shuffle mask for a shuffle with SV0 as the first
27226 // operand, and SV1 as the second operand.
27227 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
27228 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
27229 auto MergeInnerShuffle =
27230 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
27231 ShuffleVectorSDNode *OtherSVN, SDValue N1,
27232 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
27233 SmallVectorImpl<int> &Mask) -> bool {
27234 // Don't try to fold splats; they're likely to simplify somehow, or they
27235 // might be free.
27236 if (OtherSVN->isSplat())
27237 return false;
27238
27239 SV0 = SV1 = SDValue();
27240 Mask.clear();
27241
27242 for (unsigned i = 0; i != NumElts; ++i) {
27243 int Idx = SVN->getMaskElt(i);
27244 if (Idx < 0) {
27245 // Propagate Undef.
27246 Mask.push_back(Idx);
27247 continue;
27248 }
27249
27250 if (Commute)
27251 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
27252
27253 SDValue CurrentVec;
27254 if (Idx < (int)NumElts) {
27255 // This shuffle index refers to the inner shuffle N0. Lookup the inner
27256 // shuffle mask to identify which vector is actually referenced.
27257 Idx = OtherSVN->getMaskElt(Idx);
27258 if (Idx < 0) {
27259 // Propagate Undef.
27260 Mask.push_back(Idx);
27261 continue;
27262 }
27263 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
27264 : OtherSVN->getOperand(1);
27265 } else {
27266 // This shuffle index references an element within N1.
27267 CurrentVec = N1;
27268 }
27269
27270 // Simple case where 'CurrentVec' is UNDEF.
27271 if (CurrentVec.isUndef()) {
27272 Mask.push_back(-1);
27273 continue;
27274 }
27275
27276 // Canonicalize the shuffle index. We don't know yet if CurrentVec
27277 // will be the first or second operand of the combined shuffle.
27278 Idx = Idx % NumElts;
27279 if (!SV0.getNode() || SV0 == CurrentVec) {
27280 // Ok. CurrentVec is the left hand side.
27281 // Update the mask accordingly.
27282 SV0 = CurrentVec;
27283 Mask.push_back(Idx);
27284 continue;
27285 }
27286 if (!SV1.getNode() || SV1 == CurrentVec) {
27287 // Ok. CurrentVec is the right hand side.
27288 // Update the mask accordingly.
27289 SV1 = CurrentVec;
27290 Mask.push_back(Idx + NumElts);
27291 continue;
27292 }
27293
27294 // Last chance - see if the vector is another shuffle and if it
27295 // uses one of the existing candidate shuffle ops.
27296 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
27297 int InnerIdx = CurrentSVN->getMaskElt(Idx);
27298 if (InnerIdx < 0) {
27299 Mask.push_back(-1);
27300 continue;
27301 }
27302 SDValue InnerVec = (InnerIdx < (int)NumElts)
27303 ? CurrentSVN->getOperand(0)
27304 : CurrentSVN->getOperand(1);
27305 if (InnerVec.isUndef()) {
27306 Mask.push_back(-1);
27307 continue;
27308 }
27309 InnerIdx %= NumElts;
27310 if (InnerVec == SV0) {
27311 Mask.push_back(InnerIdx);
27312 continue;
27313 }
27314 if (InnerVec == SV1) {
27315 Mask.push_back(InnerIdx + NumElts);
27316 continue;
27317 }
27318 }
27319
27320 // Bail out if we cannot convert the shuffle pair into a single shuffle.
27321 return false;
27322 }
27323
27324 if (llvm::all_of(Mask, [](int M) { return M < 0; }))
27325 return true;
27326
27327 // Avoid introducing shuffles with illegal mask.
27328 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
27329 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
27330 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
27331 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
27332 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
27333 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
27334 if (TLI.isShuffleMaskLegal(Mask, VT))
27335 return true;
27336
27337 std::swap(SV0, SV1);
27338 ShuffleVectorSDNode::commuteMask(Mask);
27339 return TLI.isShuffleMaskLegal(Mask, VT);
27340 };
27341
27342 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
27343 // Canonicalize shuffles according to rules:
27344 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
27345 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
27346 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
27347 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
27348 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
27349 // The incoming shuffle must be of the same type as the result of the
27350 // current shuffle.
27351 assert(N1->getOperand(0).getValueType() == VT &&
27352 "Shuffle types don't match");
27353
27354 SDValue SV0 = N1->getOperand(0);
27355 SDValue SV1 = N1->getOperand(1);
27356 bool HasSameOp0 = N0 == SV0;
27357 bool IsSV1Undef = SV1.isUndef();
27358 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
27359 // Commute the operands of this shuffle so merging below will trigger.
27360 return DAG.getCommutedVectorShuffle(*SVN);
27361 }
27362
27363 // Canonicalize splat shuffles to the RHS to improve merging below.
27364 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
27365 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
27366 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
27367 cast<ShuffleVectorSDNode>(N0)->isSplat() &&
27368 !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
27369 return DAG.getCommutedVectorShuffle(*SVN);
27370 }
27371
27372 // Try to fold according to rules:
27373 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
27374 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
27375 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
27376 // Don't try to fold shuffles with illegal type.
27377 // Only fold if this shuffle is the only user of the other shuffle.
27378 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
27379 for (int i = 0; i != 2; ++i) {
27380 if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
27381 N->isOnlyUserOf(N->getOperand(i).getNode())) {
27382 // The incoming shuffle must be of the same type as the result of the
27383 // current shuffle.
27384 auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
27385 assert(OtherSV->getOperand(0).getValueType() == VT &&
27386 "Shuffle types don't match");
27387
27388 SDValue SV0, SV1;
27389 SmallVector<int, 4> Mask;
27390 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
27391 SV0, SV1, Mask)) {
27392 // Check if all indices in Mask are Undef. In case, propagate Undef.
27393 if (llvm::all_of(Mask, [](int M) { return M < 0; }))
27394 return DAG.getUNDEF(VT);
27395
27396 return DAG.getVectorShuffle(VT, SDLoc(N),
27397 SV0 ? SV0 : DAG.getUNDEF(VT),
27398 SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
27399 }
27400 }
27401 }
27402
27403 // Merge shuffles through binops if we are able to merge it with at least
27404 // one other shuffles.
27405 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
27406 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
27407 unsigned SrcOpcode = N0.getOpcode();
27408 if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
27409 (N1.isUndef() ||
27410 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
27411 // Get binop source ops, or just pass on the undef.
27412 SDValue Op00 = N0.getOperand(0);
27413 SDValue Op01 = N0.getOperand(1);
27414 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
27415 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
27416 // TODO: We might be able to relax the VT check but we don't currently
27417 // have any isBinOp() that has different result/ops VTs so play safe until
27418 // we have test coverage.
27419 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
27420 Op01.getValueType() == VT && Op11.getValueType() == VT &&
27421 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
27422 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
27423 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
27424 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
27425 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
27426 SmallVectorImpl<int> &Mask, bool LeftOp,
27427 bool Commute) {
27428 SDValue InnerN = Commute ? N1 : N0;
27429 SDValue Op0 = LeftOp ? Op00 : Op01;
27430 SDValue Op1 = LeftOp ? Op10 : Op11;
27431 if (Commute)
27432 std::swap(Op0, Op1);
27433 // Only accept the merged shuffle if we don't introduce undef elements,
27434 // or the inner shuffle already contained undef elements.
27435 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
27436 return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
27437 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
27438 Mask) &&
27439 (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
27440 llvm::none_of(Mask, [](int M) { return M < 0; }));
27441 };
27442
27443 // Ensure we don't increase the number of shuffles - we must merge a
27444 // shuffle from at least one of the LHS and RHS ops.
27445 bool MergedLeft = false;
27446 SDValue LeftSV0, LeftSV1;
27447 SmallVector<int, 4> LeftMask;
27448 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
27449 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
27450 MergedLeft = true;
27451 } else {
27452 LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
27453 LeftSV0 = Op00, LeftSV1 = Op10;
27454 }
27455
27456 bool MergedRight = false;
27457 SDValue RightSV0, RightSV1;
27458 SmallVector<int, 4> RightMask;
27459 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
27460 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
27461 MergedRight = true;
27462 } else {
27463 RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
27464 RightSV0 = Op01, RightSV1 = Op11;
27465 }
27466
27467 if (MergedLeft || MergedRight) {
27468 SDLoc DL(N);
27469 SDValue LHS = DAG.getVectorShuffle(
27470 VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
27471 LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
27472 SDValue RHS = DAG.getVectorShuffle(
27473 VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
27474 RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
27475 return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
27476 }
27477 }
27478 }
27479 }
27480
27481 if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
27482 return V;
27483
27484 // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
27485 // Perform this really late, because it could eliminate knowledge
27486 // of undef elements created by this shuffle.
27487 if (Level < AfterLegalizeTypes)
27488 if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
27489 LegalOperations))
27490 return V;
27491
27492 return SDValue();
27493 }
27494
visitSCALAR_TO_VECTOR(SDNode * N)27495 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
27496 EVT VT = N->getValueType(0);
27497 if (!VT.isFixedLengthVector())
27498 return SDValue();
27499
27500 // Try to convert a scalar binop with an extracted vector element to a vector
27501 // binop. This is intended to reduce potentially expensive register moves.
27502 // TODO: Check if both operands are extracted.
27503 // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
27504 // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
27505 SDValue Scalar = N->getOperand(0);
27506 unsigned Opcode = Scalar.getOpcode();
27507 EVT VecEltVT = VT.getScalarType();
27508 if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
27509 TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
27510 Scalar.getOperand(0).getValueType() == VecEltVT &&
27511 Scalar.getOperand(1).getValueType() == VecEltVT &&
27512 Scalar->isOnlyUserOf(Scalar.getOperand(0).getNode()) &&
27513 Scalar->isOnlyUserOf(Scalar.getOperand(1).getNode()) &&
27514 DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
27515 // Match an extract element and get a shuffle mask equivalent.
27516 SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
27517
27518 for (int i : {0, 1}) {
27519 // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
27520 // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
27521 SDValue EE = Scalar.getOperand(i);
27522 auto *C = dyn_cast<ConstantSDNode>(Scalar.getOperand(i ? 0 : 1));
27523 if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
27524 EE.getOperand(0).getValueType() == VT &&
27525 isa<ConstantSDNode>(EE.getOperand(1))) {
27526 // Mask = {ExtractIndex, undef, undef....}
27527 ShufMask[0] = EE.getConstantOperandVal(1);
27528 // Make sure the shuffle is legal if we are crossing lanes.
27529 if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
27530 SDLoc DL(N);
27531 SDValue V[] = {EE.getOperand(0),
27532 DAG.getConstant(C->getAPIntValue(), DL, VT)};
27533 SDValue VecBO = DAG.getNode(Opcode, DL, VT, V[i], V[1 - i]);
27534 return DAG.getVectorShuffle(VT, DL, VecBO, DAG.getUNDEF(VT),
27535 ShufMask);
27536 }
27537 }
27538 }
27539 }
27540
27541 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
27542 // with a VECTOR_SHUFFLE and possible truncate.
27543 if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
27544 !Scalar.getOperand(0).getValueType().isFixedLengthVector())
27545 return SDValue();
27546
27547 // If we have an implicit truncate, truncate here if it is legal.
27548 if (VecEltVT != Scalar.getValueType() &&
27549 Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
27550 SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
27551 return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
27552 }
27553
27554 auto *ExtIndexC = dyn_cast<ConstantSDNode>(Scalar.getOperand(1));
27555 if (!ExtIndexC)
27556 return SDValue();
27557
27558 SDValue SrcVec = Scalar.getOperand(0);
27559 EVT SrcVT = SrcVec.getValueType();
27560 unsigned SrcNumElts = SrcVT.getVectorNumElements();
27561 unsigned VTNumElts = VT.getVectorNumElements();
27562 if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
27563 // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
27564 SmallVector<int, 8> Mask(SrcNumElts, -1);
27565 Mask[0] = ExtIndexC->getZExtValue();
27566 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
27567 SrcVT, SDLoc(N), SrcVec, DAG.getUNDEF(SrcVT), Mask, DAG);
27568 if (!LegalShuffle)
27569 return SDValue();
27570
27571 // If the initial vector is the same size, the shuffle is the result.
27572 if (VT == SrcVT)
27573 return LegalShuffle;
27574
27575 // If not, shorten the shuffled vector.
27576 if (VTNumElts != SrcNumElts) {
27577 SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
27578 EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
27579 SrcVT.getVectorElementType(), VTNumElts);
27580 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, LegalShuffle,
27581 ZeroIdx);
27582 }
27583 }
27584
27585 return SDValue();
27586 }
27587
visitINSERT_SUBVECTOR(SDNode * N)27588 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
27589 EVT VT = N->getValueType(0);
27590 SDValue N0 = N->getOperand(0);
27591 SDValue N1 = N->getOperand(1);
27592 SDValue N2 = N->getOperand(2);
27593 uint64_t InsIdx = N->getConstantOperandVal(2);
27594
27595 // If inserting an UNDEF, just return the original vector.
27596 if (N1.isUndef())
27597 return N0;
27598
27599 // If this is an insert of an extracted vector into an undef vector, we can
27600 // just use the input to the extract if the types match, and can simplify
27601 // in some cases even if they don't.
27602 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
27603 N1.getOperand(1) == N2) {
27604 EVT SrcVT = N1.getOperand(0).getValueType();
27605 if (SrcVT == VT)
27606 return N1.getOperand(0);
27607 // TODO: To remove the zero check, need to adjust the offset to
27608 // a multiple of the new src type.
27609 if (isNullConstant(N2)) {
27610 if (VT.knownBitsGE(SrcVT) &&
27611 !(VT.isFixedLengthVector() && SrcVT.isScalableVector()))
27612 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
27613 VT, N0, N1.getOperand(0), N2);
27614 else if (VT.knownBitsLE(SrcVT) &&
27615 !(VT.isScalableVector() && SrcVT.isFixedLengthVector()))
27616 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N),
27617 VT, N1.getOperand(0), N2);
27618 }
27619 }
27620
27621 // Handle case where we've ended up inserting back into the source vector
27622 // we extracted the subvector from.
27623 // insert_subvector(N0, extract_subvector(N0, N2), N2) --> N0
27624 if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(0) == N0 &&
27625 N1.getOperand(1) == N2)
27626 return N0;
27627
27628 // Simplify scalar inserts into an undef vector:
27629 // insert_subvector undef, (splat X), N2 -> splat X
27630 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
27631 if (DAG.isConstantValueOfAnyType(N1.getOperand(0)) || N1.hasOneUse())
27632 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
27633
27634 // insert_subvector (splat X), (splat X), N2 -> splat X
27635 if (N0.getOpcode() == ISD::SPLAT_VECTOR && N0.getOpcode() == N1.getOpcode() &&
27636 N0.getOperand(0) == N1.getOperand(0))
27637 return N0;
27638
27639 // If we are inserting a bitcast value into an undef, with the same
27640 // number of elements, just use the bitcast input of the extract.
27641 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
27642 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
27643 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
27644 N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
27645 N1.getOperand(0).getOperand(1) == N2 &&
27646 N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
27647 VT.getVectorElementCount() &&
27648 N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
27649 VT.getSizeInBits()) {
27650 return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
27651 }
27652
27653 // If both N1 and N2 are bitcast values on which insert_subvector
27654 // would makes sense, pull the bitcast through.
27655 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
27656 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
27657 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
27658 SDValue CN0 = N0.getOperand(0);
27659 SDValue CN1 = N1.getOperand(0);
27660 EVT CN0VT = CN0.getValueType();
27661 EVT CN1VT = CN1.getValueType();
27662 if (CN0VT.isVector() && CN1VT.isVector() &&
27663 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
27664 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
27665 SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
27666 CN0.getValueType(), CN0, CN1, N2);
27667 return DAG.getBitcast(VT, NewINSERT);
27668 }
27669 }
27670
27671 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
27672 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
27673 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
27674 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
27675 N0.getOperand(1).getValueType() == N1.getValueType() &&
27676 N0.getOperand(2) == N2)
27677 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
27678 N1, N2);
27679
27680 // Eliminate an intermediate insert into an undef vector:
27681 // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
27682 // insert_subvector undef, X, 0
27683 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
27684 N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)) &&
27685 isNullConstant(N2))
27686 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
27687 N1.getOperand(1), N2);
27688
27689 // Push subvector bitcasts to the output, adjusting the index as we go.
27690 // insert_subvector(bitcast(v), bitcast(s), c1)
27691 // -> bitcast(insert_subvector(v, s, c2))
27692 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
27693 N1.getOpcode() == ISD::BITCAST) {
27694 SDValue N0Src = peekThroughBitcasts(N0);
27695 SDValue N1Src = peekThroughBitcasts(N1);
27696 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
27697 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
27698 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
27699 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
27700 EVT NewVT;
27701 SDLoc DL(N);
27702 SDValue NewIdx;
27703 LLVMContext &Ctx = *DAG.getContext();
27704 ElementCount NumElts = VT.getVectorElementCount();
27705 unsigned EltSizeInBits = VT.getScalarSizeInBits();
27706 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
27707 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
27708 NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
27709 NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
27710 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
27711 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
27712 if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
27713 NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
27714 NumElts.divideCoefficientBy(Scale));
27715 NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
27716 }
27717 }
27718 if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
27719 SDValue Res = DAG.getBitcast(NewVT, N0Src);
27720 Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
27721 return DAG.getBitcast(VT, Res);
27722 }
27723 }
27724 }
27725
27726 // Canonicalize insert_subvector dag nodes.
27727 // Example:
27728 // (insert_subvector (insert_subvector A, Idx0), Idx1)
27729 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
27730 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
27731 N1.getValueType() == N0.getOperand(1).getValueType()) {
27732 unsigned OtherIdx = N0.getConstantOperandVal(2);
27733 if (InsIdx < OtherIdx) {
27734 // Swap nodes.
27735 SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
27736 N0.getOperand(0), N1, N2);
27737 AddToWorklist(NewOp.getNode());
27738 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
27739 VT, NewOp, N0.getOperand(1), N0.getOperand(2));
27740 }
27741 }
27742
27743 // If the input vector is a concatenation, and the insert replaces
27744 // one of the pieces, we can optimize into a single concat_vectors.
27745 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
27746 N0.getOperand(0).getValueType() == N1.getValueType() &&
27747 N0.getOperand(0).getValueType().isScalableVector() ==
27748 N1.getValueType().isScalableVector()) {
27749 unsigned Factor = N1.getValueType().getVectorMinNumElements();
27750 SmallVector<SDValue, 8> Ops(N0->ops());
27751 Ops[InsIdx / Factor] = N1;
27752 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
27753 }
27754
27755 // Simplify source operands based on insertion.
27756 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
27757 return SDValue(N, 0);
27758
27759 return SDValue();
27760 }
27761
visitFP_TO_FP16(SDNode * N)27762 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
27763 SDValue N0 = N->getOperand(0);
27764
27765 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
27766 if (N0->getOpcode() == ISD::FP16_TO_FP)
27767 return N0->getOperand(0);
27768
27769 return SDValue();
27770 }
27771
visitFP16_TO_FP(SDNode * N)27772 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
27773 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
27774 auto Op = N->getOpcode();
27775 assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
27776 "opcode should be FP16_TO_FP or BF16_TO_FP.");
27777 SDValue N0 = N->getOperand(0);
27778
27779 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
27780 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
27781 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
27782 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
27783 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
27784 return DAG.getNode(Op, SDLoc(N), N->getValueType(0), N0.getOperand(0));
27785 }
27786 }
27787
27788 if (SDValue CastEliminated = eliminateFPCastPair(N))
27789 return CastEliminated;
27790
27791 // Sometimes constants manage to survive very late in the pipeline, e.g.,
27792 // because they are wrapped inside the <1 x f16> type. Try one last time to
27793 // get rid of them.
27794 SDValue Folded = DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N),
27795 N->getValueType(0), {N0});
27796 return Folded;
27797 }
27798
visitFP_TO_BF16(SDNode * N)27799 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
27800 SDValue N0 = N->getOperand(0);
27801
27802 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
27803 if (N0->getOpcode() == ISD::BF16_TO_FP)
27804 return N0->getOperand(0);
27805
27806 return SDValue();
27807 }
27808
visitBF16_TO_FP(SDNode * N)27809 SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
27810 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
27811 return visitFP16_TO_FP(N);
27812 }
27813
visitVECREDUCE(SDNode * N)27814 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
27815 SDValue N0 = N->getOperand(0);
27816 EVT VT = N0.getValueType();
27817 unsigned Opcode = N->getOpcode();
27818
27819 // VECREDUCE over 1-element vector is just an extract.
27820 if (VT.getVectorElementCount().isScalar()) {
27821 SDLoc dl(N);
27822 SDValue Res =
27823 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
27824 DAG.getVectorIdxConstant(0, dl));
27825 if (Res.getValueType() != N->getValueType(0))
27826 Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
27827 return Res;
27828 }
27829
27830 // On an boolean vector an and/or reduction is the same as a umin/umax
27831 // reduction. Convert them if the latter is legal while the former isn't.
27832 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
27833 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
27834 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
27835 if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
27836 TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
27837 DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
27838 return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
27839 }
27840
27841 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
27842 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
27843 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
27844 TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
27845 SDValue Vec = N0.getOperand(0);
27846 SDValue Subvec = N0.getOperand(1);
27847 if ((Opcode == ISD::VECREDUCE_OR &&
27848 (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
27849 (Opcode == ISD::VECREDUCE_AND &&
27850 (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
27851 return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
27852 }
27853
27854 // vecreduce_or(sext(x)) -> sext(vecreduce_or(x))
27855 // Same for zext and anyext, and for and/or/xor reductions.
27856 if ((Opcode == ISD::VECREDUCE_OR || Opcode == ISD::VECREDUCE_AND ||
27857 Opcode == ISD::VECREDUCE_XOR) &&
27858 (N0.getOpcode() == ISD::SIGN_EXTEND ||
27859 N0.getOpcode() == ISD::ZERO_EXTEND ||
27860 N0.getOpcode() == ISD::ANY_EXTEND) &&
27861 TLI.isOperationLegalOrCustom(Opcode, N0.getOperand(0).getValueType())) {
27862 SDValue Red = DAG.getNode(Opcode, SDLoc(N),
27863 N0.getOperand(0).getValueType().getScalarType(),
27864 N0.getOperand(0));
27865 return DAG.getNode(N0.getOpcode(), SDLoc(N), N->getValueType(0), Red);
27866 }
27867 return SDValue();
27868 }
27869
visitVP_FSUB(SDNode * N)27870 SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
27871 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
27872
27873 // FSUB -> FMA combines:
27874 if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
27875 AddToWorklist(Fused.getNode());
27876 return Fused;
27877 }
27878 return SDValue();
27879 }
27880
visitVPOp(SDNode * N)27881 SDValue DAGCombiner::visitVPOp(SDNode *N) {
27882
27883 if (N->getOpcode() == ISD::VP_GATHER)
27884 if (SDValue SD = visitVPGATHER(N))
27885 return SD;
27886
27887 if (N->getOpcode() == ISD::VP_SCATTER)
27888 if (SDValue SD = visitVPSCATTER(N))
27889 return SD;
27890
27891 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
27892 if (SDValue SD = visitVP_STRIDED_LOAD(N))
27893 return SD;
27894
27895 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
27896 if (SDValue SD = visitVP_STRIDED_STORE(N))
27897 return SD;
27898
27899 // VP operations in which all vector elements are disabled - either by
27900 // determining that the mask is all false or that the EVL is 0 - can be
27901 // eliminated.
27902 bool AreAllEltsDisabled = false;
27903 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
27904 AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
27905 if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
27906 AreAllEltsDisabled |=
27907 ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
27908
27909 // This is the only generic VP combine we support for now.
27910 if (!AreAllEltsDisabled) {
27911 switch (N->getOpcode()) {
27912 case ISD::VP_FADD:
27913 return visitVP_FADD(N);
27914 case ISD::VP_FSUB:
27915 return visitVP_FSUB(N);
27916 case ISD::VP_FMA:
27917 return visitFMA<VPMatchContext>(N);
27918 case ISD::VP_SELECT:
27919 return visitVP_SELECT(N);
27920 case ISD::VP_MUL:
27921 return visitMUL<VPMatchContext>(N);
27922 case ISD::VP_SUB:
27923 return foldSubCtlzNot<VPMatchContext>(N, DAG);
27924 default:
27925 break;
27926 }
27927 return SDValue();
27928 }
27929
27930 // Binary operations can be replaced by UNDEF.
27931 if (ISD::isVPBinaryOp(N->getOpcode()))
27932 return DAG.getUNDEF(N->getValueType(0));
27933
27934 // VP Memory operations can be replaced by either the chain (stores) or the
27935 // chain + undef (loads).
27936 if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
27937 if (MemSD->writeMem())
27938 return MemSD->getChain();
27939 return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
27940 }
27941
27942 // Reduction operations return the start operand when no elements are active.
27943 if (ISD::isVPReduction(N->getOpcode()))
27944 return N->getOperand(0);
27945
27946 return SDValue();
27947 }
27948
visitGET_FPENV_MEM(SDNode * N)27949 SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
27950 SDValue Chain = N->getOperand(0);
27951 SDValue Ptr = N->getOperand(1);
27952 EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
27953
27954 // Check if the memory, where FP state is written to, is used only in a single
27955 // load operation.
27956 LoadSDNode *LdNode = nullptr;
27957 for (auto *U : Ptr->users()) {
27958 if (U == N)
27959 continue;
27960 if (auto *Ld = dyn_cast<LoadSDNode>(U)) {
27961 if (LdNode && LdNode != Ld)
27962 return SDValue();
27963 LdNode = Ld;
27964 continue;
27965 }
27966 return SDValue();
27967 }
27968 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
27969 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
27970 !LdNode->getChain().reachesChainWithoutSideEffects(SDValue(N, 0)))
27971 return SDValue();
27972
27973 // Check if the loaded value is used only in a store operation.
27974 StoreSDNode *StNode = nullptr;
27975 for (SDUse &U : LdNode->uses()) {
27976 if (U.getResNo() == 0) {
27977 if (auto *St = dyn_cast<StoreSDNode>(U.getUser())) {
27978 if (StNode)
27979 return SDValue();
27980 StNode = St;
27981 } else {
27982 return SDValue();
27983 }
27984 }
27985 }
27986 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
27987 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
27988 !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
27989 return SDValue();
27990
27991 // Create new node GET_FPENV_MEM, which uses the store address to write FP
27992 // environment.
27993 SDValue Res = DAG.getGetFPEnv(Chain, SDLoc(N), StNode->getBasePtr(), MemVT,
27994 StNode->getMemOperand());
27995 CombineTo(StNode, Res, false);
27996 return Res;
27997 }
27998
visitSET_FPENV_MEM(SDNode * N)27999 SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
28000 SDValue Chain = N->getOperand(0);
28001 SDValue Ptr = N->getOperand(1);
28002 EVT MemVT = cast<FPStateAccessSDNode>(N)->getMemoryVT();
28003
28004 // Check if the address of FP state is used also in a store operation only.
28005 StoreSDNode *StNode = nullptr;
28006 for (auto *U : Ptr->users()) {
28007 if (U == N)
28008 continue;
28009 if (auto *St = dyn_cast<StoreSDNode>(U)) {
28010 if (StNode && StNode != St)
28011 return SDValue();
28012 StNode = St;
28013 continue;
28014 }
28015 return SDValue();
28016 }
28017 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
28018 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
28019 !Chain.reachesChainWithoutSideEffects(SDValue(StNode, 0)))
28020 return SDValue();
28021
28022 // Check if the stored value is loaded from some location and the loaded
28023 // value is used only in the store operation.
28024 SDValue StValue = StNode->getValue();
28025 auto *LdNode = dyn_cast<LoadSDNode>(StValue);
28026 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
28027 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
28028 !StNode->getChain().reachesChainWithoutSideEffects(SDValue(LdNode, 1)))
28029 return SDValue();
28030
28031 // Create new node SET_FPENV_MEM, which uses the load address to read FP
28032 // environment.
28033 SDValue Res =
28034 DAG.getSetFPEnv(LdNode->getChain(), SDLoc(N), LdNode->getBasePtr(), MemVT,
28035 LdNode->getMemOperand());
28036 return Res;
28037 }
28038
28039 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
28040 /// with the destination vector and a zero vector.
28041 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
28042 /// vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)28043 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
28044 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
28045
28046 EVT VT = N->getValueType(0);
28047 SDValue LHS = N->getOperand(0);
28048 SDValue RHS = peekThroughBitcasts(N->getOperand(1));
28049 SDLoc DL(N);
28050
28051 // Make sure we're not running after operation legalization where it
28052 // may have custom lowered the vector shuffles.
28053 if (LegalOperations)
28054 return SDValue();
28055
28056 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
28057 return SDValue();
28058
28059 EVT RVT = RHS.getValueType();
28060 unsigned NumElts = RHS.getNumOperands();
28061
28062 // Attempt to create a valid clear mask, splitting the mask into
28063 // sub elements and checking to see if each is
28064 // all zeros or all ones - suitable for shuffle masking.
28065 auto BuildClearMask = [&](int Split) {
28066 int NumSubElts = NumElts * Split;
28067 int NumSubBits = RVT.getScalarSizeInBits() / Split;
28068
28069 SmallVector<int, 8> Indices;
28070 for (int i = 0; i != NumSubElts; ++i) {
28071 int EltIdx = i / Split;
28072 int SubIdx = i % Split;
28073 SDValue Elt = RHS.getOperand(EltIdx);
28074 // X & undef --> 0 (not undef). So this lane must be converted to choose
28075 // from the zero constant vector (same as if the element had all 0-bits).
28076 if (Elt.isUndef()) {
28077 Indices.push_back(i + NumSubElts);
28078 continue;
28079 }
28080
28081 std::optional<APInt> Bits = Elt->bitcastToAPInt();
28082 if (!Bits)
28083 return SDValue();
28084
28085 // Extract the sub element from the constant bit mask.
28086 if (DAG.getDataLayout().isBigEndian())
28087 *Bits =
28088 Bits->extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
28089 else
28090 *Bits = Bits->extractBits(NumSubBits, SubIdx * NumSubBits);
28091
28092 if (Bits->isAllOnes())
28093 Indices.push_back(i);
28094 else if (*Bits == 0)
28095 Indices.push_back(i + NumSubElts);
28096 else
28097 return SDValue();
28098 }
28099
28100 // Let's see if the target supports this vector_shuffle.
28101 EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
28102 EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
28103 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
28104 return SDValue();
28105
28106 SDValue Zero = DAG.getConstant(0, DL, ClearVT);
28107 return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
28108 DAG.getBitcast(ClearVT, LHS),
28109 Zero, Indices));
28110 };
28111
28112 // Determine maximum split level (byte level masking).
28113 int MaxSplit = 1;
28114 if (RVT.getScalarSizeInBits() % 8 == 0)
28115 MaxSplit = RVT.getScalarSizeInBits() / 8;
28116
28117 for (int Split = 1; Split <= MaxSplit; ++Split)
28118 if (RVT.getScalarSizeInBits() % Split == 0)
28119 if (SDValue S = BuildClearMask(Split))
28120 return S;
28121
28122 return SDValue();
28123 }
28124
28125 /// If a vector binop is performed on splat values, it may be profitable to
28126 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL,bool LegalTypes)28127 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
28128 const SDLoc &DL, bool LegalTypes) {
28129 SDValue N0 = N->getOperand(0);
28130 SDValue N1 = N->getOperand(1);
28131 unsigned Opcode = N->getOpcode();
28132 EVT VT = N->getValueType(0);
28133 EVT EltVT = VT.getVectorElementType();
28134 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28135
28136 // TODO: Remove/replace the extract cost check? If the elements are available
28137 // as scalars, then there may be no extract cost. Should we ask if
28138 // inserting a scalar back into a vector is cheap instead?
28139 int Index0, Index1;
28140 SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
28141 SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
28142 // Extract element from splat_vector should be free.
28143 // TODO: use DAG.isSplatValue instead?
28144 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
28145 N1.getOpcode() == ISD::SPLAT_VECTOR;
28146 if (!Src0 || !Src1 || Index0 != Index1 ||
28147 Src0.getValueType().getVectorElementType() != EltVT ||
28148 Src1.getValueType().getVectorElementType() != EltVT ||
28149 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
28150 // If before type legalization, allow scalar types that will eventually be
28151 // made legal.
28152 !TLI.isOperationLegalOrCustom(
28153 Opcode, LegalTypes
28154 ? EltVT
28155 : TLI.getTypeToTransformTo(*DAG.getContext(), EltVT)))
28156 return SDValue();
28157
28158 // FIXME: Type legalization can't handle illegal MULHS/MULHU.
28159 if ((Opcode == ISD::MULHS || Opcode == ISD::MULHU) && !TLI.isTypeLegal(EltVT))
28160 return SDValue();
28161
28162 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode()) {
28163 // All but one element should have an undef input, which will fold to a
28164 // constant or undef. Avoid splatting which would over-define potentially
28165 // undefined elements.
28166
28167 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
28168 // build_vec ..undef, (bo X, Y), undef...
28169 SmallVector<SDValue, 16> EltsX, EltsY, EltsResult;
28170 DAG.ExtractVectorElements(Src0, EltsX);
28171 DAG.ExtractVectorElements(Src1, EltsY);
28172
28173 for (auto [X, Y] : zip(EltsX, EltsY))
28174 EltsResult.push_back(DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags()));
28175 return DAG.getBuildVector(VT, DL, EltsResult);
28176 }
28177
28178 SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
28179 SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
28180 SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
28181 SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
28182
28183 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
28184 return DAG.getSplat(VT, DL, ScalarBO);
28185 }
28186
28187 /// Visit a vector cast operation, like FP_EXTEND.
SimplifyVCastOp(SDNode * N,const SDLoc & DL)28188 SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
28189 EVT VT = N->getValueType(0);
28190 assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
28191 EVT EltVT = VT.getVectorElementType();
28192 unsigned Opcode = N->getOpcode();
28193
28194 SDValue N0 = N->getOperand(0);
28195 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
28196
28197 // TODO: promote operation might be also good here?
28198 int Index0;
28199 SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
28200 if (Src0 &&
28201 (N0.getOpcode() == ISD::SPLAT_VECTOR ||
28202 TLI.isExtractVecEltCheap(VT, Index0)) &&
28203 TLI.isOperationLegalOrCustom(Opcode, EltVT) &&
28204 TLI.preferScalarizeSplat(N)) {
28205 EVT SrcVT = N0.getValueType();
28206 EVT SrcEltVT = SrcVT.getVectorElementType();
28207 SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
28208 SDValue Elt =
28209 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC);
28210 SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags());
28211 if (VT.isScalableVector())
28212 return DAG.getSplatVector(VT, DL, ScalarBO);
28213 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
28214 return DAG.getBuildVector(VT, DL, Ops);
28215 }
28216
28217 return SDValue();
28218 }
28219
28220 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)28221 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
28222 EVT VT = N->getValueType(0);
28223 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
28224
28225 SDValue LHS = N->getOperand(0);
28226 SDValue RHS = N->getOperand(1);
28227 unsigned Opcode = N->getOpcode();
28228 SDNodeFlags Flags = N->getFlags();
28229
28230 // Move unary shuffles with identical masks after a vector binop:
28231 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
28232 // --> shuffle (VBinOp A, B), Undef, Mask
28233 // This does not require type legality checks because we are creating the
28234 // same types of operations that are in the original sequence. We do have to
28235 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
28236 // though. This code is adapted from the identical transform in instcombine.
28237 if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
28238 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
28239 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
28240 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
28241 LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
28242 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
28243 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
28244 RHS.getOperand(0), Flags);
28245 SDValue UndefV = LHS.getOperand(1);
28246 return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
28247 }
28248
28249 // Try to sink a splat shuffle after a binop with a uniform constant.
28250 // This is limited to cases where neither the shuffle nor the constant have
28251 // undefined elements because that could be poison-unsafe or inhibit
28252 // demanded elements analysis. It is further limited to not change a splat
28253 // of an inserted scalar because that may be optimized better by
28254 // load-folding or other target-specific behaviors.
28255 if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
28256 Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
28257 Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
28258 // binop (splat X), (splat C) --> splat (binop X, C)
28259 SDValue X = Shuf0->getOperand(0);
28260 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
28261 return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
28262 Shuf0->getMask());
28263 }
28264 if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
28265 Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
28266 Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
28267 // binop (splat C), (splat X) --> splat (binop C, X)
28268 SDValue X = Shuf1->getOperand(0);
28269 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
28270 return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
28271 Shuf1->getMask());
28272 }
28273 }
28274
28275 // The following pattern is likely to emerge with vector reduction ops. Moving
28276 // the binary operation ahead of insertion may allow using a narrower vector
28277 // instruction that has better performance than the wide version of the op:
28278 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
28279 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
28280 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
28281 LHS.getOperand(2) == RHS.getOperand(2) &&
28282 (LHS.hasOneUse() || RHS.hasOneUse())) {
28283 SDValue X = LHS.getOperand(1);
28284 SDValue Y = RHS.getOperand(1);
28285 SDValue Z = LHS.getOperand(2);
28286 EVT NarrowVT = X.getValueType();
28287 if (NarrowVT == Y.getValueType() &&
28288 TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
28289 LegalOperations)) {
28290 // (binop undef, undef) may not return undef, so compute that result.
28291 SDValue VecC =
28292 DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
28293 SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
28294 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
28295 }
28296 }
28297
28298 // Make sure all but the first op are undef or constant.
28299 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
28300 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
28301 all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
28302 return Op.isUndef() ||
28303 ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
28304 });
28305 };
28306
28307 // The following pattern is likely to emerge with vector reduction ops. Moving
28308 // the binary operation ahead of the concat may allow using a narrower vector
28309 // instruction that has better performance than the wide version of the op:
28310 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
28311 // concat (VBinOp X, Y), VecC
28312 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
28313 (LHS.hasOneUse() || RHS.hasOneUse())) {
28314 EVT NarrowVT = LHS.getOperand(0).getValueType();
28315 if (NarrowVT == RHS.getOperand(0).getValueType() &&
28316 TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
28317 unsigned NumOperands = LHS.getNumOperands();
28318 SmallVector<SDValue, 4> ConcatOps;
28319 for (unsigned i = 0; i != NumOperands; ++i) {
28320 // This constant fold for operands 1 and up.
28321 ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
28322 RHS.getOperand(i)));
28323 }
28324
28325 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
28326 }
28327 }
28328
28329 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL, LegalTypes))
28330 return V;
28331
28332 return SDValue();
28333 }
28334
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)28335 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
28336 SDValue N2) {
28337 assert(N0.getOpcode() == ISD::SETCC &&
28338 "First argument must be a SetCC node!");
28339
28340 SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
28341 cast<CondCodeSDNode>(N0.getOperand(2))->get());
28342
28343 // If we got a simplified select_cc node back from SimplifySelectCC, then
28344 // break it down into a new SETCC node, and a new SELECT node, and then return
28345 // the SELECT node, since we were called with a SELECT node.
28346 if (SCC.getNode()) {
28347 // Check to see if we got a select_cc back (to turn into setcc/select).
28348 // Otherwise, just return whatever node we got back, like fabs.
28349 if (SCC.getOpcode() == ISD::SELECT_CC) {
28350 const SDNodeFlags Flags = N0->getFlags();
28351 SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
28352 N0.getValueType(),
28353 SCC.getOperand(0), SCC.getOperand(1),
28354 SCC.getOperand(4), Flags);
28355 AddToWorklist(SETCC.getNode());
28356 return DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
28357 SCC.getOperand(2), SCC.getOperand(3), Flags);
28358 }
28359
28360 return SCC;
28361 }
28362 return SDValue();
28363 }
28364
28365 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
28366 /// being selected between, see if we can simplify the select. Callers of this
28367 /// should assume that TheSelect is deleted if this returns true. As such, they
28368 /// should return the appropriate thing (e.g. the node) back to the top-level of
28369 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)28370 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
28371 SDValue RHS) {
28372 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
28373 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
28374 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
28375 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
28376 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
28377 SDValue Sqrt = RHS;
28378 ISD::CondCode CC;
28379 SDValue CmpLHS;
28380 const ConstantFPSDNode *Zero = nullptr;
28381
28382 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
28383 CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
28384 CmpLHS = TheSelect->getOperand(0);
28385 Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
28386 } else {
28387 // SELECT or VSELECT
28388 SDValue Cmp = TheSelect->getOperand(0);
28389 if (Cmp.getOpcode() == ISD::SETCC) {
28390 CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
28391 CmpLHS = Cmp.getOperand(0);
28392 Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
28393 }
28394 }
28395 if (Zero && Zero->isZero() &&
28396 Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
28397 CC == ISD::SETULT || CC == ISD::SETLT)) {
28398 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
28399 CombineTo(TheSelect, Sqrt);
28400 return true;
28401 }
28402 }
28403 }
28404 // Cannot simplify select with vector condition
28405 if (TheSelect->getOperand(0).getValueType().isVector()) return false;
28406
28407 // If this is a select from two identical things, try to pull the operation
28408 // through the select.
28409 if (LHS.getOpcode() != RHS.getOpcode() ||
28410 !LHS.hasOneUse() || !RHS.hasOneUse())
28411 return false;
28412
28413 // If this is a load and the token chain is identical, replace the select
28414 // of two loads with a load through a select of the address to load from.
28415 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
28416 // constants have been dropped into the constant pool.
28417 if (LHS.getOpcode() == ISD::LOAD) {
28418 LoadSDNode *LLD = cast<LoadSDNode>(LHS);
28419 LoadSDNode *RLD = cast<LoadSDNode>(RHS);
28420
28421 // Token chains must be identical.
28422 if (LHS.getOperand(0) != RHS.getOperand(0) ||
28423 // Do not let this transformation reduce the number of volatile loads.
28424 // Be conservative for atomics for the moment
28425 // TODO: This does appear to be legal for unordered atomics (see D66309)
28426 !LLD->isSimple() || !RLD->isSimple() ||
28427 // FIXME: If either is a pre/post inc/dec load,
28428 // we'd need to split out the address adjustment.
28429 LLD->isIndexed() || RLD->isIndexed() ||
28430 // If this is an EXTLOAD, the VT's must match.
28431 LLD->getMemoryVT() != RLD->getMemoryVT() ||
28432 // If this is an EXTLOAD, the kind of extension must match.
28433 (LLD->getExtensionType() != RLD->getExtensionType() &&
28434 // The only exception is if one of the extensions is anyext.
28435 LLD->getExtensionType() != ISD::EXTLOAD &&
28436 RLD->getExtensionType() != ISD::EXTLOAD) ||
28437 // FIXME: this discards src value information. This is
28438 // over-conservative. It would be beneficial to be able to remember
28439 // both potential memory locations. Since we are discarding
28440 // src value info, don't do the transformation if the memory
28441 // locations are not in the default address space.
28442 LLD->getPointerInfo().getAddrSpace() != 0 ||
28443 RLD->getPointerInfo().getAddrSpace() != 0 ||
28444 // We can't produce a CMOV of a TargetFrameIndex since we won't
28445 // generate the address generation required.
28446 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
28447 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
28448 !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
28449 LLD->getBasePtr().getValueType()))
28450 return false;
28451
28452 // The loads must not depend on one another.
28453 if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
28454 return false;
28455
28456 // Check that the select condition doesn't reach either load. If so,
28457 // folding this will induce a cycle into the DAG. If not, this is safe to
28458 // xform, so create a select of the addresses.
28459
28460 SmallPtrSet<const SDNode *, 32> Visited;
28461 SmallVector<const SDNode *, 16> Worklist;
28462
28463 // Always fail if LLD and RLD are not independent. TheSelect is a
28464 // predecessor to all Nodes in question so we need not search past it.
28465
28466 Visited.insert(TheSelect);
28467 Worklist.push_back(LLD);
28468 Worklist.push_back(RLD);
28469
28470 if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
28471 SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
28472 return false;
28473
28474 SDValue Addr;
28475 if (TheSelect->getOpcode() == ISD::SELECT) {
28476 // We cannot do this optimization if any pair of {RLD, LLD} is a
28477 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
28478 // Loads, we only need to check if CondNode is a successor to one of the
28479 // loads. We can further avoid this if there's no use of their chain
28480 // value.
28481 SDNode *CondNode = TheSelect->getOperand(0).getNode();
28482 Worklist.push_back(CondNode);
28483
28484 if ((LLD->hasAnyUseOfValue(1) &&
28485 SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
28486 (RLD->hasAnyUseOfValue(1) &&
28487 SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
28488 return false;
28489
28490 Addr = DAG.getSelect(SDLoc(TheSelect),
28491 LLD->getBasePtr().getValueType(),
28492 TheSelect->getOperand(0), LLD->getBasePtr(),
28493 RLD->getBasePtr());
28494 } else { // Otherwise SELECT_CC
28495 // We cannot do this optimization if any pair of {RLD, LLD} is a
28496 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
28497 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
28498 // one of the loads. We can further avoid this if there's no use of their
28499 // chain value.
28500
28501 SDNode *CondLHS = TheSelect->getOperand(0).getNode();
28502 SDNode *CondRHS = TheSelect->getOperand(1).getNode();
28503 Worklist.push_back(CondLHS);
28504 Worklist.push_back(CondRHS);
28505
28506 if ((LLD->hasAnyUseOfValue(1) &&
28507 SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
28508 (RLD->hasAnyUseOfValue(1) &&
28509 SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
28510 return false;
28511
28512 Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
28513 LLD->getBasePtr().getValueType(),
28514 TheSelect->getOperand(0),
28515 TheSelect->getOperand(1),
28516 LLD->getBasePtr(), RLD->getBasePtr(),
28517 TheSelect->getOperand(4));
28518 }
28519
28520 SDValue Load;
28521 // It is safe to replace the two loads if they have different alignments,
28522 // but the new load must be the minimum (most restrictive) alignment of the
28523 // inputs.
28524 Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
28525 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
28526 if (!RLD->isInvariant())
28527 MMOFlags &= ~MachineMemOperand::MOInvariant;
28528 if (!RLD->isDereferenceable())
28529 MMOFlags &= ~MachineMemOperand::MODereferenceable;
28530 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
28531 // FIXME: Discards pointer and AA info.
28532 Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
28533 LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
28534 MMOFlags);
28535 } else {
28536 // FIXME: Discards pointer and AA info.
28537 Load = DAG.getExtLoad(
28538 LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
28539 : LLD->getExtensionType(),
28540 SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
28541 MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
28542 }
28543
28544 // Users of the select now use the result of the load.
28545 CombineTo(TheSelect, Load);
28546
28547 // Users of the old loads now use the new load's chain. We know the
28548 // old-load value is dead now.
28549 CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
28550 CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
28551 return true;
28552 }
28553
28554 return false;
28555 }
28556
28557 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
28558 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)28559 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
28560 SDValue N1, SDValue N2, SDValue N3,
28561 ISD::CondCode CC) {
28562 // If this is a select where the false operand is zero and the compare is a
28563 // check of the sign bit, see if we can perform the "gzip trick":
28564 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
28565 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
28566 EVT XType = N0.getValueType();
28567 EVT AType = N2.getValueType();
28568 if (!isNullConstant(N3) || !XType.bitsGE(AType))
28569 return SDValue();
28570
28571 // If the comparison is testing for a positive value, we have to invert
28572 // the sign bit mask, so only do that transform if the target has a bitwise
28573 // 'and not' instruction (the invert is free).
28574 if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
28575 // (X > -1) ? A : 0
28576 // (X > 0) ? X : 0 <-- This is canonical signed max.
28577 if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
28578 return SDValue();
28579 } else if (CC == ISD::SETLT) {
28580 // (X < 0) ? A : 0
28581 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
28582 if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
28583 return SDValue();
28584 } else {
28585 return SDValue();
28586 }
28587
28588 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
28589 // constant.
28590 auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
28591 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
28592 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
28593 if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
28594 SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
28595 SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
28596 AddToWorklist(Shift.getNode());
28597
28598 if (XType.bitsGT(AType)) {
28599 Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
28600 AddToWorklist(Shift.getNode());
28601 }
28602
28603 if (CC == ISD::SETGT)
28604 Shift = DAG.getNOT(DL, Shift, AType);
28605
28606 return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
28607 }
28608 }
28609
28610 unsigned ShCt = XType.getSizeInBits() - 1;
28611 if (TLI.shouldAvoidTransformToShift(XType, ShCt))
28612 return SDValue();
28613
28614 SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
28615 SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
28616 AddToWorklist(Shift.getNode());
28617
28618 if (XType.bitsGT(AType)) {
28619 Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
28620 AddToWorklist(Shift.getNode());
28621 }
28622
28623 if (CC == ISD::SETGT)
28624 Shift = DAG.getNOT(DL, Shift, AType);
28625
28626 return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
28627 }
28628
28629 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)28630 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
28631 SDValue N0 = N->getOperand(0);
28632 SDValue N1 = N->getOperand(1);
28633 SDValue N2 = N->getOperand(2);
28634 SDLoc DL(N);
28635
28636 unsigned BinOpc = N1.getOpcode();
28637 if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc) ||
28638 (N1.getResNo() != N2.getResNo()))
28639 return SDValue();
28640
28641 // The use checks are intentionally on SDNode because we may be dealing
28642 // with opcodes that produce more than one SDValue.
28643 // TODO: Do we really need to check N0 (the condition operand of the select)?
28644 // But removing that clause could cause an infinite loop...
28645 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
28646 return SDValue();
28647
28648 // Binops may include opcodes that return multiple values, so all values
28649 // must be created/propagated from the newly created binops below.
28650 SDVTList OpVTs = N1->getVTList();
28651
28652 // Fold select(cond, binop(x, y), binop(z, y))
28653 // --> binop(select(cond, x, z), y)
28654 if (N1.getOperand(1) == N2.getOperand(1)) {
28655 SDValue N10 = N1.getOperand(0);
28656 SDValue N20 = N2.getOperand(0);
28657 SDValue NewSel = DAG.getSelect(DL, N10.getValueType(), N0, N10, N20);
28658 SDNodeFlags Flags = N1->getFlags() & N2->getFlags();
28659 SDValue NewBinOp =
28660 DAG.getNode(BinOpc, DL, OpVTs, {NewSel, N1.getOperand(1)}, Flags);
28661 return SDValue(NewBinOp.getNode(), N1.getResNo());
28662 }
28663
28664 // Fold select(cond, binop(x, y), binop(x, z))
28665 // --> binop(x, select(cond, y, z))
28666 if (N1.getOperand(0) == N2.getOperand(0)) {
28667 SDValue N11 = N1.getOperand(1);
28668 SDValue N21 = N2.getOperand(1);
28669 // Second op VT might be different (e.g. shift amount type)
28670 if (N11.getValueType() == N21.getValueType()) {
28671 SDValue NewSel = DAG.getSelect(DL, N11.getValueType(), N0, N11, N21);
28672 SDNodeFlags Flags = N1->getFlags() & N2->getFlags();
28673 SDValue NewBinOp =
28674 DAG.getNode(BinOpc, DL, OpVTs, {N1.getOperand(0), NewSel}, Flags);
28675 return SDValue(NewBinOp.getNode(), N1.getResNo());
28676 }
28677 }
28678
28679 // TODO: Handle isCommutativeBinOp patterns as well?
28680 return SDValue();
28681 }
28682
28683 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)28684 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
28685 SDValue N0 = N->getOperand(0);
28686 EVT VT = N->getValueType(0);
28687 bool IsFabs = N->getOpcode() == ISD::FABS;
28688 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
28689
28690 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
28691 return SDValue();
28692
28693 SDValue Int = N0.getOperand(0);
28694 EVT IntVT = Int.getValueType();
28695
28696 // The operand to cast should be integer.
28697 if (!IntVT.isInteger() || IntVT.isVector())
28698 return SDValue();
28699
28700 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
28701 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
28702 APInt SignMask;
28703 if (N0.getValueType().isVector()) {
28704 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
28705 // 0x7f...) per element and splat it.
28706 SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
28707 if (IsFabs)
28708 SignMask = ~SignMask;
28709 SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
28710 } else {
28711 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
28712 SignMask = APInt::getSignMask(IntVT.getSizeInBits());
28713 if (IsFabs)
28714 SignMask = ~SignMask;
28715 }
28716 SDLoc DL(N0);
28717 Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
28718 DAG.getConstant(SignMask, DL, IntVT));
28719 AddToWorklist(Int.getNode());
28720 return DAG.getBitcast(VT, Int);
28721 }
28722
28723 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
28724 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
28725 /// in it. This may be a win when the constant is not otherwise available
28726 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)28727 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
28728 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
28729 ISD::CondCode CC) {
28730 if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
28731 return SDValue();
28732
28733 // If we are before legalize types, we want the other legalization to happen
28734 // first (for example, to avoid messing with soft float).
28735 auto *TV = dyn_cast<ConstantFPSDNode>(N2);
28736 auto *FV = dyn_cast<ConstantFPSDNode>(N3);
28737 EVT VT = N2.getValueType();
28738 if (!TV || !FV || !TLI.isTypeLegal(VT))
28739 return SDValue();
28740
28741 // If a constant can be materialized without loads, this does not make sense.
28742 if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
28743 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
28744 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
28745 return SDValue();
28746
28747 // If both constants have multiple uses, then we won't need to do an extra
28748 // load. The values are likely around in registers for other users.
28749 if (!TV->hasOneUse() && !FV->hasOneUse())
28750 return SDValue();
28751
28752 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
28753 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
28754 Type *FPTy = Elts[0]->getType();
28755 const DataLayout &TD = DAG.getDataLayout();
28756
28757 // Create a ConstantArray of the two constants.
28758 Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
28759 SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
28760 TD.getPrefTypeAlign(FPTy));
28761 Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
28762
28763 // Get offsets to the 0 and 1 elements of the array, so we can select between
28764 // them.
28765 SDValue Zero = DAG.getIntPtrConstant(0, DL);
28766 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
28767 SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
28768 SDValue Cond =
28769 DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
28770 AddToWorklist(Cond.getNode());
28771 SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
28772 AddToWorklist(CstOffset.getNode());
28773 CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
28774 AddToWorklist(CPIdx.getNode());
28775 return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
28776 MachinePointerInfo::getConstantPool(
28777 DAG.getMachineFunction()), Alignment);
28778 }
28779
28780 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
28781 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)28782 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
28783 SDValue N2, SDValue N3, ISD::CondCode CC,
28784 bool NotExtCompare) {
28785 // (x ? y : y) -> y.
28786 if (N2 == N3) return N2;
28787
28788 EVT CmpOpVT = N0.getValueType();
28789 EVT CmpResVT = getSetCCResultType(CmpOpVT);
28790 EVT VT = N2.getValueType();
28791 auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
28792 auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
28793 auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
28794
28795 // Determine if the condition we're dealing with is constant.
28796 if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
28797 AddToWorklist(SCC.getNode());
28798 if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
28799 // fold select_cc true, x, y -> x
28800 // fold select_cc false, x, y -> y
28801 return !(SCCC->isZero()) ? N2 : N3;
28802 }
28803 }
28804
28805 if (SDValue V =
28806 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
28807 return V;
28808
28809 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
28810 return V;
28811
28812 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
28813 // where y is has a single bit set.
28814 // A plaintext description would be, we can turn the SELECT_CC into an AND
28815 // when the condition can be materialized as an all-ones register. Any
28816 // single bit-test can be materialized as an all-ones register with
28817 // shift-left and shift-right-arith.
28818 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
28819 N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
28820 SDValue AndLHS = N0->getOperand(0);
28821 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
28822 if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
28823 // Shift the tested bit over the sign bit.
28824 const APInt &AndMask = ConstAndRHS->getAPIntValue();
28825 if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
28826 unsigned ShCt = AndMask.getBitWidth() - 1;
28827 SDValue ShlAmt = DAG.getShiftAmountConstant(AndMask.countl_zero(), VT,
28828 SDLoc(AndLHS));
28829 SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
28830
28831 // Now arithmetic right shift it all the way over, so the result is
28832 // either all-ones, or zero.
28833 SDValue ShrAmt = DAG.getShiftAmountConstant(ShCt, VT, SDLoc(Shl));
28834 SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
28835
28836 return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
28837 }
28838 }
28839 }
28840
28841 // fold select C, 16, 0 -> shl C, 4
28842 bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
28843 bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
28844
28845 if ((Fold || Swap) &&
28846 TLI.getBooleanContents(CmpOpVT) ==
28847 TargetLowering::ZeroOrOneBooleanContent &&
28848 (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT)) &&
28849 TLI.convertSelectOfConstantsToMath(VT)) {
28850
28851 if (Swap) {
28852 CC = ISD::getSetCCInverse(CC, CmpOpVT);
28853 std::swap(N2C, N3C);
28854 }
28855
28856 // If the caller doesn't want us to simplify this into a zext of a compare,
28857 // don't do it.
28858 if (NotExtCompare && N2C->isOne())
28859 return SDValue();
28860
28861 SDValue Temp, SCC;
28862 // zext (setcc n0, n1)
28863 if (LegalTypes) {
28864 SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
28865 Temp = DAG.getZExtOrTrunc(SCC, SDLoc(N2), VT);
28866 } else {
28867 SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
28868 Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
28869 }
28870
28871 AddToWorklist(SCC.getNode());
28872 AddToWorklist(Temp.getNode());
28873
28874 if (N2C->isOne())
28875 return Temp;
28876
28877 unsigned ShCt = N2C->getAPIntValue().logBase2();
28878 if (TLI.shouldAvoidTransformToShift(VT, ShCt))
28879 return SDValue();
28880
28881 // shl setcc result by log2 n2c
28882 return DAG.getNode(
28883 ISD::SHL, DL, N2.getValueType(), Temp,
28884 DAG.getShiftAmountConstant(ShCt, N2.getValueType(), SDLoc(Temp)));
28885 }
28886
28887 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
28888 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
28889 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
28890 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
28891 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
28892 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
28893 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
28894 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
28895 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
28896 SDValue ValueOnZero = N2;
28897 SDValue Count = N3;
28898 // If the condition is NE instead of E, swap the operands.
28899 if (CC == ISD::SETNE)
28900 std::swap(ValueOnZero, Count);
28901 // Check if the value on zero is a constant equal to the bits in the type.
28902 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
28903 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
28904 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
28905 // legal, combine to just cttz.
28906 if ((Count.getOpcode() == ISD::CTTZ ||
28907 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
28908 N0 == Count.getOperand(0) &&
28909 (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
28910 return DAG.getNode(ISD::CTTZ, DL, VT, N0);
28911 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
28912 // legal, combine to just ctlz.
28913 if ((Count.getOpcode() == ISD::CTLZ ||
28914 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
28915 N0 == Count.getOperand(0) &&
28916 (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
28917 return DAG.getNode(ISD::CTLZ, DL, VT, N0);
28918 }
28919 }
28920 }
28921
28922 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
28923 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
28924 if (!NotExtCompare && N1C && N2C && N3C &&
28925 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
28926 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
28927 (N1C->isZero() && CC == ISD::SETLT)) &&
28928 !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
28929 SDValue ASR = DAG.getNode(
28930 ISD::SRA, DL, CmpOpVT, N0,
28931 DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
28932 return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
28933 DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
28934 }
28935
28936 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
28937 return S;
28938 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
28939 return S;
28940 if (SDValue ABD = foldSelectToABD(N0, N1, N2, N3, CC, DL))
28941 return ABD;
28942
28943 return SDValue();
28944 }
28945
28946 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)28947 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
28948 ISD::CondCode Cond, const SDLoc &DL,
28949 bool foldBooleans) {
28950 TargetLowering::DAGCombinerInfo
28951 DagCombineInfo(DAG, Level, false, this);
28952 return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
28953 }
28954
28955 /// Given an ISD::SDIV node expressing a divide by constant, return
28956 /// a DAG expression to select that will generate the same value by multiplying
28957 /// by a magic number.
28958 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)28959 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
28960 // when optimising for minimum size, we don't want to expand a div to a mul
28961 // and a shift.
28962 if (DAG.getMachineFunction().getFunction().hasMinSize())
28963 return SDValue();
28964
28965 SmallVector<SDNode *, 8> Built;
28966 if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, LegalTypes, Built)) {
28967 for (SDNode *N : Built)
28968 AddToWorklist(N);
28969 return S;
28970 }
28971
28972 return SDValue();
28973 }
28974
28975 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
28976 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)28977 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
28978 ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
28979 if (!C)
28980 return SDValue();
28981
28982 // Avoid division by zero.
28983 if (C->isZero())
28984 return SDValue();
28985
28986 SmallVector<SDNode *, 8> Built;
28987 if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
28988 for (SDNode *N : Built)
28989 AddToWorklist(N);
28990 return S;
28991 }
28992
28993 return SDValue();
28994 }
28995
28996 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
28997 /// expression that will generate the same value by multiplying by a magic
28998 /// number.
28999 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)29000 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
29001 // when optimising for minimum size, we don't want to expand a div to a mul
29002 // and a shift.
29003 if (DAG.getMachineFunction().getFunction().hasMinSize())
29004 return SDValue();
29005
29006 SmallVector<SDNode *, 8> Built;
29007 if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, LegalTypes, Built)) {
29008 for (SDNode *N : Built)
29009 AddToWorklist(N);
29010 return S;
29011 }
29012
29013 return SDValue();
29014 }
29015
29016 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
29017 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)29018 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
29019 ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
29020 if (!C)
29021 return SDValue();
29022
29023 // Avoid division by zero.
29024 if (C->isZero())
29025 return SDValue();
29026
29027 SmallVector<SDNode *, 8> Built;
29028 if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
29029 for (SDNode *N : Built)
29030 AddToWorklist(N);
29031 return S;
29032 }
29033
29034 return SDValue();
29035 }
29036
29037 // This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
29038 //
29039 // Returns the node that represents `Log2(Op)`. This may create a new node. If
29040 // we are unable to compute `Log2(Op)` its return `SDValue()`.
29041 //
29042 // All nodes will be created at `DL` and the output will be of type `VT`.
29043 //
29044 // This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
29045 // `AssumeNonZero` if this function should simply assume (not require proving
29046 // `Op` is non-zero).
takeInexpensiveLog2(SelectionDAG & DAG,const SDLoc & DL,EVT VT,SDValue Op,unsigned Depth,bool AssumeNonZero)29047 static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
29048 SDValue Op, unsigned Depth,
29049 bool AssumeNonZero) {
29050 assert(VT.isInteger() && "Only integer types are supported!");
29051
29052 auto PeekThroughCastsAndTrunc = [](SDValue V) {
29053 while (true) {
29054 switch (V.getOpcode()) {
29055 case ISD::TRUNCATE:
29056 case ISD::ZERO_EXTEND:
29057 V = V.getOperand(0);
29058 break;
29059 default:
29060 return V;
29061 }
29062 }
29063 };
29064
29065 if (VT.isScalableVector())
29066 return SDValue();
29067
29068 Op = PeekThroughCastsAndTrunc(Op);
29069
29070 // Helper for determining whether a value is a power-2 constant scalar or a
29071 // vector of such elements.
29072 SmallVector<APInt> Pow2Constants;
29073 auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
29074 if (C->isZero() || C->isOpaque())
29075 return false;
29076 // TODO: We may also be able to support negative powers of 2 here.
29077 if (C->getAPIntValue().isPowerOf2()) {
29078 Pow2Constants.emplace_back(C->getAPIntValue());
29079 return true;
29080 }
29081 return false;
29082 };
29083
29084 if (ISD::matchUnaryPredicate(Op, IsPowerOfTwo)) {
29085 if (!VT.isVector())
29086 return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT);
29087 // We need to create a build vector
29088 if (Op.getOpcode() == ISD::SPLAT_VECTOR)
29089 return DAG.getSplat(VT, DL,
29090 DAG.getConstant(Pow2Constants.back().logBase2(), DL,
29091 VT.getScalarType()));
29092 SmallVector<SDValue> Log2Ops;
29093 for (const APInt &Pow2 : Pow2Constants)
29094 Log2Ops.emplace_back(
29095 DAG.getConstant(Pow2.logBase2(), DL, VT.getScalarType()));
29096 return DAG.getBuildVector(VT, DL, Log2Ops);
29097 }
29098
29099 if (Depth >= DAG.MaxRecursionDepth)
29100 return SDValue();
29101
29102 auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
29103 // Peek through zero extend. We can't peek through truncates since this
29104 // function is called on a shift amount. We must ensure that all of the bits
29105 // above the original shift amount are zeroed by this function.
29106 while (ToCast.getOpcode() == ISD::ZERO_EXTEND)
29107 ToCast = ToCast.getOperand(0);
29108 EVT CurVT = ToCast.getValueType();
29109 if (NewVT == CurVT)
29110 return ToCast;
29111
29112 if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
29113 return DAG.getBitcast(NewVT, ToCast);
29114
29115 return DAG.getZExtOrTrunc(ToCast, DL, NewVT);
29116 };
29117
29118 // log2(X << Y) -> log2(X) + Y
29119 if (Op.getOpcode() == ISD::SHL) {
29120 // 1 << Y and X nuw/nsw << Y are all non-zero.
29121 if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
29122 Op->getFlags().hasNoSignedWrap() || isOneConstant(Op.getOperand(0)))
29123 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0),
29124 Depth + 1, AssumeNonZero))
29125 return DAG.getNode(ISD::ADD, DL, VT, LogX,
29126 CastToVT(VT, Op.getOperand(1)));
29127 }
29128
29129 // c ? X : Y -> c ? Log2(X) : Log2(Y)
29130 if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
29131 Op.hasOneUse()) {
29132 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
29133 Depth + 1, AssumeNonZero))
29134 if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
29135 Depth + 1, AssumeNonZero))
29136 return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
29137 }
29138
29139 // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
29140 // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
29141 if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
29142 Op.hasOneUse()) {
29143 // Use AssumeNonZero as false here. Otherwise we can hit case where
29144 // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
29145 if (SDValue LogX =
29146 takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), Depth + 1,
29147 /*AssumeNonZero*/ false))
29148 if (SDValue LogY =
29149 takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), Depth + 1,
29150 /*AssumeNonZero*/ false))
29151 return DAG.getNode(Op.getOpcode(), DL, VT, LogX, LogY);
29152 }
29153
29154 return SDValue();
29155 }
29156
29157 /// Determines the LogBase2 value for a non-null input value using the
29158 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL,bool KnownNonZero,bool InexpensiveOnly,std::optional<EVT> OutVT)29159 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
29160 bool KnownNonZero, bool InexpensiveOnly,
29161 std::optional<EVT> OutVT) {
29162 EVT VT = OutVT ? *OutVT : V.getValueType();
29163 SDValue InexpensiveLogBase2 =
29164 takeInexpensiveLog2(DAG, DL, VT, V, /*Depth*/ 0, KnownNonZero);
29165 if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(V))
29166 return InexpensiveLogBase2;
29167
29168 SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
29169 SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
29170 SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
29171 return LogBase2;
29172 }
29173
29174 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29175 /// For the reciprocal, we need to find the zero of the function:
29176 /// F(X) = 1/X - A [which has a zero at X = 1/A]
29177 /// =>
29178 /// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
29179 /// does not require additional intermediate precision]
29180 /// For the last iteration, put numerator N into it to gain more precision:
29181 /// Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)29182 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
29183 SDNodeFlags Flags) {
29184 if (LegalDAG)
29185 return SDValue();
29186
29187 // TODO: Handle extended types?
29188 EVT VT = Op.getValueType();
29189 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
29190 VT.getScalarType() != MVT::f64)
29191 return SDValue();
29192
29193 // If estimates are explicitly disabled for this function, we're done.
29194 MachineFunction &MF = DAG.getMachineFunction();
29195 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
29196 if (Enabled == TLI.ReciprocalEstimate::Disabled)
29197 return SDValue();
29198
29199 // Estimates may be explicitly enabled for this type with a custom number of
29200 // refinement steps.
29201 int Iterations = TLI.getDivRefinementSteps(VT, MF);
29202 if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
29203 AddToWorklist(Est.getNode());
29204
29205 SDLoc DL(Op);
29206 if (Iterations) {
29207 SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
29208
29209 // Newton iterations: Est = Est + Est (N - Arg * Est)
29210 // If this is the last iteration, also multiply by the numerator.
29211 for (int i = 0; i < Iterations; ++i) {
29212 SDValue MulEst = Est;
29213
29214 if (i == Iterations - 1) {
29215 MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
29216 AddToWorklist(MulEst.getNode());
29217 }
29218
29219 SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
29220 AddToWorklist(NewEst.getNode());
29221
29222 NewEst = DAG.getNode(ISD::FSUB, DL, VT,
29223 (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
29224 AddToWorklist(NewEst.getNode());
29225
29226 NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
29227 AddToWorklist(NewEst.getNode());
29228
29229 Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
29230 AddToWorklist(Est.getNode());
29231 }
29232 } else {
29233 // If no iterations are available, multiply with N.
29234 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
29235 AddToWorklist(Est.getNode());
29236 }
29237
29238 return Est;
29239 }
29240
29241 return SDValue();
29242 }
29243
29244 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29245 /// For the reciprocal sqrt, we need to find the zero of the function:
29246 /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
29247 /// =>
29248 /// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
29249 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)29250 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
29251 unsigned Iterations,
29252 SDNodeFlags Flags, bool Reciprocal) {
29253 EVT VT = Arg.getValueType();
29254 SDLoc DL(Arg);
29255 SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
29256
29257 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
29258 // this entire sequence requires only one FP constant.
29259 SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
29260 HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
29261
29262 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
29263 for (unsigned i = 0; i < Iterations; ++i) {
29264 SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
29265 NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
29266 NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
29267 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
29268 }
29269
29270 // If non-reciprocal square root is requested, multiply the result by Arg.
29271 if (!Reciprocal)
29272 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
29273
29274 return Est;
29275 }
29276
29277 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
29278 /// For the reciprocal sqrt, we need to find the zero of the function:
29279 /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
29280 /// =>
29281 /// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)29282 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
29283 unsigned Iterations,
29284 SDNodeFlags Flags, bool Reciprocal) {
29285 EVT VT = Arg.getValueType();
29286 SDLoc DL(Arg);
29287 SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
29288 SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
29289
29290 // This routine must enter the loop below to work correctly
29291 // when (Reciprocal == false).
29292 assert(Iterations > 0);
29293
29294 // Newton iterations for reciprocal square root:
29295 // E = (E * -0.5) * ((A * E) * E + -3.0)
29296 for (unsigned i = 0; i < Iterations; ++i) {
29297 SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
29298 SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
29299 SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
29300
29301 // When calculating a square root at the last iteration build:
29302 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
29303 // (notice a common subexpression)
29304 SDValue LHS;
29305 if (Reciprocal || (i + 1) < Iterations) {
29306 // RSQRT: LHS = (E * -0.5)
29307 LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
29308 } else {
29309 // SQRT: LHS = (A * E) * -0.5
29310 LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
29311 }
29312
29313 Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
29314 }
29315
29316 return Est;
29317 }
29318
29319 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
29320 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
29321 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)29322 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
29323 bool Reciprocal) {
29324 if (LegalDAG)
29325 return SDValue();
29326
29327 // TODO: Handle extended types?
29328 EVT VT = Op.getValueType();
29329 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
29330 VT.getScalarType() != MVT::f64)
29331 return SDValue();
29332
29333 // If estimates are explicitly disabled for this function, we're done.
29334 MachineFunction &MF = DAG.getMachineFunction();
29335 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
29336 if (Enabled == TLI.ReciprocalEstimate::Disabled)
29337 return SDValue();
29338
29339 // Estimates may be explicitly enabled for this type with a custom number of
29340 // refinement steps.
29341 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
29342
29343 bool UseOneConstNR = false;
29344 if (SDValue Est =
29345 TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
29346 Reciprocal)) {
29347 AddToWorklist(Est.getNode());
29348
29349 if (Iterations > 0)
29350 Est = UseOneConstNR
29351 ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
29352 : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
29353 if (!Reciprocal) {
29354 SDLoc DL(Op);
29355 // Try the target specific test first.
29356 SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
29357
29358 // The estimate is now completely wrong if the input was exactly 0.0 or
29359 // possibly a denormal. Force the answer to 0.0 or value provided by
29360 // target for those cases.
29361 Est = DAG.getSelect(DL, VT, Test,
29362 TLI.getSqrtResultForDenormInput(Op, DAG), Est);
29363 }
29364 return Est;
29365 }
29366
29367 return SDValue();
29368 }
29369
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)29370 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
29371 return buildSqrtEstimateImpl(Op, Flags, true);
29372 }
29373
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)29374 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
29375 return buildSqrtEstimateImpl(Op, Flags, false);
29376 }
29377
29378 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const29379 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
29380
29381 struct MemUseCharacteristics {
29382 bool IsVolatile;
29383 bool IsAtomic;
29384 SDValue BasePtr;
29385 int64_t Offset;
29386 LocationSize NumBytes;
29387 MachineMemOperand *MMO;
29388 };
29389
29390 auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
29391 if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
29392 int64_t Offset = 0;
29393 if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
29394 Offset = (LSN->getAddressingMode() == ISD::PRE_INC) ? C->getSExtValue()
29395 : (LSN->getAddressingMode() == ISD::PRE_DEC)
29396 ? -1 * C->getSExtValue()
29397 : 0;
29398 TypeSize Size = LSN->getMemoryVT().getStoreSize();
29399 return {LSN->isVolatile(), LSN->isAtomic(),
29400 LSN->getBasePtr(), Offset /*base offset*/,
29401 LocationSize::precise(Size), LSN->getMemOperand()};
29402 }
29403 if (const auto *LN = cast<LifetimeSDNode>(N))
29404 return {false /*isVolatile*/,
29405 /*isAtomic*/ false,
29406 LN->getOperand(1),
29407 (LN->hasOffset()) ? LN->getOffset() : 0,
29408 (LN->hasOffset()) ? LocationSize::precise(LN->getSize())
29409 : LocationSize::beforeOrAfterPointer(),
29410 (MachineMemOperand *)nullptr};
29411 // Default.
29412 return {false /*isvolatile*/,
29413 /*isAtomic*/ false,
29414 SDValue(),
29415 (int64_t)0 /*offset*/,
29416 LocationSize::beforeOrAfterPointer() /*size*/,
29417 (MachineMemOperand *)nullptr};
29418 };
29419
29420 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
29421 MUC1 = getCharacteristics(Op1);
29422
29423 // If they are to the same address, then they must be aliases.
29424 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
29425 MUC0.Offset == MUC1.Offset)
29426 return true;
29427
29428 // If they are both volatile then they cannot be reordered.
29429 if (MUC0.IsVolatile && MUC1.IsVolatile)
29430 return true;
29431
29432 // Be conservative about atomics for the moment
29433 // TODO: This is way overconservative for unordered atomics (see D66309)
29434 if (MUC0.IsAtomic && MUC1.IsAtomic)
29435 return true;
29436
29437 if (MUC0.MMO && MUC1.MMO) {
29438 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
29439 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
29440 return false;
29441 }
29442
29443 // If NumBytes is scalable and offset is not 0, conservatively return may
29444 // alias
29445 if ((MUC0.NumBytes.hasValue() && MUC0.NumBytes.isScalable() &&
29446 MUC0.Offset != 0) ||
29447 (MUC1.NumBytes.hasValue() && MUC1.NumBytes.isScalable() &&
29448 MUC1.Offset != 0))
29449 return true;
29450 // Try to prove that there is aliasing, or that there is no aliasing. Either
29451 // way, we can return now. If nothing can be proved, proceed with more tests.
29452 bool IsAlias;
29453 if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
29454 DAG, IsAlias))
29455 return IsAlias;
29456
29457 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
29458 // either are not known.
29459 if (!MUC0.MMO || !MUC1.MMO)
29460 return true;
29461
29462 // If one operation reads from invariant memory, and the other may store, they
29463 // cannot alias. These should really be checking the equivalent of mayWrite,
29464 // but it only matters for memory nodes other than load /store.
29465 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
29466 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
29467 return false;
29468
29469 // If we know required SrcValue1 and SrcValue2 have relatively large
29470 // alignment compared to the size and offset of the access, we may be able
29471 // to prove they do not alias. This check is conservative for now to catch
29472 // cases created by splitting vector types, it only works when the offsets are
29473 // multiples of the size of the data.
29474 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
29475 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
29476 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
29477 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
29478 LocationSize Size0 = MUC0.NumBytes;
29479 LocationSize Size1 = MUC1.NumBytes;
29480
29481 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
29482 Size0.hasValue() && Size1.hasValue() && !Size0.isScalable() &&
29483 !Size1.isScalable() && Size0 == Size1 &&
29484 OrigAlignment0 > Size0.getValue().getKnownMinValue() &&
29485 SrcValOffset0 % Size0.getValue().getKnownMinValue() == 0 &&
29486 SrcValOffset1 % Size1.getValue().getKnownMinValue() == 0) {
29487 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
29488 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
29489
29490 // There is no overlap between these relatively aligned accesses of
29491 // similar size. Return no alias.
29492 if ((OffAlign0 + static_cast<int64_t>(
29493 Size0.getValue().getKnownMinValue())) <= OffAlign1 ||
29494 (OffAlign1 + static_cast<int64_t>(
29495 Size1.getValue().getKnownMinValue())) <= OffAlign0)
29496 return false;
29497 }
29498
29499 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
29500 ? CombinerGlobalAA
29501 : DAG.getSubtarget().useAA();
29502 #ifndef NDEBUG
29503 if (CombinerAAOnlyFunc.getNumOccurrences() &&
29504 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
29505 UseAA = false;
29506 #endif
29507
29508 if (UseAA && BatchAA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
29509 Size0.hasValue() && Size1.hasValue() &&
29510 // Can't represent a scalable size + fixed offset in LocationSize
29511 (!Size0.isScalable() || SrcValOffset0 == 0) &&
29512 (!Size1.isScalable() || SrcValOffset1 == 0)) {
29513 // Use alias analysis information.
29514 int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
29515 int64_t Overlap0 =
29516 Size0.getValue().getKnownMinValue() + SrcValOffset0 - MinOffset;
29517 int64_t Overlap1 =
29518 Size1.getValue().getKnownMinValue() + SrcValOffset1 - MinOffset;
29519 LocationSize Loc0 =
29520 Size0.isScalable() ? Size0 : LocationSize::precise(Overlap0);
29521 LocationSize Loc1 =
29522 Size1.isScalable() ? Size1 : LocationSize::precise(Overlap1);
29523 if (BatchAA->isNoAlias(
29524 MemoryLocation(MUC0.MMO->getValue(), Loc0,
29525 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
29526 MemoryLocation(MUC1.MMO->getValue(), Loc1,
29527 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
29528 return false;
29529 }
29530
29531 // Otherwise we have to assume they alias.
29532 return true;
29533 }
29534
29535 /// Walk up chain skipping non-aliasing memory nodes,
29536 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)29537 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
29538 SmallVectorImpl<SDValue> &Aliases) {
29539 SmallVector<SDValue, 8> Chains; // List of chains to visit.
29540 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
29541
29542 // Get alias information for node.
29543 // TODO: relax aliasing for unordered atomics (see D66309)
29544 const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
29545
29546 // Starting off.
29547 Chains.push_back(OriginalChain);
29548 unsigned Depth = 0;
29549
29550 // Attempt to improve chain by a single step
29551 auto ImproveChain = [&](SDValue &C) -> bool {
29552 switch (C.getOpcode()) {
29553 case ISD::EntryToken:
29554 // No need to mark EntryToken.
29555 C = SDValue();
29556 return true;
29557 case ISD::LOAD:
29558 case ISD::STORE: {
29559 // Get alias information for C.
29560 // TODO: Relax aliasing for unordered atomics (see D66309)
29561 bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
29562 cast<LSBaseSDNode>(C.getNode())->isSimple();
29563 if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
29564 // Look further up the chain.
29565 C = C.getOperand(0);
29566 return true;
29567 }
29568 // Alias, so stop here.
29569 return false;
29570 }
29571
29572 case ISD::CopyFromReg:
29573 // Always forward past CopyFromReg.
29574 C = C.getOperand(0);
29575 return true;
29576
29577 case ISD::LIFETIME_START:
29578 case ISD::LIFETIME_END: {
29579 // We can forward past any lifetime start/end that can be proven not to
29580 // alias the memory access.
29581 if (!mayAlias(N, C.getNode())) {
29582 // Look further up the chain.
29583 C = C.getOperand(0);
29584 return true;
29585 }
29586 return false;
29587 }
29588 default:
29589 return false;
29590 }
29591 };
29592
29593 // Look at each chain and determine if it is an alias. If so, add it to the
29594 // aliases list. If not, then continue up the chain looking for the next
29595 // candidate.
29596 while (!Chains.empty()) {
29597 SDValue Chain = Chains.pop_back_val();
29598
29599 // Don't bother if we've seen Chain before.
29600 if (!Visited.insert(Chain.getNode()).second)
29601 continue;
29602
29603 // For TokenFactor nodes, look at each operand and only continue up the
29604 // chain until we reach the depth limit.
29605 //
29606 // FIXME: The depth check could be made to return the last non-aliasing
29607 // chain we found before we hit a tokenfactor rather than the original
29608 // chain.
29609 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
29610 Aliases.clear();
29611 Aliases.push_back(OriginalChain);
29612 return;
29613 }
29614
29615 if (Chain.getOpcode() == ISD::TokenFactor) {
29616 // We have to check each of the operands of the token factor for "small"
29617 // token factors, so we queue them up. Adding the operands to the queue
29618 // (stack) in reverse order maintains the original order and increases the
29619 // likelihood that getNode will find a matching token factor (CSE.)
29620 if (Chain.getNumOperands() > 16) {
29621 Aliases.push_back(Chain);
29622 continue;
29623 }
29624 for (unsigned n = Chain.getNumOperands(); n;)
29625 Chains.push_back(Chain.getOperand(--n));
29626 ++Depth;
29627 continue;
29628 }
29629 // Everything else
29630 if (ImproveChain(Chain)) {
29631 // Updated Chain Found, Consider new chain if one exists.
29632 if (Chain.getNode())
29633 Chains.push_back(Chain);
29634 ++Depth;
29635 continue;
29636 }
29637 // No Improved Chain Possible, treat as Alias.
29638 Aliases.push_back(Chain);
29639 }
29640 }
29641
29642 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
29643 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)29644 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
29645 if (OptLevel == CodeGenOptLevel::None)
29646 return OldChain;
29647
29648 // Ops for replacing token factor.
29649 SmallVector<SDValue, 8> Aliases;
29650
29651 // Accumulate all the aliases to this node.
29652 GatherAllAliases(N, OldChain, Aliases);
29653
29654 // If no operands then chain to entry token.
29655 if (Aliases.empty())
29656 return DAG.getEntryNode();
29657
29658 // If a single operand then chain to it. We don't need to revisit it.
29659 if (Aliases.size() == 1)
29660 return Aliases[0];
29661
29662 // Construct a custom tailored token factor.
29663 return DAG.getTokenFactor(SDLoc(N), Aliases);
29664 }
29665
29666 // This function tries to collect a bunch of potentially interesting
29667 // nodes to improve the chains of, all at once. This might seem
29668 // redundant, as this function gets called when visiting every store
29669 // node, so why not let the work be done on each store as it's visited?
29670 //
29671 // I believe this is mainly important because mergeConsecutiveStores
29672 // is unable to deal with merging stores of different sizes, so unless
29673 // we improve the chains of all the potential candidates up-front
29674 // before running mergeConsecutiveStores, it might only see some of
29675 // the nodes that will eventually be candidates, and then not be able
29676 // to go from a partially-merged state to the desired final
29677 // fully-merged state.
29678
parallelizeChainedStores(StoreSDNode * St)29679 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
29680 SmallVector<StoreSDNode *, 8> ChainedStores;
29681 StoreSDNode *STChain = St;
29682 // Intervals records which offsets from BaseIndex have been covered. In
29683 // the common case, every store writes to the immediately previous address
29684 // space and thus merged with the previous interval at insertion time.
29685
29686 using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
29687 IntervalMapHalfOpenInfo<int64_t>>;
29688 IMap::Allocator A;
29689 IMap Intervals(A);
29690
29691 // This holds the base pointer, index, and the offset in bytes from the base
29692 // pointer.
29693 const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
29694
29695 // We must have a base and an offset.
29696 if (!BasePtr.getBase().getNode())
29697 return false;
29698
29699 // Do not handle stores to undef base pointers.
29700 if (BasePtr.getBase().isUndef())
29701 return false;
29702
29703 // Do not handle stores to opaque types
29704 if (St->getMemoryVT().isZeroSized())
29705 return false;
29706
29707 // BaseIndexOffset assumes that offsets are fixed-size, which
29708 // is not valid for scalable vectors where the offsets are
29709 // scaled by `vscale`, so bail out early.
29710 if (St->getMemoryVT().isScalableVT())
29711 return false;
29712
29713 // Add ST's interval.
29714 Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8,
29715 std::monostate{});
29716
29717 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
29718 if (Chain->getMemoryVT().isScalableVector())
29719 return false;
29720
29721 // If the chain has more than one use, then we can't reorder the mem ops.
29722 if (!SDValue(Chain, 0)->hasOneUse())
29723 break;
29724 // TODO: Relax for unordered atomics (see D66309)
29725 if (!Chain->isSimple() || Chain->isIndexed())
29726 break;
29727
29728 // Find the base pointer and offset for this memory node.
29729 const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
29730 // Check that the base pointer is the same as the original one.
29731 int64_t Offset;
29732 if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
29733 break;
29734 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
29735 // Make sure we don't overlap with other intervals by checking the ones to
29736 // the left or right before inserting.
29737 auto I = Intervals.find(Offset);
29738 // If there's a next interval, we should end before it.
29739 if (I != Intervals.end() && I.start() < (Offset + Length))
29740 break;
29741 // If there's a previous interval, we should start after it.
29742 if (I != Intervals.begin() && (--I).stop() <= Offset)
29743 break;
29744 Intervals.insert(Offset, Offset + Length, std::monostate{});
29745
29746 ChainedStores.push_back(Chain);
29747 STChain = Chain;
29748 }
29749
29750 // If we didn't find a chained store, exit.
29751 if (ChainedStores.empty())
29752 return false;
29753
29754 // Improve all chained stores (St and ChainedStores members) starting from
29755 // where the store chain ended and return single TokenFactor.
29756 SDValue NewChain = STChain->getChain();
29757 SmallVector<SDValue, 8> TFOps;
29758 for (unsigned I = ChainedStores.size(); I;) {
29759 StoreSDNode *S = ChainedStores[--I];
29760 SDValue BetterChain = FindBetterChain(S, NewChain);
29761 S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
29762 S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
29763 TFOps.push_back(SDValue(S, 0));
29764 ChainedStores[I] = S;
29765 }
29766
29767 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
29768 SDValue BetterChain = FindBetterChain(St, NewChain);
29769 SDValue NewST;
29770 if (St->isTruncatingStore())
29771 NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
29772 St->getBasePtr(), St->getMemoryVT(),
29773 St->getMemOperand());
29774 else
29775 NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
29776 St->getBasePtr(), St->getMemOperand());
29777
29778 TFOps.push_back(NewST);
29779
29780 // If we improved every element of TFOps, then we've lost the dependence on
29781 // NewChain to successors of St and we need to add it back to TFOps. Do so at
29782 // the beginning to keep relative order consistent with FindBetterChains.
29783 auto hasImprovedChain = [&](SDValue ST) -> bool {
29784 return ST->getOperand(0) != NewChain;
29785 };
29786 bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
29787 if (AddNewChain)
29788 TFOps.insert(TFOps.begin(), NewChain);
29789
29790 SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
29791 CombineTo(St, TF);
29792
29793 // Add TF and its operands to the worklist.
29794 AddToWorklist(TF.getNode());
29795 for (const SDValue &Op : TF->ops())
29796 AddToWorklist(Op.getNode());
29797 AddToWorklist(STChain);
29798 return true;
29799 }
29800
findBetterNeighborChains(StoreSDNode * St)29801 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
29802 if (OptLevel == CodeGenOptLevel::None)
29803 return false;
29804
29805 const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
29806
29807 // We must have a base and an offset.
29808 if (!BasePtr.getBase().getNode())
29809 return false;
29810
29811 // Do not handle stores to undef base pointers.
29812 if (BasePtr.getBase().isUndef())
29813 return false;
29814
29815 // Directly improve a chain of disjoint stores starting at St.
29816 if (parallelizeChainedStores(St))
29817 return true;
29818
29819 // Improve St's Chain..
29820 SDValue BetterChain = FindBetterChain(St, St->getChain());
29821 if (St->getChain() != BetterChain) {
29822 replaceStoreChain(St, BetterChain);
29823 return true;
29824 }
29825 return false;
29826 }
29827
29828 /// This is the entry point for the file.
Combine(CombineLevel Level,BatchAAResults * BatchAA,CodeGenOptLevel OptLevel)29829 void SelectionDAG::Combine(CombineLevel Level, BatchAAResults *BatchAA,
29830 CodeGenOptLevel OptLevel) {
29831 /// This is the main entry point to this class.
29832 DAGCombiner(*this, BatchAA, OptLevel).Run(Level);
29833 }
29834