//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Identification:
// This step is responsible for finding the patterns that can be lowered to
// complex instructions, and building a graph to represent the complex
// structures. Starting from the "Converging Shuffle" (a shuffle that
// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
// operands are evaluated and identified as "Composite Nodes" (collections of
// instructions that can potentially be lowered to a single complex
// instruction). This is performed by checking the real and imaginary components
// and tracking the data flow for each component while following the operand
// pairs. Validity of each node is expected to be done upon creation, and any
// validation errors should halt traversal and prevent further graph
// construction.
// Instead of relying on Shuffle operations, vector interleaving and
// deinterleaving can be represented by vector.interleave2 and
// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
// these intrinsics, whereas, fixed-width vectors are recognized for both
// shufflevector instruction and intrinsics.
//
// Replacement:
// This step traverses the graph built up by identification, delegating to the
// target to validate and generate the correct intrinsics, and plumbs them
// together connecting each end of the new intrinsics graph to the existing
// use-def chain. This step is assumed to finish successfully, as all
// information is expected to be correct by this point.
//
//
// Internal data structure:
// ComplexDeinterleavingGraph:
// Keeps references to all the valid CompositeNodes formed as part of the
// transformation, and every Instruction contained within said nodes. It also
// holds onto a reference to the root Instruction, and the root node that should
// replace it.
//
// ComplexDeinterleavingCompositeNode:
// A CompositeNode represents a single transformation point; each node should
// transform into a single complex instruction (ignoring vector splitting, which
// would generate more instructions per node). They are identified in a
// depth-first manner, traversing and identifying the operands of each
// instruction in the order they appear in the IR.
// Each node maintains a reference  to its Real and Imaginary instructions,
// as well as any additional instructions that make up the identified operation
// (Internal instructions should only have uses within their containing node).
// A Node also contains the rotation and operation type that it represents.
// Operands contains pointers to other CompositeNodes, acting as the edges in
// the graph. ReplacementValue is the transformed Value* that has been emitted
// to the IR.
//
// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
// ReplacementValue fields of that Node are relevant, where the ReplacementValue
// should be pre-populated.
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>

using namespace llvm;
using namespace PatternMatch;

#define DEBUG_TYPE "complex-deinterleaving"

STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");

static cl::opt<bool> ComplexDeinterleavingEnabled(
    "enable-complex-deinterleaving",
    cl::desc("Enable generation of complex instructions"), cl::init(true),
    cl::Hidden);

/// Checks the given mask, and determines whether said mask is interleaving.
///
/// To be interleaving, a mask must alternate between `i` and `i + (Length /
/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
/// 4x vector interleaving mask would be <0, 2, 1, 3>).
static bool isInterleavingMask(ArrayRef<int> Mask);

/// Checks the given mask, and determines whether said mask is deinterleaving.
///
/// To be deinterleaving, a mask must increment in steps of 2, and either start
/// with 0 or 1.
/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
/// <1, 3, 5, 7>).
static bool isDeinterleavingMask(ArrayRef<int> Mask);

/// Returns true if the operation is a negation of V, and it works for both
/// integers and floats.
static bool isNeg(Value *V);

/// Returns the operand for negation operation.
static Value *getNegOperand(Value *V);

namespace {

class ComplexDeinterleavingLegacyPass : public FunctionPass {
public:
  static char ID;

  ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
      : FunctionPass(ID), TM(TM) {
    initializeComplexDeinterleavingLegacyPassPass(
        *PassRegistry::getPassRegistry());
  }

  StringRef getPassName() const override {
    return "Complex Deinterleaving Pass";
  }

  bool runOnFunction(Function &F) override;
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetLibraryInfoWrapperPass>();
    AU.setPreservesCFG();
  }

private:
  const TargetMachine *TM;
};

class ComplexDeinterleavingGraph;
struct ComplexDeinterleavingCompositeNode {

  ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
                                     Value *R, Value *I)
      : Operation(Op), Real(R), Imag(I) {}

private:
  friend class ComplexDeinterleavingGraph;
  using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
  using RawNodePtr = ComplexDeinterleavingCompositeNode *;

public:
  ComplexDeinterleavingOperation Operation;
  Value *Real;
  Value *Imag;

  // This two members are required exclusively for generating
  // ComplexDeinterleavingOperation::Symmetric operations.
  unsigned Opcode;
  std::optional<FastMathFlags> Flags;

  ComplexDeinterleavingRotation Rotation =
      ComplexDeinterleavingRotation::Rotation_0;
  SmallVector<RawNodePtr> Operands;
  Value *ReplacementNode = nullptr;

  void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }

  void dump() { dump(dbgs()); }
  void dump(raw_ostream &OS) {
    auto PrintValue = [&](Value *V) {
      if (V) {
        OS << "\"";
        V->print(OS, true);
        OS << "\"\n";
      } else
        OS << "nullptr\n";
    };
    auto PrintNodeRef = [&](RawNodePtr Ptr) {
      if (Ptr)
        OS << Ptr << "\n";
      else
        OS << "nullptr\n";
    };

    OS << "- CompositeNode: " << this << "\n";
    OS << "  Real: ";
    PrintValue(Real);
    OS << "  Imag: ";
    PrintValue(Imag);
    OS << "  ReplacementNode: ";
    PrintValue(ReplacementNode);
    OS << "  Operation: " << (int)Operation << "\n";
    OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
    OS << "  Operands: \n";
    for (const auto &Op : Operands) {
      OS << "    - ";
      PrintNodeRef(Op);
    }
  }
};

class ComplexDeinterleavingGraph {
public:
  struct Product {
    Value *Multiplier;
    Value *Multiplicand;
    bool IsPositive;
  };

  using Addend = std::pair<Value *, bool>;
  using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
  using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;

  // Helper struct for holding info about potential partial multiplication
  // candidates
  struct PartialMulCandidate {
    Value *Common;
    NodePtr Node;
    unsigned RealIdx;
    unsigned ImagIdx;
    bool IsNodeInverted;
  };

  explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
                                      const TargetLibraryInfo *TLI)
      : TL(TL), TLI(TLI) {}

private:
  const TargetLowering *TL = nullptr;
  const TargetLibraryInfo *TLI = nullptr;
  SmallVector<NodePtr> CompositeNodes;
  DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;

  SmallPtrSet<Instruction *, 16> FinalInstructions;

  /// Root instructions are instructions from which complex computation starts
  std::map<Instruction *, NodePtr> RootToNode;

  /// Topologically sorted root instructions
  SmallVector<Instruction *, 1> OrderedRoots;

  /// When examining a basic block for complex deinterleaving, if it is a simple
  /// one-block loop, then the only incoming block is 'Incoming' and the
  /// 'BackEdge' block is the block itself."
  BasicBlock *BackEdge = nullptr;
  BasicBlock *Incoming = nullptr;

  /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
  /// %OutsideUser as it is shown in the IR:
  ///
  /// vector.body:
  ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
  ///                                [ %ReductionOp, %vector.body ]
  ///   ...
  ///   %ReductionOp = fadd i64 ...
  ///   ...
  ///   br i1 %condition, label %vector.body, %middle.block
  ///
  /// middle.block:
  ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
  ///
  /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
  /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
  MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;

  /// In the process of detecting a reduction, we consider a pair of
  /// %ReductionOP, which we refer to as real and imag (or vice versa), and
  /// traverse the use-tree to detect complex operations. As this is a reduction
  /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
  /// to the %ReductionOPs that we suspect to be complex.
  /// RealPHI and ImagPHI are used by the identifyPHINode method.
  PHINode *RealPHI = nullptr;
  PHINode *ImagPHI = nullptr;

  /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
  /// detection.
  bool PHIsFound = false;

  /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
  /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
  /// This mapping is populated during
  /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
  /// used in the ComplexDeinterleavingOperation::ReductionOperation node
  /// replacement process.
  std::map<PHINode *, PHINode *> OldToNewPHI;

  NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
                               Value *R, Value *I) {
    assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
             Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
            (R && I)) &&
           "Reduction related nodes must have Real and Imaginary parts");
    return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
                                                                I);
  }

  NodePtr submitCompositeNode(NodePtr Node) {
    CompositeNodes.push_back(Node);
    if (Node->Real && Node->Imag)
      CachedResult[{Node->Real, Node->Imag}] = Node;
    return Node;
  }

  /// Identifies a complex partial multiply pattern and its rotation, based on
  /// the following patterns
  ///
  ///  0:  r: cr + ar * br
  ///      i: ci + ar * bi
  /// 90:  r: cr - ai * bi
  ///      i: ci + ai * br
  /// 180: r: cr - ar * br
  ///      i: ci - ar * bi
  /// 270: r: cr + ai * bi
  ///      i: ci - ai * br
  NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);

  /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
  /// is partially known from identifyPartialMul, filling in the other half of
  /// the complex pair.
  NodePtr
  identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
                              std::pair<Value *, Value *> &CommonOperandI);

  /// Identifies a complex add pattern and its rotation, based on the following
  /// patterns.
  ///
  /// 90:  r: ar - bi
  ///      i: ai + br
  /// 270: r: ar + bi
  ///      i: ai - br
  NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
  NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);

  NodePtr identifyNode(Value *R, Value *I);

  /// Determine if a sum of complex numbers can be formed from \p RealAddends
  /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
  /// Return nullptr if it is not possible to construct a complex number.
  /// \p Flags are needed to generate symmetric Add and Sub operations.
  NodePtr identifyAdditions(std::list<Addend> &RealAddends,
                            std::list<Addend> &ImagAddends,
                            std::optional<FastMathFlags> Flags,
                            NodePtr Accumulator);

  /// Extract one addend that have both real and imaginary parts positive.
  NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
                                std::list<Addend> &ImagAddends);

  /// Determine if sum of multiplications of complex numbers can be formed from
  /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
  /// to it. Return nullptr if it is not possible to construct a complex number.
  NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
                                  std::vector<Product> &ImagMuls,
                                  NodePtr Accumulator);

  /// Go through pairs of multiplication (one Real and one Imag) and find all
  /// possible candidates for partial multiplication and put them into \p
  /// Candidates. Returns true if all Product has pair with common operand
  bool collectPartialMuls(const std::vector<Product> &RealMuls,
                          const std::vector<Product> &ImagMuls,
                          std::vector<PartialMulCandidate> &Candidates);

  /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
  /// the order of complex computation operations may be significantly altered,
  /// and the real and imaginary parts may not be executed in parallel. This
  /// function takes this into consideration and employs a more general approach
  /// to identify complex computations. Initially, it gathers all the addends
  /// and multiplicands and then constructs a complex expression from them.
  NodePtr identifyReassocNodes(Instruction *I, Instruction *J);

  NodePtr identifyRoot(Instruction *I);

  /// Identifies the Deinterleave operation applied to a vector containing
  /// complex numbers. There are two ways to represent the Deinterleave
  /// operation:
  /// * Using two shufflevectors with even indices for /pReal instruction and
  /// odd indices for /pImag instructions (only for fixed-width vectors)
  /// * Using two extractvalue instructions applied to `vector.deinterleave2`
  /// intrinsic (for both fixed and scalable vectors)
  NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);

  /// identifying the operation that represents a complex number repeated in a
  /// Splat vector. There are two possible types of splats: ConstantExpr with
  /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
  /// initialization mask with all values set to zero.
  NodePtr identifySplat(Value *Real, Value *Imag);

  NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);

  /// Identifies SelectInsts in a loop that has reduction with predication masks
  /// and/or predicated tail folding
  NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);

  Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);

  /// Complete IR modifications after producing new reduction operation:
  /// * Populate the PHINode generated for
  /// ComplexDeinterleavingOperation::ReductionPHI
  /// * Deinterleave the final value outside of the loop and repurpose original
  /// reduction users
  void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);

public:
  void dump() { dump(dbgs()); }
  void dump(raw_ostream &OS) {
    for (const auto &Node : CompositeNodes)
      Node->dump(OS);
  }

  /// Returns false if the deinterleaving operation should be cancelled for the
  /// current graph.
  bool identifyNodes(Instruction *RootI);

  /// In case \pB is one-block loop, this function seeks potential reductions
  /// and populates ReductionInfo. Returns true if any reductions were
  /// identified.
  bool collectPotentialReductions(BasicBlock *B);

  void identifyReductionNodes();

  /// Check that every instruction, from the roots to the leaves, has internal
  /// uses.
  bool checkNodes();

  /// Perform the actual replacement of the underlying instruction graph.
  void replaceNodes();
};

class ComplexDeinterleaving {
public:
  ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
      : TL(tl), TLI(tli) {}
  bool runOnFunction(Function &F);

private:
  bool evaluateBasicBlock(BasicBlock *B);

  const TargetLowering *TL = nullptr;
  const TargetLibraryInfo *TLI = nullptr;
};

} // namespace

char ComplexDeinterleavingLegacyPass::ID = 0;

INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                      "Complex Deinterleaving", false, false)
INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                    "Complex Deinterleaving", false, false)

PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
                                                 FunctionAnalysisManager &AM) {
  const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
  if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
    return PreservedAnalyses::all();

  PreservedAnalyses PA;
  PA.preserve<FunctionAnalysisManagerModuleProxy>();
  return PA;
}

FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
  return new ComplexDeinterleavingLegacyPass(TM);
}

bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
  const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
}

bool ComplexDeinterleaving::runOnFunction(Function &F) {
  if (!ComplexDeinterleavingEnabled) {
    LLVM_DEBUG(
        dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
    return false;
  }

  if (!TL->isComplexDeinterleavingSupported()) {
    LLVM_DEBUG(
        dbgs() << "Complex deinterleaving has been disabled, target does "
                  "not support lowering of complex number operations.\n");
    return false;
  }

  bool Changed = false;
  for (auto &B : F)
    Changed |= evaluateBasicBlock(&B);

  return Changed;
}

static bool isInterleavingMask(ArrayRef<int> Mask) {
  // If the size is not even, it's not an interleaving mask
  if ((Mask.size() & 1))
    return false;

  int HalfNumElements = Mask.size() / 2;
  for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
    int MaskIdx = Idx * 2;
    if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
      return false;
  }

  return true;
}

static bool isDeinterleavingMask(ArrayRef<int> Mask) {
  int Offset = Mask[0];
  int HalfNumElements = Mask.size() / 2;

  for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
    if (Mask[Idx] != (Idx * 2) + Offset)
      return false;
  }

  return true;
}

bool isNeg(Value *V) {
  return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
}

Value *getNegOperand(Value *V) {
  assert(isNeg(V));
  auto *I = cast<Instruction>(V);
  if (I->getOpcode() == Instruction::FNeg)
    return I->getOperand(0);

  return I->getOperand(1);
}

bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
  ComplexDeinterleavingGraph Graph(TL, TLI);
  if (Graph.collectPotentialReductions(B))
    Graph.identifyReductionNodes();

  for (auto &I : *B)
    Graph.identifyNodes(&I);

  if (Graph.checkNodes()) {
    Graph.replaceNodes();
    return true;
  }

  return false;
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
    Instruction *Real, Instruction *Imag,
    std::pair<Value *, Value *> &PartialMatch) {
  LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
                    << "\n");

  if (!Real->hasOneUse() || !Imag->hasOneUse()) {
    LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
    return nullptr;
  }

  if ((Real->getOpcode() != Instruction::FMul &&
       Real->getOpcode() != Instruction::Mul) ||
      (Imag->getOpcode() != Instruction::FMul &&
       Imag->getOpcode() != Instruction::Mul)) {
    LLVM_DEBUG(
        dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
    return nullptr;
  }

  Value *R0 = Real->getOperand(0);
  Value *R1 = Real->getOperand(1);
  Value *I0 = Imag->getOperand(0);
  Value *I1 = Imag->getOperand(1);

  // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
  // rotations and use the operand.
  unsigned Negs = 0;
  Value *Op;
  if (match(R0, m_Neg(m_Value(Op)))) {
    Negs |= 1;
    R0 = Op;
  } else if (match(R1, m_Neg(m_Value(Op)))) {
    Negs |= 1;
    R1 = Op;
  }

  if (isNeg(I0)) {
    Negs |= 2;
    Negs ^= 1;
    I0 = Op;
  } else if (match(I1, m_Neg(m_Value(Op)))) {
    Negs |= 2;
    Negs ^= 1;
    I1 = Op;
  }

  ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;

  Value *CommonOperand;
  Value *UncommonRealOp;
  Value *UncommonImagOp;

  if (R0 == I0 || R0 == I1) {
    CommonOperand = R0;
    UncommonRealOp = R1;
  } else if (R1 == I0 || R1 == I1) {
    CommonOperand = R1;
    UncommonRealOp = R0;
  } else {
    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
    return nullptr;
  }

  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_270)
    std::swap(UncommonRealOp, UncommonImagOp);

  // Between identifyPartialMul and here we need to have found a complete valid
  // pair from the CommonOperand of each part.
  if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_180)
    PartialMatch.first = CommonOperand;
  else
    PartialMatch.second = CommonOperand;

  if (!PartialMatch.first || !PartialMatch.second) {
    LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
    return nullptr;
  }

  NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
  if (!CommonNode) {
    LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
    return nullptr;
  }

  NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
  if (!UncommonNode) {
    LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
    return nullptr;
  }

  NodePtr Node = prepareCompositeNode(
      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(CommonNode);
  Node->addOperand(UncommonNode);
  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                                               Instruction *Imag) {
  LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
                    << "\n");
  // Determine rotation
  auto IsAdd = [](unsigned Op) {
    return Op == Instruction::FAdd || Op == Instruction::Add;
  };
  auto IsSub = [](unsigned Op) {
    return Op == Instruction::FSub || Op == Instruction::Sub;
  };
  ComplexDeinterleavingRotation Rotation;
  if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_0;
  else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_90;
  else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_180;
  else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
    Rotation = ComplexDeinterleavingRotation::Rotation_270;
  else {
    LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
    return nullptr;
  }

  if (isa<FPMathOperator>(Real) &&
      (!Real->getFastMathFlags().allowContract() ||
       !Imag->getFastMathFlags().allowContract())) {
    LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
    return nullptr;
  }

  Value *CR = Real->getOperand(0);
  Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
  if (!RealMulI)
    return nullptr;
  Value *CI = Imag->getOperand(0);
  Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
  if (!ImagMulI)
    return nullptr;

  if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
    LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
    return nullptr;
  }

  Value *R0 = RealMulI->getOperand(0);
  Value *R1 = RealMulI->getOperand(1);
  Value *I0 = ImagMulI->getOperand(0);
  Value *I1 = ImagMulI->getOperand(1);

  Value *CommonOperand;
  Value *UncommonRealOp;
  Value *UncommonImagOp;

  if (R0 == I0 || R0 == I1) {
    CommonOperand = R0;
    UncommonRealOp = R1;
  } else if (R1 == I0 || R1 == I1) {
    CommonOperand = R1;
    UncommonRealOp = R0;
  } else {
    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
    return nullptr;
  }

  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_270)
    std::swap(UncommonRealOp, UncommonImagOp);

  std::pair<Value *, Value *> PartialMatch(
      (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
       Rotation == ComplexDeinterleavingRotation::Rotation_180)
          ? CommonOperand
          : nullptr,
      (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
       Rotation == ComplexDeinterleavingRotation::Rotation_270)
          ? CommonOperand
          : nullptr);

  auto *CRInst = dyn_cast<Instruction>(CR);
  auto *CIInst = dyn_cast<Instruction>(CI);

  if (!CRInst || !CIInst) {
    LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
    return nullptr;
  }

  NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
  if (!CNode) {
    LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
    return nullptr;
  }

  NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
  if (!UncommonRes) {
    LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
    return nullptr;
  }

  assert(PartialMatch.first && PartialMatch.second);
  NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
  if (!CommonRes) {
    LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
    return nullptr;
  }

  NodePtr Node = prepareCompositeNode(
      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(CommonRes);
  Node->addOperand(UncommonRes);
  Node->addOperand(CNode);
  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
  LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");

  // Determine rotation
  ComplexDeinterleavingRotation Rotation;
  if ((Real->getOpcode() == Instruction::FSub &&
       Imag->getOpcode() == Instruction::FAdd) ||
      (Real->getOpcode() == Instruction::Sub &&
       Imag->getOpcode() == Instruction::Add))
    Rotation = ComplexDeinterleavingRotation::Rotation_90;
  else if ((Real->getOpcode() == Instruction::FAdd &&
            Imag->getOpcode() == Instruction::FSub) ||
           (Real->getOpcode() == Instruction::Add &&
            Imag->getOpcode() == Instruction::Sub))
    Rotation = ComplexDeinterleavingRotation::Rotation_270;
  else {
    LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
    return nullptr;
  }

  auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
  auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
  auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
  auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));

  if (!AR || !AI || !BR || !BI) {
    LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
    return nullptr;
  }

  NodePtr ResA = identifyNode(AR, AI);
  if (!ResA) {
    LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
    return nullptr;
  }
  NodePtr ResB = identifyNode(BR, BI);
  if (!ResB) {
    LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
    return nullptr;
  }

  NodePtr Node =
      prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(ResA);
  Node->addOperand(ResB);
  return submitCompositeNode(Node);
}

static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
  unsigned OpcA = A->getOpcode();
  unsigned OpcB = B->getOpcode();

  return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
         (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
         (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
         (OpcA == Instruction::Add && OpcB == Instruction::Sub);
}

static bool isInstructionPairMul(Instruction *A, Instruction *B) {
  auto Pattern =
      m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));

  return match(A, Pattern) && match(B, Pattern);
}

static bool isInstructionPotentiallySymmetric(Instruction *I) {
  switch (I->getOpcode()) {
  case Instruction::FAdd:
  case Instruction::FSub:
  case Instruction::FMul:
  case Instruction::FNeg:
  case Instruction::Add:
  case Instruction::Sub:
  case Instruction::Mul:
    return true;
  default:
    return false;
  }
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
                                                       Instruction *Imag) {
  if (Real->getOpcode() != Imag->getOpcode())
    return nullptr;

  if (!isInstructionPotentiallySymmetric(Real) ||
      !isInstructionPotentiallySymmetric(Imag))
    return nullptr;

  auto *R0 = Real->getOperand(0);
  auto *I0 = Imag->getOperand(0);

  NodePtr Op0 = identifyNode(R0, I0);
  NodePtr Op1 = nullptr;
  if (Op0 == nullptr)
    return nullptr;

  if (Real->isBinaryOp()) {
    auto *R1 = Real->getOperand(1);
    auto *I1 = Imag->getOperand(1);
    Op1 = identifyNode(R1, I1);
    if (Op1 == nullptr)
      return nullptr;
  }

  if (isa<FPMathOperator>(Real) &&
      Real->getFastMathFlags() != Imag->getFastMathFlags())
    return nullptr;

  auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
                                   Real, Imag);
  Node->Opcode = Real->getOpcode();
  if (isa<FPMathOperator>(Real))
    Node->Flags = Real->getFastMathFlags();

  Node->addOperand(Op0);
  if (Real->isBinaryOp())
    Node->addOperand(Op1);

  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
  LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
  assert(R->getType() == I->getType() &&
         "Real and imaginary parts should not have different types");

  auto It = CachedResult.find({R, I});
  if (It != CachedResult.end()) {
    LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
    return It->second;
  }

  if (NodePtr CN = identifySplat(R, I))
    return CN;

  auto *Real = dyn_cast<Instruction>(R);
  auto *Imag = dyn_cast<Instruction>(I);
  if (!Real || !Imag)
    return nullptr;

  if (NodePtr CN = identifyDeinterleave(Real, Imag))
    return CN;

  if (NodePtr CN = identifyPHINode(Real, Imag))
    return CN;

  if (NodePtr CN = identifySelectNode(Real, Imag))
    return CN;

  auto *VTy = cast<VectorType>(Real->getType());
  auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);

  bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
      ComplexDeinterleavingOperation::CMulPartial, NewVTy);
  bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
      ComplexDeinterleavingOperation::CAdd, NewVTy);

  if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
    if (NodePtr CN = identifyPartialMul(Real, Imag))
      return CN;
  }

  if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
    if (NodePtr CN = identifyAdd(Real, Imag))
      return CN;
  }

  if (HasCMulSupport && HasCAddSupport) {
    if (NodePtr CN = identifyReassocNodes(Real, Imag))
      return CN;
  }

  if (NodePtr CN = identifySymmetricOperation(Real, Imag))
    return CN;

  LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
  CachedResult[{R, I}] = nullptr;
  return nullptr;
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
                                                 Instruction *Imag) {
  auto IsOperationSupported = [](unsigned Opcode) -> bool {
    return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
           Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
           Opcode == Instruction::Sub;
  };

  if (!IsOperationSupported(Real->getOpcode()) ||
      !IsOperationSupported(Imag->getOpcode()))
    return nullptr;

  std::optional<FastMathFlags> Flags;
  if (isa<FPMathOperator>(Real)) {
    if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
      LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
                           "not identical\n");
      return nullptr;
    }

    Flags = Real->getFastMathFlags();
    if (!Flags->allowReassoc()) {
      LLVM_DEBUG(
          dbgs()
          << "the 'Reassoc' attribute is missing in the FastMath flags\n");
      return nullptr;
    }
  }

  // Collect multiplications and addend instructions from the given instruction
  // while traversing it operands. Additionally, verify that all instructions
  // have the same fast math flags.
  auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
                          std::list<Addend> &Addends) -> bool {
    SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
    SmallPtrSet<Value *, 8> Visited;
    while (!Worklist.empty()) {
      auto [V, IsPositive] = Worklist.back();
      Worklist.pop_back();
      if (!Visited.insert(V).second)
        continue;

      Instruction *I = dyn_cast<Instruction>(V);
      if (!I) {
        Addends.emplace_back(V, IsPositive);
        continue;
      }

      // If an instruction has more than one user, it indicates that it either
      // has an external user, which will be later checked by the checkNodes
      // function, or it is a subexpression utilized by multiple expressions. In
      // the latter case, we will attempt to separately identify the complex
      // operation from here in order to create a shared
      // ComplexDeinterleavingCompositeNode.
      if (I != Insn && I->getNumUses() > 1) {
        LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
        Addends.emplace_back(I, IsPositive);
        continue;
      }
      switch (I->getOpcode()) {
      case Instruction::FAdd:
      case Instruction::Add:
        Worklist.emplace_back(I->getOperand(1), IsPositive);
        Worklist.emplace_back(I->getOperand(0), IsPositive);
        break;
      case Instruction::FSub:
        Worklist.emplace_back(I->getOperand(1), !IsPositive);
        Worklist.emplace_back(I->getOperand(0), IsPositive);
        break;
      case Instruction::Sub:
        if (isNeg(I)) {
          Worklist.emplace_back(getNegOperand(I), !IsPositive);
        } else {
          Worklist.emplace_back(I->getOperand(1), !IsPositive);
          Worklist.emplace_back(I->getOperand(0), IsPositive);
        }
        break;
      case Instruction::FMul:
      case Instruction::Mul: {
        Value *A, *B;
        if (isNeg(I->getOperand(0))) {
          A = getNegOperand(I->getOperand(0));
          IsPositive = !IsPositive;
        } else {
          A = I->getOperand(0);
        }

        if (isNeg(I->getOperand(1))) {
          B = getNegOperand(I->getOperand(1));
          IsPositive = !IsPositive;
        } else {
          B = I->getOperand(1);
        }
        Muls.push_back(Product{A, B, IsPositive});
        break;
      }
      case Instruction::FNeg:
        Worklist.emplace_back(I->getOperand(0), !IsPositive);
        break;
      default:
        Addends.emplace_back(I, IsPositive);
        continue;
      }

      if (Flags && I->getFastMathFlags() != *Flags) {
        LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
                             "inconsistent with the root instructions' flags: "
                          << *I << "\n");
        return false;
      }
    }
    return true;
  };

  std::vector<Product> RealMuls, ImagMuls;
  std::list<Addend> RealAddends, ImagAddends;
  if (!Collect(Real, RealMuls, RealAddends) ||
      !Collect(Imag, ImagMuls, ImagAddends))
    return nullptr;

  if (RealAddends.size() != ImagAddends.size())
    return nullptr;

  NodePtr FinalNode;
  if (!RealMuls.empty() || !ImagMuls.empty()) {
    // If there are multiplicands, extract positive addend and use it as an
    // accumulator
    FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
    FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
    if (!FinalNode)
      return nullptr;
  }

  // Identify and process remaining additions
  if (!RealAddends.empty() || !ImagAddends.empty()) {
    FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
    if (!FinalNode)
      return nullptr;
  }
  assert(FinalNode && "FinalNode can not be nullptr here");
  // Set the Real and Imag fields of the final node and submit it
  FinalNode->Real = Real;
  FinalNode->Imag = Imag;
  submitCompositeNode(FinalNode);
  return FinalNode;
}

bool ComplexDeinterleavingGraph::collectPartialMuls(
    const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
    std::vector<PartialMulCandidate> &PartialMulCandidates) {
  // Helper function to extract a common operand from two products
  auto FindCommonInstruction = [](const Product &Real,
                                  const Product &Imag) -> Value * {
    if (Real.Multiplicand == Imag.Multiplicand ||
        Real.Multiplicand == Imag.Multiplier)
      return Real.Multiplicand;

    if (Real.Multiplier == Imag.Multiplicand ||
        Real.Multiplier == Imag.Multiplier)
      return Real.Multiplier;

    return nullptr;
  };

  // Iterating over real and imaginary multiplications to find common operands
  // If a common operand is found, a partial multiplication candidate is created
  // and added to the candidates vector The function returns false if no common
  // operands are found for any product
  for (unsigned i = 0; i < RealMuls.size(); ++i) {
    bool FoundCommon = false;
    for (unsigned j = 0; j < ImagMuls.size(); ++j) {
      auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
      if (!Common)
        continue;

      auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
                                                   : RealMuls[i].Multiplicand;
      auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
                                                   : ImagMuls[j].Multiplicand;

      auto Node = identifyNode(A, B);
      if (Node) {
        FoundCommon = true;
        PartialMulCandidates.push_back({Common, Node, i, j, false});
      }

      Node = identifyNode(B, A);
      if (Node) {
        FoundCommon = true;
        PartialMulCandidates.push_back({Common, Node, i, j, true});
      }
    }
    if (!FoundCommon)
      return false;
  }
  return true;
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyMultiplications(
    std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
    NodePtr Accumulator = nullptr) {
  if (RealMuls.size() != ImagMuls.size())
    return nullptr;

  std::vector<PartialMulCandidate> Info;
  if (!collectPartialMuls(RealMuls, ImagMuls, Info))
    return nullptr;

  // Map to store common instruction to node pointers
  std::map<Value *, NodePtr> CommonToNode;
  std::vector<bool> Processed(Info.size(), false);
  for (unsigned I = 0; I < Info.size(); ++I) {
    if (Processed[I])
      continue;

    PartialMulCandidate &InfoA = Info[I];
    for (unsigned J = I + 1; J < Info.size(); ++J) {
      if (Processed[J])
        continue;

      PartialMulCandidate &InfoB = Info[J];
      auto *InfoReal = &InfoA;
      auto *InfoImag = &InfoB;

      auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
      if (!NodeFromCommon) {
        std::swap(InfoReal, InfoImag);
        NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
      }
      if (!NodeFromCommon)
        continue;

      CommonToNode[InfoReal->Common] = NodeFromCommon;
      CommonToNode[InfoImag->Common] = NodeFromCommon;
      Processed[I] = true;
      Processed[J] = true;
    }
  }

  std::vector<bool> ProcessedReal(RealMuls.size(), false);
  std::vector<bool> ProcessedImag(ImagMuls.size(), false);
  NodePtr Result = Accumulator;
  for (auto &PMI : Info) {
    if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
      continue;

    auto It = CommonToNode.find(PMI.Common);
    // TODO: Process independent complex multiplications. Cases like this:
    //  A.real() * B where both A and B are complex numbers.
    if (It == CommonToNode.end()) {
      LLVM_DEBUG({
        dbgs() << "Unprocessed independent partial multiplication:\n";
        for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
          dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
                           << " multiplied by " << *Mul->Multiplicand << "\n";
      });
      return nullptr;
    }

    auto &RealMul = RealMuls[PMI.RealIdx];
    auto &ImagMul = ImagMuls[PMI.ImagIdx];

    auto NodeA = It->second;
    auto NodeB = PMI.Node;
    auto IsMultiplicandReal = PMI.Common == NodeA->Real;
    // The following table illustrates the relationship between multiplications
    // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
    // can see:
    //
    // Rotation |   Real |   Imag |
    // ---------+--------+--------+
    //        0 |  x * u |  x * v |
    //       90 | -y * v |  y * u |
    //      180 | -x * u | -x * v |
    //      270 |  y * v | -y * u |
    //
    // Check if the candidate can indeed be represented by partial
    // multiplication
    // TODO: Add support for multiplication by complex one
    if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
        (!IsMultiplicandReal && !PMI.IsNodeInverted))
      continue;

    // Determine the rotation based on the multiplications
    ComplexDeinterleavingRotation Rotation;
    if (IsMultiplicandReal) {
      // Detect 0 and 180 degrees rotation
      if (RealMul.IsPositive && ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
      else if (!RealMul.IsPositive && !ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
      else
        continue;

    } else {
      // Detect 90 and 270 degrees rotation
      if (!RealMul.IsPositive && ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
      else if (RealMul.IsPositive && !ImagMul.IsPositive)
        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
      else
        continue;
    }

    LLVM_DEBUG({
      dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
      dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
      dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
      dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
      dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
      dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
    });

    NodePtr NodeMul = prepareCompositeNode(
        ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
    NodeMul->Rotation = Rotation;
    NodeMul->addOperand(NodeA);
    NodeMul->addOperand(NodeB);
    if (Result)
      NodeMul->addOperand(Result);
    submitCompositeNode(NodeMul);
    Result = NodeMul;
    ProcessedReal[PMI.RealIdx] = true;
    ProcessedImag[PMI.ImagIdx] = true;
  }

  // Ensure all products have been processed, if not return nullptr.
  if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
      !all_of(ProcessedImag, [](bool V) { return V; })) {

    // Dump debug information about which partial multiplications are not
    // processed.
    LLVM_DEBUG({
      dbgs() << "Unprocessed products (Real):\n";
      for (size_t i = 0; i < ProcessedReal.size(); ++i) {
        if (!ProcessedReal[i])
          dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
                           << *RealMuls[i].Multiplier << " multiplied by "
                           << *RealMuls[i].Multiplicand << "\n";
      }
      dbgs() << "Unprocessed products (Imag):\n";
      for (size_t i = 0; i < ProcessedImag.size(); ++i) {
        if (!ProcessedImag[i])
          dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
                           << *ImagMuls[i].Multiplier << " multiplied by "
                           << *ImagMuls[i].Multiplicand << "\n";
      }
    });
    return nullptr;
  }

  return Result;
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyAdditions(
    std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
    std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
  if (RealAddends.size() != ImagAddends.size())
    return nullptr;

  NodePtr Result;
  // If we have accumulator use it as first addend
  if (Accumulator)
    Result = Accumulator;
  // Otherwise find an element with both positive real and imaginary parts.
  else
    Result = extractPositiveAddend(RealAddends, ImagAddends);

  if (!Result)
    return nullptr;

  while (!RealAddends.empty()) {
    auto ItR = RealAddends.begin();
    auto [R, IsPositiveR] = *ItR;

    bool FoundImag = false;
    for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
      auto [I, IsPositiveI] = *ItI;
      ComplexDeinterleavingRotation Rotation;
      if (IsPositiveR && IsPositiveI)
        Rotation = ComplexDeinterleavingRotation::Rotation_0;
      else if (!IsPositiveR && IsPositiveI)
        Rotation = ComplexDeinterleavingRotation::Rotation_90;
      else if (!IsPositiveR && !IsPositiveI)
        Rotation = ComplexDeinterleavingRotation::Rotation_180;
      else
        Rotation = ComplexDeinterleavingRotation::Rotation_270;

      NodePtr AddNode;
      if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
          Rotation == ComplexDeinterleavingRotation::Rotation_180) {
        AddNode = identifyNode(R, I);
      } else {
        AddNode = identifyNode(I, R);
      }
      if (AddNode) {
        LLVM_DEBUG({
          dbgs() << "Identified addition:\n";
          dbgs().indent(4) << "X: " << *R << "\n";
          dbgs().indent(4) << "Y: " << *I << "\n";
          dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
        });

        NodePtr TmpNode;
        if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
          TmpNode = prepareCompositeNode(
              ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
          if (Flags) {
            TmpNode->Opcode = Instruction::FAdd;
            TmpNode->Flags = *Flags;
          } else {
            TmpNode->Opcode = Instruction::Add;
          }
        } else if (Rotation ==
                   llvm::ComplexDeinterleavingRotation::Rotation_180) {
          TmpNode = prepareCompositeNode(
              ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
          if (Flags) {
            TmpNode->Opcode = Instruction::FSub;
            TmpNode->Flags = *Flags;
          } else {
            TmpNode->Opcode = Instruction::Sub;
          }
        } else {
          TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
                                         nullptr, nullptr);
          TmpNode->Rotation = Rotation;
        }

        TmpNode->addOperand(Result);
        TmpNode->addOperand(AddNode);
        submitCompositeNode(TmpNode);
        Result = TmpNode;
        RealAddends.erase(ItR);
        ImagAddends.erase(ItI);
        FoundImag = true;
        break;
      }
    }
    if (!FoundImag)
      return nullptr;
  }
  return Result;
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::extractPositiveAddend(
    std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
  for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
    for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
      auto [R, IsPositiveR] = *ItR;
      auto [I, IsPositiveI] = *ItI;
      if (IsPositiveR && IsPositiveI) {
        auto Result = identifyNode(R, I);
        if (Result) {
          RealAddends.erase(ItR);
          ImagAddends.erase(ItI);
          return Result;
        }
      }
    }
  }
  return nullptr;
}

bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
  // This potential root instruction might already have been recognized as
  // reduction. Because RootToNode maps both Real and Imaginary parts to
  // CompositeNode we should choose only one either Real or Imag instruction to
  // use as an anchor for generating complex instruction.
  auto It = RootToNode.find(RootI);
  if (It != RootToNode.end()) {
    auto RootNode = It->second;
    assert(RootNode->Operation ==
           ComplexDeinterleavingOperation::ReductionOperation);
    // Find out which part, Real or Imag, comes later, and only if we come to
    // the latest part, add it to OrderedRoots.
    auto *R = cast<Instruction>(RootNode->Real);
    auto *I = cast<Instruction>(RootNode->Imag);
    auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
    if (ReplacementAnchor != RootI)
      return false;
    OrderedRoots.push_back(RootI);
    return true;
  }

  auto RootNode = identifyRoot(RootI);
  if (!RootNode)
    return false;

  LLVM_DEBUG({
    Function *F = RootI->getFunction();
    BasicBlock *B = RootI->getParent();
    dbgs() << "Complex deinterleaving graph for " << F->getName()
           << "::" << B->getName() << ".\n";
    dump(dbgs());
    dbgs() << "\n";
  });
  RootToNode[RootI] = RootNode;
  OrderedRoots.push_back(RootI);
  return true;
}

bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
  bool FoundPotentialReduction = false;

  auto *Br = dyn_cast<BranchInst>(B->getTerminator());
  if (!Br || Br->getNumSuccessors() != 2)
    return false;

  // Identify simple one-block loop
  if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
    return false;

  SmallVector<PHINode *> PHIs;
  for (auto &PHI : B->phis()) {
    if (PHI.getNumIncomingValues() != 2)
      continue;

    if (!PHI.getType()->isVectorTy())
      continue;

    auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
    if (!ReductionOp)
      continue;

    // Check if final instruction is reduced outside of current block
    Instruction *FinalReduction = nullptr;
    auto NumUsers = 0u;
    for (auto *U : ReductionOp->users()) {
      ++NumUsers;
      if (U == &PHI)
        continue;
      FinalReduction = dyn_cast<Instruction>(U);
    }

    if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
        isa<PHINode>(FinalReduction))
      continue;

    ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
    BackEdge = B;
    auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
    auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
    Incoming = PHI.getIncomingBlock(IncomingIdx);
    FoundPotentialReduction = true;

    // If the initial value of PHINode is an Instruction, consider it a leaf
    // value of a complex deinterleaving graph.
    if (auto *InitPHI =
            dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
      FinalInstructions.insert(InitPHI);
  }
  return FoundPotentialReduction;
}

void ComplexDeinterleavingGraph::identifyReductionNodes() {
  SmallVector<bool> Processed(ReductionInfo.size(), false);
  SmallVector<Instruction *> OperationInstruction;
  for (auto &P : ReductionInfo)
    OperationInstruction.push_back(P.first);

  // Identify a complex computation by evaluating two reduction operations that
  // potentially could be involved
  for (size_t i = 0; i < OperationInstruction.size(); ++i) {
    if (Processed[i])
      continue;
    for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
      if (Processed[j])
        continue;

      auto *Real = OperationInstruction[i];
      auto *Imag = OperationInstruction[j];
      if (Real->getType() != Imag->getType())
        continue;

      RealPHI = ReductionInfo[Real].first;
      ImagPHI = ReductionInfo[Imag].first;
      PHIsFound = false;
      auto Node = identifyNode(Real, Imag);
      if (!Node) {
        std::swap(Real, Imag);
        std::swap(RealPHI, ImagPHI);
        Node = identifyNode(Real, Imag);
      }

      // If a node is identified and reduction PHINode is used in the chain of
      // operations, mark its operation instructions as used to prevent
      // re-identification and attach the node to the real part
      if (Node && PHIsFound) {
        LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
                          << *Real << " / " << *Imag << "\n");
        Processed[i] = true;
        Processed[j] = true;
        auto RootNode = prepareCompositeNode(
            ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
        RootNode->addOperand(Node);
        RootToNode[Real] = RootNode;
        RootToNode[Imag] = RootNode;
        submitCompositeNode(RootNode);
        break;
      }
    }
  }

  RealPHI = nullptr;
  ImagPHI = nullptr;
}

bool ComplexDeinterleavingGraph::checkNodes() {
  // Collect all instructions from roots to leaves
  SmallPtrSet<Instruction *, 16> AllInstructions;
  SmallVector<Instruction *, 8> Worklist;
  for (auto &Pair : RootToNode)
    Worklist.push_back(Pair.first);

  // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
  // chains
  while (!Worklist.empty()) {
    auto *I = Worklist.back();
    Worklist.pop_back();

    if (!AllInstructions.insert(I).second)
      continue;

    for (Value *Op : I->operands()) {
      if (auto *OpI = dyn_cast<Instruction>(Op)) {
        if (!FinalInstructions.count(I))
          Worklist.emplace_back(OpI);
      }
    }
  }

  // Find instructions that have users outside of chain
  SmallVector<Instruction *, 2> OuterInstructions;
  for (auto *I : AllInstructions) {
    // Skip root nodes
    if (RootToNode.count(I))
      continue;

    for (User *U : I->users()) {
      if (AllInstructions.count(cast<Instruction>(U)))
        continue;

      // Found an instruction that is not used by XCMLA/XCADD chain
      Worklist.emplace_back(I);
      break;
    }
  }

  // If any instructions are found to be used outside, find and remove roots
  // that somehow connect to those instructions.
  SmallPtrSet<Instruction *, 16> Visited;
  while (!Worklist.empty()) {
    auto *I = Worklist.back();
    Worklist.pop_back();
    if (!Visited.insert(I).second)
      continue;

    // Found an impacted root node. Removing it from the nodes to be
    // deinterleaved
    if (RootToNode.count(I)) {
      LLVM_DEBUG(dbgs() << "Instruction " << *I
                        << " could be deinterleaved but its chain of complex "
                           "operations have an outside user\n");
      RootToNode.erase(I);
    }

    if (!AllInstructions.count(I) || FinalInstructions.count(I))
      continue;

    for (User *U : I->users())
      Worklist.emplace_back(cast<Instruction>(U));

    for (Value *Op : I->operands()) {
      if (auto *OpI = dyn_cast<Instruction>(Op))
        Worklist.emplace_back(OpI);
    }
  }
  return !RootToNode.empty();
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
  if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
    if (Intrinsic->getIntrinsicID() !=
        Intrinsic::experimental_vector_interleave2)
      return nullptr;

    auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
    auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
    if (!Real || !Imag)
      return nullptr;

    return identifyNode(Real, Imag);
  }

  auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
  if (!SVI)
    return nullptr;

  // Look for a shufflevector that takes separate vectors of the real and
  // imaginary components and recombines them into a single vector.
  if (!isInterleavingMask(SVI->getShuffleMask()))
    return nullptr;

  Instruction *Real;
  Instruction *Imag;
  if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
    return nullptr;

  return identifyNode(Real, Imag);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
                                                 Instruction *Imag) {
  Instruction *I = nullptr;
  Value *FinalValue = nullptr;
  if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
      match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
      match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
                   m_Value(FinalValue)))) {
    NodePtr PlaceholderNode = prepareCompositeNode(
        llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
    PlaceholderNode->ReplacementNode = FinalValue;
    FinalInstructions.insert(Real);
    FinalInstructions.insert(Imag);
    return submitCompositeNode(PlaceholderNode);
  }

  auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
  auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
  if (!RealShuffle || !ImagShuffle) {
    if (RealShuffle || ImagShuffle)
      LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
    return nullptr;
  }

  Value *RealOp1 = RealShuffle->getOperand(1);
  if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
    LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
    return nullptr;
  }
  Value *ImagOp1 = ImagShuffle->getOperand(1);
  if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
    LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
    return nullptr;
  }

  Value *RealOp0 = RealShuffle->getOperand(0);
  Value *ImagOp0 = ImagShuffle->getOperand(0);

  if (RealOp0 != ImagOp0) {
    LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
    return nullptr;
  }

  ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
  ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
  if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
    LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
    return nullptr;
  }

  if (RealMask[0] != 0 || ImagMask[0] != 1) {
    LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
    return nullptr;
  }

  // Type checking, the shuffle type should be a vector type of the same
  // scalar type, but half the size
  auto CheckType = [&](ShuffleVectorInst *Shuffle) {
    Value *Op = Shuffle->getOperand(0);
    auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
    auto *OpTy = cast<FixedVectorType>(Op->getType());

    if (OpTy->getScalarType() != ShuffleTy->getScalarType())
      return false;
    if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
      return false;

    return true;
  };

  auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
    if (!CheckType(Shuffle))
      return false;

    ArrayRef<int> Mask = Shuffle->getShuffleMask();
    int Last = *Mask.rbegin();

    Value *Op = Shuffle->getOperand(0);
    auto *OpTy = cast<FixedVectorType>(Op->getType());
    int NumElements = OpTy->getNumElements();

    // Ensure that the deinterleaving shuffle only pulls from the first
    // shuffle operand.
    return Last < NumElements;
  };

  if (RealShuffle->getType() != ImagShuffle->getType()) {
    LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
    return nullptr;
  }
  if (!CheckDeinterleavingShuffle(RealShuffle)) {
    LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
    return nullptr;
  }
  if (!CheckDeinterleavingShuffle(ImagShuffle)) {
    LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
    return nullptr;
  }

  NodePtr PlaceholderNode =
      prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
                           RealShuffle, ImagShuffle);
  PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
  FinalInstructions.insert(RealShuffle);
  FinalInstructions.insert(ImagShuffle);
  return submitCompositeNode(PlaceholderNode);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
  auto IsSplat = [](Value *V) -> bool {
    // Fixed-width vector with constants
    if (isa<ConstantDataVector>(V))
      return true;

    VectorType *VTy;
    ArrayRef<int> Mask;
    // Splats are represented differently depending on whether the repeated
    // value is a constant or an Instruction
    if (auto *Const = dyn_cast<ConstantExpr>(V)) {
      if (Const->getOpcode() != Instruction::ShuffleVector)
        return false;
      VTy = cast<VectorType>(Const->getType());
      Mask = Const->getShuffleMask();
    } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
      VTy = Shuf->getType();
      Mask = Shuf->getShuffleMask();
    } else {
      return false;
    }

    // When the data type is <1 x Type>, it's not possible to differentiate
    // between the ComplexDeinterleaving::Deinterleave and
    // ComplexDeinterleaving::Splat operations.
    if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
      return false;

    return all_equal(Mask) && Mask[0] == 0;
  };

  if (!IsSplat(R) || !IsSplat(I))
    return nullptr;

  auto *Real = dyn_cast<Instruction>(R);
  auto *Imag = dyn_cast<Instruction>(I);
  if ((!Real && Imag) || (Real && !Imag))
    return nullptr;

  if (Real && Imag) {
    // Non-constant splats should be in the same basic block
    if (Real->getParent() != Imag->getParent())
      return nullptr;

    FinalInstructions.insert(Real);
    FinalInstructions.insert(Imag);
  }
  NodePtr PlaceholderNode =
      prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
  return submitCompositeNode(PlaceholderNode);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
                                            Instruction *Imag) {
  if (Real != RealPHI || Imag != ImagPHI)
    return nullptr;

  PHIsFound = true;
  NodePtr PlaceholderNode = prepareCompositeNode(
      ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
  return submitCompositeNode(PlaceholderNode);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
                                               Instruction *Imag) {
  auto *SelectReal = dyn_cast<SelectInst>(Real);
  auto *SelectImag = dyn_cast<SelectInst>(Imag);
  if (!SelectReal || !SelectImag)
    return nullptr;

  Instruction *MaskA, *MaskB;
  Instruction *AR, *AI, *RA, *BI;
  if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
                            m_Instruction(RA))) ||
      !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
                            m_Instruction(BI))))
    return nullptr;

  if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
    return nullptr;

  if (!MaskA->getType()->isVectorTy())
    return nullptr;

  auto NodeA = identifyNode(AR, AI);
  if (!NodeA)
    return nullptr;

  auto NodeB = identifyNode(RA, BI);
  if (!NodeB)
    return nullptr;

  NodePtr PlaceholderNode = prepareCompositeNode(
      ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
  PlaceholderNode->addOperand(NodeA);
  PlaceholderNode->addOperand(NodeB);
  FinalInstructions.insert(MaskA);
  FinalInstructions.insert(MaskB);
  return submitCompositeNode(PlaceholderNode);
}

static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
                                   std::optional<FastMathFlags> Flags,
                                   Value *InputA, Value *InputB) {
  Value *I;
  switch (Opcode) {
  case Instruction::FNeg:
    I = B.CreateFNeg(InputA);
    break;
  case Instruction::FAdd:
    I = B.CreateFAdd(InputA, InputB);
    break;
  case Instruction::Add:
    I = B.CreateAdd(InputA, InputB);
    break;
  case Instruction::FSub:
    I = B.CreateFSub(InputA, InputB);
    break;
  case Instruction::Sub:
    I = B.CreateSub(InputA, InputB);
    break;
  case Instruction::FMul:
    I = B.CreateFMul(InputA, InputB);
    break;
  case Instruction::Mul:
    I = B.CreateMul(InputA, InputB);
    break;
  default:
    llvm_unreachable("Incorrect symmetric opcode");
  }
  if (Flags)
    cast<Instruction>(I)->setFastMathFlags(*Flags);
  return I;
}

Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
                                               RawNodePtr Node) {
  if (Node->ReplacementNode)
    return Node->ReplacementNode;

  auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
    return Node->Operands.size() > Idx
               ? replaceNode(Builder, Node->Operands[Idx])
               : nullptr;
  };

  Value *ReplacementNode;
  switch (Node->Operation) {
  case ComplexDeinterleavingOperation::CAdd:
  case ComplexDeinterleavingOperation::CMulPartial:
  case ComplexDeinterleavingOperation::Symmetric: {
    Value *Input0 = ReplaceOperandIfExist(Node, 0);
    Value *Input1 = ReplaceOperandIfExist(Node, 1);
    Value *Accumulator = ReplaceOperandIfExist(Node, 2);
    assert(!Input1 || (Input0->getType() == Input1->getType() &&
                       "Node inputs need to be of the same type"));
    assert(!Accumulator ||
           (Input0->getType() == Accumulator->getType() &&
            "Accumulator and input need to be of the same type"));
    if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
      ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
                                             Input0, Input1);
    else
      ReplacementNode = TL->createComplexDeinterleavingIR(
          Builder, Node->Operation, Node->Rotation, Input0, Input1,
          Accumulator);
    break;
  }
  case ComplexDeinterleavingOperation::Deinterleave:
    llvm_unreachable("Deinterleave node should already have ReplacementNode");
    break;
  case ComplexDeinterleavingOperation::Splat: {
    auto *NewTy = VectorType::getDoubleElementsVectorType(
        cast<VectorType>(Node->Real->getType()));
    auto *R = dyn_cast<Instruction>(Node->Real);
    auto *I = dyn_cast<Instruction>(Node->Imag);
    if (R && I) {
      // Splats that are not constant are interleaved where they are located
      Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
      IRBuilder<> IRB(InsertPoint);
      ReplacementNode =
          IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
                              {Node->Real, Node->Imag});
    } else {
      ReplacementNode =
          Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
                                  NewTy, {Node->Real, Node->Imag});
    }
    break;
  }
  case ComplexDeinterleavingOperation::ReductionPHI: {
    // If Operation is ReductionPHI, a new empty PHINode is created.
    // It is filled later when the ReductionOperation is processed.
    auto *VTy = cast<VectorType>(Node->Real->getType());
    auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
    auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
    OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
    ReplacementNode = NewPHI;
    break;
  }
  case ComplexDeinterleavingOperation::ReductionOperation:
    ReplacementNode = replaceNode(Builder, Node->Operands[0]);
    processReductionOperation(ReplacementNode, Node);
    break;
  case ComplexDeinterleavingOperation::ReductionSelect: {
    auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
    auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
    auto *A = replaceNode(Builder, Node->Operands[0]);
    auto *B = replaceNode(Builder, Node->Operands[1]);
    auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
        cast<VectorType>(MaskReal->getType()));
    auto *NewMask =
        Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
                                NewMaskTy, {MaskReal, MaskImag});
    ReplacementNode = Builder.CreateSelect(NewMask, A, B);
    break;
  }
  }

  assert(ReplacementNode && "Target failed to create Intrinsic call.");
  NumComplexTransformations += 1;
  Node->ReplacementNode = ReplacementNode;
  return ReplacementNode;
}

void ComplexDeinterleavingGraph::processReductionOperation(
    Value *OperationReplacement, RawNodePtr Node) {
  auto *Real = cast<Instruction>(Node->Real);
  auto *Imag = cast<Instruction>(Node->Imag);
  auto *OldPHIReal = ReductionInfo[Real].first;
  auto *OldPHIImag = ReductionInfo[Imag].first;
  auto *NewPHI = OldToNewPHI[OldPHIReal];

  auto *VTy = cast<VectorType>(Real->getType());
  auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);

  // We have to interleave initial origin values coming from IncomingBlock
  Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
  Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);

  IRBuilder<> Builder(Incoming->getTerminator());
  auto *NewInit = Builder.CreateIntrinsic(
      Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});

  NewPHI->addIncoming(NewInit, Incoming);
  NewPHI->addIncoming(OperationReplacement, BackEdge);

  // Deinterleave complex vector outside of loop so that it can be finally
  // reduced
  auto *FinalReductionReal = ReductionInfo[Real].second;
  auto *FinalReductionImag = ReductionInfo[Imag].second;

  Builder.SetInsertPoint(
      &*FinalReductionReal->getParent()->getFirstInsertionPt());
  auto *Deinterleave = Builder.CreateIntrinsic(
      Intrinsic::experimental_vector_deinterleave2,
      OperationReplacement->getType(), OperationReplacement);

  auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
  FinalReductionReal->replaceUsesOfWith(Real, NewReal);

  Builder.SetInsertPoint(FinalReductionImag);
  auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
  FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
}

void ComplexDeinterleavingGraph::replaceNodes() {
  SmallVector<Instruction *, 16> DeadInstrRoots;
  for (auto *RootInstruction : OrderedRoots) {
    // Check if this potential root went through check process and we can
    // deinterleave it
    if (!RootToNode.count(RootInstruction))
      continue;

    IRBuilder<> Builder(RootInstruction);
    auto RootNode = RootToNode[RootInstruction];
    Value *R = replaceNode(Builder, RootNode.get());

    if (RootNode->Operation ==
        ComplexDeinterleavingOperation::ReductionOperation) {
      auto *RootReal = cast<Instruction>(RootNode->Real);
      auto *RootImag = cast<Instruction>(RootNode->Imag);
      ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
      ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
      DeadInstrRoots.push_back(cast<Instruction>(RootReal));
      DeadInstrRoots.push_back(cast<Instruction>(RootImag));
    } else {
      assert(R && "Unable to find replacement for RootInstruction");
      DeadInstrRoots.push_back(RootInstruction);
      RootInstruction->replaceAllUsesWith(R);
    }
  }

  for (auto *I : DeadInstrRoots)
    RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}