1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/MapVector.h" 64 #include "llvm/ADT/Statistic.h" 65 #include "llvm/Analysis/TargetLibraryInfo.h" 66 #include "llvm/Analysis/TargetTransformInfo.h" 67 #include "llvm/CodeGen/TargetLowering.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 /// Returns true if the operation is a negation of V, and it works for both 104 /// integers and floats. 105 static bool isNeg(Value *V); 106 107 /// Returns the operand for negation operation. 108 static Value *getNegOperand(Value *V); 109 110 namespace { 111 template <typename T, typename IterT> 112 std::optional<T> findCommonBetweenCollections(IterT A, IterT B) { 113 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); }); 114 if (Common != A.end()) 115 return std::make_optional(*Common); 116 return std::nullopt; 117 } 118 119 class ComplexDeinterleavingLegacyPass : public FunctionPass { 120 public: 121 static char ID; 122 123 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 124 : FunctionPass(ID), TM(TM) { 125 initializeComplexDeinterleavingLegacyPassPass( 126 *PassRegistry::getPassRegistry()); 127 } 128 129 StringRef getPassName() const override { 130 return "Complex Deinterleaving Pass"; 131 } 132 133 bool runOnFunction(Function &F) override; 134 void getAnalysisUsage(AnalysisUsage &AU) const override { 135 AU.addRequired<TargetLibraryInfoWrapperPass>(); 136 AU.setPreservesCFG(); 137 } 138 139 private: 140 const TargetMachine *TM; 141 }; 142 143 class ComplexDeinterleavingGraph; 144 struct ComplexDeinterleavingCompositeNode { 145 146 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 147 Value *R, Value *I) 148 : Operation(Op), Real(R), Imag(I) {} 149 150 private: 151 friend class ComplexDeinterleavingGraph; 152 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 153 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 154 bool OperandsValid = true; 155 156 public: 157 ComplexDeinterleavingOperation Operation; 158 Value *Real; 159 Value *Imag; 160 161 // This two members are required exclusively for generating 162 // ComplexDeinterleavingOperation::Symmetric operations. 163 unsigned Opcode; 164 std::optional<FastMathFlags> Flags; 165 166 ComplexDeinterleavingRotation Rotation = 167 ComplexDeinterleavingRotation::Rotation_0; 168 SmallVector<RawNodePtr> Operands; 169 Value *ReplacementNode = nullptr; 170 171 void addOperand(NodePtr Node) { 172 if (!Node || !Node.get()) 173 OperandsValid = false; 174 Operands.push_back(Node.get()); 175 } 176 177 void dump() { dump(dbgs()); } 178 void dump(raw_ostream &OS) { 179 auto PrintValue = [&](Value *V) { 180 if (V) { 181 OS << "\""; 182 V->print(OS, true); 183 OS << "\"\n"; 184 } else 185 OS << "nullptr\n"; 186 }; 187 auto PrintNodeRef = [&](RawNodePtr Ptr) { 188 if (Ptr) 189 OS << Ptr << "\n"; 190 else 191 OS << "nullptr\n"; 192 }; 193 194 OS << "- CompositeNode: " << this << "\n"; 195 OS << " Real: "; 196 PrintValue(Real); 197 OS << " Imag: "; 198 PrintValue(Imag); 199 OS << " ReplacementNode: "; 200 PrintValue(ReplacementNode); 201 OS << " Operation: " << (int)Operation << "\n"; 202 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 203 OS << " Operands: \n"; 204 for (const auto &Op : Operands) { 205 OS << " - "; 206 PrintNodeRef(Op); 207 } 208 } 209 210 bool areOperandsValid() { return OperandsValid; } 211 }; 212 213 class ComplexDeinterleavingGraph { 214 public: 215 struct Product { 216 Value *Multiplier; 217 Value *Multiplicand; 218 bool IsPositive; 219 }; 220 221 using Addend = std::pair<Value *, bool>; 222 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 223 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 224 225 // Helper struct for holding info about potential partial multiplication 226 // candidates 227 struct PartialMulCandidate { 228 Value *Common; 229 NodePtr Node; 230 unsigned RealIdx; 231 unsigned ImagIdx; 232 bool IsNodeInverted; 233 }; 234 235 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 236 const TargetLibraryInfo *TLI) 237 : TL(TL), TLI(TLI) {} 238 239 private: 240 const TargetLowering *TL = nullptr; 241 const TargetLibraryInfo *TLI = nullptr; 242 SmallVector<NodePtr> CompositeNodes; 243 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; 244 245 SmallPtrSet<Instruction *, 16> FinalInstructions; 246 247 /// Root instructions are instructions from which complex computation starts 248 std::map<Instruction *, NodePtr> RootToNode; 249 250 /// Topologically sorted root instructions 251 SmallVector<Instruction *, 1> OrderedRoots; 252 253 /// When examining a basic block for complex deinterleaving, if it is a simple 254 /// one-block loop, then the only incoming block is 'Incoming' and the 255 /// 'BackEdge' block is the block itself." 256 BasicBlock *BackEdge = nullptr; 257 BasicBlock *Incoming = nullptr; 258 259 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 260 /// %OutsideUser as it is shown in the IR: 261 /// 262 /// vector.body: 263 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 264 /// [ %ReductionOp, %vector.body ] 265 /// ... 266 /// %ReductionOp = fadd i64 ... 267 /// ... 268 /// br i1 %condition, label %vector.body, %middle.block 269 /// 270 /// middle.block: 271 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 272 /// 273 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 274 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 275 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 276 277 /// In the process of detecting a reduction, we consider a pair of 278 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 279 /// traverse the use-tree to detect complex operations. As this is a reduction 280 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 281 /// to the %ReductionOPs that we suspect to be complex. 282 /// RealPHI and ImagPHI are used by the identifyPHINode method. 283 PHINode *RealPHI = nullptr; 284 PHINode *ImagPHI = nullptr; 285 286 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction 287 /// detection. 288 bool PHIsFound = false; 289 290 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 291 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 292 /// This mapping is populated during 293 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 294 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 295 /// replacement process. 296 std::map<PHINode *, PHINode *> OldToNewPHI; 297 298 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 299 Value *R, Value *I) { 300 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 301 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 302 (R && I)) && 303 "Reduction related nodes must have Real and Imaginary parts"); 304 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 305 I); 306 } 307 308 NodePtr submitCompositeNode(NodePtr Node) { 309 CompositeNodes.push_back(Node); 310 if (Node->Real) 311 CachedResult[{Node->Real, Node->Imag}] = Node; 312 return Node; 313 } 314 315 /// Identifies a complex partial multiply pattern and its rotation, based on 316 /// the following patterns 317 /// 318 /// 0: r: cr + ar * br 319 /// i: ci + ar * bi 320 /// 90: r: cr - ai * bi 321 /// i: ci + ai * br 322 /// 180: r: cr - ar * br 323 /// i: ci - ar * bi 324 /// 270: r: cr + ai * bi 325 /// i: ci - ai * br 326 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 327 328 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 329 /// is partially known from identifyPartialMul, filling in the other half of 330 /// the complex pair. 331 NodePtr 332 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 333 std::pair<Value *, Value *> &CommonOperandI); 334 335 /// Identifies a complex add pattern and its rotation, based on the following 336 /// patterns. 337 /// 338 /// 90: r: ar - bi 339 /// i: ai + br 340 /// 270: r: ar + bi 341 /// i: ai - br 342 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 343 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 344 NodePtr identifyPartialReduction(Value *R, Value *I); 345 NodePtr identifyDotProduct(Value *Inst); 346 347 NodePtr identifyNode(Value *R, Value *I); 348 349 /// Determine if a sum of complex numbers can be formed from \p RealAddends 350 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 351 /// Return nullptr if it is not possible to construct a complex number. 352 /// \p Flags are needed to generate symmetric Add and Sub operations. 353 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 354 std::list<Addend> &ImagAddends, 355 std::optional<FastMathFlags> Flags, 356 NodePtr Accumulator); 357 358 /// Extract one addend that have both real and imaginary parts positive. 359 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 360 std::list<Addend> &ImagAddends); 361 362 /// Determine if sum of multiplications of complex numbers can be formed from 363 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 364 /// to it. Return nullptr if it is not possible to construct a complex number. 365 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 366 std::vector<Product> &ImagMuls, 367 NodePtr Accumulator); 368 369 /// Go through pairs of multiplication (one Real and one Imag) and find all 370 /// possible candidates for partial multiplication and put them into \p 371 /// Candidates. Returns true if all Product has pair with common operand 372 bool collectPartialMuls(const std::vector<Product> &RealMuls, 373 const std::vector<Product> &ImagMuls, 374 std::vector<PartialMulCandidate> &Candidates); 375 376 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 377 /// the order of complex computation operations may be significantly altered, 378 /// and the real and imaginary parts may not be executed in parallel. This 379 /// function takes this into consideration and employs a more general approach 380 /// to identify complex computations. Initially, it gathers all the addends 381 /// and multiplicands and then constructs a complex expression from them. 382 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 383 384 NodePtr identifyRoot(Instruction *I); 385 386 /// Identifies the Deinterleave operation applied to a vector containing 387 /// complex numbers. There are two ways to represent the Deinterleave 388 /// operation: 389 /// * Using two shufflevectors with even indices for /pReal instruction and 390 /// odd indices for /pImag instructions (only for fixed-width vectors) 391 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 392 /// intrinsic (for both fixed and scalable vectors) 393 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 394 395 /// identifying the operation that represents a complex number repeated in a 396 /// Splat vector. There are two possible types of splats: ConstantExpr with 397 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an 398 /// initialization mask with all values set to zero. 399 NodePtr identifySplat(Value *Real, Value *Imag); 400 401 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 402 403 /// Identifies SelectInsts in a loop that has reduction with predication masks 404 /// and/or predicated tail folding 405 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 406 407 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 408 409 /// Complete IR modifications after producing new reduction operation: 410 /// * Populate the PHINode generated for 411 /// ComplexDeinterleavingOperation::ReductionPHI 412 /// * Deinterleave the final value outside of the loop and repurpose original 413 /// reduction users 414 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 415 void processReductionSingle(Value *OperationReplacement, RawNodePtr Node); 416 417 public: 418 void dump() { dump(dbgs()); } 419 void dump(raw_ostream &OS) { 420 for (const auto &Node : CompositeNodes) 421 Node->dump(OS); 422 } 423 424 /// Returns false if the deinterleaving operation should be cancelled for the 425 /// current graph. 426 bool identifyNodes(Instruction *RootI); 427 428 /// In case \pB is one-block loop, this function seeks potential reductions 429 /// and populates ReductionInfo. Returns true if any reductions were 430 /// identified. 431 bool collectPotentialReductions(BasicBlock *B); 432 433 void identifyReductionNodes(); 434 435 /// Check that every instruction, from the roots to the leaves, has internal 436 /// uses. 437 bool checkNodes(); 438 439 /// Perform the actual replacement of the underlying instruction graph. 440 void replaceNodes(); 441 }; 442 443 class ComplexDeinterleaving { 444 public: 445 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 446 : TL(tl), TLI(tli) {} 447 bool runOnFunction(Function &F); 448 449 private: 450 bool evaluateBasicBlock(BasicBlock *B); 451 452 const TargetLowering *TL = nullptr; 453 const TargetLibraryInfo *TLI = nullptr; 454 }; 455 456 } // namespace 457 458 char ComplexDeinterleavingLegacyPass::ID = 0; 459 460 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 461 "Complex Deinterleaving", false, false) 462 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 463 "Complex Deinterleaving", false, false) 464 465 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 466 FunctionAnalysisManager &AM) { 467 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 468 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 469 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 470 return PreservedAnalyses::all(); 471 472 PreservedAnalyses PA; 473 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 474 return PA; 475 } 476 477 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 478 return new ComplexDeinterleavingLegacyPass(TM); 479 } 480 481 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 482 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 483 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 484 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 485 } 486 487 bool ComplexDeinterleaving::runOnFunction(Function &F) { 488 if (!ComplexDeinterleavingEnabled) { 489 LLVM_DEBUG( 490 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 491 return false; 492 } 493 494 if (!TL->isComplexDeinterleavingSupported()) { 495 LLVM_DEBUG( 496 dbgs() << "Complex deinterleaving has been disabled, target does " 497 "not support lowering of complex number operations.\n"); 498 return false; 499 } 500 501 bool Changed = false; 502 for (auto &B : F) 503 Changed |= evaluateBasicBlock(&B); 504 505 return Changed; 506 } 507 508 static bool isInterleavingMask(ArrayRef<int> Mask) { 509 // If the size is not even, it's not an interleaving mask 510 if ((Mask.size() & 1)) 511 return false; 512 513 int HalfNumElements = Mask.size() / 2; 514 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 515 int MaskIdx = Idx * 2; 516 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 517 return false; 518 } 519 520 return true; 521 } 522 523 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 524 int Offset = Mask[0]; 525 int HalfNumElements = Mask.size() / 2; 526 527 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 528 if (Mask[Idx] != (Idx * 2) + Offset) 529 return false; 530 } 531 532 return true; 533 } 534 535 bool isNeg(Value *V) { 536 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); 537 } 538 539 Value *getNegOperand(Value *V) { 540 assert(isNeg(V)); 541 auto *I = cast<Instruction>(V); 542 if (I->getOpcode() == Instruction::FNeg) 543 return I->getOperand(0); 544 545 return I->getOperand(1); 546 } 547 548 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 549 ComplexDeinterleavingGraph Graph(TL, TLI); 550 if (Graph.collectPotentialReductions(B)) 551 Graph.identifyReductionNodes(); 552 553 for (auto &I : *B) 554 Graph.identifyNodes(&I); 555 556 if (Graph.checkNodes()) { 557 Graph.replaceNodes(); 558 return true; 559 } 560 561 return false; 562 } 563 564 ComplexDeinterleavingGraph::NodePtr 565 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 566 Instruction *Real, Instruction *Imag, 567 std::pair<Value *, Value *> &PartialMatch) { 568 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 569 << "\n"); 570 571 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 572 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 573 return nullptr; 574 } 575 576 if ((Real->getOpcode() != Instruction::FMul && 577 Real->getOpcode() != Instruction::Mul) || 578 (Imag->getOpcode() != Instruction::FMul && 579 Imag->getOpcode() != Instruction::Mul)) { 580 LLVM_DEBUG( 581 dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); 582 return nullptr; 583 } 584 585 Value *R0 = Real->getOperand(0); 586 Value *R1 = Real->getOperand(1); 587 Value *I0 = Imag->getOperand(0); 588 Value *I1 = Imag->getOperand(1); 589 590 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 591 // rotations and use the operand. 592 unsigned Negs = 0; 593 Value *Op; 594 if (match(R0, m_Neg(m_Value(Op)))) { 595 Negs |= 1; 596 R0 = Op; 597 } else if (match(R1, m_Neg(m_Value(Op)))) { 598 Negs |= 1; 599 R1 = Op; 600 } 601 602 if (isNeg(I0)) { 603 Negs |= 2; 604 Negs ^= 1; 605 I0 = Op; 606 } else if (match(I1, m_Neg(m_Value(Op)))) { 607 Negs |= 2; 608 Negs ^= 1; 609 I1 = Op; 610 } 611 612 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 613 614 Value *CommonOperand; 615 Value *UncommonRealOp; 616 Value *UncommonImagOp; 617 618 if (R0 == I0 || R0 == I1) { 619 CommonOperand = R0; 620 UncommonRealOp = R1; 621 } else if (R1 == I0 || R1 == I1) { 622 CommonOperand = R1; 623 UncommonRealOp = R0; 624 } else { 625 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 626 return nullptr; 627 } 628 629 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 630 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 631 Rotation == ComplexDeinterleavingRotation::Rotation_270) 632 std::swap(UncommonRealOp, UncommonImagOp); 633 634 // Between identifyPartialMul and here we need to have found a complete valid 635 // pair from the CommonOperand of each part. 636 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 637 Rotation == ComplexDeinterleavingRotation::Rotation_180) 638 PartialMatch.first = CommonOperand; 639 else 640 PartialMatch.second = CommonOperand; 641 642 if (!PartialMatch.first || !PartialMatch.second) { 643 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 644 return nullptr; 645 } 646 647 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 648 if (!CommonNode) { 649 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 650 return nullptr; 651 } 652 653 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 654 if (!UncommonNode) { 655 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 656 return nullptr; 657 } 658 659 NodePtr Node = prepareCompositeNode( 660 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 661 Node->Rotation = Rotation; 662 Node->addOperand(CommonNode); 663 Node->addOperand(UncommonNode); 664 return submitCompositeNode(Node); 665 } 666 667 ComplexDeinterleavingGraph::NodePtr 668 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 669 Instruction *Imag) { 670 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 671 << "\n"); 672 // Determine rotation 673 auto IsAdd = [](unsigned Op) { 674 return Op == Instruction::FAdd || Op == Instruction::Add; 675 }; 676 auto IsSub = [](unsigned Op) { 677 return Op == Instruction::FSub || Op == Instruction::Sub; 678 }; 679 ComplexDeinterleavingRotation Rotation; 680 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 681 Rotation = ComplexDeinterleavingRotation::Rotation_0; 682 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 683 Rotation = ComplexDeinterleavingRotation::Rotation_90; 684 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) 685 Rotation = ComplexDeinterleavingRotation::Rotation_180; 686 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) 687 Rotation = ComplexDeinterleavingRotation::Rotation_270; 688 else { 689 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 690 return nullptr; 691 } 692 693 if (isa<FPMathOperator>(Real) && 694 (!Real->getFastMathFlags().allowContract() || 695 !Imag->getFastMathFlags().allowContract())) { 696 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 697 return nullptr; 698 } 699 700 Value *CR = Real->getOperand(0); 701 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 702 if (!RealMulI) 703 return nullptr; 704 Value *CI = Imag->getOperand(0); 705 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 706 if (!ImagMulI) 707 return nullptr; 708 709 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 710 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 711 return nullptr; 712 } 713 714 Value *R0 = RealMulI->getOperand(0); 715 Value *R1 = RealMulI->getOperand(1); 716 Value *I0 = ImagMulI->getOperand(0); 717 Value *I1 = ImagMulI->getOperand(1); 718 719 Value *CommonOperand; 720 Value *UncommonRealOp; 721 Value *UncommonImagOp; 722 723 if (R0 == I0 || R0 == I1) { 724 CommonOperand = R0; 725 UncommonRealOp = R1; 726 } else if (R1 == I0 || R1 == I1) { 727 CommonOperand = R1; 728 UncommonRealOp = R0; 729 } else { 730 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 731 return nullptr; 732 } 733 734 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 735 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 736 Rotation == ComplexDeinterleavingRotation::Rotation_270) 737 std::swap(UncommonRealOp, UncommonImagOp); 738 739 std::pair<Value *, Value *> PartialMatch( 740 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 741 Rotation == ComplexDeinterleavingRotation::Rotation_180) 742 ? CommonOperand 743 : nullptr, 744 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 745 Rotation == ComplexDeinterleavingRotation::Rotation_270) 746 ? CommonOperand 747 : nullptr); 748 749 auto *CRInst = dyn_cast<Instruction>(CR); 750 auto *CIInst = dyn_cast<Instruction>(CI); 751 752 if (!CRInst || !CIInst) { 753 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 754 return nullptr; 755 } 756 757 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 758 if (!CNode) { 759 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 760 return nullptr; 761 } 762 763 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 764 if (!UncommonRes) { 765 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 766 return nullptr; 767 } 768 769 assert(PartialMatch.first && PartialMatch.second); 770 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 771 if (!CommonRes) { 772 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 773 return nullptr; 774 } 775 776 NodePtr Node = prepareCompositeNode( 777 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 778 Node->Rotation = Rotation; 779 Node->addOperand(CommonRes); 780 Node->addOperand(UncommonRes); 781 Node->addOperand(CNode); 782 return submitCompositeNode(Node); 783 } 784 785 ComplexDeinterleavingGraph::NodePtr 786 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 787 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 788 789 // Determine rotation 790 ComplexDeinterleavingRotation Rotation; 791 if ((Real->getOpcode() == Instruction::FSub && 792 Imag->getOpcode() == Instruction::FAdd) || 793 (Real->getOpcode() == Instruction::Sub && 794 Imag->getOpcode() == Instruction::Add)) 795 Rotation = ComplexDeinterleavingRotation::Rotation_90; 796 else if ((Real->getOpcode() == Instruction::FAdd && 797 Imag->getOpcode() == Instruction::FSub) || 798 (Real->getOpcode() == Instruction::Add && 799 Imag->getOpcode() == Instruction::Sub)) 800 Rotation = ComplexDeinterleavingRotation::Rotation_270; 801 else { 802 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 803 return nullptr; 804 } 805 806 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 807 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 808 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 809 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 810 811 if (!AR || !AI || !BR || !BI) { 812 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 813 return nullptr; 814 } 815 816 NodePtr ResA = identifyNode(AR, AI); 817 if (!ResA) { 818 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 819 return nullptr; 820 } 821 NodePtr ResB = identifyNode(BR, BI); 822 if (!ResB) { 823 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 824 return nullptr; 825 } 826 827 NodePtr Node = 828 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 829 Node->Rotation = Rotation; 830 Node->addOperand(ResA); 831 Node->addOperand(ResB); 832 return submitCompositeNode(Node); 833 } 834 835 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 836 unsigned OpcA = A->getOpcode(); 837 unsigned OpcB = B->getOpcode(); 838 839 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 840 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 841 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 842 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 843 } 844 845 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 846 auto Pattern = 847 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 848 849 return match(A, Pattern) && match(B, Pattern); 850 } 851 852 static bool isInstructionPotentiallySymmetric(Instruction *I) { 853 switch (I->getOpcode()) { 854 case Instruction::FAdd: 855 case Instruction::FSub: 856 case Instruction::FMul: 857 case Instruction::FNeg: 858 case Instruction::Add: 859 case Instruction::Sub: 860 case Instruction::Mul: 861 return true; 862 default: 863 return false; 864 } 865 } 866 867 ComplexDeinterleavingGraph::NodePtr 868 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 869 Instruction *Imag) { 870 if (Real->getOpcode() != Imag->getOpcode()) 871 return nullptr; 872 873 if (!isInstructionPotentiallySymmetric(Real) || 874 !isInstructionPotentiallySymmetric(Imag)) 875 return nullptr; 876 877 auto *R0 = Real->getOperand(0); 878 auto *I0 = Imag->getOperand(0); 879 880 NodePtr Op0 = identifyNode(R0, I0); 881 NodePtr Op1 = nullptr; 882 if (Op0 == nullptr) 883 return nullptr; 884 885 if (Real->isBinaryOp()) { 886 auto *R1 = Real->getOperand(1); 887 auto *I1 = Imag->getOperand(1); 888 Op1 = identifyNode(R1, I1); 889 if (Op1 == nullptr) 890 return nullptr; 891 } 892 893 if (isa<FPMathOperator>(Real) && 894 Real->getFastMathFlags() != Imag->getFastMathFlags()) 895 return nullptr; 896 897 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 898 Real, Imag); 899 Node->Opcode = Real->getOpcode(); 900 if (isa<FPMathOperator>(Real)) 901 Node->Flags = Real->getFastMathFlags(); 902 903 Node->addOperand(Op0); 904 if (Real->isBinaryOp()) 905 Node->addOperand(Op1); 906 907 return submitCompositeNode(Node); 908 } 909 910 ComplexDeinterleavingGraph::NodePtr 911 ComplexDeinterleavingGraph::identifyDotProduct(Value *V) { 912 913 if (!TL->isComplexDeinterleavingOperationSupported( 914 ComplexDeinterleavingOperation::CDot, V->getType())) { 915 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving " 916 "operation CDot with the type " 917 << *V->getType() << "\n"); 918 return nullptr; 919 } 920 921 auto *Inst = cast<Instruction>(V); 922 auto *RealUser = cast<Instruction>(*Inst->user_begin()); 923 924 NodePtr CN = 925 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr); 926 927 NodePtr ANode; 928 929 const Intrinsic::ID PartialReduceInt = 930 Intrinsic::experimental_vector_partial_reduce_add; 931 932 Value *AReal = nullptr; 933 Value *AImag = nullptr; 934 Value *BReal = nullptr; 935 Value *BImag = nullptr; 936 Value *Phi = nullptr; 937 938 auto UnwrapCast = [](Value *V) -> Value * { 939 if (auto *CI = dyn_cast<CastInst>(V)) 940 return CI->getOperand(0); 941 return V; 942 }; 943 944 auto PatternRot0 = m_Intrinsic<PartialReduceInt>( 945 m_Intrinsic<PartialReduceInt>(m_Value(Phi), 946 m_Mul(m_Value(BReal), m_Value(AReal))), 947 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag)))); 948 949 auto PatternRot270 = m_Intrinsic<PartialReduceInt>( 950 m_Intrinsic<PartialReduceInt>( 951 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))), 952 m_Mul(m_Value(BImag), m_Value(AReal))); 953 954 if (match(Inst, PatternRot0)) { 955 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0; 956 } else if (match(Inst, PatternRot270)) { 957 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270; 958 } else { 959 Value *A0, *A1; 960 // The rotations 90 and 180 share the same operation pattern, so inspect the 961 // order of the operands, identifying where the real and imaginary 962 // components of A go, to discern between the aforementioned rotations. 963 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>( 964 m_Intrinsic<PartialReduceInt>(m_Value(Phi), 965 m_Mul(m_Value(BReal), m_Value(A0))), 966 m_Mul(m_Value(BImag), m_Value(A1))); 967 968 if (!match(Inst, PatternRot90Rot180)) 969 return nullptr; 970 971 A0 = UnwrapCast(A0); 972 A1 = UnwrapCast(A1); 973 974 // Test if A0 is real/A1 is imag 975 ANode = identifyNode(A0, A1); 976 if (!ANode) { 977 // Test if A0 is imag/A1 is real 978 ANode = identifyNode(A1, A0); 979 // Unable to identify operand components, thus unable to identify rotation 980 if (!ANode) 981 return nullptr; 982 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90; 983 AReal = A1; 984 AImag = A0; 985 } else { 986 AReal = A0; 987 AImag = A1; 988 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180; 989 } 990 } 991 992 AReal = UnwrapCast(AReal); 993 AImag = UnwrapCast(AImag); 994 BReal = UnwrapCast(BReal); 995 BImag = UnwrapCast(BImag); 996 997 VectorType *VTy = cast<VectorType>(V->getType()); 998 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2); 999 if (AReal->getType() != ExpectedOperandTy) 1000 return nullptr; 1001 if (AImag->getType() != ExpectedOperandTy) 1002 return nullptr; 1003 if (BReal->getType() != ExpectedOperandTy) 1004 return nullptr; 1005 if (BImag->getType() != ExpectedOperandTy) 1006 return nullptr; 1007 1008 if (Phi->getType() != VTy && RealUser->getType() != VTy) 1009 return nullptr; 1010 1011 NodePtr Node = identifyNode(AReal, AImag); 1012 1013 // In the case that a node was identified to figure out the rotation, ensure 1014 // that trying to identify a node with AReal and AImag post-unwrap results in 1015 // the same node 1016 if (ANode && Node != ANode) { 1017 LLVM_DEBUG( 1018 dbgs() 1019 << "Identified node is different from previously identified node. " 1020 "Unable to confidently generate a complex operation node\n"); 1021 return nullptr; 1022 } 1023 1024 CN->addOperand(Node); 1025 CN->addOperand(identifyNode(BReal, BImag)); 1026 CN->addOperand(identifyNode(Phi, RealUser)); 1027 1028 return submitCompositeNode(CN); 1029 } 1030 1031 ComplexDeinterleavingGraph::NodePtr 1032 ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) { 1033 // Partial reductions don't support non-vector types, so check these first 1034 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType())) 1035 return nullptr; 1036 1037 if (!R->hasUseList() || !I->hasUseList()) 1038 return nullptr; 1039 1040 auto CommonUser = 1041 findCommonBetweenCollections<Value *>(R->users(), I->users()); 1042 if (!CommonUser) 1043 return nullptr; 1044 1045 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser); 1046 if (!IInst || IInst->getIntrinsicID() != 1047 Intrinsic::experimental_vector_partial_reduce_add) 1048 return nullptr; 1049 1050 if (NodePtr CN = identifyDotProduct(IInst)) 1051 return CN; 1052 1053 return nullptr; 1054 } 1055 1056 ComplexDeinterleavingGraph::NodePtr 1057 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 1058 auto It = CachedResult.find({R, I}); 1059 if (It != CachedResult.end()) { 1060 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 1061 return It->second; 1062 } 1063 1064 if (NodePtr CN = identifyPartialReduction(R, I)) 1065 return CN; 1066 1067 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I); 1068 if (!IsReduction && R->getType() != I->getType()) 1069 return nullptr; 1070 1071 if (NodePtr CN = identifySplat(R, I)) 1072 return CN; 1073 1074 auto *Real = dyn_cast<Instruction>(R); 1075 auto *Imag = dyn_cast<Instruction>(I); 1076 if (!Real || !Imag) 1077 return nullptr; 1078 1079 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 1080 return CN; 1081 1082 if (NodePtr CN = identifyPHINode(Real, Imag)) 1083 return CN; 1084 1085 if (NodePtr CN = identifySelectNode(Real, Imag)) 1086 return CN; 1087 1088 auto *VTy = cast<VectorType>(Real->getType()); 1089 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1090 1091 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 1092 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 1093 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 1094 ComplexDeinterleavingOperation::CAdd, NewVTy); 1095 1096 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 1097 if (NodePtr CN = identifyPartialMul(Real, Imag)) 1098 return CN; 1099 } 1100 1101 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 1102 if (NodePtr CN = identifyAdd(Real, Imag)) 1103 return CN; 1104 } 1105 1106 if (HasCMulSupport && HasCAddSupport) { 1107 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 1108 return CN; 1109 } 1110 1111 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 1112 return CN; 1113 1114 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 1115 CachedResult[{R, I}] = nullptr; 1116 return nullptr; 1117 } 1118 1119 ComplexDeinterleavingGraph::NodePtr 1120 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 1121 Instruction *Imag) { 1122 auto IsOperationSupported = [](unsigned Opcode) -> bool { 1123 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || 1124 Opcode == Instruction::FNeg || Opcode == Instruction::Add || 1125 Opcode == Instruction::Sub; 1126 }; 1127 1128 if (!IsOperationSupported(Real->getOpcode()) || 1129 !IsOperationSupported(Imag->getOpcode())) 1130 return nullptr; 1131 1132 std::optional<FastMathFlags> Flags; 1133 if (isa<FPMathOperator>(Real)) { 1134 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 1135 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " 1136 "not identical\n"); 1137 return nullptr; 1138 } 1139 1140 Flags = Real->getFastMathFlags(); 1141 if (!Flags->allowReassoc()) { 1142 LLVM_DEBUG( 1143 dbgs() 1144 << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 1145 return nullptr; 1146 } 1147 } 1148 1149 // Collect multiplications and addend instructions from the given instruction 1150 // while traversing it operands. Additionally, verify that all instructions 1151 // have the same fast math flags. 1152 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 1153 std::list<Addend> &Addends) -> bool { 1154 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 1155 SmallPtrSet<Value *, 8> Visited; 1156 while (!Worklist.empty()) { 1157 auto [V, IsPositive] = Worklist.pop_back_val(); 1158 if (!Visited.insert(V).second) 1159 continue; 1160 1161 Instruction *I = dyn_cast<Instruction>(V); 1162 if (!I) { 1163 Addends.emplace_back(V, IsPositive); 1164 continue; 1165 } 1166 1167 // If an instruction has more than one user, it indicates that it either 1168 // has an external user, which will be later checked by the checkNodes 1169 // function, or it is a subexpression utilized by multiple expressions. In 1170 // the latter case, we will attempt to separately identify the complex 1171 // operation from here in order to create a shared 1172 // ComplexDeinterleavingCompositeNode. 1173 if (I != Insn && I->hasNUsesOrMore(2)) { 1174 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 1175 Addends.emplace_back(I, IsPositive); 1176 continue; 1177 } 1178 switch (I->getOpcode()) { 1179 case Instruction::FAdd: 1180 case Instruction::Add: 1181 Worklist.emplace_back(I->getOperand(1), IsPositive); 1182 Worklist.emplace_back(I->getOperand(0), IsPositive); 1183 break; 1184 case Instruction::FSub: 1185 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1186 Worklist.emplace_back(I->getOperand(0), IsPositive); 1187 break; 1188 case Instruction::Sub: 1189 if (isNeg(I)) { 1190 Worklist.emplace_back(getNegOperand(I), !IsPositive); 1191 } else { 1192 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1193 Worklist.emplace_back(I->getOperand(0), IsPositive); 1194 } 1195 break; 1196 case Instruction::FMul: 1197 case Instruction::Mul: { 1198 Value *A, *B; 1199 if (isNeg(I->getOperand(0))) { 1200 A = getNegOperand(I->getOperand(0)); 1201 IsPositive = !IsPositive; 1202 } else { 1203 A = I->getOperand(0); 1204 } 1205 1206 if (isNeg(I->getOperand(1))) { 1207 B = getNegOperand(I->getOperand(1)); 1208 IsPositive = !IsPositive; 1209 } else { 1210 B = I->getOperand(1); 1211 } 1212 Muls.push_back(Product{A, B, IsPositive}); 1213 break; 1214 } 1215 case Instruction::FNeg: 1216 Worklist.emplace_back(I->getOperand(0), !IsPositive); 1217 break; 1218 default: 1219 Addends.emplace_back(I, IsPositive); 1220 continue; 1221 } 1222 1223 if (Flags && I->getFastMathFlags() != *Flags) { 1224 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1225 "inconsistent with the root instructions' flags: " 1226 << *I << "\n"); 1227 return false; 1228 } 1229 } 1230 return true; 1231 }; 1232 1233 std::vector<Product> RealMuls, ImagMuls; 1234 std::list<Addend> RealAddends, ImagAddends; 1235 if (!Collect(Real, RealMuls, RealAddends) || 1236 !Collect(Imag, ImagMuls, ImagAddends)) 1237 return nullptr; 1238 1239 if (RealAddends.size() != ImagAddends.size()) 1240 return nullptr; 1241 1242 NodePtr FinalNode; 1243 if (!RealMuls.empty() || !ImagMuls.empty()) { 1244 // If there are multiplicands, extract positive addend and use it as an 1245 // accumulator 1246 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1247 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1248 if (!FinalNode) 1249 return nullptr; 1250 } 1251 1252 // Identify and process remaining additions 1253 if (!RealAddends.empty() || !ImagAddends.empty()) { 1254 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1255 if (!FinalNode) 1256 return nullptr; 1257 } 1258 assert(FinalNode && "FinalNode can not be nullptr here"); 1259 // Set the Real and Imag fields of the final node and submit it 1260 FinalNode->Real = Real; 1261 FinalNode->Imag = Imag; 1262 submitCompositeNode(FinalNode); 1263 return FinalNode; 1264 } 1265 1266 bool ComplexDeinterleavingGraph::collectPartialMuls( 1267 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1268 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1269 // Helper function to extract a common operand from two products 1270 auto FindCommonInstruction = [](const Product &Real, 1271 const Product &Imag) -> Value * { 1272 if (Real.Multiplicand == Imag.Multiplicand || 1273 Real.Multiplicand == Imag.Multiplier) 1274 return Real.Multiplicand; 1275 1276 if (Real.Multiplier == Imag.Multiplicand || 1277 Real.Multiplier == Imag.Multiplier) 1278 return Real.Multiplier; 1279 1280 return nullptr; 1281 }; 1282 1283 // Iterating over real and imaginary multiplications to find common operands 1284 // If a common operand is found, a partial multiplication candidate is created 1285 // and added to the candidates vector The function returns false if no common 1286 // operands are found for any product 1287 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1288 bool FoundCommon = false; 1289 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1290 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1291 if (!Common) 1292 continue; 1293 1294 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1295 : RealMuls[i].Multiplicand; 1296 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1297 : ImagMuls[j].Multiplicand; 1298 1299 auto Node = identifyNode(A, B); 1300 if (Node) { 1301 FoundCommon = true; 1302 PartialMulCandidates.push_back({Common, Node, i, j, false}); 1303 } 1304 1305 Node = identifyNode(B, A); 1306 if (Node) { 1307 FoundCommon = true; 1308 PartialMulCandidates.push_back({Common, Node, i, j, true}); 1309 } 1310 } 1311 if (!FoundCommon) 1312 return false; 1313 } 1314 return true; 1315 } 1316 1317 ComplexDeinterleavingGraph::NodePtr 1318 ComplexDeinterleavingGraph::identifyMultiplications( 1319 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1320 NodePtr Accumulator = nullptr) { 1321 if (RealMuls.size() != ImagMuls.size()) 1322 return nullptr; 1323 1324 std::vector<PartialMulCandidate> Info; 1325 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1326 return nullptr; 1327 1328 // Map to store common instruction to node pointers 1329 std::map<Value *, NodePtr> CommonToNode; 1330 std::vector<bool> Processed(Info.size(), false); 1331 for (unsigned I = 0; I < Info.size(); ++I) { 1332 if (Processed[I]) 1333 continue; 1334 1335 PartialMulCandidate &InfoA = Info[I]; 1336 for (unsigned J = I + 1; J < Info.size(); ++J) { 1337 if (Processed[J]) 1338 continue; 1339 1340 PartialMulCandidate &InfoB = Info[J]; 1341 auto *InfoReal = &InfoA; 1342 auto *InfoImag = &InfoB; 1343 1344 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1345 if (!NodeFromCommon) { 1346 std::swap(InfoReal, InfoImag); 1347 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1348 } 1349 if (!NodeFromCommon) 1350 continue; 1351 1352 CommonToNode[InfoReal->Common] = NodeFromCommon; 1353 CommonToNode[InfoImag->Common] = NodeFromCommon; 1354 Processed[I] = true; 1355 Processed[J] = true; 1356 } 1357 } 1358 1359 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1360 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1361 NodePtr Result = Accumulator; 1362 for (auto &PMI : Info) { 1363 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1364 continue; 1365 1366 auto It = CommonToNode.find(PMI.Common); 1367 // TODO: Process independent complex multiplications. Cases like this: 1368 // A.real() * B where both A and B are complex numbers. 1369 if (It == CommonToNode.end()) { 1370 LLVM_DEBUG({ 1371 dbgs() << "Unprocessed independent partial multiplication:\n"; 1372 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1373 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1374 << " multiplied by " << *Mul->Multiplicand << "\n"; 1375 }); 1376 return nullptr; 1377 } 1378 1379 auto &RealMul = RealMuls[PMI.RealIdx]; 1380 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1381 1382 auto NodeA = It->second; 1383 auto NodeB = PMI.Node; 1384 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1385 // The following table illustrates the relationship between multiplications 1386 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1387 // can see: 1388 // 1389 // Rotation | Real | Imag | 1390 // ---------+--------+--------+ 1391 // 0 | x * u | x * v | 1392 // 90 | -y * v | y * u | 1393 // 180 | -x * u | -x * v | 1394 // 270 | y * v | -y * u | 1395 // 1396 // Check if the candidate can indeed be represented by partial 1397 // multiplication 1398 // TODO: Add support for multiplication by complex one 1399 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1400 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1401 continue; 1402 1403 // Determine the rotation based on the multiplications 1404 ComplexDeinterleavingRotation Rotation; 1405 if (IsMultiplicandReal) { 1406 // Detect 0 and 180 degrees rotation 1407 if (RealMul.IsPositive && ImagMul.IsPositive) 1408 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1409 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1410 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1411 else 1412 continue; 1413 1414 } else { 1415 // Detect 90 and 270 degrees rotation 1416 if (!RealMul.IsPositive && ImagMul.IsPositive) 1417 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1418 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1419 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1420 else 1421 continue; 1422 } 1423 1424 LLVM_DEBUG({ 1425 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1426 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1427 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1428 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1429 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1430 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1431 }); 1432 1433 NodePtr NodeMul = prepareCompositeNode( 1434 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1435 NodeMul->Rotation = Rotation; 1436 NodeMul->addOperand(NodeA); 1437 NodeMul->addOperand(NodeB); 1438 if (Result) 1439 NodeMul->addOperand(Result); 1440 submitCompositeNode(NodeMul); 1441 Result = NodeMul; 1442 ProcessedReal[PMI.RealIdx] = true; 1443 ProcessedImag[PMI.ImagIdx] = true; 1444 } 1445 1446 // Ensure all products have been processed, if not return nullptr. 1447 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1448 !all_of(ProcessedImag, [](bool V) { return V; })) { 1449 1450 // Dump debug information about which partial multiplications are not 1451 // processed. 1452 LLVM_DEBUG({ 1453 dbgs() << "Unprocessed products (Real):\n"; 1454 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1455 if (!ProcessedReal[i]) 1456 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1457 << *RealMuls[i].Multiplier << " multiplied by " 1458 << *RealMuls[i].Multiplicand << "\n"; 1459 } 1460 dbgs() << "Unprocessed products (Imag):\n"; 1461 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1462 if (!ProcessedImag[i]) 1463 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1464 << *ImagMuls[i].Multiplier << " multiplied by " 1465 << *ImagMuls[i].Multiplicand << "\n"; 1466 } 1467 }); 1468 return nullptr; 1469 } 1470 1471 return Result; 1472 } 1473 1474 ComplexDeinterleavingGraph::NodePtr 1475 ComplexDeinterleavingGraph::identifyAdditions( 1476 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, 1477 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { 1478 if (RealAddends.size() != ImagAddends.size()) 1479 return nullptr; 1480 1481 NodePtr Result; 1482 // If we have accumulator use it as first addend 1483 if (Accumulator) 1484 Result = Accumulator; 1485 // Otherwise find an element with both positive real and imaginary parts. 1486 else 1487 Result = extractPositiveAddend(RealAddends, ImagAddends); 1488 1489 if (!Result) 1490 return nullptr; 1491 1492 while (!RealAddends.empty()) { 1493 auto ItR = RealAddends.begin(); 1494 auto [R, IsPositiveR] = *ItR; 1495 1496 bool FoundImag = false; 1497 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1498 auto [I, IsPositiveI] = *ItI; 1499 ComplexDeinterleavingRotation Rotation; 1500 if (IsPositiveR && IsPositiveI) 1501 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1502 else if (!IsPositiveR && IsPositiveI) 1503 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1504 else if (!IsPositiveR && !IsPositiveI) 1505 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1506 else 1507 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1508 1509 NodePtr AddNode; 1510 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1511 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1512 AddNode = identifyNode(R, I); 1513 } else { 1514 AddNode = identifyNode(I, R); 1515 } 1516 if (AddNode) { 1517 LLVM_DEBUG({ 1518 dbgs() << "Identified addition:\n"; 1519 dbgs().indent(4) << "X: " << *R << "\n"; 1520 dbgs().indent(4) << "Y: " << *I << "\n"; 1521 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1522 }); 1523 1524 NodePtr TmpNode; 1525 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1526 TmpNode = prepareCompositeNode( 1527 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1528 if (Flags) { 1529 TmpNode->Opcode = Instruction::FAdd; 1530 TmpNode->Flags = *Flags; 1531 } else { 1532 TmpNode->Opcode = Instruction::Add; 1533 } 1534 } else if (Rotation == 1535 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1536 TmpNode = prepareCompositeNode( 1537 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1538 if (Flags) { 1539 TmpNode->Opcode = Instruction::FSub; 1540 TmpNode->Flags = *Flags; 1541 } else { 1542 TmpNode->Opcode = Instruction::Sub; 1543 } 1544 } else { 1545 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1546 nullptr, nullptr); 1547 TmpNode->Rotation = Rotation; 1548 } 1549 1550 TmpNode->addOperand(Result); 1551 TmpNode->addOperand(AddNode); 1552 submitCompositeNode(TmpNode); 1553 Result = TmpNode; 1554 RealAddends.erase(ItR); 1555 ImagAddends.erase(ItI); 1556 FoundImag = true; 1557 break; 1558 } 1559 } 1560 if (!FoundImag) 1561 return nullptr; 1562 } 1563 return Result; 1564 } 1565 1566 ComplexDeinterleavingGraph::NodePtr 1567 ComplexDeinterleavingGraph::extractPositiveAddend( 1568 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1569 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1570 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1571 auto [R, IsPositiveR] = *ItR; 1572 auto [I, IsPositiveI] = *ItI; 1573 if (IsPositiveR && IsPositiveI) { 1574 auto Result = identifyNode(R, I); 1575 if (Result) { 1576 RealAddends.erase(ItR); 1577 ImagAddends.erase(ItI); 1578 return Result; 1579 } 1580 } 1581 } 1582 } 1583 return nullptr; 1584 } 1585 1586 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1587 // This potential root instruction might already have been recognized as 1588 // reduction. Because RootToNode maps both Real and Imaginary parts to 1589 // CompositeNode we should choose only one either Real or Imag instruction to 1590 // use as an anchor for generating complex instruction. 1591 auto It = RootToNode.find(RootI); 1592 if (It != RootToNode.end()) { 1593 auto RootNode = It->second; 1594 assert(RootNode->Operation == 1595 ComplexDeinterleavingOperation::ReductionOperation || 1596 RootNode->Operation == 1597 ComplexDeinterleavingOperation::ReductionSingle); 1598 // Find out which part, Real or Imag, comes later, and only if we come to 1599 // the latest part, add it to OrderedRoots. 1600 auto *R = cast<Instruction>(RootNode->Real); 1601 auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr; 1602 1603 Instruction *ReplacementAnchor; 1604 if (I) 1605 ReplacementAnchor = R->comesBefore(I) ? I : R; 1606 else 1607 ReplacementAnchor = R; 1608 1609 if (ReplacementAnchor != RootI) 1610 return false; 1611 OrderedRoots.push_back(RootI); 1612 return true; 1613 } 1614 1615 auto RootNode = identifyRoot(RootI); 1616 if (!RootNode) 1617 return false; 1618 1619 LLVM_DEBUG({ 1620 Function *F = RootI->getFunction(); 1621 BasicBlock *B = RootI->getParent(); 1622 dbgs() << "Complex deinterleaving graph for " << F->getName() 1623 << "::" << B->getName() << ".\n"; 1624 dump(dbgs()); 1625 dbgs() << "\n"; 1626 }); 1627 RootToNode[RootI] = RootNode; 1628 OrderedRoots.push_back(RootI); 1629 return true; 1630 } 1631 1632 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1633 bool FoundPotentialReduction = false; 1634 1635 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1636 if (!Br || Br->getNumSuccessors() != 2) 1637 return false; 1638 1639 // Identify simple one-block loop 1640 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1641 return false; 1642 1643 for (auto &PHI : B->phis()) { 1644 if (PHI.getNumIncomingValues() != 2) 1645 continue; 1646 1647 if (!PHI.getType()->isVectorTy()) 1648 continue; 1649 1650 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1651 if (!ReductionOp) 1652 continue; 1653 1654 // Check if final instruction is reduced outside of current block 1655 Instruction *FinalReduction = nullptr; 1656 auto NumUsers = 0u; 1657 for (auto *U : ReductionOp->users()) { 1658 ++NumUsers; 1659 if (U == &PHI) 1660 continue; 1661 FinalReduction = dyn_cast<Instruction>(U); 1662 } 1663 1664 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || 1665 isa<PHINode>(FinalReduction)) 1666 continue; 1667 1668 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1669 BackEdge = B; 1670 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1671 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1672 Incoming = PHI.getIncomingBlock(IncomingIdx); 1673 FoundPotentialReduction = true; 1674 1675 // If the initial value of PHINode is an Instruction, consider it a leaf 1676 // value of a complex deinterleaving graph. 1677 if (auto *InitPHI = 1678 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1679 FinalInstructions.insert(InitPHI); 1680 } 1681 return FoundPotentialReduction; 1682 } 1683 1684 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1685 SmallVector<bool> Processed(ReductionInfo.size(), false); 1686 SmallVector<Instruction *> OperationInstruction; 1687 for (auto &P : ReductionInfo) 1688 OperationInstruction.push_back(P.first); 1689 1690 // Identify a complex computation by evaluating two reduction operations that 1691 // potentially could be involved 1692 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1693 if (Processed[i]) 1694 continue; 1695 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1696 if (Processed[j]) 1697 continue; 1698 auto *Real = OperationInstruction[i]; 1699 auto *Imag = OperationInstruction[j]; 1700 if (Real->getType() != Imag->getType()) 1701 continue; 1702 1703 RealPHI = ReductionInfo[Real].first; 1704 ImagPHI = ReductionInfo[Imag].first; 1705 PHIsFound = false; 1706 auto Node = identifyNode(Real, Imag); 1707 if (!Node) { 1708 std::swap(Real, Imag); 1709 std::swap(RealPHI, ImagPHI); 1710 Node = identifyNode(Real, Imag); 1711 } 1712 1713 // If a node is identified and reduction PHINode is used in the chain of 1714 // operations, mark its operation instructions as used to prevent 1715 // re-identification and attach the node to the real part 1716 if (Node && PHIsFound) { 1717 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1718 << *Real << " / " << *Imag << "\n"); 1719 Processed[i] = true; 1720 Processed[j] = true; 1721 auto RootNode = prepareCompositeNode( 1722 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1723 RootNode->addOperand(Node); 1724 RootToNode[Real] = RootNode; 1725 RootToNode[Imag] = RootNode; 1726 submitCompositeNode(RootNode); 1727 break; 1728 } 1729 } 1730 1731 auto *Real = OperationInstruction[i]; 1732 // We want to check that we have 2 operands, but the function attributes 1733 // being counted as operands bloats this value. 1734 if (Processed[i] || Real->getNumOperands() < 2) 1735 continue; 1736 1737 // Can only combined integer reductions at the moment. 1738 if (!ReductionInfo[Real].second->getType()->isIntegerTy()) 1739 continue; 1740 1741 RealPHI = ReductionInfo[Real].first; 1742 ImagPHI = nullptr; 1743 PHIsFound = false; 1744 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1)); 1745 if (Node && PHIsFound) { 1746 LLVM_DEBUG( 1747 dbgs() << "Identified single reduction starting from instruction: " 1748 << *Real << "/" << *ReductionInfo[Real].second << "\n"); 1749 1750 // Reducing to a single vector is not supported, only permit reducing down 1751 // to scalar values. 1752 // Doing this here will leave the prior node in the graph, 1753 // however with no uses the node will be unreachable by the replacement 1754 // process. That along with the usage outside the graph should prevent the 1755 // replacement process from kicking off at all for this graph. 1756 // TODO Add support for reducing to a single vector value 1757 if (ReductionInfo[Real].second->getType()->isVectorTy()) 1758 continue; 1759 1760 Processed[i] = true; 1761 auto RootNode = prepareCompositeNode( 1762 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr); 1763 RootNode->addOperand(Node); 1764 RootToNode[Real] = RootNode; 1765 submitCompositeNode(RootNode); 1766 } 1767 } 1768 1769 RealPHI = nullptr; 1770 ImagPHI = nullptr; 1771 } 1772 1773 bool ComplexDeinterleavingGraph::checkNodes() { 1774 1775 bool FoundDeinterleaveNode = false; 1776 for (NodePtr N : CompositeNodes) { 1777 if (!N->areOperandsValid()) 1778 return false; 1779 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave) 1780 FoundDeinterleaveNode = true; 1781 } 1782 1783 // We need a deinterleave node in order to guarantee that we're working with 1784 // complex numbers. 1785 if (!FoundDeinterleaveNode) { 1786 LLVM_DEBUG( 1787 dbgs() << "Couldn't find a deinterleave node within the graph, cannot " 1788 "guarantee safety during graph transformation.\n"); 1789 return false; 1790 } 1791 1792 // Collect all instructions from roots to leaves 1793 SmallPtrSet<Instruction *, 16> AllInstructions; 1794 SmallVector<Instruction *, 8> Worklist; 1795 for (auto &Pair : RootToNode) 1796 Worklist.push_back(Pair.first); 1797 1798 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1799 // chains 1800 while (!Worklist.empty()) { 1801 auto *I = Worklist.pop_back_val(); 1802 1803 if (!AllInstructions.insert(I).second) 1804 continue; 1805 1806 for (Value *Op : I->operands()) { 1807 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1808 if (!FinalInstructions.count(I)) 1809 Worklist.emplace_back(OpI); 1810 } 1811 } 1812 } 1813 1814 // Find instructions that have users outside of chain 1815 for (auto *I : AllInstructions) { 1816 // Skip root nodes 1817 if (RootToNode.count(I)) 1818 continue; 1819 1820 for (User *U : I->users()) { 1821 if (AllInstructions.count(cast<Instruction>(U))) 1822 continue; 1823 1824 // Found an instruction that is not used by XCMLA/XCADD chain 1825 Worklist.emplace_back(I); 1826 break; 1827 } 1828 } 1829 1830 // If any instructions are found to be used outside, find and remove roots 1831 // that somehow connect to those instructions. 1832 SmallPtrSet<Instruction *, 16> Visited; 1833 while (!Worklist.empty()) { 1834 auto *I = Worklist.pop_back_val(); 1835 if (!Visited.insert(I).second) 1836 continue; 1837 1838 // Found an impacted root node. Removing it from the nodes to be 1839 // deinterleaved 1840 if (RootToNode.count(I)) { 1841 LLVM_DEBUG(dbgs() << "Instruction " << *I 1842 << " could be deinterleaved but its chain of complex " 1843 "operations have an outside user\n"); 1844 RootToNode.erase(I); 1845 } 1846 1847 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1848 continue; 1849 1850 for (User *U : I->users()) 1851 Worklist.emplace_back(cast<Instruction>(U)); 1852 1853 for (Value *Op : I->operands()) { 1854 if (auto *OpI = dyn_cast<Instruction>(Op)) 1855 Worklist.emplace_back(OpI); 1856 } 1857 } 1858 return !RootToNode.empty(); 1859 } 1860 1861 ComplexDeinterleavingGraph::NodePtr 1862 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1863 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1864 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2) 1865 return nullptr; 1866 1867 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1868 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1869 if (!Real || !Imag) 1870 return nullptr; 1871 1872 return identifyNode(Real, Imag); 1873 } 1874 1875 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1876 if (!SVI) 1877 return nullptr; 1878 1879 // Look for a shufflevector that takes separate vectors of the real and 1880 // imaginary components and recombines them into a single vector. 1881 if (!isInterleavingMask(SVI->getShuffleMask())) 1882 return nullptr; 1883 1884 Instruction *Real; 1885 Instruction *Imag; 1886 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1887 return nullptr; 1888 1889 return identifyNode(Real, Imag); 1890 } 1891 1892 ComplexDeinterleavingGraph::NodePtr 1893 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1894 Instruction *Imag) { 1895 Instruction *I = nullptr; 1896 Value *FinalValue = nullptr; 1897 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1898 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1899 match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>( 1900 m_Value(FinalValue)))) { 1901 NodePtr PlaceholderNode = prepareCompositeNode( 1902 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1903 PlaceholderNode->ReplacementNode = FinalValue; 1904 FinalInstructions.insert(Real); 1905 FinalInstructions.insert(Imag); 1906 return submitCompositeNode(PlaceholderNode); 1907 } 1908 1909 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1910 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1911 if (!RealShuffle || !ImagShuffle) { 1912 if (RealShuffle || ImagShuffle) 1913 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1914 return nullptr; 1915 } 1916 1917 Value *RealOp1 = RealShuffle->getOperand(1); 1918 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1919 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1920 return nullptr; 1921 } 1922 Value *ImagOp1 = ImagShuffle->getOperand(1); 1923 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1924 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1925 return nullptr; 1926 } 1927 1928 Value *RealOp0 = RealShuffle->getOperand(0); 1929 Value *ImagOp0 = ImagShuffle->getOperand(0); 1930 1931 if (RealOp0 != ImagOp0) { 1932 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1933 return nullptr; 1934 } 1935 1936 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1937 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1938 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1939 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1940 return nullptr; 1941 } 1942 1943 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1944 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1945 return nullptr; 1946 } 1947 1948 // Type checking, the shuffle type should be a vector type of the same 1949 // scalar type, but half the size 1950 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1951 Value *Op = Shuffle->getOperand(0); 1952 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1953 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1954 1955 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1956 return false; 1957 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1958 return false; 1959 1960 return true; 1961 }; 1962 1963 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1964 if (!CheckType(Shuffle)) 1965 return false; 1966 1967 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1968 int Last = *Mask.rbegin(); 1969 1970 Value *Op = Shuffle->getOperand(0); 1971 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1972 int NumElements = OpTy->getNumElements(); 1973 1974 // Ensure that the deinterleaving shuffle only pulls from the first 1975 // shuffle operand. 1976 return Last < NumElements; 1977 }; 1978 1979 if (RealShuffle->getType() != ImagShuffle->getType()) { 1980 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1981 return nullptr; 1982 } 1983 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1984 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1985 return nullptr; 1986 } 1987 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1988 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1989 return nullptr; 1990 } 1991 1992 NodePtr PlaceholderNode = 1993 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1994 RealShuffle, ImagShuffle); 1995 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1996 FinalInstructions.insert(RealShuffle); 1997 FinalInstructions.insert(ImagShuffle); 1998 return submitCompositeNode(PlaceholderNode); 1999 } 2000 2001 ComplexDeinterleavingGraph::NodePtr 2002 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { 2003 auto IsSplat = [](Value *V) -> bool { 2004 // Fixed-width vector with constants 2005 if (isa<ConstantDataVector>(V)) 2006 return true; 2007 2008 if (isa<ConstantInt>(V) || isa<ConstantFP>(V)) 2009 return isa<VectorType>(V->getType()); 2010 2011 VectorType *VTy; 2012 ArrayRef<int> Mask; 2013 // Splats are represented differently depending on whether the repeated 2014 // value is a constant or an Instruction 2015 if (auto *Const = dyn_cast<ConstantExpr>(V)) { 2016 if (Const->getOpcode() != Instruction::ShuffleVector) 2017 return false; 2018 VTy = cast<VectorType>(Const->getType()); 2019 Mask = Const->getShuffleMask(); 2020 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { 2021 VTy = Shuf->getType(); 2022 Mask = Shuf->getShuffleMask(); 2023 } else { 2024 return false; 2025 } 2026 2027 // When the data type is <1 x Type>, it's not possible to differentiate 2028 // between the ComplexDeinterleaving::Deinterleave and 2029 // ComplexDeinterleaving::Splat operations. 2030 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) 2031 return false; 2032 2033 return all_equal(Mask) && Mask[0] == 0; 2034 }; 2035 2036 if (!IsSplat(R) || !IsSplat(I)) 2037 return nullptr; 2038 2039 auto *Real = dyn_cast<Instruction>(R); 2040 auto *Imag = dyn_cast<Instruction>(I); 2041 if ((!Real && Imag) || (Real && !Imag)) 2042 return nullptr; 2043 2044 if (Real && Imag) { 2045 // Non-constant splats should be in the same basic block 2046 if (Real->getParent() != Imag->getParent()) 2047 return nullptr; 2048 2049 FinalInstructions.insert(Real); 2050 FinalInstructions.insert(Imag); 2051 } 2052 NodePtr PlaceholderNode = 2053 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); 2054 return submitCompositeNode(PlaceholderNode); 2055 } 2056 2057 ComplexDeinterleavingGraph::NodePtr 2058 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 2059 Instruction *Imag) { 2060 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI)) 2061 return nullptr; 2062 2063 PHIsFound = true; 2064 NodePtr PlaceholderNode = prepareCompositeNode( 2065 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 2066 return submitCompositeNode(PlaceholderNode); 2067 } 2068 2069 ComplexDeinterleavingGraph::NodePtr 2070 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 2071 Instruction *Imag) { 2072 auto *SelectReal = dyn_cast<SelectInst>(Real); 2073 auto *SelectImag = dyn_cast<SelectInst>(Imag); 2074 if (!SelectReal || !SelectImag) 2075 return nullptr; 2076 2077 Instruction *MaskA, *MaskB; 2078 Instruction *AR, *AI, *RA, *BI; 2079 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 2080 m_Instruction(RA))) || 2081 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 2082 m_Instruction(BI)))) 2083 return nullptr; 2084 2085 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 2086 return nullptr; 2087 2088 if (!MaskA->getType()->isVectorTy()) 2089 return nullptr; 2090 2091 auto NodeA = identifyNode(AR, AI); 2092 if (!NodeA) 2093 return nullptr; 2094 2095 auto NodeB = identifyNode(RA, BI); 2096 if (!NodeB) 2097 return nullptr; 2098 2099 NodePtr PlaceholderNode = prepareCompositeNode( 2100 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 2101 PlaceholderNode->addOperand(NodeA); 2102 PlaceholderNode->addOperand(NodeB); 2103 FinalInstructions.insert(MaskA); 2104 FinalInstructions.insert(MaskB); 2105 return submitCompositeNode(PlaceholderNode); 2106 } 2107 2108 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 2109 std::optional<FastMathFlags> Flags, 2110 Value *InputA, Value *InputB) { 2111 Value *I; 2112 switch (Opcode) { 2113 case Instruction::FNeg: 2114 I = B.CreateFNeg(InputA); 2115 break; 2116 case Instruction::FAdd: 2117 I = B.CreateFAdd(InputA, InputB); 2118 break; 2119 case Instruction::Add: 2120 I = B.CreateAdd(InputA, InputB); 2121 break; 2122 case Instruction::FSub: 2123 I = B.CreateFSub(InputA, InputB); 2124 break; 2125 case Instruction::Sub: 2126 I = B.CreateSub(InputA, InputB); 2127 break; 2128 case Instruction::FMul: 2129 I = B.CreateFMul(InputA, InputB); 2130 break; 2131 case Instruction::Mul: 2132 I = B.CreateMul(InputA, InputB); 2133 break; 2134 default: 2135 llvm_unreachable("Incorrect symmetric opcode"); 2136 } 2137 if (Flags) 2138 cast<Instruction>(I)->setFastMathFlags(*Flags); 2139 return I; 2140 } 2141 2142 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 2143 RawNodePtr Node) { 2144 if (Node->ReplacementNode) 2145 return Node->ReplacementNode; 2146 2147 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 2148 return Node->Operands.size() > Idx 2149 ? replaceNode(Builder, Node->Operands[Idx]) 2150 : nullptr; 2151 }; 2152 2153 Value *ReplacementNode; 2154 switch (Node->Operation) { 2155 case ComplexDeinterleavingOperation::CDot: { 2156 Value *Input0 = ReplaceOperandIfExist(Node, 0); 2157 Value *Input1 = ReplaceOperandIfExist(Node, 1); 2158 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 2159 assert(!Input1 || (Input0->getType() == Input1->getType() && 2160 "Node inputs need to be of the same type")); 2161 ReplacementNode = TL->createComplexDeinterleavingIR( 2162 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); 2163 break; 2164 } 2165 case ComplexDeinterleavingOperation::CAdd: 2166 case ComplexDeinterleavingOperation::CMulPartial: 2167 case ComplexDeinterleavingOperation::Symmetric: { 2168 Value *Input0 = ReplaceOperandIfExist(Node, 0); 2169 Value *Input1 = ReplaceOperandIfExist(Node, 1); 2170 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 2171 assert(!Input1 || (Input0->getType() == Input1->getType() && 2172 "Node inputs need to be of the same type")); 2173 assert(!Accumulator || 2174 (Input0->getType() == Accumulator->getType() && 2175 "Accumulator and input need to be of the same type")); 2176 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 2177 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 2178 Input0, Input1); 2179 else 2180 ReplacementNode = TL->createComplexDeinterleavingIR( 2181 Builder, Node->Operation, Node->Rotation, Input0, Input1, 2182 Accumulator); 2183 break; 2184 } 2185 case ComplexDeinterleavingOperation::Deinterleave: 2186 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 2187 break; 2188 case ComplexDeinterleavingOperation::Splat: { 2189 auto *NewTy = VectorType::getDoubleElementsVectorType( 2190 cast<VectorType>(Node->Real->getType())); 2191 auto *R = dyn_cast<Instruction>(Node->Real); 2192 auto *I = dyn_cast<Instruction>(Node->Imag); 2193 if (R && I) { 2194 // Splats that are not constant are interleaved where they are located 2195 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); 2196 IRBuilder<> IRB(InsertPoint); 2197 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2, 2198 NewTy, {Node->Real, Node->Imag}); 2199 } else { 2200 ReplacementNode = Builder.CreateIntrinsic( 2201 Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag}); 2202 } 2203 break; 2204 } 2205 case ComplexDeinterleavingOperation::ReductionPHI: { 2206 // If Operation is ReductionPHI, a new empty PHINode is created. 2207 // It is filled later when the ReductionOperation is processed. 2208 auto *OldPHI = cast<PHINode>(Node->Real); 2209 auto *VTy = cast<VectorType>(Node->Real->getType()); 2210 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2211 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt()); 2212 OldToNewPHI[OldPHI] = NewPHI; 2213 ReplacementNode = NewPHI; 2214 break; 2215 } 2216 case ComplexDeinterleavingOperation::ReductionSingle: 2217 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 2218 processReductionSingle(ReplacementNode, Node); 2219 break; 2220 case ComplexDeinterleavingOperation::ReductionOperation: 2221 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 2222 processReductionOperation(ReplacementNode, Node); 2223 break; 2224 case ComplexDeinterleavingOperation::ReductionSelect: { 2225 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 2226 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 2227 auto *A = replaceNode(Builder, Node->Operands[0]); 2228 auto *B = replaceNode(Builder, Node->Operands[1]); 2229 auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 2230 cast<VectorType>(MaskReal->getType())); 2231 auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, 2232 NewMaskTy, {MaskReal, MaskImag}); 2233 ReplacementNode = Builder.CreateSelect(NewMask, A, B); 2234 break; 2235 } 2236 } 2237 2238 assert(ReplacementNode && "Target failed to create Intrinsic call."); 2239 NumComplexTransformations += 1; 2240 Node->ReplacementNode = ReplacementNode; 2241 return ReplacementNode; 2242 } 2243 2244 void ComplexDeinterleavingGraph::processReductionSingle( 2245 Value *OperationReplacement, RawNodePtr Node) { 2246 auto *Real = cast<Instruction>(Node->Real); 2247 auto *OldPHI = ReductionInfo[Real].first; 2248 auto *NewPHI = OldToNewPHI[OldPHI]; 2249 auto *VTy = cast<VectorType>(Real->getType()); 2250 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2251 2252 Value *Init = OldPHI->getIncomingValueForBlock(Incoming); 2253 2254 IRBuilder<> Builder(Incoming->getTerminator()); 2255 2256 Value *NewInit = nullptr; 2257 if (auto *C = dyn_cast<Constant>(Init)) { 2258 if (C->isZeroValue()) 2259 NewInit = Constant::getNullValue(NewVTy); 2260 } 2261 2262 if (!NewInit) 2263 NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, 2264 {Init, Constant::getNullValue(VTy)}); 2265 2266 NewPHI->addIncoming(NewInit, Incoming); 2267 NewPHI->addIncoming(OperationReplacement, BackEdge); 2268 2269 auto *FinalReduction = ReductionInfo[Real].second; 2270 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt()); 2271 2272 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement); 2273 FinalReduction->replaceAllUsesWith(AddReduce); 2274 } 2275 2276 void ComplexDeinterleavingGraph::processReductionOperation( 2277 Value *OperationReplacement, RawNodePtr Node) { 2278 auto *Real = cast<Instruction>(Node->Real); 2279 auto *Imag = cast<Instruction>(Node->Imag); 2280 auto *OldPHIReal = ReductionInfo[Real].first; 2281 auto *OldPHIImag = ReductionInfo[Imag].first; 2282 auto *NewPHI = OldToNewPHI[OldPHIReal]; 2283 2284 auto *VTy = cast<VectorType>(Real->getType()); 2285 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2286 2287 // We have to interleave initial origin values coming from IncomingBlock 2288 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 2289 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 2290 2291 IRBuilder<> Builder(Incoming->getTerminator()); 2292 auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, 2293 {InitReal, InitImag}); 2294 2295 NewPHI->addIncoming(NewInit, Incoming); 2296 NewPHI->addIncoming(OperationReplacement, BackEdge); 2297 2298 // Deinterleave complex vector outside of loop so that it can be finally 2299 // reduced 2300 auto *FinalReductionReal = ReductionInfo[Real].second; 2301 auto *FinalReductionImag = ReductionInfo[Imag].second; 2302 2303 Builder.SetInsertPoint( 2304 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 2305 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2, 2306 OperationReplacement->getType(), 2307 OperationReplacement); 2308 2309 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 2310 FinalReductionReal->replaceUsesOfWith(Real, NewReal); 2311 2312 Builder.SetInsertPoint(FinalReductionImag); 2313 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 2314 FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 2315 } 2316 2317 void ComplexDeinterleavingGraph::replaceNodes() { 2318 SmallVector<Instruction *, 16> DeadInstrRoots; 2319 for (auto *RootInstruction : OrderedRoots) { 2320 // Check if this potential root went through check process and we can 2321 // deinterleave it 2322 if (!RootToNode.count(RootInstruction)) 2323 continue; 2324 2325 IRBuilder<> Builder(RootInstruction); 2326 auto RootNode = RootToNode[RootInstruction]; 2327 Value *R = replaceNode(Builder, RootNode.get()); 2328 2329 if (RootNode->Operation == 2330 ComplexDeinterleavingOperation::ReductionOperation) { 2331 auto *RootReal = cast<Instruction>(RootNode->Real); 2332 auto *RootImag = cast<Instruction>(RootNode->Imag); 2333 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 2334 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 2335 DeadInstrRoots.push_back(RootReal); 2336 DeadInstrRoots.push_back(RootImag); 2337 } else if (RootNode->Operation == 2338 ComplexDeinterleavingOperation::ReductionSingle) { 2339 auto *RootInst = cast<Instruction>(RootNode->Real); 2340 auto &Info = ReductionInfo[RootInst]; 2341 Info.first->removeIncomingValue(BackEdge); 2342 DeadInstrRoots.push_back(Info.second); 2343 } else { 2344 assert(R && "Unable to find replacement for RootInstruction"); 2345 DeadInstrRoots.push_back(RootInstruction); 2346 RootInstruction->replaceAllUsesWith(R); 2347 } 2348 } 2349 2350 for (auto *I : DeadInstrRoots) 2351 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 2352 } 2353