//===- ComplexDeinterleavingPass.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Identification: // This step is responsible for finding the patterns that can be lowered to // complex instructions, and building a graph to represent the complex // structures. Starting from the "Converging Shuffle" (a shuffle that // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the // operands are evaluated and identified as "Composite Nodes" (collections of // instructions that can potentially be lowered to a single complex // instruction). This is performed by checking the real and imaginary components // and tracking the data flow for each component while following the operand // pairs. Validity of each node is expected to be done upon creation, and any // validation errors should halt traversal and prevent further graph // construction. // Instead of relying on Shuffle operations, vector interleaving and // deinterleaving can be represented by vector.interleave2 and // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by // these intrinsics, whereas, fixed-width vectors are recognized for both // shufflevector instruction and intrinsics. // // Replacement: // This step traverses the graph built up by identification, delegating to the // target to validate and generate the correct intrinsics, and plumbs them // together connecting each end of the new intrinsics graph to the existing // use-def chain. This step is assumed to finish successfully, as all // information is expected to be correct by this point. // // // Internal data structure: // ComplexDeinterleavingGraph: // Keeps references to all the valid CompositeNodes formed as part of the // transformation, and every Instruction contained within said nodes. It also // holds onto a reference to the root Instruction, and the root node that should // replace it. // // ComplexDeinterleavingCompositeNode: // A CompositeNode represents a single transformation point; each node should // transform into a single complex instruction (ignoring vector splitting, which // would generate more instructions per node). They are identified in a // depth-first manner, traversing and identifying the operands of each // instruction in the order they appear in the IR. // Each node maintains a reference to its Real and Imaginary instructions, // as well as any additional instructions that make up the identified operation // (Internal instructions should only have uses within their containing node). // A Node also contains the rotation and operation type that it represents. // Operands contains pointers to other CompositeNodes, acting as the edges in // the graph. ReplacementValue is the transformed Value* that has been emitted // to the IR. // // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and // ReplacementValue fields of that Node are relevant, where the ReplacementValue // should be pre-populated. // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/ComplexDeinterleavingPass.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/Local.h" #include using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "complex-deinterleaving" STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); static cl::opt ComplexDeinterleavingEnabled( "enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden); /// Checks the given mask, and determines whether said mask is interleaving. /// /// To be interleaving, a mask must alternate between `i` and `i + (Length / /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a /// 4x vector interleaving mask would be <0, 2, 1, 3>). static bool isInterleavingMask(ArrayRef Mask); /// Checks the given mask, and determines whether said mask is deinterleaving. /// /// To be deinterleaving, a mask must increment in steps of 2, and either start /// with 0 or 1. /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or /// <1, 3, 5, 7>). static bool isDeinterleavingMask(ArrayRef Mask); /// Returns true if the operation is a negation of V, and it works for both /// integers and floats. static bool isNeg(Value *V); /// Returns the operand for negation operation. static Value *getNegOperand(Value *V); namespace { class ComplexDeinterleavingLegacyPass : public FunctionPass { public: static char ID; ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) : FunctionPass(ID), TM(TM) { initializeComplexDeinterleavingLegacyPassPass( *PassRegistry::getPassRegistry()); } StringRef getPassName() const override { return "Complex Deinterleaving Pass"; } bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.setPreservesCFG(); } private: const TargetMachine *TM; }; class ComplexDeinterleavingGraph; struct ComplexDeinterleavingCompositeNode { ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, Value *R, Value *I) : Operation(Op), Real(R), Imag(I) {} private: friend class ComplexDeinterleavingGraph; using NodePtr = std::shared_ptr; using RawNodePtr = ComplexDeinterleavingCompositeNode *; public: ComplexDeinterleavingOperation Operation; Value *Real; Value *Imag; // This two members are required exclusively for generating // ComplexDeinterleavingOperation::Symmetric operations. unsigned Opcode; std::optional Flags; ComplexDeinterleavingRotation Rotation = ComplexDeinterleavingRotation::Rotation_0; SmallVector Operands; Value *ReplacementNode = nullptr; void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { auto PrintValue = [&](Value *V) { if (V) { OS << "\""; V->print(OS, true); OS << "\"\n"; } else OS << "nullptr\n"; }; auto PrintNodeRef = [&](RawNodePtr Ptr) { if (Ptr) OS << Ptr << "\n"; else OS << "nullptr\n"; }; OS << "- CompositeNode: " << this << "\n"; OS << " Real: "; PrintValue(Real); OS << " Imag: "; PrintValue(Imag); OS << " ReplacementNode: "; PrintValue(ReplacementNode); OS << " Operation: " << (int)Operation << "\n"; OS << " Rotation: " << ((int)Rotation * 90) << "\n"; OS << " Operands: \n"; for (const auto &Op : Operands) { OS << " - "; PrintNodeRef(Op); } } }; class ComplexDeinterleavingGraph { public: struct Product { Value *Multiplier; Value *Multiplicand; bool IsPositive; }; using Addend = std::pair; using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; // Helper struct for holding info about potential partial multiplication // candidates struct PartialMulCandidate { Value *Common; NodePtr Node; unsigned RealIdx; unsigned ImagIdx; bool IsNodeInverted; }; explicit ComplexDeinterleavingGraph(const TargetLowering *TL, const TargetLibraryInfo *TLI) : TL(TL), TLI(TLI) {} private: const TargetLowering *TL = nullptr; const TargetLibraryInfo *TLI = nullptr; SmallVector CompositeNodes; DenseMap, NodePtr> CachedResult; SmallPtrSet FinalInstructions; /// Root instructions are instructions from which complex computation starts std::map RootToNode; /// Topologically sorted root instructions SmallVector OrderedRoots; /// When examining a basic block for complex deinterleaving, if it is a simple /// one-block loop, then the only incoming block is 'Incoming' and the /// 'BackEdge' block is the block itself." BasicBlock *BackEdge = nullptr; BasicBlock *Incoming = nullptr; /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction /// %OutsideUser as it is shown in the IR: /// /// vector.body: /// %PHInode = phi [ zeroinitializer, %entry ], /// [ %ReductionOp, %vector.body ] /// ... /// %ReductionOp = fadd i64 ... /// ... /// br i1 %condition, label %vector.body, %middle.block /// /// middle.block: /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) /// /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding /// `llvm.vector.reduce.fadd` when unroll factor isn't one. MapVector> ReductionInfo; /// In the process of detecting a reduction, we consider a pair of /// %ReductionOP, which we refer to as real and imag (or vice versa), and /// traverse the use-tree to detect complex operations. As this is a reduction /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds /// to the %ReductionOPs that we suspect to be complex. /// RealPHI and ImagPHI are used by the identifyPHINode method. PHINode *RealPHI = nullptr; PHINode *ImagPHI = nullptr; /// Set this flag to true if RealPHI and ImagPHI were reached during reduction /// detection. bool PHIsFound = false; /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. /// The new PHINode corresponds to a vector of deinterleaved complex numbers. /// This mapping is populated during /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then /// used in the ComplexDeinterleavingOperation::ReductionOperation node /// replacement process. std::map OldToNewPHI; NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, Value *R, Value *I) { assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && Operation != ComplexDeinterleavingOperation::ReductionOperation) || (R && I)) && "Reduction related nodes must have Real and Imaginary parts"); return std::make_shared(Operation, R, I); } NodePtr submitCompositeNode(NodePtr Node) { CompositeNodes.push_back(Node); if (Node->Real && Node->Imag) CachedResult[{Node->Real, Node->Imag}] = Node; return Node; } /// Identifies a complex partial multiply pattern and its rotation, based on /// the following patterns /// /// 0: r: cr + ar * br /// i: ci + ar * bi /// 90: r: cr - ai * bi /// i: ci + ai * br /// 180: r: cr - ar * br /// i: ci - ar * bi /// 270: r: cr + ai * bi /// i: ci - ai * br NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); /// Identify the other branch of a Partial Mul, taking the CommonOperandI that /// is partially known from identifyPartialMul, filling in the other half of /// the complex pair. NodePtr identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, std::pair &CommonOperandI); /// Identifies a complex add pattern and its rotation, based on the following /// patterns. /// /// 90: r: ar - bi /// i: ai + br /// 270: r: ar + bi /// i: ai - br NodePtr identifyAdd(Instruction *Real, Instruction *Imag); NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); NodePtr identifyNode(Value *R, Value *I); /// Determine if a sum of complex numbers can be formed from \p RealAddends /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. /// Return nullptr if it is not possible to construct a complex number. /// \p Flags are needed to generate symmetric Add and Sub operations. NodePtr identifyAdditions(std::list &RealAddends, std::list &ImagAddends, std::optional Flags, NodePtr Accumulator); /// Extract one addend that have both real and imaginary parts positive. NodePtr extractPositiveAddend(std::list &RealAddends, std::list &ImagAddends); /// Determine if sum of multiplications of complex numbers can be formed from /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result /// to it. Return nullptr if it is not possible to construct a complex number. NodePtr identifyMultiplications(std::vector &RealMuls, std::vector &ImagMuls, NodePtr Accumulator); /// Go through pairs of multiplication (one Real and one Imag) and find all /// possible candidates for partial multiplication and put them into \p /// Candidates. Returns true if all Product has pair with common operand bool collectPartialMuls(const std::vector &RealMuls, const std::vector &ImagMuls, std::vector &Candidates); /// If the code is compiled with -Ofast or expressions have `reassoc` flag, /// the order of complex computation operations may be significantly altered, /// and the real and imaginary parts may not be executed in parallel. This /// function takes this into consideration and employs a more general approach /// to identify complex computations. Initially, it gathers all the addends /// and multiplicands and then constructs a complex expression from them. NodePtr identifyReassocNodes(Instruction *I, Instruction *J); NodePtr identifyRoot(Instruction *I); /// Identifies the Deinterleave operation applied to a vector containing /// complex numbers. There are two ways to represent the Deinterleave /// operation: /// * Using two shufflevectors with even indices for /pReal instruction and /// odd indices for /pImag instructions (only for fixed-width vectors) /// * Using two extractvalue instructions applied to `vector.deinterleave2` /// intrinsic (for both fixed and scalable vectors) NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); /// identifying the operation that represents a complex number repeated in a /// Splat vector. There are two possible types of splats: ConstantExpr with /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an /// initialization mask with all values set to zero. NodePtr identifySplat(Value *Real, Value *Imag); NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); /// Identifies SelectInsts in a loop that has reduction with predication masks /// and/or predicated tail folding NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); /// Complete IR modifications after producing new reduction operation: /// * Populate the PHINode generated for /// ComplexDeinterleavingOperation::ReductionPHI /// * Deinterleave the final value outside of the loop and repurpose original /// reduction users void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); public: void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { for (const auto &Node : CompositeNodes) Node->dump(OS); } /// Returns false if the deinterleaving operation should be cancelled for the /// current graph. bool identifyNodes(Instruction *RootI); /// In case \pB is one-block loop, this function seeks potential reductions /// and populates ReductionInfo. Returns true if any reductions were /// identified. bool collectPotentialReductions(BasicBlock *B); void identifyReductionNodes(); /// Check that every instruction, from the roots to the leaves, has internal /// uses. bool checkNodes(); /// Perform the actual replacement of the underlying instruction graph. void replaceNodes(); }; class ComplexDeinterleaving { public: ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) : TL(tl), TLI(tli) {} bool runOnFunction(Function &F); private: bool evaluateBasicBlock(BasicBlock *B); const TargetLowering *TL = nullptr; const TargetLibraryInfo *TLI = nullptr; }; } // namespace char ComplexDeinterleavingLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, "Complex Deinterleaving", false, false) INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, "Complex Deinterleaving", false, false) PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, FunctionAnalysisManager &AM) { const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); auto &TLI = AM.getResult(F); if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve(); return PA; } FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { return new ComplexDeinterleavingLegacyPass(TM); } bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); auto TLI = getAnalysis().getTLI(F); return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); } bool ComplexDeinterleaving::runOnFunction(Function &F) { if (!ComplexDeinterleavingEnabled) { LLVM_DEBUG( dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); return false; } if (!TL->isComplexDeinterleavingSupported()) { LLVM_DEBUG( dbgs() << "Complex deinterleaving has been disabled, target does " "not support lowering of complex number operations.\n"); return false; } bool Changed = false; for (auto &B : F) Changed |= evaluateBasicBlock(&B); return Changed; } static bool isInterleavingMask(ArrayRef Mask) { // If the size is not even, it's not an interleaving mask if ((Mask.size() & 1)) return false; int HalfNumElements = Mask.size() / 2; for (int Idx = 0; Idx < HalfNumElements; ++Idx) { int MaskIdx = Idx * 2; if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) return false; } return true; } static bool isDeinterleavingMask(ArrayRef Mask) { int Offset = Mask[0]; int HalfNumElements = Mask.size() / 2; for (int Idx = 1; Idx < HalfNumElements; ++Idx) { if (Mask[Idx] != (Idx * 2) + Offset) return false; } return true; } bool isNeg(Value *V) { return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); } Value *getNegOperand(Value *V) { assert(isNeg(V)); auto *I = cast(V); if (I->getOpcode() == Instruction::FNeg) return I->getOperand(0); return I->getOperand(1); } bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { ComplexDeinterleavingGraph Graph(TL, TLI); if (Graph.collectPotentialReductions(B)) Graph.identifyReductionNodes(); for (auto &I : *B) Graph.identifyNodes(&I); if (Graph.checkNodes()) { Graph.replaceNodes(); return true; } return false; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( Instruction *Real, Instruction *Imag, std::pair &PartialMatch) { LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag << "\n"); if (!Real->hasOneUse() || !Imag->hasOneUse()) { LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); return nullptr; } if ((Real->getOpcode() != Instruction::FMul && Real->getOpcode() != Instruction::Mul) || (Imag->getOpcode() != Instruction::FMul && Imag->getOpcode() != Instruction::Mul)) { LLVM_DEBUG( dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); return nullptr; } Value *R0 = Real->getOperand(0); Value *R1 = Real->getOperand(1); Value *I0 = Imag->getOperand(0); Value *I1 = Imag->getOperand(1); // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the // rotations and use the operand. unsigned Negs = 0; Value *Op; if (match(R0, m_Neg(m_Value(Op)))) { Negs |= 1; R0 = Op; } else if (match(R1, m_Neg(m_Value(Op)))) { Negs |= 1; R1 = Op; } if (isNeg(I0)) { Negs |= 2; Negs ^= 1; I0 = Op; } else if (match(I1, m_Neg(m_Value(Op)))) { Negs |= 2; Negs ^= 1; I1 = Op; } ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; Value *CommonOperand; Value *UncommonRealOp; Value *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; UncommonRealOp = R1; } else if (R1 == I0 || R1 == I1) { CommonOperand = R1; UncommonRealOp = R0; } else { LLVM_DEBUG(dbgs() << " - No equal operand\n"); return nullptr; } UncommonImagOp = (CommonOperand == I0) ? I1 : I0; if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); // Between identifyPartialMul and here we need to have found a complete valid // pair from the CommonOperand of each part. if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) PartialMatch.first = CommonOperand; else PartialMatch.second = CommonOperand; if (!PartialMatch.first || !PartialMatch.second) { LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); return nullptr; } NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); if (!CommonNode) { LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); return nullptr; } NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); if (!UncommonNode) { LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); return nullptr; } NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); Node->Rotation = Rotation; Node->addOperand(CommonNode); Node->addOperand(UncommonNode); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag << "\n"); // Determine rotation auto IsAdd = [](unsigned Op) { return Op == Instruction::FAdd || Op == Instruction::Add; }; auto IsSub = [](unsigned Op) { return Op == Instruction::FSub || Op == Instruction::Sub; }; ComplexDeinterleavingRotation Rotation; if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) Rotation = ComplexDeinterleavingRotation::Rotation_0; else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) Rotation = ComplexDeinterleavingRotation::Rotation_180; else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) Rotation = ComplexDeinterleavingRotation::Rotation_270; else { LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); return nullptr; } if (isa(Real) && (!Real->getFastMathFlags().allowContract() || !Imag->getFastMathFlags().allowContract())) { LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); return nullptr; } Value *CR = Real->getOperand(0); Instruction *RealMulI = dyn_cast(Real->getOperand(1)); if (!RealMulI) return nullptr; Value *CI = Imag->getOperand(0); Instruction *ImagMulI = dyn_cast(Imag->getOperand(1)); if (!ImagMulI) return nullptr; if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); return nullptr; } Value *R0 = RealMulI->getOperand(0); Value *R1 = RealMulI->getOperand(1); Value *I0 = ImagMulI->getOperand(0); Value *I1 = ImagMulI->getOperand(1); Value *CommonOperand; Value *UncommonRealOp; Value *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; UncommonRealOp = R1; } else if (R1 == I0 || R1 == I1) { CommonOperand = R1; UncommonRealOp = R0; } else { LLVM_DEBUG(dbgs() << " - No equal operand\n"); return nullptr; } UncommonImagOp = (CommonOperand == I0) ? I1 : I0; if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); std::pair PartialMatch( (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) ? CommonOperand : nullptr, (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) ? CommonOperand : nullptr); auto *CRInst = dyn_cast(CR); auto *CIInst = dyn_cast(CI); if (!CRInst || !CIInst) { LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); return nullptr; } NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); if (!CNode) { LLVM_DEBUG(dbgs() << " - No cnode identified\n"); return nullptr; } NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); if (!UncommonRes) { LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); return nullptr; } assert(PartialMatch.first && PartialMatch.second); NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); if (!CommonRes) { LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); return nullptr; } NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); Node->Rotation = Rotation; Node->addOperand(CommonRes); Node->addOperand(UncommonRes); Node->addOperand(CNode); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); // Determine rotation ComplexDeinterleavingRotation Rotation; if ((Real->getOpcode() == Instruction::FSub && Imag->getOpcode() == Instruction::FAdd) || (Real->getOpcode() == Instruction::Sub && Imag->getOpcode() == Instruction::Add)) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if ((Real->getOpcode() == Instruction::FAdd && Imag->getOpcode() == Instruction::FSub) || (Real->getOpcode() == Instruction::Add && Imag->getOpcode() == Instruction::Sub)) Rotation = ComplexDeinterleavingRotation::Rotation_270; else { LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); return nullptr; } auto *AR = dyn_cast(Real->getOperand(0)); auto *BI = dyn_cast(Real->getOperand(1)); auto *AI = dyn_cast(Imag->getOperand(0)); auto *BR = dyn_cast(Imag->getOperand(1)); if (!AR || !AI || !BR || !BI) { LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); return nullptr; } NodePtr ResA = identifyNode(AR, AI); if (!ResA) { LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); return nullptr; } NodePtr ResB = identifyNode(BR, BI); if (!ResB) { LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); return nullptr; } NodePtr Node = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); Node->Rotation = Rotation; Node->addOperand(ResA); Node->addOperand(ResB); return submitCompositeNode(Node); } static bool isInstructionPairAdd(Instruction *A, Instruction *B) { unsigned OpcA = A->getOpcode(); unsigned OpcB = B->getOpcode(); return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || (OpcA == Instruction::Sub && OpcB == Instruction::Add) || (OpcA == Instruction::Add && OpcB == Instruction::Sub); } static bool isInstructionPairMul(Instruction *A, Instruction *B) { auto Pattern = m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); return match(A, Pattern) && match(B, Pattern); } static bool isInstructionPotentiallySymmetric(Instruction *I) { switch (I->getOpcode()) { case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: case Instruction::FNeg: case Instruction::Add: case Instruction::Sub: case Instruction::Mul: return true; default: return false; } } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, Instruction *Imag) { if (Real->getOpcode() != Imag->getOpcode()) return nullptr; if (!isInstructionPotentiallySymmetric(Real) || !isInstructionPotentiallySymmetric(Imag)) return nullptr; auto *R0 = Real->getOperand(0); auto *I0 = Imag->getOperand(0); NodePtr Op0 = identifyNode(R0, I0); NodePtr Op1 = nullptr; if (Op0 == nullptr) return nullptr; if (Real->isBinaryOp()) { auto *R1 = Real->getOperand(1); auto *I1 = Imag->getOperand(1); Op1 = identifyNode(R1, I1); if (Op1 == nullptr) return nullptr; } if (isa(Real) && Real->getFastMathFlags() != Imag->getFastMathFlags()) return nullptr; auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Real, Imag); Node->Opcode = Real->getOpcode(); if (isa(Real)) Node->Flags = Real->getFastMathFlags(); Node->addOperand(Op0); if (Real->isBinaryOp()) Node->addOperand(Op1); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); assert(R->getType() == I->getType() && "Real and imaginary parts should not have different types"); auto It = CachedResult.find({R, I}); if (It != CachedResult.end()) { LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); return It->second; } if (NodePtr CN = identifySplat(R, I)) return CN; auto *Real = dyn_cast(R); auto *Imag = dyn_cast(I); if (!Real || !Imag) return nullptr; if (NodePtr CN = identifyDeinterleave(Real, Imag)) return CN; if (NodePtr CN = identifyPHINode(Real, Imag)) return CN; if (NodePtr CN = identifySelectNode(Real, Imag)) return CN; auto *VTy = cast(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( ComplexDeinterleavingOperation::CMulPartial, NewVTy); bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( ComplexDeinterleavingOperation::CAdd, NewVTy); if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { if (NodePtr CN = identifyPartialMul(Real, Imag)) return CN; } if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { if (NodePtr CN = identifyAdd(Real, Imag)) return CN; } if (HasCMulSupport && HasCAddSupport) { if (NodePtr CN = identifyReassocNodes(Real, Imag)) return CN; } if (NodePtr CN = identifySymmetricOperation(Real, Imag)) return CN; LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); CachedResult[{R, I}] = nullptr; return nullptr; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, Instruction *Imag) { auto IsOperationSupported = [](unsigned Opcode) -> bool { return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || Opcode == Instruction::FNeg || Opcode == Instruction::Add || Opcode == Instruction::Sub; }; if (!IsOperationSupported(Real->getOpcode()) || !IsOperationSupported(Imag->getOpcode())) return nullptr; std::optional Flags; if (isa(Real)) { if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " "not identical\n"); return nullptr; } Flags = Real->getFastMathFlags(); if (!Flags->allowReassoc()) { LLVM_DEBUG( dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n"); return nullptr; } } // Collect multiplications and addend instructions from the given instruction // while traversing it operands. Additionally, verify that all instructions // have the same fast math flags. auto Collect = [&Flags](Instruction *Insn, std::vector &Muls, std::list &Addends) -> bool { SmallVector> Worklist = {{Insn, true}}; SmallPtrSet Visited; while (!Worklist.empty()) { auto [V, IsPositive] = Worklist.back(); Worklist.pop_back(); if (!Visited.insert(V).second) continue; Instruction *I = dyn_cast(V); if (!I) { Addends.emplace_back(V, IsPositive); continue; } // If an instruction has more than one user, it indicates that it either // has an external user, which will be later checked by the checkNodes // function, or it is a subexpression utilized by multiple expressions. In // the latter case, we will attempt to separately identify the complex // operation from here in order to create a shared // ComplexDeinterleavingCompositeNode. if (I != Insn && I->getNumUses() > 1) { LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); Addends.emplace_back(I, IsPositive); continue; } switch (I->getOpcode()) { case Instruction::FAdd: case Instruction::Add: Worklist.emplace_back(I->getOperand(1), IsPositive); Worklist.emplace_back(I->getOperand(0), IsPositive); break; case Instruction::FSub: Worklist.emplace_back(I->getOperand(1), !IsPositive); Worklist.emplace_back(I->getOperand(0), IsPositive); break; case Instruction::Sub: if (isNeg(I)) { Worklist.emplace_back(getNegOperand(I), !IsPositive); } else { Worklist.emplace_back(I->getOperand(1), !IsPositive); Worklist.emplace_back(I->getOperand(0), IsPositive); } break; case Instruction::FMul: case Instruction::Mul: { Value *A, *B; if (isNeg(I->getOperand(0))) { A = getNegOperand(I->getOperand(0)); IsPositive = !IsPositive; } else { A = I->getOperand(0); } if (isNeg(I->getOperand(1))) { B = getNegOperand(I->getOperand(1)); IsPositive = !IsPositive; } else { B = I->getOperand(1); } Muls.push_back(Product{A, B, IsPositive}); break; } case Instruction::FNeg: Worklist.emplace_back(I->getOperand(0), !IsPositive); break; default: Addends.emplace_back(I, IsPositive); continue; } if (Flags && I->getFastMathFlags() != *Flags) { LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " "inconsistent with the root instructions' flags: " << *I << "\n"); return false; } } return true; }; std::vector RealMuls, ImagMuls; std::list RealAddends, ImagAddends; if (!Collect(Real, RealMuls, RealAddends) || !Collect(Imag, ImagMuls, ImagAddends)) return nullptr; if (RealAddends.size() != ImagAddends.size()) return nullptr; NodePtr FinalNode; if (!RealMuls.empty() || !ImagMuls.empty()) { // If there are multiplicands, extract positive addend and use it as an // accumulator FinalNode = extractPositiveAddend(RealAddends, ImagAddends); FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); if (!FinalNode) return nullptr; } // Identify and process remaining additions if (!RealAddends.empty() || !ImagAddends.empty()) { FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); if (!FinalNode) return nullptr; } assert(FinalNode && "FinalNode can not be nullptr here"); // Set the Real and Imag fields of the final node and submit it FinalNode->Real = Real; FinalNode->Imag = Imag; submitCompositeNode(FinalNode); return FinalNode; } bool ComplexDeinterleavingGraph::collectPartialMuls( const std::vector &RealMuls, const std::vector &ImagMuls, std::vector &PartialMulCandidates) { // Helper function to extract a common operand from two products auto FindCommonInstruction = [](const Product &Real, const Product &Imag) -> Value * { if (Real.Multiplicand == Imag.Multiplicand || Real.Multiplicand == Imag.Multiplier) return Real.Multiplicand; if (Real.Multiplier == Imag.Multiplicand || Real.Multiplier == Imag.Multiplier) return Real.Multiplier; return nullptr; }; // Iterating over real and imaginary multiplications to find common operands // If a common operand is found, a partial multiplication candidate is created // and added to the candidates vector The function returns false if no common // operands are found for any product for (unsigned i = 0; i < RealMuls.size(); ++i) { bool FoundCommon = false; for (unsigned j = 0; j < ImagMuls.size(); ++j) { auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); if (!Common) continue; auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier : RealMuls[i].Multiplicand; auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier : ImagMuls[j].Multiplicand; auto Node = identifyNode(A, B); if (Node) { FoundCommon = true; PartialMulCandidates.push_back({Common, Node, i, j, false}); } Node = identifyNode(B, A); if (Node) { FoundCommon = true; PartialMulCandidates.push_back({Common, Node, i, j, true}); } } if (!FoundCommon) return false; } return true; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyMultiplications( std::vector &RealMuls, std::vector &ImagMuls, NodePtr Accumulator = nullptr) { if (RealMuls.size() != ImagMuls.size()) return nullptr; std::vector Info; if (!collectPartialMuls(RealMuls, ImagMuls, Info)) return nullptr; // Map to store common instruction to node pointers std::map CommonToNode; std::vector Processed(Info.size(), false); for (unsigned I = 0; I < Info.size(); ++I) { if (Processed[I]) continue; PartialMulCandidate &InfoA = Info[I]; for (unsigned J = I + 1; J < Info.size(); ++J) { if (Processed[J]) continue; PartialMulCandidate &InfoB = Info[J]; auto *InfoReal = &InfoA; auto *InfoImag = &InfoB; auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); if (!NodeFromCommon) { std::swap(InfoReal, InfoImag); NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); } if (!NodeFromCommon) continue; CommonToNode[InfoReal->Common] = NodeFromCommon; CommonToNode[InfoImag->Common] = NodeFromCommon; Processed[I] = true; Processed[J] = true; } } std::vector ProcessedReal(RealMuls.size(), false); std::vector ProcessedImag(ImagMuls.size(), false); NodePtr Result = Accumulator; for (auto &PMI : Info) { if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) continue; auto It = CommonToNode.find(PMI.Common); // TODO: Process independent complex multiplications. Cases like this: // A.real() * B where both A and B are complex numbers. if (It == CommonToNode.end()) { LLVM_DEBUG({ dbgs() << "Unprocessed independent partial multiplication:\n"; for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier << " multiplied by " << *Mul->Multiplicand << "\n"; }); return nullptr; } auto &RealMul = RealMuls[PMI.RealIdx]; auto &ImagMul = ImagMuls[PMI.ImagIdx]; auto NodeA = It->second; auto NodeB = PMI.Node; auto IsMultiplicandReal = PMI.Common == NodeA->Real; // The following table illustrates the relationship between multiplications // and rotations. If we consider the multiplication (X + iY) * (U + iV), we // can see: // // Rotation | Real | Imag | // ---------+--------+--------+ // 0 | x * u | x * v | // 90 | -y * v | y * u | // 180 | -x * u | -x * v | // 270 | y * v | -y * u | // // Check if the candidate can indeed be represented by partial // multiplication // TODO: Add support for multiplication by complex one if ((IsMultiplicandReal && PMI.IsNodeInverted) || (!IsMultiplicandReal && !PMI.IsNodeInverted)) continue; // Determine the rotation based on the multiplications ComplexDeinterleavingRotation Rotation; if (IsMultiplicandReal) { // Detect 0 and 180 degrees rotation if (RealMul.IsPositive && ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; else if (!RealMul.IsPositive && !ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; else continue; } else { // Detect 90 and 270 degrees rotation if (!RealMul.IsPositive && ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; else if (RealMul.IsPositive && !ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; else continue; } LLVM_DEBUG({ dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; }); NodePtr NodeMul = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); NodeMul->Rotation = Rotation; NodeMul->addOperand(NodeA); NodeMul->addOperand(NodeB); if (Result) NodeMul->addOperand(Result); submitCompositeNode(NodeMul); Result = NodeMul; ProcessedReal[PMI.RealIdx] = true; ProcessedImag[PMI.ImagIdx] = true; } // Ensure all products have been processed, if not return nullptr. if (!all_of(ProcessedReal, [](bool V) { return V; }) || !all_of(ProcessedImag, [](bool V) { return V; })) { // Dump debug information about which partial multiplications are not // processed. LLVM_DEBUG({ dbgs() << "Unprocessed products (Real):\n"; for (size_t i = 0; i < ProcessedReal.size(); ++i) { if (!ProcessedReal[i]) dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") << *RealMuls[i].Multiplier << " multiplied by " << *RealMuls[i].Multiplicand << "\n"; } dbgs() << "Unprocessed products (Imag):\n"; for (size_t i = 0; i < ProcessedImag.size(); ++i) { if (!ProcessedImag[i]) dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") << *ImagMuls[i].Multiplier << " multiplied by " << *ImagMuls[i].Multiplicand << "\n"; } }); return nullptr; } return Result; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyAdditions( std::list &RealAddends, std::list &ImagAddends, std::optional Flags, NodePtr Accumulator = nullptr) { if (RealAddends.size() != ImagAddends.size()) return nullptr; NodePtr Result; // If we have accumulator use it as first addend if (Accumulator) Result = Accumulator; // Otherwise find an element with both positive real and imaginary parts. else Result = extractPositiveAddend(RealAddends, ImagAddends); if (!Result) return nullptr; while (!RealAddends.empty()) { auto ItR = RealAddends.begin(); auto [R, IsPositiveR] = *ItR; bool FoundImag = false; for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { auto [I, IsPositiveI] = *ItI; ComplexDeinterleavingRotation Rotation; if (IsPositiveR && IsPositiveI) Rotation = ComplexDeinterleavingRotation::Rotation_0; else if (!IsPositiveR && IsPositiveI) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if (!IsPositiveR && !IsPositiveI) Rotation = ComplexDeinterleavingRotation::Rotation_180; else Rotation = ComplexDeinterleavingRotation::Rotation_270; NodePtr AddNode; if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) { AddNode = identifyNode(R, I); } else { AddNode = identifyNode(I, R); } if (AddNode) { LLVM_DEBUG({ dbgs() << "Identified addition:\n"; dbgs().indent(4) << "X: " << *R << "\n"; dbgs().indent(4) << "Y: " << *I << "\n"; dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; }); NodePtr TmpNode; if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { TmpNode = prepareCompositeNode( ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); if (Flags) { TmpNode->Opcode = Instruction::FAdd; TmpNode->Flags = *Flags; } else { TmpNode->Opcode = Instruction::Add; } } else if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_180) { TmpNode = prepareCompositeNode( ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); if (Flags) { TmpNode->Opcode = Instruction::FSub; TmpNode->Flags = *Flags; } else { TmpNode->Opcode = Instruction::Sub; } } else { TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, nullptr, nullptr); TmpNode->Rotation = Rotation; } TmpNode->addOperand(Result); TmpNode->addOperand(AddNode); submitCompositeNode(TmpNode); Result = TmpNode; RealAddends.erase(ItR); ImagAddends.erase(ItI); FoundImag = true; break; } } if (!FoundImag) return nullptr; } return Result; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::extractPositiveAddend( std::list &RealAddends, std::list &ImagAddends) { for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { auto [R, IsPositiveR] = *ItR; auto [I, IsPositiveI] = *ItI; if (IsPositiveR && IsPositiveI) { auto Result = identifyNode(R, I); if (Result) { RealAddends.erase(ItR); ImagAddends.erase(ItI); return Result; } } } } return nullptr; } bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { // This potential root instruction might already have been recognized as // reduction. Because RootToNode maps both Real and Imaginary parts to // CompositeNode we should choose only one either Real or Imag instruction to // use as an anchor for generating complex instruction. auto It = RootToNode.find(RootI); if (It != RootToNode.end()) { auto RootNode = It->second; assert(RootNode->Operation == ComplexDeinterleavingOperation::ReductionOperation); // Find out which part, Real or Imag, comes later, and only if we come to // the latest part, add it to OrderedRoots. auto *R = cast(RootNode->Real); auto *I = cast(RootNode->Imag); auto *ReplacementAnchor = R->comesBefore(I) ? I : R; if (ReplacementAnchor != RootI) return false; OrderedRoots.push_back(RootI); return true; } auto RootNode = identifyRoot(RootI); if (!RootNode) return false; LLVM_DEBUG({ Function *F = RootI->getFunction(); BasicBlock *B = RootI->getParent(); dbgs() << "Complex deinterleaving graph for " << F->getName() << "::" << B->getName() << ".\n"; dump(dbgs()); dbgs() << "\n"; }); RootToNode[RootI] = RootNode; OrderedRoots.push_back(RootI); return true; } bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { bool FoundPotentialReduction = false; auto *Br = dyn_cast(B->getTerminator()); if (!Br || Br->getNumSuccessors() != 2) return false; // Identify simple one-block loop if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) return false; SmallVector PHIs; for (auto &PHI : B->phis()) { if (PHI.getNumIncomingValues() != 2) continue; if (!PHI.getType()->isVectorTy()) continue; auto *ReductionOp = dyn_cast(PHI.getIncomingValueForBlock(B)); if (!ReductionOp) continue; // Check if final instruction is reduced outside of current block Instruction *FinalReduction = nullptr; auto NumUsers = 0u; for (auto *U : ReductionOp->users()) { ++NumUsers; if (U == &PHI) continue; FinalReduction = dyn_cast(U); } if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || isa(FinalReduction)) continue; ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; BackEdge = B; auto BackEdgeIdx = PHI.getBasicBlockIndex(B); auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; Incoming = PHI.getIncomingBlock(IncomingIdx); FoundPotentialReduction = true; // If the initial value of PHINode is an Instruction, consider it a leaf // value of a complex deinterleaving graph. if (auto *InitPHI = dyn_cast(PHI.getIncomingValueForBlock(Incoming))) FinalInstructions.insert(InitPHI); } return FoundPotentialReduction; } void ComplexDeinterleavingGraph::identifyReductionNodes() { SmallVector Processed(ReductionInfo.size(), false); SmallVector OperationInstruction; for (auto &P : ReductionInfo) OperationInstruction.push_back(P.first); // Identify a complex computation by evaluating two reduction operations that // potentially could be involved for (size_t i = 0; i < OperationInstruction.size(); ++i) { if (Processed[i]) continue; for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { if (Processed[j]) continue; auto *Real = OperationInstruction[i]; auto *Imag = OperationInstruction[j]; if (Real->getType() != Imag->getType()) continue; RealPHI = ReductionInfo[Real].first; ImagPHI = ReductionInfo[Imag].first; PHIsFound = false; auto Node = identifyNode(Real, Imag); if (!Node) { std::swap(Real, Imag); std::swap(RealPHI, ImagPHI); Node = identifyNode(Real, Imag); } // If a node is identified and reduction PHINode is used in the chain of // operations, mark its operation instructions as used to prevent // re-identification and attach the node to the real part if (Node && PHIsFound) { LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " << *Real << " / " << *Imag << "\n"); Processed[i] = true; Processed[j] = true; auto RootNode = prepareCompositeNode( ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); RootNode->addOperand(Node); RootToNode[Real] = RootNode; RootToNode[Imag] = RootNode; submitCompositeNode(RootNode); break; } } } RealPHI = nullptr; ImagPHI = nullptr; } bool ComplexDeinterleavingGraph::checkNodes() { // Collect all instructions from roots to leaves SmallPtrSet AllInstructions; SmallVector Worklist; for (auto &Pair : RootToNode) Worklist.push_back(Pair.first); // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG // chains while (!Worklist.empty()) { auto *I = Worklist.back(); Worklist.pop_back(); if (!AllInstructions.insert(I).second) continue; for (Value *Op : I->operands()) { if (auto *OpI = dyn_cast(Op)) { if (!FinalInstructions.count(I)) Worklist.emplace_back(OpI); } } } // Find instructions that have users outside of chain SmallVector OuterInstructions; for (auto *I : AllInstructions) { // Skip root nodes if (RootToNode.count(I)) continue; for (User *U : I->users()) { if (AllInstructions.count(cast(U))) continue; // Found an instruction that is not used by XCMLA/XCADD chain Worklist.emplace_back(I); break; } } // If any instructions are found to be used outside, find and remove roots // that somehow connect to those instructions. SmallPtrSet Visited; while (!Worklist.empty()) { auto *I = Worklist.back(); Worklist.pop_back(); if (!Visited.insert(I).second) continue; // Found an impacted root node. Removing it from the nodes to be // deinterleaved if (RootToNode.count(I)) { LLVM_DEBUG(dbgs() << "Instruction " << *I << " could be deinterleaved but its chain of complex " "operations have an outside user\n"); RootToNode.erase(I); } if (!AllInstructions.count(I) || FinalInstructions.count(I)) continue; for (User *U : I->users()) Worklist.emplace_back(cast(U)); for (Value *Op : I->operands()) { if (auto *OpI = dyn_cast(Op)) Worklist.emplace_back(OpI); } } return !RootToNode.empty(); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { if (auto *Intrinsic = dyn_cast(RootI)) { if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2) return nullptr; auto *Real = dyn_cast(Intrinsic->getOperand(0)); auto *Imag = dyn_cast(Intrinsic->getOperand(1)); if (!Real || !Imag) return nullptr; return identifyNode(Real, Imag); } auto *SVI = dyn_cast(RootI); if (!SVI) return nullptr; // Look for a shufflevector that takes separate vectors of the real and // imaginary components and recombines them into a single vector. if (!isInterleavingMask(SVI->getShuffleMask())) return nullptr; Instruction *Real; Instruction *Imag; if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) return nullptr; return identifyNode(Real, Imag); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, Instruction *Imag) { Instruction *I = nullptr; Value *FinalValue = nullptr; if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && match(Imag, m_ExtractValue<1>(m_Specific(I))) && match(I, m_Intrinsic( m_Value(FinalValue)))) { NodePtr PlaceholderNode = prepareCompositeNode( llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); PlaceholderNode->ReplacementNode = FinalValue; FinalInstructions.insert(Real); FinalInstructions.insert(Imag); return submitCompositeNode(PlaceholderNode); } auto *RealShuffle = dyn_cast(Real); auto *ImagShuffle = dyn_cast(Imag); if (!RealShuffle || !ImagShuffle) { if (RealShuffle || ImagShuffle) LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); return nullptr; } Value *RealOp1 = RealShuffle->getOperand(1); if (!isa(RealOp1) && !isa(RealOp1)) { LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); return nullptr; } Value *ImagOp1 = ImagShuffle->getOperand(1); if (!isa(ImagOp1) && !isa(ImagOp1)) { LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); return nullptr; } Value *RealOp0 = RealShuffle->getOperand(0); Value *ImagOp0 = ImagShuffle->getOperand(0); if (RealOp0 != ImagOp0) { LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); return nullptr; } ArrayRef RealMask = RealShuffle->getShuffleMask(); ArrayRef ImagMask = ImagShuffle->getShuffleMask(); if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); return nullptr; } if (RealMask[0] != 0 || ImagMask[0] != 1) { LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); return nullptr; } // Type checking, the shuffle type should be a vector type of the same // scalar type, but half the size auto CheckType = [&](ShuffleVectorInst *Shuffle) { Value *Op = Shuffle->getOperand(0); auto *ShuffleTy = cast(Shuffle->getType()); auto *OpTy = cast(Op->getType()); if (OpTy->getScalarType() != ShuffleTy->getScalarType()) return false; if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) return false; return true; }; auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { if (!CheckType(Shuffle)) return false; ArrayRef Mask = Shuffle->getShuffleMask(); int Last = *Mask.rbegin(); Value *Op = Shuffle->getOperand(0); auto *OpTy = cast(Op->getType()); int NumElements = OpTy->getNumElements(); // Ensure that the deinterleaving shuffle only pulls from the first // shuffle operand. return Last < NumElements; }; if (RealShuffle->getType() != ImagShuffle->getType()) { LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); return nullptr; } if (!CheckDeinterleavingShuffle(RealShuffle)) { LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); return nullptr; } if (!CheckDeinterleavingShuffle(ImagShuffle)) { LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); return nullptr; } NodePtr PlaceholderNode = prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, RealShuffle, ImagShuffle); PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); FinalInstructions.insert(RealShuffle); FinalInstructions.insert(ImagShuffle); return submitCompositeNode(PlaceholderNode); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { auto IsSplat = [](Value *V) -> bool { // Fixed-width vector with constants if (isa(V)) return true; VectorType *VTy; ArrayRef Mask; // Splats are represented differently depending on whether the repeated // value is a constant or an Instruction if (auto *Const = dyn_cast(V)) { if (Const->getOpcode() != Instruction::ShuffleVector) return false; VTy = cast(Const->getType()); Mask = Const->getShuffleMask(); } else if (auto *Shuf = dyn_cast(V)) { VTy = Shuf->getType(); Mask = Shuf->getShuffleMask(); } else { return false; } // When the data type is <1 x Type>, it's not possible to differentiate // between the ComplexDeinterleaving::Deinterleave and // ComplexDeinterleaving::Splat operations. if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) return false; return all_equal(Mask) && Mask[0] == 0; }; if (!IsSplat(R) || !IsSplat(I)) return nullptr; auto *Real = dyn_cast(R); auto *Imag = dyn_cast(I); if ((!Real && Imag) || (Real && !Imag)) return nullptr; if (Real && Imag) { // Non-constant splats should be in the same basic block if (Real->getParent() != Imag->getParent()) return nullptr; FinalInstructions.insert(Real); FinalInstructions.insert(Imag); } NodePtr PlaceholderNode = prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); return submitCompositeNode(PlaceholderNode); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, Instruction *Imag) { if (Real != RealPHI || Imag != ImagPHI) return nullptr; PHIsFound = true; NodePtr PlaceholderNode = prepareCompositeNode( ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); return submitCompositeNode(PlaceholderNode); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, Instruction *Imag) { auto *SelectReal = dyn_cast(Real); auto *SelectImag = dyn_cast(Imag); if (!SelectReal || !SelectImag) return nullptr; Instruction *MaskA, *MaskB; Instruction *AR, *AI, *RA, *BI; if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), m_Instruction(RA))) || !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), m_Instruction(BI)))) return nullptr; if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) return nullptr; if (!MaskA->getType()->isVectorTy()) return nullptr; auto NodeA = identifyNode(AR, AI); if (!NodeA) return nullptr; auto NodeB = identifyNode(RA, BI); if (!NodeB) return nullptr; NodePtr PlaceholderNode = prepareCompositeNode( ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); PlaceholderNode->addOperand(NodeA); PlaceholderNode->addOperand(NodeB); FinalInstructions.insert(MaskA); FinalInstructions.insert(MaskB); return submitCompositeNode(PlaceholderNode); } static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, std::optional Flags, Value *InputA, Value *InputB) { Value *I; switch (Opcode) { case Instruction::FNeg: I = B.CreateFNeg(InputA); break; case Instruction::FAdd: I = B.CreateFAdd(InputA, InputB); break; case Instruction::Add: I = B.CreateAdd(InputA, InputB); break; case Instruction::FSub: I = B.CreateFSub(InputA, InputB); break; case Instruction::Sub: I = B.CreateSub(InputA, InputB); break; case Instruction::FMul: I = B.CreateFMul(InputA, InputB); break; case Instruction::Mul: I = B.CreateMul(InputA, InputB); break; default: llvm_unreachable("Incorrect symmetric opcode"); } if (Flags) cast(I)->setFastMathFlags(*Flags); return I; } Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, RawNodePtr Node) { if (Node->ReplacementNode) return Node->ReplacementNode; auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { return Node->Operands.size() > Idx ? replaceNode(Builder, Node->Operands[Idx]) : nullptr; }; Value *ReplacementNode; switch (Node->Operation) { case ComplexDeinterleavingOperation::CAdd: case ComplexDeinterleavingOperation::CMulPartial: case ComplexDeinterleavingOperation::Symmetric: { Value *Input0 = ReplaceOperandIfExist(Node, 0); Value *Input1 = ReplaceOperandIfExist(Node, 1); Value *Accumulator = ReplaceOperandIfExist(Node, 2); assert(!Input1 || (Input0->getType() == Input1->getType() && "Node inputs need to be of the same type")); assert(!Accumulator || (Input0->getType() == Accumulator->getType() && "Accumulator and input need to be of the same type")); if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, Input0, Input1); else ReplacementNode = TL->createComplexDeinterleavingIR( Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); break; } case ComplexDeinterleavingOperation::Deinterleave: llvm_unreachable("Deinterleave node should already have ReplacementNode"); break; case ComplexDeinterleavingOperation::Splat: { auto *NewTy = VectorType::getDoubleElementsVectorType( cast(Node->Real->getType())); auto *R = dyn_cast(Node->Real); auto *I = dyn_cast(Node->Imag); if (R && I) { // Splats that are not constant are interleaved where they are located Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); IRBuilder<> IRB(InsertPoint); ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag}); } else { ReplacementNode = Builder.CreateIntrinsic( Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag}); } break; } case ComplexDeinterleavingOperation::ReductionPHI: { // If Operation is ReductionPHI, a new empty PHINode is created. // It is filled later when the ReductionOperation is processed. auto *VTy = cast(Node->Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt()); OldToNewPHI[dyn_cast(Node->Real)] = NewPHI; ReplacementNode = NewPHI; break; } case ComplexDeinterleavingOperation::ReductionOperation: ReplacementNode = replaceNode(Builder, Node->Operands[0]); processReductionOperation(ReplacementNode, Node); break; case ComplexDeinterleavingOperation::ReductionSelect: { auto *MaskReal = cast(Node->Real)->getOperand(0); auto *MaskImag = cast(Node->Imag)->getOperand(0); auto *A = replaceNode(Builder, Node->Operands[0]); auto *B = replaceNode(Builder, Node->Operands[1]); auto *NewMaskTy = VectorType::getDoubleElementsVectorType( cast(MaskReal->getType())); auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewMaskTy, {MaskReal, MaskImag}); ReplacementNode = Builder.CreateSelect(NewMask, A, B); break; } } assert(ReplacementNode && "Target failed to create Intrinsic call."); NumComplexTransformations += 1; Node->ReplacementNode = ReplacementNode; return ReplacementNode; } void ComplexDeinterleavingGraph::processReductionOperation( Value *OperationReplacement, RawNodePtr Node) { auto *Real = cast(Node->Real); auto *Imag = cast(Node->Imag); auto *OldPHIReal = ReductionInfo[Real].first; auto *OldPHIImag = ReductionInfo[Imag].first; auto *NewPHI = OldToNewPHI[OldPHIReal]; auto *VTy = cast(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); // We have to interleave initial origin values coming from IncomingBlock Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); IRBuilder<> Builder(Incoming->getTerminator()); auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, {InitReal, InitImag}); NewPHI->addIncoming(NewInit, Incoming); NewPHI->addIncoming(OperationReplacement, BackEdge); // Deinterleave complex vector outside of loop so that it can be finally // reduced auto *FinalReductionReal = ReductionInfo[Real].second; auto *FinalReductionImag = ReductionInfo[Imag].second; Builder.SetInsertPoint( &*FinalReductionReal->getParent()->getFirstInsertionPt()); auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2, OperationReplacement->getType(), OperationReplacement); auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); FinalReductionReal->replaceUsesOfWith(Real, NewReal); Builder.SetInsertPoint(FinalReductionImag); auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); FinalReductionImag->replaceUsesOfWith(Imag, NewImag); } void ComplexDeinterleavingGraph::replaceNodes() { SmallVector DeadInstrRoots; for (auto *RootInstruction : OrderedRoots) { // Check if this potential root went through check process and we can // deinterleave it if (!RootToNode.count(RootInstruction)) continue; IRBuilder<> Builder(RootInstruction); auto RootNode = RootToNode[RootInstruction]; Value *R = replaceNode(Builder, RootNode.get()); if (RootNode->Operation == ComplexDeinterleavingOperation::ReductionOperation) { auto *RootReal = cast(RootNode->Real); auto *RootImag = cast(RootNode->Imag); ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); DeadInstrRoots.push_back(cast(RootReal)); DeadInstrRoots.push_back(cast(RootImag)); } else { assert(R && "Unable to find replacement for RootInstruction"); DeadInstrRoots.push_back(RootInstruction); RootInstruction->replaceAllUsesWith(R); } } for (auto *I : DeadInstrRoots) RecursivelyDeleteTriviallyDeadInstructions(I, TLI); }