xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision ba3c1f5972d7b90feb6e6da47905ff2757e0fe57)
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 //
22 // Replacement:
23 // This step traverses the graph built up by identification, delegating to the
24 // target to validate and generate the correct intrinsics, and plumbs them
25 // together connecting each end of the new intrinsics graph to the existing
26 // use-def chain. This step is assumed to finish successfully, as all
27 // information is expected to be correct by this point.
28 //
29 //
30 // Internal data structure:
31 // ComplexDeinterleavingGraph:
32 // Keeps references to all the valid CompositeNodes formed as part of the
33 // transformation, and every Instruction contained within said nodes. It also
34 // holds onto a reference to the root Instruction, and the root node that should
35 // replace it.
36 //
37 // ComplexDeinterleavingCompositeNode:
38 // A CompositeNode represents a single transformation point; each node should
39 // transform into a single complex instruction (ignoring vector splitting, which
40 // would generate more instructions per node). They are identified in a
41 // depth-first manner, traversing and identifying the operands of each
42 // instruction in the order they appear in the IR.
43 // Each node maintains a reference  to its Real and Imaginary instructions,
44 // as well as any additional instructions that make up the identified operation
45 // (Internal instructions should only have uses within their containing node).
46 // A Node also contains the rotation and operation type that it represents.
47 // Operands contains pointers to other CompositeNodes, acting as the edges in
48 // the graph. ReplacementValue is the transformed Value* that has been emitted
49 // to the IR.
50 //
51 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
52 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
53 // should be pre-populated.
54 //
55 //===----------------------------------------------------------------------===//
56 
57 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
58 #include "llvm/ADT/Statistic.h"
59 #include "llvm/Analysis/TargetLibraryInfo.h"
60 #include "llvm/Analysis/TargetTransformInfo.h"
61 #include "llvm/CodeGen/TargetLowering.h"
62 #include "llvm/CodeGen/TargetPassConfig.h"
63 #include "llvm/CodeGen/TargetSubtargetInfo.h"
64 #include "llvm/IR/IRBuilder.h"
65 #include "llvm/InitializePasses.h"
66 #include "llvm/Target/TargetMachine.h"
67 #include "llvm/Transforms/Utils/Local.h"
68 #include <algorithm>
69 
70 using namespace llvm;
71 using namespace PatternMatch;
72 
73 #define DEBUG_TYPE "complex-deinterleaving"
74 
75 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
76 
77 static cl::opt<bool> ComplexDeinterleavingEnabled(
78     "enable-complex-deinterleaving",
79     cl::desc("Enable generation of complex instructions"), cl::init(true),
80     cl::Hidden);
81 
82 /// Checks the given mask, and determines whether said mask is interleaving.
83 ///
84 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
85 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
86 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
87 static bool isInterleavingMask(ArrayRef<int> Mask);
88 
89 /// Checks the given mask, and determines whether said mask is deinterleaving.
90 ///
91 /// To be deinterleaving, a mask must increment in steps of 2, and either start
92 /// with 0 or 1.
93 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
94 /// <1, 3, 5, 7>).
95 static bool isDeinterleavingMask(ArrayRef<int> Mask);
96 
97 namespace {
98 
99 class ComplexDeinterleavingLegacyPass : public FunctionPass {
100 public:
101   static char ID;
102 
103   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
104       : FunctionPass(ID), TM(TM) {
105     initializeComplexDeinterleavingLegacyPassPass(
106         *PassRegistry::getPassRegistry());
107   }
108 
109   StringRef getPassName() const override {
110     return "Complex Deinterleaving Pass";
111   }
112 
113   bool runOnFunction(Function &F) override;
114   void getAnalysisUsage(AnalysisUsage &AU) const override {
115     AU.addRequired<TargetLibraryInfoWrapperPass>();
116     AU.setPreservesCFG();
117   }
118 
119 private:
120   const TargetMachine *TM;
121 };
122 
123 class ComplexDeinterleavingGraph;
124 struct ComplexDeinterleavingCompositeNode {
125 
126   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
127                                      Instruction *R, Instruction *I)
128       : Operation(Op), Real(R), Imag(I) {}
129 
130 private:
131   friend class ComplexDeinterleavingGraph;
132   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
133   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
134 
135 public:
136   ComplexDeinterleavingOperation Operation;
137   Instruction *Real;
138   Instruction *Imag;
139 
140   // Instructions that should only exist within this node, there should be no
141   // users of these instructions outside the node. An example of these would be
142   // the multiply instructions of a partial multiply operation.
143   SmallVector<Instruction *> InternalInstructions;
144   ComplexDeinterleavingRotation Rotation;
145   SmallVector<RawNodePtr> Operands;
146   Value *ReplacementNode = nullptr;
147 
148   void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
149   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
150 
151   bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
152 
153   void dump() { dump(dbgs()); }
154   void dump(raw_ostream &OS) {
155     auto PrintValue = [&](Value *V) {
156       if (V) {
157         OS << "\"";
158         V->print(OS, true);
159         OS << "\"\n";
160       } else
161         OS << "nullptr\n";
162     };
163     auto PrintNodeRef = [&](RawNodePtr Ptr) {
164       if (Ptr)
165         OS << Ptr << "\n";
166       else
167         OS << "nullptr\n";
168     };
169 
170     OS << "- CompositeNode: " << this << "\n";
171     OS << "  Real: ";
172     PrintValue(Real);
173     OS << "  Imag: ";
174     PrintValue(Imag);
175     OS << "  ReplacementNode: ";
176     PrintValue(ReplacementNode);
177     OS << "  Operation: " << (int)Operation << "\n";
178     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
179     OS << "  Operands: \n";
180     for (const auto &Op : Operands) {
181       OS << "    - ";
182       PrintNodeRef(Op);
183     }
184     OS << "  InternalInstructions:\n";
185     for (const auto &I : InternalInstructions) {
186       OS << "    - \"";
187       I->print(OS, true);
188       OS << "\"\n";
189     }
190   }
191 };
192 
193 class ComplexDeinterleavingGraph {
194 public:
195   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
196   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
197   explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
198 
199 private:
200   const TargetLowering *TL;
201   Instruction *RootValue;
202   NodePtr RootNode;
203   SmallVector<NodePtr> CompositeNodes;
204   SmallPtrSet<Instruction *, 16> AllInstructions;
205 
206   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
207                                Instruction *R, Instruction *I) {
208     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
209                                                                 I);
210   }
211 
212   NodePtr submitCompositeNode(NodePtr Node) {
213     CompositeNodes.push_back(Node);
214     AllInstructions.insert(Node->Real);
215     AllInstructions.insert(Node->Imag);
216     for (auto *I : Node->InternalInstructions)
217       AllInstructions.insert(I);
218     return Node;
219   }
220 
221   NodePtr getContainingComposite(Value *R, Value *I) {
222     for (const auto &CN : CompositeNodes) {
223       if (CN->Real == R && CN->Imag == I)
224         return CN;
225     }
226     return nullptr;
227   }
228 
229   /// Identifies a complex partial multiply pattern and its rotation, based on
230   /// the following patterns
231   ///
232   ///  0:  r: cr + ar * br
233   ///      i: ci + ar * bi
234   /// 90:  r: cr - ai * bi
235   ///      i: ci + ai * br
236   /// 180: r: cr - ar * br
237   ///      i: ci - ar * bi
238   /// 270: r: cr + ai * bi
239   ///      i: ci - ai * br
240   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
241 
242   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
243   /// is partially known from identifyPartialMul, filling in the other half of
244   /// the complex pair.
245   NodePtr identifyNodeWithImplicitAdd(
246       Instruction *I, Instruction *J,
247       std::pair<Instruction *, Instruction *> &CommonOperandI);
248 
249   /// Identifies a complex add pattern and its rotation, based on the following
250   /// patterns.
251   ///
252   /// 90:  r: ar - bi
253   ///      i: ai + br
254   /// 270: r: ar + bi
255   ///      i: ai - br
256   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
257 
258   NodePtr identifyNode(Instruction *I, Instruction *J);
259 
260   Value *replaceNode(RawNodePtr Node);
261 
262 public:
263   void dump() { dump(dbgs()); }
264   void dump(raw_ostream &OS) {
265     for (const auto &Node : CompositeNodes)
266       Node->dump(OS);
267   }
268 
269   /// Returns false if the deinterleaving operation should be cancelled for the
270   /// current graph.
271   bool identifyNodes(Instruction *RootI);
272 
273   /// Perform the actual replacement of the underlying instruction graph.
274   /// Returns false if the deinterleaving operation should be cancelled for the
275   /// current graph.
276   void replaceNodes();
277 };
278 
279 class ComplexDeinterleaving {
280 public:
281   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
282       : TL(tl), TLI(tli) {}
283   bool runOnFunction(Function &F);
284 
285 private:
286   bool evaluateBasicBlock(BasicBlock *B);
287 
288   const TargetLowering *TL = nullptr;
289   const TargetLibraryInfo *TLI = nullptr;
290 };
291 
292 } // namespace
293 
294 char ComplexDeinterleavingLegacyPass::ID = 0;
295 
296 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
297                       "Complex Deinterleaving", false, false)
298 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
299                     "Complex Deinterleaving", false, false)
300 
301 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
302                                                  FunctionAnalysisManager &AM) {
303   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
304   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
305   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
306     return PreservedAnalyses::all();
307 
308   PreservedAnalyses PA;
309   PA.preserve<FunctionAnalysisManagerModuleProxy>();
310   return PA;
311 }
312 
313 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
314   return new ComplexDeinterleavingLegacyPass(TM);
315 }
316 
317 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
318   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
319   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
320   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
321 }
322 
323 bool ComplexDeinterleaving::runOnFunction(Function &F) {
324   if (!ComplexDeinterleavingEnabled) {
325     LLVM_DEBUG(
326         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
327     return false;
328   }
329 
330   if (!TL->isComplexDeinterleavingSupported()) {
331     LLVM_DEBUG(
332         dbgs() << "Complex deinterleaving has been disabled, target does "
333                   "not support lowering of complex number operations.\n");
334     return false;
335   }
336 
337   bool Changed = false;
338   for (auto &B : F)
339     Changed |= evaluateBasicBlock(&B);
340 
341   return Changed;
342 }
343 
344 static bool isInterleavingMask(ArrayRef<int> Mask) {
345   // If the size is not even, it's not an interleaving mask
346   if ((Mask.size() & 1))
347     return false;
348 
349   int HalfNumElements = Mask.size() / 2;
350   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
351     int MaskIdx = Idx * 2;
352     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
353       return false;
354   }
355 
356   return true;
357 }
358 
359 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
360   int Offset = Mask[0];
361   int HalfNumElements = Mask.size() / 2;
362 
363   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
364     if (Mask[Idx] != (Idx * 2) + Offset)
365       return false;
366   }
367 
368   return true;
369 }
370 
371 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
372   bool Changed = false;
373 
374   SmallVector<Instruction *> DeadInstrRoots;
375 
376   for (auto &I : *B) {
377     auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
378     if (!SVI)
379       continue;
380 
381     // Look for a shufflevector that takes separate vectors of the real and
382     // imaginary components and recombines them into a single vector.
383     if (!isInterleavingMask(SVI->getShuffleMask()))
384       continue;
385 
386     ComplexDeinterleavingGraph Graph(TL);
387     if (!Graph.identifyNodes(SVI))
388       continue;
389 
390     Graph.replaceNodes();
391     DeadInstrRoots.push_back(SVI);
392     Changed = true;
393   }
394 
395   for (const auto &I : DeadInstrRoots) {
396     if (!I || I->getParent() == nullptr)
397       continue;
398     llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
399   }
400 
401   return Changed;
402 }
403 
404 ComplexDeinterleavingGraph::NodePtr
405 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
406     Instruction *Real, Instruction *Imag,
407     std::pair<Instruction *, Instruction *> &PartialMatch) {
408   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
409                     << "\n");
410 
411   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
412     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
413     return nullptr;
414   }
415 
416   if (Real->getOpcode() != Instruction::FMul ||
417       Imag->getOpcode() != Instruction::FMul) {
418     LLVM_DEBUG(dbgs() << "  - Real or imaginary instruction is not fmul\n");
419     return nullptr;
420   }
421 
422   Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
423   Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
424   Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
425   Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
426   if (!R0 || !R1 || !I0 || !I1) {
427     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
428     return nullptr;
429   }
430 
431   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
432   // rotations and use the operand.
433   unsigned Negs = 0;
434   SmallVector<Instruction *> FNegs;
435   if (R0->getOpcode() == Instruction::FNeg ||
436       R1->getOpcode() == Instruction::FNeg) {
437     Negs |= 1;
438     if (R0->getOpcode() == Instruction::FNeg) {
439       FNegs.push_back(R0);
440       R0 = dyn_cast<Instruction>(R0->getOperand(0));
441     } else {
442       FNegs.push_back(R1);
443       R1 = dyn_cast<Instruction>(R1->getOperand(0));
444     }
445     if (!R0 || !R1)
446       return nullptr;
447   }
448   if (I0->getOpcode() == Instruction::FNeg ||
449       I1->getOpcode() == Instruction::FNeg) {
450     Negs |= 2;
451     Negs ^= 1;
452     if (I0->getOpcode() == Instruction::FNeg) {
453       FNegs.push_back(I0);
454       I0 = dyn_cast<Instruction>(I0->getOperand(0));
455     } else {
456       FNegs.push_back(I1);
457       I1 = dyn_cast<Instruction>(I1->getOperand(0));
458     }
459     if (!I0 || !I1)
460       return nullptr;
461   }
462 
463   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
464 
465   Instruction *CommonOperand;
466   Instruction *UncommonRealOp;
467   Instruction *UncommonImagOp;
468 
469   if (R0 == I0 || R0 == I1) {
470     CommonOperand = R0;
471     UncommonRealOp = R1;
472   } else if (R1 == I0 || R1 == I1) {
473     CommonOperand = R1;
474     UncommonRealOp = R0;
475   } else {
476     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
477     return nullptr;
478   }
479 
480   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
481   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
482       Rotation == ComplexDeinterleavingRotation::Rotation_270)
483     std::swap(UncommonRealOp, UncommonImagOp);
484 
485   // Between identifyPartialMul and here we need to have found a complete valid
486   // pair from the CommonOperand of each part.
487   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
488       Rotation == ComplexDeinterleavingRotation::Rotation_180)
489     PartialMatch.first = CommonOperand;
490   else
491     PartialMatch.second = CommonOperand;
492 
493   if (!PartialMatch.first || !PartialMatch.second) {
494     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
495     return nullptr;
496   }
497 
498   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
499   if (!CommonNode) {
500     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
501     return nullptr;
502   }
503 
504   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
505   if (!UncommonNode) {
506     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
507     return nullptr;
508   }
509 
510   NodePtr Node = prepareCompositeNode(
511       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
512   Node->Rotation = Rotation;
513   Node->addOperand(CommonNode);
514   Node->addOperand(UncommonNode);
515   Node->InternalInstructions.append(FNegs);
516   return submitCompositeNode(Node);
517 }
518 
519 ComplexDeinterleavingGraph::NodePtr
520 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
521                                                Instruction *Imag) {
522   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
523                     << "\n");
524   // Determine rotation
525   ComplexDeinterleavingRotation Rotation;
526   if (Real->getOpcode() == Instruction::FAdd &&
527       Imag->getOpcode() == Instruction::FAdd)
528     Rotation = ComplexDeinterleavingRotation::Rotation_0;
529   else if (Real->getOpcode() == Instruction::FSub &&
530            Imag->getOpcode() == Instruction::FAdd)
531     Rotation = ComplexDeinterleavingRotation::Rotation_90;
532   else if (Real->getOpcode() == Instruction::FSub &&
533            Imag->getOpcode() == Instruction::FSub)
534     Rotation = ComplexDeinterleavingRotation::Rotation_180;
535   else if (Real->getOpcode() == Instruction::FAdd &&
536            Imag->getOpcode() == Instruction::FSub)
537     Rotation = ComplexDeinterleavingRotation::Rotation_270;
538   else {
539     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
540     return nullptr;
541   }
542 
543   if (!Real->getFastMathFlags().allowContract() ||
544       !Imag->getFastMathFlags().allowContract()) {
545     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
546     return nullptr;
547   }
548 
549   Value *CR = Real->getOperand(0);
550   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
551   if (!RealMulI)
552     return nullptr;
553   Value *CI = Imag->getOperand(0);
554   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
555   if (!ImagMulI)
556     return nullptr;
557 
558   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
559     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
560     return nullptr;
561   }
562 
563   Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
564   Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
565   Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
566   Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
567   if (!R0 || !R1 || !I0 || !I1) {
568     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
569     return nullptr;
570   }
571 
572   Instruction *CommonOperand;
573   Instruction *UncommonRealOp;
574   Instruction *UncommonImagOp;
575 
576   if (R0 == I0 || R0 == I1) {
577     CommonOperand = R0;
578     UncommonRealOp = R1;
579   } else if (R1 == I0 || R1 == I1) {
580     CommonOperand = R1;
581     UncommonRealOp = R0;
582   } else {
583     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
584     return nullptr;
585   }
586 
587   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
588   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
589       Rotation == ComplexDeinterleavingRotation::Rotation_270)
590     std::swap(UncommonRealOp, UncommonImagOp);
591 
592   std::pair<Instruction *, Instruction *> PartialMatch(
593       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
594        Rotation == ComplexDeinterleavingRotation::Rotation_180)
595           ? CommonOperand
596           : nullptr,
597       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
598        Rotation == ComplexDeinterleavingRotation::Rotation_270)
599           ? CommonOperand
600           : nullptr);
601   NodePtr CNode = identifyNodeWithImplicitAdd(
602       cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
603   if (!CNode) {
604     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
605     return nullptr;
606   }
607 
608   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
609   if (!UncommonRes) {
610     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
611     return nullptr;
612   }
613 
614   assert(PartialMatch.first && PartialMatch.second);
615   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
616   if (!CommonRes) {
617     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
618     return nullptr;
619   }
620 
621   NodePtr Node = prepareCompositeNode(
622       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
623   Node->addInstruction(RealMulI);
624   Node->addInstruction(ImagMulI);
625   Node->Rotation = Rotation;
626   Node->addOperand(CommonRes);
627   Node->addOperand(UncommonRes);
628   Node->addOperand(CNode);
629   return submitCompositeNode(Node);
630 }
631 
632 ComplexDeinterleavingGraph::NodePtr
633 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
634   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
635 
636   // Determine rotation
637   ComplexDeinterleavingRotation Rotation;
638   if ((Real->getOpcode() == Instruction::FSub &&
639        Imag->getOpcode() == Instruction::FAdd) ||
640       (Real->getOpcode() == Instruction::Sub &&
641        Imag->getOpcode() == Instruction::Add))
642     Rotation = ComplexDeinterleavingRotation::Rotation_90;
643   else if ((Real->getOpcode() == Instruction::FAdd &&
644             Imag->getOpcode() == Instruction::FSub) ||
645            (Real->getOpcode() == Instruction::Add &&
646             Imag->getOpcode() == Instruction::Sub))
647     Rotation = ComplexDeinterleavingRotation::Rotation_270;
648   else {
649     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
650     return nullptr;
651   }
652 
653   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
654   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
655   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
656   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
657 
658   if (!AR || !AI || !BR || !BI) {
659     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
660     return nullptr;
661   }
662 
663   NodePtr ResA = identifyNode(AR, AI);
664   if (!ResA) {
665     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
666     return nullptr;
667   }
668   NodePtr ResB = identifyNode(BR, BI);
669   if (!ResB) {
670     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
671     return nullptr;
672   }
673 
674   NodePtr Node =
675       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
676   Node->Rotation = Rotation;
677   Node->addOperand(ResA);
678   Node->addOperand(ResB);
679   return submitCompositeNode(Node);
680 }
681 
682 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
683   unsigned OpcA = A->getOpcode();
684   unsigned OpcB = B->getOpcode();
685 
686   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
687          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
688          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
689          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
690 }
691 
692 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
693   auto Pattern =
694       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
695 
696   return match(A, Pattern) && match(B, Pattern);
697 }
698 
699 ComplexDeinterleavingGraph::NodePtr
700 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
701   LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
702   if (NodePtr CN = getContainingComposite(Real, Imag)) {
703     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
704     return CN;
705   }
706 
707   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
708   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
709   if (RealShuffle && ImagShuffle) {
710     Value *RealOp1 = RealShuffle->getOperand(1);
711     if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
712       LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
713       return nullptr;
714     }
715     Value *ImagOp1 = ImagShuffle->getOperand(1);
716     if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
717       LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
718       return nullptr;
719     }
720 
721     Value *RealOp0 = RealShuffle->getOperand(0);
722     Value *ImagOp0 = ImagShuffle->getOperand(0);
723 
724     if (RealOp0 != ImagOp0) {
725       LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
726       return nullptr;
727     }
728 
729     ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
730     ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
731     if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
732       LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
733       return nullptr;
734     }
735 
736     if (RealMask[0] != 0 || ImagMask[0] != 1) {
737       LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
738       return nullptr;
739     }
740 
741     // Type checking, the shuffle type should be a vector type of the same
742     // scalar type, but half the size
743     auto CheckType = [&](ShuffleVectorInst *Shuffle) {
744       Value *Op = Shuffle->getOperand(0);
745       auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
746       auto *OpTy = cast<FixedVectorType>(Op->getType());
747 
748       if (OpTy->getScalarType() != ShuffleTy->getScalarType())
749         return false;
750       if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
751         return false;
752 
753       return true;
754     };
755 
756     auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
757       if (!CheckType(Shuffle))
758         return false;
759 
760       ArrayRef<int> Mask = Shuffle->getShuffleMask();
761       int Last = *Mask.rbegin();
762 
763       Value *Op = Shuffle->getOperand(0);
764       auto *OpTy = cast<FixedVectorType>(Op->getType());
765       int NumElements = OpTy->getNumElements();
766 
767       // Ensure that the deinterleaving shuffle only pulls from the first
768       // shuffle operand.
769       return Last < NumElements;
770     };
771 
772     if (RealShuffle->getType() != ImagShuffle->getType()) {
773       LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
774       return nullptr;
775     }
776     if (!CheckDeinterleavingShuffle(RealShuffle)) {
777       LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
778       return nullptr;
779     }
780     if (!CheckDeinterleavingShuffle(ImagShuffle)) {
781       LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
782       return nullptr;
783     }
784 
785     NodePtr PlaceholderNode =
786         prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
787                              RealShuffle, ImagShuffle);
788     PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
789     return submitCompositeNode(PlaceholderNode);
790   }
791   if (RealShuffle || ImagShuffle)
792     return nullptr;
793 
794   auto *VTy = cast<FixedVectorType>(Real->getType());
795   auto *NewVTy =
796       FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2);
797 
798   if (TL->isComplexDeinterleavingOperationSupported(
799           ComplexDeinterleavingOperation::CMulPartial, NewVTy) &&
800       isInstructionPairMul(Real, Imag)) {
801     return identifyPartialMul(Real, Imag);
802   }
803 
804   if (TL->isComplexDeinterleavingOperationSupported(
805           ComplexDeinterleavingOperation::CAdd, NewVTy) &&
806       isInstructionPairAdd(Real, Imag)) {
807     return identifyAdd(Real, Imag);
808   }
809 
810   return nullptr;
811 }
812 
813 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
814   Instruction *Real;
815   Instruction *Imag;
816   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
817     return false;
818 
819   RootValue = RootI;
820   AllInstructions.insert(RootI);
821   RootNode = identifyNode(Real, Imag);
822 
823   LLVM_DEBUG({
824     Function *F = RootI->getFunction();
825     BasicBlock *B = RootI->getParent();
826     dbgs() << "Complex deinterleaving graph for " << F->getName()
827            << "::" << B->getName() << ".\n";
828     dump(dbgs());
829     dbgs() << "\n";
830   });
831 
832   // Check all instructions have internal uses
833   for (const auto &Node : CompositeNodes) {
834     if (!Node->hasAllInternalUses(AllInstructions)) {
835       LLVM_DEBUG(dbgs() << "  - Invalid internal uses\n");
836       return false;
837     }
838   }
839   return RootNode != nullptr;
840 }
841 
842 Value *ComplexDeinterleavingGraph::replaceNode(
843     ComplexDeinterleavingGraph::RawNodePtr Node) {
844   if (Node->ReplacementNode)
845     return Node->ReplacementNode;
846 
847   Value *Input0 = replaceNode(Node->Operands[0]);
848   Value *Input1 = replaceNode(Node->Operands[1]);
849   Value *Accumulator =
850       Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
851 
852   assert(Input0->getType() == Input1->getType() &&
853          "Node inputs need to be of the same type");
854 
855   Node->ReplacementNode = TL->createComplexDeinterleavingIR(
856       Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
857 
858   assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
859   NumComplexTransformations += 1;
860   return Node->ReplacementNode;
861 }
862 
863 void ComplexDeinterleavingGraph::replaceNodes() {
864   Value *R = replaceNode(RootNode.get());
865   assert(R && "Unable to find replacement for RootValue");
866   RootValue->replaceAllUsesWith(R);
867 }
868 
869 bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
870     SmallPtrSet<Instruction *, 16> &AllInstructions) {
871   if (Operation == ComplexDeinterleavingOperation::Shuffle)
872     return true;
873 
874   for (auto *User : Real->users()) {
875     if (!AllInstructions.contains(cast<Instruction>(User)))
876       return false;
877   }
878   for (auto *User : Imag->users()) {
879     if (!AllInstructions.contains(cast<Instruction>(User)))
880       return false;
881   }
882   for (auto *I : InternalInstructions) {
883     for (auto *User : I->users()) {
884       if (!AllInstructions.contains(cast<Instruction>(User)))
885         return false;
886     }
887   }
888   return true;
889 }
890