xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/MapVector.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/TargetLibraryInfo.h"
66 #include "llvm/Analysis/TargetTransformInfo.h"
67 #include "llvm/CodeGen/TargetLowering.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75 
76 using namespace llvm;
77 using namespace PatternMatch;
78 
79 #define DEBUG_TYPE "complex-deinterleaving"
80 
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82 
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87 
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94 
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102 
103 /// Returns true if the operation is a negation of V, and it works for both
104 /// integers and floats.
105 static bool isNeg(Value *V);
106 
107 /// Returns the operand for negation operation.
108 static Value *getNegOperand(Value *V);
109 
110 namespace {
111 template <typename T, typename IterT>
112 std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
113   auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
114   if (Common != A.end())
115     return std::make_optional(*Common);
116   return std::nullopt;
117 }
118 
119 class ComplexDeinterleavingLegacyPass : public FunctionPass {
120 public:
121   static char ID;
122 
123   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
124       : FunctionPass(ID), TM(TM) {
125     initializeComplexDeinterleavingLegacyPassPass(
126         *PassRegistry::getPassRegistry());
127   }
128 
129   StringRef getPassName() const override {
130     return "Complex Deinterleaving Pass";
131   }
132 
133   bool runOnFunction(Function &F) override;
134   void getAnalysisUsage(AnalysisUsage &AU) const override {
135     AU.addRequired<TargetLibraryInfoWrapperPass>();
136     AU.setPreservesCFG();
137   }
138 
139 private:
140   const TargetMachine *TM;
141 };
142 
143 class ComplexDeinterleavingGraph;
144 struct ComplexDeinterleavingCompositeNode {
145 
146   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
147                                      Value *R, Value *I)
148       : Operation(Op), Real(R), Imag(I) {}
149 
150 private:
151   friend class ComplexDeinterleavingGraph;
152   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
153   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
154   bool OperandsValid = true;
155 
156 public:
157   ComplexDeinterleavingOperation Operation;
158   Value *Real;
159   Value *Imag;
160 
161   // This two members are required exclusively for generating
162   // ComplexDeinterleavingOperation::Symmetric operations.
163   unsigned Opcode;
164   std::optional<FastMathFlags> Flags;
165 
166   ComplexDeinterleavingRotation Rotation =
167       ComplexDeinterleavingRotation::Rotation_0;
168   SmallVector<RawNodePtr> Operands;
169   Value *ReplacementNode = nullptr;
170 
171   void addOperand(NodePtr Node) {
172     if (!Node || !Node.get())
173       OperandsValid = false;
174     Operands.push_back(Node.get());
175   }
176 
177   void dump() { dump(dbgs()); }
178   void dump(raw_ostream &OS) {
179     auto PrintValue = [&](Value *V) {
180       if (V) {
181         OS << "\"";
182         V->print(OS, true);
183         OS << "\"\n";
184       } else
185         OS << "nullptr\n";
186     };
187     auto PrintNodeRef = [&](RawNodePtr Ptr) {
188       if (Ptr)
189         OS << Ptr << "\n";
190       else
191         OS << "nullptr\n";
192     };
193 
194     OS << "- CompositeNode: " << this << "\n";
195     OS << "  Real: ";
196     PrintValue(Real);
197     OS << "  Imag: ";
198     PrintValue(Imag);
199     OS << "  ReplacementNode: ";
200     PrintValue(ReplacementNode);
201     OS << "  Operation: " << (int)Operation << "\n";
202     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
203     OS << "  Operands: \n";
204     for (const auto &Op : Operands) {
205       OS << "    - ";
206       PrintNodeRef(Op);
207     }
208   }
209 
210   bool areOperandsValid() { return OperandsValid; }
211 };
212 
213 class ComplexDeinterleavingGraph {
214 public:
215   struct Product {
216     Value *Multiplier;
217     Value *Multiplicand;
218     bool IsPositive;
219   };
220 
221   using Addend = std::pair<Value *, bool>;
222   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
223   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
224 
225   // Helper struct for holding info about potential partial multiplication
226   // candidates
227   struct PartialMulCandidate {
228     Value *Common;
229     NodePtr Node;
230     unsigned RealIdx;
231     unsigned ImagIdx;
232     bool IsNodeInverted;
233   };
234 
235   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
236                                       const TargetLibraryInfo *TLI)
237       : TL(TL), TLI(TLI) {}
238 
239 private:
240   const TargetLowering *TL = nullptr;
241   const TargetLibraryInfo *TLI = nullptr;
242   SmallVector<NodePtr> CompositeNodes;
243   DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
244 
245   SmallPtrSet<Instruction *, 16> FinalInstructions;
246 
247   /// Root instructions are instructions from which complex computation starts
248   std::map<Instruction *, NodePtr> RootToNode;
249 
250   /// Topologically sorted root instructions
251   SmallVector<Instruction *, 1> OrderedRoots;
252 
253   /// When examining a basic block for complex deinterleaving, if it is a simple
254   /// one-block loop, then the only incoming block is 'Incoming' and the
255   /// 'BackEdge' block is the block itself."
256   BasicBlock *BackEdge = nullptr;
257   BasicBlock *Incoming = nullptr;
258 
259   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
260   /// %OutsideUser as it is shown in the IR:
261   ///
262   /// vector.body:
263   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
264   ///                                [ %ReductionOp, %vector.body ]
265   ///   ...
266   ///   %ReductionOp = fadd i64 ...
267   ///   ...
268   ///   br i1 %condition, label %vector.body, %middle.block
269   ///
270   /// middle.block:
271   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
272   ///
273   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
274   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
275   MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
276 
277   /// In the process of detecting a reduction, we consider a pair of
278   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
279   /// traverse the use-tree to detect complex operations. As this is a reduction
280   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
281   /// to the %ReductionOPs that we suspect to be complex.
282   /// RealPHI and ImagPHI are used by the identifyPHINode method.
283   PHINode *RealPHI = nullptr;
284   PHINode *ImagPHI = nullptr;
285 
286   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
287   /// detection.
288   bool PHIsFound = false;
289 
290   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
291   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
292   /// This mapping is populated during
293   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
294   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
295   /// replacement process.
296   std::map<PHINode *, PHINode *> OldToNewPHI;
297 
298   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
299                                Value *R, Value *I) {
300     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
301              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
302             (R && I)) &&
303            "Reduction related nodes must have Real and Imaginary parts");
304     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
305                                                                 I);
306   }
307 
308   NodePtr submitCompositeNode(NodePtr Node) {
309     CompositeNodes.push_back(Node);
310     if (Node->Real)
311       CachedResult[{Node->Real, Node->Imag}] = Node;
312     return Node;
313   }
314 
315   /// Identifies a complex partial multiply pattern and its rotation, based on
316   /// the following patterns
317   ///
318   ///  0:  r: cr + ar * br
319   ///      i: ci + ar * bi
320   /// 90:  r: cr - ai * bi
321   ///      i: ci + ai * br
322   /// 180: r: cr - ar * br
323   ///      i: ci - ar * bi
324   /// 270: r: cr + ai * bi
325   ///      i: ci - ai * br
326   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
327 
328   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
329   /// is partially known from identifyPartialMul, filling in the other half of
330   /// the complex pair.
331   NodePtr
332   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
333                               std::pair<Value *, Value *> &CommonOperandI);
334 
335   /// Identifies a complex add pattern and its rotation, based on the following
336   /// patterns.
337   ///
338   /// 90:  r: ar - bi
339   ///      i: ai + br
340   /// 270: r: ar + bi
341   ///      i: ai - br
342   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
343   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
344   NodePtr identifyPartialReduction(Value *R, Value *I);
345   NodePtr identifyDotProduct(Value *Inst);
346 
347   NodePtr identifyNode(Value *R, Value *I);
348 
349   /// Determine if a sum of complex numbers can be formed from \p RealAddends
350   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
351   /// Return nullptr if it is not possible to construct a complex number.
352   /// \p Flags are needed to generate symmetric Add and Sub operations.
353   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
354                             std::list<Addend> &ImagAddends,
355                             std::optional<FastMathFlags> Flags,
356                             NodePtr Accumulator);
357 
358   /// Extract one addend that have both real and imaginary parts positive.
359   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
360                                 std::list<Addend> &ImagAddends);
361 
362   /// Determine if sum of multiplications of complex numbers can be formed from
363   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
364   /// to it. Return nullptr if it is not possible to construct a complex number.
365   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
366                                   std::vector<Product> &ImagMuls,
367                                   NodePtr Accumulator);
368 
369   /// Go through pairs of multiplication (one Real and one Imag) and find all
370   /// possible candidates for partial multiplication and put them into \p
371   /// Candidates. Returns true if all Product has pair with common operand
372   bool collectPartialMuls(const std::vector<Product> &RealMuls,
373                           const std::vector<Product> &ImagMuls,
374                           std::vector<PartialMulCandidate> &Candidates);
375 
376   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
377   /// the order of complex computation operations may be significantly altered,
378   /// and the real and imaginary parts may not be executed in parallel. This
379   /// function takes this into consideration and employs a more general approach
380   /// to identify complex computations. Initially, it gathers all the addends
381   /// and multiplicands and then constructs a complex expression from them.
382   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
383 
384   NodePtr identifyRoot(Instruction *I);
385 
386   /// Identifies the Deinterleave operation applied to a vector containing
387   /// complex numbers. There are two ways to represent the Deinterleave
388   /// operation:
389   /// * Using two shufflevectors with even indices for /pReal instruction and
390   /// odd indices for /pImag instructions (only for fixed-width vectors)
391   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
392   /// intrinsic (for both fixed and scalable vectors)
393   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
394 
395   /// identifying the operation that represents a complex number repeated in a
396   /// Splat vector. There are two possible types of splats: ConstantExpr with
397   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
398   /// initialization mask with all values set to zero.
399   NodePtr identifySplat(Value *Real, Value *Imag);
400 
401   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
402 
403   /// Identifies SelectInsts in a loop that has reduction with predication masks
404   /// and/or predicated tail folding
405   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
406 
407   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
408 
409   /// Complete IR modifications after producing new reduction operation:
410   /// * Populate the PHINode generated for
411   /// ComplexDeinterleavingOperation::ReductionPHI
412   /// * Deinterleave the final value outside of the loop and repurpose original
413   /// reduction users
414   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
415   void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
416 
417 public:
418   void dump() { dump(dbgs()); }
419   void dump(raw_ostream &OS) {
420     for (const auto &Node : CompositeNodes)
421       Node->dump(OS);
422   }
423 
424   /// Returns false if the deinterleaving operation should be cancelled for the
425   /// current graph.
426   bool identifyNodes(Instruction *RootI);
427 
428   /// In case \pB is one-block loop, this function seeks potential reductions
429   /// and populates ReductionInfo. Returns true if any reductions were
430   /// identified.
431   bool collectPotentialReductions(BasicBlock *B);
432 
433   void identifyReductionNodes();
434 
435   /// Check that every instruction, from the roots to the leaves, has internal
436   /// uses.
437   bool checkNodes();
438 
439   /// Perform the actual replacement of the underlying instruction graph.
440   void replaceNodes();
441 };
442 
443 class ComplexDeinterleaving {
444 public:
445   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
446       : TL(tl), TLI(tli) {}
447   bool runOnFunction(Function &F);
448 
449 private:
450   bool evaluateBasicBlock(BasicBlock *B);
451 
452   const TargetLowering *TL = nullptr;
453   const TargetLibraryInfo *TLI = nullptr;
454 };
455 
456 } // namespace
457 
458 char ComplexDeinterleavingLegacyPass::ID = 0;
459 
460 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
461                       "Complex Deinterleaving", false, false)
462 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
463                     "Complex Deinterleaving", false, false)
464 
465 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
466                                                  FunctionAnalysisManager &AM) {
467   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
468   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
469   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
470     return PreservedAnalyses::all();
471 
472   PreservedAnalyses PA;
473   PA.preserve<FunctionAnalysisManagerModuleProxy>();
474   return PA;
475 }
476 
477 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
478   return new ComplexDeinterleavingLegacyPass(TM);
479 }
480 
481 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
482   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
483   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
484   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
485 }
486 
487 bool ComplexDeinterleaving::runOnFunction(Function &F) {
488   if (!ComplexDeinterleavingEnabled) {
489     LLVM_DEBUG(
490         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
491     return false;
492   }
493 
494   if (!TL->isComplexDeinterleavingSupported()) {
495     LLVM_DEBUG(
496         dbgs() << "Complex deinterleaving has been disabled, target does "
497                   "not support lowering of complex number operations.\n");
498     return false;
499   }
500 
501   bool Changed = false;
502   for (auto &B : F)
503     Changed |= evaluateBasicBlock(&B);
504 
505   return Changed;
506 }
507 
508 static bool isInterleavingMask(ArrayRef<int> Mask) {
509   // If the size is not even, it's not an interleaving mask
510   if ((Mask.size() & 1))
511     return false;
512 
513   int HalfNumElements = Mask.size() / 2;
514   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
515     int MaskIdx = Idx * 2;
516     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
517       return false;
518   }
519 
520   return true;
521 }
522 
523 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
524   int Offset = Mask[0];
525   int HalfNumElements = Mask.size() / 2;
526 
527   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
528     if (Mask[Idx] != (Idx * 2) + Offset)
529       return false;
530   }
531 
532   return true;
533 }
534 
535 bool isNeg(Value *V) {
536   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
537 }
538 
539 Value *getNegOperand(Value *V) {
540   assert(isNeg(V));
541   auto *I = cast<Instruction>(V);
542   if (I->getOpcode() == Instruction::FNeg)
543     return I->getOperand(0);
544 
545   return I->getOperand(1);
546 }
547 
548 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
549   ComplexDeinterleavingGraph Graph(TL, TLI);
550   if (Graph.collectPotentialReductions(B))
551     Graph.identifyReductionNodes();
552 
553   for (auto &I : *B)
554     Graph.identifyNodes(&I);
555 
556   if (Graph.checkNodes()) {
557     Graph.replaceNodes();
558     return true;
559   }
560 
561   return false;
562 }
563 
564 ComplexDeinterleavingGraph::NodePtr
565 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
566     Instruction *Real, Instruction *Imag,
567     std::pair<Value *, Value *> &PartialMatch) {
568   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
569                     << "\n");
570 
571   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
572     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
573     return nullptr;
574   }
575 
576   if ((Real->getOpcode() != Instruction::FMul &&
577        Real->getOpcode() != Instruction::Mul) ||
578       (Imag->getOpcode() != Instruction::FMul &&
579        Imag->getOpcode() != Instruction::Mul)) {
580     LLVM_DEBUG(
581         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
582     return nullptr;
583   }
584 
585   Value *R0 = Real->getOperand(0);
586   Value *R1 = Real->getOperand(1);
587   Value *I0 = Imag->getOperand(0);
588   Value *I1 = Imag->getOperand(1);
589 
590   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
591   // rotations and use the operand.
592   unsigned Negs = 0;
593   Value *Op;
594   if (match(R0, m_Neg(m_Value(Op)))) {
595     Negs |= 1;
596     R0 = Op;
597   } else if (match(R1, m_Neg(m_Value(Op)))) {
598     Negs |= 1;
599     R1 = Op;
600   }
601 
602   if (isNeg(I0)) {
603     Negs |= 2;
604     Negs ^= 1;
605     I0 = Op;
606   } else if (match(I1, m_Neg(m_Value(Op)))) {
607     Negs |= 2;
608     Negs ^= 1;
609     I1 = Op;
610   }
611 
612   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
613 
614   Value *CommonOperand;
615   Value *UncommonRealOp;
616   Value *UncommonImagOp;
617 
618   if (R0 == I0 || R0 == I1) {
619     CommonOperand = R0;
620     UncommonRealOp = R1;
621   } else if (R1 == I0 || R1 == I1) {
622     CommonOperand = R1;
623     UncommonRealOp = R0;
624   } else {
625     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
626     return nullptr;
627   }
628 
629   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
630   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
631       Rotation == ComplexDeinterleavingRotation::Rotation_270)
632     std::swap(UncommonRealOp, UncommonImagOp);
633 
634   // Between identifyPartialMul and here we need to have found a complete valid
635   // pair from the CommonOperand of each part.
636   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
637       Rotation == ComplexDeinterleavingRotation::Rotation_180)
638     PartialMatch.first = CommonOperand;
639   else
640     PartialMatch.second = CommonOperand;
641 
642   if (!PartialMatch.first || !PartialMatch.second) {
643     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
644     return nullptr;
645   }
646 
647   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
648   if (!CommonNode) {
649     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
650     return nullptr;
651   }
652 
653   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
654   if (!UncommonNode) {
655     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
656     return nullptr;
657   }
658 
659   NodePtr Node = prepareCompositeNode(
660       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
661   Node->Rotation = Rotation;
662   Node->addOperand(CommonNode);
663   Node->addOperand(UncommonNode);
664   return submitCompositeNode(Node);
665 }
666 
667 ComplexDeinterleavingGraph::NodePtr
668 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
669                                                Instruction *Imag) {
670   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
671                     << "\n");
672   // Determine rotation
673   auto IsAdd = [](unsigned Op) {
674     return Op == Instruction::FAdd || Op == Instruction::Add;
675   };
676   auto IsSub = [](unsigned Op) {
677     return Op == Instruction::FSub || Op == Instruction::Sub;
678   };
679   ComplexDeinterleavingRotation Rotation;
680   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
681     Rotation = ComplexDeinterleavingRotation::Rotation_0;
682   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
683     Rotation = ComplexDeinterleavingRotation::Rotation_90;
684   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
685     Rotation = ComplexDeinterleavingRotation::Rotation_180;
686   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
687     Rotation = ComplexDeinterleavingRotation::Rotation_270;
688   else {
689     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
690     return nullptr;
691   }
692 
693   if (isa<FPMathOperator>(Real) &&
694       (!Real->getFastMathFlags().allowContract() ||
695        !Imag->getFastMathFlags().allowContract())) {
696     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
697     return nullptr;
698   }
699 
700   Value *CR = Real->getOperand(0);
701   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
702   if (!RealMulI)
703     return nullptr;
704   Value *CI = Imag->getOperand(0);
705   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
706   if (!ImagMulI)
707     return nullptr;
708 
709   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
710     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
711     return nullptr;
712   }
713 
714   Value *R0 = RealMulI->getOperand(0);
715   Value *R1 = RealMulI->getOperand(1);
716   Value *I0 = ImagMulI->getOperand(0);
717   Value *I1 = ImagMulI->getOperand(1);
718 
719   Value *CommonOperand;
720   Value *UncommonRealOp;
721   Value *UncommonImagOp;
722 
723   if (R0 == I0 || R0 == I1) {
724     CommonOperand = R0;
725     UncommonRealOp = R1;
726   } else if (R1 == I0 || R1 == I1) {
727     CommonOperand = R1;
728     UncommonRealOp = R0;
729   } else {
730     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
731     return nullptr;
732   }
733 
734   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
735   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
736       Rotation == ComplexDeinterleavingRotation::Rotation_270)
737     std::swap(UncommonRealOp, UncommonImagOp);
738 
739   std::pair<Value *, Value *> PartialMatch(
740       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
741        Rotation == ComplexDeinterleavingRotation::Rotation_180)
742           ? CommonOperand
743           : nullptr,
744       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
745        Rotation == ComplexDeinterleavingRotation::Rotation_270)
746           ? CommonOperand
747           : nullptr);
748 
749   auto *CRInst = dyn_cast<Instruction>(CR);
750   auto *CIInst = dyn_cast<Instruction>(CI);
751 
752   if (!CRInst || !CIInst) {
753     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
754     return nullptr;
755   }
756 
757   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
758   if (!CNode) {
759     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
760     return nullptr;
761   }
762 
763   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
764   if (!UncommonRes) {
765     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
766     return nullptr;
767   }
768 
769   assert(PartialMatch.first && PartialMatch.second);
770   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
771   if (!CommonRes) {
772     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
773     return nullptr;
774   }
775 
776   NodePtr Node = prepareCompositeNode(
777       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
778   Node->Rotation = Rotation;
779   Node->addOperand(CommonRes);
780   Node->addOperand(UncommonRes);
781   Node->addOperand(CNode);
782   return submitCompositeNode(Node);
783 }
784 
785 ComplexDeinterleavingGraph::NodePtr
786 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
787   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
788 
789   // Determine rotation
790   ComplexDeinterleavingRotation Rotation;
791   if ((Real->getOpcode() == Instruction::FSub &&
792        Imag->getOpcode() == Instruction::FAdd) ||
793       (Real->getOpcode() == Instruction::Sub &&
794        Imag->getOpcode() == Instruction::Add))
795     Rotation = ComplexDeinterleavingRotation::Rotation_90;
796   else if ((Real->getOpcode() == Instruction::FAdd &&
797             Imag->getOpcode() == Instruction::FSub) ||
798            (Real->getOpcode() == Instruction::Add &&
799             Imag->getOpcode() == Instruction::Sub))
800     Rotation = ComplexDeinterleavingRotation::Rotation_270;
801   else {
802     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
803     return nullptr;
804   }
805 
806   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
807   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
808   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
809   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
810 
811   if (!AR || !AI || !BR || !BI) {
812     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
813     return nullptr;
814   }
815 
816   NodePtr ResA = identifyNode(AR, AI);
817   if (!ResA) {
818     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
819     return nullptr;
820   }
821   NodePtr ResB = identifyNode(BR, BI);
822   if (!ResB) {
823     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
824     return nullptr;
825   }
826 
827   NodePtr Node =
828       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
829   Node->Rotation = Rotation;
830   Node->addOperand(ResA);
831   Node->addOperand(ResB);
832   return submitCompositeNode(Node);
833 }
834 
835 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
836   unsigned OpcA = A->getOpcode();
837   unsigned OpcB = B->getOpcode();
838 
839   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
840          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
841          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
842          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
843 }
844 
845 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
846   auto Pattern =
847       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
848 
849   return match(A, Pattern) && match(B, Pattern);
850 }
851 
852 static bool isInstructionPotentiallySymmetric(Instruction *I) {
853   switch (I->getOpcode()) {
854   case Instruction::FAdd:
855   case Instruction::FSub:
856   case Instruction::FMul:
857   case Instruction::FNeg:
858   case Instruction::Add:
859   case Instruction::Sub:
860   case Instruction::Mul:
861     return true;
862   default:
863     return false;
864   }
865 }
866 
867 ComplexDeinterleavingGraph::NodePtr
868 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
869                                                        Instruction *Imag) {
870   if (Real->getOpcode() != Imag->getOpcode())
871     return nullptr;
872 
873   if (!isInstructionPotentiallySymmetric(Real) ||
874       !isInstructionPotentiallySymmetric(Imag))
875     return nullptr;
876 
877   auto *R0 = Real->getOperand(0);
878   auto *I0 = Imag->getOperand(0);
879 
880   NodePtr Op0 = identifyNode(R0, I0);
881   NodePtr Op1 = nullptr;
882   if (Op0 == nullptr)
883     return nullptr;
884 
885   if (Real->isBinaryOp()) {
886     auto *R1 = Real->getOperand(1);
887     auto *I1 = Imag->getOperand(1);
888     Op1 = identifyNode(R1, I1);
889     if (Op1 == nullptr)
890       return nullptr;
891   }
892 
893   if (isa<FPMathOperator>(Real) &&
894       Real->getFastMathFlags() != Imag->getFastMathFlags())
895     return nullptr;
896 
897   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
898                                    Real, Imag);
899   Node->Opcode = Real->getOpcode();
900   if (isa<FPMathOperator>(Real))
901     Node->Flags = Real->getFastMathFlags();
902 
903   Node->addOperand(Op0);
904   if (Real->isBinaryOp())
905     Node->addOperand(Op1);
906 
907   return submitCompositeNode(Node);
908 }
909 
910 ComplexDeinterleavingGraph::NodePtr
911 ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
912 
913   if (!TL->isComplexDeinterleavingOperationSupported(
914           ComplexDeinterleavingOperation::CDot, V->getType())) {
915     LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
916                          "operation CDot with the type "
917                       << *V->getType() << "\n");
918     return nullptr;
919   }
920 
921   auto *Inst = cast<Instruction>(V);
922   auto *RealUser = cast<Instruction>(*Inst->user_begin());
923 
924   NodePtr CN =
925       prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
926 
927   NodePtr ANode;
928 
929   const Intrinsic::ID PartialReduceInt =
930       Intrinsic::experimental_vector_partial_reduce_add;
931 
932   Value *AReal = nullptr;
933   Value *AImag = nullptr;
934   Value *BReal = nullptr;
935   Value *BImag = nullptr;
936   Value *Phi = nullptr;
937 
938   auto UnwrapCast = [](Value *V) -> Value * {
939     if (auto *CI = dyn_cast<CastInst>(V))
940       return CI->getOperand(0);
941     return V;
942   };
943 
944   auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
945       m_Intrinsic<PartialReduceInt>(m_Value(Phi),
946                                     m_Mul(m_Value(BReal), m_Value(AReal))),
947       m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
948 
949   auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
950       m_Intrinsic<PartialReduceInt>(
951           m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
952       m_Mul(m_Value(BImag), m_Value(AReal)));
953 
954   if (match(Inst, PatternRot0)) {
955     CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
956   } else if (match(Inst, PatternRot270)) {
957     CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
958   } else {
959     Value *A0, *A1;
960     // The rotations 90 and 180 share the same operation pattern, so inspect the
961     // order of the operands, identifying where the real and imaginary
962     // components of A go, to discern between the aforementioned rotations.
963     auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
964         m_Intrinsic<PartialReduceInt>(m_Value(Phi),
965                                       m_Mul(m_Value(BReal), m_Value(A0))),
966         m_Mul(m_Value(BImag), m_Value(A1)));
967 
968     if (!match(Inst, PatternRot90Rot180))
969       return nullptr;
970 
971     A0 = UnwrapCast(A0);
972     A1 = UnwrapCast(A1);
973 
974     // Test if A0 is real/A1 is imag
975     ANode = identifyNode(A0, A1);
976     if (!ANode) {
977       // Test if A0 is imag/A1 is real
978       ANode = identifyNode(A1, A0);
979       // Unable to identify operand components, thus unable to identify rotation
980       if (!ANode)
981         return nullptr;
982       CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
983       AReal = A1;
984       AImag = A0;
985     } else {
986       AReal = A0;
987       AImag = A1;
988       CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
989     }
990   }
991 
992   AReal = UnwrapCast(AReal);
993   AImag = UnwrapCast(AImag);
994   BReal = UnwrapCast(BReal);
995   BImag = UnwrapCast(BImag);
996 
997   VectorType *VTy = cast<VectorType>(V->getType());
998   Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
999   if (AReal->getType() != ExpectedOperandTy)
1000     return nullptr;
1001   if (AImag->getType() != ExpectedOperandTy)
1002     return nullptr;
1003   if (BReal->getType() != ExpectedOperandTy)
1004     return nullptr;
1005   if (BImag->getType() != ExpectedOperandTy)
1006     return nullptr;
1007 
1008   if (Phi->getType() != VTy && RealUser->getType() != VTy)
1009     return nullptr;
1010 
1011   NodePtr Node = identifyNode(AReal, AImag);
1012 
1013   // In the case that a node was identified to figure out the rotation, ensure
1014   // that trying to identify a node with AReal and AImag post-unwrap results in
1015   // the same node
1016   if (ANode && Node != ANode) {
1017     LLVM_DEBUG(
1018         dbgs()
1019         << "Identified node is different from previously identified node. "
1020            "Unable to confidently generate a complex operation node\n");
1021     return nullptr;
1022   }
1023 
1024   CN->addOperand(Node);
1025   CN->addOperand(identifyNode(BReal, BImag));
1026   CN->addOperand(identifyNode(Phi, RealUser));
1027 
1028   return submitCompositeNode(CN);
1029 }
1030 
1031 ComplexDeinterleavingGraph::NodePtr
1032 ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1033   // Partial reductions don't support non-vector types, so check these first
1034   if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1035     return nullptr;
1036 
1037   if (!R->hasUseList() || !I->hasUseList())
1038     return nullptr;
1039 
1040   auto CommonUser =
1041       findCommonBetweenCollections<Value *>(R->users(), I->users());
1042   if (!CommonUser)
1043     return nullptr;
1044 
1045   auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1046   if (!IInst || IInst->getIntrinsicID() !=
1047                     Intrinsic::experimental_vector_partial_reduce_add)
1048     return nullptr;
1049 
1050   if (NodePtr CN = identifyDotProduct(IInst))
1051     return CN;
1052 
1053   return nullptr;
1054 }
1055 
1056 ComplexDeinterleavingGraph::NodePtr
1057 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
1058   auto It = CachedResult.find({R, I});
1059   if (It != CachedResult.end()) {
1060     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1061     return It->second;
1062   }
1063 
1064   if (NodePtr CN = identifyPartialReduction(R, I))
1065     return CN;
1066 
1067   bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1068   if (!IsReduction && R->getType() != I->getType())
1069     return nullptr;
1070 
1071   if (NodePtr CN = identifySplat(R, I))
1072     return CN;
1073 
1074   auto *Real = dyn_cast<Instruction>(R);
1075   auto *Imag = dyn_cast<Instruction>(I);
1076   if (!Real || !Imag)
1077     return nullptr;
1078 
1079   if (NodePtr CN = identifyDeinterleave(Real, Imag))
1080     return CN;
1081 
1082   if (NodePtr CN = identifyPHINode(Real, Imag))
1083     return CN;
1084 
1085   if (NodePtr CN = identifySelectNode(Real, Imag))
1086     return CN;
1087 
1088   auto *VTy = cast<VectorType>(Real->getType());
1089   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1090 
1091   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1092       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1093   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1094       ComplexDeinterleavingOperation::CAdd, NewVTy);
1095 
1096   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1097     if (NodePtr CN = identifyPartialMul(Real, Imag))
1098       return CN;
1099   }
1100 
1101   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1102     if (NodePtr CN = identifyAdd(Real, Imag))
1103       return CN;
1104   }
1105 
1106   if (HasCMulSupport && HasCAddSupport) {
1107     if (NodePtr CN = identifyReassocNodes(Real, Imag))
1108       return CN;
1109   }
1110 
1111   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
1112     return CN;
1113 
1114   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
1115   CachedResult[{R, I}] = nullptr;
1116   return nullptr;
1117 }
1118 
1119 ComplexDeinterleavingGraph::NodePtr
1120 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1121                                                  Instruction *Imag) {
1122   auto IsOperationSupported = [](unsigned Opcode) -> bool {
1123     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1124            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1125            Opcode == Instruction::Sub;
1126   };
1127 
1128   if (!IsOperationSupported(Real->getOpcode()) ||
1129       !IsOperationSupported(Imag->getOpcode()))
1130     return nullptr;
1131 
1132   std::optional<FastMathFlags> Flags;
1133   if (isa<FPMathOperator>(Real)) {
1134     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1135       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1136                            "not identical\n");
1137       return nullptr;
1138     }
1139 
1140     Flags = Real->getFastMathFlags();
1141     if (!Flags->allowReassoc()) {
1142       LLVM_DEBUG(
1143           dbgs()
1144           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1145       return nullptr;
1146     }
1147   }
1148 
1149   // Collect multiplications and addend instructions from the given instruction
1150   // while traversing it operands. Additionally, verify that all instructions
1151   // have the same fast math flags.
1152   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
1153                           std::list<Addend> &Addends) -> bool {
1154     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1155     SmallPtrSet<Value *, 8> Visited;
1156     while (!Worklist.empty()) {
1157       auto [V, IsPositive] = Worklist.pop_back_val();
1158       if (!Visited.insert(V).second)
1159         continue;
1160 
1161       Instruction *I = dyn_cast<Instruction>(V);
1162       if (!I) {
1163         Addends.emplace_back(V, IsPositive);
1164         continue;
1165       }
1166 
1167       // If an instruction has more than one user, it indicates that it either
1168       // has an external user, which will be later checked by the checkNodes
1169       // function, or it is a subexpression utilized by multiple expressions. In
1170       // the latter case, we will attempt to separately identify the complex
1171       // operation from here in order to create a shared
1172       // ComplexDeinterleavingCompositeNode.
1173       if (I != Insn && I->hasNUsesOrMore(2)) {
1174         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1175         Addends.emplace_back(I, IsPositive);
1176         continue;
1177       }
1178       switch (I->getOpcode()) {
1179       case Instruction::FAdd:
1180       case Instruction::Add:
1181         Worklist.emplace_back(I->getOperand(1), IsPositive);
1182         Worklist.emplace_back(I->getOperand(0), IsPositive);
1183         break;
1184       case Instruction::FSub:
1185         Worklist.emplace_back(I->getOperand(1), !IsPositive);
1186         Worklist.emplace_back(I->getOperand(0), IsPositive);
1187         break;
1188       case Instruction::Sub:
1189         if (isNeg(I)) {
1190           Worklist.emplace_back(getNegOperand(I), !IsPositive);
1191         } else {
1192           Worklist.emplace_back(I->getOperand(1), !IsPositive);
1193           Worklist.emplace_back(I->getOperand(0), IsPositive);
1194         }
1195         break;
1196       case Instruction::FMul:
1197       case Instruction::Mul: {
1198         Value *A, *B;
1199         if (isNeg(I->getOperand(0))) {
1200           A = getNegOperand(I->getOperand(0));
1201           IsPositive = !IsPositive;
1202         } else {
1203           A = I->getOperand(0);
1204         }
1205 
1206         if (isNeg(I->getOperand(1))) {
1207           B = getNegOperand(I->getOperand(1));
1208           IsPositive = !IsPositive;
1209         } else {
1210           B = I->getOperand(1);
1211         }
1212         Muls.push_back(Product{A, B, IsPositive});
1213         break;
1214       }
1215       case Instruction::FNeg:
1216         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1217         break;
1218       default:
1219         Addends.emplace_back(I, IsPositive);
1220         continue;
1221       }
1222 
1223       if (Flags && I->getFastMathFlags() != *Flags) {
1224         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1225                              "inconsistent with the root instructions' flags: "
1226                           << *I << "\n");
1227         return false;
1228       }
1229     }
1230     return true;
1231   };
1232 
1233   std::vector<Product> RealMuls, ImagMuls;
1234   std::list<Addend> RealAddends, ImagAddends;
1235   if (!Collect(Real, RealMuls, RealAddends) ||
1236       !Collect(Imag, ImagMuls, ImagAddends))
1237     return nullptr;
1238 
1239   if (RealAddends.size() != ImagAddends.size())
1240     return nullptr;
1241 
1242   NodePtr FinalNode;
1243   if (!RealMuls.empty() || !ImagMuls.empty()) {
1244     // If there are multiplicands, extract positive addend and use it as an
1245     // accumulator
1246     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1247     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1248     if (!FinalNode)
1249       return nullptr;
1250   }
1251 
1252   // Identify and process remaining additions
1253   if (!RealAddends.empty() || !ImagAddends.empty()) {
1254     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1255     if (!FinalNode)
1256       return nullptr;
1257   }
1258   assert(FinalNode && "FinalNode can not be nullptr here");
1259   // Set the Real and Imag fields of the final node and submit it
1260   FinalNode->Real = Real;
1261   FinalNode->Imag = Imag;
1262   submitCompositeNode(FinalNode);
1263   return FinalNode;
1264 }
1265 
1266 bool ComplexDeinterleavingGraph::collectPartialMuls(
1267     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1268     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1269   // Helper function to extract a common operand from two products
1270   auto FindCommonInstruction = [](const Product &Real,
1271                                   const Product &Imag) -> Value * {
1272     if (Real.Multiplicand == Imag.Multiplicand ||
1273         Real.Multiplicand == Imag.Multiplier)
1274       return Real.Multiplicand;
1275 
1276     if (Real.Multiplier == Imag.Multiplicand ||
1277         Real.Multiplier == Imag.Multiplier)
1278       return Real.Multiplier;
1279 
1280     return nullptr;
1281   };
1282 
1283   // Iterating over real and imaginary multiplications to find common operands
1284   // If a common operand is found, a partial multiplication candidate is created
1285   // and added to the candidates vector The function returns false if no common
1286   // operands are found for any product
1287   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1288     bool FoundCommon = false;
1289     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1290       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1291       if (!Common)
1292         continue;
1293 
1294       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1295                                                    : RealMuls[i].Multiplicand;
1296       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1297                                                    : ImagMuls[j].Multiplicand;
1298 
1299       auto Node = identifyNode(A, B);
1300       if (Node) {
1301         FoundCommon = true;
1302         PartialMulCandidates.push_back({Common, Node, i, j, false});
1303       }
1304 
1305       Node = identifyNode(B, A);
1306       if (Node) {
1307         FoundCommon = true;
1308         PartialMulCandidates.push_back({Common, Node, i, j, true});
1309       }
1310     }
1311     if (!FoundCommon)
1312       return false;
1313   }
1314   return true;
1315 }
1316 
1317 ComplexDeinterleavingGraph::NodePtr
1318 ComplexDeinterleavingGraph::identifyMultiplications(
1319     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1320     NodePtr Accumulator = nullptr) {
1321   if (RealMuls.size() != ImagMuls.size())
1322     return nullptr;
1323 
1324   std::vector<PartialMulCandidate> Info;
1325   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1326     return nullptr;
1327 
1328   // Map to store common instruction to node pointers
1329   std::map<Value *, NodePtr> CommonToNode;
1330   std::vector<bool> Processed(Info.size(), false);
1331   for (unsigned I = 0; I < Info.size(); ++I) {
1332     if (Processed[I])
1333       continue;
1334 
1335     PartialMulCandidate &InfoA = Info[I];
1336     for (unsigned J = I + 1; J < Info.size(); ++J) {
1337       if (Processed[J])
1338         continue;
1339 
1340       PartialMulCandidate &InfoB = Info[J];
1341       auto *InfoReal = &InfoA;
1342       auto *InfoImag = &InfoB;
1343 
1344       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1345       if (!NodeFromCommon) {
1346         std::swap(InfoReal, InfoImag);
1347         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1348       }
1349       if (!NodeFromCommon)
1350         continue;
1351 
1352       CommonToNode[InfoReal->Common] = NodeFromCommon;
1353       CommonToNode[InfoImag->Common] = NodeFromCommon;
1354       Processed[I] = true;
1355       Processed[J] = true;
1356     }
1357   }
1358 
1359   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1360   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1361   NodePtr Result = Accumulator;
1362   for (auto &PMI : Info) {
1363     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1364       continue;
1365 
1366     auto It = CommonToNode.find(PMI.Common);
1367     // TODO: Process independent complex multiplications. Cases like this:
1368     //  A.real() * B where both A and B are complex numbers.
1369     if (It == CommonToNode.end()) {
1370       LLVM_DEBUG({
1371         dbgs() << "Unprocessed independent partial multiplication:\n";
1372         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1373           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1374                            << " multiplied by " << *Mul->Multiplicand << "\n";
1375       });
1376       return nullptr;
1377     }
1378 
1379     auto &RealMul = RealMuls[PMI.RealIdx];
1380     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1381 
1382     auto NodeA = It->second;
1383     auto NodeB = PMI.Node;
1384     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1385     // The following table illustrates the relationship between multiplications
1386     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1387     // can see:
1388     //
1389     // Rotation |   Real |   Imag |
1390     // ---------+--------+--------+
1391     //        0 |  x * u |  x * v |
1392     //       90 | -y * v |  y * u |
1393     //      180 | -x * u | -x * v |
1394     //      270 |  y * v | -y * u |
1395     //
1396     // Check if the candidate can indeed be represented by partial
1397     // multiplication
1398     // TODO: Add support for multiplication by complex one
1399     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1400         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1401       continue;
1402 
1403     // Determine the rotation based on the multiplications
1404     ComplexDeinterleavingRotation Rotation;
1405     if (IsMultiplicandReal) {
1406       // Detect 0 and 180 degrees rotation
1407       if (RealMul.IsPositive && ImagMul.IsPositive)
1408         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1409       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1410         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1411       else
1412         continue;
1413 
1414     } else {
1415       // Detect 90 and 270 degrees rotation
1416       if (!RealMul.IsPositive && ImagMul.IsPositive)
1417         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1418       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1419         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1420       else
1421         continue;
1422     }
1423 
1424     LLVM_DEBUG({
1425       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1426       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1427       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1428       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1429       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1430       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1431     });
1432 
1433     NodePtr NodeMul = prepareCompositeNode(
1434         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1435     NodeMul->Rotation = Rotation;
1436     NodeMul->addOperand(NodeA);
1437     NodeMul->addOperand(NodeB);
1438     if (Result)
1439       NodeMul->addOperand(Result);
1440     submitCompositeNode(NodeMul);
1441     Result = NodeMul;
1442     ProcessedReal[PMI.RealIdx] = true;
1443     ProcessedImag[PMI.ImagIdx] = true;
1444   }
1445 
1446   // Ensure all products have been processed, if not return nullptr.
1447   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1448       !all_of(ProcessedImag, [](bool V) { return V; })) {
1449 
1450     // Dump debug information about which partial multiplications are not
1451     // processed.
1452     LLVM_DEBUG({
1453       dbgs() << "Unprocessed products (Real):\n";
1454       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1455         if (!ProcessedReal[i])
1456           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1457                            << *RealMuls[i].Multiplier << " multiplied by "
1458                            << *RealMuls[i].Multiplicand << "\n";
1459       }
1460       dbgs() << "Unprocessed products (Imag):\n";
1461       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1462         if (!ProcessedImag[i])
1463           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1464                            << *ImagMuls[i].Multiplier << " multiplied by "
1465                            << *ImagMuls[i].Multiplicand << "\n";
1466       }
1467     });
1468     return nullptr;
1469   }
1470 
1471   return Result;
1472 }
1473 
1474 ComplexDeinterleavingGraph::NodePtr
1475 ComplexDeinterleavingGraph::identifyAdditions(
1476     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1477     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1478   if (RealAddends.size() != ImagAddends.size())
1479     return nullptr;
1480 
1481   NodePtr Result;
1482   // If we have accumulator use it as first addend
1483   if (Accumulator)
1484     Result = Accumulator;
1485   // Otherwise find an element with both positive real and imaginary parts.
1486   else
1487     Result = extractPositiveAddend(RealAddends, ImagAddends);
1488 
1489   if (!Result)
1490     return nullptr;
1491 
1492   while (!RealAddends.empty()) {
1493     auto ItR = RealAddends.begin();
1494     auto [R, IsPositiveR] = *ItR;
1495 
1496     bool FoundImag = false;
1497     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1498       auto [I, IsPositiveI] = *ItI;
1499       ComplexDeinterleavingRotation Rotation;
1500       if (IsPositiveR && IsPositiveI)
1501         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1502       else if (!IsPositiveR && IsPositiveI)
1503         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1504       else if (!IsPositiveR && !IsPositiveI)
1505         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1506       else
1507         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1508 
1509       NodePtr AddNode;
1510       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1511           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1512         AddNode = identifyNode(R, I);
1513       } else {
1514         AddNode = identifyNode(I, R);
1515       }
1516       if (AddNode) {
1517         LLVM_DEBUG({
1518           dbgs() << "Identified addition:\n";
1519           dbgs().indent(4) << "X: " << *R << "\n";
1520           dbgs().indent(4) << "Y: " << *I << "\n";
1521           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1522         });
1523 
1524         NodePtr TmpNode;
1525         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1526           TmpNode = prepareCompositeNode(
1527               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1528           if (Flags) {
1529             TmpNode->Opcode = Instruction::FAdd;
1530             TmpNode->Flags = *Flags;
1531           } else {
1532             TmpNode->Opcode = Instruction::Add;
1533           }
1534         } else if (Rotation ==
1535                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1536           TmpNode = prepareCompositeNode(
1537               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1538           if (Flags) {
1539             TmpNode->Opcode = Instruction::FSub;
1540             TmpNode->Flags = *Flags;
1541           } else {
1542             TmpNode->Opcode = Instruction::Sub;
1543           }
1544         } else {
1545           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1546                                          nullptr, nullptr);
1547           TmpNode->Rotation = Rotation;
1548         }
1549 
1550         TmpNode->addOperand(Result);
1551         TmpNode->addOperand(AddNode);
1552         submitCompositeNode(TmpNode);
1553         Result = TmpNode;
1554         RealAddends.erase(ItR);
1555         ImagAddends.erase(ItI);
1556         FoundImag = true;
1557         break;
1558       }
1559     }
1560     if (!FoundImag)
1561       return nullptr;
1562   }
1563   return Result;
1564 }
1565 
1566 ComplexDeinterleavingGraph::NodePtr
1567 ComplexDeinterleavingGraph::extractPositiveAddend(
1568     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1569   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1570     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1571       auto [R, IsPositiveR] = *ItR;
1572       auto [I, IsPositiveI] = *ItI;
1573       if (IsPositiveR && IsPositiveI) {
1574         auto Result = identifyNode(R, I);
1575         if (Result) {
1576           RealAddends.erase(ItR);
1577           ImagAddends.erase(ItI);
1578           return Result;
1579         }
1580       }
1581     }
1582   }
1583   return nullptr;
1584 }
1585 
1586 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1587   // This potential root instruction might already have been recognized as
1588   // reduction. Because RootToNode maps both Real and Imaginary parts to
1589   // CompositeNode we should choose only one either Real or Imag instruction to
1590   // use as an anchor for generating complex instruction.
1591   auto It = RootToNode.find(RootI);
1592   if (It != RootToNode.end()) {
1593     auto RootNode = It->second;
1594     assert(RootNode->Operation ==
1595                ComplexDeinterleavingOperation::ReductionOperation ||
1596            RootNode->Operation ==
1597                ComplexDeinterleavingOperation::ReductionSingle);
1598     // Find out which part, Real or Imag, comes later, and only if we come to
1599     // the latest part, add it to OrderedRoots.
1600     auto *R = cast<Instruction>(RootNode->Real);
1601     auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
1602 
1603     Instruction *ReplacementAnchor;
1604     if (I)
1605       ReplacementAnchor = R->comesBefore(I) ? I : R;
1606     else
1607       ReplacementAnchor = R;
1608 
1609     if (ReplacementAnchor != RootI)
1610       return false;
1611     OrderedRoots.push_back(RootI);
1612     return true;
1613   }
1614 
1615   auto RootNode = identifyRoot(RootI);
1616   if (!RootNode)
1617     return false;
1618 
1619   LLVM_DEBUG({
1620     Function *F = RootI->getFunction();
1621     BasicBlock *B = RootI->getParent();
1622     dbgs() << "Complex deinterleaving graph for " << F->getName()
1623            << "::" << B->getName() << ".\n";
1624     dump(dbgs());
1625     dbgs() << "\n";
1626   });
1627   RootToNode[RootI] = RootNode;
1628   OrderedRoots.push_back(RootI);
1629   return true;
1630 }
1631 
1632 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1633   bool FoundPotentialReduction = false;
1634 
1635   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1636   if (!Br || Br->getNumSuccessors() != 2)
1637     return false;
1638 
1639   // Identify simple one-block loop
1640   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1641     return false;
1642 
1643   for (auto &PHI : B->phis()) {
1644     if (PHI.getNumIncomingValues() != 2)
1645       continue;
1646 
1647     if (!PHI.getType()->isVectorTy())
1648       continue;
1649 
1650     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1651     if (!ReductionOp)
1652       continue;
1653 
1654     // Check if final instruction is reduced outside of current block
1655     Instruction *FinalReduction = nullptr;
1656     auto NumUsers = 0u;
1657     for (auto *U : ReductionOp->users()) {
1658       ++NumUsers;
1659       if (U == &PHI)
1660         continue;
1661       FinalReduction = dyn_cast<Instruction>(U);
1662     }
1663 
1664     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1665         isa<PHINode>(FinalReduction))
1666       continue;
1667 
1668     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1669     BackEdge = B;
1670     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1671     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1672     Incoming = PHI.getIncomingBlock(IncomingIdx);
1673     FoundPotentialReduction = true;
1674 
1675     // If the initial value of PHINode is an Instruction, consider it a leaf
1676     // value of a complex deinterleaving graph.
1677     if (auto *InitPHI =
1678             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1679       FinalInstructions.insert(InitPHI);
1680   }
1681   return FoundPotentialReduction;
1682 }
1683 
1684 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1685   SmallVector<bool> Processed(ReductionInfo.size(), false);
1686   SmallVector<Instruction *> OperationInstruction;
1687   for (auto &P : ReductionInfo)
1688     OperationInstruction.push_back(P.first);
1689 
1690   // Identify a complex computation by evaluating two reduction operations that
1691   // potentially could be involved
1692   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1693     if (Processed[i])
1694       continue;
1695     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1696       if (Processed[j])
1697         continue;
1698       auto *Real = OperationInstruction[i];
1699       auto *Imag = OperationInstruction[j];
1700       if (Real->getType() != Imag->getType())
1701         continue;
1702 
1703       RealPHI = ReductionInfo[Real].first;
1704       ImagPHI = ReductionInfo[Imag].first;
1705       PHIsFound = false;
1706       auto Node = identifyNode(Real, Imag);
1707       if (!Node) {
1708         std::swap(Real, Imag);
1709         std::swap(RealPHI, ImagPHI);
1710         Node = identifyNode(Real, Imag);
1711       }
1712 
1713       // If a node is identified and reduction PHINode is used in the chain of
1714       // operations, mark its operation instructions as used to prevent
1715       // re-identification and attach the node to the real part
1716       if (Node && PHIsFound) {
1717         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1718                           << *Real << " / " << *Imag << "\n");
1719         Processed[i] = true;
1720         Processed[j] = true;
1721         auto RootNode = prepareCompositeNode(
1722             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1723         RootNode->addOperand(Node);
1724         RootToNode[Real] = RootNode;
1725         RootToNode[Imag] = RootNode;
1726         submitCompositeNode(RootNode);
1727         break;
1728       }
1729     }
1730 
1731     auto *Real = OperationInstruction[i];
1732     // We want to check that we have 2 operands, but the function attributes
1733     // being counted as operands bloats this value.
1734     if (Processed[i] || Real->getNumOperands() < 2)
1735       continue;
1736 
1737     // Can only combined integer reductions at the moment.
1738     if (!ReductionInfo[Real].second->getType()->isIntegerTy())
1739       continue;
1740 
1741     RealPHI = ReductionInfo[Real].first;
1742     ImagPHI = nullptr;
1743     PHIsFound = false;
1744     auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1745     if (Node && PHIsFound) {
1746       LLVM_DEBUG(
1747           dbgs() << "Identified single reduction starting from instruction: "
1748                  << *Real << "/" << *ReductionInfo[Real].second << "\n");
1749 
1750       // Reducing to a single vector is not supported, only permit reducing down
1751       // to scalar values.
1752       // Doing this here will leave the prior node in the graph,
1753       // however with no uses the node will be unreachable by the replacement
1754       // process. That along with the usage outside the graph should prevent the
1755       // replacement process from kicking off at all for this graph.
1756       // TODO Add support for reducing to a single vector value
1757       if (ReductionInfo[Real].second->getType()->isVectorTy())
1758         continue;
1759 
1760       Processed[i] = true;
1761       auto RootNode = prepareCompositeNode(
1762           ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1763       RootNode->addOperand(Node);
1764       RootToNode[Real] = RootNode;
1765       submitCompositeNode(RootNode);
1766     }
1767   }
1768 
1769   RealPHI = nullptr;
1770   ImagPHI = nullptr;
1771 }
1772 
1773 bool ComplexDeinterleavingGraph::checkNodes() {
1774 
1775   bool FoundDeinterleaveNode = false;
1776   for (NodePtr N : CompositeNodes) {
1777     if (!N->areOperandsValid())
1778       return false;
1779     if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1780       FoundDeinterleaveNode = true;
1781   }
1782 
1783   // We need a deinterleave node in order to guarantee that we're working with
1784   // complex numbers.
1785   if (!FoundDeinterleaveNode) {
1786     LLVM_DEBUG(
1787         dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1788                   "guarantee safety during graph transformation.\n");
1789     return false;
1790   }
1791 
1792   // Collect all instructions from roots to leaves
1793   SmallPtrSet<Instruction *, 16> AllInstructions;
1794   SmallVector<Instruction *, 8> Worklist;
1795   for (auto &Pair : RootToNode)
1796     Worklist.push_back(Pair.first);
1797 
1798   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1799   // chains
1800   while (!Worklist.empty()) {
1801     auto *I = Worklist.pop_back_val();
1802 
1803     if (!AllInstructions.insert(I).second)
1804       continue;
1805 
1806     for (Value *Op : I->operands()) {
1807       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1808         if (!FinalInstructions.count(I))
1809           Worklist.emplace_back(OpI);
1810       }
1811     }
1812   }
1813 
1814   // Find instructions that have users outside of chain
1815   for (auto *I : AllInstructions) {
1816     // Skip root nodes
1817     if (RootToNode.count(I))
1818       continue;
1819 
1820     for (User *U : I->users()) {
1821       if (AllInstructions.count(cast<Instruction>(U)))
1822         continue;
1823 
1824       // Found an instruction that is not used by XCMLA/XCADD chain
1825       Worklist.emplace_back(I);
1826       break;
1827     }
1828   }
1829 
1830   // If any instructions are found to be used outside, find and remove roots
1831   // that somehow connect to those instructions.
1832   SmallPtrSet<Instruction *, 16> Visited;
1833   while (!Worklist.empty()) {
1834     auto *I = Worklist.pop_back_val();
1835     if (!Visited.insert(I).second)
1836       continue;
1837 
1838     // Found an impacted root node. Removing it from the nodes to be
1839     // deinterleaved
1840     if (RootToNode.count(I)) {
1841       LLVM_DEBUG(dbgs() << "Instruction " << *I
1842                         << " could be deinterleaved but its chain of complex "
1843                            "operations have an outside user\n");
1844       RootToNode.erase(I);
1845     }
1846 
1847     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1848       continue;
1849 
1850     for (User *U : I->users())
1851       Worklist.emplace_back(cast<Instruction>(U));
1852 
1853     for (Value *Op : I->operands()) {
1854       if (auto *OpI = dyn_cast<Instruction>(Op))
1855         Worklist.emplace_back(OpI);
1856     }
1857   }
1858   return !RootToNode.empty();
1859 }
1860 
1861 ComplexDeinterleavingGraph::NodePtr
1862 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1863   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1864     if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1865       return nullptr;
1866 
1867     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1868     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1869     if (!Real || !Imag)
1870       return nullptr;
1871 
1872     return identifyNode(Real, Imag);
1873   }
1874 
1875   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1876   if (!SVI)
1877     return nullptr;
1878 
1879   // Look for a shufflevector that takes separate vectors of the real and
1880   // imaginary components and recombines them into a single vector.
1881   if (!isInterleavingMask(SVI->getShuffleMask()))
1882     return nullptr;
1883 
1884   Instruction *Real;
1885   Instruction *Imag;
1886   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1887     return nullptr;
1888 
1889   return identifyNode(Real, Imag);
1890 }
1891 
1892 ComplexDeinterleavingGraph::NodePtr
1893 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1894                                                  Instruction *Imag) {
1895   Instruction *I = nullptr;
1896   Value *FinalValue = nullptr;
1897   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1898       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1899       match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1900                    m_Value(FinalValue)))) {
1901     NodePtr PlaceholderNode = prepareCompositeNode(
1902         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1903     PlaceholderNode->ReplacementNode = FinalValue;
1904     FinalInstructions.insert(Real);
1905     FinalInstructions.insert(Imag);
1906     return submitCompositeNode(PlaceholderNode);
1907   }
1908 
1909   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1910   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1911   if (!RealShuffle || !ImagShuffle) {
1912     if (RealShuffle || ImagShuffle)
1913       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1914     return nullptr;
1915   }
1916 
1917   Value *RealOp1 = RealShuffle->getOperand(1);
1918   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1919     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1920     return nullptr;
1921   }
1922   Value *ImagOp1 = ImagShuffle->getOperand(1);
1923   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1924     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1925     return nullptr;
1926   }
1927 
1928   Value *RealOp0 = RealShuffle->getOperand(0);
1929   Value *ImagOp0 = ImagShuffle->getOperand(0);
1930 
1931   if (RealOp0 != ImagOp0) {
1932     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1933     return nullptr;
1934   }
1935 
1936   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1937   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1938   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1939     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1940     return nullptr;
1941   }
1942 
1943   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1944     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1945     return nullptr;
1946   }
1947 
1948   // Type checking, the shuffle type should be a vector type of the same
1949   // scalar type, but half the size
1950   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1951     Value *Op = Shuffle->getOperand(0);
1952     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1953     auto *OpTy = cast<FixedVectorType>(Op->getType());
1954 
1955     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1956       return false;
1957     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1958       return false;
1959 
1960     return true;
1961   };
1962 
1963   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1964     if (!CheckType(Shuffle))
1965       return false;
1966 
1967     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1968     int Last = *Mask.rbegin();
1969 
1970     Value *Op = Shuffle->getOperand(0);
1971     auto *OpTy = cast<FixedVectorType>(Op->getType());
1972     int NumElements = OpTy->getNumElements();
1973 
1974     // Ensure that the deinterleaving shuffle only pulls from the first
1975     // shuffle operand.
1976     return Last < NumElements;
1977   };
1978 
1979   if (RealShuffle->getType() != ImagShuffle->getType()) {
1980     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1981     return nullptr;
1982   }
1983   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1984     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1985     return nullptr;
1986   }
1987   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1988     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1989     return nullptr;
1990   }
1991 
1992   NodePtr PlaceholderNode =
1993       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1994                            RealShuffle, ImagShuffle);
1995   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1996   FinalInstructions.insert(RealShuffle);
1997   FinalInstructions.insert(ImagShuffle);
1998   return submitCompositeNode(PlaceholderNode);
1999 }
2000 
2001 ComplexDeinterleavingGraph::NodePtr
2002 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
2003   auto IsSplat = [](Value *V) -> bool {
2004     // Fixed-width vector with constants
2005     if (isa<ConstantDataVector>(V))
2006       return true;
2007 
2008     if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
2009       return isa<VectorType>(V->getType());
2010 
2011     VectorType *VTy;
2012     ArrayRef<int> Mask;
2013     // Splats are represented differently depending on whether the repeated
2014     // value is a constant or an Instruction
2015     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2016       if (Const->getOpcode() != Instruction::ShuffleVector)
2017         return false;
2018       VTy = cast<VectorType>(Const->getType());
2019       Mask = Const->getShuffleMask();
2020     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2021       VTy = Shuf->getType();
2022       Mask = Shuf->getShuffleMask();
2023     } else {
2024       return false;
2025     }
2026 
2027     // When the data type is <1 x Type>, it's not possible to differentiate
2028     // between the ComplexDeinterleaving::Deinterleave and
2029     // ComplexDeinterleaving::Splat operations.
2030     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2031       return false;
2032 
2033     return all_equal(Mask) && Mask[0] == 0;
2034   };
2035 
2036   if (!IsSplat(R) || !IsSplat(I))
2037     return nullptr;
2038 
2039   auto *Real = dyn_cast<Instruction>(R);
2040   auto *Imag = dyn_cast<Instruction>(I);
2041   if ((!Real && Imag) || (Real && !Imag))
2042     return nullptr;
2043 
2044   if (Real && Imag) {
2045     // Non-constant splats should be in the same basic block
2046     if (Real->getParent() != Imag->getParent())
2047       return nullptr;
2048 
2049     FinalInstructions.insert(Real);
2050     FinalInstructions.insert(Imag);
2051   }
2052   NodePtr PlaceholderNode =
2053       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
2054   return submitCompositeNode(PlaceholderNode);
2055 }
2056 
2057 ComplexDeinterleavingGraph::NodePtr
2058 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2059                                             Instruction *Imag) {
2060   if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2061     return nullptr;
2062 
2063   PHIsFound = true;
2064   NodePtr PlaceholderNode = prepareCompositeNode(
2065       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2066   return submitCompositeNode(PlaceholderNode);
2067 }
2068 
2069 ComplexDeinterleavingGraph::NodePtr
2070 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2071                                                Instruction *Imag) {
2072   auto *SelectReal = dyn_cast<SelectInst>(Real);
2073   auto *SelectImag = dyn_cast<SelectInst>(Imag);
2074   if (!SelectReal || !SelectImag)
2075     return nullptr;
2076 
2077   Instruction *MaskA, *MaskB;
2078   Instruction *AR, *AI, *RA, *BI;
2079   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2080                             m_Instruction(RA))) ||
2081       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2082                             m_Instruction(BI))))
2083     return nullptr;
2084 
2085   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2086     return nullptr;
2087 
2088   if (!MaskA->getType()->isVectorTy())
2089     return nullptr;
2090 
2091   auto NodeA = identifyNode(AR, AI);
2092   if (!NodeA)
2093     return nullptr;
2094 
2095   auto NodeB = identifyNode(RA, BI);
2096   if (!NodeB)
2097     return nullptr;
2098 
2099   NodePtr PlaceholderNode = prepareCompositeNode(
2100       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2101   PlaceholderNode->addOperand(NodeA);
2102   PlaceholderNode->addOperand(NodeB);
2103   FinalInstructions.insert(MaskA);
2104   FinalInstructions.insert(MaskB);
2105   return submitCompositeNode(PlaceholderNode);
2106 }
2107 
2108 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2109                                    std::optional<FastMathFlags> Flags,
2110                                    Value *InputA, Value *InputB) {
2111   Value *I;
2112   switch (Opcode) {
2113   case Instruction::FNeg:
2114     I = B.CreateFNeg(InputA);
2115     break;
2116   case Instruction::FAdd:
2117     I = B.CreateFAdd(InputA, InputB);
2118     break;
2119   case Instruction::Add:
2120     I = B.CreateAdd(InputA, InputB);
2121     break;
2122   case Instruction::FSub:
2123     I = B.CreateFSub(InputA, InputB);
2124     break;
2125   case Instruction::Sub:
2126     I = B.CreateSub(InputA, InputB);
2127     break;
2128   case Instruction::FMul:
2129     I = B.CreateFMul(InputA, InputB);
2130     break;
2131   case Instruction::Mul:
2132     I = B.CreateMul(InputA, InputB);
2133     break;
2134   default:
2135     llvm_unreachable("Incorrect symmetric opcode");
2136   }
2137   if (Flags)
2138     cast<Instruction>(I)->setFastMathFlags(*Flags);
2139   return I;
2140 }
2141 
2142 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2143                                                RawNodePtr Node) {
2144   if (Node->ReplacementNode)
2145     return Node->ReplacementNode;
2146 
2147   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
2148     return Node->Operands.size() > Idx
2149                ? replaceNode(Builder, Node->Operands[Idx])
2150                : nullptr;
2151   };
2152 
2153   Value *ReplacementNode;
2154   switch (Node->Operation) {
2155   case ComplexDeinterleavingOperation::CDot: {
2156     Value *Input0 = ReplaceOperandIfExist(Node, 0);
2157     Value *Input1 = ReplaceOperandIfExist(Node, 1);
2158     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2159     assert(!Input1 || (Input0->getType() == Input1->getType() &&
2160                        "Node inputs need to be of the same type"));
2161     ReplacementNode = TL->createComplexDeinterleavingIR(
2162         Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2163     break;
2164   }
2165   case ComplexDeinterleavingOperation::CAdd:
2166   case ComplexDeinterleavingOperation::CMulPartial:
2167   case ComplexDeinterleavingOperation::Symmetric: {
2168     Value *Input0 = ReplaceOperandIfExist(Node, 0);
2169     Value *Input1 = ReplaceOperandIfExist(Node, 1);
2170     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2171     assert(!Input1 || (Input0->getType() == Input1->getType() &&
2172                        "Node inputs need to be of the same type"));
2173     assert(!Accumulator ||
2174            (Input0->getType() == Accumulator->getType() &&
2175             "Accumulator and input need to be of the same type"));
2176     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2177       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2178                                              Input0, Input1);
2179     else
2180       ReplacementNode = TL->createComplexDeinterleavingIR(
2181           Builder, Node->Operation, Node->Rotation, Input0, Input1,
2182           Accumulator);
2183     break;
2184   }
2185   case ComplexDeinterleavingOperation::Deinterleave:
2186     llvm_unreachable("Deinterleave node should already have ReplacementNode");
2187     break;
2188   case ComplexDeinterleavingOperation::Splat: {
2189     auto *NewTy = VectorType::getDoubleElementsVectorType(
2190         cast<VectorType>(Node->Real->getType()));
2191     auto *R = dyn_cast<Instruction>(Node->Real);
2192     auto *I = dyn_cast<Instruction>(Node->Imag);
2193     if (R && I) {
2194       // Splats that are not constant are interleaved where they are located
2195       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
2196       IRBuilder<> IRB(InsertPoint);
2197       ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
2198                                             NewTy, {Node->Real, Node->Imag});
2199     } else {
2200       ReplacementNode = Builder.CreateIntrinsic(
2201           Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
2202     }
2203     break;
2204   }
2205   case ComplexDeinterleavingOperation::ReductionPHI: {
2206     // If Operation is ReductionPHI, a new empty PHINode is created.
2207     // It is filled later when the ReductionOperation is processed.
2208     auto *OldPHI = cast<PHINode>(Node->Real);
2209     auto *VTy = cast<VectorType>(Node->Real->getType());
2210     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2211     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2212     OldToNewPHI[OldPHI] = NewPHI;
2213     ReplacementNode = NewPHI;
2214     break;
2215   }
2216   case ComplexDeinterleavingOperation::ReductionSingle:
2217     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2218     processReductionSingle(ReplacementNode, Node);
2219     break;
2220   case ComplexDeinterleavingOperation::ReductionOperation:
2221     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2222     processReductionOperation(ReplacementNode, Node);
2223     break;
2224   case ComplexDeinterleavingOperation::ReductionSelect: {
2225     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
2226     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
2227     auto *A = replaceNode(Builder, Node->Operands[0]);
2228     auto *B = replaceNode(Builder, Node->Operands[1]);
2229     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
2230         cast<VectorType>(MaskReal->getType()));
2231     auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
2232                                             NewMaskTy, {MaskReal, MaskImag});
2233     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2234     break;
2235   }
2236   }
2237 
2238   assert(ReplacementNode && "Target failed to create Intrinsic call.");
2239   NumComplexTransformations += 1;
2240   Node->ReplacementNode = ReplacementNode;
2241   return ReplacementNode;
2242 }
2243 
2244 void ComplexDeinterleavingGraph::processReductionSingle(
2245     Value *OperationReplacement, RawNodePtr Node) {
2246   auto *Real = cast<Instruction>(Node->Real);
2247   auto *OldPHI = ReductionInfo[Real].first;
2248   auto *NewPHI = OldToNewPHI[OldPHI];
2249   auto *VTy = cast<VectorType>(Real->getType());
2250   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2251 
2252   Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2253 
2254   IRBuilder<> Builder(Incoming->getTerminator());
2255 
2256   Value *NewInit = nullptr;
2257   if (auto *C = dyn_cast<Constant>(Init)) {
2258     if (C->isZeroValue())
2259       NewInit = Constant::getNullValue(NewVTy);
2260   }
2261 
2262   if (!NewInit)
2263     NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2264                                       {Init, Constant::getNullValue(VTy)});
2265 
2266   NewPHI->addIncoming(NewInit, Incoming);
2267   NewPHI->addIncoming(OperationReplacement, BackEdge);
2268 
2269   auto *FinalReduction = ReductionInfo[Real].second;
2270   Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2271 
2272   auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2273   FinalReduction->replaceAllUsesWith(AddReduce);
2274 }
2275 
2276 void ComplexDeinterleavingGraph::processReductionOperation(
2277     Value *OperationReplacement, RawNodePtr Node) {
2278   auto *Real = cast<Instruction>(Node->Real);
2279   auto *Imag = cast<Instruction>(Node->Imag);
2280   auto *OldPHIReal = ReductionInfo[Real].first;
2281   auto *OldPHIImag = ReductionInfo[Imag].first;
2282   auto *NewPHI = OldToNewPHI[OldPHIReal];
2283 
2284   auto *VTy = cast<VectorType>(Real->getType());
2285   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2286 
2287   // We have to interleave initial origin values coming from IncomingBlock
2288   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2289   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2290 
2291   IRBuilder<> Builder(Incoming->getTerminator());
2292   auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2293                                           {InitReal, InitImag});
2294 
2295   NewPHI->addIncoming(NewInit, Incoming);
2296   NewPHI->addIncoming(OperationReplacement, BackEdge);
2297 
2298   // Deinterleave complex vector outside of loop so that it can be finally
2299   // reduced
2300   auto *FinalReductionReal = ReductionInfo[Real].second;
2301   auto *FinalReductionImag = ReductionInfo[Imag].second;
2302 
2303   Builder.SetInsertPoint(
2304       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2305   auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2306                                                OperationReplacement->getType(),
2307                                                OperationReplacement);
2308 
2309   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2310   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2311 
2312   Builder.SetInsertPoint(FinalReductionImag);
2313   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2314   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2315 }
2316 
2317 void ComplexDeinterleavingGraph::replaceNodes() {
2318   SmallVector<Instruction *, 16> DeadInstrRoots;
2319   for (auto *RootInstruction : OrderedRoots) {
2320     // Check if this potential root went through check process and we can
2321     // deinterleave it
2322     if (!RootToNode.count(RootInstruction))
2323       continue;
2324 
2325     IRBuilder<> Builder(RootInstruction);
2326     auto RootNode = RootToNode[RootInstruction];
2327     Value *R = replaceNode(Builder, RootNode.get());
2328 
2329     if (RootNode->Operation ==
2330         ComplexDeinterleavingOperation::ReductionOperation) {
2331       auto *RootReal = cast<Instruction>(RootNode->Real);
2332       auto *RootImag = cast<Instruction>(RootNode->Imag);
2333       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2334       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2335       DeadInstrRoots.push_back(RootReal);
2336       DeadInstrRoots.push_back(RootImag);
2337     } else if (RootNode->Operation ==
2338                ComplexDeinterleavingOperation::ReductionSingle) {
2339       auto *RootInst = cast<Instruction>(RootNode->Real);
2340       auto &Info = ReductionInfo[RootInst];
2341       Info.first->removeIncomingValue(BackEdge);
2342       DeadInstrRoots.push_back(Info.second);
2343     } else {
2344       assert(R && "Unable to find replacement for RootInstruction");
2345       DeadInstrRoots.push_back(RootInstruction);
2346       RootInstruction->replaceAllUsesWith(R);
2347     }
2348   }
2349 
2350   for (auto *I : DeadInstrRoots)
2351     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2352 }
2353